errgroup: use WithCancelCause to cancel context

Fixes golang/go#59355

Change-Id: Ib6a88e7e5fefe7b0d5672035af16d109aabcbf1e
Reviewed-on: https://go-review.googlesource.com/c/sync/+/481255
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Bryan Mills <bcmills@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
diff --git a/errgroup/errgroup.go b/errgroup/errgroup.go
index cbee7a4..b18efb7 100644
--- a/errgroup/errgroup.go
+++ b/errgroup/errgroup.go
@@ -20,7 +20,7 @@
 // A zero Group is valid, has no limit on the number of active goroutines,
 // and does not cancel on error.
 type Group struct {
-	cancel func()
+	cancel func(error)
 
 	wg sync.WaitGroup
 
@@ -43,7 +43,7 @@
 // returns a non-nil error or the first time Wait returns, whichever occurs
 // first.
 func WithContext(ctx context.Context) (*Group, context.Context) {
-	ctx, cancel := context.WithCancel(ctx)
+	ctx, cancel := withCancelCause(ctx)
 	return &Group{cancel: cancel}, ctx
 }
 
@@ -52,7 +52,7 @@
 func (g *Group) Wait() error {
 	g.wg.Wait()
 	if g.cancel != nil {
-		g.cancel()
+		g.cancel(g.err)
 	}
 	return g.err
 }
@@ -76,7 +76,7 @@
 			g.errOnce.Do(func() {
 				g.err = err
 				if g.cancel != nil {
-					g.cancel()
+					g.cancel(g.err)
 				}
 			})
 		}
@@ -105,7 +105,7 @@
 			g.errOnce.Do(func() {
 				g.err = err
 				if g.cancel != nil {
-					g.cancel()
+					g.cancel(g.err)
 				}
 			})
 		}
diff --git a/errgroup/go120.go b/errgroup/go120.go
new file mode 100644
index 0000000..7d419d3
--- /dev/null
+++ b/errgroup/go120.go
@@ -0,0 +1,14 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.20
+// +build go1.20
+
+package errgroup
+
+import "context"
+
+func withCancelCause(parent context.Context) (context.Context, func(error)) {
+	return context.WithCancelCause(parent)
+}
diff --git a/errgroup/go120_test.go b/errgroup/go120_test.go
new file mode 100644
index 0000000..0c354a1
--- /dev/null
+++ b/errgroup/go120_test.go
@@ -0,0 +1,55 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.20
+// +build go1.20
+
+package errgroup_test
+
+import (
+	"context"
+	"errors"
+	"testing"
+
+	"golang.org/x/sync/errgroup"
+)
+
+func TestCancelCause(t *testing.T) {
+	errDoom := errors.New("group_test: doomed")
+
+	cases := []struct {
+		errs []error
+		want error
+	}{
+		{want: nil},
+		{errs: []error{nil}, want: nil},
+		{errs: []error{errDoom}, want: errDoom},
+		{errs: []error{errDoom, nil}, want: errDoom},
+	}
+
+	for _, tc := range cases {
+		g, ctx := errgroup.WithContext(context.Background())
+
+		for _, err := range tc.errs {
+			err := err
+			g.TryGo(func() error { return err })
+		}
+
+		if err := g.Wait(); err != tc.want {
+			t.Errorf("after %T.TryGo(func() error { return err }) for err in %v\n"+
+				"g.Wait() = %v; want %v",
+				g, tc.errs, err, tc.want)
+		}
+
+		if tc.want == nil {
+			tc.want = context.Canceled
+		}
+
+		if err := context.Cause(ctx); err != tc.want {
+			t.Errorf("after %T.TryGo(func() error { return err }) for err in %v\n"+
+				"context.Cause(ctx) = %v; tc.want %v",
+				g, tc.errs, err, tc.want)
+		}
+	}
+}
diff --git a/errgroup/pre_go120.go b/errgroup/pre_go120.go
new file mode 100644
index 0000000..1795c18
--- /dev/null
+++ b/errgroup/pre_go120.go
@@ -0,0 +1,15 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !go1.20
+// +build !go1.20
+
+package errgroup
+
+import "context"
+
+func withCancelCause(parent context.Context) (context.Context, func(error)) {
+	ctx, cancel := context.WithCancel(parent)
+	return ctx, func(error) { cancel() }
+}