Merge pull request #874 from smola/patchcontext

plumbing: add context to allow cancel on diff/patch computing
diff --git a/plumbing/object/change.go b/plumbing/object/change.go
index 729ff5a..a1b4c27 100644
--- a/plumbing/object/change.go
+++ b/plumbing/object/change.go
@@ -2,6 +2,7 @@
 
 import (
 	"bytes"
+	"context"
 	"fmt"
 	"strings"
 
@@ -81,7 +82,15 @@
 // Patch returns a Patch with all the file changes in chunks. This
 // representation can be used to create several diff outputs.
 func (c *Change) Patch() (*Patch, error) {
-	return getPatch("", c)
+	return c.PatchContext(context.Background())
+}
+
+// Patch returns a Patch with all the file changes in chunks. This
+// representation can be used to create several diff outputs.
+// If context expires, an non-nil error will be returned
+// Provided context must be non-nil
+func (c *Change) PatchContext(ctx context.Context) (*Patch, error) {
+	return getPatchContext(ctx, "", c)
 }
 
 func (c *Change) name() string {
@@ -136,5 +145,13 @@
 // Patch returns a Patch with all the changes in chunks. This
 // representation can be used to create several diff outputs.
 func (c Changes) Patch() (*Patch, error) {
-	return getPatch("", c...)
+	return c.PatchContext(context.Background())
+}
+
+// Patch returns a Patch with all the changes in chunks. This
+// representation can be used to create several diff outputs.
+// If context expires, an non-nil error will be returned
+// Provided context must be non-nil
+func (c Changes) PatchContext(ctx context.Context) (*Patch, error) {
+	return getPatchContext(ctx, "", c...)
 }
diff --git a/plumbing/object/change_test.go b/plumbing/object/change_test.go
index 7036fa3..b0e89c7 100644
--- a/plumbing/object/change_test.go
+++ b/plumbing/object/change_test.go
@@ -1,6 +1,7 @@
 package object
 
 import (
+	"context"
 	"sort"
 
 	"gopkg.in/src-d/go-git.v4/plumbing"
@@ -82,6 +83,12 @@
 	c.Assert(len(p.FilePatches()[0].Chunks()), Equals, 1)
 	c.Assert(p.FilePatches()[0].Chunks()[0].Type(), Equals, diff.Add)
 
+	p, err = change.PatchContext(context.Background())
+	c.Assert(err, IsNil)
+	c.Assert(len(p.FilePatches()), Equals, 1)
+	c.Assert(len(p.FilePatches()[0].Chunks()), Equals, 1)
+	c.Assert(p.FilePatches()[0].Chunks()[0].Type(), Equals, diff.Add)
+
 	str := change.String()
 	c.Assert(str, Equals, "<Action: Insert, Path: examples/clone/main.go>")
 }
@@ -134,6 +141,12 @@
 	c.Assert(len(p.FilePatches()[0].Chunks()), Equals, 1)
 	c.Assert(p.FilePatches()[0].Chunks()[0].Type(), Equals, diff.Delete)
 
+	p, err = change.PatchContext(context.Background())
+	c.Assert(err, IsNil)
+	c.Assert(len(p.FilePatches()), Equals, 1)
+	c.Assert(len(p.FilePatches()[0].Chunks()), Equals, 1)
+	c.Assert(p.FilePatches()[0].Chunks()[0].Type(), Equals, diff.Delete)
+
 	str := change.String()
 	c.Assert(str, Equals, "<Action: Delete, Path: utils/difftree/difftree.go>")
 }
@@ -206,6 +219,18 @@
 	c.Assert(p.FilePatches()[0].Chunks()[5].Type(), Equals, diff.Add)
 	c.Assert(p.FilePatches()[0].Chunks()[6].Type(), Equals, diff.Equal)
 
+	p, err = change.PatchContext(context.Background())
+	c.Assert(err, IsNil)
+	c.Assert(len(p.FilePatches()), Equals, 1)
+	c.Assert(len(p.FilePatches()[0].Chunks()), Equals, 7)
+	c.Assert(p.FilePatches()[0].Chunks()[0].Type(), Equals, diff.Equal)
+	c.Assert(p.FilePatches()[0].Chunks()[1].Type(), Equals, diff.Delete)
+	c.Assert(p.FilePatches()[0].Chunks()[2].Type(), Equals, diff.Add)
+	c.Assert(p.FilePatches()[0].Chunks()[3].Type(), Equals, diff.Equal)
+	c.Assert(p.FilePatches()[0].Chunks()[4].Type(), Equals, diff.Delete)
+	c.Assert(p.FilePatches()[0].Chunks()[5].Type(), Equals, diff.Add)
+	c.Assert(p.FilePatches()[0].Chunks()[6].Type(), Equals, diff.Equal)
+
 	str := change.String()
 	c.Assert(str, Equals, "<Action: Modify, Path: utils/difftree/difftree.go>")
 }
