blob: a7dd9d82d1d6041692a18ca480fdd965e29500b1 [file] [log] [blame]
// Copyright 2013 Dario Castañé. All rights reserved.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Based on src/pkg/reflect/deepequal.go from official
// golang's stdlib.
package mergo
import (
"reflect"
)
// Traverses recursively both values, assigning src's fields values to dst.
// The map argument tracks comparisons that have already been seen, which allows
// short circuiting on recursive types.
func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, overwrite bool) (err error) {
if !src.IsValid() {
return
}
if dst.CanAddr() {
addr := dst.UnsafeAddr()
h := 17 * addr
seen := visited[h]
typ := dst.Type()
for p := seen; p != nil; p = p.next {
if p.ptr == addr && p.typ == typ {
return nil
}
}
// Remember, remember...
visited[h] = &visit{addr, typ, seen}
}
switch dst.Kind() {
case reflect.Struct:
for i, n := 0, dst.NumField(); i < n; i++ {
if err = deepMerge(dst.Field(i), src.Field(i), visited, depth+1, overwrite); err != nil {
return
}
}
case reflect.Map:
for _, key := range src.MapKeys() {
srcElement := src.MapIndex(key)
if !srcElement.IsValid() {
continue
}
dstElement := dst.MapIndex(key)
switch srcElement.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice:
if srcElement.IsNil() {
continue
}
fallthrough
default:
switch reflect.TypeOf(srcElement.Interface()).Kind() {
case reflect.Struct:
fallthrough
case reflect.Ptr:
fallthrough
case reflect.Map:
if err = deepMerge(dstElement, srcElement, visited, depth+1, overwrite); err != nil {
return
}
}
}
if !isEmptyValue(srcElement) && (overwrite || (!dstElement.IsValid() || isEmptyValue(dst))) {
if dst.IsNil() {
dst.Set(reflect.MakeMap(dst.Type()))
}
dst.SetMapIndex(key, srcElement)
}
}
case reflect.Ptr:
fallthrough
case reflect.Interface:
if src.IsNil() {
break
} else if dst.IsNil() {
if dst.CanSet() && (overwrite || isEmptyValue(dst)) {
dst.Set(src)
}
} else if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, overwrite); err != nil {
return
}
default:
if dst.CanSet() && !isEmptyValue(src) && (overwrite || isEmptyValue(dst)) {
dst.Set(src)
}
}
return
}
// Merge sets fields' values in dst from src if they have a zero
// value of their type.
// dst and src must be valid same-type structs and dst must be
// a pointer to struct.
// It won't merge unexported (private) fields and will do recursively
// any exported field.
func Merge(dst, src interface{}) error {
return merge(dst, src, false)
}
func MergeWithOverwrite(dst, src interface{}) error {
return merge(dst, src, true)
}
func merge(dst, src interface{}, overwrite bool) error {
var (
vDst, vSrc reflect.Value
err error
)
if vDst, vSrc, err = resolveValues(dst, src); err != nil {
return err
}
if vDst.Type() != vSrc.Type() {
return ErrDifferentArgumentsTypes
}
return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite)
}