Merge pull request #16 from reinventer/improvements

Case insensitive algorithm name and small optimization
diff --git a/authorization.go b/authorization.go
index 9f8d8bc..741fd65 100644
--- a/authorization.go
+++ b/authorization.go
@@ -9,7 +9,6 @@
 	"hash"
 	"io"
 	"net/url"
-	"regexp"
 	"strings"
 	"time"
 )
@@ -49,6 +48,13 @@
 	return ah.refreshAuthorization(dr)
 }
 
+const (
+	algorithmMD5        = "MD5"
+	algorithmMD5Sess    = "MD5-SESS"
+	algorithmSHA256     = "SHA-256"
+	algorithmSHA256Sess = "SHA-256-SESS"
+)
+
 func (ah *authorization) refreshAuthorization(dr *DigestRequest) (*authorization, error) {
 
 	ah.Username = dr.Username
@@ -82,11 +88,13 @@
 
 func (ah *authorization) computeA1(dr *DigestRequest) string {
 
-	if ah.Algorithm == "" || ah.Algorithm == "MD5" || ah.Algorithm == "SHA-256" {
+	algorithm := strings.ToUpper(ah.Algorithm)
+
+	if algorithm == "" || algorithm == algorithmMD5 || algorithm == algorithmSHA256 {
 		return fmt.Sprintf("%s:%s:%s", ah.Username, ah.Realm, dr.Password)
 	}
 
-	if ah.Algorithm == "MD5-sess" || ah.Algorithm == "SHA-256-sess" {
+	if algorithm == algorithmMD5Sess || algorithm == algorithmSHA256Sess {
 		upHash := ah.hash(fmt.Sprintf("%s:%s:%s", ah.Username, ah.Realm, dr.Password))
 		return fmt.Sprintf("%s:%s:%s", upHash, ah.Nonce, ah.Cnonce)
 	}
@@ -96,7 +104,7 @@
 
 func (ah *authorization) computeA2(dr *DigestRequest) string {
 
-	if matched, _ := regexp.MatchString("auth-int", dr.Wa.Qop); matched {
+	if strings.Contains(dr.Wa.Qop, "auth-int") {
 		ah.Qop = "auth-int"
 		return fmt.Sprintf("%s:%s:%s", dr.Method, ah.URI, ah.hash(dr.Body))
 	}
@@ -109,20 +117,21 @@
 	return ""
 }
 
-func (ah *authorization) hash(a string) (s string) {
-
+func (ah *authorization) hash(a string) string {
 	var h hash.Hash
+	algorithm := strings.ToUpper(ah.Algorithm)
 
-	if ah.Algorithm == "" || ah.Algorithm == "MD5" || ah.Algorithm == "MD5-sess" {
+	if algorithm == "" || algorithm == algorithmMD5 || algorithm == algorithmMD5Sess {
 		h = md5.New()
-	} else if ah.Algorithm == "SHA-256" || ah.Algorithm == "SHA-256-sess" {
+	} else if algorithm == algorithmSHA256 || algorithm == algorithmSHA256Sess {
 		h = sha256.New()
+	} else {
+		// unknown algorithm
+		return ""
 	}
 
 	io.WriteString(h, a)
-	s = hex.EncodeToString(h.Sum(nil))
-
-	return
+	return hex.EncodeToString(h.Sum(nil))
 }
 
 func (ah *authorization) toString() string {
diff --git a/authorization_test.go b/authorization_test.go
new file mode 100644
index 0000000..76cb721
--- /dev/null
+++ b/authorization_test.go
@@ -0,0 +1,174 @@
+package digest_auth_client
+
+import "testing"
+
+func TestHash(t *testing.T) {
+	testCases := []struct {
+		name      string
+		algorithm string
+		expRes    string
+	}{
+		{
+			name:      "empty algorithm",
+			algorithm: "",
+			expRes:    "1a79a4d60de6718e8e5b326e338ae533",
+		},
+		{
+			name:      "MD5 algorithm",
+			algorithm: "MD5",
+			expRes:    "1a79a4d60de6718e8e5b326e338ae533",
+		},
+		{
+			name:      "MD5-sess algorithm",
+			algorithm: "MD5",
+			expRes:    "1a79a4d60de6718e8e5b326e338ae533",
+		},
+		{
+			name:      "SHA256 algorithm",
+			algorithm: "SHA-256",
+			expRes:    "50d858e0985ecc7f60418aaf0cc5ab587f42c2570a884095a9e8ccacd0f6545c",
+		},
+		{
+			name:      "SHA256-sess algorithm",
+			algorithm: "SHA-256",
+			expRes:    "50d858e0985ecc7f60418aaf0cc5ab587f42c2570a884095a9e8ccacd0f6545c",
+		},
+		{
+			name:      "md5 algorithm",
+			algorithm: "md5",
+			expRes:    "1a79a4d60de6718e8e5b326e338ae533",
+		},
+		{
+			name:      "unknown algorithm",
+			algorithm: "unknown",
+			expRes:    "",
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			ah := &authorization{Algorithm: tc.algorithm}
+			res := ah.hash("example")
+			if res != tc.expRes {
+				t.Errorf("got: %q, want: %q", res, tc.expRes)
+			}
+		})
+	}
+}
+
+func TestComputeA1(t *testing.T) {
+	testCases := []struct {
+		name      string
+		algorithm string
+		expRes    string
+	}{
+		{
+			name:      "empty algorithm",
+			algorithm: "",
+			expRes:    "username:realm:secret",
+		},
+		{
+			name:      "MD5 algorithm",
+			algorithm: "MD5",
+			expRes:    "username:realm:secret",
+		},
+		{
+			name:      "MD5-sess algorithm",
+			algorithm: "MD5",
+			expRes:    "username:realm:secret",
+		},
+		{
+			name:      "SHA256 algorithm",
+			algorithm: "SHA-256",
+			expRes:    "username:realm:secret",
+		},
+		{
+			name:      "SHA256-sess algorithm",
+			algorithm: "SHA-256",
+			expRes:    "username:realm:secret",
+		},
+		{
+			name:      "md5 algorithm",
+			algorithm: "md5",
+			expRes:    "username:realm:secret",
+		},
+		{
+			name:      "unknown algorithm",
+			algorithm: "unknown",
+			expRes:    "",
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			dr := &DigestRequest{Password: "secret"}
+			ah := &authorization{
+				Algorithm: tc.algorithm,
+				Nonce:     "nonce",
+				Cnonce:    "cnonce",
+				Username:  "username",
+				Realm:     "realm",
+			}
+			res := ah.computeA1(dr)
+			if res != tc.expRes {
+				t.Errorf("got: %q, want: %q", res, tc.expRes)
+			}
+		})
+	}
+}
+
+func TestComputeA2(t *testing.T) {
+	testCases := []struct {
+		name       string
+		qop        string
+		expRes     string
+		expAuthQop string
+	}{
+		{
+			name:       "empty qop",
+			qop:        "",
+			expRes:     "method:uri",
+			expAuthQop: "auth",
+		},
+		{
+			name:       "qop is auth",
+			qop:        "auth",
+			expRes:     "method:uri",
+			expAuthQop: "auth",
+		},
+		{
+			name:       "qop is auth-int",
+			qop:        "qop is auth-int",
+			expRes:     "method:uri:841a2d689ad86bd1611447453c22c6fc",
+			expAuthQop: "auth-int",
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			dr := &DigestRequest{
+				Method: "method",
+				Body:   "body",
+				Wa: &wwwAuthenticate{
+					Qop: tc.qop,
+				},
+			}
+			ah := &authorization{
+				Algorithm: "MD5",
+				Nonce:     "nonce",
+				Cnonce:    "cnonce",
+				Username:  "username",
+				Realm:     "realm",
+				URI:       "uri",
+				Qop:       tc.qop,
+			}
+			res := ah.computeA2(dr)
+			if res != tc.expRes {
+				t.Errorf("wrong result, got: %q, want: %q", res, tc.expRes)
+			}
+			if ah.Qop != tc.expAuthQop {
+				t.Errorf("wrong qop, got: %q, want: %q", ah.Qop, tc.expAuthQop)
+			}
+		})
+	}
+}