blob: d0573490e76239e2d35ac9a99fc73bb4474488d5 [file] [log] [blame]
package main
import (
"os"
"reflect"
"testing"
"golang.org/x/tools/go/vcs"
)
var (
root = &vcs.RepoRoot{
VCS: vcs.ByCmd("git"),
Repo: "https://github.com/bazeltest/rules_go",
Root: "github.com/bazeltest/rules_go",
}
)
func TestMain(m *testing.M) {
// Replace vcs.RepoRootForImportPath to disable any network calls.
repoRootForImportPath = func(_ string, _ bool) (*vcs.RepoRoot, error) {
return root, nil
}
os.Exit(m.Run())
}
func TestGetRepoRoot(t *testing.T) {
for _, tc := range []struct {
label string
remote string
cmd string
importpath string
r *vcs.RepoRoot
}{
{
label: "all",
remote: "https://github.com/bazeltest/rules_go",
cmd: "git",
importpath: "github.com/bazeltest/rules_go",
r: root,
},
{
label: "different remote",
remote: "https://example.com/rules_go",
cmd: "git",
importpath: "github.com/bazeltest/rules_go",
r: &vcs.RepoRoot{
VCS: vcs.ByCmd("git"),
Repo: "https://example.com/rules_go",
Root: "github.com/bazeltest/rules_go",
},
},
{
label: "only importpath",
importpath: "github.com/bazeltest/rules_go",
r: root,
},
} {
r, err := getRepoRoot(tc.remote, tc.cmd, tc.importpath)
if err != nil {
t.Errorf("[%s] %v", tc.label, err)
}
if !reflect.DeepEqual(r, tc.r) {
t.Errorf("[%s] Expected %+v, got %+v", tc.label, tc.r, r)
}
}
}
func TestGetRepoRoot_error(t *testing.T) {
for _, tc := range []struct {
label string
remote string
cmd string
importpath string
}{
{
label: "importpath as remote",
remote: "github.com/bazeltest/rules_go",
},
{
label: "missing vcs",
remote: "https://github.com/bazeltest/rules_go",
importpath: "github.com/bazeltest/rules_go",
},
{
label: "missing remote",
cmd: "git",
importpath: "github.com/bazeltest/rules_go",
},
} {
r, err := getRepoRoot(tc.remote, tc.cmd, tc.importpath)
if err == nil {
t.Errorf("[%s] expected error. Got %+v", tc.label, r)
}
}
}