blob: 590c241a318fabbab2216ebf25c25c1783d6f67b [file] [log] [blame]
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package state
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"reflect"
"sort"
"github.com/golang/protobuf/proto"
pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
)
// objectState represents an object that may be in the process of being
// decoded. Specifically, it represents either a decoded object, or an an
// interest in a future object that will be decoded. When that interest is
// registered (via register), the storage for the object will be created, but
// it will not be decoded until the object is encountered in the stream.
type objectState struct {
// id is the id for this object.
//
// If this field is zero, then this is an anonymous (unregistered,
// non-reference primitive) object. This is immutable.
id uint64
// obj is the object. This may or may not be valid yet, depending on
// whether complete returns true. However, regardless of whether the
// object is valid, obj contains a final storage location for the
// object. This is immutable.
//
// Note that this must be addressable (obj.Addr() must not panic).
//
// The obj passed to the decode methods below will equal this obj only
// in the case of decoding the top-level object. However, the passed
// obj may represent individual fields, elements of a slice, etc. that
// are effectively embedded within the reflect.Value below but with
// distinct types.
obj reflect.Value
// blockedBy is the number of dependencies this object has.
blockedBy int
// blocking is a list of the objects blocked by this one.
blocking []*objectState
// callbacks is a set of callbacks to execute on load.
callbacks []func()
// path is the decoding path to the object.
path recoverable
}
// complete indicates the object is complete.
func (os *objectState) complete() bool {
return os.blockedBy == 0 && len(os.callbacks) == 0
}
// checkComplete checks for completion. If the object is complete, pending
// callbacks will be executed and checkComplete will be called on downstream
// objects (those depending on this one).
func (os *objectState) checkComplete(stats *Stats) {
if os.blockedBy > 0 {
return
}
stats.Start(os.obj)
// Fire all callbacks.
for _, fn := range os.callbacks {
fn()
}
os.callbacks = nil
// Clear all blocked objects.
for _, other := range os.blocking {
other.blockedBy--
other.checkComplete(stats)
}
os.blocking = nil
stats.Done()
}
// waitFor queues a dependency on the given object.
func (os *objectState) waitFor(other *objectState, callback func()) {
os.blockedBy++
other.blocking = append(other.blocking, os)
if callback != nil {
other.callbacks = append(other.callbacks, callback)
}
}
// findCycleFor returns when the given object is found in the blocking set.
func (os *objectState) findCycleFor(target *objectState) []*objectState {
for _, other := range os.blocking {
if other == target {
return []*objectState{target}
} else if childList := other.findCycleFor(target); childList != nil {
return append(childList, other)
}
}
return nil
}
// findCycle finds a dependency cycle.
func (os *objectState) findCycle() []*objectState {
return append(os.findCycleFor(os), os)
}
// decodeState is a graph of objects in the process of being decoded.
//
// The decode process involves loading the breadth-first graph generated by
// encode. This graph is read in it's entirety, ensuring that all object
// storage is complete.
//
// As the graph is being serialized, a set of completion callbacks are
// executed. These completion callbacks should form a set of acyclic subgraphs
// over the original one. After decoding is complete, the objects are scanned
// to ensure that all callbacks are executed, otherwise the callback graph was
// not acyclic.
type decodeState struct {
// ctx is the decode context.
ctx context.Context
// objectByID is the set of objects in progress.
objectsByID map[uint64]*objectState
// deferred are objects that have been read, by no interest has been
// registered yet. These will be decoded once interest in registered.
deferred map[uint64]*pb.Object
// outstanding is the number of outstanding objects.
outstanding uint32
// r is the input stream.
r io.Reader
// stats is the passed stats object.
stats *Stats
// recoverable is the panic recover facility.
recoverable
}
// lookup looks up an object in decodeState or returns nil if no such object
// has been previously registered.
func (ds *decodeState) lookup(id uint64) *objectState {
return ds.objectsByID[id]
}
// wait registers a dependency on an object.
//
// As a special case, we always allow _useable_ references back to the first
// decoding object because it may have fields that are already decoded. We also
// allow trivial self reference, since they can be handled internally.
func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
switch id {
case 0:
// Nil pointer; nothing to wait for.
fallthrough
case waiter.id:
// Trivial self reference.
fallthrough
case 1:
// Root object; see above.
if callback != nil {
callback()
}
return
}
// No nil can be returned here.
waiter.waitFor(ds.lookup(id), callback)
}
// waitObject notes a blocking relationship.
func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
if rv, ok := p.Value.(*pb.Object_RefValue); ok {
// Refs can encode pointers and maps.
ds.wait(os, rv.RefValue, callback)
} else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
// See decodeObject; we need to wait for the array (if non-nil).
ds.wait(os, sv.SliceValue.RefValue, callback)
} else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
// It's an interface (wait recurisvely).
ds.waitObject(os, iv.InterfaceValue.Value, callback)
} else if callback != nil {
// Nothing to wait for: execute the callback immediately.
callback()
}
}
// register registers a decode with a type.
//
// This type is only used to instantiate a new object if it has not been
// registered previously.
func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
os, ok := ds.objectsByID[id]
if ok {
return os
}
// Record in the object index.
if typ.Kind() == reflect.Map {
os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
} else {
os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
}
ds.objectsByID[id] = os
if o, ok := ds.deferred[id]; ok {
// There is a deferred object.
delete(ds.deferred, id) // Free memory.
ds.decodeObject(os, os.obj, o, "", nil)
} else {
// There is no deferred object.
ds.outstanding++
}
return os
}
// decodeStruct decodes a struct value.
func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
// Set the fields.
m := Map{newInternalMap(nil, ds, os)}
defer internalMapPool.Put(m.internalMap)
for _, field := range s.Fields {
m.data = append(m.data, entry{
name: field.Name,
object: field.Value,
})
}
// Sort the fields for efficient searching.
//
// Technically, these should already appear in sorted order in the
// state ordering, so this cost is effectively a single scan to ensure
// that the order is correct.
if len(m.data) > 1 {
sort.Slice(m.data, func(i, j int) bool {
return m.data[i].name < m.data[j].name
})
}
// Invoke the load; this will recursively decode other objects.
fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
if ok {
// Invoke the loader.
fns.invokeLoad(obj.Addr(), m)
} else if obj.NumField() == 0 {
// Allow anonymous empty structs.
return
} else {
// Propagate an error.
panic(fmt.Errorf("unregistered type %s", obj.Type()))
}
}
// decodeMap decodes a map value.
func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
if obj.IsNil() {
obj.Set(reflect.MakeMap(obj.Type()))
}
for i := 0; i < len(m.Keys); i++ {
// Decode the objects.
kv := reflect.New(obj.Type().Key()).Elem()
vv := reflect.New(obj.Type().Elem()).Elem()
ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
ds.waitObject(os, m.Keys[i], nil)
ds.waitObject(os, m.Values[i], nil)
// Set in the map.
obj.SetMapIndex(kv, vv)
}
}
// decodeArray decodes an array value.
func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
if len(a.Contents) != obj.Len() {
panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
}
// Decode the contents into the array.
for i := 0; i < len(a.Contents); i++ {
ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
ds.waitObject(os, a.Contents[i], nil)
}
}
// decodeInterface decodes an interface value.
func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
// Is this a nil value?
if i.Type == "" {
return // Just leave obj alone.
}
// Get the dispatchable type. This may not be used if the given
// reference has already been resolved, but if not we need to know the
// type to create.
t, ok := registeredTypes.lookupType(i.Type)
if !ok {
panic(fmt.Errorf("no valid type for %q", i.Type))
}
if obj.Kind() != reflect.Map {
// Set the obj to be the given typed value; this actually sets
// obj to be a non-zero value -- namely, it inserts type
// information. There's no need to do this for maps.
obj.Set(reflect.Zero(t))
}
// Decode the dereferenced element; there is no need to wait here, as
// the interface object shares the current object state.
ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
}
// decodeObject decodes a object value.
func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
ds.push(false, format, param)
ds.stats.Add(obj)
ds.stats.Start(obj)
switch x := object.GetValue().(type) {
case *pb.Object_BoolValue:
obj.SetBool(x.BoolValue)
case *pb.Object_StringValue:
obj.SetString(string(x.StringValue))
case *pb.Object_Int64Value:
obj.SetInt(x.Int64Value)
if obj.Int() != x.Int64Value {
panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
}
case *pb.Object_Uint64Value:
obj.SetUint(x.Uint64Value)
if obj.Uint() != x.Uint64Value {
panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
}
case *pb.Object_DoubleValue:
obj.SetFloat(x.DoubleValue)
if obj.Float() != x.DoubleValue {
panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
}
case *pb.Object_RefValue:
// Resolve the pointer itself, even though the object may not
// be decoded yet. You need to use wait() in order to ensure
// that is the case. See wait above, and Map.Barrier.
if id := x.RefValue; id != 0 {
// Decoding the interface should have imparted type
// information, so from this point it's safe to resolve
// and use this dynamic information for actually
// creating the object in register.
//
// (For non-interfaces this is a no-op).
dyntyp := reflect.TypeOf(obj.Interface())
if dyntyp.Kind() == reflect.Map {
// Remove the map object count here to avoid
// double counting, as this object will be
// counted again when it gets processed later.
// We do not add a reference count as the
// reference is artificial.
ds.stats.Remove(obj)
obj.Set(ds.register(id, dyntyp).obj)
} else if dyntyp.Kind() == reflect.Ptr {
ds.push(true /* dereference */, "", nil)
obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
ds.pop()
} else {
obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
}
} else {
// We leave obj alone here. That's because if obj
// represents an interface, it may have been embued
// with type information in decodeInterface, and we
// don't want to destroy that information.
}
case *pb.Object_SliceValue:
// It's okay to slice the array here, since the contents will
// still be provided later on. These semantics are a bit
// strange but they are handled in the Map.Barrier properly.
//
// The special semantics of zero ref apply here too.
if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
}
case *pb.Object_ArrayValue:
ds.decodeArray(os, obj, x.ArrayValue)
case *pb.Object_StructValue:
ds.decodeStruct(os, obj, x.StructValue)
case *pb.Object_MapValue:
ds.decodeMap(os, obj, x.MapValue)
case *pb.Object_InterfaceValue:
ds.decodeInterface(os, obj, x.InterfaceValue)
case *pb.Object_ByteArrayValue:
copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
case *pb.Object_Uint16ArrayValue:
// 16-bit slices are serialized as 32-bit slices.
// See object.proto for details.
s := x.Uint16ArrayValue.Values
t := obj.Slice(0, obj.Len()).Interface().([]uint16)
if len(t) != len(s) {
panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
}
for i := range s {
t[i] = uint16(s[i])
}
case *pb.Object_Uint32ArrayValue:
copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
case *pb.Object_Uint64ArrayValue:
copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
case *pb.Object_UintptrArrayValue:
copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
case *pb.Object_Int8ArrayValue:
copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
case *pb.Object_Int16ArrayValue:
// 16-bit slices are serialized as 32-bit slices.
// See object.proto for details.
s := x.Int16ArrayValue.Values
t := obj.Slice(0, obj.Len()).Interface().([]int16)
if len(t) != len(s) {
panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
}
for i := range s {
t[i] = int16(s[i])
}
case *pb.Object_Int32ArrayValue:
copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
case *pb.Object_Int64ArrayValue:
copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
case *pb.Object_BoolArrayValue:
copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
case *pb.Object_Float64ArrayValue:
copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
case *pb.Object_Float32ArrayValue:
copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
default:
// Shoud not happen, not propagated as an error.
panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
}
ds.stats.Done()
ds.pop()
}
func copyArray(dest reflect.Value, src reflect.Value) {
if dest.Len() != src.Len() {
panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
}
reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
}
// Deserialize deserializes the object state.
//
// This function may panic and should be run in safely().
func (ds *decodeState) Deserialize(obj reflect.Value) {
ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
ds.outstanding = 1 // The root object.
// Decode all objects in the stream.
//
// See above, we never process objects while we have no outstanding
// interests (other than the very first object).
for id := uint64(1); ds.outstanding > 0; id++ {
os := ds.lookup(id)
ds.stats.Start(os.obj)
o, err := ds.readObject()
if err != nil {
panic(err)
}
if os != nil {
// Decode the object.
ds.from = &os.path
ds.decodeObject(os, os.obj, o, "", nil)
ds.outstanding--
} else {
// If an object hasn't had interest registered
// previously, we deferred decoding until interest is
// registered.
ds.deferred[id] = o
}
ds.stats.Done()
}
// Check the zero-length header at the end.
length, object, err := ReadHeader(ds.r)
if err != nil {
panic(err)
}
if length != 0 {
panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
}
if object {
panic("expected non-object terminal")
}
// Check if we have any deferred objects.
if count := len(ds.deferred); count > 0 {
// Shoud not happen, not propagated as an error.
panic(fmt.Sprintf("still have %d deferred objects", count))
}
// Scan and fire all callbacks.
for _, os := range ds.objectsByID {
os.checkComplete(ds.stats)
}
// Check if we have any remaining dependency cycles.
for _, os := range ds.objectsByID {
if !os.complete() {
// This must be the result of a dependency cycle.
cycle := os.findCycle()
var buf bytes.Buffer
buf.WriteString("dependency cycle: {")
for i, cycleOS := range cycle {
if i > 0 {
buf.WriteString(" => ")
}
buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
}
buf.WriteString("}")
// Panic as an error; propagate to the caller.
panic(errors.New(string(buf.Bytes())))
}
}
}
type byteReader struct {
io.Reader
}
// ReadByte implements io.ByteReader.
func (br byteReader) ReadByte() (byte, error) {
var b [1]byte
n, err := br.Reader.Read(b[:])
if n > 0 {
return b[0], nil
} else if err != nil {
return 0, err
} else {
return 0, io.ErrUnexpectedEOF
}
}
// ReadHeader reads an object header.
//
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
// Read the header.
length, err = binary.ReadUvarint(byteReader{r})
if err != nil {
return
}
// Decode whether the object is valid.
object = length&0x1 != 0
length = length >> 1
return
}
// readObject reads an object from the stream.
func (ds *decodeState) readObject() (*pb.Object, error) {
// Read the header.
length, object, err := ReadHeader(ds.r)
if err != nil {
return nil, err
}
if !object {
return nil, fmt.Errorf("invalid object header")
}
// Read the object.
buf := make([]byte, length)
for done := 0; done < len(buf); {
n, err := ds.r.Read(buf[done:])
done += n
if n == 0 && err != nil {
return nil, err
}
}
// Unmarshal.
obj := new(pb.Object)
if err := proto.Unmarshal(buf, obj); err != nil {
return nil, err
}
return obj, nil
}