blob: 607135eee3d8d0ba27ee7047a5e62da431f48334 [file] [log] [blame]
// Copyright 2013 M-Lab
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package digest
import (
"crypto/md5"
"crypto/rand"
"errors"
"fmt"
"io"
"net/http"
"strings"
)
var (
ErrNilTransport = errors.New("transport is nil")
ErrBadChallenge = errors.New("challenge is bad")
ErrAlgNotImplemented = errors.New("algorithm not implemented")
)
// Transport is an implementation of http.RoundTripper that supports HTTP
// digest authentication.
type Transport struct {
Username string
Password string
Transport http.RoundTripper
}
// NewTransport creates a new digest transport using http.DefaultTransport.
func NewTransport(username, password string) *Transport {
return &Transport{
Username: username,
Password: password,
Transport: http.DefaultTransport,
}
}
// RoundTrip makes a request expecting a 401 response that will require digest
// authentication. It creates the credentials it needs and makes a follow-up
// request.
func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
if t.Transport == nil {
return nil, ErrNilTransport
}
body, err := r.GetBody()
if err != nil {
return nil, err
}
req, err := http.NewRequest(r.Method, r.URL.String(), body)
if err != nil {
return nil, err
}
req.Header = make(http.Header)
for k, s := range r.Header {
req.Header[k] = s
}
// Make a request to get the 401 that contains the challenge.
res, err := t.Transport.RoundTrip(r)
if err != nil || res.StatusCode != 401 {
return res, err
}
defer res.Body.Close()
chal := res.Header.Get("WWW-Authenticate")
c, err := parseChallenge(chal)
if err != nil {
return res, err
}
// Generate credentials based on the challenge.
cr := t.authenticate(req, c)
auth, err := cr.authorize()
if err != nil {
return res, err
}
// Make authenticated request.
req.Header.Set("Authorization", auth)
return t.Transport.RoundTrip(req)
}
type challenge struct {
realm string
domain string
nonce string
opaque string
stale string
algorithm string
qop string
}
func parseChallenge(input string) (*challenge, error) {
const ws = " \n\r\t"
const qs = `"`
s := strings.Trim(input, ws)
if !strings.HasPrefix(s, "Digest ") {
return nil, ErrBadChallenge
}
s = strings.Trim(s[7:], ws)
sl := strings.Split(s, ",")
c := &challenge{
algorithm: "MD5",
}
var r []string
for i := range sl {
r = strings.SplitN(strings.Trim(sl[i], ws), "=", 2)
switch r[0] {
case "realm":
c.realm = strings.Trim(r[1], qs)
case "domain":
c.domain = strings.Trim(r[1], qs)
case "nonce":
c.nonce = strings.Trim(r[1], qs)
case "opaque":
c.opaque = strings.Trim(r[1], qs)
case "stale":
c.stale = strings.Trim(r[1], qs)
case "algorithm":
c.algorithm = strings.Trim(r[1], qs)
case "qop":
c.qop = strings.Trim(r[1], qs)
default:
return nil, ErrBadChallenge
}
}
return c, nil
}
type credentials struct {
userhash bool
username string
realm string
nonce string
uri string
algorithm string
cnonce string
opaque string
qop string
nc int
method string
password string
}
func h(data string) string {
return fmt.Sprintf("%x", md5.Sum([]byte(data)))
}
func (c *credentials) ha1() string {
return h(fmt.Sprintf("%s:%s:%s", c.username, c.realm, c.password))
}
func (c *credentials) ha2() string {
return h(fmt.Sprintf("%s:%s", c.method, c.uri))
}
func (c *credentials) response(cnonce string) (string, error) {
c.nc++
if c.qop == "auth" {
if cnonce != "" {
c.cnonce = cnonce
} else {
b := make([]byte, 8)
io.ReadFull(rand.Reader, b)
c.cnonce = fmt.Sprintf("%x", b)[:16]
}
return h(fmt.Sprintf("%s:%s:%08x:%s:%s:%s",
c.ha1(), c.nonce, c.nc, c.cnonce, c.qop, c.ha2())), nil
} else if c.qop == "" {
return h(fmt.Sprintf("%s:%s:%s", c.ha1(), c.nonce, c.ha2())), nil
}
return "", ErrAlgNotImplemented
}
func (c *credentials) authorize() (string, error) {
if c.algorithm != "MD5" {
return "", ErrAlgNotImplemented
}
if c.qop != "auth" && c.qop != "" {
return "", ErrAlgNotImplemented
}
response, err := c.response("")
if err != nil {
return "", err
}
sl := []string{}
sl = append(sl, fmt.Sprintf(`username="%s"`, c.username))
sl = append(sl, fmt.Sprintf(`realm="%s"`, c.realm))
sl = append(sl, fmt.Sprintf(`nonce="%s"`, c.nonce))
sl = append(sl, fmt.Sprintf(`uri="%s"`, c.uri))
sl = append(sl, fmt.Sprintf(`response="%s"`, response))
sl = append(sl, fmt.Sprintf(`algorithm="%s"`, c.algorithm))
sl = append(sl, fmt.Sprintf(`cnonce="%s"`, c.cnonce))
if c.opaque != "" {
sl = append(sl, fmt.Sprintf(`opaque="%s"`, c.opaque))
}
if c.qop != "" {
sl = append(sl, fmt.Sprintf(`qop=%s`, c.qop))
}
sl = append(sl, fmt.Sprintf("nc=%08x", c.nc))
if c.userhash {
sl = append(sl, `userhash="true"`)
}
return fmt.Sprintf("Digest %s", strings.Join(sl, ", ")), nil
}
func (t *Transport) authenticate(req *http.Request, c *challenge) *credentials {
return &credentials{
username: t.Username,
realm: c.realm,
nonce: c.nonce,
uri: req.URL.RequestURI(),
algorithm: c.algorithm,
opaque: c.opaque,
qop: c.qop,
nc: 0,
method: req.Method,
password: t.Password,
}
}