| // Copyright 2018 The gVisor Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // Package state provides functionality related to saving and loading object |
| // graphs. For most types, it provides a set of default saving / loading logic |
| // that will be invoked automatically if custom logic is not defined. |
| // |
| // Kind Support |
| // ---- ------- |
| // Bool default |
| // Int default |
| // Int8 default |
| // Int16 default |
| // Int32 default |
| // Int64 default |
| // Uint default |
| // Uint8 default |
| // Uint16 default |
| // Uint32 default |
| // Uint64 default |
| // Float32 default |
| // Float64 default |
| // Complex64 custom |
| // Complex128 custom |
| // Array default |
| // Chan custom |
| // Func custom |
| // Interface custom |
| // Map default (*) |
| // Ptr default |
| // Slice default |
| // String default |
| // Struct custom |
| // UnsafePointer custom |
| // |
| // (*) Maps are treated as value types by this package, even if they are |
| // pointers internally. If you want to save two independent references |
| // to the same map value, you must explicitly use a pointer to a map. |
| package state |
| |
| import ( |
| "context" |
| "fmt" |
| "io" |
| "reflect" |
| "runtime" |
| |
| pb "gvisor.dev/gvisor/pkg/state/object_go_proto" |
| ) |
| |
| // ErrState is returned when an error is encountered during encode/decode. |
| type ErrState struct { |
| // err is the underlying error. |
| err error |
| |
| // path is the visit path from root to the current object. |
| path string |
| |
| // trace is the stack trace. |
| trace string |
| } |
| |
| // Error returns a sensible description of the state error. |
| func (e *ErrState) Error() string { |
| return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace) |
| } |
| |
| // UnwrapErrState returns the underlying error in ErrState. |
| // |
| // If err is not *ErrState, err is returned directly. |
| func UnwrapErrState(err error) error { |
| if e, ok := err.(*ErrState); ok { |
| return e.err |
| } |
| return err |
| } |
| |
| // Save saves the given object state. |
| func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { |
| // Create the encoding state. |
| es := &encodeState{ |
| ctx: ctx, |
| idsByObject: make(map[uintptr]uint64), |
| w: w, |
| stats: stats, |
| } |
| |
| // Perform the encoding. |
| return es.safely(func() { |
| es.Serialize(reflect.ValueOf(rootPtr).Elem()) |
| }) |
| } |
| |
| // Load loads a checkpoint. |
| func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { |
| // Create the decoding state. |
| ds := &decodeState{ |
| ctx: ctx, |
| objectsByID: make(map[uint64]*objectState), |
| deferred: make(map[uint64]*pb.Object), |
| r: r, |
| stats: stats, |
| } |
| |
| // Attempt our decode. |
| return ds.safely(func() { |
| ds.Deserialize(reflect.ValueOf(rootPtr).Elem()) |
| }) |
| } |
| |
| // Fns are the state dispatch functions. |
| type Fns struct { |
| // Save is a function like Save(concreteType, Map). |
| Save interface{} |
| |
| // Load is a function like Load(concreteType, Map). |
| Load interface{} |
| } |
| |
| // Save executes the save function. |
| func (fns *Fns) invokeSave(obj reflect.Value, m Map) { |
| reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)}) |
| } |
| |
| // Load executes the load function. |
| func (fns *Fns) invokeLoad(obj reflect.Value, m Map) { |
| reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)}) |
| } |
| |
| // validateStateFn ensures types are correct. |
| func validateStateFn(fn interface{}, typ reflect.Type) bool { |
| fnTyp := reflect.TypeOf(fn) |
| if fnTyp.Kind() != reflect.Func { |
| return false |
| } |
| if fnTyp.NumIn() != 2 { |
| return false |
| } |
| if fnTyp.NumOut() != 0 { |
| return false |
| } |
| if fnTyp.In(0) != typ { |
| return false |
| } |
| if fnTyp.In(1) != reflect.TypeOf(Map{}) { |
| return false |
| } |
| return true |
| } |
| |
| // Validate validates all state functions. |
| func (fns *Fns) Validate(typ reflect.Type) bool { |
| return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ) |
| } |
| |
| type typeDatabase struct { |
| // nameToType is a forward lookup table. |
| nameToType map[string]reflect.Type |
| |
| // typeToName is the reverse lookup table. |
| typeToName map[reflect.Type]string |
| |
| // typeToFns is the function lookup table. |
| typeToFns map[reflect.Type]Fns |
| } |
| |
| // registeredTypes is a database used for SaveInterface and LoadInterface. |
| var registeredTypes = typeDatabase{ |
| nameToType: make(map[string]reflect.Type), |
| typeToName: make(map[reflect.Type]string), |
| typeToFns: make(map[reflect.Type]Fns), |
| } |
| |
| // register registers a type under the given name. This will generally be |
| // called via init() methods, and therefore uses panic to propagate errors. |
| func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) { |
| // We can't allow name collisions. |
| if ot, ok := t.nameToType[name]; ok { |
| panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name())) |
| } |
| |
| // Or multiple registrations. |
| if on, ok := t.typeToName[typ]; ok { |
| panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on)) |
| } |
| |
| t.nameToType[name] = typ |
| t.typeToName[typ] = name |
| t.typeToFns[typ] = fns |
| } |
| |
| // lookupType finds a type given a name. |
| func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) { |
| typ, ok := t.nameToType[name] |
| return typ, ok |
| } |
| |
| // lookupName finds a name given a type. |
| func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) { |
| name, ok := t.typeToName[typ] |
| return name, ok |
| } |
| |
| // lookupFns finds functions given a type. |
| func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) { |
| fns, ok := t.typeToFns[typ] |
| return fns, ok |
| } |
| |
| // Register must be called for any interface implementation types that |
| // implements Loader. |
| // |
| // Register should be called either immediately after startup or via init() |
| // methods. Double registration of either names or types will result in a panic. |
| // |
| // No synchronization is provided; this should only be called in init. |
| // |
| // Example usage: |
| // |
| // state.Register("Foo", (*Foo)(nil), state.Fns{ |
| // Save: (*Foo).Save, |
| // Load: (*Foo).Load, |
| // }) |
| // |
| func Register(name string, instance interface{}, fns Fns) { |
| registeredTypes.register(name, reflect.TypeOf(instance), fns) |
| } |
| |
| // IsZeroValue checks if the given value is the zero value. |
| // |
| // This function is used by the stateify tool. |
| func IsZeroValue(val interface{}) bool { |
| if val == nil { |
| return true |
| } |
| return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface()) |
| } |
| |
| // step captures one encoding / decoding step. On each step, there is up to one |
| // choice made, which is captured by non-nil param. We intentionally do not |
| // eagerly create the final path string, as that will only be needed upon panic. |
| type step struct { |
| // dereference indicate if the current object is obtained by |
| // dereferencing a pointer. |
| dereference bool |
| |
| // format is the formatting string that takes param below, if |
| // non-nil. For example, in array indexing case, we have "[%d]". |
| format string |
| |
| // param stores the choice made at the current encoding / decoding step. |
| // For eaxmple, in array indexing case, param stores the index. When no |
| // choice is made, e.g. dereference, param should be nil. |
| param interface{} |
| } |
| |
| // recoverable is the state encoding / decoding panic recovery facility. It is |
| // also used to store encoding / decoding steps as well as the reference to the |
| // original queued object from which the current object is dispatched. The |
| // complete encoding / decoding path is synthesised from the steps in all queued |
| // objects leading to the current object. |
| type recoverable struct { |
| from *recoverable |
| steps []step |
| } |
| |
| // push enters a new context level. |
| func (sr *recoverable) push(dereference bool, format string, param interface{}) { |
| sr.steps = append(sr.steps, step{dereference, format, param}) |
| } |
| |
| // pop exits the current context level. |
| func (sr *recoverable) pop() { |
| if len(sr.steps) <= 1 { |
| return |
| } |
| sr.steps = sr.steps[:len(sr.steps)-1] |
| } |
| |
| // path returns the complete encoding / decoding path from root. This is only |
| // called upon panic. |
| func (sr *recoverable) path() string { |
| if sr.from == nil { |
| return "root" |
| } |
| p := sr.from.path() |
| for _, s := range sr.steps { |
| if s.dereference { |
| p = fmt.Sprintf("*(%s)", p) |
| } |
| if s.param == nil { |
| p += s.format |
| } else { |
| p += fmt.Sprintf(s.format, s.param) |
| } |
| } |
| return p |
| } |
| |
| func (sr *recoverable) copy() recoverable { |
| return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)} |
| } |
| |
| // safely executes the given function, catching a panic and unpacking as an error. |
| // |
| // The error flow through the state package uses panic and recover. There are |
| // two important reasons for this: |
| // |
| // 1) Many of the reflection methods will already panic with invalid data or |
| // violated assumptions. We would want to recover anyways here. |
| // |
| // 2) It allows us to eliminate boilerplate within Save() and Load() functions. |
| // In nearly all cases, when the low-level serialization functions fail, you |
| // will want the checkpoint to fail anyways. Plumbing errors through every |
| // method doesn't add a lot of value. If there are specific error conditions |
| // that you'd like to handle, you should add appropriate functionality to |
| // objects themselves prior to calling Save() and Load(). |
| func (sr *recoverable) safely(fn func()) (err error) { |
| defer func() { |
| if r := recover(); r != nil { |
| es := new(ErrState) |
| if e, ok := r.(error); ok { |
| es.err = e |
| } else { |
| es.err = fmt.Errorf("%v", r) |
| } |
| |
| es.path = sr.path() |
| |
| // Make a stack. We don't know how big it will be ahead |
| // of time, but want to make sure we get the whole |
| // thing. So we just do a stupid brute force approach. |
| var stack []byte |
| for sz := 1024; ; sz *= 2 { |
| stack = make([]byte, sz) |
| n := runtime.Stack(stack, false) |
| if n < sz { |
| es.trace = string(stack[:n]) |
| break |
| } |
| } |
| |
| // Set the error. |
| err = es |
| } |
| }() |
| |
| // Execute the function. |
| fn() |
| return nil |
| } |