net/http/httputil: run the ReverseProxy.ModifyResponse hook for upgrades

Fixes #29627

Change-Id: I08a5b45151a11b5a4f3b5a2d984c0322cf904697
Reviewed-on: https://go-review.googlesource.com/c/157098
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go
index 1c9feb7..4e10bf3 100644
--- a/src/net/http/httputil/reverseproxy.go
+++ b/src/net/http/httputil/reverseproxy.go
@@ -171,6 +171,20 @@
 	return p.defaultErrorHandler
 }
 
+// modifyResponse conditionally runs the optional ModifyResponse hook
+// and reports whether the request should proceed.
+func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
+	if p.ModifyResponse == nil {
+		return true
+	}
+	if err := p.ModifyResponse(res); err != nil {
+		res.Body.Close()
+		p.getErrorHandler()(rw, req, err)
+		return false
+	}
+	return true
+}
+
 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
 	transport := p.Transport
 	if transport == nil {
@@ -250,6 +264,9 @@
 
 	// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
 	if res.StatusCode == http.StatusSwitchingProtocols {
+		if !p.modifyResponse(rw, res, outreq) {
+			return
+		}
 		p.handleUpgradeResponse(rw, outreq, res)
 		return
 	}
@@ -260,12 +277,8 @@
 		res.Header.Del(h)
 	}
 
-	if p.ModifyResponse != nil {
-		if err := p.ModifyResponse(res); err != nil {
-			res.Body.Close()
-			p.getErrorHandler()(rw, outreq, err)
-			return
-		}
+	if !p.modifyResponse(rw, res, outreq) {
+		return
 	}
 
 	copyHeader(rw.Header(), res.Header)
diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go
index bda569a..5edefa0 100644
--- a/src/net/http/httputil/reverseproxy_test.go
+++ b/src/net/http/httputil/reverseproxy_test.go
@@ -1012,6 +1012,10 @@
 	backURL, _ := url.Parse(backendServer.URL)
 	rproxy := NewSingleHostReverseProxy(backURL)
 	rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
+	rproxy.ModifyResponse = func(res *http.Response) error {
+		res.Header.Add("X-Modified", "true")
+		return nil
+	}
 
 	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 		rw.Header().Set("X-Header", "X-Value")
@@ -1049,6 +1053,10 @@
 	}
 	defer rwc.Close()
 
+	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
+		t.Errorf("response X-Modified header = %q; want %q", got, want)
+	}
+
 	io.WriteString(rwc, "Hello\n")
 	bs := bufio.NewScanner(rwc)
 	if !bs.Scan() {