panic for MustRegister errors
diff --git a/v2/delay/delay.go b/v2/delay/delay.go
index 0b0393d..f9f3fa9 100644
--- a/v2/delay/delay.go
+++ b/v2/delay/delay.go
@@ -182,11 +182,11 @@
func MustRegister(key string, i interface{}) *Function {
f, err := registerFunction(key, i)
if err != nil {
- return f
+ panic(err)
}
if old := funcs[f.key]; old != nil {
- panic(fmt.Sprintf("multiple functions registered for %q", key))
+ panic(errors.New(fmt.Sprintf("multiple functions registered for %q", key)))
}
funcs[f.key] = f
return f
diff --git a/v2/delay/delay_test.go b/v2/delay/delay_test.go
index c20d9d8..7708bef 100644
--- a/v2/delay/delay_test.go
+++ b/v2/delay/delay_test.go
@@ -42,9 +42,6 @@
}
var (
- invalidFunc = Func("invalid", func() {})
- invalidRegister = MustRegister("invalid", func() {})
-
regFRuns = 0
regFMsg = ""
regF = func(c context.Context, arg string) {
@@ -101,13 +98,13 @@
dupeWhich = 1
}
}
- dupe1Func = Func("dupe", dupe1F)
- dupe2F = func(c context.Context) {
+ dupe1Func = Func("dupe", dupe1F)
+ dupe2F = func(c context.Context) {
if dupeWhich == 0 {
dupeWhich = 2
}
}
- dupe2Func = Func("dupe", dupe2F)
+ dupe2Func = Func("dupe", dupe2F)
requestFuncRuns = 0
requestFuncHeaders *taskqueue.RequestHeaders
@@ -151,11 +148,10 @@
func TestInvalidFunction(t *testing.T) {
c := newFakeContext()
+ invalidFunc := Func("invalid", func() {})
- for _, testTarget := range []*Function{invalidFunc, invalidRegister} {
- if got, want := testTarget.Call(c.ctx), fmt.Errorf("delay: func is invalid: %s", errFirstArg); got.Error() != want.Error() {
- t.Errorf("Incorrect error: got %q, want %q", got, want)
- }
+ if got, want := invalidFunc.Call(c.ctx), fmt.Errorf("delay: func is invalid: %s", errFirstArg); got.Error() != want.Error() {
+ t.Errorf("Incorrect error: got %q, want %q", got, want)
}
}
@@ -425,13 +421,33 @@
func TestMustRegisterDuplicateFunction(t *testing.T) {
MustRegister("dupe", dupe1F)
defer func() {
- if r := recover(); r == nil {
+ err := recover()
+ if err == nil {
t.Error("MustRegister did not panic")
}
+ got := fmt.Sprintf("%s", err)
+ want := fmt.Sprintf("multiple functions registered for %q", "dupe")
+ if got != want {
+ t.Errorf("Incorrect error: got %q, want %q", got, want)
+ }
}()
MustRegister("dupe", dupe2F)
}
+func TestInvalidFunction_MustRegister(t *testing.T) {
+ defer func() {
+ err := recover()
+ if err == nil {
+ t.Error("MustRegister did not panic")
+ }
+ if err != errFirstArg {
+ t.Errorf("Incorrect error: got %q, want %q", err, errFirstArg)
+ }
+ }()
+ MustRegister("invalid", func() {})
+}
+
+
func TestGetRequestHeadersFromContext(t *testing.T) {
for _, testTarget := range []*Function{requestFunc, requestRegister} {
c := newFakeContext()