blob: 699a80d72e63abd3a01592499d8d98a4c34a3095 [file] [log] [blame]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package rust
import (
"fmt"
"strings"
"go.fuchsia.dev/fuchsia/tools/fidl/gidl/lib/ir"
"go.fuchsia.dev/fuchsia/tools/fidl/gidl/lib/mixer"
"go.fuchsia.dev/fuchsia/tools/fidl/lib/fidlgen"
)
func buildEqualityCheck(actualExpr string, expectedValue ir.Value, decl mixer.Declaration) string {
var b equalityCheckBuilder
b.visit(actualExpr, expectedValue, decl)
return b.String()
}
// Returns true if the value can be tested with a single `assert_eq!`.
func canAssertEq(value ir.Value) bool {
switch value := value.(type) {
case nil, string, bool, int64, uint64, float64:
return true
case ir.RawFloat, ir.AnyHandle, ir.UnknownData:
return false
case []ir.Value:
for _, elem := range value {
if !canAssertEq(elem) {
return false
}
}
return true
case ir.Record:
for _, field := range value.Fields {
if field.Key.IsUnknown() || !canAssertEq(field.Value) {
return false
}
}
return true
}
panic(fmt.Sprintf("unhandled type: %T", value))
}
type equalityCheckBuilder struct {
strings.Builder
}
func (b *equalityCheckBuilder) write(format string, args ...interface{}) {
fmt.Fprintf(b, format, args...)
b.WriteRune('\n')
}
func (b *equalityCheckBuilder) visit(expr string, value ir.Value, decl mixer.Declaration) {
if canAssertEq(value) {
b.write("assert_eq!(%s, %s);", expr, visit(value, decl))
return
}
if decl.IsNullable() {
// Unwrap the Option<...>.
expr = fmt.Sprintf("%s.as_ref().unwrap()", expr)
if _, ok := value.(ir.Record); ok {
// Unwrap again for Option<Box<...>>.
expr = fmt.Sprintf("%s.as_ref()", expr)
}
}
switch value := value.(type) {
case ir.RawFloat:
b.write("assert_eq!(%s.to_bits(), 0x%x)", expr, value)
case ir.Handle:
b.write("assert_eq!(%s.basic_info().unwrap().koid.raw_koid(), _handle_koids[%d]);", expr, value)
case ir.RestrictedHandle:
b.write(`
match %s.basic_info() {
Ok(info) => {
assert_eq!(info.koid.raw_koid(), _handle_koids[%d]);
assert_eq!(info.object_type, %s);
assert_eq!(info.rights, Rights::from_bits(%d).unwrap());
},
Err(e) => panic!("handle basic_info failed: {}", e),
}
`, expr, value.Handle, objectTypeConst(value.Type), value.Rights)
case []ir.Value:
elemDecl := decl.(mixer.ListDeclaration).Elem()
for i, elem := range value {
b.visit(fmt.Sprintf("%s[%d]", expr, i), elem, elemDecl)
}
case ir.Record:
switch decl := decl.(type) {
case *mixer.StructDecl:
for _, field := range value.Fields {
b.visit(fmt.Sprintf("%s.%s", expr, field.Key.Name), field.Value, decl.Field(field.Key.Name))
}
case *mixer.TableDecl:
for _, field := range value.Fields {
b.visit(fmt.Sprintf("%s.%s.as_ref().unwrap()", expr, field.Key.Name), field.Value, decl.Field(field.Key.Name))
}
case *mixer.UnionDecl:
field := value.Fields[0]
if field.Key.IsUnknown() {
if field.Value != nil {
panic(fmt.Sprintf("union %s: unknown ordinal %d: Rust cannot construct union with unknown bytes/handles",
decl.Name(), field.Key.UnknownOrdinal))
}
b.write("assert!(%s.is_unknown());", expr)
b.write("assert_eq!(%s.ordinal(), %d);", expr, field.Key.UnknownOrdinal)
} else {
b.write(`
match &%s {
%s::%s(x) => {
`, expr, declName(decl), fidlgen.ToUpperCamelCase(field.Key.Name))
b.visit("x", field.Value, decl.Field(field.Key.Name))
b.write("}")
if len(decl.FieldNames()) > 1 {
b.write(`other => panic!("expected %s::%s, got {:?}", other)`, expr, declName(decl))
}
b.write("}")
}
default:
panic(fmt.Sprintf("unhandled decl type: %T", decl))
}
default:
panic(fmt.Sprintf("unhandled value type: %T", value))
}
}
func objectTypeConst(objectType fidlgen.ObjectType) string {
switch objectType {
case fidlgen.ObjectTypeEvent:
return "fidl::ObjectType::EVENT"
case fidlgen.ObjectTypeChannel:
return "fidl::ObjectType::CHANNEL"
case fidlgen.ObjectTypeSocket:
return "fidl::ObjectType::SOCKET"
default:
panic(fmt.Sprintf("unsupported object type: %d", objectType))
}
}