Exemplar: Record with sampled SpanContext in gRPC plugin. (#1127)

diff --git a/plugin/ocgrpc/client_stats_handler_test.go b/plugin/ocgrpc/client_stats_handler_test.go
index 53f9248..e9197fc 100644
--- a/plugin/ocgrpc/client_stats_handler_test.go
+++ b/plugin/ocgrpc/client_stats_handler_test.go
@@ -16,14 +16,19 @@
 package ocgrpc
 
 import (
+	"reflect"
 	"testing"
 
+	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
+
 	"go.opencensus.io/trace"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 
 	"golang.org/x/net/context"
 
+	"go.opencensus.io/metric/metricdata"
 	"go.opencensus.io/stats/view"
 	"go.opencensus.io/tag"
 
@@ -334,6 +339,72 @@
 	}
 }
 
+func TestClientRecordExemplar(t *testing.T) {
+	key, _ := tag.NewKey("test_key")
+	tagInfo := &stats.RPCTagInfo{FullMethodName: "/package.service/method"}
+	out := &stats.OutPayload{Length: 2000}
+	end := &stats.End{Error: nil}
+
+	if err := view.Register(ClientSentBytesPerRPCView); err != nil {
+		t.Error(err)
+	}
+	h := &ClientHandler{}
+	h.StartOptions.Sampler = trace.AlwaysSample()
+	ctx, err := tag.New(context.Background(), tag.Upsert(key, "test_val"))
+	if err != nil {
+		t.Error(err)
+	}
+	encoded := tag.Encode(tag.FromContext(ctx))
+	ctx = stats.SetTags(context.Background(), encoded)
+	ctx = h.TagRPC(ctx, tagInfo)
+
+	out.Client = true
+	h.HandleRPC(ctx, out)
+	end.Client = true
+	h.HandleRPC(ctx, end)
+
+	span := trace.FromContext(ctx)
+	if span == nil {
+		t.Fatal("expected non-nil span, got nil")
+	}
+	if !span.IsRecordingEvents() {
+		t.Errorf("span should be sampled")
+	}
+	attachments := map[string]interface{}{metricdata.AttachmentKeySpanContext: span.SpanContext()}
+	wantExemplar := &metricdata.Exemplar{Value: 2000, Attachments: attachments}
+
+	rows, err := view.RetrieveData(ClientSentBytesPerRPCView.Name)
+	if err != nil {
+		t.Fatal("Error RetrieveData ", err)
+	}
+	if len(rows) == 0 {
+		t.Fatal("No data was recorded.")
+	}
+	data := rows[0].Data
+	dis, ok := data.(*view.DistributionData)
+	if !ok {
+		t.Fatal("want DistributionData, got ", data)
+	}
+	// Only recorded value is 2000, which falls into the second bucket (1024, 2048].
+	wantBuckets := []int64{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+	if !reflect.DeepEqual(dis.CountPerBucket, wantBuckets) {
+		t.Errorf("want buckets %v, got %v", wantBuckets, dis.CountPerBucket)
+	}
+	for i, e := range dis.ExemplarsPerBucket {
+		// Only the second bucket should have an exemplar.
+		if i == 1 {
+			if diff := cmpExemplar(e, wantExemplar); diff != "" {
+				t.Fatalf("Unexpected Exemplar -got +want: %s", diff)
+			}
+		} else if e != nil {
+			t.Errorf("want nil exemplar, got %v", e)
+		}
+	}
+
+	// Unregister views to cleanup.
+	view.Unregister(ClientSentBytesPerRPCView)
+}
+
 // containsRow returns true if rows contain r.
 func containsRow(rows []*view.Row, r *view.Row) bool {
 	for _, x := range rows {
@@ -343,3 +414,8 @@
 	}
 	return false
 }
+
+// Compare exemplars while ignoring exemplar timestamp, since timestamp is non-deterministic.
+func cmpExemplar(got, want *metricdata.Exemplar) string {
+	return cmp.Diff(got, want, cmpopts.IgnoreFields(metricdata.Exemplar{}, "Timestamp"), cmpopts.IgnoreUnexported(metricdata.Exemplar{}))
+}
diff --git a/plugin/ocgrpc/server_stats_handler_test.go b/plugin/ocgrpc/server_stats_handler_test.go
index cab232a..921155e 100644
--- a/plugin/ocgrpc/server_stats_handler_test.go
+++ b/plugin/ocgrpc/server_stats_handler_test.go
@@ -16,11 +16,13 @@
 package ocgrpc
 
 import (
+	"reflect"
 	"testing"
 
 	"go.opencensus.io/trace"
 	"golang.org/x/net/context"
 
+	"go.opencensus.io/metric/metricdata"
 	"go.opencensus.io/stats/view"
 	"go.opencensus.io/tag"
 
@@ -334,3 +336,69 @@
 		CountPerBucket:  countPerBucket,
 	}
 }
