| // Copyright 2018, OpenCensus Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package ochttp |
| |
| import ( |
| "bytes" |
| "context" |
| "encoding/hex" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "log" |
| "net" |
| "net/http" |
| "net/http/httptest" |
| "net/url" |
| "reflect" |
| "strings" |
| "testing" |
| "time" |
| |
| "go.opencensus.io/plugin/ochttp/propagation/b3" |
| "go.opencensus.io/plugin/ochttp/propagation/tracecontext" |
| "go.opencensus.io/trace" |
| ) |
| |
| type testExporter struct { |
| spans []*trace.SpanData |
| } |
| |
| func (t *testExporter) ExportSpan(s *trace.SpanData) { |
| t.spans = append(t.spans, s) |
| } |
| |
| type testTransport struct { |
| ch chan *http.Request |
| } |
| |
| func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
| t.ch <- req |
| return nil, errors.New("noop") |
| } |
| |
| type testPropagator struct{} |
| |
| func (t testPropagator) SpanContextFromRequest(req *http.Request) (sc trace.SpanContext, ok bool) { |
| header := req.Header.Get("trace") |
| buf, err := hex.DecodeString(header) |
| if err != nil { |
| log.Fatalf("Cannot decode trace header: %q", header) |
| } |
| r := bytes.NewReader(buf) |
| r.Read(sc.TraceID[:]) |
| r.Read(sc.SpanID[:]) |
| opts, err := r.ReadByte() |
| if err != nil { |
| log.Fatalf("Cannot read trace options from trace header: %q", header) |
| } |
| sc.TraceOptions = trace.TraceOptions(opts) |
| return sc, true |
| } |
| |
| func (t testPropagator) SpanContextToRequest(sc trace.SpanContext, req *http.Request) { |
| var buf bytes.Buffer |
| buf.Write(sc.TraceID[:]) |
| buf.Write(sc.SpanID[:]) |
| buf.WriteByte(byte(sc.TraceOptions)) |
| req.Header.Set("trace", hex.EncodeToString(buf.Bytes())) |
| } |
| |
| func TestTransport_RoundTrip_Race(t *testing.T) { |
| // This tests that we don't modify the request in accordance with the |
| // specification for http.RoundTripper. |
| // We attempt to trigger a race by reading the request from a separate |
| // goroutine. If the request is modified by Transport, this should trigger |
| // the race detector. |
| |
| transport := &testTransport{ch: make(chan *http.Request, 1)} |
| rt := &Transport{ |
| Propagation: &testPropagator{}, |
| Base: transport, |
| } |
| req, _ := http.NewRequest("GET", "http://foo.com", nil) |
| go func() { |
| fmt.Println(*req) |
| }() |
| rt.RoundTrip(req) |
| _ = <-transport.ch |
| } |
| |
| func TestTransport_RoundTrip(t *testing.T) { |
| _, parent := trace.StartSpan(context.Background(), "parent") |
| tests := []struct { |
| name string |
| parent *trace.Span |
| }{ |
| { |
| name: "no parent", |
| parent: nil, |
| }, |
| { |
| name: "parent", |
| parent: parent, |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| transport := &testTransport{ch: make(chan *http.Request, 1)} |
| |
| rt := &Transport{ |
| Propagation: &testPropagator{}, |
| Base: transport, |
| } |
| |
| req, _ := http.NewRequest("GET", "http://foo.com", nil) |
| if tt.parent != nil { |
| req = req.WithContext(trace.NewContext(req.Context(), tt.parent)) |
| } |
| rt.RoundTrip(req) |
| |
| req = <-transport.ch |
| span := trace.FromContext(req.Context()) |
| |
| if header := req.Header.Get("trace"); header == "" { |
| t.Fatalf("Trace header = empty; want valid trace header") |
| } |
| if span == nil { |
| t.Fatalf("Got no spans in req context; want one") |
| } |
| if tt.parent != nil { |
| if got, want := span.SpanContext().TraceID, tt.parent.SpanContext().TraceID; got != want { |
| t.Errorf("span.SpanContext().TraceID=%v; want %v", got, want) |
| } |
| } |
| }) |
| } |
| } |
| |
| func TestHandler(t *testing.T) { |
| traceID := [16]byte{16, 84, 69, 170, 120, 67, 188, 139, 242, 6, 177, 32, 0, 16, 0, 0} |
| tests := []struct { |
| header string |
| wantTraceID trace.TraceID |
| wantTraceOptions trace.TraceOptions |
| }{ |
| { |
| header: "105445aa7843bc8bf206b12000100000000000000000000000", |
| wantTraceID: traceID, |
| wantTraceOptions: trace.TraceOptions(0), |
| }, |
| { |
| header: "105445aa7843bc8bf206b12000100000000000000000000001", |
| wantTraceID: traceID, |
| wantTraceOptions: trace.TraceOptions(1), |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.header, func(t *testing.T) { |
| handler := &Handler{ |
| Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| span := trace.FromContext(r.Context()) |
| sc := span.SpanContext() |
| if got, want := sc.TraceID, tt.wantTraceID; got != want { |
| t.Errorf("TraceID = %q; want %q", got, want) |
| } |
| if got, want := sc.TraceOptions, tt.wantTraceOptions; got != want { |
| t.Errorf("TraceOptions = %v; want %v", got, want) |
| } |
| }), |
| StartOptions: trace.StartOptions{Sampler: trace.ProbabilitySampler(0.0)}, |
| Propagation: &testPropagator{}, |
| } |
| req, _ := http.NewRequest("GET", "http://foo.com", nil) |
| req.Header.Add("trace", tt.header) |
| handler.ServeHTTP(nil, req) |
| }) |
| } |
| } |
| |
| var _ http.RoundTripper = (*traceTransport)(nil) |
| |
| type collector []*trace.SpanData |
| |
| func (c *collector) ExportSpan(s *trace.SpanData) { |
| *c = append(*c, s) |
| } |
| |
| func TestEndToEnd(t *testing.T) { |
| tc := []struct { |
| name string |
| handler *Handler |
| transport *Transport |
| wantSameTraceID bool |
| wantLinks bool // expect a link between client and server span |
| }{ |
| { |
| name: "internal default propagation", |
| handler: &Handler{}, |
| transport: &Transport{}, |
| wantSameTraceID: true, |
| }, |
| { |
| name: "external default propagation", |
| handler: &Handler{IsPublicEndpoint: true}, |
| transport: &Transport{}, |
| wantSameTraceID: false, |
| wantLinks: true, |
| }, |
| { |
| name: "internal TraceContext propagation", |
| handler: &Handler{Propagation: &tracecontext.HTTPFormat{}}, |
| transport: &Transport{Propagation: &tracecontext.HTTPFormat{}}, |
| wantSameTraceID: true, |
| }, |
| { |
| name: "misconfigured propagation", |
| handler: &Handler{IsPublicEndpoint: true, Propagation: &tracecontext.HTTPFormat{}}, |
| transport: &Transport{Propagation: &b3.HTTPFormat{}}, |
| wantSameTraceID: false, |
| wantLinks: false, |
| }, |
| } |
| |
| for _, tt := range tc { |
| t.Run(tt.name, func(t *testing.T) { |
| var spans collector |
| trace.RegisterExporter(&spans) |
| defer trace.UnregisterExporter(&spans) |
| |
| // Start the server. |
| serverDone := make(chan struct{}) |
| serverReturn := make(chan time.Time) |
| tt.handler.StartOptions.Sampler = trace.AlwaysSample() |
| url := serveHTTP(tt.handler, serverDone, serverReturn, 200) |
| |
| ctx := context.Background() |
| // Make the request. |
| req, err := http.NewRequest( |
| http.MethodPost, |
| fmt.Sprintf("%s/example/url/path?qparam=val", url), |
| strings.NewReader("expected-request-body")) |
| if err != nil { |
| t.Fatal(err) |
| } |
| req = req.WithContext(ctx) |
| tt.transport.StartOptions.Sampler = trace.AlwaysSample() |
| c := &http.Client{ |
| Transport: tt.transport, |
| } |
| resp, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if resp.StatusCode != http.StatusOK { |
| t.Fatalf("resp.StatusCode = %d", resp.StatusCode) |
| } |
| |
| // Tell the server to return from request handling. |
| serverReturn <- time.Now().Add(time.Millisecond) |
| |
| respBody, err := ioutil.ReadAll(resp.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if got, want := string(respBody), "expected-response"; got != want { |
| t.Fatalf("respBody = %q; want %q", got, want) |
| } |
| |
| resp.Body.Close() |
| |
| <-serverDone |
| trace.UnregisterExporter(&spans) |
| |
| if got, want := len(spans), 2; got != want { |
| t.Fatalf("len(spans) = %d; want %d", got, want) |
| } |
| |
| var client, server *trace.SpanData |
| for _, sp := range spans { |
| switch sp.SpanKind { |
| case trace.SpanKindClient: |
| client = sp |
| if got, want := client.Name, "/example/url/path"; got != want { |
| t.Errorf("Span name: %q; want %q", got, want) |
| } |
| case trace.SpanKindServer: |
| server = sp |
| if got, want := server.Name, "/example/url/path"; got != want { |
| t.Errorf("Span name: %q; want %q", got, want) |
| } |
| default: |
| t.Fatalf("server or client span missing; kind = %v", sp.SpanKind) |
| } |
| } |
| |
| if tt.wantSameTraceID { |
| if server.TraceID != client.TraceID { |
| t.Errorf("TraceID does not match: server.TraceID=%q client.TraceID=%q", server.TraceID, client.TraceID) |
| } |
| if !server.HasRemoteParent { |
| t.Errorf("server span should have remote parent") |
| } |
| if server.ParentSpanID != client.SpanID { |
| t.Errorf("server span should have client span as parent") |
| } |
| } |
| if !tt.wantSameTraceID { |
| if server.TraceID == client.TraceID { |
| t.Errorf("TraceID should not be trusted") |
| } |
| } |
| if tt.wantLinks { |
| if got, want := len(server.Links), 1; got != want { |
| t.Errorf("len(server.Links) = %d; want %d", got, want) |
| } else { |
| link := server.Links[0] |
| if got, want := link.Type, trace.LinkTypeParent; got != want { |
| t.Errorf("link.Type = %v; want %v", got, want) |
| } |
| } |
| } |
| if server.StartTime.Before(client.StartTime) { |
| t.Errorf("server span starts before client span") |
| } |
| if server.EndTime.After(client.EndTime) { |
| t.Errorf("client span ends before server span") |
| } |
| }) |
| } |
| } |
| |
| func serveHTTP(handler *Handler, done chan struct{}, wait chan time.Time, statusCode int) string { |
| handler.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.WriteHeader(statusCode) |
| w.(http.Flusher).Flush() |
| |
| // Simulate a slow-responding server. |
| sleepUntil := <-wait |
| for time.Now().Before(sleepUntil) { |
| time.Sleep(sleepUntil.Sub(time.Now())) |
| } |
| |
| io.WriteString(w, "expected-response") |
| close(done) |
| }) |
| server := httptest.NewServer(handler) |
| go func() { |
| <-done |
| server.Close() |
| }() |
| return server.URL |
| } |
| |
| func TestSpanNameFromURL(t *testing.T) { |
| tests := []struct { |
| u string |
| want string |
| }{ |
| { |
| u: "http://localhost:80/hello?q=a", |
| want: "/hello", |
| }, |
| { |
| u: "/a/b?q=c", |
| want: "/a/b", |
| }, |
| } |
| for _, tt := range tests { |
| t.Run(tt.u, func(t *testing.T) { |
| req, err := http.NewRequest("GET", tt.u, nil) |
| if err != nil { |
| t.Errorf("url issue = %v", err) |
| } |
| if got := spanNameFromURL(req); got != tt.want { |
| t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want) |
| } |
| }) |
| } |
| } |
| |
| func TestFormatSpanName(t *testing.T) { |
| formatSpanName := func(r *http.Request) string { |
| return r.Method + " " + r.URL.Path |
| } |
| |
| handler := &Handler{ |
| Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { |
| resp.Write([]byte("Hello, world!")) |
| }), |
| FormatSpanName: formatSpanName, |
| } |
| |
| server := httptest.NewServer(handler) |
| defer server.Close() |
| |
| client := &http.Client{ |
| Transport: &Transport{ |
| FormatSpanName: formatSpanName, |
| StartOptions: trace.StartOptions{ |
| Sampler: trace.AlwaysSample(), |
| }, |
| }, |
| } |
| |
| tests := []struct { |
| u string |
| want string |
| }{ |
| { |
| u: "/hello?q=a", |
| want: "GET /hello", |
| }, |
| { |
| u: "/a/b?q=c", |
| want: "GET /a/b", |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.u, func(t *testing.T) { |
| var te testExporter |
| trace.RegisterExporter(&te) |
| res, err := client.Get(server.URL + tt.u) |
| if err != nil { |
| t.Fatalf("error creating request: %v", err) |
| } |
| res.Body.Close() |
| trace.UnregisterExporter(&te) |
| if want, got := 2, len(te.spans); want != got { |
| t.Fatalf("got exported spans %#v, wanted two spans", te.spans) |
| } |
| if got := te.spans[0].Name; got != tt.want { |
| t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want) |
| } |
| if got := te.spans[1].Name; got != tt.want { |
| t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want) |
| } |
| }) |
| } |
| } |
| |
| func TestRequestAttributes(t *testing.T) { |
| tests := []struct { |
| name string |
| makeReq func() *http.Request |
| wantAttrs []trace.Attribute |
| }{ |
| { |
| name: "GET example.com/hello", |
| makeReq: func() *http.Request { |
| req, _ := http.NewRequest("GET", "http://example.com:779/hello", nil) |
| req.Header.Add("User-Agent", "ua") |
| return req |
| }, |
| wantAttrs: []trace.Attribute{ |
| trace.StringAttribute("http.path", "/hello"), |
| trace.StringAttribute("http.url", "http://example.com:779/hello"), |
| trace.StringAttribute("http.host", "example.com:779"), |
| trace.StringAttribute("http.method", "GET"), |
| trace.StringAttribute("http.user_agent", "ua"), |
| }, |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| req := tt.makeReq() |
| attrs := requestAttrs(req) |
| |
| if got, want := attrs, tt.wantAttrs; !reflect.DeepEqual(got, want) { |
| t.Errorf("Request attributes = %#v; want %#v", got, want) |
| } |
| }) |
| } |
| } |
| |
| func TestResponseAttributes(t *testing.T) { |
| tests := []struct { |
| name string |
| resp *http.Response |
| wantAttrs []trace.Attribute |
| }{ |
| { |
| name: "non-zero HTTP 200 response", |
| resp: &http.Response{StatusCode: 200}, |
| wantAttrs: []trace.Attribute{ |
| trace.Int64Attribute("http.status_code", 200), |
| }, |
| }, |
| { |
| name: "zero HTTP 500 response", |
| resp: &http.Response{StatusCode: 500}, |
| wantAttrs: []trace.Attribute{ |
| trace.Int64Attribute("http.status_code", 500), |
| }, |
| }, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| attrs := responseAttrs(tt.resp) |
| if got, want := attrs, tt.wantAttrs; !reflect.DeepEqual(got, want) { |
| t.Errorf("Response attributes = %#v; want %#v", got, want) |
| } |
| }) |
| } |
| } |
| |
| type TestCase struct { |
| Name string |
| Method string |
| URL string |
| Headers map[string]string |
| ResponseCode int |
| SpanName string |
| SpanStatus string |
| SpanKind string |
| SpanAttributes map[string]string |
| } |
| |
| func TestAgainstSpecs(t *testing.T) { |
| |
| fmt.Println("start") |
| |
| dat, err := ioutil.ReadFile("testdata/http-out-test-cases.json") |
| if err != nil { |
| t.Fatalf("error reading file: %v", err) |
| } |
| |
| tests := make([]TestCase, 0) |
| err = json.Unmarshal(dat, &tests) |
| if err != nil { |
| t.Fatalf("error parsing json: %v", err) |
| } |
| |
| trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()}) |
| |
| for _, tt := range tests { |
| t.Run(tt.Name, func(t *testing.T) { |
| var spans collector |
| trace.RegisterExporter(&spans) |
| defer trace.UnregisterExporter(&spans) |
| |
| handler := &Handler{} |
| transport := &Transport{} |
| |
| serverDone := make(chan struct{}) |
| serverReturn := make(chan time.Time) |
| host := "" |
| port := "" |
| serverRequired := strings.Contains(tt.URL, "{") |
| if serverRequired { |
| // Start the server. |
| localServerURL := serveHTTP(handler, serverDone, serverReturn, tt.ResponseCode) |
| u, _ := url.Parse(localServerURL) |
| host, port, _ = net.SplitHostPort(u.Host) |
| |
| tt.URL = strings.Replace(tt.URL, "{host}", host, 1) |
| tt.URL = strings.Replace(tt.URL, "{port}", port, 1) |
| } |
| |
| // Start a root Span in the client. |
| ctx, _ := trace.StartSpan( |
| context.Background(), |
| "top-level") |
| // Make the request. |
| req, err := http.NewRequest( |
| tt.Method, |
| tt.URL, |
| nil) |
| for headerName, headerValue := range tt.Headers { |
| req.Header.Add(headerName, headerValue) |
| } |
| if err != nil { |
| t.Fatal(err) |
| } |
| req = req.WithContext(ctx) |
| resp, err := transport.RoundTrip(req) |
| if err != nil { |
| // do not fail. We want to validate DNS issues |
| //t.Fatal(err) |
| } |
| |
| if serverRequired { |
| // Tell the server to return from request handling. |
| serverReturn <- time.Now().Add(time.Millisecond) |
| } |
| |
| if resp != nil { |
| // If it simply closes body without reading |
| // synchronization problem may happen for spans slice. |
| // Server span and client span will write themselves |
| // at the same time |
| ioutil.ReadAll(resp.Body) |
| resp.Body.Close() |
| if serverRequired { |
| <-serverDone |
| } |
| } |
| trace.UnregisterExporter(&spans) |
| |
| var client *trace.SpanData |
| for _, sp := range spans { |
| if sp.SpanKind == trace.SpanKindClient { |
| client = sp |
| } |
| } |
| |
| if client.Name != tt.SpanName { |
| t.Errorf("span names don't match: expected: %s, actual: %s", tt.SpanName, client.Name) |
| } |
| |
| spanKindToStr := map[int]string{ |
| trace.SpanKindClient: "Client", |
| trace.SpanKindServer: "Server", |
| } |
| |
| if !strings.EqualFold(codeToStr[client.Status.Code], tt.SpanStatus) { |
| t.Errorf("span status don't match: expected: %s, actual: %d (%s)", tt.SpanStatus, client.Status.Code, codeToStr[client.Status.Code]) |
| } |
| |
| if !strings.EqualFold(spanKindToStr[client.SpanKind], tt.SpanKind) { |
| t.Errorf("span kind don't match: expected: %s, actual: %d (%s)", tt.SpanKind, client.SpanKind, spanKindToStr[client.SpanKind]) |
| } |
| |
| normalizedActualAttributes := map[string]string{} |
| for k, v := range client.Attributes { |
| normalizedActualAttributes[k] = fmt.Sprintf("%v", v) |
| } |
| |
| normalizedExpectedAttributes := map[string]string{} |
| for k, v := range tt.SpanAttributes { |
| normalizedValue := v |
| normalizedValue = strings.Replace(normalizedValue, "{host}", host, 1) |
| normalizedValue = strings.Replace(normalizedValue, "{port}", port, 1) |
| |
| normalizedExpectedAttributes[k] = normalizedValue |
| } |
| |
| if got, want := normalizedActualAttributes, normalizedExpectedAttributes; !reflect.DeepEqual(got, want) { |
| t.Errorf("Request attributes = %#v; want %#v", got, want) |
| } |
| }) |
| } |
| } |
| |
| func TestStatusUnitTest(t *testing.T) { |
| tests := []struct { |
| in int |
| want trace.Status |
| }{ |
| {200, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}}, |
| {204, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}}, |
| {100, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}}, |
| {500, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}}, |
| {400, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: `INVALID_ARGUMENT`}}, |
| {422, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: `INVALID_ARGUMENT`}}, |
| {499, trace.Status{Code: trace.StatusCodeCancelled, Message: `CANCELLED`}}, |
| {404, trace.Status{Code: trace.StatusCodeNotFound, Message: `NOT_FOUND`}}, |
| {600, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}}, |
| {401, trace.Status{Code: trace.StatusCodeUnauthenticated, Message: `UNAUTHENTICATED`}}, |
| {403, trace.Status{Code: trace.StatusCodePermissionDenied, Message: `PERMISSION_DENIED`}}, |
| {301, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}}, |
| {501, trace.Status{Code: trace.StatusCodeUnimplemented, Message: `UNIMPLEMENTED`}}, |
| } |
| |
| for _, tt := range tests { |
| got, want := TraceStatus(tt.in, ""), tt.want |
| if got != want { |
| t.Errorf("status(%d) got = (%#v) want = (%#v)", tt.in, got, want) |
| } |
| } |
| } |