diff --git a/docs/remotes.rst b/docs/remotes.rst index 089ab4a..b136fee 100644 --- a/docs/remotes.rst +++ b/docs/remotes.rst @@ -16,6 +16,9 @@ The Remote type .. autoattribute:: pygit2.Remote.refspec_count .. autoattribute:: pygit2.Remote.push_refspecs .. autoattribute:: pygit2.Remote.fetch_refspecs +.. autoattribute:: pygit2.Remote.progress +.. autoattribute:: pygit2.Remote.transfer_progress +.. autoattribute:: pygit2.Remote.update_tips .. automethod:: pygit2.Remote.get_push_refspecs .. automethod:: pygit2.Remote.get_fetch_refspecs .. automethod:: pygit2.Remote.set_push_refspecs diff --git a/src/remote.c b/src/remote.c index cafa648..1b82586 100644 --- a/src/remote.c +++ b/src/remote.c @@ -32,6 +32,7 @@ #include "utils.h" #include "types.h" #include "remote.h" +#include "oid.h" extern PyObject *GitError; @@ -294,6 +295,93 @@ PyTypeObject RefspecType = { 0, /* tp_new */ }; +static int +progress_cb(const char *str, int len, void *data) +{ + Remote *remote = (Remote *) data; + PyObject *arglist, *ret; + + if (remote->progress == NULL) + return 0; + + if (!PyCallable_Check(remote->progress)) { + PyErr_SetString(PyExc_TypeError, "progress callback is not callable"); + return -1; + } + + arglist = Py_BuildValue("(s#)", str, len); + ret = PyObject_CallObject(remote->progress, arglist); + Py_DECREF(arglist); + + if (!ret) + return -1; + + Py_DECREF(ret); + + return 0; +} + +static int +transfer_progress_cb(const git_transfer_progress *stats, void *data) +{ + Remote *remote = (Remote *) data; + PyObject *arglist, *ret; + + if (remote->transfer_progress == NULL) + return 0; + + if (!PyCallable_Check(remote->transfer_progress)) { + PyErr_SetString(PyExc_TypeError, "transfer progress callback is not callable"); + return -1; + } + + arglist = Py_BuildValue("({s:I,s:I,s:n})", + "indexed_objects", stats->indexed_objects, + "received_objects", stats->received_objects, + "received_bytes", stats->received_bytes); + + ret = PyObject_CallObject(remote->transfer_progress, arglist); + Py_DECREF(arglist); + + if (!ret) + return -1; + + Py_DECREF(ret); + + return 0; +} + +static int +update_tips_cb(const char *refname, const git_oid *a, const git_oid *b, void *data) +{ + Remote *remote = (Remote *) data; + PyObject *ret; + PyObject *old, *new; + + if (remote->update_tips == NULL) + return 0; + + if (!PyCallable_Check(remote->update_tips)) { + PyErr_SetString(PyExc_TypeError, "update tips callback is not callable"); + return -1; + } + + old = git_oid_to_python(a); + new = git_oid_to_python(b); + + ret = PyObject_CallFunction(remote->update_tips, "(s,O,O)", refname, old ,new); + + Py_DECREF(old); + Py_DECREF(new); + + if (!ret) + return -1; + + Py_DECREF(ret); + + return 0; +} + PyObject * Remote_init(Remote *self, PyObject *args, PyObject *kwds) { @@ -311,19 +399,37 @@ Remote_init(Remote *self, PyObject *args, PyObject *kwds) if (err < 0) return Error_set(err); + self->progress = NULL; + self->transfer_progress = NULL; + self->update_tips = NULL; + + Remote_set_callbacks(self); return (PyObject*) self; } +void +Remote_set_callbacks(Remote *self) +{ + git_remote_callbacks callbacks = GIT_REMOTE_CALLBACKS_INIT; + + self->progress = NULL; + + callbacks.progress = progress_cb; + callbacks.transfer_progress = transfer_progress_cb; + callbacks.update_tips = update_tips_cb; + callbacks.payload = self; + git_remote_set_callbacks(self->remote, &callbacks); +} static void Remote_dealloc(Remote *self) { Py_CLEAR(self->repo); + Py_CLEAR(self->progress); git_remote_free(self->remote); PyObject_Del(self); } - PyDoc_STRVAR(Remote_name__doc__, "Name of the remote refspec"); PyObject * @@ -671,24 +777,23 @@ Remote_fetch(Remote *self, PyObject *args) const git_transfer_progress *stats; int err; - err = git_remote_connect(self->remote, GIT_DIRECTION_FETCH); - if (err == GIT_OK) { - err = git_remote_download(self->remote); - if (err == GIT_OK) { - stats = git_remote_stats(self->remote); - py_stats = Py_BuildValue("{s:I,s:I,s:n}", - "indexed_objects", stats->indexed_objects, - "received_objects", stats->received_objects, - "received_bytes", stats->received_bytes); - - err = git_remote_update_tips(self->remote); - } - git_remote_disconnect(self->remote); - } - + PyErr_Clear(); + err = git_remote_fetch(self->remote); + /* + * XXX: We should be checking for GIT_EUSER, but on v0.20, this does not + * make it all the way to us for update_tips + */ + if (err < 0 && PyErr_Occurred()) + return NULL; if (err < 0) return Error_set(err); + stats = git_remote_stats(self->remote); + py_stats = Py_BuildValue("{s:I,s:I,s:n}", + "indexed_objects", stats->indexed_objects, + "received_objects", stats->received_objects, + "received_bytes", stats->received_bytes); + return (PyObject*) py_stats; } @@ -848,6 +953,13 @@ PyGetSetDef Remote_getseters[] = { {NULL} }; +PyMemberDef Remote_members[] = { + MEMBER(Remote, progress, T_OBJECT_EX, "Progress output callback"), + MEMBER(Remote, transfer_progress, T_OBJECT_EX, "Transfer progress callback"), + MEMBER(Remote, update_tips, T_OBJECT_EX, "update tips callback"), + {NULL}, +}; + PyDoc_STRVAR(Remote__doc__, "Remote object."); PyTypeObject RemoteType = { @@ -879,7 +991,7 @@ PyTypeObject RemoteType = { 0, /* tp_iter */ 0, /* tp_iternext */ Remote_methods, /* tp_methods */ - 0, /* tp_members */ + Remote_members, /* tp_members */ Remote_getseters, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ diff --git a/src/remote.h b/src/remote.h index 7f6e131..0ab41a5 100644 --- a/src/remote.h +++ b/src/remote.h @@ -36,4 +36,6 @@ PyObject* Remote_init(Remote *self, PyObject *args, PyObject *kwds); PyObject* Remote_fetch(Remote *self, PyObject *args); +void Remote_set_callbacks(Remote *self); + #endif diff --git a/src/repository.c b/src/repository.c index fa0dd7e..f52dd66 100644 --- a/src/repository.c +++ b/src/repository.c @@ -1295,6 +1295,7 @@ Repository_create_remote(Repository *self, PyObject *args) Py_INCREF(self); py_remote->repo = self; py_remote->remote = remote; + Remote_set_callbacks(py_remote); return (PyObject*) py_remote; } diff --git a/src/types.h b/src/types.h index 63a5672..c8f7a6f 100644 --- a/src/types.h +++ b/src/types.h @@ -196,7 +196,15 @@ typedef struct { /* git_remote */ -SIMPLE_TYPE(Remote, git_remote, remote) +typedef struct { + PyObject_HEAD + Repository *repo; + git_remote *remote; + /* Callbacks for network events */ + PyObject *progress; + PyObject *transfer_progress; + PyObject *update_tips; +} Remote; /* git_refspec */ typedef struct { diff --git a/test/test_remote.py b/test/test_remote.py index 267c22e..ce146e0 100644 --- a/test/test_remote.py +++ b/test/test_remote.py @@ -30,6 +30,7 @@ import unittest import pygit2 +from pygit2 import Oid from . import utils REMOTE_NAME = 'origin' @@ -173,6 +174,20 @@ class RepositoryTest(utils.RepoTestCase): self.assertEqual('+refs/heads/*:refs/remotes/test_refspec/*', remote.get_fetch_refspecs()[1]) + def test_remote_callback_typecheck(self): + remote = self.repo.remotes[0] + remote.progress = 5 + self.assertRaises(TypeError, remote, 'fetch') + + remote = self.repo.remotes[0] + remote.transfer_progress = 5 + self.assertRaises(TypeError, remote, 'fetch') + + remote = self.repo.remotes[0] + remote.update_tips = 5 + self.assertRaises(TypeError, remote, 'fetch') + + class EmptyRepositoryTest(utils.EmptyRepoTestCase): def test_fetch(self): @@ -182,6 +197,32 @@ class EmptyRepositoryTest(utils.EmptyRepoTestCase): self.assertEqual(stats['indexed_objects'], REMOTE_REPO_OBJECTS) self.assertEqual(stats['received_objects'], REMOTE_REPO_OBJECTS) + def test_transfer_progress(self): + self.tp = None + def tp_cb(stats): + self.tp = stats + + remote = self.repo.remotes[0] + remote.transfer_progress = tp_cb + stats = remote.fetch() + self.assertEqual(stats['received_bytes'], self.tp.received_bytes) + self.assertEqual(stats['indexed_objects'], self.tp.indexed_objects) + self.assertEqual(stats['received_objects'], self.tp.received_objects) + + def test_update_tips(self): + remote = self.repo.remotes[0] + self.i = 0 + self.tips = [('refs/remotes/origin/master', Oid(hex='0'*40), + Oid(hex='784855caf26449a1914d2cf62d12b9374d76ae78')), + ('refs/tags/root', Oid(hex='0'*40), + Oid(hex='3d2962987c695a29f1f80b6c3aa4ec046ef44369'))] + + def ut_cb(name, old, new): + self.assertEqual(self.tips[self.i], (name, old, new)) + self.i += 1 + + remote.update_tips = ut_cb + remote.fetch() class PushTestCase(unittest.TestCase): def setUp(self):