@@ -367,3 +392,39 @@
 	sort.Sort(changes)
 	c.Assert(changes.String(), Equals, expected)
 }
+
+func (s *ChangeSuite) TestCancel(c *C) {
+	// Commit a5078b19f08f63e7948abd0a5e2fb7d319d3a565 of the go-git
+	// fixture inserted "examples/clone/main.go".
+	//
+	// On that commit, the "examples/clone" tree is
+	//     6efca3ff41cab651332f9ebc0c96bb26be809615
+	//
+	// and the "examples/colone/main.go" is
+	//     f95dc8f7923add1a8b9f72ecb1e8db1402de601a
+
+	path := "examples/clone/main.go"
+	name := "main.go"
+	mode := filemode.Regular
+	blob := plumbing.NewHash("f95dc8f7923add1a8b9f72ecb1e8db1402de601a")
+	tree := plumbing.NewHash("6efca3ff41cab651332f9ebc0c96bb26be809615")
+
+	change := &Change{
+		From: empty,
+		To: ChangeEntry{
+			Name: path,
+			Tree: s.tree(c, tree),
+			TreeEntry: TreeEntry{
+				Name: name,
+				Mode: mode,
+				Hash: blob,
+			},
+		},
+	}
+
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+	p, err := change.PatchContext(ctx)
+	c.Assert(p, IsNil)
+	c.Assert(err, ErrorMatches, "operation canceled")
+}
diff --git a/plumbing/object/commit.go b/plumbing/object/commit.go
index c9a4c0e..3ed85ba 100644
--- a/plumbing/object/commit.go
+++ b/plumbing/object/commit.go
@@ -3,6 +3,7 @@
 import (
 	"bufio"
 	"bytes"
+	"context"
 	"errors"
 	"fmt"
 	"io"
@@ -75,7 +76,8 @@
 }
 
 // Patch returns the Patch between the actual commit and the provided one.
