[testsharder] Add -max-shard-size flag

We can play with this value to trade off machine utilization against
test execution latency.
In the future we can hopefully get signals from the inputs to this
tool that will let us make more intelligent decisions.

Bug: IN-1433
Change-Id: I1b97c19fae187ed17c86dbadf9a24864a6b6496e
diff --git a/cmd/testsharder/main.go b/cmd/testsharder/main.go
index a80d71c..c715cdb 100644
--- a/cmd/testsharder/main.go
+++ b/cmd/testsharder/main.go
@@ -29,6 +29,9 @@
 
 	// The path to the json manifest file containing the tests to mutiply.
 	multipliersPath string
+
+	// Maximum number of tests per shard.
+	maxShardSize int
 )
 
 func usage() {
@@ -46,6 +49,7 @@
 	flag.Var(&mode, "mode", "mode in which to run the testsharder (e.g., normal or restricted).")
 	flag.Var(&tags, "tag", "environment tags on which to filter; only the tests that match all tags will be sharded")
 	flag.StringVar(&multipliersPath, "multipliers", "", "path to the json manifest containing tests to multiply")
+	flag.IntVar(&maxShardSize, "max-shard-size", 0, "maximum number of tests per shard. If <= 0, will be ignored. Otherwise, tests will be placed into more, smaller shards")
 	flag.Usage = usage
 }
 
@@ -82,6 +86,7 @@
 			log.Fatal(err)
 		}
 	}
+	shards = testsharder.WithMaxSize(shards, maxShardSize)
 	f := os.Stdout
 	if outputFile != "" {
 		var err error
diff --git a/testsharder/shard.go b/testsharder/shard.go
index fb7b159..b5c71b2 100644
--- a/testsharder/shard.go
+++ b/testsharder/shard.go
@@ -24,11 +24,11 @@
 
 // MakeShards is the core algorithm to this tool. It takes a set of test specs and produces
 // a set of shards which may then be converted into Swarming tasks.
+// A single output Shard will contain only tests that have the same Envs.
 //
-// Environments that do not match all specified tags will be ignored.
+// Environments that do not match all tags will be ignored.
 //
-// This is the most naive algorithm at the moment. It just merges all tests together which
-// have the same environment setting into the same shard.
+// In Restricted mode, environments that don't specify a ServiceAccount will be ignored.
 func MakeShards(specs []TestSpec, mode Mode, tags []string) []*Shard {
 	// Collect the order of the shards so our shard ordering is deterministic with
 	// respect to the input.
@@ -98,6 +98,34 @@
 	return shards, nil
 }
 
+func min(a, b int) int {
+	if a < b {
+		return a
+	}
+	return b
+}
+
+// WithMaxSize returns a list of shards such that each shard contains fewer than maxShardSize tests.
+// If maxShardSize <= 0, just returns its input.
+func WithMaxSize(shards []*Shard, maxShardSize int) []*Shard {
+	if maxShardSize <= 0 {
+		return shards
+	}
+	output := make([]*Shard, 0, len(shards))
+	for _, shard := range shards {
+		for i := 0; i*maxShardSize < len(shard.Tests); i++ {
+			sliceStart := i * maxShardSize
+			sliceLimit := min((i+1)*maxShardSize, len(shard.Tests))
+			output = append(output, &Shard{
+				Name:  fmt.Sprintf("%s-(%d)", shard.Name, i),
+				Tests: shard.Tests[sliceStart:sliceLimit],
+				Env:   shard.Env,
+			})
+		}
+	}
+	return output
+}
+
 // Removes leading slashes and replaces all other `/` with `_`. This allows the
 // shard name to appear in filepaths.
 func normalizeTestName(name string) string {
diff --git a/testsharder/shard_test.go b/testsharder/shard_test.go
index 5ff524f..eea167c 100644
--- a/testsharder/shard_test.go
+++ b/testsharder/shard_test.go
@@ -25,6 +25,31 @@
 	}
 }
 
