| // Copyright 2018 The gVisor Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package state |
| |
| import ( |
| "bytes" |
| "context" |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "io" |
| "reflect" |
| "sort" |
| |
| "github.com/golang/protobuf/proto" |
| pb "gvisor.dev/gvisor/pkg/state/object_go_proto" |
| ) |
| |
| // objectState represents an object that may be in the process of being |
| // decoded. Specifically, it represents either a decoded object, or an an |
| // interest in a future object that will be decoded. When that interest is |
| // registered (via register), the storage for the object will be created, but |
| // it will not be decoded until the object is encountered in the stream. |
| type objectState struct { |
| // id is the id for this object. |
| // |
| // If this field is zero, then this is an anonymous (unregistered, |
| // non-reference primitive) object. This is immutable. |
| id uint64 |
| |
| // obj is the object. This may or may not be valid yet, depending on |
| // whether complete returns true. However, regardless of whether the |
| // object is valid, obj contains a final storage location for the |
| // object. This is immutable. |
| // |
| // Note that this must be addressable (obj.Addr() must not panic). |
| // |
| // The obj passed to the decode methods below will equal this obj only |
| // in the case of decoding the top-level object. However, the passed |
| // obj may represent individual fields, elements of a slice, etc. that |
| // are effectively embedded within the reflect.Value below but with |
| // distinct types. |
| obj reflect.Value |
| |
| // blockedBy is the number of dependencies this object has. |
| blockedBy int |
| |
| // blocking is a list of the objects blocked by this one. |
| blocking []*objectState |
| |
| // callbacks is a set of callbacks to execute on load. |
| callbacks []func() |
| |
| // path is the decoding path to the object. |
| path recoverable |
| } |
| |
| // complete indicates the object is complete. |
| func (os *objectState) complete() bool { |
| return os.blockedBy == 0 && len(os.callbacks) == 0 |
| } |
| |
| // checkComplete checks for completion. If the object is complete, pending |
| // callbacks will be executed and checkComplete will be called on downstream |
| // objects (those depending on this one). |
| func (os *objectState) checkComplete(stats *Stats) { |
| if os.blockedBy > 0 { |
| return |
| } |
| stats.Start(os.obj) |
| |
| // Fire all callbacks. |
| for _, fn := range os.callbacks { |
| fn() |
| } |
| os.callbacks = nil |
| |
| // Clear all blocked objects. |
| for _, other := range os.blocking { |
| other.blockedBy-- |
| other.checkComplete(stats) |
| } |
| os.blocking = nil |
| stats.Done() |
| } |
| |
| // waitFor queues a dependency on the given object. |
| func (os *objectState) waitFor(other *objectState, callback func()) { |
| os.blockedBy++ |
| other.blocking = append(other.blocking, os) |
| if callback != nil { |
| other.callbacks = append(other.callbacks, callback) |
| } |
| } |
| |
| // findCycleFor returns when the given object is found in the blocking set. |
| func (os *objectState) findCycleFor(target *objectState) []*objectState { |
| for _, other := range os.blocking { |
| if other == target { |
| return []*objectState{target} |
| } else if childList := other.findCycleFor(target); childList != nil { |
| return append(childList, other) |
| } |
| } |
| return nil |
| } |
| |
| // findCycle finds a dependency cycle. |
| func (os *objectState) findCycle() []*objectState { |
| return append(os.findCycleFor(os), os) |
| } |
| |
| // decodeState is a graph of objects in the process of being decoded. |
| // |
| // The decode process involves loading the breadth-first graph generated by |
| // encode. This graph is read in it's entirety, ensuring that all object |
| // storage is complete. |
| // |
| // As the graph is being serialized, a set of completion callbacks are |
| // executed. These completion callbacks should form a set of acyclic subgraphs |
| // over the original one. After decoding is complete, the objects are scanned |
| // to ensure that all callbacks are executed, otherwise the callback graph was |
| // not acyclic. |
| type decodeState struct { |
| // ctx is the decode context. |
| ctx context.Context |
| |
| // objectByID is the set of objects in progress. |
| objectsByID map[uint64]*objectState |
| |
| // deferred are objects that have been read, by no interest has been |
| // registered yet. These will be decoded once interest in registered. |
| deferred map[uint64]*pb.Object |
| |
| // outstanding is the number of outstanding objects. |
| outstanding uint32 |
| |
| // r is the input stream. |
| r io.Reader |
| |
| // stats is the passed stats object. |
| stats *Stats |
| |
| // recoverable is the panic recover facility. |
| recoverable |
| } |
| |
| // lookup looks up an object in decodeState or returns nil if no such object |
| // has been previously registered. |
| func (ds *decodeState) lookup(id uint64) *objectState { |
| return ds.objectsByID[id] |
| } |
| |
| // wait registers a dependency on an object. |
| // |
| // As a special case, we always allow _useable_ references back to the first |
| // decoding object because it may have fields that are already decoded. We also |
| // allow trivial self reference, since they can be handled internally. |
| func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { |
| switch id { |
| case 0: |
| // Nil pointer; nothing to wait for. |
| fallthrough |
| case waiter.id: |
| // Trivial self reference. |
| fallthrough |
| case 1: |
| // Root object; see above. |
| if callback != nil { |
| callback() |
| } |
| return |
| } |
| |
| // No nil can be returned here. |
| waiter.waitFor(ds.lookup(id), callback) |
| } |
| |
| // waitObject notes a blocking relationship. |
| func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) { |
| if rv, ok := p.Value.(*pb.Object_RefValue); ok { |
| // Refs can encode pointers and maps. |
| ds.wait(os, rv.RefValue, callback) |
| } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok { |
| // See decodeObject; we need to wait for the array (if non-nil). |
| ds.wait(os, sv.SliceValue.RefValue, callback) |
| } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok { |
| // It's an interface (wait recurisvely). |
| ds.waitObject(os, iv.InterfaceValue.Value, callback) |
| } else if callback != nil { |
| // Nothing to wait for: execute the callback immediately. |
| callback() |
| } |
| } |
| |
| // register registers a decode with a type. |
| // |
| // This type is only used to instantiate a new object if it has not been |
| // registered previously. |
| func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState { |
| os, ok := ds.objectsByID[id] |
| if ok { |
| return os |
| } |
| |
| // Record in the object index. |
| if typ.Kind() == reflect.Map { |
| os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()} |
| } else { |
| os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()} |
| } |
| ds.objectsByID[id] = os |
| |
| if o, ok := ds.deferred[id]; ok { |
| // There is a deferred object. |
| delete(ds.deferred, id) // Free memory. |
| ds.decodeObject(os, os.obj, o, "", nil) |
| } else { |
| // There is no deferred object. |
| ds.outstanding++ |
| } |
| |
| return os |
| } |
| |
| // decodeStruct decodes a struct value. |
| func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) { |
| // Set the fields. |
| m := Map{newInternalMap(nil, ds, os)} |
| defer internalMapPool.Put(m.internalMap) |
| for _, field := range s.Fields { |
| m.data = append(m.data, entry{ |
| name: field.Name, |
| object: field.Value, |
| }) |
| } |
| |
| // Sort the fields for efficient searching. |
| // |
| // Technically, these should already appear in sorted order in the |
| // state ordering, so this cost is effectively a single scan to ensure |
| // that the order is correct. |
| if len(m.data) > 1 { |
| sort.Slice(m.data, func(i, j int) bool { |
| return m.data[i].name < m.data[j].name |
| }) |
| } |
| |
| // Invoke the load; this will recursively decode other objects. |
| fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) |
| if ok { |
| // Invoke the loader. |
| fns.invokeLoad(obj.Addr(), m) |
| } else if obj.NumField() == 0 { |
| // Allow anonymous empty structs. |
| return |
| } else { |
| // Propagate an error. |
| panic(fmt.Errorf("unregistered type %s", obj.Type())) |
| } |
| } |
| |
| // decodeMap decodes a map value. |
| func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) { |
| if obj.IsNil() { |
| obj.Set(reflect.MakeMap(obj.Type())) |
| } |
| for i := 0; i < len(m.Keys); i++ { |
| // Decode the objects. |
| kv := reflect.New(obj.Type().Key()).Elem() |
| vv := reflect.New(obj.Type().Elem()).Elem() |
| ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i) |
| ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface()) |
| ds.waitObject(os, m.Keys[i], nil) |
| ds.waitObject(os, m.Values[i], nil) |
| |
| // Set in the map. |
| obj.SetMapIndex(kv, vv) |
| } |
| } |
| |
| // decodeArray decodes an array value. |
| func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) { |
| if len(a.Contents) != obj.Len() { |
| panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents))) |
| } |
| // Decode the contents into the array. |
| for i := 0; i < len(a.Contents); i++ { |
| ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i) |
| ds.waitObject(os, a.Contents[i], nil) |
| } |
| } |
| |
| // decodeInterface decodes an interface value. |
| func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) { |
| // Is this a nil value? |
| if i.Type == "" { |
| return // Just leave obj alone. |
| } |
| |
| // Get the dispatchable type. This may not be used if the given |
| // reference has already been resolved, but if not we need to know the |
| // type to create. |
| t, ok := registeredTypes.lookupType(i.Type) |
| if !ok { |
| panic(fmt.Errorf("no valid type for %q", i.Type)) |
| } |
| |
| if obj.Kind() != reflect.Map { |
| // Set the obj to be the given typed value; this actually sets |
| // obj to be a non-zero value -- namely, it inserts type |
| // information. There's no need to do this for maps. |
| obj.Set(reflect.Zero(t)) |
| } |
| |
| // Decode the dereferenced element; there is no need to wait here, as |
| // the interface object shares the current object state. |
| ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type) |
| } |
| |
| // decodeObject decodes a object value. |
| func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) { |
| ds.push(false, format, param) |
| ds.stats.Add(obj) |
| ds.stats.Start(obj) |
| |
| switch x := object.GetValue().(type) { |
| case *pb.Object_BoolValue: |
| obj.SetBool(x.BoolValue) |
| case *pb.Object_StringValue: |
| obj.SetString(string(x.StringValue)) |
| case *pb.Object_Int64Value: |
| obj.SetInt(x.Int64Value) |
| if obj.Int() != x.Int64Value { |
| panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type())) |
| } |
| case *pb.Object_Uint64Value: |
| obj.SetUint(x.Uint64Value) |
| if obj.Uint() != x.Uint64Value { |
| panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type())) |
| } |
| case *pb.Object_DoubleValue: |
| obj.SetFloat(x.DoubleValue) |
| if obj.Float() != x.DoubleValue { |
| panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type())) |
| } |
| case *pb.Object_RefValue: |
| // Resolve the pointer itself, even though the object may not |
| // be decoded yet. You need to use wait() in order to ensure |
| // that is the case. See wait above, and Map.Barrier. |
| if id := x.RefValue; id != 0 { |
| // Decoding the interface should have imparted type |
| // information, so from this point it's safe to resolve |
| // and use this dynamic information for actually |
| // creating the object in register. |
| // |
| // (For non-interfaces this is a no-op). |
| dyntyp := reflect.TypeOf(obj.Interface()) |
| if dyntyp.Kind() == reflect.Map { |
| // Remove the map object count here to avoid |
| // double counting, as this object will be |
| // counted again when it gets processed later. |
| // We do not add a reference count as the |
| // reference is artificial. |
| ds.stats.Remove(obj) |
| obj.Set(ds.register(id, dyntyp).obj) |
| } else if dyntyp.Kind() == reflect.Ptr { |
| ds.push(true /* dereference */, "", nil) |
| obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) |
| ds.pop() |
| } else { |
| obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) |
| } |
| } else { |
| // We leave obj alone here. That's because if obj |
| // represents an interface, it may have been embued |
| // with type information in decodeInterface, and we |
| // don't want to destroy that information. |
| } |
| case *pb.Object_SliceValue: |
| // It's okay to slice the array here, since the contents will |
| // still be provided later on. These semantics are a bit |
| // strange but they are handled in the Map.Barrier properly. |
| // |
| // The special semantics of zero ref apply here too. |
| if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 { |
| v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem()) |
| obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity))) |
| } |
| case *pb.Object_ArrayValue: |
| ds.decodeArray(os, obj, x.ArrayValue) |
| case *pb.Object_StructValue: |
| ds.decodeStruct(os, obj, x.StructValue) |
| case *pb.Object_MapValue: |
| ds.decodeMap(os, obj, x.MapValue) |
| case *pb.Object_InterfaceValue: |
| ds.decodeInterface(os, obj, x.InterfaceValue) |
| case *pb.Object_ByteArrayValue: |
| copyArray(obj, reflect.ValueOf(x.ByteArrayValue)) |
| case *pb.Object_Uint16ArrayValue: |
| // 16-bit slices are serialized as 32-bit slices. |
| // See object.proto for details. |
| s := x.Uint16ArrayValue.Values |
| t := obj.Slice(0, obj.Len()).Interface().([]uint16) |
| if len(t) != len(s) { |
| panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) |
| } |
| for i := range s { |
| t[i] = uint16(s[i]) |
| } |
| case *pb.Object_Uint32ArrayValue: |
| copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values)) |
| case *pb.Object_Uint64ArrayValue: |
| copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values)) |
| case *pb.Object_UintptrArrayValue: |
| copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) |
| case *pb.Object_Int8ArrayValue: |
| copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) |
| case *pb.Object_Int16ArrayValue: |
| // 16-bit slices are serialized as 32-bit slices. |
| // See object.proto for details. |
| s := x.Int16ArrayValue.Values |
| t := obj.Slice(0, obj.Len()).Interface().([]int16) |
| if len(t) != len(s) { |
| panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) |
| } |
| for i := range s { |
| t[i] = int16(s[i]) |
| } |
| case *pb.Object_Int32ArrayValue: |
| copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values)) |
| case *pb.Object_Int64ArrayValue: |
| copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values)) |
| case *pb.Object_BoolArrayValue: |
| copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values)) |
| case *pb.Object_Float64ArrayValue: |
| copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values)) |
| case *pb.Object_Float32ArrayValue: |
| copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values)) |
| default: |
| // Shoud not happen, not propagated as an error. |
| panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type())) |
| } |
| |
| ds.stats.Done() |
| ds.pop() |
| } |
| |
| func copyArray(dest reflect.Value, src reflect.Value) { |
| if dest.Len() != src.Len() { |
| panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len())) |
| } |
| reflect.Copy(dest, castSlice(src, dest.Type().Elem())) |
| } |
| |
| // Deserialize deserializes the object state. |
| // |
| // This function may panic and should be run in safely(). |
| func (ds *decodeState) Deserialize(obj reflect.Value) { |
| ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()} |
| ds.outstanding = 1 // The root object. |
| |
| // Decode all objects in the stream. |
| // |
| // See above, we never process objects while we have no outstanding |
| // interests (other than the very first object). |
| for id := uint64(1); ds.outstanding > 0; id++ { |
| os := ds.lookup(id) |
| ds.stats.Start(os.obj) |
| |
| o, err := ds.readObject() |
| if err != nil { |
| panic(err) |
| } |
| |
| if os != nil { |
| // Decode the object. |
| ds.from = &os.path |
| ds.decodeObject(os, os.obj, o, "", nil) |
| ds.outstanding-- |
| } else { |
| // If an object hasn't had interest registered |
| // previously, we deferred decoding until interest is |
| // registered. |
| ds.deferred[id] = o |
| } |
| |
| ds.stats.Done() |
| } |
| |
| // Check the zero-length header at the end. |
| length, object, err := ReadHeader(ds.r) |
| if err != nil { |
| panic(err) |
| } |
| if length != 0 { |
| panic(fmt.Sprintf("expected zero-length terminal, got %d", length)) |
| } |
| if object { |
| panic("expected non-object terminal") |
| } |
| |
| // Check if we have any deferred objects. |
| if count := len(ds.deferred); count > 0 { |
| // Shoud not happen, not propagated as an error. |
| panic(fmt.Sprintf("still have %d deferred objects", count)) |
| } |
| |
| // Scan and fire all callbacks. |
| for _, os := range ds.objectsByID { |
| os.checkComplete(ds.stats) |
| } |
| |
| // Check if we have any remaining dependency cycles. |
| for _, os := range ds.objectsByID { |
| if !os.complete() { |
| // This must be the result of a dependency cycle. |
| cycle := os.findCycle() |
| var buf bytes.Buffer |
| buf.WriteString("dependency cycle: {") |
| for i, cycleOS := range cycle { |
| if i > 0 { |
| buf.WriteString(" => ") |
| } |
| buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type())) |
| } |
| buf.WriteString("}") |
| // Panic as an error; propagate to the caller. |
| panic(errors.New(string(buf.Bytes()))) |
| } |
| } |
| } |
| |
| type byteReader struct { |
| io.Reader |
| } |
| |
| // ReadByte implements io.ByteReader. |
| func (br byteReader) ReadByte() (byte, error) { |
| var b [1]byte |
| n, err := br.Reader.Read(b[:]) |
| if n > 0 { |
| return b[0], nil |
| } else if err != nil { |
| return 0, err |
| } else { |
| return 0, io.ErrUnexpectedEOF |
| } |
| } |
| |
| // ReadHeader reads an object header. |
| // |
| // Each object written to the statefile is prefixed with a header. See |
| // WriteHeader for more information; these functions are exported to allow |
| // non-state writes to the file to play nice with debugging tools. |
| func ReadHeader(r io.Reader) (length uint64, object bool, err error) { |
| // Read the header. |
| length, err = binary.ReadUvarint(byteReader{r}) |
| if err != nil { |
| return |
| } |
| |
| // Decode whether the object is valid. |
| object = length&0x1 != 0 |
| length = length >> 1 |
| return |
| } |
| |
| // readObject reads an object from the stream. |
| func (ds *decodeState) readObject() (*pb.Object, error) { |
| // Read the header. |
| length, object, err := ReadHeader(ds.r) |
| if err != nil { |
| return nil, err |
| } |
| if !object { |
| return nil, fmt.Errorf("invalid object header") |
| } |
| |
| // Read the object. |
| buf := make([]byte, length) |
| for done := 0; done < len(buf); { |
| n, err := ds.r.Read(buf[done:]) |
| done += n |
| if n == 0 && err != nil { |
| return nil, err |
| } |
| } |
| |
| // Unmarshal. |
| obj := new(pb.Object) |
| if err := proto.Unmarshal(buf, obj); err != nil { |
| return nil, err |
| } |
| |
| return obj, nil |
| } |