Move remote callbacks to use a class for callbacks
This represents what's going on much better than the remnants from the older methods. What we do is pass a list of callbacks to libgit2 for it to call, and they are valid for a single operation, not for the remote itself. This should also make it easier to re-use callbacks which have already been set up.
This commit is contained in:
		@@ -35,10 +35,10 @@ from _pygit2 import *
 | 
			
		||||
from .blame import Blame, BlameHunk
 | 
			
		||||
from .config import Config
 | 
			
		||||
from .credentials import *
 | 
			
		||||
from .errors import check_error
 | 
			
		||||
from .errors import check_error, Passthrough
 | 
			
		||||
from .ffi import ffi, C
 | 
			
		||||
from .index import Index, IndexEntry
 | 
			
		||||
from .remote import Remote, get_credentials
 | 
			
		||||
from .remote import Remote, RemoteCallbacks, get_credentials
 | 
			
		||||
from .repository import Repository
 | 
			
		||||
from .settings import Settings
 | 
			
		||||
from .submodule import Submodule
 | 
			
		||||
 
 | 
			
		||||
@@ -62,3 +62,6 @@ def check_error(err, io=False):
 | 
			
		||||
 | 
			
		||||
    # Generic Git error
 | 
			
		||||
    raise GitError(message)
 | 
			
		||||
 | 
			
		||||
# Indicate that we want libgit2 to pretend a function was not set
 | 
			
		||||
Passthrough = Exception("The function asked for pass-through")
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										228
									
								
								pygit2/remote.py
									
									
									
									
									
								
							
							
						
						
									
										228
									
								
								pygit2/remote.py
									
									
									
									
									
								
							@@ -30,7 +30,7 @@ from __future__ import absolute_import
 | 
			
		||||
 | 
			
		||||
# Import from pygit2
 | 
			
		||||
from _pygit2 import Oid
 | 
			
		||||
from .errors import check_error, GitError
 | 
			
		||||
from .errors import check_error, GitError, Passthrough
 | 
			
		||||
from .ffi import ffi, C
 | 
			
		||||
from .credentials import KeypairFromAgent
 | 
			
		||||
from .refspec import Refspec
 | 
			
		||||
@@ -71,7 +71,12 @@ class TransferProgress(object):
 | 
			
		||||
        """"Number of bytes received up to now"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Remote(object):
 | 
			
		||||
class RemoteCallbacks(object):
 | 
			
		||||
    """Base class for pygit2 remote callbacks.
 | 
			
		||||
 | 
			
		||||
    Inherit from this class and override the callbacks which you want to use
 | 
			
		||||
    in your class, which you can then pass to the network operations.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def sideband_progress(self, string):
 | 
			
		||||
        """Progress output callback
 | 
			
		||||
@@ -100,7 +105,7 @@ class Remote(object):
 | 
			
		||||
 | 
			
		||||
        Return value: credential
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
        raise Passthrough
 | 
			
		||||
 | 
			
		||||
    def transfer_progress(self, stats):
 | 
			
		||||
        """Transfer progress callback
 | 
			
		||||