-func (c *Commit) Patch(to *Commit) (*Patch, error) {
+// Error will be return if context expires. Provided context must be non-nil
+func (c *Commit) PatchContext(ctx context.Context, to *Commit) (*Patch, error) {
 	fromTree, err := c.Tree()
 	if err != nil {
 		return nil, err
@@ -86,7 +88,12 @@
 		return nil, err
 	}
 
-	return fromTree.Patch(toTree)
+	return fromTree.PatchContext(ctx, toTree)
+}
+
+// Patch returns the Patch between the actual commit and the provided one.
+func (c *Commit) Patch(to *Commit) (*Patch, error) {
+	return c.PatchContext(context.Background(), to)
 }
 
 // Parents return a CommitIter to the parent Commits.
diff --git a/plumbing/object/commit_test.go b/plumbing/object/commit_test.go
index 191b14d..996d481 100644
--- a/plumbing/object/commit_test.go
+++ b/plumbing/object/commit_test.go
@@ -2,6 +2,7 @@
 
 import (
 	"bytes"
+	"context"
 	"io"
 	"strings"
 	"time"
@@ -132,6 +133,59 @@
 	c.Assert(buf.String(), Equals, patch.String())
 }
 
+func (s *SuiteCommit) TestPatchContext(c *C) {
+	from := s.commit(c, plumbing.NewHash("918c48b83bd081e863dbe1b80f8998f058cd8294"))
+	to := s.commit(c, plumbing.NewHash("6ecf0ef2c2dffb796033e5a02219af86ec6584e5"))
+
+	patch, err := from.PatchContext(context.Background(), to)
+	c.Assert(err, IsNil)
+
+	buf := bytes.NewBuffer(nil)
+	err = patch.Encode(buf)
+	c.Assert(err, IsNil)
+
+	c.Assert(buf.String(), Equals, `diff --git a/vendor/foo.go b/vendor/foo.go
+new file mode 100644
+index 0000000000000000000000000000000000000000..9dea2395f5403188298c1dabe8bdafe562c491e3
+--- /dev/null
++++ b/vendor/foo.go
+@@ -0,0 +1,7 @@
++package main
++
++import "fmt"
++
++func main() {
++	fmt.Println("Hello, playground")
++}
+`)
+	c.Assert(buf.String(), Equals, patch.String())
+
+	from = s.commit(c, plumbing.NewHash("b8e471f58bcbca63b07bda20e428190409c2db47"))
+	to = s.commit(c, plumbing.NewHash("35e85108805c84807bc66a02d91535e1e24b38b9"))
+
+	patch, err = from.PatchContext(context.Background(), to)
+	c.Assert(err, IsNil)
+
+	buf.Reset()
+	err = patch.Encode(buf)
+	c.Assert(err, IsNil)
+
+	c.Assert(buf.String(), Equals, `diff --git a/CHANGELOG b/CHANGELOG
+deleted file mode 100644
+index d3ff53e0564a9f87d8e84b6e28e5060e517008aa..0000000000000000000000000000000000000000
+--- a/CHANGELOG
++++ /dev/null
+@@ -1 +0,0 @@
+-Initial changelog
+diff --git a/binary.jpg b/binary.jpg
+new file mode 100644
+index 0000000000000000000000000000000000000000..d5c0f4ab811897cadf03aec358ae60d21f91c50d
+Binary files /dev/null and b/binary.jpg differ
+`)
+
+	c.Assert(buf.String(), Equals, patch.String())
+}
+
 func (s *SuiteCommit) TestCommitEncodeDecodeIdempotent(c *C) {
 	ts, err := time.Parse(time.RFC3339, "2006-01-02T15:04:05-07:00")
 	c.Assert(err, IsNil)
@@ -363,3 +417,15 @@
 	_, ok := e.Identities["Sunny <me@darkowlzz.space>"]
 	c.Assert(ok, Equals, true)
 }
+
+func (s *SuiteCommit) TestPatchCancel(c *C) {
+	from := s.commit(c, plumbing.NewHash("918c48b83bd081e863dbe1b80f8998f058cd8294"))
+	to := s.commit(c, plumbing.NewHash("6ecf0ef2c2dffb796033e5a02219af86ec6584e5"))
+
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel()
+	patch, err := from.PatchContext(ctx, to)
+	c.Assert(patch, IsNil)
+	c.Assert(err, ErrorMatches, "operation canceled")
+
+}
diff --git a/plumbing/object/difftree.go b/plumbing/object/difftree.go
index ac58c4d..a30a29e 100644
--- a/plumbing/object/difftree.go
+++ b/plumbing/object/difftree.go
@@ -2,6 +2,7 @@
 
 import (
 	"bytes"
+	"context"
 
 	"gopkg.in/src-d/go-git.v4/utils/merkletrie"
 	"gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
@@ -10,6 +11,13 @@
 // DiffTree compares the content and mode of the blobs found via two
 // tree objects.
 func DiffTree(a, b *Tree) (Changes, error) {
+	return DiffTreeContext(context.Background(), a, b)
+}
+
+// DiffTree compares the content and mode of the blobs found via two
+// tree objects. Provided context must be non-nil.
+// An error will be return if context expires
+func DiffTreeContext(ctx context.Context, a, b *Tree) (Changes, error) {
 	from := NewTreeRootNode(a)
 	to := NewTreeRootNode(b)
 
@@ -17,8 +25,11 @@
 		return bytes.Equal(a.Hash(), b.Hash())
 	}
 
-	merkletrieChanges, err := merkletrie.DiffTree(from, to, hashEqual)
+	merkletrieChanges, err := merkletrie.DiffTreeContext(ctx, from, to, hashEqual)
 	if err != nil {
+		if err == merkletrie.ErrCanceled {
+			return nil, ErrCanceled
+		}
 		return nil, err
 	}
 
diff --git a/plumbing/object/patch.go b/plumbing/object/patch.go
index aa96a96..adeaccb 100644
--- a/plumbing/object/patch.go
+++ b/plumbing/object/patch.go
@@ -2,6 +2,8 @@
 
 import (
 	"bytes"
+	"context"
+	"errors"
 	"fmt"
 	"io"
 	"math"
@@ -15,10 +17,25 @@
 	dmp "github.com/sergi/go-diff/diffmatchpatch"
 )
 
+var (
+	ErrCanceled = errors.New("operation canceled")
+)
+
 func getPatch(message string, changes ...*Change) (*Patch, error) {
+	ctx := context.Background()
+	return getPatchContext(ctx, message, changes...)
+}
+
+func getPatchContext(ctx context.Context, message string, changes ...*Change) (*Patch, error) {
 	var filePatches []fdiff.FilePatch
 	for _, c := range changes {
-		fp, err := filePatch(c)
+		select {
+		case <-ctx.Done():
+			return nil, ErrCanceled
+		default:
+		}
+
+		fp, err := filePatchWithContext(ctx, c)
 		if err != nil {
 			return nil, err
 		}
@@ -29,7 +46,7 @@
 	return &Patch{message, filePatches}, nil
 }
 
-func filePatch(c *Change) (fdiff.FilePatch, error) {
+func filePatchWithContext(ctx context.Context, c *Change) (fdiff.FilePatch, error) {
 	from, to, err := c.Files()
 	if err != nil {
 		return nil, err
@@ -52,6 +69,12 @@
 
 	var chunks []fdiff.Chunk
 	for _, d := range diffs {
+		select {
+		case <-ctx.Done():
+			return nil, ErrCanceled
+		default:
+		}
+
 		var op fdiff.Operation
 		switch d.Type {
 		case dmp.DiffEqual:
@@ -70,6 +93,11 @@
 		from:   c.From,
 		to:     c.To,
 	}, nil
+
+}
+
+func filePatch(c *Change) (fdiff.FilePatch, error) {
+	return filePatchWithContext(context.Background(), c)
 }
 
 func fileContent(f *File) (content string, isBinary bool, err error) {
diff --git a/plumbing/object/tree.go b/plumbing/object/tree.go
index 30bbcb0..86d19c0 100644
--- a/plumbing/object/tree.go
+++ b/plumbing/object/tree.go
@@ -2,6 +2,7 @@
 
 import (
 	"bufio"
+	"context"
 	"errors"
 	"fmt"
 	"io"
@@ -295,15 +296,30 @@
 	return DiffTree(from, to)
 }
 
+// Diff returns a list of changes between this tree and the provided one
+// Error will be returned if context expires
+// Provided context must be non nil
+func (from *Tree) DiffContext(ctx context.Context, to *Tree) (Changes, error) {
+	return DiffTreeContext(ctx, from, to)
+}
+
 // Patch returns a slice of Patch objects with all the changes between trees
 // in chunks. This representation can be used to create several diff outputs.
 func (from *Tree) Patch(to *Tree) (*Patch, error) {
-	changes, err := DiffTree(from, to)
+	return from.PatchContext(context.Background(), to)
+}
+
+// Patch returns a slice of Patch objects with all the changes between trees
+// in chunks. This representation can be used to create several diff outputs.
+// If context expires, an error will be returned
+// Provided context must be non-nil
+func (from *Tree) PatchContext(ctx context.Context, to *Tree) (*Patch, error) {
+	changes, err := DiffTreeContext(ctx, from, to)
 	if err != nil {
 		return nil, err
 	}
 
-	return changes.Patch()
+	return changes.PatchContext(ctx)
 }
 
 // treeEntryIter facilitates iterating through the TreeEntry objects in a Tree.
diff --git a/utils/merkletrie/difftree.go b/utils/merkletrie/difftree.go
index 2294096..d57ed13 100644
--- a/utils/merkletrie/difftree.go
+++ b/utils/merkletrie/difftree.go
@@ -248,15 +248,30 @@
 // h: else of i
 
 import (
+	"context"
+	"errors"
 	"fmt"
 
 	"gopkg.in/src-d/go-git.v4/utils/merkletrie/noder"
 )
 
+var (
+	ErrCanceled = errors.New("operation canceled")
+)
+
 // DiffTree calculates the list of changes between two merkletries.  It
 // uses the provided hashEqual callback to compare noders.
 func DiffTree(fromTree, toTree noder.Noder,
 	hashEqual noder.Equal) (Changes, error) {
+	return DiffTreeContext(context.Background(), fromTree, toTree, hashEqual)
+}
+
+// DiffTree calculates the list of changes between two merkletries.  It
+// uses the provided hashEqual callback to compare noders.
+// Error will be returned if context expires
+// Provided context must be non nil
+func DiffTreeContext(ctx context.Context, fromTree, toTree noder.Noder,
+	hashEqual noder.Equal) (Changes, error) {
 	ret := NewChanges()
 
 	ii, err := newDoubleIter(fromTree, toTree, hashEqual)
@@ -265,6 +280,12 @@
 	}
 
 	for {
+		select {
+		case <-ctx.Done():
+			return nil, ErrCanceled
+		default:
+		}
+
 		from := ii.from.current
 		to := ii.to.current
 
diff --git a/utils/merkletrie/difftree_test.go b/utils/merkletrie/difftree_test.go
index 9f033b1..ab0eb57 100644
--- a/utils/merkletrie/difftree_test.go
+++ b/utils/merkletrie/difftree_test.go
@@ -2,6 +2,7 @@
 
 import (
 	"bytes"
+	ctx "context"
 	"fmt"
 	"reflect"
 	"sort"
@@ -61,9 +62,45 @@
 	c.Assert(obtained, changesEquals, expected, comment)
 }
 
+func (t diffTreeTest) innerRunCtx(c *C, context string, reverse bool) {
+	comment := Commentf("\n%s", context)
+	if reverse {
+		comment = Commentf("%s [REVERSED]", comment.CheckCommentString())
+	}
+
+	a, err := fsnoder.New(t.from)
+	c.Assert(err, IsNil, comment)
+	comment = Commentf("%s\n\t    from = %s", comment.CheckCommentString(), a)
+
+	b, err := fsnoder.New(t.to)
+	c.Assert(err, IsNil, comment)
+	comment = Commentf("%s\n\t      to = %s", comment.CheckCommentString(), b)
+
+	expected, err := newChangesFromString(t.expected)
+	c.Assert(err, IsNil, comment)
+
+	if reverse {
+		a, b = b, a
+		expected = expected.reverse()
+	}
+	comment = Commentf("%s\n\texpected = %s", comment.CheckCommentString(), expected)
+
+	results, err := merkletrie.DiffTreeContext(ctx.Background(), a, b, fsnoder.HashEqual)
+	c.Assert(err, IsNil, comment)
+
+	obtained, err := newChanges(results)
+	c.Assert(err, IsNil, comment)
+
+	comment = Commentf("%s\n\tobtained = %s", comment.CheckCommentString(), obtained)
+
+	c.Assert(obtained, changesEquals, expected, comment)
+}
+
 func (t diffTreeTest) run(c *C, context string) {
 	t.innerRun(c, context, false)
 	t.innerRun(c, context, true)
+	t.innerRunCtx(c, context, false)
+	t.innerRunCtx(c, context, true)
 }
 
 type change struct {
@@ -437,3 +474,27 @@
 		},
 	})
 }
+
+func (s *DiffTreeSuite) TestCancel(c *C) {
+	t :=  diffTreeTest{"()", "(a<> b<1> c() d<> e<2> f())", "+a +b +d +e"}
+	comment := Commentf("\n%s", "test cancel:")
+
+	a, err := fsnoder.New(t.from)
+	c.Assert(err, IsNil, comment)
+	comment = Commentf("%s\n\t    from = %s", comment.CheckCommentString(), a)
+
+	b, err := fsnoder.New(t.to)
+	c.Assert(err, IsNil, comment)
+	comment = Commentf("%s\n\t      to = %s", comment.CheckCommentString(), b)
+
+	expected, err := newChangesFromString(t.expected)
+	c.Assert(err, IsNil, comment)
+
+	comment = Commentf("%s\n\texpected = %s", comment.CheckCommentString(), expected)
+	context, cancel := ctx.WithCancel(ctx.Background())
+	cancel()
+	results, err := merkletrie.DiffTreeContext(context, a, b, fsnoder.HashEqual)
+	c.Assert(results, IsNil, comment)
+	c.Assert(err, ErrorMatches, "operation canceled")
+
+}