diff --git a/pygit2/__init__.py b/pygit2/__init__.py index 3e437d7..b4c7d0d 100644 --- a/pygit2/__init__.py +++ b/pygit2/__init__.py @@ -35,10 +35,10 @@ from _pygit2 import * from .blame import Blame, BlameHunk from .config import Config from .credentials import * -from .errors import check_error +from .errors import check_error, Passthrough from .ffi import ffi, C from .index import Index, IndexEntry -from .remote import Remote, get_credentials +from .remote import Remote, RemoteCallbacks, get_credentials from .repository import Repository from .settings import Settings from .submodule import Submodule diff --git a/pygit2/errors.py b/pygit2/errors.py index e0a6767..3d7f6f9 100644 --- a/pygit2/errors.py +++ b/pygit2/errors.py @@ -62,3 +62,6 @@ def check_error(err, io=False): # Generic Git error raise GitError(message) + +# Indicate that we want libgit2 to pretend a function was not set +Passthrough = Exception("The function asked for pass-through") diff --git a/pygit2/remote.py b/pygit2/remote.py index def4497..c38dea0 100644 --- a/pygit2/remote.py +++ b/pygit2/remote.py @@ -30,7 +30,7 @@ from __future__ import absolute_import # Import from pygit2 from _pygit2 import Oid -from .errors import check_error, GitError +from .errors import check_error, GitError, Passthrough from .ffi import ffi, C from .credentials import KeypairFromAgent from .refspec import Refspec @@ -71,7 +71,12 @@ class TransferProgress(object): """"Number of bytes received up to now""" -class Remote(object): +class RemoteCallbacks(object): + """Base class for pygit2 remote callbacks. + + Inherit from this class and override the callbacks which you want to use + in your class, which you can then pass to the network operations. + """ def sideband_progress(self, string): """Progress output callback @@ -100,7 +105,7 @@ class Remote(object): Return value: credential """ - pass + raise Passthrough def transfer_progress(self, stats): """Transfer progress callback @@ -131,48 +136,7 @@ class Remote(object): :param str messsage: rejection message from the remote. If None, the update was accepted. """ - def __init__(self, repo, ptr): - """The constructor is for internal use only""" - - self._repo = repo - self._remote = ptr - self._stored_exception = None - - def __del__(self): - C.git_remote_free(self._remote) - - @property - def name(self): - """Name of the remote""" - - return maybe_string(C.git_remote_name(self._remote)) - - @property - def url(self): - """Url of the remote""" - - return maybe_string(C.git_remote_url(self._remote)) - - @property - def push_url(self): - """Push url of the remote""" - - return maybe_string(C.git_remote_pushurl(self._remote)) - - def save(self): - """Save a remote to its repository's configuration.""" - - err = C.git_remote_save(self._remote) - check_error(err) - - def fetch(self, refspecs=None, message=None): - """Perform a fetch against this remote. Returns a - object. - """ - - fetch_opts = ffi.new('git_fetch_options *') - err = C.git_fetch_init_options(fetch_opts, C.GIT_FETCH_OPTIONS_VERSION) - + def _fill_fetch_options(self, fetch_opts): fetch_opts.callbacks.sideband_progress = self._sideband_progress_cb fetch_opts.callbacks.transfer_progress = self._transfer_progress_cb fetch_opts.callbacks.update_tips = self._update_tips_cb @@ -183,58 +147,7 @@ class Remote(object): self._stored_exception = None - try: - with StrArray(refspecs) as arr: - err = C.git_remote_fetch(self._remote, arr, fetch_opts, to_bytes(message)) - if self._stored_exception: - raise self._stored_exception - check_error(err) - finally: - self._self_handle = None - - return TransferProgress(C.git_remote_stats(self._remote)) - - @property - def refspec_count(self): - """Total number of refspecs in this remote""" - - return C.git_remote_refspec_count(self._remote) - - def get_refspec(self, n): - """Return the object at the given position.""" - spec = C.git_remote_get_refspec(self._remote, n) - return Refspec(self, spec) - - @property - def fetch_refspecs(self): - """Refspecs that will be used for fetching""" - - specs = ffi.new('git_strarray *') - err = C.git_remote_get_fetch_refspecs(specs, self._remote) - check_error(err) - - return strarray_to_strings(specs) - - @property - def push_refspecs(self): - """Refspecs that will be used for pushing""" - - specs = ffi.new('git_strarray *') - err = C.git_remote_get_push_refspecs(specs, self._remote) - check_error(err) - - return strarray_to_strings(specs) - - def push(self, specs): - """Push the given refspec to the remote. Raises ``GitError`` on - protocol error or unpack failure. - - :param [str] specs: push refspecs to use - """ - push_opts = ffi.new('git_push_options *') - err = C.git_push_init_options(push_opts, C.GIT_PUSH_OPTIONS_VERSION) - - # Build custom callback structure + def _fill_push_options(self, push_opts): push_opts.callbacks.sideband_progress = self._sideband_progress_cb push_opts.callbacks.transfer_progress = self._transfer_progress_cb push_opts.callbacks.update_tips = self._update_tips_cb @@ -244,13 +157,6 @@ class Remote(object): self._self_handle = ffi.new_handle(self) push_opts.callbacks.payload = self._self_handle - try: - with StrArray(specs) as refspecs: - err = C.git_remote_push(self._remote, refspecs, push_opts) - check_error(err) - finally: - self._self_handle = None - # These functions exist to be called by the git_remote as # callbacks. They proxy the call to whatever the user set @@ -337,11 +243,125 @@ class Remote(object): cred_out[0] = ccred[0] except Exception as e: + if e is Passthrough: + return C.GIT_PASSTHROUGH + self._stored_exception = e return C.GIT_EUSER return 0 +class Remote(object): + def __init__(self, repo, ptr): + """The constructor is for internal use only""" + + self._repo = repo + self._remote = ptr + self._stored_exception = None + + def __del__(self): + C.git_remote_free(self._remote) + + @property + def name(self): + """Name of the remote""" + + return maybe_string(C.git_remote_name(self._remote)) + + @property + def url(self): + """Url of the remote""" + + return maybe_string(C.git_remote_url(self._remote)) + + @property + def push_url(self): + """Push url of the remote""" + + return maybe_string(C.git_remote_pushurl(self._remote)) + + def save(self): + """Save a remote to its repository's configuration.""" + + err = C.git_remote_save(self._remote) + check_error(err) + + def fetch(self, refspecs=None, callbacks=None, message=None): + """Perform a fetch against this remote. Returns a + object. + """ + + fetch_opts = ffi.new('git_fetch_options *') + err = C.git_fetch_init_options(fetch_opts, C.GIT_FETCH_OPTIONS_VERSION) + + if callbacks is None: + callbacks = RemoteCallbacks() + + callbacks._fill_fetch_options(fetch_opts) + + try: + with StrArray(refspecs) as arr: + err = C.git_remote_fetch(self._remote, arr, fetch_opts, to_bytes(message)) + if callbacks._stored_exception: + raise callbacks._stored_exception + check_error(err) + finally: + callbacks._self_handle = None + + return TransferProgress(C.git_remote_stats(self._remote)) + + @property + def refspec_count(self): + """Total number of refspecs in this remote""" + + return C.git_remote_refspec_count(self._remote) + + def get_refspec(self, n): + """Return the object at the given position.""" + spec = C.git_remote_get_refspec(self._remote, n) + return Refspec(self, spec) + + @property + def fetch_refspecs(self): + """Refspecs that will be used for fetching""" + + specs = ffi.new('git_strarray *') + err = C.git_remote_get_fetch_refspecs(specs, self._remote) + check_error(err) + + return strarray_to_strings(specs) + + @property + def push_refspecs(self): + """Refspecs that will be used for pushing""" + + specs = ffi.new('git_strarray *') + err = C.git_remote_get_push_refspecs(specs, self._remote) + check_error(err) + + return strarray_to_strings(specs) + + def push(self, specs, callbacks=None): + """Push the given refspec to the remote. Raises ``GitError`` on + protocol error or unpack failure. + + :param [str] specs: push refspecs to use + """ + push_opts = ffi.new('git_push_options *') + err = C.git_push_init_options(push_opts, C.GIT_PUSH_OPTIONS_VERSION) + + if callbacks is None: + callbacks = RemoteCallbacks() + + callbacks._fill_push_options(push_opts) + # Build custom callback structure + + try: + with StrArray(specs) as refspecs: + err = C.git_remote_push(self._remote, refspecs, push_opts) + check_error(err) + finally: + callbacks._self_handle = None def get_credentials(fn, url, username, allowed): """Call fn and return the credentials object""" diff --git a/test/test_credentials.py b/test/test_credentials.py index 376032f..461fe5e 100644 --- a/test/test_credentials.py +++ b/test/test_credentials.py @@ -70,32 +70,33 @@ class CredentialCreateTest(utils.NoRepoTestCase): class CredentialCallback(utils.RepoTestCase): def test_callback(self): - def credentials_cb(url, username, allowed): - self.assertTrue(allowed & GIT_CREDTYPE_USERPASS_PLAINTEXT) - raise Exception("I don't know the password") + class MyCallbacks(pygit2.RemoteCallbacks): + def credentials(url, username, allowed): + self.assertTrue(allowed & GIT_CREDTYPE_USERPASS_PLAINTEXT) + raise Exception("I don't know the password") remote = self.repo.create_remote("github", "https://github.com/github/github") - remote.credentials = credentials_cb - self.assertRaises(Exception, remote.fetch) + self.assertRaises(Exception, lambda: remote.fetch(callbacks=MyCallbacks())) def test_bad_cred_type(self): - def credentials_cb(url, username, allowed): - self.assertTrue(allowed & GIT_CREDTYPE_USERPASS_PLAINTEXT) - return Keypair("git", "foo.pub", "foo", "sekkrit") + class MyCallbacks(pygit2.RemoteCallbacks): + def credentials(url, username, allowed): + self.assertTrue(allowed & GIT_CREDTYPE_USERPASS_PLAINTEXT) + return Keypair("git", "foo.pub", "foo", "sekkrit") remote = self.repo.create_remote("github", "https://github.com/github/github") - remote.credentials = credentials_cb - - self.assertRaises(TypeError, remote.fetch) + self.assertRaises(TypeError, lambda: remote.fetch(callbacks=MyCallbacks())) class CallableCredentialTest(utils.RepoTestCase): def test_user_pass(self): - remote = self.repo.create_remote("bb", "https://bitbucket.org/libgit2/testgitrepository.git") - remote.credentials = UserPass("libgit2", "libgit2") + class MyCallbacks(pygit2.RemoteCallbacks): + def __init__(self): + self.credentials = UserPass("libgit2", "libgit2") - remote.fetch() + remote = self.repo.create_remote("bb", "https://bitbucket.org/libgit2/testgitrepository.git") + remote.fetch(callbacks=MyCallbacks()) if __name__ == '__main__': unittest.main() diff --git a/test/test_remote.py b/test/test_remote.py index c2e4f8f..b39312a 100644 --- a/test/test_remote.py +++ b/test/test_remote.py @@ -213,30 +213,37 @@ class EmptyRepositoryTest(utils.EmptyRepoTestCase): def test_transfer_progress(self): self.tp = None - def tp_cb(stats): - self.tp = stats + class MyCallbacks(pygit2.RemoteCallbacks): + def transfer_progress(self, stats): + self.tp = stats + callbacks = MyCallbacks() remote = self.repo.remotes[0] - remote.transfer_progress = tp_cb - stats = remote.fetch() - self.assertEqual(stats.received_bytes, self.tp.received_bytes) - self.assertEqual(stats.indexed_objects, self.tp.indexed_objects) - self.assertEqual(stats.received_objects, self.tp.received_objects) + stats = remote.fetch(callbacks=callbacks) + self.assertEqual(stats.received_bytes, callbacks.tp.received_bytes) + self.assertEqual(stats.indexed_objects, callbacks.tp.indexed_objects) + self.assertEqual(stats.received_objects, callbacks.tp.received_objects) def test_update_tips(self): remote = self.repo.remotes[0] - self.i = 0 - self.tips = [('refs/remotes/origin/master', Oid(hex='0'*40), - Oid(hex='784855caf26449a1914d2cf62d12b9374d76ae78')), - ('refs/tags/root', Oid(hex='0'*40), - Oid(hex='3d2962987c695a29f1f80b6c3aa4ec046ef44369'))] + tips = [('refs/remotes/origin/master', Oid(hex='0'*40), + Oid(hex='784855caf26449a1914d2cf62d12b9374d76ae78')), + ('refs/tags/root', Oid(hex='0'*40), + Oid(hex='3d2962987c695a29f1f80b6c3aa4ec046ef44369'))] - def ut_cb(name, old, new): - self.assertEqual(self.tips[self.i], (name, old, new)) - self.i += 1 + class MyCallbacks(pygit2.RemoteCallbacks): + def __init__(self, test_self, tips): + self.test = test_self + self.tips = tips + self.i = 0 - remote.update_tips = ut_cb - remote.fetch() + def update_tips(self, name, old, new): + self.test.assertEqual(self.tips[self.i], (name, old, new)) + self.i += 1 + + callbacks = MyCallbacks(self, tips) + remote.fetch(callbacks=callbacks) + self.assertTrue(callbacks.i > 0) class PushTestCase(unittest.TestCase): def setUp(self):