Merge pull request #572 from mcuadros/reset

Worktree.Reset refactor and Soft, Merge, Hard and Mixed modes
diff --git a/options.go b/options.go
index 0ec18d4..02fb926 100644
--- a/options.go
+++ b/options.go
@@ -238,13 +238,13 @@
 type ResetMode int8
 
 const (
-	// HardReset resets the index and working tree. Any changes to tracked files
-	// in the working tree are discarded.
-	HardReset ResetMode = iota
 	// MixedReset resets the index but not the working tree (i.e., the changed
 	// files are preserved but not marked for commit) and reports what has not
 	// been updated. This is the default action.
-	MixedReset
+	MixedReset ResetMode = iota
+	// HardReset resets the index and working tree. Any changes to tracked files
+	// in the working tree are discarded.
+	HardReset
 	// MergeReset resets the index and updates the files in the working tree
 	// that are different between Commit and HEAD, but keeps those which are
 	// different between the index and working tree (i.e. which have changes
@@ -253,6 +253,10 @@
 	// If a file that is different between Commit and the index has unstaged
 	// changes, reset is aborted.
 	MergeReset
+	// SoftReset does not touch the index file or the working tree at all (but
+	// resets the head to <commit>, just like all modes do). This leaves all
+	// your changed files "Changes to be committed", as git status would put it.
+	SoftReset
 )
 
 // ResetOptions describes how a reset operation should be performed.
diff --git a/repository.go b/repository.go
index 932b8d4..fbc7871 100644
--- a/repository.go
+++ b/repository.go
@@ -456,7 +456,10 @@
 			return err
 		}
 