@@ -131,48 +136,7 @@ class Remote(object):
 | 
			
		||||
        :param str messsage: rejection message from the remote. If None, the update was accepted.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, repo, ptr):
 | 
			
		||||
        """The constructor is for internal use only"""
 | 
			
		||||
 | 
			
		||||
        self._repo = repo
 | 
			
		||||
        self._remote = ptr
 | 
			
		||||
        self._stored_exception = None
 | 
			
		||||
 | 
			
		||||
    def __del__(self):
 | 
			
		||||
        C.git_remote_free(self._remote)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def name(self):
 | 
			
		||||
        """Name of the remote"""
 | 
			
		||||
 | 
			
		||||
        return maybe_string(C.git_remote_name(self._remote))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def url(self):
 | 
			
		||||
        """Url of the remote"""
 | 
			
		||||
 | 
			
		||||
        return maybe_string(C.git_remote_url(self._remote))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def push_url(self):
 | 
			
		||||
        """Push url of the remote"""
 | 
			
		||||
 | 
			
		||||
        return maybe_string(C.git_remote_pushurl(self._remote))
 | 
			
		||||
 | 
			
		||||
    def save(self):
 | 
			
		||||
        """Save a remote to its repository's configuration."""
 | 
			
		||||
 | 
			
		||||
        err = C.git_remote_save(self._remote)
 | 
			
		||||
        check_error(err)
 | 
			
		||||
 | 
			
		||||
    def fetch(self, refspecs=None, message=None):
 | 
			
		||||
        """Perform a fetch against this remote. Returns a <TransferProgress>
 | 
			
		||||
        object.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        fetch_opts = ffi.new('git_fetch_options *')
 | 
			
		||||
        err = C.git_fetch_init_options(fetch_opts, C.GIT_FETCH_OPTIONS_VERSION)
 | 
			
		||||
 | 
			
		||||
    def _fill_fetch_options(self, fetch_opts):
 | 
			
		||||
        fetch_opts.callbacks.sideband_progress = self._sideband_progress_cb
 | 
			
		||||
        fetch_opts.callbacks.transfer_progress = self._transfer_progress_cb
 | 
			
		||||
        fetch_opts.callbacks.update_tips = self._update_tips_cb
 | 
			
		||||
@@ -183,58 +147,7 @@ class Remote(object):
 | 
			
		||||
 | 
			
		||||
        self._stored_exception = None
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            with StrArray(refspecs) as arr:
 | 
			
		||||
                err = C.git_remote_fetch(self._remote, arr, fetch_opts, to_bytes(message))
 | 
			
		||||
                if self._stored_exception:
 | 
			
		||||
                    raise self._stored_exception
 | 
			
		||||
                check_error(err)
 | 
			
		||||
        finally:
 | 
			
		||||
            self._self_handle = None
 | 
			
		||||
 | 
			
		||||
        return TransferProgress(C.git_remote_stats(self._remote))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def refspec_count(self):
 | 
			
		||||
        """Total number of refspecs in this remote"""
 | 
			
		||||
 | 
			
		||||
        return C.git_remote_refspec_count(self._remote)
 | 
			
		||||
 | 
			
		||||
    def get_refspec(self, n):
 | 
			
		||||
        """Return the <Refspec> object at the given position."""
 | 
			
		||||
        spec = C.git_remote_get_refspec(self._remote, n)
 | 
			
		||||
        return Refspec(self, spec)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def fetch_refspecs(self):
 | 
			
		||||
        """Refspecs that will be used for fetching"""
 | 
			
		||||
 | 
			
		||||
        specs = ffi.new('git_strarray *')
 | 
			
		||||
        err = C.git_remote_get_fetch_refspecs(specs, self._remote)
 | 
			
		||||
        check_error(err)
 | 
			
		||||
 | 
			
		||||
        return strarray_to_strings(specs)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def push_refspecs(self):
 | 
			
		||||
        """Refspecs that will be used for pushing"""
 | 
			
		||||
 | 
			
		||||
        specs = ffi.new('git_strarray *')
 | 
			
		||||
        err = C.git_remote_get_push_refspecs(specs, self._remote)
 | 
			
		||||
        check_error(err)
 | 
			
		||||
 | 
			
		||||
        return strarray_to_strings(specs)
 | 
			
		||||
 | 
			
		||||
    def push(self, specs):
 | 
			
		||||
        """Push the given refspec to the remote. Raises ``GitError`` on
 | 
			
		||||
        protocol error or unpack failure.
 | 
			
		||||
 | 
			
		||||
        :param [str] specs: push refspecs to use
 | 
			
		||||
        """
 | 
			
		||||
        push_opts = ffi.new('git_push_options *')
 | 
			
		||||
        err = C.git_push_init_options(push_opts, C.GIT_PUSH_OPTIONS_VERSION)
 | 
			
		||||
 | 
			
		||||
        # Build custom callback structure
 | 
			
		||||
    def _fill_push_options(self, push_opts):
 | 
			
		||||
        push_opts.callbacks.sideband_progress = self._sideband_progress_cb
 | 
			
		||||
        push_opts.callbacks.transfer_progress = self._transfer_progress_cb
 | 
			
		||||
        push_opts.callbacks.update_tips = self._update_tips_cb
 | 
			
		||||
@@ -244,13 +157,6 @@ class Remote(object):
 | 
			
		||||
        self._self_handle = ffi.new_handle(self)
 | 
			
		||||
        push_opts.callbacks.payload = self._self_handle
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            with StrArray(specs) as refspecs:
 | 
			
		||||
                err = C.git_remote_push(self._remote, refspecs, push_opts)
 | 
			
		||||
                check_error(err)
 | 
			
		||||
        finally:
 | 
			
		||||
            self._self_handle = None
 | 
			
		||||
 | 
			
		||||
    # These functions exist to be called by the git_remote as
 | 
			
		||||
    # callbacks. They proxy the call to whatever the user set
 | 
			
		||||
 | 
			
		||||
@@ -337,11 +243,125 @@ class Remote(object):
 | 
			
		||||
            cred_out[0] = ccred[0]
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            if e is Passthrough:
 | 
			
		||||
                return C.GIT_PASSTHROUGH
 | 
			
		||||
 | 
			
		||||
            self._stored_exception = e
 | 
			
		||||
            return C.GIT_EUSER
 | 
			
		||||
 | 
			
		||||
        return 0
 | 
			
		||||
 | 
			
		||||
class Remote(object):
 | 
			
		||||
    def __init__(self, repo, ptr):
 | 
			
		||||
        """The constructor is for internal use only"""
 | 
			
		||||
 | 
			
		||||
        self._repo = repo
 | 
			
		||||
        self._remote = ptr
 | 
			
		||||
        self._stored_exception = None
 | 
			
		||||
 | 
			
		||||
    def __del__(self):
 | 
			
		||||
        C.git_remote_free(self._remote)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def name(self):
 | 
			
		||||
        """Name of the remote"""
 | 
			
		||||
 | 
			
		||||
        return maybe_string(C.git_remote_name(self._remote))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def url(self):
 | 
			
		||||
        """Url of the remote"""
 | 
			
		||||
 | 
			
		||||
        return maybe_string(C.git_remote_url(self._remote))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def push_url(self):
 | 
			
		||||
        """Push url of the remote"""
 | 
			
		||||
 | 
			
		||||
        return maybe_string(C.git_remote_pushurl(self._remote))
 | 
			
		||||
 | 
			
		||||
    def save(self):
 | 
			
		||||
        """Save a remote to its repository's configuration."""
 | 
			
		||||
 | 
			
		||||
        err = C.git_remote_save(self._remote)
 | 
			
		||||
        check_error(err)
 | 
			
		||||
 | 
			
		||||
    def fetch(self, refspecs=None, callbacks=None, message=None):
 | 
			
		||||
        """Perform a fetch against this remote. Returns a <TransferProgress>
 | 
			
		||||
        object.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        fetch_opts = ffi.new('git_fetch_options *')
 | 
			
		||||
        err = C.git_fetch_init_options(fetch_opts, C.GIT_FETCH_OPTIONS_VERSION)
 | 
			
		||||
 | 
			
		||||
        if callbacks is None:
 | 
			
		||||
            callbacks = RemoteCallbacks()
 | 
			
		||||
 | 
			
		||||
        callbacks._fill_fetch_options(fetch_opts)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            with StrArray(refspecs) as arr:
 | 
			
		||||
                err = C.git_remote_fetch(self._remote, arr, fetch_opts, to_bytes(message))
 | 
			
		||||
                if callbacks._stored_exception:
 | 
			
		||||
                    raise callbacks._stored_exception
 | 
			
		||||
                check_error(err)
 | 
			
		||||
        finally:
 | 
			
		||||
            callbacks._self_handle = None
 | 
			
		||||
 | 
			
		||||
        return TransferProgress(C.git_remote_stats(self._remote))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def refspec_count(self):
 | 
			
		||||
        """Total number of refspecs in this remote"""
 | 
			
		||||
 | 
			
		||||
        return C.git_remote_refspec_count(self._remote)
 | 
			
		||||
 | 
			
		||||
    def get_refspec(self, n):
 | 
			
		||||
        """Return the <Refspec> object at the given position."""
 | 
			
		||||
        spec = C.git_remote_get_refspec(self._remote, n)
 | 
			
		||||
        return Refspec(self, spec)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def fetch_refspecs(self):
 | 
			
		||||
        """Refspecs that will be used for fetching"""
 | 
			
		||||
 | 
			
		||||
        specs = ffi.new('git_strarray *')
 | 
			
		||||
        err = C.git_remote_get_fetch_refspecs(specs, self._remote)
 | 
			
		||||
        check_error(err)
 | 
			
		||||
 | 
			
		||||
        return strarray_to_strings(specs)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def push_refspecs(self):
 | 
			
		||||
        """Refspecs that will be used for pushing"""
 | 
			
		||||
 | 
			
		||||
        specs = ffi.new('git_strarray *')
 | 
			
		||||
        err = C.git_remote_get_push_refspecs(specs, self._remote)
 | 
			
		||||
        check_error(err)
 | 
			
		||||
 | 
			
		||||
        return strarray_to_strings(specs)
 | 
			
		||||
 | 
			
		||||
    def push(self, specs, callbacks=None):
 | 
			
		||||
        """Push the given refspec to the remote. Raises ``GitError`` on
 | 
			
		||||
        protocol error or unpack failure.
 | 
			
		||||
 | 
			
		||||
        :param [str] specs: push refspecs to use
 | 
			
		||||
        """
 | 
			
		||||
        push_opts = ffi.new('git_push_options *')
 | 
			
		||||
        err = C.git_push_init_options(push_opts, C.GIT_PUSH_OPTIONS_VERSION)
 | 
			
		||||
 | 
			
		||||
        if callbacks is None:
 | 
			
		||||
            callbacks = RemoteCallbacks()
 | 
			
		||||
 | 
			
		||||
        callbacks._fill_push_options(push_opts)
 | 
			
		||||
        # Build custom callback structure
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            with StrArray(specs) as refspecs:
 | 
			
		||||
                err = C.git_remote_push(self._remote, refspecs, push_opts)
 | 
			
		||||
                check_error(err)
 | 
			
		||||
        finally:
 | 
			
		||||
            callbacks._self_handle = None
 | 
			
		||||
 | 
			
		||||
def get_credentials(fn, url, username, allowed):
 | 
			
		||||
    """Call fn and return the credentials object"""
 | 
			
		||||
 
 | 
			
		||||
@@ -70,32 +70,33 @@ class CredentialCreateTest(utils.NoRepoTestCase):
 | 
			
		||||
 | 
			
		||||
class CredentialCallback(utils.RepoTestCase):
 | 
			
		||||
    def test_callback(self):
 | 
			
		||||
        def credentials_cb(url, username, allowed):
 | 
			
		||||
            self.assertTrue(allowed & GIT_CREDTYPE_USERPASS_PLAINTEXT)
 | 
			
		||||
            raise Exception("I don't know the password")
 | 
			
		||||
        class MyCallbacks(pygit2.RemoteCallbacks):
 | 
			
		||||
            def credentials(url, username, allowed):
 | 
			
		||||
                self.assertTrue(allowed & GIT_CREDTYPE_USERPASS_PLAINTEXT)
 | 
			
		||||
                raise Exception("I don't know the password")
 | 
			
		||||
 | 
			
		||||
        remote = self.repo.create_remote("github", "https://github.com/github/github")
 | 
			
		||||
        remote.credentials = credentials_cb
 | 
			
		||||
 | 
			
		||||
        self.assertRaises(Exception, remote.fetch)
 | 
			
		||||
        self.assertRaises(Exception, lambda: remote.fetch(callbacks=MyCallbacks()))
 | 
			
		||||
 | 
			
		||||
    def test_bad_cred_type(self):
 | 
			
		||||
        def credentials_cb(url, username, allowed):
 | 
			
		||||
            self.assertTrue(allowed & GIT_CREDTYPE_USERPASS_PLAINTEXT)
 | 
			
		||||
            return Keypair("git", "foo.pub", "foo", "sekkrit")
 | 
			
		||||
        class MyCallbacks(pygit2.RemoteCallbacks):
 | 
			
		||||
            def credentials(url, username, allowed):
 | 
			
		||||
                self.assertTrue(allowed & GIT_CREDTYPE_USERPASS_PLAINTEXT)
 | 
			
		||||
                return Keypair("git", "foo.pub", "foo", "sekkrit")
 | 
			
		||||
 | 
			
		||||
        remote = self.repo.create_remote("github", "https://github.com/github/github")
 | 
			
		||||
        remote.credentials = credentials_cb
 | 
			
		||||
 | 
			
		||||
        self.assertRaises(TypeError, remote.fetch)
 | 
			
		||||
        self.assertRaises(TypeError, lambda: remote.fetch(callbacks=MyCallbacks()))
 | 
			
		||||
 | 
			
		||||
class CallableCredentialTest(utils.RepoTestCase):
 | 
			
		||||
 | 
			
		||||
    def test_user_pass(self):
 | 
			
		||||
        remote = self.repo.create_remote("bb", "https://bitbucket.org/libgit2/testgitrepository.git")
 | 
			
		||||
        remote.credentials = UserPass("libgit2", "libgit2")
 | 
			
		||||
        class MyCallbacks(pygit2.RemoteCallbacks):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                self.credentials = UserPass("libgit2", "libgit2")
 | 
			
		||||
 | 
			
		||||
        remote.fetch()
 | 
			
		||||
        remote = self.repo.create_remote("bb", "https://bitbucket.org/libgit2/testgitrepository.git")
 | 
			
		||||
        remote.fetch(callbacks=MyCallbacks())
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
@@ -213,30 +213,37 @@ class EmptyRepositoryTest(utils.EmptyRepoTestCase):
 | 
			
		||||
 | 
			
		||||
    def test_transfer_progress(self):
 | 
			
		||||
        self.tp = None
 | 
			
		||||
        def tp_cb(stats):
 | 
			
		||||
            self.tp = stats
 | 
			
		||||
        class MyCallbacks(pygit2.RemoteCallbacks):
 | 
			
		||||
            def transfer_progress(self, stats):
 | 
			
		||||
                self.tp = stats
 | 
			
		||||
 | 
			
		||||
        callbacks = MyCallbacks()
 | 
			
		||||
        remote = self.repo.remotes[0]
 | 
			
		||||
        remote.transfer_progress = tp_cb
 | 
			
		||||
        stats = remote.fetch()
 | 
			
		||||
        self.assertEqual(stats.received_bytes, self.tp.received_bytes)
 | 
			
		||||
        self.assertEqual(stats.indexed_objects, self.tp.indexed_objects)
 | 
			
		||||
        self.assertEqual(stats.received_objects, self.tp.received_objects)
 | 
			
		||||
        stats = remote.fetch(callbacks=callbacks)
 | 
			
		||||
        self.assertEqual(stats.received_bytes, callbacks.tp.received_bytes)
 | 
			
		||||
        self.assertEqual(stats.indexed_objects, callbacks.tp.indexed_objects)
 | 
			
		||||
        self.assertEqual(stats.received_objects, callbacks.tp.received_objects)
 | 
			
		||||
 | 
			
		||||
    def test_update_tips(self):
 | 
			
		||||
        remote = self.repo.remotes[0]
 | 
			
		||||
        self.i = 0
 | 
			
		||||
        self.tips = [('refs/remotes/origin/master', Oid(hex='0'*40),
 | 
			
		||||
                      Oid(hex='784855caf26449a1914d2cf62d12b9374d76ae78')),
 | 
			
		||||
                     ('refs/tags/root', Oid(hex='0'*40),
 | 
			
		||||
                      Oid(hex='3d2962987c695a29f1f80b6c3aa4ec046ef44369'))]
 | 
			
		||||
        tips = [('refs/remotes/origin/master', Oid(hex='0'*40),
 | 
			
		||||
                 Oid(hex='784855caf26449a1914d2cf62d12b9374d76ae78')),
 | 
			
		||||
                ('refs/tags/root', Oid(hex='0'*40),
 | 
			
		||||
                 Oid(hex='3d2962987c695a29f1f80b6c3aa4ec046ef44369'))]
 | 
			
		||||
 | 
			
		||||
        def ut_cb(name, old, new):
 | 
			
		||||
            self.assertEqual(self.tips[self.i], (name, old, new))
 | 
			
		||||
            self.i += 1
 | 
			
		||||
        class MyCallbacks(pygit2.RemoteCallbacks):
 | 
			
		||||
            def __init__(self, test_self, tips):
 | 
			
		||||
                self.test = test_self
 | 
			
		||||
                self.tips = tips
 | 
			
		||||
                self.i = 0
 | 
			
		||||
 | 
			
		||||
        remote.update_tips = ut_cb
 | 
			
		||||
        remote.fetch()
 | 
			
		||||
            def update_tips(self, name, old, new):
 | 
			
		||||
                self.test.assertEqual(self.tips[self.i], (name, old, new))
 | 
			
		||||
                self.i += 1
 | 
			
		||||
 | 
			
		||||
        callbacks = MyCallbacks(self, tips)
 | 
			
		||||
        remote.fetch(callbacks=callbacks)
 | 
			
		||||
        self.assertTrue(callbacks.i > 0)
 | 
			
		||||
 | 
			
		||||
class PushTestCase(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user