| // Copyright 2022 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package fuchsia_controller |
| |
| import ( |
| "bytes" |
| _ "embed" |
| "fmt" |
| "strings" |
| "text/template" |
| |
| "go.fuchsia.dev/fuchsia/tools/fidl/gidl/lib/config" |
| "go.fuchsia.dev/fuchsia/tools/fidl/gidl/lib/ir" |
| "go.fuchsia.dev/fuchsia/tools/fidl/gidl/lib/mixer" |
| "go.fuchsia.dev/fuchsia/tools/fidl/lib/fidlgen" |
| ) |
| |
| var ( |
| //go:embed conformance.tmpl |
| conformanceTmplText string |
| conformanceTmpl = template.Must(template.New("conformanceTmpl").Parse(conformanceTmplText)) |
| ) |
| |
| type conformanceTmplInput struct { |
| EncodeSuccessCases []encodeSuccessCase |
| DecodeSuccessCases []decodeSuccessCase |
| EncodeFailureCases []encodeFailureCase |
| DecodeFailureCases []decodeFailureCase |
| } |
| |
| type encodeSuccessCase struct { |
| Name, Context, HandleDefs, Handles, HandleDispositions, Value, Bytes string |
| } |
| |
| type decodeSuccessCase struct { |
| Name, Context, HandleDefs, Handles, HandleDispositions, ValueType, Bytes, EqualityCheck string |
| } |
| |
| type encodeFailureCase struct{} |
| |
| type decodeFailureCase struct{} |
| |
| func GenerateConformanceTests(gidl ir.All, fidl fidlgen.Root, config config.GeneratorConfig) ([]byte, error) { |
| schema := mixer.BuildSchema(fidl) |
| encodeSuccessCases, err := encodeSuccessCases(gidl.EncodeSuccess, schema) |
| if err != nil { |
| return nil, err |
| } |
| decodeSuccessCases, err := decodeSuccessCases(gidl.DecodeSuccess, schema) |
| if err != nil { |
| return nil, err |
| } |
| var buf bytes.Buffer |
| err = conformanceTmpl.Execute(&buf, conformanceTmplInput{ |
| EncodeSuccessCases: encodeSuccessCases, |
| DecodeSuccessCases: decodeSuccessCases, |
| }) |
| return buf.Bytes(), err |
| } |
| |
| func declName(decl mixer.NamedDeclaration) string { |
| return identifierName(decl.Name()) |
| } |
| |
| func identifierName(qualifiedName string) string { |
| parts := strings.Split(qualifiedName, "/") |
| library_parts := strings.Split(parts[0], ".") |
| return fmt.Sprintf("%s.%s", strings.Join(library_parts, "_"), |
| fidlgen.ToUpperCamelCase(parts[1])) |
| } |
| |
| func encodeSuccessCases(gidlEncodeSuccesses []ir.EncodeSuccess, schema mixer.Schema) ([]encodeSuccessCase, error) { |
| var encodeSuccessCases []encodeSuccessCase |
| for _, encodeSuccess := range gidlEncodeSuccesses { |
| decl, err := schema.ExtractDeclarationEncodeSuccess(encodeSuccess.Value, encodeSuccess.HandleDefs) |
| if err != nil { |
| return nil, fmt.Errorf("encode success %s: %s", encodeSuccess.Name, err) |
| } |
| value := visit(encodeSuccess.Value, decl) |
| for _, encoding := range encodeSuccess.Encodings { |
| if !wireFormatSupported(encoding.WireFormat) { |
| continue |
| } |
| newCase := encodeSuccessCase{ |
| Name: testCaseName(encodeSuccess.Name, encoding.WireFormat), |
| Context: encodingContext(encoding.WireFormat), |
| HandleDefs: buildHandleDefs(encodeSuccess.HandleDefs), |
| Value: value, |
| Bytes: buildBytes(encoding.Bytes), |
| } |
| if len(newCase.HandleDefs) != 0 { |
| if encodeSuccess.CheckHandleRights { |
| newCase.HandleDispositions = buildRawHandleDispositions(encoding.HandleDispositions) |
| } else { |
| newCase.Handles = buildRawHandles(encoding.HandleDispositions) |
| } |
| } |
| encodeSuccessCases = append(encodeSuccessCases, newCase) |
| } |
| } |
| return encodeSuccessCases, nil |
| } |
| |
| func decodeSuccessCases(gidlDecodeSuccesses []ir.DecodeSuccess, schema mixer.Schema) ([]decodeSuccessCase, error) { |
| var decodeSuccessCases []decodeSuccessCase |
| for _, decodeSuccess := range gidlDecodeSuccesses { |
| decl, err := schema.ExtractDeclaration(decodeSuccess.Value, decodeSuccess.HandleDefs) |
| if err != nil { |
| return nil, fmt.Errorf("decode success %s: %s", decodeSuccess.Name, err) |
| } |
| equalityCheck := buildEqualityCheck(decodeSuccess.Value, decl) |
| for _, encoding := range decodeSuccess.Encodings { |
| if !wireFormatSupported(encoding.WireFormat) { |
| continue |
| } |
| decodeSuccessCases = append(decodeSuccessCases, decodeSuccessCase{ |
| Name: testCaseName(decodeSuccess.Name, encoding.WireFormat), |
| Context: encodingContext(encoding.WireFormat), |
| HandleDefs: buildHandleDefs(decodeSuccess.HandleDefs), |
| Handles: buildHandles(encoding.Handles), |
| ValueType: decl.Name(), |
| Bytes: buildBytes(encoding.Bytes), |
| EqualityCheck: equalityCheck, |
| }) |
| } |
| } |
| return decodeSuccessCases, nil |
| } |
| |
| func handleTypeName(subtype fidlgen.HandleSubtype) string { |
| switch subtype { |
| case fidlgen.HandleSubtypeNone: |
| return "Handle" |
| case fidlgen.HandleSubtypeChannel: |
| return "Channel" |
| case fidlgen.HandleSubtypeEvent: |
| return "Event" |
| default: |
| panic(fmt.Sprintf("unsupported handle subtype: %s", subtype)) |
| } |
| } |
| |
| func buildHandleDefs(defs []ir.HandleDef) string { |
| if len(defs) == 0 { |
| return "" |
| } |
| var builder strings.Builder |
| builder.WriteString("[\n") |
| for _, d := range defs { |
| builder.WriteString(fmt.Sprintf("create_handle(fuchsia_controller_py.%s),\n", handleTypeName(d.Subtype))) |
| } |
| builder.WriteString("]") |
| return builder.String() |
| } |
| |
| func buildHandles(handles []ir.Handle) string { |
| var builder strings.Builder |
| builder.WriteString("[\n") |
| for i, h := range handles { |
| builder.WriteString(fmt.Sprintf("%d,", h)) |
| if i%8 == 7 { |
| builder.WriteString("\n") |
| } |
| } |
| builder.WriteString("]") |
| return builder.String() |
| } |
| |
| func buildRawHandleDispositions(defs []ir.HandleDisposition) string { |
| if len(defs) == 0 { |
| return "" |
| } |
| var builder strings.Builder |
| builder.WriteString("[") |
| for _, d := range defs { |
| // MOVE operation at idx 0, result ZX_OK at last idx. |
| builder.WriteString(fmt.Sprintf("(0, handle_defs[%d].as_int(), %d, %d, 0),", d.Handle, d.Type, d.Rights)) |
| } |
| builder.WriteString("]") |
| return builder.String() |
| } |
| |
| func buildRawHandles(defs []ir.HandleDisposition) string { |
| if len(defs) == 0 { |
| return "" |
| } |
| var builder strings.Builder |
| builder.WriteString("[") |
| for i, d := range defs { |
| builder.WriteString(fmt.Sprintf("%d,", d.Handle)) |
| if i%8 == 7 { |
| builder.WriteString("\n") |
| } |
| } |
| builder.WriteString("]") |
| return builder.String() |
| } |
| |
| func testCaseName(baseName string, wireFormat ir.WireFormat) string { |
| return fidlgen.ToSnakeCase(fmt.Sprintf("%s_%s", baseName, wireFormat)) |
| } |
| |
| var supportedWireFormats = []ir.WireFormat{ |
| ir.V2WireFormat, |
| } |
| |
| func wireFormatSupported(wireFormat ir.WireFormat) bool { |
| for _, wf := range supportedWireFormats { |
| if wireFormat == wf { |
| return true |
| } |
| } |
| return false |
| } |
| |
| func encodingContext(wireFormat ir.WireFormat) string { |
| switch wireFormat { |
| case ir.V2WireFormat: |
| return "_V2_CONTEXT" |
| default: |
| panic(fmt.Sprintf("unexpected wire format %v", wireFormat)) |
| } |
| } |
| |
| func primitiveTypeName(subtype fidlgen.PrimitiveSubtype) string { |
| switch subtype { |
| case fidlgen.Int8, fidlgen.Uint8, fidlgen.Int16, fidlgen.Uint16, |
| fidlgen.Int32, fidlgen.Uint32, fidlgen.Int64, fidlgen.Uint64: |
| return "int" |
| case fidlgen.Float32, fidlgen.Float64: |
| return "float" |
| case fidlgen.Bool: |
| return "bool" |
| default: |
| panic(fmt.Sprintf("unexpected subtype %v", subtype)) |
| } |
| } |
| |
| func formatPyBool(value bool) string { |
| if value { |
| return "True" |
| } |
| return "False" |
| } |
| |
| func onStruct(value ir.Record, decl *mixer.StructDecl) string { |
| var structFields []string |
| providedKeys := make(map[string]struct{}, len(value.Fields)) |
| for _, field := range value.Fields { |
| if field.Key.IsUnknown() { |
| panic(fmt.Sprintf("unknown field not supported %+v", field.Key)) |
| } |
| providedKeys[field.Key.Name] = struct{}{} |
| fieldName := fidlgen.ToSnakeCase(field.Key.Name) |
| fieldValueStr := visit(field.Value, decl.Field(field.Key.Name)) |
| structFields = append(structFields, fmt.Sprintf("%s=%s", fieldName, fieldValueStr)) |
| } |
| for _, key := range decl.FieldNames() { |
| if _, ok := providedKeys[key]; !ok { |
| fieldName := fidlgen.ToSnakeCase(key) |
| structFields = append(structFields, fmt.Sprintf("%s=None", fieldName)) |
| } |
| } |
| valueStr := fmt.Sprintf("%s(%s)", declName(decl), strings.Join(structFields, ", ")) |
| return valueStr |
| } |
| |
| func onTable(value ir.Record, decl *mixer.TableDecl) string { |
| var tableFields []string |
| for _, field := range value.Fields { |
| if field.Key.IsUnknown() { |
| panic(fmt.Sprintf("table %s: unknown ordinal %d: Rust cannot construct tables with unknown fields", |
| decl.Name(), field.Key.UnknownOrdinal)) |
| } |
| fieldName := fidlgen.ToSnakeCase(field.Key.Name) |
| fieldValueStr := visit(field.Value, decl.Field(field.Key.Name)) |
| tableFields = append(tableFields, fmt.Sprintf("%s=%s", fieldName, fieldValueStr)) |
| } |
| tableName := declName(decl) |
| valueStr := fmt.Sprintf("%s(%s)", tableName, strings.Join(tableFields, ", ")) |
| return valueStr |
| } |
| |
| func onUnion(value ir.Record, decl *mixer.UnionDecl) string { |
| if len(value.Fields) != 1 { |
| panic(fmt.Sprintf("union has %d fields, expected 1", len(value.Fields))) |
| } |
| field := value.Fields[0] |
| var valueStr string |
| if field.Key.IsUnknown() { |
| if field.Key.UnknownOrdinal != 0 { |
| panic(fmt.Sprintf("union %s: unknown ordinal %d: Rust can only construct unknowns with the ordinal 0", |
| decl.Name(), field.Key.UnknownOrdinal)) |
| } |
| if field.Value != nil { |
| panic(fmt.Sprintf("union %s: unknown ordinal %d: Rust cannot construct union with unknown bytes/handles", |
| decl.Name(), field.Key.UnknownOrdinal)) |
| } |
| valueStr = fmt.Sprintf("%s()", declName(decl)) |
| } else { |
| fieldName := fidlgen.ToSnakeCase(field.Key.Name) |
| fieldValueStr := visit(field.Value, decl.Field(field.Key.Name)) |
| valueStr = fmt.Sprintf("%s.%s_variant(%s)", declName(decl), fieldName, fieldValueStr) |
| } |
| return valueStr |
| } |
| |
| func onList(value []ir.Value, decl mixer.ListDeclaration) string { |
| var elements []string |
| elemDecl := decl.Elem() |
| for _, item := range value { |
| elements = append(elements, visit(item, elemDecl)) |
| } |
| elementsStr := strings.Join(elements, ", ") |
| var valueStr string |
| switch decl.(type) { |
| case *mixer.ArrayDecl: |
| valueStr = fmt.Sprintf("[%s]", elementsStr) |
| case *mixer.VectorDecl: |
| valueStr = fmt.Sprintf("[%s]", elementsStr) |
| default: |
| panic(fmt.Sprintf("unexpected decl %v", decl)) |
| } |
| return valueStr |
| } |
| |
| func visit(value ir.Value, decl mixer.Declaration) string { |
| switch value := value.(type) { |
| case bool: |
| return formatPyBool(value) |
| case int64, uint64, float64: |
| switch decl := decl.(type) { |
| case mixer.PrimitiveDeclaration: |
| return fmt.Sprintf("%v", value) |
| case *mixer.BitsDecl: |
| primitive := visit(value, &decl.Underlying) |
| return fmt.Sprintf("%v", primitive) |
| case *mixer.EnumDecl: |
| primitive := visit(value, &decl.Underlying) |
| return fmt.Sprintf("%v", primitive) |
| } |
| case ir.RawFloat: |
| switch decl.(*mixer.FloatDecl).Subtype() { |
| case fidlgen.Float32: |
| return fmt.Sprintf("struct.unpack('>f', bytes.fromhex('%08x'))[0]", value) |
| case fidlgen.Float64: |
| return fmt.Sprintf("struct.unpack('>d', bytes.fromhex('%016x'))[0]", value) |
| } |
| case string: |
| return fmt.Sprintf("%q", value) |
| case nil: |
| if !decl.IsNullable() { |
| if _, ok := decl.(*mixer.HandleDecl); ok { |
| return "0" |
| } |
| panic(fmt.Sprintf("got nil for non-nullable type: %T", decl)) |
| } |
| return "None" |
| case ir.Handle: |
| return fmt.Sprintf("handle_defs[%d].as_int()", int(value)) |
| case ir.Record: |
| switch decl := decl.(type) { |
| case *mixer.StructDecl: |
| return onStruct(value, decl) |
| case *mixer.TableDecl: |
| return onTable(value, decl) |
| case *mixer.UnionDecl: |
| return onUnion(value, decl) |
| } |
| case []ir.Value: |
| switch decl := decl.(type) { |
| case *mixer.ArrayDecl: |
| return onList(value, decl) |
| case *mixer.VectorDecl: |
| return onList(value, decl) |
| } |
| } |
| panic(fmt.Sprintf("not implemented: %T", value)) |
| } |
| |
| func buildBytes(bytes []byte) string { |
| var builder strings.Builder |
| builder.WriteString("bytearray([\n") |
| for i, b := range bytes { |
| builder.WriteString(fmt.Sprintf("0x%02x,", b)) |
| if i%8 == 7 { |
| builder.WriteString("\n") |
| } |
| } |
| builder.WriteString("\n])") |
| return builder.String() |
| } |