Improve the behavior of null decoding (#681).
A null value in v3 is considered a request to maintain the default
untouched. If the value being decoded into is a map and there's no
prior value for the field, a new key will be added with the zero
map value type as its value.
diff --git a/decode.go b/decode.go
index 21c0dac..df36e3a 100644
--- a/decode.go
+++ b/decode.go
@@ -399,7 +399,7 @@
//
// If n holds a null value, prepare returns before doing anything.
func (d *decoder) prepare(n *Node, out reflect.Value) (newout reflect.Value, unmarshaled, good bool) {
- if n.ShortTag() == nullTag || n.Kind == 0 && n.IsZero() {
+ if n.ShortTag() == nullTag {
return out, false, false
}
again := true
@@ -808,8 +808,10 @@
}
}
+ mapIsNew := false
if out.IsNil() {
out.Set(reflect.MakeMap(outt))
+ mapIsNew = true
}
for i := 0; i < l; i += 2 {
if isMerge(n.Content[i]) {
@@ -826,7 +828,7 @@
failf("invalid map key: %#v", k.Interface())
}
e := reflect.New(et).Elem()
- if d.unmarshal(n.Content[i+1], e) {
+ if d.unmarshal(n.Content[i+1], e) || n.Content[i+1].ShortTag() == nullTag && (mapIsNew || !out.MapIndex(k).IsValid()) {
out.SetMapIndex(k, e)
}
}
diff --git a/decode_test.go b/decode_test.go
index 71a848c..5ade550 100644
--- a/decode_test.go
+++ b/decode_test.go
@@ -505,7 +505,7 @@
map[string]*string{"foo": nil},
}, {
"foo: null",
- map[string]string{},
+ map[string]string{"foo": ""},
}, {
"foo: null",
map[string]interface{}{"foo": nil},
@@ -517,7 +517,7 @@
map[string]*string{"foo": nil},
}, {
"foo: ~",
- map[string]string{},
+ map[string]string{"foo": ""},
}, {
"foo: ~",
map[string]interface{}{"foo": nil},
@@ -1436,29 +1436,51 @@
}
}
-var unmarshalNullTests = []func() interface{}{
+var unmarshalNullTests = []struct{ input string; pristine, expected func() interface{} }{{
+ "null",
func() interface{} { var v interface{}; v = "v"; return &v },
+ func() interface{} { var v interface{}; v = nil; return &v },
+}, {
+ "null",
func() interface{} { var s = "s"; return &s },
+ func() interface{} { var s = "s"; return &s },
+}, {
+ "null",
func() interface{} { var s = "s"; sptr := &s; return &sptr },
+ func() interface{} { var sptr *string; return &sptr },
+}, {
+ "null",
func() interface{} { var i = 1; return &i },
+ func() interface{} { var i = 1; return &i },
+}, {
+ "null",
func() interface{} { var i = 1; iptr := &i; return &iptr },
- func() interface{} { m := map[string]int{"s": 1}; return &m },
- func() interface{} { m := map[string]int{"s": 1}; return m },
-}
+ func() interface{} { var iptr *int; return &iptr },
+}, {
+ "null",
+ func() interface{} { var m = map[string]int{"s": 1}; return &m },
+ func() interface{} { var m map[string]int; return &m },
+}, {
+ "null",
+ func() interface{} { var m = map[string]int{"s": 1}; return m },
+ func() interface{} { var m = map[string]int{"s": 1}; return m },
+}, {
+ "s2: null\ns3: null",
+ func() interface{} { var m = map[string]int{"s1": 1, "s2": 2}; return m },
+ func() interface{} { var m = map[string]int{"s1": 1, "s2": 2, "s3": 0}; return m },
+}, {
+ "s2: null\ns3: null",
+ func() interface{} { var m = map[string]interface{}{"s1": 1, "s2": 2}; return m },
+ func() interface{} { var m = map[string]interface{}{"s1": 1, "s2": nil, "s3": nil}; return m },
+}}
func (s *S) TestUnmarshalNull(c *C) {
for _, test := range unmarshalNullTests {
- pristine := test()
- decoded := test()
- zero := reflect.Zero(reflect.TypeOf(decoded).Elem()).Interface()
- err := yaml.Unmarshal([]byte("null"), decoded)
+ pristine := test.pristine()
+ expected := test.expected()
+ err := yaml.Unmarshal([]byte(test.input), pristine)
c.Assert(err, IsNil)
- switch pristine.(type) {
- case *interface{}, **string, **int, *map[string]int:
- c.Assert(reflect.ValueOf(decoded).Elem().Interface(), DeepEquals, zero)
- default:
- c.Assert(reflect.ValueOf(decoded).Interface(), DeepEquals, pristine)
- }
+ c.Assert(pristine, DeepEquals, expected)
}
}
diff --git a/node_test.go b/node_test.go
index 9927dad..147594b 100644
--- a/node_test.go
+++ b/node_test.go
@@ -2688,6 +2688,9 @@
c.Assert(n.Decode(&v), IsNil)
c.Assert(v, IsNil)
+ // ... and even when looking for its tag.
+ c.Assert(n.ShortTag(), Equals, "!!null")
+
// Kind zero is still unknown, though.
n.Line = 1
_, err = yaml.Marshal(&n)
diff --git a/yaml.go b/yaml.go
index 56e8a84..8cec6da 100644
--- a/yaml.go
+++ b/yaml.go
@@ -449,6 +449,11 @@
case ScalarNode:
tag, _ := resolve("", n.Value)
return tag
+ case 0:
+ // Special case to make the zero value convenient.
+ if n.IsZero() {
+ return nullTag
+ }
}
return ""
}