blob: 6688df57984fe1cda0a98b64f27b449a9b2afd4b [file] [log] [blame]
// Copyright 2021 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package hlcpp
import (
"fmt"
"strings"
gidlir "go.fuchsia.dev/fuchsia/tools/fidl/gidl/ir"
gidlmixer "go.fuchsia.dev/fuchsia/tools/fidl/gidl/mixer"
"go.fuchsia.dev/fuchsia/tools/fidl/lib/fidlgen"
)
func BuildEqualityCheck(actualExpr string, expectedValue interface{}, decl gidlmixer.Declaration, handleKoidVectorName string) string {
builder := equalityCheckBuilder{
handleKoidVectorName: handleKoidVectorName,
}
builder.visit(actualExpr, expectedValue, decl)
return builder.String()
}
// Generator of new variable names from a sequence.
type varSeq int
func (v *varSeq) next() string {
*v++
return fmt.Sprintf("f%d", *v)
}
type equalityCheckBuilder struct {
strings.Builder
varSeq varSeq
// Name of a C++ variable containing an vector of zx_koid_t of handle values
// This is read-only and is used for checking handle koid equality.
handleKoidVectorName string
}
func (b *equalityCheckBuilder) write(format string, vals ...interface{}) {
b.WriteString(fmt.Sprintf(format, vals...))
}
func (b *equalityCheckBuilder) createAndAssignVar(val string) string {
varName := b.varSeq.next()
b.write("[[maybe_unused]] auto& %s = %s;\n", varName, val)
return varName
}
func (b *equalityCheckBuilder) construct(typename string, fmtStr string, args ...interface{}) string {
return fmt.Sprintf("%s(%s)", typename, fmt.Sprintf(fmtStr, args...))
}
func (b *equalityCheckBuilder) assertEquals(actual, expected string) {
b.write("ASSERT_EQ(%s, %s);\n", actual, expected)
}
func (b *equalityCheckBuilder) assertStringEquals(actual, expected string) {
b.write("ASSERT_STR_EQ(%s, %s);\n", actual, expected)
}
func (b *equalityCheckBuilder) assertNotEquals(actual, expected string) {
b.write("ASSERT_NE(%s, %s);\n", actual, expected)
}
func (b *equalityCheckBuilder) assertFalse(value string) {
b.write("ASSERT_FALSE(%s);\n", value)
}
func (b *equalityCheckBuilder) assertTrue(value string) {
b.write("ASSERT_TRUE(%s);\n", value)
}
func (b *equalityCheckBuilder) assertNull(value string) {
b.write("ASSERT_NULL(%s);\n", value)
}
func (b *equalityCheckBuilder) visit(actualExpr string, expectedValue interface{}, decl gidlmixer.Declaration) {
switch expectedValue := expectedValue.(type) {
case bool:
b.assertEquals(actualExpr, b.construct(typeName(decl), "%t", expectedValue))
return
case int64, uint64, float64:
switch decl := decl.(type) {
case gidlmixer.PrimitiveDeclaration, *gidlmixer.EnumDecl:
b.assertEquals(actualExpr, b.construct(typeName(decl), formatPrimitive(expectedValue)))
return
case *gidlmixer.BitsDecl:
b.assertEquals(actualExpr, fmt.Sprintf("static_cast<%s>(%s)", declName(decl), formatPrimitive(expectedValue)))
return
}
case gidlir.RawFloat:
switch decl.(*gidlmixer.FloatDecl).Subtype() {
case fidlgen.Float32:
b.assertEquals(actualExpr, fmt.Sprintf("([] { uint32_t u = %#b; float f; memcpy(&f, &u, sizeof(float)); return f; })()", expectedValue))
return
case fidlgen.Float64:
b.assertEquals(actualExpr, fmt.Sprintf("([] { uint64_t u = %#b; double d; memcpy(&d, &u, sizeof(double)); return d; })()", expectedValue))
return
}
case string:
dereferencedActual := actualExpr
if decl.IsNullable() {
dereferencedActual = fmt.Sprintf("(*%s)", actualExpr)
}
b.assertStringEquals(dereferencedActual, escapeStr(expectedValue))
return
case gidlir.HandleWithRights:
b.visitHandle(actualExpr, expectedValue, decl.(*gidlmixer.HandleDecl))
return
case gidlir.Record:
switch decl := decl.(type) {
case *gidlmixer.StructDecl:
b.visitStruct(actualExpr, expectedValue, decl)
return
case *gidlmixer.TableDecl:
b.visitTable(actualExpr, expectedValue, decl)
return
case *gidlmixer.UnionDecl:
b.visitUnion(actualExpr, expectedValue, decl)
return
}
case []interface{}:
b.visitList(actualExpr, expectedValue, decl.(gidlmixer.ListDeclaration))
return
case nil:
switch decl.(type) {
case *gidlmixer.StringDecl:
b.assertNull(actualExpr)
return
case *gidlmixer.HandleDecl:
b.assertFalse(fmt.Sprintf("%s.is_valid()", actualExpr))
return
case *gidlmixer.UnionDecl:
b.assertNull(actualExpr)
return
case *gidlmixer.VectorDecl:
b.assertFalse(fmt.Sprintf("%s.has_value()", actualExpr))
return
case *gidlmixer.StructDecl:
b.assertNull(actualExpr)
return
}
}
panic(fmt.Sprintf("not implemented: %T (decl: %T)", expectedValue, decl))
}
func (b *equalityCheckBuilder) visitHandle(actualExpr string, expectedValue gidlir.HandleWithRights, decl *gidlmixer.HandleDecl) {
actualVar := b.createAndAssignVar(actualExpr)
resultVar := b.varSeq.next()
// Check:
// - Original handle's koid matches final handle (it could be replaced so can't check handle value).
// - Type matches expectation.
// - Rights matches expectation.
b.write(`
zx_info_handle_basic_t %[1]s_info;
ASSERT_OK(zx_object_get_info(%[2]s.get(), ZX_INFO_HANDLE_BASIC, &%[1]s_info, sizeof(%[1]s_info), nullptr, nullptr));
ASSERT_EQ(%[1]s_info.koid, %[3]s[%[4]d]);
ASSERT_TRUE(%[1]s_info.type == %[5]d || %[5]d == ZX_OBJ_TYPE_NONE);
ASSERT_TRUE(%[1]s_info.rights == %[6]d || %[6]d == ZX_RIGHT_SAME_RIGHTS);
`, resultVar, actualVar, b.handleKoidVectorName, expectedValue.Handle, expectedValue.Type, expectedValue.Rights)
}
func (b *equalityCheckBuilder) visitStruct(actualExpr string, expectedValue gidlir.Record, decl *gidlmixer.StructDecl) {
op := "."
if decl.IsNullable() {
op = "->"
}
actualVar := b.createAndAssignVar(actualExpr)
for _, field := range expectedValue.Fields {
fieldDecl, ok := decl.Field(field.Key.Name)
if !ok {
panic(fmt.Sprintf("field %q not found", field.Key.Name))
}
actualFieldExpr := fmt.Sprintf("%s%s%s", actualVar, op, field.Key.Name)
b.visit(actualFieldExpr, field.Value, fieldDecl)
}
}
func (b *equalityCheckBuilder) visitTable(actualExpr string, expectedValue gidlir.Record, decl *gidlmixer.TableDecl) {
actualVar := b.createAndAssignVar(actualExpr)
expectedFieldValues := map[string]gidlir.Value{}
for _, field := range expectedValue.Fields {
if field.Key.IsUnknown() {
unknownData := fmt.Sprintf("%s.UnknownData()", actualVar)
b.assertNotEquals(
fmt.Sprintf("%s.find(%d)", unknownData, field.Key.UnknownOrdinal),
fmt.Sprintf("%s.end()", unknownData))
if decl.IsResourceType() {
b.visitUnknownBytes(
fmt.Sprintf("%s.at(%d).bytes", unknownData, field.Key.UnknownOrdinal),
field.Value.(gidlir.UnknownData).Bytes)
b.visitUnknownHandles(
fmt.Sprintf("%s.at(%d).handles", unknownData, field.Key.UnknownOrdinal),
field.Value.(gidlir.UnknownData).Handles)
} else {
b.visitUnknownBytes(
fmt.Sprintf("%s.at(%d)", unknownData, field.Key.UnknownOrdinal),
field.Value.(gidlir.UnknownData).Bytes)
}
} else {
expectedFieldValues[field.Key.Name] = field.Value
}
}
for _, fieldName := range decl.FieldNames() {
fieldDecl, ok := decl.Field(fieldName)
if !ok {
panic(fmt.Sprintf("field decl %s not found", fieldName))
}
if expectedFieldValue, ok := expectedFieldValues[fieldName]; ok {
b.assertTrue(fmt.Sprintf("%s.has_%s()", actualVar, fieldName))
actualFieldExpr := fmt.Sprintf("%s.%s()", actualVar, fieldName)
b.visit(actualFieldExpr, expectedFieldValue, fieldDecl)
} else {
b.assertFalse(fmt.Sprintf("%s.has_%s()", actualVar, fieldName))
}
}
}
func (b *equalityCheckBuilder) visitUnion(actualExpr string, expectedValue gidlir.Record, decl *gidlmixer.UnionDecl) {
op := "."
if decl.IsNullable() {
op = "->"
}
actualVar := b.createAndAssignVar(actualExpr)
if len(expectedValue.Fields) != 1 {
panic("shouldn't happen")
}
field := expectedValue.Fields[0]
if field.Key.IsUnknown() {
b.visitUnknownBytes(
fmt.Sprintf("(*%s.UnknownBytes())", actualExpr),
field.Value.(gidlir.UnknownData).Bytes)
if decl.IsResourceType() {
b.visitUnknownHandles(
fmt.Sprintf("(*%s.UnknownHandles())", actualExpr),
field.Value.(gidlir.UnknownData).Handles)
}
return
}
fieldDecl, ok := decl.Field(field.Key.Name)
if !ok {
panic(fmt.Sprintf("field %q not found", field.Key.Name))
}
b.assertEquals(
fmt.Sprintf("%s%sWhich()", actualVar, op),
fmt.Sprintf("%s::Tag::k%s", declName(decl), fidlgen.ToUpperCamelCase(field.Key.Name)))
actualFieldExpr := fmt.Sprintf("%s%s%s()", actualVar, op, field.Key.Name)
b.visit(actualFieldExpr, field.Value, fieldDecl)
}
func (b *equalityCheckBuilder) visitList(actualExpr string, expectedValue []interface{}, decl gidlmixer.ListDeclaration) {
actualVar := b.createAndAssignVar(actualExpr)
if _, ok := decl.(*gidlmixer.VectorDecl); ok {
b.assertEquals(fmt.Sprintf("%s.size()", actualVar), fmt.Sprintf("%d", len(expectedValue)))
}
for i, item := range expectedValue {
b.visit(fmt.Sprintf("%s[%d]", actualVar, i), item, decl.Elem())
}
}
func (b *equalityCheckBuilder) visitUnknownBytes(actualExpr string, expectedValue []byte) {
b.write(`
std::vector<uint8_t> bytes%[1]s = %[2]s;
ASSERT_EQ(bytes%[1]s, %[3]s);
`,
b.varSeq.next(), BuildBytes(expectedValue), actualExpr)
}
func (b *equalityCheckBuilder) visitUnknownHandles(actualExpr string, expectedValue []gidlir.Handle) {
b.write(`
std::vector<zx_handle_t> handles%[1]s = %[2]s;
ASSERT_EQ(handles%[1]s.size(), %[3]s.size());
for (uint32_t i = 0; i < handles%[1]s.size(); ++i) {
zx_handle_t actual = handles%[1]s[i];
zx_info_handle_basic_t %[1]s_info_actual;
ASSERT_OK(zx_object_get_info(actual, ZX_INFO_HANDLE_BASIC, &%[1]s_info_actual, sizeof(%[1]s_info_actual), nullptr, nullptr));
zx_handle_t expected = %[3]s[i].get();
zx_info_handle_basic_t %[1]s_info_expected;
ASSERT_OK(zx_object_get_info(expected, ZX_INFO_HANDLE_BASIC, &%[1]s_info_expected, sizeof(%[1]s_info_expected), nullptr, nullptr));
ASSERT_EQ(%[1]s_info_expected.koid, %[1]s_info_actual.koid);
}
`, b.varSeq.next(), BuildRawHandlesFromHandleInfos(expectedValue), actualExpr)
}
func formatPrimitive(value interface{}) string {
switch value := value.(type) {
case int64:
if value == -9223372036854775808 {
return "-9223372036854775807ll - 1"
}
return fmt.Sprintf("%dll", value)
case uint64:
return fmt.Sprintf("%dull", value)
case float64:
return fmt.Sprintf("%g", value)
}
panic("Unreachable")
}