-		if err := w.Reset(&ResetOptions{Commit: head.Hash()}); err != nil {
+		if err := w.Reset(&ResetOptions{
+			Mode:   MergeReset,
+			Commit: head.Hash(),
+		}); err != nil {
 			return err
 		}
 
diff --git a/submodule.go b/submodule.go
index fd3d173..de8ac73 100644
--- a/submodule.go
+++ b/submodule.go
@@ -62,14 +62,17 @@
 }
 
 func (s *Submodule) status(idx *index.Index) (*SubmoduleStatus, error) {
+	status := &SubmoduleStatus{
+		Path: s.c.Path,
+	}
+
 	e, err := idx.Entry(s.c.Path)
-	if err != nil {
+	if err != nil && err != index.ErrEntryNotFound {
 		return nil, err
 	}
 
-	status := &SubmoduleStatus{
-		Path:     s.c.Path,
-		Expected: e.Hash,
+	if e != nil {
+		status.Expected = e.Hash
 	}
 
 	if !s.initialized {
diff --git a/worktree.go b/worktree.go
index 4f8e740..e2f8562 100644
--- a/worktree.go
+++ b/worktree.go
@@ -107,7 +107,10 @@
 		return err
 	}
 
-	if err := w.Reset(&ResetOptions{Commit: ref.Hash()}); err != nil {
+	if err := w.Reset(&ResetOptions{
+		Mode:   MergeReset,
+		Commit: ref.Hash(),
+	}); err != nil {
 		return err
 	}
 
@@ -270,17 +273,88 @@
 		}
 	}
 
-	changes, err := w.diffCommitWithStaging(opts.Commit, true)
+	if err := w.setHEADCommit(opts.Commit); err != nil {
+		return err
+	}
+
+	if opts.Mode == SoftReset {
+		return nil
+	}
+
+	t, err := w.getTreeFromCommitHash(opts.Commit)
 	if err != nil {
 		return err
 	}
 
+	if opts.Mode == MixedReset || opts.Mode == MergeReset || opts.Mode == HardReset {
+		if err := w.resetIndex(t); err != nil {
+			return err
+		}
+	}
+
+	if opts.Mode == MergeReset || opts.Mode == HardReset {
+		if err := w.resetWorktree(t); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (w *Worktree) resetIndex(t *object.Tree) error {
 	idx, err := w.r.Storer.Index()
 	if err != nil {
 		return err
 	}
 
-	t, err := w.getTreeFromCommitHash(opts.Commit)
+	changes, err := w.diffTreeWithStaging(t, true)
+	if err != nil {
+		return err
+	}
+
+	for _, ch := range changes {
+		a, err := ch.Action()
+		if err != nil {
+			return err
+		}
+
+		var name string
+		var e *object.TreeEntry
+
+		switch a {
+		case merkletrie.Modify, merkletrie.Insert:
+			name = ch.To.String()
+			e, err = t.FindEntry(name)
+			if err != nil {
+				return err
+			}
+		case merkletrie.Delete:
+			name = ch.From.String()
+		}
+
+		_, _ = idx.Remove(name)
+		if e == nil {
+			continue
+		}
+
+		idx.Entries = append(idx.Entries, &index.Entry{
+			Name: name,
+			Hash: e.Hash,
+			Mode: e.Mode,
+		})
+
+	}
+
+	return w.r.Storer.SetIndex(idx)
+}
+
+func (w *Worktree) resetWorktree(t *object.Tree) error {
+	changes, err := w.diffStagingWithWorktree(true)
+	if err != nil {
+		return err
+	}
+
+	idx, err := w.r.Storer.Index()
 	if err != nil {
 		return err
 	}
@@ -291,20 +365,59 @@
 		}
 	}
 
-	if err := w.r.Storer.SetIndex(idx); err != nil {
+	return w.r.Storer.SetIndex(idx)
+}
+
+func (w *Worktree) checkoutChange(ch merkletrie.Change, t *object.Tree, idx *index.Index) error {
+	a, err := ch.Action()
+	if err != nil {
 		return err
 	}
 
-	return w.setHEADCommit(opts.Commit)
+	var e *object.TreeEntry
+	var name string
+	var isSubmodule bool
+
+	switch a {
+	case merkletrie.Modify, merkletrie.Insert:
+		name = ch.To.String()
+		e, err = t.FindEntry(name)
+		if err != nil {
+			return err
+		}
+
+		isSubmodule = e.Mode == filemode.Submodule
+	case merkletrie.Delete:
+		return rmFileAndDirIfEmpty(w.Filesystem, ch.From.String())
+	}
+
+	if isSubmodule {
+		return w.checkoutChangeSubmodule(name, a, e, idx)
+	}
+
+	return w.checkoutChangeRegularFile(name, a, t, e, idx)
 }
 
 func (w *Worktree) containsUnstagedChanges() (bool, error) {
-	ch, err := w.diffStagingWithWorktree()
+	ch, err := w.diffStagingWithWorktree(false)
 	if err != nil {
 		return false, err
 	}
 
-	return len(ch) != 0, nil
+	for _, c := range ch {
+		a, err := c.Action()
+		if err != nil {
+			return false, err
+		}
+
+		if a == merkletrie.Insert {
+			continue
+		}
+
+		return true, nil
+	}
+
+	return false, nil
 }
 
 func (w *Worktree) setHEADCommit(commit plumbing.Hash) error {
@@ -331,42 +444,6 @@
 	return w.r.Storer.SetReference(branch)
 }
 
-func (w *Worktree) checkoutChange(ch merkletrie.Change, t *object.Tree, idx *index.Index) error {
-	a, err := ch.Action()
-	if err != nil {
-		return err
-	}
-
-	var e *object.TreeEntry
-	var name string
-	var isSubmodule bool
-
-	switch a {
-	case merkletrie.Modify, merkletrie.Insert:
-		name = ch.To.String()
-		e, err = t.FindEntry(name)
-		if err != nil {
-			return err
-		}
-
-		isSubmodule = e.Mode == filemode.Submodule
-	case merkletrie.Delete:
-		name = ch.From.String()
-		ie, err := idx.Entry(name)
-		if err != nil {
-			return err
-		}
-
-		isSubmodule = ie.Mode == filemode.Submodule
-	}
-
-	if isSubmodule {
-		return w.checkoutChangeSubmodule(name, a, e, idx)
-	}
-
-	return w.checkoutChangeRegularFile(name, a, t, e, idx)
-}
-
 func (w *Worktree) checkoutChangeSubmodule(name string,
 	a merkletrie.Action,
 	e *object.TreeEntry,
@@ -383,17 +460,7 @@
 			return nil
 		}
 
-		if err := w.rmIndexFromFile(name, idx); err != nil {
-			return err
-		}
-
-		if err := w.addIndexFromTreeEntry(name, e, idx); err != nil {
-			return err
-		}
-
-		// TODO: the submodule update should be reviewed as reported at:
-		// https://github.com/src-d/go-git/issues/415
-		return sub.update(context.TODO(), &SubmoduleUpdateOptions{}, e.Hash)
+		return w.addIndexFromTreeEntry(name, e, idx)
 	case merkletrie.Insert:
 		mode, err := e.Mode.ToOSFileMode()
 		if err != nil {
@@ -405,12 +472,6 @@
 		}
 
 		return w.addIndexFromTreeEntry(name, e, idx)
-	case merkletrie.Delete:
-		if err := rmFileAndDirIfEmpty(w.Filesystem, name); err != nil {
-			return err
-		}
-
-		return w.rmIndexFromFile(name, idx)
 	}
 
 	return nil
@@ -424,9 +485,7 @@
 ) error {
 	switch a {
 	case merkletrie.Modify:
-		if err := w.rmIndexFromFile(name, idx); err != nil {
-			return err
-		}
+		_, _ = idx.Remove(name)
 
 		// to apply perm changes the file is deleted, billy doesn't implement
 		// chmod
@@ -446,12 +505,6 @@
 		}
 
 		return w.addIndexFromFile(name, e.Hash, idx)
-	case merkletrie.Delete:
-		if err := rmFileAndDirIfEmpty(w.Filesystem, name); err != nil {
-			return err
-		}
-
-		return w.rmIndexFromFile(name, idx)
 	}
 
 	return nil