+func spec(id int, envs ...Environment) TestSpec {
+	return TestSpec{
+		Test: Test{
+			Location: fmt.Sprintf("/path/to/test/%d", id),
+		},
+		Envs: envs,
+	}
+}
+
+func shard(env Environment, ids ...int) *Shard {
+	return namedShard(env, env.Name(), ids...)
+}
+
+func namedShard(env Environment, name string, ids ...int) *Shard {
+	var tests []Test
+	for _, id := range ids {
+		tests = append(tests, spec(id, env).Test)
+	}
+	return &Shard{
+		Name:  name,
+		Tests: tests,
+		Env:   env,
+	}
+}
+
 func TestMakeShards(t *testing.T) {
 	env1 := Environment{
 		Dimensions: DimensionSet{DeviceType: "QEMU"},
@@ -47,27 +72,6 @@
 		}
 	})
 
-	spec := func(id int, envs ...Environment) TestSpec {
-		return TestSpec{
-			Test: Test{
-				Location: fmt.Sprintf("/path/to/test/%d", id),
-			},
-			Envs: envs,
-		}
-	}
-
-	shard := func(env Environment, ids ...int) *Shard {
-		var tests []Test
-		for _, id := range ids {
-			tests = append(tests, spec(id, env).Test)
-		}
-		return &Shard{
-			Name:  env.Name(),
-			Tests: tests,
-			Env:   env,
-		}
-	}
-
 	t.Run("tests of same environment are grouped", func(t *testing.T) {
 		actual := MakeShards(
 			[]TestSpec{spec(1, env1, env2), spec(2, env1, env3), spec(3, env3)},
@@ -274,3 +278,46 @@
 		}
 	})
 }
+
+func TestWithMaxSize(t *testing.T) {
+	env1 := Environment{
+		Tags: []string{"env1"},
+	}
+	env2 := Environment{
+		Dimensions: DimensionSet{DeviceType: "env2"},
+		Tags:       []string{"env2"},
+	}
+	input := []*Shard{namedShard(env1, "env1", 1, 2, 3, 4, 5), namedShard(env2, "env2", 6, 7, 8)}
+	t.Run("does nothing if max is 0", func(t *testing.T) {
+		assertEqual(t, input, WithMaxSize(input, 0))
+	})
+	t.Run("does nothing if max is < 0", func(t *testing.T) {
+		assertEqual(t, input, WithMaxSize(input, -7))
+	})
+	assertShardsLessThanSize := func(t *testing.T, actual []*Shard, maxSize int) {
+		for _, s := range actual {
+			if len(s.Tests) > maxSize {
+				t.Errorf("Shard %s has %d tests, expected at most %d", s.Name, len(s.Tests), maxSize)
+			}
+		}
+	}
+	t.Run("max is larger than all shards", func(t *testing.T) {
+		maxSize := len(input[0].Tests)+len(input[1].Tests)
+		actual := WithMaxSize(input, maxSize)
+		assertEqual(t, []*Shard{
+			// Returns equivalent shards, but renamed.
+			namedShard(env1, "env1-(0)", 1, 2, 3, 4, 5), namedShard(env2, "env2-(0)", 6, 7, 8)},
+			actual)
+			assertShardsLessThanSize(t, actual, maxSize)
+	})
+	t.Run("applies max", func(t *testing.T) {
+		maxSize := 2
+		actual := WithMaxSize(input, maxSize)
+		assertEqual(t, []*Shard{
+			namedShard(env1, "env1-(0)", 1, 2), namedShard(env1, "env1-(1)", 3, 4),
+			namedShard(env1, "env1-(2)", 5),
+			namedShard(env2, "env2-(0)", 6, 7), namedShard(env2, "env2-(1)", 8)},
+			actual)
+		assertShardsLessThanSize(t, actual, maxSize)
+	})
+}