blob: e519ddeca434af110930d860d688fe68836fe2e5 [file] [log] [blame]
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package state
import (
"bytes"
"context"
"fmt"
"math"
"reflect"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/state/wire"
)
// internalCallback is a interface called on object completion.
//
// There are two implementations: objectDecodeState & userCallback.
type internalCallback interface {
// source returns the dependent object. May be nil.
source() *objectDecodeState
// callbackRun executes the callback.
callbackRun()
}
// userCallback is an implementation of internalCallback.
type userCallback func()
// source implements internalCallback.source.
func (userCallback) source() *objectDecodeState {
return nil
}
// callbackRun implements internalCallback.callbackRun.
func (uc userCallback) callbackRun() {
uc()
}
// objectDecodeState represents an object that may be in the process of being
// decoded. Specifically, it represents either a decoded object, or an an
// interest in a future object that will be decoded. When that interest is
// registered (via register), the storage for the object will be created, but
// it will not be decoded until the object is encountered in the stream.
type objectDecodeState struct {
// id is the id for this object.
id objectID
// typ is the id for this typeID. This may be zero if this is not a
// type-registered structure.
typ typeID
// obj is the object. This may or may not be valid yet, depending on
// whether complete returns true. However, regardless of whether the
// object is valid, obj contains a final storage location for the
// object. This is immutable.
//
// Note that this must be addressable (obj.Addr() must not panic).
//
// The obj passed to the decode methods below will equal this obj only
// in the case of decoding the top-level object. However, the passed
// obj may represent individual fields, elements of a slice, etc. that
// are effectively embedded within the reflect.Value below but with
// distinct types.
obj reflect.Value
// blockedBy is the number of dependencies this object has.
blockedBy int
// callbacksInline is inline storage for callbacks.
callbacksInline [2]internalCallback
// callbacks is a set of callbacks to execute on load.
callbacks []internalCallback
completeEntry
}
// addCallback adds a callback to the objectDecodeState.
func (ods *objectDecodeState) addCallback(ic internalCallback) {
if ods.callbacks == nil {
ods.callbacks = ods.callbacksInline[:0]
}
ods.callbacks = append(ods.callbacks, ic)
}
// findCycleFor returns when the given object is found in the blocking set.
func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState {
for _, ic := range ods.callbacks {
other := ic.source()
if other != nil && other == target {
return []*objectDecodeState{target}
} else if childList := other.findCycleFor(target); childList != nil {
return append(childList, other)
}
}
// This should not occur.
Failf("no deadlock found?")
panic("unreachable")
}
// findCycle finds a dependency cycle.
func (ods *objectDecodeState) findCycle() []*objectDecodeState {
return append(ods.findCycleFor(ods), ods)
}
// source implements internalCallback.source.
func (ods *objectDecodeState) source() *objectDecodeState {
return ods
}
// callbackRun implements internalCallback.callbackRun.
func (ods *objectDecodeState) callbackRun() {
ods.blockedBy--
}
// decodeState is a graph of objects in the process of being decoded.
//
// The decode process involves loading the breadth-first graph generated by
// encode. This graph is read in it's entirety, ensuring that all object
// storage is complete.
//
// As the graph is being serialized, a set of completion callbacks are
// executed. These completion callbacks should form a set of acyclic subgraphs
// over the original one. After decoding is complete, the objects are scanned
// to ensure that all callbacks are executed, otherwise the callback graph was
// not acyclic.
type decodeState struct {
// ctx is the decode context.
ctx context.Context
// r is the input stream.
r wire.Reader
// types is the type database.
types typeDecodeDatabase
// objectByID is the set of objects in progress.
objectsByID []*objectDecodeState
// deferred are objects that have been read, by no interest has been
// registered yet. These will be decoded once interest in registered.
deferred map[objectID]wire.Object
// pending is the set of objects that are not yet complete.
pending completeList
// stats tracks time data.
stats Stats
}
// lookup looks up an object in decodeState or returns nil if no such object
// has been previously registered.
func (ds *decodeState) lookup(id objectID) *objectDecodeState {
if len(ds.objectsByID) < int(id) {
return nil
}
return ds.objectsByID[id-1]
}
// checkComplete checks for completion.
func (ds *decodeState) checkComplete(ods *objectDecodeState) bool {
// Still blocked?
if ods.blockedBy > 0 {
return false
}
// Track stats if relevant.
if ods.callbacks != nil && ods.typ != 0 {
ds.stats.start(ods.typ)
defer ds.stats.done()
}
// Fire all callbacks.
for _, ic := range ods.callbacks {
ic.callbackRun()
}
// Mark completed.
cbs := ods.callbacks
ods.callbacks = nil
ds.pending.Remove(ods)
// Recursively check others.
for _, ic := range cbs {
if other := ic.source(); other != nil && other.blockedBy == 0 {
ds.checkComplete(other)
}
}
return true // All set.
}
// wait registers a dependency on an object.
//
// As a special case, we always allow _useable_ references back to the first
// decoding object because it may have fields that are already decoded. We also
// allow trivial self reference, since they can be handled internally.
func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) {
switch id {
case waiter.id:
// Trivial self reference.
fallthrough
case 1:
// Root object; see above.
if callback != nil {
callback()
}
return
}
// Mark as blocked.
waiter.blockedBy++
// No nil can be returned here.
other := ds.lookup(id)
if callback != nil {
// Add the additional user callback.
other.addCallback(userCallback(callback))
}
// Mark waiter as unblocked.
other.addCallback(waiter)
}
// waitObject notes a blocking relationship.
func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) {
if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 {
// Refs can encode pointers and maps.
ds.wait(ods, objectID(rv.Root), callback)
} else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 {
// See decodeObject; we need to wait for the array (if non-nil).
ds.wait(ods, objectID(sv.Ref.Root), callback)
} else if iv, ok := encoded.(*wire.Interface); ok {
// It's an interface (wait recurisvely).
ds.waitObject(ods, iv.Value, callback)
} else if callback != nil {
// Nothing to wait for: execute the callback immediately.
callback()
}
}
// walkChild returns a child object from obj, given an accessor path. This is
// the decode-side equivalent to traverse in encode.go.
//
// For the purposes of this function, a child object is either a field within a
// struct or an array element, with one such indirection per element in
// path. The returned value may be an unexported field, so it may not be
// directly assignable. See decode_unsafe.go.
func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
// See wire.Ref.Dots. The path here is specified in reverse order.
for i := len(path) - 1; i >= 0; i-- {
switch pc := path[i].(type) {
case *wire.FieldName: // Must be a pointer.
if obj.Kind() != reflect.Struct {
Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj)
}
obj = obj.FieldByName(string(*pc))
case wire.Index: // Embedded.
if obj.Kind() != reflect.Array {
Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj)
}
obj = obj.Index(int(pc))
default:
panic("unreachable: switch should be exhaustive")
}
}
return obj
}
// register registers a decode with a type.
//
// This type is only used to instantiate a new object if it has not been
// registered previously. This depends on the type provided if none is
// available in the object itself.
func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
// Grow the objectsByID slice.
id := objectID(r.Root)
if len(ds.objectsByID) < int(id) {
ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
}
// Does this object already exist?
ods := ds.objectsByID[id-1]
if ods != nil {
return walkChild(r.Dots, ods.obj)
}
// Create the object.
if len(r.Dots) != 0 {
typ = ds.findType(r.Type)
}
v := reflect.New(typ)
ods = &objectDecodeState{
id: id,
obj: v.Elem(),
}
ds.objectsByID[id-1] = ods
ds.pending.PushBack(ods)
// Process any deferred objects & callbacks.
if encoded, ok := ds.deferred[id]; ok {
delete(ds.deferred, id)
ds.decodeObject(ods, ods.obj, encoded)
}
return walkChild(r.Dots, ods.obj)
}
// objectDecoder is for decoding structs.
type objectDecoder struct {
// ds is decodeState.
ds *decodeState
// ods is current object being decoded.
ods *objectDecodeState
// reconciledTypeEntry is the reconciled type information.
rte *reconciledTypeEntry
// encoded is the encoded object state.
encoded *wire.Struct
}
// load is helper for the public methods on Source.
func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) {
// Note that we have reconciled the type and may remap the fields here
// to match what's expected by the decoder. The "slot" parameter here
// is in terms of the local type, where the fields in the encoded
// object are in terms of the wire object's type, which might be in a
// different order (but will have the same fields).
v := *od.encoded.Field(od.rte.FieldOrder[slot])
od.ds.decodeObject(od.ods, objPtr.Elem(), v)
if wait {
// Mark this individual object a blocker.
od.ds.waitObject(od.ods, v, fn)
}
}
// aterLoad implements Source.AfterLoad.
func (od *objectDecoder) afterLoad(fn func()) {
// Queue the local callback; this will execute when all of the above
// data dependencies have been cleared.
od.ods.addCallback(userCallback(fn))
}
// decodeStruct decodes a struct value.
func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) {
if encoded.TypeID == 0 {
// Allow anonymous empty structs, but only if the encoded
// object also has no fields.
if encoded.Fields() == 0 && obj.NumField() == 0 {
return
}
// Propagate an error.
Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name())
}
// Lookup the object type.
rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type())
ods.typ = typeID(encoded.TypeID)
// Invoke the loader.
od := objectDecoder{
ds: ds,
ods: ods,
rte: rte,
encoded: encoded,
}
ds.stats.start(ods.typ)
defer ds.stats.done()
if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
// Note: may be a registered empty struct which does not
// implement the saver/loader interfaces.
sl.StateLoad(Source{internal: od})
}
}
// decodeMap decodes a map value.
func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) {
if obj.IsNil() {
// See pointerTo.
obj.Set(reflect.MakeMap(obj.Type()))
}
for i := 0; i < len(encoded.Keys); i++ {
// Decode the objects.
kv := reflect.New(obj.Type().Key()).Elem()
vv := reflect.New(obj.Type().Elem()).Elem()
ds.decodeObject(ods, kv, encoded.Keys[i])
ds.decodeObject(ods, vv, encoded.Values[i])
ds.waitObject(ods, encoded.Keys[i], nil)
ds.waitObject(ods, encoded.Values[i], nil)
// Set in the map.
obj.SetMapIndex(kv, vv)
}
}
// decodeArray decodes an array value.
func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) {
if len(encoded.Contents) != obj.Len() {
Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents))
}
// Decode the contents into the array.
for i := 0; i < len(encoded.Contents); i++ {
ds.decodeObject(ods, obj.Index(i), encoded.Contents[i])
ds.waitObject(ods, encoded.Contents[i], nil)
}
}
// findType finds the type for the given wire.TypeSpecs.
func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type {
switch x := t.(type) {
case wire.TypeID:
typ := ds.types.LookupType(typeID(x))
rte := ds.types.Lookup(typeID(x), typ)
return rte.LocalType
case *wire.TypeSpecPointer:
return reflect.PtrTo(ds.findType(x.Type))
case *wire.TypeSpecArray:
return reflect.ArrayOf(int(x.Count), ds.findType(x.Type))
case *wire.TypeSpecSlice:
return reflect.SliceOf(ds.findType(x.Type))
case *wire.TypeSpecMap:
return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value))
default:
// Should not happen.
Failf("unknown type %#v", t)
}
panic("unreachable")
}
// decodeInterface decodes an interface value.
func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) {
if _, ok := encoded.Type.(wire.TypeSpecNil); ok {
// Special case; the nil object. Just decode directly, which
// will read nil from the wire (if encoded correctly).
ds.decodeObject(ods, obj, encoded.Value)
return
}
// We now need to resolve the actual type.
typ := ds.findType(encoded.Type)
// We need to imbue type information here, then we can proceed to
// decode normally. In order to avoid issues with setting value-types,
// we create a new non-interface version of this object. We will then
// set the interface object to be equal to whatever we decode.
origObj := obj
obj = reflect.New(typ).Elem()
defer origObj.Set(obj)
// With the object now having sufficient type information to actually
// have Set called on it, we can proceed to decode the value.
ds.decodeObject(ods, obj, encoded.Value)
}
// isFloatEq determines if x and y represent the same value.
func isFloatEq(x float64, y float64) bool {
switch {
case math.IsNaN(x):
return math.IsNaN(y)
case math.IsInf(x, 1):
return math.IsInf(y, 1)
case math.IsInf(x, -1):
return math.IsInf(y, -1)
default:
return x == y
}
}
// isComplexEq determines if x and y represent the same value.
func isComplexEq(x complex128, y complex128) bool {
return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y))
}
// decodeObject decodes a object value.
func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) {
switch x := encoded.(type) {
case wire.Nil: // Fast path: first.
// We leave obj alone here. That's because if obj represents an
// interface, it may have been imbued with type information in
// decodeInterface, and we don't want to destroy that.
case *wire.Ref:
// Nil pointers may be encoded in a "forceValue" context. For
// those we just leave it alone as the value will already be
// correct (nil).
if id := objectID(x.Root); id == 0 {
return
}
// Note that if this is a map type, we go through a level of
// indirection to allow for map aliasing.
if obj.Kind() == reflect.Map {
v := ds.register(x, obj.Type())
if v.IsNil() {
// Note that we don't want to clobber the map
// if has already been decoded by decodeMap. We
// just make it so that we have a consistent
// reference when that eventually does happen.
v.Set(reflect.MakeMap(v.Type()))
}
obj.Set(v)
return
}
// Normal assignment: authoritative only if no dots.
v := ds.register(x, obj.Type().Elem())
obj.Set(reflectValueRWAddr(v))
case wire.Bool:
obj.SetBool(bool(x))
case wire.Int:
obj.SetInt(int64(x))
if obj.Int() != int64(x) {
Failf("signed integer truncated from %v to %v", int64(x), obj.Int())
}
case wire.Uint:
obj.SetUint(uint64(x))
if obj.Uint() != uint64(x) {
Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint())
}
case wire.Float32:
obj.SetFloat(float64(x))
case wire.Float64:
obj.SetFloat(float64(x))
if !isFloatEq(obj.Float(), float64(x)) {
Failf("floating point number truncated from %v to %v", float64(x), obj.Float())
}
case *wire.Complex64:
obj.SetComplex(complex128(*x))
case *wire.Complex128:
obj.SetComplex(complex128(*x))
if !isComplexEq(obj.Complex(), complex128(*x)) {
Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex())
}
case *wire.String:
obj.SetString(string(*x))
case *wire.Slice:
// See *wire.Ref above; same applies.
if id := objectID(x.Ref.Root); id == 0 {
return
}
// Note that it's fine to slice the array here and assume that
// contents will still be filled in later on.
typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
v := ds.register(&x.Ref, typ)
obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity)))
case *wire.Array:
ds.decodeArray(ods, obj, x)
case *wire.Struct:
ds.decodeStruct(ods, obj, x)
case *wire.Map:
ds.decodeMap(ods, obj, x)
case *wire.Interface:
ds.decodeInterface(ods, obj, x)
default:
// Shoud not happen, not propagated as an error.
Failf("unknown object %#v for %q", encoded, obj.Type().Name())
}
}
// Load deserializes the object graph rooted at obj.
//
// This function may panic and should be run in safely().
func (ds *decodeState) Load(obj reflect.Value) {
ds.stats.init()
defer ds.stats.fini(func(id typeID) string {
return ds.types.LookupName(id)
})
// Create the root object.
rootOds := &objectDecodeState{
id: 1,
obj: obj,
}
ds.objectsByID = append(ds.objectsByID, rootOds)
ds.pending.PushBack(rootOds)
// Read the number of objects.
numObjects, object, err := ReadHeader(ds.r)
if err != nil {
Failf("header error: %w", err)
}
if !object {
Failf("object missing")
}
// Decode all objects.
var (
encoded wire.Object
ods *objectDecodeState
id objectID
tid = typeID(1)
)
if err := safely(func() {
// Decode all objects in the stream.
//
// Note that the structure of this decoding loop should match the raw
// decoding loop in state/pretty/pretty.printer.printStream().
for i := uint64(0); i < numObjects; {
// Unmarshal either a type object or object ID.
encoded = wire.Load(ds.r)
switch we := encoded.(type) {
case *wire.Type:
ds.types.Register(we)
tid++
encoded = nil
continue
case wire.Uint:
id = objectID(we)
i++
// Unmarshal and resolve the actual object.
encoded = wire.Load(ds.r)
ods = ds.lookup(id)
if ods != nil {
// Decode the object.
ds.decodeObject(ods, ods.obj, encoded)
} else {
// If an object hasn't had interest registered
// previously or isn't yet valid, we deferred
// decoding until interest is registered.
ds.deferred[id] = encoded
}
// For error handling.
ods = nil
encoded = nil
default:
Failf("wanted type or object ID, got %#v", encoded)
}
}
}); err != nil {
// Include as much information as we can, taking into account
// the possible state transitions above.
if ods != nil {
Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
} else if encoded != nil {
Failf("error decoding from %#v: %w", encoded, err)
} else {
Failf("general decoding error: %w", err)
}
}
// Check if we have any deferred objects.
numDeferred := 0
for id, encoded := range ds.deferred {
numDeferred++
if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 {
typ := ds.types.LookupType(typeID(s.TypeID))
log.Warningf("unused deferred object: ID %d, type %v", id, typ)
} else {
log.Warningf("unused deferred object: ID %d, %#v", id, encoded)
}
}
if numDeferred != 0 {
Failf("still had %d deferred objects", numDeferred)
}
// Scan and fire all callbacks. We iterate over the list of incomplete
// objects until all have been finished. We stop iterating if no
// objects become complete (there is a dependency cycle).
//
// Note that we iterate backwards here, because there will be a strong
// tendendcy for blocking relationships to go from earlier objects to
// later (deeper) objects in the graph. This will reduce the number of
// iterations required to finish all objects.
if err := safely(func() {
for ds.pending.Back() != nil {
thisCycle := false
for ods = ds.pending.Back(); ods != nil; {
if ds.checkComplete(ods) {
thisCycle = true
break
}
ods = ods.Prev()
}
if !thisCycle {
break
}
}
}); err != nil {
Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err)
}
// Check if we have any remaining dependency cycles. If there are any
// objects left in the pending list, then it must be due to a cycle.
if ods := ds.pending.Front(); ods != nil {
// This must be the result of a dependency cycle.
cycle := ods.findCycle()
var buf bytes.Buffer
buf.WriteString("dependency cycle: {")
for i, cycleOS := range cycle {
if i > 0 {
buf.WriteString(" => ")
}
fmt.Fprintf(&buf, "%q", cycleOS.obj.Type())
}
buf.WriteString("}")
Failf("incomplete graph: %s", string(buf.Bytes()))
}
}
// ReadHeader reads an object header.
//
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
func ReadHeader(r wire.Reader) (length uint64, object bool, err error) {
// Read the header.
err = safely(func() {
length = wire.LoadUint(r)
})
if err != nil {
// On the header, pass raw I/O errors.
if sErr, ok := err.(*ErrState); ok {
return 0, false, sErr.Unwrap()
}
}
// Decode whether the object is valid.
object = length&objectFlag != 0
length &^= objectFlag
return
}