@@ -503,6 +556,7 @@
 }
 
 func (w *Worktree) addIndexFromTreeEntry(name string, f *object.TreeEntry, idx *index.Index) error {
+	_, _ = idx.Remove(name)
 	idx.Entries = append(idx.Entries, &index.Entry{
 		Hash: f.Hash,
 		Name: name,
@@ -513,6 +567,7 @@
 }
 
 func (w *Worktree) addIndexFromFile(name string, h plumbing.Hash, idx *index.Index) error {
+	_, _ = idx.Remove(name)
 	fi, err := w.Filesystem.Lstat(name)
 	if err != nil {
 		return err
@@ -541,19 +596,6 @@
 	return nil
 }
 
-func (w *Worktree) rmIndexFromFile(name string, idx *index.Index) error {
-	for i, e := range idx.Entries {
-		if e.Name != name {
-			continue
-		}
-
-		idx.Entries = append(idx.Entries[:i], idx.Entries[i+1:]...)
-		return nil
-	}
-
-	return nil
-}
-
 func (w *Worktree) getTreeFromCommitHash(commit plumbing.Hash) (*object.Tree, error) {
 	c, err := w.r.CommitObject(commit)
 	if err != nil {
diff --git a/worktree_status.go b/worktree_status.go
index 9b0773e..24d0534 100644
--- a/worktree_status.go
+++ b/worktree_status.go
@@ -65,7 +65,7 @@
 		}
 	}
 
-	right, err := w.diffStagingWithWorktree()
+	right, err := w.diffStagingWithWorktree(false)
 	if err != nil {
 		return nil, err
 	}
@@ -104,7 +104,7 @@
 	return name
 }
 
-func (w *Worktree) diffStagingWithWorktree() (merkletrie.Changes, error) {
+func (w *Worktree) diffStagingWithWorktree(reverse bool) (merkletrie.Changes, error) {
 	idx, err := w.r.Storer.Index()
 	if err != nil {
 		return nil, err
@@ -117,11 +117,19 @@
 	}
 
 	to := filesystem.NewRootNode(w.Filesystem, submodules)
-	res, err := merkletrie.DiffTree(from, to, diffTreeIsEquals)
-	if err == nil {
-		res = w.excludeIgnoredChanges(res)
+
+	var c merkletrie.Changes
+	if reverse {
+		c, err = merkletrie.DiffTree(to, from, diffTreeIsEquals)
+	} else {
+		c, err = merkletrie.DiffTree(from, to, diffTreeIsEquals)
 	}
-	return res, err
+
+	if err != nil {
+		return nil, err
+	}
+
+	return w.excludeIgnoredChanges(c), nil
 }
 
 func (w *Worktree) excludeIgnoredChanges(changes merkletrie.Changes) merkletrie.Changes {
@@ -179,27 +187,35 @@
 }
 
 func (w *Worktree) diffCommitWithStaging(commit plumbing.Hash, reverse bool) (merkletrie.Changes, error) {
-	idx, err := w.r.Storer.Index()
-	if err != nil {
-		return nil, err
-	}
-
-	var from noder.Noder
+	var t *object.Tree
 	if !commit.IsZero() {
 		c, err := w.r.CommitObject(commit)
 		if err != nil {
 			return nil, err
 		}
 
-		t, err := c.Tree()
+		t, err = c.Tree()
 		if err != nil {
 			return nil, err
 		}
+	}
 
+	return w.diffTreeWithStaging(t, reverse)
+}
+
+func (w *Worktree) diffTreeWithStaging(t *object.Tree, reverse bool) (merkletrie.Changes, error) {
+	var from noder.Noder
+	if t != nil {
 		from = object.NewTreeRootNode(t)
 	}
 
+	idx, err := w.r.Storer.Index()
+	if err != nil {
+		return nil, err
+	}
+
 	to := mindex.NewRootNode(idx)
+
 	if reverse {
 		return merkletrie.DiffTree(to, from, diffTreeIsEquals)
 	}
diff --git a/worktree_test.go b/worktree_test.go
index 70167f0..1eb305d 100644
--- a/worktree_test.go
+++ b/worktree_test.go
@@ -259,7 +259,9 @@
 		Filesystem: fs,
 	}
 
-	err := w.Checkout(&CheckoutOptions{})
+	err := w.Checkout(&CheckoutOptions{
+		Force: true,
+	})
 	c.Assert(err, IsNil)
 
 	entries, err := fs.ReadDir("/")
@@ -278,6 +280,27 @@
 	c.Assert(idx.Entries, HasLen, 9)
 }
 
+func (s *WorktreeSuite) TestCheckoutForce(c *C) {
+	w := &Worktree{
+		r:          s.Repository,
+		Filesystem: memfs.New(),
+	}
+
+	err := w.Checkout(&CheckoutOptions{})
+	c.Assert(err, IsNil)
+
+	w.Filesystem = memfs.New()
+
+	err = w.Checkout(&CheckoutOptions{
+		Force: true,
+	})
+	c.Assert(err, IsNil)
+
+	entries, err := w.Filesystem.ReadDir("/")
+	c.Assert(err, IsNil)
+	c.Assert(entries, HasLen, 8)
+}
+
 func (s *WorktreeSuite) TestCheckoutSymlink(c *C) {
 	if runtime.GOOS == "windows" {
 		c.Skip("git doesn't support symlinks by default in windows")
@@ -608,35 +631,6 @@
 	})
 }
 
-func (s *WorktreeSuite) TestCheckoutWithGitignore(c *C) {
-	fs := memfs.New()
-	w := &Worktree{
-		r:          s.Repository,
-		Filesystem: fs,
-	}
-
-	err := w.Checkout(&CheckoutOptions{})
-	c.Assert(err, IsNil)
-
-	f, _ := fs.Create("file")
-	f.Close()
-
-	err = w.Checkout(&CheckoutOptions{})
-	c.Assert(err.Error(), Equals, "worktree contains unstagged changes")
-
-	f, _ = fs.Create(".gitignore")
-	f.Write([]byte("file"))
-	f.Close()
-
-	err = w.Checkout(&CheckoutOptions{})
-	c.Assert(err.Error(), Equals, "worktree contains unstagged changes")
-
-	w.Add(".gitignore")
-
-	err = w.Checkout(&CheckoutOptions{})
-	c.Assert(err, IsNil)
-}
-
 func (s *WorktreeSuite) TestStatus(c *C) {
 	fs := memfs.New()
 	w := &Worktree{
@@ -702,15 +696,19 @@
 	c.Assert(err, IsNil)
 	c.Assert(branch.Hash(), Not(Equals), commit)
 
-	err = w.Reset(&ResetOptions{Commit: commit})
+	err = w.Reset(&ResetOptions{Mode: MergeReset, Commit: commit})
 	c.Assert(err, IsNil)
 
 	branch, err = w.r.Reference(plumbing.Master, false)
 	c.Assert(err, IsNil)
 	c.Assert(branch.Hash(), Equals, commit)
+
+	status, err := w.Status()
+	c.Assert(err, IsNil)
+	c.Assert(status.IsClean(), Equals, true)
 }
 
-func (s *WorktreeSuite) TestResetMerge(c *C) {
+func (s *WorktreeSuite) TestResetWithUntracked(c *C) {
 	fs := memfs.New()
 	w := &Worktree{
 		r:          s.Repository,
@@ -722,6 +720,87 @@
 	err := w.Checkout(&CheckoutOptions{})
 	c.Assert(err, IsNil)
 
+	err = util.WriteFile(fs, "foo", nil, 0755)
+	c.Assert(err, IsNil)
+
+	err = w.Reset(&ResetOptions{Mode: MergeReset, Commit: commit})
+	c.Assert(err, IsNil)
+
+	status, err := w.Status()
+	c.Assert(err, IsNil)
+	c.Assert(status.IsClean(), Equals, true)
+}
+
+func (s *WorktreeSuite) TestResetSoft(c *C) {
+	fs := memfs.New()
+	w := &Worktree{
+		r:          s.Repository,
+		Filesystem: fs,
+	}
+
+	commit := plumbing.NewHash("35e85108805c84807bc66a02d91535e1e24b38b9")
+
+	err := w.Checkout(&CheckoutOptions{})
+	c.Assert(err, IsNil)
+
+	err = w.Reset(&ResetOptions{Mode: SoftReset, Commit: commit})
+	c.Assert(err, IsNil)
+
+	branch, err := w.r.Reference(plumbing.Master, false)
+	c.Assert(err, IsNil)
+	c.Assert(branch.Hash(), Equals, commit)
+
+	status, err := w.Status()
+	c.Assert(err, IsNil)
+	c.Assert(status.IsClean(), Equals, false)
+	c.Assert(status.File("CHANGELOG").Staging, Equals, Added)
+}
+
+func (s *WorktreeSuite) TestResetMixed(c *C) {
+	fs := memfs.New()
+	w := &Worktree{
+		r:          s.Repository,
+		Filesystem: fs,
+	}
+
+	commit := plumbing.NewHash("35e85108805c84807bc66a02d91535e1e24b38b9")
+
+	err := w.Checkout(&CheckoutOptions{})
+	c.Assert(err, IsNil)
+
+	err = w.Reset(&ResetOptions{Mode: MixedReset, Commit: commit})
+	c.Assert(err, IsNil)
+
+	branch, err := w.r.Reference(plumbing.Master, false)
+	c.Assert(err, IsNil)
+	c.Assert(branch.Hash(), Equals, commit)
+
+	status, err := w.Status()
+	c.Assert(err, IsNil)
+	c.Assert(status.IsClean(), Equals, false)
+	c.Assert(status.File("CHANGELOG").Staging, Equals, Untracked)
+}
+
+func (s *WorktreeSuite) TestResetMerge(c *C) {
+	fs := memfs.New()
+	w := &Worktree{
+		r:          s.Repository,
+		Filesystem: fs,
+	}
+
+	commitA := plumbing.NewHash("918c48b83bd081e863dbe1b80f8998f058cd8294")
+	commitB := plumbing.NewHash("35e85108805c84807bc66a02d91535e1e24b38b9")
+
+	err := w.Checkout(&CheckoutOptions{})
+	c.Assert(err, IsNil)
+
+	err = w.Reset(&ResetOptions{Mode: MergeReset, Commit: commitA})
+	c.Assert(err, IsNil)
+
+	branch, err := w.r.Reference(plumbing.Master, false)
+	c.Assert(err, IsNil)
+	c.Assert(branch.Hash(), Equals, commitA)
+
 	f, err := fs.Create(".gitignore")
 	c.Assert(err, IsNil)
 	_, err = f.Write([]byte("foo"))
@@ -729,12 +808,12 @@
 	err = f.Close()
 	c.Assert(err, IsNil)
 
-	err = w.Reset(&ResetOptions{Mode: MergeReset, Commit: commit})
+	err = w.Reset(&ResetOptions{Mode: MergeReset, Commit: commitB})
 	c.Assert(err, Equals, ErrUnstaggedChanges)
 
-	branch, err := w.r.Reference(plumbing.Master, false)
+	branch, err = w.r.Reference(plumbing.Master, false)
 	c.Assert(err, IsNil)
-	c.Assert(branch.Hash(), Not(Equals), commit)
+	c.Assert(branch.Hash(), Equals, commitA)
 }
 
 func (s *WorktreeSuite) TestResetHard(c *C) {