blob: 48578843f92b6e27f49694ad731c198d6bc966fe [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package golang
import (
"fmt"
"io"
"strconv"
"text/template"
fidlir "fidl/compiler/backend/types"
gidlir "gidl/ir"
gidlmixer "gidl/mixer"
)
var conformanceTmpl = template.Must(template.New("conformanceTmpls").Parse(`
package fidl_test
import (
"reflect"
"testing"
"fidl/conformance"
"syscall/zx/fidl"
)
{{ if .EncodeSuccessCases }}
func TestAllEncodeSuccessCases(t *testing.T) {
{{ range .EncodeSuccessCases }}
{
encodeSuccessCase{
name: {{ .Name }},
context: {{ .Context }},
input: &{{ .Value }},
bytes: {{ .Bytes }},
}.check(t)
}
{{ end }}
}
{{ end }}
{{ if .DecodeSuccessCases }}
func TestAllDecodeSuccessCases(t *testing.T) {
{{ range .DecodeSuccessCases }}
{
decodeSuccessCase{
name: {{ .Name }},
context: {{ .Context }},
input: &{{ .Value }},
bytes: {{ .Bytes }},
}.check(t)
}
{{ end }}
}
{{ end }}
{{ if .EncodeFailureCases }}
func TestAllEncodeFailureCases(t *testing.T) {
{{ range .EncodeFailureCases }}
{
encodeFailureCase{
name: {{ .Name }},
context: {{ .Context }},
input: &{{ .Value }},
code: {{ .ErrorCode }},
}.check(t)
}
{{ end }}
}
{{ end }}
{{ if .DecodeFailureCases }}
func TestAllDecodeFailureCases(t *testing.T) {
{{ range .DecodeFailureCases }}
{
decodeFailureCase{
name: {{ .Name }},
context: {{ .Context }},
valTyp: reflect.TypeOf((*{{ .ValueType }})(nil)),
bytes: {{ .Bytes }},
code: {{ .ErrorCode }},
}.check(t)
}
{{ end }}
}
{{ end }}
`))
type conformanceTmplInput struct {
EncodeSuccessCases []encodeSuccessCase
DecodeSuccessCases []decodeSuccessCase
EncodeFailureCases []encodeFailureCase
DecodeFailureCases []decodeFailureCase
}
type encodeSuccessCase struct {
Name, Context, Value, Bytes string
}
type decodeSuccessCase struct {
Name, Context, Value, Bytes string
}
type encodeFailureCase struct {
Name, Context, Value, ErrorCode string
}
type decodeFailureCase struct {
Name, Context, ValueType, Bytes, ErrorCode string
}
// GenerateConformanceTests generates Go tests.
func GenerateConformanceTests(wr io.Writer, gidl gidlir.All, fidl fidlir.Root) error {
schema := gidlmixer.BuildSchema(fidl)
encodeSuccessCases, err := encodeSuccessCases(gidl.EncodeSuccess, schema)
if err != nil {
return err
}
decodeSuccessCases, err := decodeSuccessCases(gidl.DecodeSuccess, schema)
if err != nil {
return err
}
encodeFailureCases, err := encodeFailureCases(gidl.EncodeFailure, schema)
if err != nil {
return err
}
decodeFailureCases, err := decodeFailureCases(gidl.DecodeFailure, schema)
if err != nil {
return err
}
input := conformanceTmplInput{
EncodeSuccessCases: encodeSuccessCases,
DecodeSuccessCases: decodeSuccessCases,
EncodeFailureCases: encodeFailureCases,
DecodeFailureCases: decodeFailureCases,
}
return withGoFmt{conformanceTmpl}.Execute(wr, input)
}
func marshalerContext(wireFormat gidlir.WireFormat) string {
switch wireFormat {
case gidlir.V1WireFormat:
return `fidl.MarshalerContext{
DecodeUnionsFromXUnionBytes: true,
EncodeUnionsAsXUnionBytes: true,
}`
default:
panic(fmt.Sprintf("unexpected wire format %v", wireFormat))
}
}
func encodeSuccessCases(gidlEncodeSuccesses []gidlir.EncodeSuccess, schema gidlmixer.Schema) ([]encodeSuccessCase, error) {
var encodeSuccessCases []encodeSuccessCase
for _, encodeSuccess := range gidlEncodeSuccesses {
decl, err := schema.ExtractDeclaration(encodeSuccess.Value)
if err != nil {
return nil, fmt.Errorf("encode success %s: %s", encodeSuccess.Name, err)
}
if gidlir.ContainsUnknownField(encodeSuccess.Value) {
continue
}
value := visit(encodeSuccess.Value, decl)
for _, encoding := range encodeSuccess.Encodings {
if !wireFormatSupported(encoding.WireFormat) {
continue
}
encodeSuccessCases = append(encodeSuccessCases, encodeSuccessCase{
Name: testCaseName(encodeSuccess.Name, encoding.WireFormat),
Context: marshalerContext(encoding.WireFormat),
Value: value,
Bytes: bytesBuilder(encoding.Bytes),
})
}
}
return encodeSuccessCases, nil
}
func decodeSuccessCases(gidlDecodeSuccesses []gidlir.DecodeSuccess, schema gidlmixer.Schema) ([]decodeSuccessCase, error) {
var decodeSuccessCases []decodeSuccessCase
for _, decodeSuccess := range gidlDecodeSuccesses {
decl, err := schema.ExtractDeclaration(decodeSuccess.Value)
if err != nil {
return nil, fmt.Errorf("decode success %s: %s", decodeSuccess.Name, err)
}
if gidlir.ContainsUnknownField(decodeSuccess.Value) {
continue
}
value := visit(decodeSuccess.Value, decl)
for _, encoding := range decodeSuccess.Encodings {
if !wireFormatSupported(encoding.WireFormat) {
continue
}
decodeSuccessCases = append(decodeSuccessCases, decodeSuccessCase{
Name: testCaseName(decodeSuccess.Name, encoding.WireFormat),
Context: marshalerContext(encoding.WireFormat),
Value: value,
Bytes: bytesBuilder(encoding.Bytes),
})
}
}
return decodeSuccessCases, nil
}
func encodeFailureCases(gidlEncodeFailures []gidlir.EncodeFailure, schema gidlmixer.Schema) ([]encodeFailureCase, error) {
var encodeFailureCases []encodeFailureCase
for _, encodeFailure := range gidlEncodeFailures {
decl, err := schema.ExtractDeclarationUnsafe(encodeFailure.Value)
if err != nil {
return nil, fmt.Errorf("encode failure %s: %s", encodeFailure.Name, err)
}
if gidlir.ContainsUnknownField(encodeFailure.Value) {
continue
}
code, err := goErrorCode(encodeFailure.Err)
if err != nil {
return nil, fmt.Errorf("encode failure %s: %s", encodeFailure.Name, err)
}
value := visit(encodeFailure.Value, decl)
for _, wireFormat := range encodeFailure.WireFormats {
if !wireFormatSupported(wireFormat) {
continue
}
encodeFailureCases = append(encodeFailureCases, encodeFailureCase{
Name: testCaseName(encodeFailure.Name, wireFormat),
Context: marshalerContext(wireFormat),
Value: value,
ErrorCode: code,
})
}
}
return encodeFailureCases, nil
}
func decodeFailureCases(gidlDecodeFailures []gidlir.DecodeFailure, schema gidlmixer.Schema) ([]decodeFailureCase, error) {
var decodeFailureCases []decodeFailureCase
for _, decodeFailure := range gidlDecodeFailures {
decl, err := schema.ExtractDeclarationByName(decodeFailure.Type)
if err != nil {
return nil, fmt.Errorf("decode failure %s: %s", decodeFailure.Name, err)
}
code, err := goErrorCode(decodeFailure.Err)
if err != nil {
return nil, fmt.Errorf("decode failure %s: %s", decodeFailure.Name, err)
}
valueType := declName(decl)
for _, encoding := range decodeFailure.Encodings {
if !wireFormatSupported(encoding.WireFormat) {
continue
}
decodeFailureCases = append(decodeFailureCases, decodeFailureCase{
Name: testCaseName(decodeFailure.Name, encoding.WireFormat),
Context: marshalerContext(encoding.WireFormat),
ValueType: valueType,
Bytes: bytesBuilder(encoding.Bytes),
ErrorCode: code,
})
}
}
return decodeFailureCases, nil
}
func wireFormatSupported(wireFormat gidlir.WireFormat) bool {
return wireFormat == gidlir.V1WireFormat
}
func testCaseName(baseName string, wireFormat gidlir.WireFormat) string {
return strconv.Quote(fmt.Sprintf("%s_%s", baseName, wireFormat))
}