+
+func TestServerRecordExemplar(t *testing.T) {
+	key, _ := tag.NewKey("test_key")
+	tagInfo := &stats.RPCTagInfo{FullMethodName: "/package.service/method"}
+	out := &stats.OutPayload{Length: 2000}
+	end := &stats.End{Error: nil}
+
+	if err := view.Register(ServerSentBytesPerRPCView); err != nil {
+		t.Error(err)
+	}
+	h := &ServerHandler{}
+	h.StartOptions.Sampler = trace.AlwaysSample()
+	ctx, err := tag.New(context.Background(), tag.Upsert(key, "test_val"))
+	if err != nil {
+		t.Error(err)
+	}
+	encoded := tag.Encode(tag.FromContext(ctx))
+	ctx = stats.SetTags(context.Background(), encoded)
+	ctx = h.TagRPC(ctx, tagInfo)
+
+	out.Client = false
+	h.HandleRPC(ctx, out)
+	end.Client = false
+	h.HandleRPC(ctx, end)
+
+	span := trace.FromContext(ctx)
+	if span == nil {
+		t.Fatal("expected non-nil span, got nil")
+	}
+	if !span.IsRecordingEvents() {
+		t.Errorf("span should be sampled")
+	}
+	attachments := map[string]interface{}{metricdata.AttachmentKeySpanContext: span.SpanContext()}
+	wantExemplar := &metricdata.Exemplar{Value: 2000, Attachments: attachments}
+
+	rows, err := view.RetrieveData(ServerSentBytesPerRPCView.Name)
+	if err != nil {
+		t.Fatal("Error RetrieveData ", err)
+	}
+	if len(rows) == 0 {
+		t.Fatal("No data was recorded.")
+	}
+	data := rows[0].Data
+	dis, ok := data.(*view.DistributionData)
+	if !ok {
+		t.Fatal("want DistributionData, got ", data)
+	}
+	// Only recorded value is 2000, which falls into the second bucket (1024, 2048].
+	wantBuckets := []int64{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+	if !reflect.DeepEqual(dis.CountPerBucket, wantBuckets) {
+		t.Errorf("want buckets %v, got %v", wantBuckets, dis.CountPerBucket)
+	}
+	for i, e := range dis.ExemplarsPerBucket {
+		// Only the second bucket should have an exemplar.
+		if i == 1 {
+			if diff := cmpExemplar(e, wantExemplar); diff != "" {
+				t.Fatalf("Unexpected Exemplar -got +want: %s", diff)
+			}
+		} else if e != nil {
+			t.Errorf("want nil exemplar, got %v", e)
+		}
+	}
+
+	// Unregister views to cleanup.
+	view.Unregister(ServerSentBytesPerRPCView)
+}
diff --git a/plugin/ocgrpc/stats_common.go b/plugin/ocgrpc/stats_common.go
index e9991fe..0ae5691 100644
--- a/plugin/ocgrpc/stats_common.go
+++ b/plugin/ocgrpc/stats_common.go
@@ -22,9 +22,11 @@
 	"sync/atomic"
 	"time"
 
+	"go.opencensus.io/metric/metricdata"
 	ocstats "go.opencensus.io/stats"
 	"go.opencensus.io/stats/view"
 	"go.opencensus.io/tag"
+	"go.opencensus.io/trace"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/grpclog"
 	"google.golang.org/grpc/stats"
@@ -141,27 +143,31 @@
 	}
 
 	latencyMillis := float64(elapsedTime) / float64(time.Millisecond)
