| // Copyright 2017 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // +build go1.7 |
| |
| package trace |
| |
| import ( |
| "io/ioutil" |
| "net/http" |
| "net/http/httptest" |
| "strings" |
| "testing" |
| ) |
| |
| type recorderTransport struct { |
| ch chan *http.Request |
| } |
| |
| func (rt *recorderTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
| rt.ch <- req |
| resp := &http.Response{ |
| Status: "200 OK", |
| StatusCode: 200, |
| Body: ioutil.NopCloser(strings.NewReader("{}")), |
| } |
| return resp, nil |
| } |
| |
| func TestNewHTTPClient(t *testing.T) { |
| rt := &recorderTransport{ |
| ch: make(chan *http.Request, 1), |
| } |
| |
| tc := newTestClient(&noopTransport{}) |
| client := &http.Client{ |
| Transport: &Transport{ |
| Base: rt, |
| }, |
| } |
| req, _ := http.NewRequest("GET", "http://example.com", nil) |
| |
| t.Run("NoTrace", func(t *testing.T) { |
| _, err := client.Do(req) |
| if err != nil { |
| t.Error(err) |
| } |
| outgoing := <-rt.ch |
| if got, want := outgoing.Header.Get(httpHeader), ""; want != got { |
| t.Errorf("got trace header = %q; want none", got) |
| } |
| }) |
| |
| t.Run("Trace", func(t *testing.T) { |
| span := tc.NewSpan("/foo") |
| |
| req = req.WithContext(NewContext(req.Context(), span)) |
| _, err := client.Do(req) |
| if err != nil { |
| t.Error(err) |
| } |
| outgoing := <-rt.ch |
| |
| s := tc.SpanFromHeader("/foo", outgoing.Header.Get(httpHeader)) |
| if got, want := s.TraceID(), span.TraceID(); got != want { |
| t.Errorf("trace ID = %q; want %q", got, want) |
| } |
| }) |
| } |
| |
| func TestHTTPHandlerNoTrace(t *testing.T) { |
| tc := newTestClient(&noopTransport{}) |
| client := &http.Client{ |
| Transport: &Transport{}, |
| } |
| handler := tc.HTTPHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| span := FromContext(r.Context()) |
| if span == nil { |
| t.Errorf("span is nil; want non-nil span") |
| } |
| })) |
| |
| ts := httptest.NewServer(handler) |
| defer ts.Close() |
| |
| req, _ := http.NewRequest("GET", ts.URL, nil) |
| _, err := client.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| func TestHTTPHandler_response(t *testing.T) { |
| tc := newTestClient(&noopTransport{}) |
| p, _ := NewLimitedSampler(1, 1<<32) // all |
| tc.SetSamplingPolicy(p) |
| handler := tc.HTTPHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) |
| ts := httptest.NewServer(handler) |
| defer ts.Close() |
| |
| tests := []struct { |
| name string |
| traceHeader string |
| wantTraceHeader string |
| }{ |
| { |
| name: "no global", |
| traceHeader: "0123456789ABCDEF0123456789ABCDEF/123", |
| wantTraceHeader: "0123456789ABCDEF0123456789ABCDEF/123;o=1", |
| }, |
| { |
| name: "global=1", |
| traceHeader: "0123456789ABCDEF0123456789ABCDEF/123;o=1", |
| wantTraceHeader: "", |
| }, |
| { |
| name: "global=0", |
| traceHeader: "0123456789ABCDEF0123456789ABCDEF/123;o=0", |
| wantTraceHeader: "", |
| }, |
| { |
| name: "no trace context", |
| traceHeader: "", |
| wantTraceHeader: "", |
| }, |
| } |
| |
| for _, tt := range tests { |
| req, _ := http.NewRequest("GET", ts.URL, nil) |
| req.Header.Set(httpHeader, tt.traceHeader) |
| |
| res, err := http.DefaultClient.Do(req) |
| if err != nil { |
| t.Errorf("failed to request: %v", err) |
| } |
| if got, want := res.Header.Get(httpHeader), tt.wantTraceHeader; got != want { |
| t.Errorf("%v: response context header = %q; want %q", tt.name, got, want) |
| } |
| } |
| } |