Merge pull request #485 from mcuadros/fetch-tags

remote: fetch, correct behavior on tags
diff --git a/options.go b/options.go
index 977e462..bbfe244 100644
--- a/options.go
+++ b/options.go
@@ -103,6 +103,18 @@
 	return nil
 }
 
+type TagFetchMode int
+
+var (
+	// TagFollowing any tag that points into the histories being fetched is also
+	// fetched. TagFollowing requires a server with `include-tag` capability
+	// in order to fetch the annotated tags objects.
+	TagFollowing TagFetchMode = 0
+	// AllTags fetch all tags from the remote (i.e., fetch remote tags
+	// refs/tags/* into local tags with the same name)
+	AllTags TagFetchMode = 1
+)
+
 // FetchOptions describes how a fetch should be performed
 type FetchOptions struct {
 	// Name of the remote to fetch from. Defaults to origin.
@@ -117,6 +129,9 @@
 	// stored, if nil nothing is stored and the capability (if supported)
 	// no-progress, is sent to the server to avoid send this information.
 	Progress sideband.Progress
+	// Tags describe how the tags will be fetched from the remote repository,
+	// by default is TagFollowing.
+	Tags TagFetchMode
 }
 
 // Validate validates the fields and sets the default values.
diff --git a/remote.go b/remote.go
index 4c2643b..8f1da2f 100644
--- a/remote.go
+++ b/remote.go
@@ -134,7 +134,7 @@
 	return rs.Error()
 }
 
