blob: 4cca1bcb5e04af994b2fc56bf7387f5fa54f37f5 [file] [log] [blame]
package ochttp
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"golang.org/x/net/http2"
"go.opencensus.io/stats/view"
"go.opencensus.io/trace"
)
func httpHandler(statusCode, respSize int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(statusCode)
body := make([]byte, respSize)
w.Write(body)
})
}
func updateMean(mean float64, sample, count int) float64 {
if count == 1 {
return float64(sample)
}
return mean + (float64(sample)-mean)/float64(count)
}
func TestHandlerStatsCollection(t *testing.T) {
if err := view.Register(DefaultServerViews...); err != nil {
t.Fatalf("Failed to register ochttp.DefaultServerViews error: %v", err)
}
views := []string{
"opencensus.io/http/server/request_count",
"opencensus.io/http/server/latency",
"opencensus.io/http/server/request_bytes",
"opencensus.io/http/server/response_bytes",
}
// TODO: test latency measurements?
tests := []struct {
name, method, target string
count, statusCode, reqSize, respSize int
}{
{"get 200", "GET", "http://opencensus.io/request/one", 10, 200, 512, 512},
{"post 503", "POST", "http://opencensus.io/request/two", 5, 503, 1024, 16384},
{"no body 302", "GET", "http://opencensus.io/request/three", 2, 302, 0, 0},
}
totalCount, meanReqSize, meanRespSize := 0, 0.0, 0.0
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
body := bytes.NewBuffer(make([]byte, test.reqSize))
r := httptest.NewRequest(test.method, test.target, body)
w := httptest.NewRecorder()
mux := http.NewServeMux()
mux.Handle("/request/", httpHandler(test.statusCode, test.respSize))
h := &Handler{
Handler: mux,
StartOptions: trace.StartOptions{
Sampler: trace.NeverSample(),
},
}
for i := 0; i < test.count; i++ {
h.ServeHTTP(w, r)
totalCount++
// Distributions do not track sum directly, we must
// mimic their behaviour to avoid rounding failures.
meanReqSize = updateMean(meanReqSize, test.reqSize, totalCount)
meanRespSize = updateMean(meanRespSize, test.respSize, totalCount)
}
})
}
for _, viewName := range views {
v := view.Find(viewName)
if v == nil {
t.Errorf("view not found %q", viewName)
continue
}
rows, err := view.RetrieveData(viewName)
if err != nil {
t.Error(err)
continue
}
if got, want := len(rows), 1; got != want {
t.Errorf("len(%q) = %d; want %d", viewName, got, want)
continue
}
data := rows[0].Data
var count int
var sum float64
switch data := data.(type) {
case *view.CountData:
count = int(data.Value)
case *view.DistributionData:
count = int(data.Count)
sum = data.Sum()
default:
t.Errorf("Unknown data type: %v", data)
continue
}
if got, want := count, totalCount; got != want {
t.Fatalf("%s = %d; want %d", viewName, got, want)
}
// We can only check sum for distribution views.
switch viewName {
case "opencensus.io/http/server/request_bytes":
if got, want := sum, meanReqSize*float64(totalCount); got != want {
t.Fatalf("%s = %g; want %g", viewName, got, want)
}
case "opencensus.io/http/server/response_bytes":
if got, want := sum, meanRespSize*float64(totalCount); got != want {
t.Fatalf("%s = %g; want %g", viewName, got, want)
}
}
}
}
type testResponseWriterHijacker struct {
httptest.ResponseRecorder
}
func (trw *testResponseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, nil
}
func TestUnitTestHandlerProxiesHijack(t *testing.T) {
tests := []struct {
w http.ResponseWriter
hasHijack bool
}{
{httptest.NewRecorder(), false},
{nil, false},
{new(testResponseWriterHijacker), true},
}
for i, tt := range tests {
tw := &trackingResponseWriter{writer: tt.w}
w := tw.wrappedResponseWriter()
_, ttHijacker := w.(http.Hijacker)
if want, have := tt.hasHijack, ttHijacker; want != have {
t.Errorf("#%d Hijack got %t, want %t", i, have, want)
}
}
}
// Integration test with net/http to ensure that our Handler proxies to its
// response the call to (http.Hijack).Hijacker() and that that successfully
// passes with HTTP/1.1 connections. See Issue #642
func TestHandlerProxiesHijack_HTTP1(t *testing.T) {
cst := httptest.NewServer(&Handler{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var writeMsg func(string)
defer func() {
err := recover()
writeMsg(fmt.Sprintf("Proto=%s\npanic=%v", r.Proto, err != nil))
}()
conn, _, _ := w.(http.Hijacker).Hijack()
writeMsg = func(msg string) {
fmt.Fprintf(conn, "%s 200\nContentLength: %d", r.Proto, len(msg))
fmt.Fprintf(conn, "\r\n\r\n%s", msg)
conn.Close()
}
}),
})
defer cst.Close()
testCases := []struct {
name string
tr *http.Transport
want string
}{
{
name: "http1-transport",
tr: new(http.Transport),
want: "Proto=HTTP/1.1\npanic=false",
},
{
name: "http2-transport",
tr: func() *http.Transport {
tr := new(http.Transport)
http2.ConfigureTransport(tr)
return tr
}(),
want: "Proto=HTTP/1.1\npanic=false",
},
}
for _, tc := range testCases {
c := &http.Client{Transport: &Transport{Base: tc.tr}}
res, err := c.Get(cst.URL)
if err != nil {
t.Errorf("(%s) unexpected error %v", tc.name, err)
continue
}
blob, _ := ioutil.ReadAll(res.Body)
res.Body.Close()
if g, w := string(blob), tc.want; g != w {
t.Errorf("(%s) got = %q; want = %q", tc.name, g, w)
}
}
}
// Integration test with net/http, x/net/http2 to ensure that our Handler proxies
// to its response the call to (http.Hijack).Hijacker() and that that crashes
// since http.Hijacker and HTTP/2.0 connections are incompatible, but the
// detection is only at runtime and ensure that we can stream and flush to the
// connection even after invoking Hijack(). See Issue #642.
func TestHandlerProxiesHijack_HTTP2(t *testing.T) {
cst := httptest.NewUnstartedServer(&Handler{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := w.(http.Hijacker); ok {
conn, _, err := w.(http.Hijacker).Hijack()
if conn != nil {
data := fmt.Sprintf("Surprisingly got the Hijacker() Proto: %s", r.Proto)
fmt.Fprintf(conn, "%s 200\nContent-Length:%d\r\n\r\n%s", r.Proto, len(data), data)
conn.Close()
return
}
switch {
case err == nil:
fmt.Fprintf(w, "Unexpectedly did not encounter an error!")
default:
fmt.Fprintf(w, "Unexpected error: %v", err)
case strings.Contains(err.(error).Error(), "Hijack"):
// Confirmed HTTP/2.0, let's stream to it
for i := 0; i < 5; i++ {
fmt.Fprintf(w, "%d\n", i)
w.(http.Flusher).Flush()
}
}
} else {
// Confirmed HTTP/2.0, let's stream to it
for i := 0; i < 5; i++ {
fmt.Fprintf(w, "%d\n", i)
w.(http.Flusher).Flush()
}
}
}),
})
cst.TLS = &tls.Config{NextProtos: []string{"h2"}}
cst.StartTLS()
defer cst.Close()
if wantPrefix := "https://"; !strings.HasPrefix(cst.URL, wantPrefix) {
t.Fatalf("URL got = %q wantPrefix = %q", cst.URL, wantPrefix)
}
tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
http2.ConfigureTransport(tr)
c := &http.Client{Transport: tr}
res, err := c.Get(cst.URL)
if err != nil {
t.Fatalf("Unexpected error %v", err)
}
blob, _ := ioutil.ReadAll(res.Body)
res.Body.Close()
if g, w := string(blob), "0\n1\n2\n3\n4\n"; g != w {
t.Errorf("got = %q; want = %q", g, w)
}
}
func TestEnsureTrackingResponseWriterSetsStatusCode(t *testing.T) {
// Ensure that the trackingResponseWriter always sets the spanStatus on ending the span.
// Because we can only examine the Status after exporting, this test roundtrips a
// couple of requests and then later examines the exported spans.
// See Issue #700.
exporter := &spanExporter{cur: make(chan *trace.SpanData, 1)}
trace.RegisterExporter(exporter)
defer trace.UnregisterExporter(exporter)
tests := []struct {
res *http.Response
want trace.Status
}{
{res: &http.Response{StatusCode: 200}, want: trace.Status{Code: trace.StatusCodeOK, Message: `OK`}},
{res: &http.Response{StatusCode: 500}, want: trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}},
{res: &http.Response{StatusCode: 403}, want: trace.Status{Code: trace.StatusCodePermissionDenied, Message: `PERMISSION_DENIED`}},
{res: &http.Response{StatusCode: 401}, want: trace.Status{Code: trace.StatusCodeUnauthenticated, Message: `UNAUTHENTICATED`}},
{res: &http.Response{StatusCode: 429}, want: trace.Status{Code: trace.StatusCodeResourceExhausted, Message: `RESOURCE_EXHAUSTED`}},
}
for _, tt := range tests {
t.Run(tt.want.Message, func(t *testing.T) {
ctx := context.Background()
prc, pwc := io.Pipe()
go func() {
pwc.Write([]byte("Foo"))
pwc.Close()
}()
inRes := tt.res
inRes.Body = prc
tr := &traceTransport{
base: &testResponseTransport{res: inRes},
formatSpanName: spanNameFromURL,
startOptions: trace.StartOptions{
Sampler: trace.AlwaysSample(),
},
}
req, err := http.NewRequest("POST", "https://example.org", bytes.NewReader([]byte("testing")))
if err != nil {
t.Fatalf("NewRequest error: %v", err)
}
req = req.WithContext(ctx)
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip error: %v", err)
}
_, _ = ioutil.ReadAll(res.Body)
res.Body.Close()
cur := <-exporter.cur
if got, want := cur.Status, tt.want; got != want {
t.Fatalf("SpanData:\ngot = (%#v)\nwant = (%#v)", got, want)
}
})
}
}
type spanExporter struct {
sync.Mutex
cur chan *trace.SpanData
}
var _ trace.Exporter = (*spanExporter)(nil)
func (se *spanExporter) ExportSpan(sd *trace.SpanData) {
se.Lock()
se.cur <- sd
se.Unlock()
}
type testResponseTransport struct {
res *http.Response
}
var _ http.RoundTripper = (*testResponseTransport)(nil)
func (rb *testResponseTransport) RoundTrip(*http.Request) (*http.Response, error) {
return rb.res, nil
}
func TestHandlerImplementsHTTPPusher(t *testing.T) {
cst := setupAndStartServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pusher, ok := w.(http.Pusher)
if !ok {
w.Write([]byte("false"))
return
}
err := pusher.Push("/static.css", &http.PushOptions{
Method: "GET",
Header: http.Header{"Accept-Encoding": r.Header["Accept-Encoding"]},
})
if err != nil && false {
// TODO: (@odeke-em) consult with Go stdlib for why trying
// to configure even an HTTP/2 server and HTTP/2 transport
// still return http.ErrNotSupported even without using ochttp.Handler.
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Write([]byte("true"))
}), asHTTP2)
defer cst.Close()
tests := []struct {
rt http.RoundTripper
wantBody string
}{
{
rt: h1Transport(),
wantBody: "false",
},
{
rt: h2Transport(),
wantBody: "true",
},
{
rt: &Transport{Base: h1Transport()},
wantBody: "false",
},
{
rt: &Transport{Base: h2Transport()},
wantBody: "true",
},
}
for i, tt := range tests {
c := &http.Client{Transport: &Transport{Base: tt.rt}}
res, err := c.Get(cst.URL)
if err != nil {
t.Errorf("#%d: Unexpected error %v", i, err)
continue
}
body, _ := ioutil.ReadAll(res.Body)
_ = res.Body.Close()
if g, w := string(body), tt.wantBody; g != w {
t.Errorf("#%d: got = %q; want = %q", i, g, w)
}
}
}
const (
isNil = "isNil"
hang = "hang"
ended = "ended"
nonNotifier = "nonNotifier"
asHTTP1 = false
asHTTP2 = true
)
func setupAndStartServer(hf func(http.ResponseWriter, *http.Request), isHTTP2 bool) *httptest.Server {
cst := httptest.NewUnstartedServer(&Handler{
Handler: http.HandlerFunc(hf),
})
if isHTTP2 {
http2.ConfigureServer(cst.Config, new(http2.Server))
cst.TLS = cst.Config.TLSConfig
cst.StartTLS()
} else {
cst.Start()
}
return cst
}
func insecureTLS() *tls.Config { return &tls.Config{InsecureSkipVerify: true} }
func h1Transport() *http.Transport { return &http.Transport{TLSClientConfig: insecureTLS()} }
func h2Transport() *http.Transport {
tr := &http.Transport{TLSClientConfig: insecureTLS()}
http2.ConfigureTransport(tr)
return tr
}
type concurrentBuffer struct {
sync.RWMutex
bw *bytes.Buffer
}
func (cw *concurrentBuffer) Write(b []byte) (int, error) {
cw.Lock()
defer cw.Unlock()
return cw.bw.Write(b)
}
func (cw *concurrentBuffer) String() string {
cw.Lock()
defer cw.Unlock()
return cw.bw.String()
}
func handleCloseNotify(outLog io.Writer) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cn, ok := w.(http.CloseNotifier)
if !ok {
fmt.Fprintln(outLog, nonNotifier)
return
}
ch := cn.CloseNotify()
if ch == nil {
fmt.Fprintln(outLog, isNil)
return
}
<-ch
fmt.Fprintln(outLog, ended)
})
}
func TestHandlerImplementsHTTPCloseNotify(t *testing.T) {
http1Log := &concurrentBuffer{bw: new(bytes.Buffer)}
http1Server := setupAndStartServer(handleCloseNotify(http1Log), asHTTP1)
http2Log := &concurrentBuffer{bw: new(bytes.Buffer)}
http2Server := setupAndStartServer(handleCloseNotify(http2Log), asHTTP2)
defer http1Server.Close()
defer http2Server.Close()
tests := []struct {
url string
want string
}{
{url: http1Server.URL, want: nonNotifier},
{url: http2Server.URL, want: ended},
}
transports := []struct {
name string
rt http.RoundTripper
}{
{name: "http2+ochttp", rt: &Transport{Base: h2Transport()}},
{name: "http1+ochttp", rt: &Transport{Base: h1Transport()}},
{name: "http1-ochttp", rt: h1Transport()},
{name: "http2-ochttp", rt: h2Transport()},
}
// Each transport invokes one of two server types, either HTTP/1 or HTTP/2
for _, trc := range transports {
// Try out all the transport combinations
for i, tt := range tests {
req, err := http.NewRequest("GET", tt.url, nil)
if err != nil {
t.Errorf("#%d: Unexpected error making request: %v", i, err)
continue
}
// Using a timeout to ensure that the request is cancelled and the server
// if its handler implements CloseNotify will see this as the client leaving.
ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
defer cancel()
req = req.WithContext(ctx)
client := &http.Client{Transport: trc.rt}
res, err := client.Do(req)
if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") {
t.Errorf("#%d: %sClient Unexpected error %v", i, trc.name, err)
continue
}
if res != nil && res.Body != nil {
io.CopyN(ioutil.Discard, res.Body, 5)
_ = res.Body.Close()
}
}
}
// Wait for a couple of milliseconds for the GoAway frames to be properly propagated
<-time.After(200 * time.Millisecond)
wantHTTP1Log := strings.Repeat("ended\n", len(transports))
wantHTTP2Log := strings.Repeat("ended\n", len(transports))
if g, w := http1Log.String(), wantHTTP1Log; g != w {
t.Errorf("HTTP1Log got\n\t%q\nwant\n\t%q", g, w)
}
if g, w := http2Log.String(), wantHTTP2Log; g != w {
t.Errorf("HTTP2Log got\n\t%q\nwant\n\t%q", g, w)
}
}
func TestIgnoreHealthz(t *testing.T) {
var spans int
ts := httptest.NewServer(&Handler{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.FromContext(r.Context())
if span != nil {
spans++
}
fmt.Fprint(w, "ok")
}),
StartOptions: trace.StartOptions{
Sampler: trace.AlwaysSample(),
},
})
defer ts.Close()
client := &http.Client{}
for _, path := range []string{"/healthz", "/_ah/health"} {
resp, err := client.Get(ts.URL + path)
if err != nil {
t.Fatalf("Cannot GET %q: %v", path, err)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Cannot read body for %q: %v", path, err)
}
if got, want := string(b), "ok"; got != want {
t.Fatalf("Body for %q = %q; want %q", path, got, want)
}
resp.Body.Close()
}
if spans > 0 {
t.Errorf("Got %v spans; want no spans", spans)
}
}