diff --git a/src/oid.c b/src/oid.c index b40e157..f1a6f52 100644 --- a/src/oid.c +++ b/src/oid.c @@ -45,63 +45,71 @@ git_oid_to_python(const git_oid *oid) return (PyObject*)py_oid; } - int -py_str_to_git_oid(PyObject *py_str, git_oid *oid) +_oid_from_hex(PyObject *py_oid, git_oid *oid) { PyObject *py_hex; - char *hex_or_bin; int err; + char *hex; Py_ssize_t len; - /* Case 1: Git Oid */ - if (PyObject_TypeCheck(py_str, (PyTypeObject*)&OidType)) { - git_oid_cpy(oid, &((Oid*)py_str)->oid); - return GIT_OID_RAWSZ; - } - - /* Case 2: raw sha (bytes) */ - if (PyBytes_Check(py_str)) { - err = PyBytes_AsStringAndSize(py_str, &hex_or_bin, &len); +#if PY_MAJOR_VERSION == 2 + /* Bytes (only supported in Python 2) */ + if (PyBytes_Check(py_oid)) { + err = PyBytes_AsStringAndSize(py_oid, &hex, &len); if (err) return -1; - if (len > GIT_OID_RAWSZ) { - PyErr_SetObject(PyExc_ValueError, py_str); + + err = git_oid_fromstrn(oid, hex, len); + if (err < 0) { + PyErr_SetObject(Error_type(err), py_oid); return -1; } - memcpy(oid->id, (const unsigned char*)hex_or_bin, len); - return len * 2; - } - /* Case 3: hex sha (unicode) */ - if (PyUnicode_Check(py_str)) { - py_hex = PyUnicode_AsASCIIString(py_str); + return len; + } +#endif + + /* Unicode */ + if (PyUnicode_Check(py_oid)) { + py_hex = PyUnicode_AsASCIIString(py_oid); if (py_hex == NULL) return -1; - err = PyBytes_AsStringAndSize(py_hex, &hex_or_bin, &len); + + err = PyBytes_AsStringAndSize(py_hex, &hex, &len); if (err) { Py_DECREF(py_hex); return -1; } - err = git_oid_fromstrn(oid, hex_or_bin, len); - + err = git_oid_fromstrn(oid, hex, len); Py_DECREF(py_hex); - if (err < 0) { - PyErr_SetObject(Error_type(err), py_str); + PyErr_SetObject(Error_type(err), py_oid); return -1; } + return len; } /* Type error */ - PyErr_Format(PyExc_TypeError, - "Git object id must be byte or a text string, not: %.200s", - Py_TYPE(py_str)->tp_name); + PyErr_SetObject(PyExc_TypeError, py_oid); return -1; } +int +py_str_to_git_oid(PyObject *py_oid, git_oid *oid) +{ + /* Oid */ + if (PyObject_TypeCheck(py_oid, (PyTypeObject*)&OidType)) { + git_oid_cpy(oid, &((Oid*)py_oid)->oid); + return GIT_OID_RAWSZ; + } + + /* Hex */ + return _oid_from_hex(py_oid, oid); +} + int py_str_to_git_oid_expand(git_repository *repo, PyObject *py_str, git_oid *oid) { @@ -152,6 +160,8 @@ Oid_init(Oid *self, PyObject *args, PyObject *kw) char *keywords[] = {"raw", "hex", NULL}; PyObject *raw = NULL, *hex = NULL; int err; + char *bytes; + Py_ssize_t len; if (!PyArg_ParseTupleAndKeywords(args, kw, "|OO", keywords, &raw, &hex)) return -1; @@ -166,12 +176,23 @@ Oid_init(Oid *self, PyObject *args, PyObject *kw) return -1; } - /* Get the oid. */ - if (raw != NULL) - err = py_str_to_git_oid(raw, &self->oid); - else - err = py_str_to_git_oid(hex, &self->oid); + /* Case 1: raw */ + if (raw != NULL) { + err = PyBytes_AsStringAndSize(raw, &bytes, &len); + if (err) + return -1; + if (len > GIT_OID_RAWSZ) { + PyErr_SetObject(PyExc_ValueError, raw); + return -1; + } + + memcpy(self->oid.id, (const unsigned char*)bytes, len); + return 0; + } + + /* Case 2: hex */ + err = _oid_from_hex(hex, &self->oid); if (err < 0) return -1; diff --git a/test/test_oid.py b/test/test_oid.py index 22dcb0b..97c7979 100644 --- a/test/test_oid.py +++ b/test/test_oid.py @@ -33,6 +33,7 @@ from __future__ import unicode_literals # Import from the Standard Library from binascii import unhexlify +from sys import version_info import unittest # Import from pygit2 @@ -55,6 +56,16 @@ class OidTest(utils.BareRepoTestCase): self.assertEqual(oid.raw, RAW) self.assertEqual(oid.hex, HEX) + def test_hex_bytes(self): + if version_info.major == 2: + hex = bytes(HEX) + oid = Oid(hex=hex) + self.assertEqual(oid.raw, RAW) + self.assertEqual(oid.hex, HEX) + else: + hex = bytes(HEX, "ascii") + self.assertRaises(TypeError, Oid, hex=hex) + def test_none(self): self.assertRaises(ValueError, Oid) diff --git a/test/test_repository.py b/test/test_repository.py index 39fea40..373730e 100644 --- a/test/test_repository.py +++ b/test/test_repository.py @@ -48,8 +48,9 @@ from . import utils HEAD_SHA = '784855caf26449a1914d2cf62d12b9374d76ae78' PARENT_SHA = 'f5e5aa4e36ab0fe62ee1ccc6eb8f79b866863b87' # HEAD^ -A_HEX_SHA = 'af431f20fc541ed6d5afede3e2dc7160f6f01f16' -A_BIN_SHA = binascii.unhexlify(A_HEX_SHA.encode('ascii')) +BLOB_HEX = 'af431f20fc541ed6d5afede3e2dc7160f6f01f16' +BLOB_RAW = binascii.unhexlify(BLOB_HEX.encode('ascii')) +BLOB_OID = Oid(raw=BLOB_RAW) class RepositoryTest(utils.BareRepoTestCase): @@ -71,15 +72,15 @@ class RepositoryTest(utils.BareRepoTestCase): self.assertRaises(TypeError, self.repo.read, 123) self.assertRaisesWithArg(KeyError, '1' * 40, self.repo.read, '1' * 40) - ab = self.repo.read(A_BIN_SHA) - a = self.repo.read(A_HEX_SHA) + ab = self.repo.read(BLOB_OID) + a = self.repo.read(BLOB_HEX) self.assertEqual(ab, a) self.assertEqual((GIT_OBJ_BLOB, b'a contents\n'), a) a2 = self.repo.read('7f129fd57e31e935c6d60a0c794efe4e6927664b') self.assertEqual((GIT_OBJ_BLOB, b'a contents 2\n'), a2) - a_hex_prefix = A_HEX_SHA[:4] + a_hex_prefix = BLOB_HEX[:4] a3 = self.repo.read(a_hex_prefix) self.assertEqual((GIT_OBJ_BLOB, b'a contents\n'), a3) @@ -93,29 +94,28 @@ class RepositoryTest(utils.BareRepoTestCase): def test_contains(self): self.assertRaises(TypeError, lambda: 123 in self.repo) - self.assertTrue(A_BIN_SHA in self.repo) - self.assertTrue(A_BIN_SHA[:10] in self.repo) - self.assertTrue(A_HEX_SHA in self.repo) - self.assertTrue(A_HEX_SHA[:10] in self.repo) + self.assertTrue(BLOB_OID in self.repo) + self.assertTrue(BLOB_HEX in self.repo) + self.assertTrue(BLOB_HEX[:10] in self.repo) self.assertFalse('a' * 40 in self.repo) self.assertFalse('a' * 20 in self.repo) def test_iterable(self): l = [ obj for obj in self.repo ] - self.assertTrue(A_HEX_SHA in l) + self.assertTrue(BLOB_HEX in l) def test_lookup_blob(self): self.assertRaises(TypeError, lambda: self.repo[123]) - self.assertEqual(self.repo[A_BIN_SHA].hex, A_HEX_SHA) - a = self.repo[A_HEX_SHA] + self.assertEqual(self.repo[BLOB_OID].hex, BLOB_HEX) + a = self.repo[BLOB_HEX] self.assertEqual(b'a contents\n', a.read_raw()) - self.assertEqual(A_HEX_SHA, a.hex) + self.assertEqual(BLOB_HEX, a.hex) self.assertEqual(GIT_OBJ_BLOB, a.type) def test_lookup_blob_prefix(self): - a = self.repo[A_HEX_SHA[:5]] + a = self.repo[BLOB_HEX[:5]] self.assertEqual(b'a contents\n', a.read_raw()) - self.assertEqual(A_HEX_SHA, a.hex) + self.assertEqual(BLOB_HEX, a.hex) self.assertEqual(GIT_OBJ_BLOB, a.type) def test_lookup_commit(self):