From b828eb926f38d98c7d4a0b41fbb75b3113816fa4 Mon Sep 17 00:00:00 2001
From: Nico von Geyso <Nico.Geyso@FU-Berlin.de>
Date: Sun, 17 Feb 2013 14:23:41 +0100
Subject: [PATCH] added functionality to checkout from HEAD

---
 src/repository.c        | 40 ++++++++++++++++++++++++----------------
 test/test_repository.py | 20 +++++++++++++++-----
 2 files changed, 39 insertions(+), 21 deletions(-)

diff --git a/src/repository.c b/src/repository.c
index 38451c8..cf816d9 100644
--- a/src/repository.c
+++ b/src/repository.c
@@ -1068,34 +1068,42 @@ Repository_remotes__get__(Repository *self)
 
 
 PyDoc_STRVAR(Repository_checkout__doc__,
-  "checkout(reference:Reference, [strategy:int])\n"
+  "checkout([strategy:int, reference:Reference])\n"
   "\n"
-  "Checks out a tree by a given reference and modifies the HEAD pointer.\n"
-  "Standard checkout strategy is pygit2.GIT_CHECKOUT_SAFE_CREATE");
+  "Checks out a tree by a given reference and modifies the HEAD pointer\n"
+  "Standard checkout strategy is pygit2.GIT_CHECKOUT_SAFE_CREATE\n"
+  "If no reference is given, checkout will use HEAD instead.");
 
 PyObject *
-Repository_checkout(Repository *self, PyObject *args)
+Repository_checkout(Repository *self, PyObject *args, PyObject *kw)
 {
     git_checkout_opts opts = GIT_CHECKOUT_OPTS_INIT;
     unsigned int strategy = GIT_CHECKOUT_SAFE_CREATE;
-    Reference* ref;
+    Reference* ref = NULL;
     git_object* object;
     const git_oid* id;
     int err;
 
-    if (!PyArg_ParseTuple(args, "O!|I", &ReferenceType, &ref, &strategy))
+    static char *kwlist[] = {"strategy", "reference", NULL};
+
+    if (!PyArg_ParseTupleAndKeywords(args, kw, "|IO!", kwlist,
+                                     &strategy, &ReferenceType, &ref))
         return NULL;
 
-    CHECK_REFERENCE(ref);
-
-    id = git_reference_target(ref->reference);
-    err = git_object_lookup(&object, self->repo, id, GIT_OBJ_COMMIT);
-    if(err == GIT_OK) {
+    if (ref != NULL) { // checkout from treeish
+        id = git_reference_target(ref->reference);
+        err = git_object_lookup(&object, self->repo, id, GIT_OBJ_COMMIT);
+        if(err == GIT_OK) {
+            opts.checkout_strategy = strategy;
+            err = git_checkout_tree(self->repo, object, &opts);
+            if (err == GIT_OK) {
+                err = git_repository_set_head(self->repo,
+                          git_reference_name(ref->reference));
+            }
+        }
+    } else { // checkout from head
         opts.checkout_strategy = strategy;
-        err = git_checkout_tree(self->repo, object, &opts);
-        if (err == GIT_OK)
-            err = git_repository_set_head(self->repo,
-                      git_reference_name(ref->reference));
+        err = git_checkout_head(self->repo, &opts);
     }
 
     if(err < 0)
@@ -1122,7 +1130,7 @@ PyMethodDef Repository_methods[] = {
     METHOD(Repository, status, METH_NOARGS),
     METHOD(Repository, status_file, METH_O),
     METHOD(Repository, create_remote, METH_VARARGS),
-    METHOD(Repository, checkout, METH_VARARGS),
+    METHOD(Repository, checkout, METH_VARARGS|METH_KEYWORDS),
     {NULL}
 };
 
diff --git a/test/test_repository.py b/test/test_repository.py
index a05d9e3..e85e6a5 100644
--- a/test/test_repository.py
+++ b/test/test_repository.py
@@ -187,16 +187,26 @@ class RepositoryTest_II(utils.RepoTestCase):
     def test_checkout(self):
         ref_i18n = self.repo.lookup_reference('refs/heads/i18n')
 
-        self.assertRaises(pygit2.GitError, self.repo.checkout, ref_i18n)
+        # checkout i18n with conflicts and default strategy should
+        # not be possible
+        self.assertRaises(pygit2.GitError,
+                          lambda: self.repo.checkout(reference=ref_i18n))
 
-        self.repo.checkout(ref_i18n, pygit2.GIT_CHECKOUT_FORCE)
+        # checkout i18n with GIT_CHECKOUT_FORCE
+        self.assertTrue('new' not in self.repo.head.tree)
+        self.repo.checkout(pygit2.GIT_CHECKOUT_FORCE, ref_i18n)
         self.assertEqual(self.repo.head.hex, self.repo[ref_i18n.target].hex)
         self.assertTrue('new' in self.repo.head.tree)
         self.assertTrue('bye.txt' not in self.repo.status())
 
-        ref_master = self.repo.lookup_reference('refs/heads/master')
-        self.repo.checkout(ref_master, pygit2.GIT_CHECKOUT_FORCE)
-        self.assertTrue('new' not in self.repo.head.tree)
+        # some changes to working dir
+        with open(os.path.join(self.repo.workdir, 'bye.txt'), 'w') as f:
+          f.write('new content')
+
+        # checkout head
+        self.assertTrue('bye.txt' in self.repo.status())
+        self.repo.checkout(pygit2.GIT_CHECKOUT_FORCE)
+        self.assertTrue('bye.txt' not in self.repo.status())
 
 class NewRepositoryTest(utils.NoRepoTestCase):
     def test_new_repo(self):