-func (r *Remote) fetch(o *FetchOptions) (refs storer.ReferenceStorer, err error) {
+func (r *Remote) fetch(o *FetchOptions) (storer.ReferenceStorer, error) {
 	if o.RemoteName == "" {
 		o.RemoteName = r.c.Name
 	}
@@ -169,7 +169,12 @@
 		return nil, err
 	}
 
-	req.Wants, err = getWants(o.RefSpecs, r.s, remoteRefs)
+	refs, err := calculateRefs(o.RefSpecs, remoteRefs, o.Tags)
+	if err != nil {
+		return nil, err
+	}
+
+	req.Wants, err = getWants(r.s, refs)
 	if len(req.Wants) > 0 {
 		req.Haves, err = getHaves(r.s)
 		if err != nil {
@@ -181,14 +186,15 @@
 		}
 	}
 
-	err = r.updateLocalReferenceStorage(o.RefSpecs, remoteRefs)
-	if err != nil && err != NoErrAlreadyUpToDate {
+	updated, err := r.updateLocalReferenceStorage(o.RefSpecs, refs, remoteRefs)
+	if err != nil {
 		return nil, err
 	}
 
-	if len(req.Wants) == 0 {
-		return remoteRefs, err
+	if !updated {
+		return remoteRefs, NoErrAlreadyUpToDate
 	}
+
 	return remoteRefs, nil
 }
 
@@ -382,56 +388,52 @@
 	return result, nil
 }
 
-func getWants(
-	spec []config.RefSpec, localStorer storage.Storer, remoteRefs storer.ReferenceStorer,
-) ([]plumbing.Hash, error) {
-	wantTags := true
-	for _, s := range spec {
-		if !s.IsWildcard() {
-			wantTags = false
-			break
-		}
-	}
-
+func calculateRefs(spec []config.RefSpec,
+	remoteRefs storer.ReferenceStorer,
+	tags TagFetchMode,
+) (memory.ReferenceStorage, error) {
 	iter, err := remoteRefs.IterReferences()
 	if err != nil {
 		return nil, err
 	}
 
-	wants := map[plumbing.Hash]bool{}
-	err = iter.ForEach(func(ref *plumbing.Reference) error {
+	refs := make(memory.ReferenceStorage, 0)
+	return refs, iter.ForEach(func(ref *plumbing.Reference) error {
 		if !config.MatchAny(spec, ref.Name()) {
-			if !ref.IsTag() || !wantTags {
+			if !ref.IsTag() || tags != AllTags {
 				return nil
 			}
 		}
 
 		if ref.Type() == plumbing.SymbolicReference {
-			ref, err = storer.ResolveReference(remoteRefs, ref.Name())
+			target, err := storer.ResolveReference(remoteRefs, ref.Name())
 			if err != nil {
 				return err
 			}
+
+			ref = plumbing.NewHashReference(ref.Name(), target.Hash())
 		}
 
 		if ref.Type() != plumbing.HashReference {
 			return nil
 		}
 
-		hash := ref.Hash()
+		return refs.SetReference(ref)
+	})
+}
 
-		exists, err := objectExists(localStorer, hash)
+func getWants(localStorer storage.Storer, refs memory.ReferenceStorage) ([]plumbing.Hash, error) {
+	wants := map[plumbing.Hash]bool{}
+	for _, ref := range refs {
+		hash := ref.Hash()
+		exists, err := objectExists(localStorer, ref.Hash())
 		if err != nil {
-			return err
+			return nil, err
 		}
 
 		if !exists {
 			wants[hash] = true
 		}
-
-		return nil
-	})
-	if err != nil {
-		return nil, err
 	}
 
 	var result []plumbing.Hash
@@ -513,6 +515,19 @@
 		}
 	}
 
+	isWildcard := true
+	for _, s := range o.RefSpecs {
+		if !s.IsWildcard() {
+			isWildcard = false
+		}
+	}
+
+	if isWildcard && o.Tags == TagFollowing && ar.Capabilities.Supports(capability.IncludeTag) {
+		if err := req.Capabilities.Set(capability.IncludeTag); err != nil {
+			return nil, err
+		}
+	}
+
 	return req, nil
 }
 
@@ -534,10 +549,17 @@
 	return d
 }
 
-func (r *Remote) updateLocalReferenceStorage(specs []config.RefSpec, refs memory.ReferenceStorage) error {
-	updated := false
+func (r *Remote) updateLocalReferenceStorage(
+	specs []config.RefSpec,
+	fetchedRefs, remoteRefs memory.ReferenceStorage,
+) (updated bool, err error) {
+	isWildcard := true
 	for _, spec := range specs {
-		for _, ref := range refs {
+		if !spec.IsWildcard() {
+			isWildcard = false
+		}
+
+		for _, ref := range fetchedRefs {
 			if !spec.Match(ref.Name()) {
 				continue
 			}
@@ -546,33 +568,36 @@
 				continue
 			}
 
-			name := spec.Dst(ref.Name())
-			sref, err := r.s.Reference(name)
-			if err != nil && err != plumbing.ErrReferenceNotFound {
-				return err
+			new := plumbing.NewHashReference(spec.Dst(ref.Name()), ref.Hash())
+
+			refUpdated, err := updateReferenceStorerIfNeeded(r.s, new)
+			if err != nil {
+				return updated, err
 			}
-			if err == plumbing.ErrReferenceNotFound || sref.Hash() != ref.Hash() {
-				n := plumbing.NewHashReference(name, ref.Hash())
-				if err := r.s.SetReference(n); err != nil {
-					return err
-				}
+
+			if refUpdated {
 				updated = true
 			}
 		}
 	}
 
-	if err := r.buildFetchedTags(refs); err != nil {
-		return err
+	tags := fetchedRefs
+	if isWildcard {
+		tags = remoteRefs
+	}
+	tagUpdated, err := r.buildFetchedTags(tags)
+	if err != nil {
+		return updated, err
 	}
 
-	if !updated {
-		return NoErrAlreadyUpToDate
+	if tagUpdated {
+		updated = true
 	}
-	return nil
+
+	return
 }
 
-func (r *Remote) buildFetchedTags(refs memory.ReferenceStorage) error {
-	updated := false
+func (r *Remote) buildFetchedTags(refs memory.ReferenceStorage) (updated bool, err error) {
 	for _, ref := range refs {
 		if !ref.IsTag() {
 			continue
@@ -584,18 +609,20 @@
 		}
 
 		if err != nil {
-			return err
+			return false, err
 		}
 
-		if err = r.s.SetReference(ref); err != nil {
-			return err
+		refUpdated, err := updateReferenceStorerIfNeeded(r.s, ref)
+		if err != nil {
+			return updated, err
 		}
-		updated = true
+
+		if refUpdated {
+			updated = true
+		}
 	}
-	if !updated {
-		return NoErrAlreadyUpToDate
-	}
-	return nil
+
+	return
 }
 
 func objectsToPush(commands []*packp.Command) ([]plumbing.Hash, error) {
diff --git a/remote_test.go b/remote_test.go
index 501d06e..7ffe040 100644
--- a/remote_test.go
+++ b/remote_test.go
@@ -51,61 +51,109 @@
 	c.Assert(err, Equals, config.ErrRefSpecMalformedSeparator)
 }
 
-func (s *RemoteSuite) TestFetch(c *C) {
-	url := s.GetBasicLocalRepositoryURL()
-	sto := memory.NewStorage()
-	r := newRemote(sto, &config.RemoteConfig{Name: "foo", URL: url})
-
-	refspec := config.RefSpec("+refs/heads/*:refs/remotes/origin/*")
-	err := r.Fetch(&FetchOptions{
-		RefSpecs: []config.RefSpec{refspec},
+func (s *RemoteSuite) TestFetchWildcard(c *C) {
+	r := newRemote(memory.NewStorage(), &config.RemoteConfig{
+		URL: s.GetBasicLocalRepositoryURL(),
 	})
 
-	c.Assert(err, IsNil)
-	c.Assert(sto.Objects, HasLen, 31)
-
-	expectedRefs := []*plumbing.Reference{
+	s.testFetch(c, r, &FetchOptions{
+		RefSpecs: []config.RefSpec{
+			config.RefSpec("+refs/heads/*:refs/remotes/origin/*"),
+		},
+	}, []*plumbing.Reference{
 		plumbing.NewReferenceFromStrings("refs/remotes/origin/master", "6ecf0ef2c2dffb796033e5a02219af86ec6584e5"),
 		plumbing.NewReferenceFromStrings("refs/remotes/origin/branch", "e8d3ffab552895c19b9fcf7aa264d277cde33881"),
-	}
-
-	for _, exp := range expectedRefs {
-		r, _ := sto.Reference(exp.Name())
-		c.Assert(exp.String(), Equals, r.String())
-	}
+		plumbing.NewReferenceFromStrings("refs/tags/v1.0.0", "6ecf0ef2c2dffb796033e5a02219af86ec6584e5"),
+	})
 }
 
-func (s *RemoteSuite) TestFetchDepth(c *C) {
-	url := s.GetBasicLocalRepositoryURL()
-	sto := memory.NewStorage()
-	r := newRemote(sto, &config.RemoteConfig{Name: "foo", URL: url})
-
-	refspec := config.RefSpec("+refs/heads/*:refs/remotes/origin/*")
-	err := r.Fetch(&FetchOptions{
-		RefSpecs: []config.RefSpec{refspec},
-		Depth:    1,
+func (s *RemoteSuite) TestFetchWildcardTags(c *C) {
+	r := newRemote(memory.NewStorage(), &config.RemoteConfig{
+		URL: s.GetLocalRepositoryURL(fixtures.ByTag("tags").One()),
 	})
 
-	c.Assert(err, IsNil)
-	c.Assert(sto.Objects, HasLen, 18)
+	s.testFetch(c, r, &FetchOptions{
+		RefSpecs: []config.RefSpec{
+			config.RefSpec("+refs/heads/*:refs/remotes/origin/*"),
+		},
+	}, []*plumbing.Reference{
+		plumbing.NewReferenceFromStrings("refs/remotes/origin/master", "f7b877701fbf855b44c0a9e86f3fdce2c298b07f"),
+		plumbing.NewReferenceFromStrings("refs/tags/annotated-tag", "b742a2a9fa0afcfa9a6fad080980fbc26b007c69"),
+		plumbing.NewReferenceFromStrings("refs/tags/tree-tag", "152175bf7e5580299fa1f0ba41ef6474cc043b70"),
+		plumbing.NewReferenceFromStrings("refs/tags/commit-tag", "ad7897c0fb8e7d9a9ba41fa66072cf06095a6cfc"),
+		plumbing.NewReferenceFromStrings("refs/tags/blob-tag", "fe6cb94756faa81e5ed9240f9191b833db5f40ae"),
+		plumbing.NewReferenceFromStrings("refs/tags/lightweight-tag", "f7b877701fbf855b44c0a9e86f3fdce2c298b07f"),
+	})
+}
 
-	expectedRefs := []*plumbing.Reference{
+func (s *RemoteSuite) TestFetch(c *C) {
+	r := newRemote(memory.NewStorage(), &config.RemoteConfig{
+		URL: s.GetLocalRepositoryURL(fixtures.ByTag("tags").One()),
+	})
+
+	s.testFetch(c, r, &FetchOptions{
+		RefSpecs: []config.RefSpec{
+			config.RefSpec("+refs/heads/master:refs/remotes/origin/master"),
+		},
+	}, []*plumbing.Reference{
+		plumbing.NewReferenceFromStrings("refs/remotes/origin/master", "f7b877701fbf855b44c0a9e86f3fdce2c298b07f"),
+	})
+}
+
+func (s *RemoteSuite) TestFetchWithAllTags(c *C) {
+	r := newRemote(memory.NewStorage(), &config.RemoteConfig{
+		URL: s.GetLocalRepositoryURL(fixtures.ByTag("tags").One()),
+	})
+
+	s.testFetch(c, r, &FetchOptions{
+		Tags: AllTags,
+		RefSpecs: []config.RefSpec{
+			config.RefSpec("+refs/heads/master:refs/remotes/origin/master"),
+		},
+	}, []*plumbing.Reference{
+		plumbing.NewReferenceFromStrings("refs/remotes/origin/master", "f7b877701fbf855b44c0a9e86f3fdce2c298b07f"),
+		plumbing.NewReferenceFromStrings("refs/tags/annotated-tag", "b742a2a9fa0afcfa9a6fad080980fbc26b007c69"),
+		plumbing.NewReferenceFromStrings("refs/tags/tree-tag", "152175bf7e5580299fa1f0ba41ef6474cc043b70"),
+		plumbing.NewReferenceFromStrings("refs/tags/commit-tag", "ad7897c0fb8e7d9a9ba41fa66072cf06095a6cfc"),
+		plumbing.NewReferenceFromStrings("refs/tags/blob-tag", "fe6cb94756faa81e5ed9240f9191b833db5f40ae"),
+		plumbing.NewReferenceFromStrings("refs/tags/lightweight-tag", "f7b877701fbf855b44c0a9e86f3fdce2c298b07f"),
+	})
+}
+
+func (s *RemoteSuite) TestFetchWithDepth(c *C) {
+	r := newRemote(memory.NewStorage(), &config.RemoteConfig{
+		URL: s.GetBasicLocalRepositoryURL(),
+	})
+
+	s.testFetch(c, r, &FetchOptions{
+		Depth: 1,
+		RefSpecs: []config.RefSpec{
+			config.RefSpec("+refs/heads/*:refs/remotes/origin/*"),
+		},
+	}, []*plumbing.Reference{
 		plumbing.NewReferenceFromStrings("refs/remotes/origin/master", "6ecf0ef2c2dffb796033e5a02219af86ec6584e5"),
 		plumbing.NewReferenceFromStrings("refs/remotes/origin/branch", "e8d3ffab552895c19b9fcf7aa264d277cde33881"),
-	}
+		plumbing.NewReferenceFromStrings("refs/tags/v1.0.0", "6ecf0ef2c2dffb796033e5a02219af86ec6584e5"),
+	})
 
-	for _, exp := range expectedRefs {
-		r, _ := sto.Reference(exp.Name())
+	c.Assert(r.s.(*memory.Storage).Objects, HasLen, 18)
+}
+
+func (s *RemoteSuite) testFetch(c *C, r *Remote, o *FetchOptions, expected []*plumbing.Reference) {
+	err := r.Fetch(o)
+	c.Assert(err, IsNil)
+
+	var refs int
+	l, err := r.s.IterReferences()
+	l.ForEach(func(r *plumbing.Reference) error { refs++; return nil })
+
+	c.Assert(refs, Equals, len(expected))
+
+	for _, exp := range expected {
+		r, err := r.s.Reference(exp.Name())
+		c.Assert(err, IsNil)
 		c.Assert(exp.String(), Equals, r.String())
 	}
-
-	h, err := sto.Shallow()
-	c.Assert(err, IsNil)
-	c.Assert(h, HasLen, 2)
-	c.Assert(h, DeepEquals, []plumbing.Hash{
-		plumbing.NewHash("e8d3ffab552895c19b9fcf7aa264d277cde33881"),
-		plumbing.NewHash("6ecf0ef2c2dffb796033e5a02219af86ec6584e5"),
-	})
 }
 
 func (s *RemoteSuite) TestFetchWithProgress(c *C) {
@@ -177,26 +225,33 @@
 }
 
 func (s *RemoteSuite) TestFetchNoErrAlreadyUpToDateButStillUpdateLocalRemoteRefs(c *C) {
-	url := s.GetBasicLocalRepositoryURL()
+	r := newRemote(memory.NewStorage(), &config.RemoteConfig{
+		URL: s.GetBasicLocalRepositoryURL(),
+	})
 
-	sto := memory.NewStorage()
-	r := newRemote(sto, &config.RemoteConfig{Name: "foo", URL: url})
-
-	refspec := config.RefSpec("+refs/heads/*:refs/remotes/origin/*")
 	o := &FetchOptions{
-		RefSpecs: []config.RefSpec{refspec},
+		RefSpecs: []config.RefSpec{
+			config.RefSpec("+refs/heads/*:refs/remotes/origin/*"),
+		},
 	}
 
 	err := r.Fetch(o)
 	c.Assert(err, IsNil)
 
 	// Simulate an out of date remote ref even though we have the new commit locally
-	sto.SetReference(plumbing.NewReferenceFromStrings("refs/remotes/origin/master", "918c48b83bd081e863dbe1b80f8998f058cd8294"))
+	r.s.SetReference(plumbing.NewReferenceFromStrings(
+		"refs/remotes/origin/master", "918c48b83bd081e863dbe1b80f8998f058cd8294",
+	))
 
 	err = r.Fetch(o)
 	c.Assert(err, IsNil)
-	exp := plumbing.NewReferenceFromStrings("refs/remotes/origin/master", "6ecf0ef2c2dffb796033e5a02219af86ec6584e5")
-	ref, _ := sto.Reference("refs/remotes/origin/master")
+
+	exp := plumbing.NewReferenceFromStrings(
+		"refs/remotes/origin/master", "6ecf0ef2c2dffb796033e5a02219af86ec6584e5",
+	)
+
+	ref, err := r.s.Reference("refs/remotes/origin/master")
+	c.Assert(err, IsNil)
 	c.Assert(exp.String(), Equals, ref.String())
 }
 
@@ -207,13 +262,12 @@
 }
 
 func (s *RemoteSuite) doTestFetchNoErrAlreadyUpToDate(c *C, url string) {
+	r := newRemote(memory.NewStorage(), &config.RemoteConfig{URL: url})
 
-	sto := memory.NewStorage()
-	r := newRemote(sto, &config.RemoteConfig{Name: "foo", URL: url})
-
-	refspec := config.RefSpec("+refs/heads/*:refs/remotes/origin/*")
 	o := &FetchOptions{
-		RefSpecs: []config.RefSpec{refspec},
+		RefSpecs: []config.RefSpec{
+			config.RefSpec("+refs/heads/*:refs/remotes/origin/*"),
+		},
 	}
 
 	err := r.Fetch(o)