logging: add an option to supply the context used for WriteLogEntries

Provide a LoggingOption that lets the user supply the context for the
WriteLogEntries call, instead of always using context.Background.

Fixes #1136.

Change-Id: Ia4043b2bda26908ab456616ad51320bcb73cabc5
Reviewed-on: https://code-review.googlesource.com/c/33930
Reviewed-by: Jean de Klerk <deklerk@google.com>
diff --git a/logging/examples_test.go b/logging/examples_test.go
index b4bb452..d4a2c3a 100644
--- a/logging/examples_test.go
+++ b/logging/examples_test.go
@@ -20,6 +20,7 @@
 	"os"
 
 	"cloud.google.com/go/logging"
+	"go.opencensus.io/trace"
 	"golang.org/x/net/context"
 )
 
@@ -164,3 +165,19 @@
 	fmt.Println(sev)
 	// Output: Alert
 }
+
+// This example shows how to create a Logger that disables OpenCensus tracing of the
+// WriteLogEntries RPC.
+func ExampleContextFunc() {
+	ctx := context.Background()
+	client, err := logging.NewClient(ctx, "my-project")
+	if err != nil {
+		// TODO: Handle error.
+	}
+	lg := client.Logger("logID", logging.ContextFunc(func() (context.Context, func()) {
+		ctx, span := trace.StartSpan(context.Background(), "this span will not be exported",
+			trace.WithSampler(trace.NeverSample()))
+		return ctx, span.End
+	}))
+	_ = lg // TODO: Use lg
+}
diff --git a/logging/logging.go b/logging/logging.go
index b09a0b8..f7a8251 100644
--- a/logging/logging.go
+++ b/logging/logging.go
@@ -234,6 +234,7 @@
 	commonResource *mrpb.MonitoredResource
 	commonLabels   map[string]string
 	writeTimeout   time.Duration
+	ctxFunc        func() (context.Context, func())
 }
 
 // A LoggerOption is a configuration option for a Logger.
@@ -397,6 +398,23 @@
 
 func (b bufferedByteLimit) set(l *Logger) { l.bundler.BufferedByteLimit = int(b) }
 
+// ContextFunc is a function that will be called to obtain a context.Context for the
+// WriteLogEntries RPC executed in the background for calls to Logger.Log. The
+// default is a function that always returns context.Background. The second return
+// value of the function is a function to call after the RPC completes.
+//
+// The function is not used for calls to Logger.LogSync, since the caller can pass
+// in the context directly.
+//
+// This option is EXPERIMENTAL. It may be changed or removed.
+func ContextFunc(f func() (ctx context.Context, afterCall func())) LoggerOption {
+	return contextFunc(f)
+}
+
+type contextFunc func() (ctx context.Context, afterCall func())
+
+func (c contextFunc) set(l *Logger) { l.ctxFunc = c }
+
 // Logger returns a Logger that will write entries with the given log ID, such as
 // "syslog". A log ID must be less than 512 characters long and can only
 // include the following characters: upper and lower case alphanumeric
@@ -411,6 +429,7 @@
 		client:         c,
 		logName:        internal.LogPath(c.parent, logID),
 		commonResource: r,
+		ctxFunc:        func() (context.Context, func()) { return context.Background(), nil },
 	}
 	l.bundler = bundler.NewBundler(&logpb.LogEntry{}, func(entries interface{}) {
 		l.writeLogEntries(entries.([]*logpb.LogEntry))
@@ -759,12 +778,16 @@
 		Labels:   l.commonLabels,
 		Entries:  entries,
 	}
-	ctx, cancel := context.WithTimeout(context.Background(), defaultWriteTimeout)
+	ctx, afterCall := l.ctxFunc()
+	ctx, cancel := context.WithTimeout(ctx, defaultWriteTimeout)
 	defer cancel()
 	_, err := l.client.client.WriteLogEntries(ctx, req)
 	if err != nil {
 		l.client.error(err)
 	}
+	if afterCall != nil {
+		afterCall()
+	}
 }
 
 // StandardLogger returns a *log.Logger for the provided severity.
diff --git a/logging/logging_test.go b/logging/logging_test.go
index c20481e..4f05930 100644
--- a/logging/logging_test.go
+++ b/logging/logging_test.go
@@ -24,6 +24,7 @@
 	"os"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 
@@ -225,6 +226,23 @@
 	}
 }
 
+func TestContextFunc(t *testing.T) {
+	initLogs(ctx)
+	var contextFuncCalls, cleanupCalls int32 //atomic
+
+	lg := client.Logger(testLogID, logging.ContextFunc(func() (context.Context, func()) {
+		atomic.AddInt32(&contextFuncCalls, 1)
+		return context.Background(), func() { atomic.AddInt32(&cleanupCalls, 1) }
+	}))
+	lg.Log(logging.Entry{Payload: "p"})
+	lg.Flush()
+	got1 := atomic.LoadInt32(&contextFuncCalls)
+	got2 := atomic.LoadInt32(&cleanupCalls)
+	if got1 != 1 || got1 != got2 {
+		t.Errorf("got %d calls to context func, %d calls to cleanup func; want 1, 1", got1, got2)
+	}
+}
+
 // compareEntries compares most fields list of Entries against expected. compareEntries does not compare:
 //   - HTTPRequest
 //   - Operation