diff --git a/src/tree.c b/src/tree.c index 13309be..4ffa5c9 100644 --- a/src/tree.c +++ b/src/tree.c @@ -332,21 +332,29 @@ Tree_diff_to_tree(Tree *self, PyObject *args, PyObject *kwds) { git_diff_options opts = GIT_DIFF_OPTIONS_INIT; git_diff_list *diff; - git_tree* tree; + git_tree *from, *to, *tmp; git_repository* repo; - int err; - char *keywords[] = {"obj", "flags", NULL}; + int err, swap = 0; + char *keywords[] = {"obj", "flags", "swap", NULL}; Diff *py_diff; Tree *py_tree = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O!i", keywords, - &TreeType, &py_tree, &opts.flags)) + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O!ii", keywords, + &TreeType, &py_tree, &opts.flags, + &swap)) return NULL; repo = self->repo->repo; - tree = (py_tree == NULL) ? NULL : py_tree->tree; - err = git_diff_tree_to_tree(&diff, repo, self->tree, tree, &opts); + to = (py_tree == NULL) ? NULL : py_tree->tree; + from = self->tree; + if (swap > 0) { + tmp = from; + from = to; + to = tmp; + } + + err = git_diff_tree_to_tree(&diff, repo, from, to, &opts); if (err < 0) return Error_set(err); diff --git a/test/test_diff.py b/test/test_diff.py index d68d7ed..b336624 100644 --- a/test/test_diff.py +++ b/test/test_diff.py @@ -33,6 +33,7 @@ import unittest import pygit2 from pygit2 import GIT_DIFF_INCLUDE_UNMODIFIED from . import utils +from itertools import chain COMMIT_SHA1_1 = '5fe808e8953c12735680c257f56600cb0de44b10' @@ -145,8 +146,20 @@ class DiffTest(utils.BareRepoTestCase): def test_diff_empty_tree(self): commit_a = self.repo[COMMIT_SHA1_1] diff = commit_a.tree.diff_to_tree() + + def get_context_for_lines(diff): + hunks = chain(*map(lambda x: x.hunks, [p for p in diff])) + lines = chain(*map(lambda x: x.lines, hunks)) + return map(lambda x: x[0], lines) + entries = [p.new_file_path for p in diff] self.assertAll(lambda x: commit_a.tree[x], entries) + self.assertAll(lambda x: '-' == x, get_context_for_lines(diff)) + + diff_swaped = commit_a.tree.diff_to_tree(swap=True) + entries = [p.new_file_path for p in diff_swaped] + self.assertAll(lambda x: commit_a.tree[x], entries) + self.assertAll(lambda x: '+' == x, get_context_for_lines(diff_swaped)) def test_diff_tree_opts(self): commit_c = self.repo[COMMIT_SHA1_3]