blob: dbe507ab4f99f445cf8b19e24580b0f527a32ab7 [file] [log] [blame]
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package state provides functionality related to saving and loading object
// graphs. For most types, it provides a set of default saving / loading logic
// that will be invoked automatically if custom logic is not defined.
//
// Kind Support
// ---- -------
// Bool default
// Int default
// Int8 default
// Int16 default
// Int32 default
// Int64 default
// Uint default
// Uint8 default
// Uint16 default
// Uint32 default
// Uint64 default
// Float32 default
// Float64 default
// Complex64 custom
// Complex128 custom
// Array default
// Chan custom
// Func custom
// Interface custom
// Map default (*)
// Ptr default
// Slice default
// String default
// Struct custom
// UnsafePointer custom
//
// (*) Maps are treated as value types by this package, even if they are
// pointers internally. If you want to save two independent references
// to the same map value, you must explicitly use a pointer to a map.
package state
import (
"context"
"fmt"
"io"
"reflect"
"runtime"
pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
)
// ErrState is returned when an error is encountered during encode/decode.
type ErrState struct {
// err is the underlying error.
err error
// path is the visit path from root to the current object.
path string
// trace is the stack trace.
trace string
}
// Error returns a sensible description of the state error.
func (e *ErrState) Error() string {
return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace)
}
// UnwrapErrState returns the underlying error in ErrState.
//
// If err is not *ErrState, err is returned directly.
func UnwrapErrState(err error) error {
if e, ok := err.(*ErrState); ok {
return e.err
}
return err
}
// Save saves the given object state.
func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error {
// Create the encoding state.
es := &encodeState{
ctx: ctx,
idsByObject: make(map[uintptr]uint64),
w: w,
stats: stats,
}
// Perform the encoding.
return es.safely(func() {
es.Serialize(reflect.ValueOf(rootPtr).Elem())
})
}
// Load loads a checkpoint.
func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error {
// Create the decoding state.
ds := &decodeState{
ctx: ctx,
objectsByID: make(map[uint64]*objectState),
deferred: make(map[uint64]*pb.Object),
r: r,
stats: stats,
}
// Attempt our decode.
return ds.safely(func() {
ds.Deserialize(reflect.ValueOf(rootPtr).Elem())
})
}
// Fns are the state dispatch functions.
type Fns struct {
// Save is a function like Save(concreteType, Map).
Save interface{}
// Load is a function like Load(concreteType, Map).
Load interface{}
}
// Save executes the save function.
func (fns *Fns) invokeSave(obj reflect.Value, m Map) {
reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)})
}
// Load executes the load function.
func (fns *Fns) invokeLoad(obj reflect.Value, m Map) {
reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)})
}
// validateStateFn ensures types are correct.
func validateStateFn(fn interface{}, typ reflect.Type) bool {
fnTyp := reflect.TypeOf(fn)
if fnTyp.Kind() != reflect.Func {
return false
}
if fnTyp.NumIn() != 2 {
return false
}
if fnTyp.NumOut() != 0 {
return false
}
if fnTyp.In(0) != typ {
return false
}
if fnTyp.In(1) != reflect.TypeOf(Map{}) {
return false
}
return true
}
// Validate validates all state functions.
func (fns *Fns) Validate(typ reflect.Type) bool {
return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ)
}
type typeDatabase struct {
// nameToType is a forward lookup table.
nameToType map[string]reflect.Type
// typeToName is the reverse lookup table.
typeToName map[reflect.Type]string
// typeToFns is the function lookup table.
typeToFns map[reflect.Type]Fns
}
// registeredTypes is a database used for SaveInterface and LoadInterface.
var registeredTypes = typeDatabase{
nameToType: make(map[string]reflect.Type),
typeToName: make(map[reflect.Type]string),
typeToFns: make(map[reflect.Type]Fns),
}
// register registers a type under the given name. This will generally be
// called via init() methods, and therefore uses panic to propagate errors.
func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) {
// We can't allow name collisions.
if ot, ok := t.nameToType[name]; ok {
panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name()))
}
// Or multiple registrations.
if on, ok := t.typeToName[typ]; ok {
panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on))
}
t.nameToType[name] = typ
t.typeToName[typ] = name
t.typeToFns[typ] = fns
}
// lookupType finds a type given a name.
func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) {
typ, ok := t.nameToType[name]
return typ, ok
}
// lookupName finds a name given a type.
func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) {
name, ok := t.typeToName[typ]
return name, ok
}
// lookupFns finds functions given a type.
func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) {
fns, ok := t.typeToFns[typ]
return fns, ok
}
// Register must be called for any interface implementation types that
// implements Loader.
//
// Register should be called either immediately after startup or via init()
// methods. Double registration of either names or types will result in a panic.
//
// No synchronization is provided; this should only be called in init.
//
// Example usage:
//
// state.Register("Foo", (*Foo)(nil), state.Fns{
// Save: (*Foo).Save,
// Load: (*Foo).Load,
// })
//
func Register(name string, instance interface{}, fns Fns) {
registeredTypes.register(name, reflect.TypeOf(instance), fns)
}
// IsZeroValue checks if the given value is the zero value.
//
// This function is used by the stateify tool.
func IsZeroValue(val interface{}) bool {
if val == nil {
return true
}
return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
}
// step captures one encoding / decoding step. On each step, there is up to one
// choice made, which is captured by non-nil param. We intentionally do not
// eagerly create the final path string, as that will only be needed upon panic.
type step struct {
// dereference indicate if the current object is obtained by
// dereferencing a pointer.
dereference bool
// format is the formatting string that takes param below, if
// non-nil. For example, in array indexing case, we have "[%d]".
format string
// param stores the choice made at the current encoding / decoding step.
// For eaxmple, in array indexing case, param stores the index. When no
// choice is made, e.g. dereference, param should be nil.
param interface{}
}
// recoverable is the state encoding / decoding panic recovery facility. It is
// also used to store encoding / decoding steps as well as the reference to the
// original queued object from which the current object is dispatched. The
// complete encoding / decoding path is synthesised from the steps in all queued
// objects leading to the current object.
type recoverable struct {
from *recoverable
steps []step
}
// push enters a new context level.
func (sr *recoverable) push(dereference bool, format string, param interface{}) {
sr.steps = append(sr.steps, step{dereference, format, param})
}
// pop exits the current context level.
func (sr *recoverable) pop() {
if len(sr.steps) <= 1 {
return
}
sr.steps = sr.steps[:len(sr.steps)-1]
}
// path returns the complete encoding / decoding path from root. This is only
// called upon panic.
func (sr *recoverable) path() string {
if sr.from == nil {
return "root"
}
p := sr.from.path()
for _, s := range sr.steps {
if s.dereference {
p = fmt.Sprintf("*(%s)", p)
}
if s.param == nil {
p += s.format
} else {
p += fmt.Sprintf(s.format, s.param)
}
}
return p
}
func (sr *recoverable) copy() recoverable {
return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)}
}
// safely executes the given function, catching a panic and unpacking as an error.
//
// The error flow through the state package uses panic and recover. There are
// two important reasons for this:
//
// 1) Many of the reflection methods will already panic with invalid data or
// violated assumptions. We would want to recover anyways here.
//
// 2) It allows us to eliminate boilerplate within Save() and Load() functions.
// In nearly all cases, when the low-level serialization functions fail, you
// will want the checkpoint to fail anyways. Plumbing errors through every
// method doesn't add a lot of value. If there are specific error conditions
// that you'd like to handle, you should add appropriate functionality to
// objects themselves prior to calling Save() and Load().
func (sr *recoverable) safely(fn func()) (err error) {
defer func() {
if r := recover(); r != nil {
es := new(ErrState)
if e, ok := r.(error); ok {
es.err = e
} else {
es.err = fmt.Errorf("%v", r)
}
es.path = sr.path()
// Make a stack. We don't know how big it will be ahead
// of time, but want to make sure we get the whole
// thing. So we just do a stupid brute force approach.
var stack []byte
for sz := 1024; ; sz *= 2 {
stack = make([]byte, sz)
n := runtime.Stack(stack, false)
if n < sz {
es.trace = string(stack[:n])
break
}
}
// Set the error.
err = es
}
}()
// Execute the function.
fn()
return nil
}