| // Copyright 2020 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package golang |
| |
| import ( |
| "fmt" |
| "io" |
| "strconv" |
| "text/template" |
| |
| fidlir "fidl/compiler/backend/types" |
| gidlir "gidl/ir" |
| gidlmixer "gidl/mixer" |
| ) |
| |
| var conformanceTmpl = template.Must(template.New("conformanceTmpls").Parse(` |
| package fidl_test |
| |
| import ( |
| "reflect" |
| "testing" |
| |
| "fidl/conformance" |
| |
| "syscall/zx/fidl" |
| ) |
| |
| {{ if .EncodeSuccessCases }} |
| func TestAllEncodeSuccessCases(t *testing.T) { |
| {{ range .EncodeSuccessCases }} |
| { |
| encodeSuccessCase{ |
| name: {{ .Name }}, |
| context: {{ .Context }}, |
| input: &{{ .Value }}, |
| bytes: {{ .Bytes }}, |
| }.check(t) |
| } |
| {{ end }} |
| } |
| {{ end }} |
| |
| {{ if .DecodeSuccessCases }} |
| func TestAllDecodeSuccessCases(t *testing.T) { |
| {{ range .DecodeSuccessCases }} |
| { |
| decodeSuccessCase{ |
| name: {{ .Name }}, |
| context: {{ .Context }}, |
| input: &{{ .Value }}, |
| bytes: {{ .Bytes }}, |
| }.check(t) |
| } |
| {{ end }} |
| } |
| {{ end }} |
| |
| {{ if .EncodeFailureCases }} |
| func TestAllEncodeFailureCases(t *testing.T) { |
| {{ range .EncodeFailureCases }} |
| { |
| encodeFailureCase{ |
| name: {{ .Name }}, |
| context: {{ .Context }}, |
| input: &{{ .Value }}, |
| code: {{ .ErrorCode }}, |
| }.check(t) |
| } |
| {{ end }} |
| } |
| {{ end }} |
| |
| {{ if .DecodeFailureCases }} |
| func TestAllDecodeFailureCases(t *testing.T) { |
| {{ range .DecodeFailureCases }} |
| { |
| decodeFailureCase{ |
| name: {{ .Name }}, |
| context: {{ .Context }}, |
| valTyp: reflect.TypeOf((*{{ .ValueType }})(nil)), |
| bytes: {{ .Bytes }}, |
| code: {{ .ErrorCode }}, |
| }.check(t) |
| } |
| {{ end }} |
| } |
| {{ end }} |
| `)) |
| |
| type conformanceTmplInput struct { |
| EncodeSuccessCases []encodeSuccessCase |
| DecodeSuccessCases []decodeSuccessCase |
| EncodeFailureCases []encodeFailureCase |
| DecodeFailureCases []decodeFailureCase |
| } |
| |
| type encodeSuccessCase struct { |
| Name, Context, Value, Bytes string |
| } |
| |
| type decodeSuccessCase struct { |
| Name, Context, Value, Bytes string |
| } |
| |
| type encodeFailureCase struct { |
| Name, Context, Value, ErrorCode string |
| } |
| |
| type decodeFailureCase struct { |
| Name, Context, ValueType, Bytes, ErrorCode string |
| } |
| |
| // GenerateConformanceTests generates Go tests. |
| func GenerateConformanceTests(wr io.Writer, gidl gidlir.All, fidl fidlir.Root) error { |
| schema := gidlmixer.BuildSchema(fidl) |
| encodeSuccessCases, err := encodeSuccessCases(gidl.EncodeSuccess, schema) |
| if err != nil { |
| return err |
| } |
| decodeSuccessCases, err := decodeSuccessCases(gidl.DecodeSuccess, schema) |
| if err != nil { |
| return err |
| } |
| encodeFailureCases, err := encodeFailureCases(gidl.EncodeFailure, schema) |
| if err != nil { |
| return err |
| } |
| decodeFailureCases, err := decodeFailureCases(gidl.DecodeFailure, schema) |
| if err != nil { |
| return err |
| } |
| input := conformanceTmplInput{ |
| EncodeSuccessCases: encodeSuccessCases, |
| DecodeSuccessCases: decodeSuccessCases, |
| EncodeFailureCases: encodeFailureCases, |
| DecodeFailureCases: decodeFailureCases, |
| } |
| return withGoFmt{conformanceTmpl}.Execute(wr, input) |
| } |
| |
| func marshalerContext(wireFormat gidlir.WireFormat) string { |
| switch wireFormat { |
| case gidlir.V1WireFormat: |
| return `fidl.MarshalerContext{ |
| DecodeUnionsFromXUnionBytes: true, |
| EncodeUnionsAsXUnionBytes: true, |
| }` |
| default: |
| panic(fmt.Sprintf("unexpected wire format %v", wireFormat)) |
| } |
| } |
| |
| func encodeSuccessCases(gidlEncodeSuccesses []gidlir.EncodeSuccess, schema gidlmixer.Schema) ([]encodeSuccessCase, error) { |
| var encodeSuccessCases []encodeSuccessCase |
| for _, encodeSuccess := range gidlEncodeSuccesses { |
| decl, err := schema.ExtractDeclaration(encodeSuccess.Value) |
| if err != nil { |
| return nil, fmt.Errorf("encode success %s: %s", encodeSuccess.Name, err) |
| } |
| if gidlir.ContainsUnknownField(encodeSuccess.Value) { |
| continue |
| } |
| value := visit(encodeSuccess.Value, decl) |
| for _, encoding := range encodeSuccess.Encodings { |
| if !wireFormatSupported(encoding.WireFormat) { |
| continue |
| } |
| encodeSuccessCases = append(encodeSuccessCases, encodeSuccessCase{ |
| Name: testCaseName(encodeSuccess.Name, encoding.WireFormat), |
| Context: marshalerContext(encoding.WireFormat), |
| Value: value, |
| Bytes: bytesBuilder(encoding.Bytes), |
| }) |
| } |
| } |
| return encodeSuccessCases, nil |
| } |
| |
| func decodeSuccessCases(gidlDecodeSuccesses []gidlir.DecodeSuccess, schema gidlmixer.Schema) ([]decodeSuccessCase, error) { |
| var decodeSuccessCases []decodeSuccessCase |
| for _, decodeSuccess := range gidlDecodeSuccesses { |
| decl, err := schema.ExtractDeclaration(decodeSuccess.Value) |
| if err != nil { |
| return nil, fmt.Errorf("decode success %s: %s", decodeSuccess.Name, err) |
| } |
| if gidlir.ContainsUnknownField(decodeSuccess.Value) { |
| continue |
| } |
| value := visit(decodeSuccess.Value, decl) |
| for _, encoding := range decodeSuccess.Encodings { |
| if !wireFormatSupported(encoding.WireFormat) { |
| continue |
| } |
| decodeSuccessCases = append(decodeSuccessCases, decodeSuccessCase{ |
| Name: testCaseName(decodeSuccess.Name, encoding.WireFormat), |
| Context: marshalerContext(encoding.WireFormat), |
| Value: value, |
| Bytes: bytesBuilder(encoding.Bytes), |
| }) |
| } |
| } |
| return decodeSuccessCases, nil |
| } |
| |
| func encodeFailureCases(gidlEncodeFailures []gidlir.EncodeFailure, schema gidlmixer.Schema) ([]encodeFailureCase, error) { |
| var encodeFailureCases []encodeFailureCase |
| for _, encodeFailure := range gidlEncodeFailures { |
| decl, err := schema.ExtractDeclarationUnsafe(encodeFailure.Value) |
| if err != nil { |
| return nil, fmt.Errorf("encode failure %s: %s", encodeFailure.Name, err) |
| } |
| if gidlir.ContainsUnknownField(encodeFailure.Value) { |
| continue |
| } |
| code, err := goErrorCode(encodeFailure.Err) |
| if err != nil { |
| return nil, fmt.Errorf("encode failure %s: %s", encodeFailure.Name, err) |
| } |
| value := visit(encodeFailure.Value, decl) |
| for _, wireFormat := range encodeFailure.WireFormats { |
| if !wireFormatSupported(wireFormat) { |
| continue |
| } |
| encodeFailureCases = append(encodeFailureCases, encodeFailureCase{ |
| Name: testCaseName(encodeFailure.Name, wireFormat), |
| Context: marshalerContext(wireFormat), |
| Value: value, |
| ErrorCode: code, |
| }) |
| } |
| } |
| return encodeFailureCases, nil |
| } |
| |
| func decodeFailureCases(gidlDecodeFailures []gidlir.DecodeFailure, schema gidlmixer.Schema) ([]decodeFailureCase, error) { |
| var decodeFailureCases []decodeFailureCase |
| for _, decodeFailure := range gidlDecodeFailures { |
| decl, err := schema.ExtractDeclarationByName(decodeFailure.Type) |
| if err != nil { |
| return nil, fmt.Errorf("decode failure %s: %s", decodeFailure.Name, err) |
| } |
| code, err := goErrorCode(decodeFailure.Err) |
| if err != nil { |
| return nil, fmt.Errorf("decode failure %s: %s", decodeFailure.Name, err) |
| } |
| valueType := declName(decl) |
| for _, encoding := range decodeFailure.Encodings { |
| if !wireFormatSupported(encoding.WireFormat) { |
| continue |
| } |
| decodeFailureCases = append(decodeFailureCases, decodeFailureCase{ |
| Name: testCaseName(decodeFailure.Name, encoding.WireFormat), |
| Context: marshalerContext(encoding.WireFormat), |
| ValueType: valueType, |
| Bytes: bytesBuilder(encoding.Bytes), |
| ErrorCode: code, |
| }) |
| } |
| } |
| return decodeFailureCases, nil |
| } |
| |
| func wireFormatSupported(wireFormat gidlir.WireFormat) bool { |
| return wireFormat == gidlir.V1WireFormat |
| } |
| |
| func testCaseName(baseName string, wireFormat gidlir.WireFormat) string { |
| return strconv.Quote(fmt.Sprintf("%s_%s", baseName, wireFormat)) |
| } |