+	attachments := getSpanCtxAttachment(ctx)
 	if s.Client {
-		ocstats.RecordWithTags(ctx,
-			[]tag.Mutator{
+		ocstats.RecordWithOptions(ctx,
+			ocstats.WithTags(
 				tag.Upsert(KeyClientMethod, methodName(d.method)),
-				tag.Upsert(KeyClientStatus, st),
-			},
-			ClientSentBytesPerRPC.M(atomic.LoadInt64(&d.sentBytes)),
-			ClientSentMessagesPerRPC.M(atomic.LoadInt64(&d.sentCount)),
-			ClientReceivedMessagesPerRPC.M(atomic.LoadInt64(&d.recvCount)),
-			ClientReceivedBytesPerRPC.M(atomic.LoadInt64(&d.recvBytes)),
-			ClientRoundtripLatency.M(latencyMillis))
+				tag.Upsert(KeyClientStatus, st)),
+			ocstats.WithAttachments(attachments),
+			ocstats.WithMeasurements(
+				ClientSentBytesPerRPC.M(atomic.LoadInt64(&d.sentBytes)),
+				ClientSentMessagesPerRPC.M(atomic.LoadInt64(&d.sentCount)),
+				ClientReceivedMessagesPerRPC.M(atomic.LoadInt64(&d.recvCount)),
+				ClientReceivedBytesPerRPC.M(atomic.LoadInt64(&d.recvBytes)),
+				ClientRoundtripLatency.M(latencyMillis)))
 	} else {
-		ocstats.RecordWithTags(ctx,
-			[]tag.Mutator{
+		ocstats.RecordWithOptions(ctx,
+			ocstats.WithTags(
 				tag.Upsert(KeyServerStatus, st),
-			},
-			ServerSentBytesPerRPC.M(atomic.LoadInt64(&d.sentBytes)),
-			ServerSentMessagesPerRPC.M(atomic.LoadInt64(&d.sentCount)),
-			ServerReceivedMessagesPerRPC.M(atomic.LoadInt64(&d.recvCount)),
-			ServerReceivedBytesPerRPC.M(atomic.LoadInt64(&d.recvBytes)),
-			ServerLatency.M(latencyMillis))
+			),
+			ocstats.WithAttachments(attachments),
+			ocstats.WithMeasurements(
+				ServerSentBytesPerRPC.M(atomic.LoadInt64(&d.sentBytes)),
+				ServerSentMessagesPerRPC.M(atomic.LoadInt64(&d.sentCount)),
+				ServerReceivedMessagesPerRPC.M(atomic.LoadInt64(&d.recvCount)),
+				ServerReceivedBytesPerRPC.M(atomic.LoadInt64(&d.recvBytes)),
+				ServerLatency.M(latencyMillis)))
 	}
 }
 
@@ -206,3 +212,16 @@
 		return "CODE_" + strconv.FormatInt(int64(c), 10)
 	}
 }
+
+func getSpanCtxAttachment(ctx context.Context) metricdata.Attachments {
+	attachments := map[string]interface{}{}
+	span := trace.FromContext(ctx)
+	if span == nil {
+		return attachments
+	}
+	spanCtx := span.SpanContext()
+	if spanCtx.IsSampled() {
+		attachments[metricdata.AttachmentKeySpanContext] = spanCtx
+	}
+	return attachments
+}