| // Copyright 2020 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package rust |
| |
| import ( |
| "bytes" |
| "fmt" |
| "sort" |
| "strings" |
| "text/template" |
| |
| fidl "go.fuchsia.dev/fuchsia/tools/fidl/lib/fidlgen" |
| "go.fuchsia.dev/fuchsia/tools/fidl/measure-tape/src/measurer" |
| "go.fuchsia.dev/fuchsia/tools/fidl/measure-tape/src/utils" |
| ) |
| |
| func WriteRs(buf *bytes.Buffer, |
| m *measurer.Measurer, |
| targetMt *measurer.MeasuringTape, |
| allMethods map[measurer.MethodID]*measurer.Method) { |
| |
| if err := topOfRs.Execute(buf, newTmplParams(m, targetMt)); err != nil { |
| panic(err) |
| } |
| |
| cb := codeBuffer{buf: buf} |
| utils.ForAllMethodsInOrder(allMethods, func(m *measurer.Method) { |
| buf.WriteString("\n") |
| cb.writeMethod(m) |
| }) |
| } |
| |
| type codeBuffer struct { |
| level int |
| buf *bytes.Buffer |
| } |
| |
| func (cb *codeBuffer) writeMethod(m *measurer.Method) { |
| traitName, methodName := toTraitAndMethodName(m.ID.Kind) |
| cb.writef("impl %s for %s {\n", traitName, toTypeName(m.ID.TargetType)) |
| cb.indent(func() { |
| // TODO(fxbug.dev/51366): With improved locals handling, we could |
| // conditionally define the alias below. Of course, this would be |
| // superseded by fxbug.dev/51368 but both should happen. |
| cb.writef("#[inline]\n") |
| cb.writef("#[allow(unused_variables)]\n") |
| cb.writef("fn %s(&self, size_agg: &mut SizeAgg) {\n", methodName) |
| cb.indent(func() { |
| // TODO(fxbug.dev/51368): Variable naming should be defered to printing. |
| // Here, we should bind m.Arg to the name 'self' therefore avoiding |
| // this local. |
| cb.writef("let %s = self;\n", formatExpr{m.Arg}) |
| cb.writeBlock(m.Body) |
| }) |
| cb.writef("}\n") |
| }) |
| cb.writef("}\n") |
| } |
| |
| func toTraitAndMethodName(kind measurer.MethodKind) (string, string) { |
| switch kind { |
| case measurer.Measure: |
| return "Measurable", "measure" |
| case measurer.MeasureOutOfLine: |
| return "MeasurableOutOfLine", "measure_out_of_line" |
| case measurer.MeasureHandles: |
| return "MeasurableHandles", "measure_handles" |
| default: |
| panic(fmt.Sprintf("should not be reachable for method kind %v", kind)) |
| } |
| } |
| |
| func (cb *codeBuffer) writeBlock(b *measurer.Block) { |
| b.ForAllStatements(func(stmt *measurer.Statement) { |
| stmt.Visit(cb) |
| }) |
| } |
| |
| var _ measurer.StatementFormatter = (*codeBuffer)(nil) |
| |
| func (cb *codeBuffer) CaseMaxOut() { |
| cb.writef("size_agg.maxed_out = true;\n") |
| } |
| |
| func (cb *codeBuffer) CaseAddNumBytes(val measurer.Expression) { |
| cb.writef("size_agg.add_num_bytes(%s);\n", formatExpr{val}) |
| } |
| |
| func (cb *codeBuffer) CaseAddNumHandles(val measurer.Expression) { |
| cb.writef("size_agg.add_num_handles(%s);\n", formatExpr{val}) |
| } |
| |
| func (cb *codeBuffer) CaseInvoke(id measurer.MethodID, val measurer.Expression) { |
| _, methodName := toTraitAndMethodName(id.Kind) |
| cb.writef("%s.%s(size_agg);\n", formatExpr{val}, methodName) |
| } |
| |
| func (cb *codeBuffer) CaseGuard(cond measurer.Expression, body *measurer.Block) { |
| // TODO(fxbug.dev/51613): Improve guard statement. |
| cb.writef("match %s {\n", formatExpr{cond}) |
| cb.indent(func() { |
| cb.writef("Some(_) => {\n") |
| cb.indent(func() { |
| cb.writeBlock(body) |
| }) |
| cb.writef("}\n") |
| cb.writef("_ => {}\n") |
| }) |
| cb.writef("}\n") |
| } |
| |
| func (cb *codeBuffer) CaseIterate(local, val measurer.Expression, body *measurer.Block) { |
| var iter string |
| if kind := val.AssertKind(measurer.String, measurer.Vector, measurer.Array); kind == measurer.Array || kind == measurer.Vector { |
| iter = ".iter()" |
| } |
| cb.writef("for %s in %s%s%s {\n", formatExpr{local}, formatExpr{val}, maybeUnwrap(val), iter) |
| cb.indent(func() { |
| cb.writeBlock(body) |
| }) |
| cb.writef("}\n") |
| } |
| |
| func (cb *codeBuffer) CaseSelectVariant( |
| val measurer.Expression, |
| targetType fidl.Name, |
| variants map[string]measurer.LocalWithBlock) { |
| |
| cb.writef("match %s {\n", formatExpr{val}) |
| cb.indent(func() { |
| utils.ForAllVariantsInOrder(variants, func(member string, localWithBlock measurer.LocalWithBlock) { |
| if member != measurer.UnknownVariant { |
| cb.writef("%s::%s(%s) => {\n", |
| toTypeName(targetType), fidl.ToUpperCamelCase(member), |
| formatExpr{localWithBlock.Local}) |
| } else { |
| cb.writef("%sUnknown!() => {\n", toTypeName(targetType)) |
| } |
| cb.indent(func() { |
| cb.writeBlock(localWithBlock.Body) |
| }) |
| cb.writef("}\n") |
| }) |
| }) |
| cb.writef("}\n") |
| } |
| |
| func (cb *codeBuffer) CaseDeclareMaxOrdinal(local measurer.Expression) { |
| cb.writef("let mut %s: usize = 0;\n", formatExpr{local}) |
| } |
| |
| func (cb *codeBuffer) CaseSetMaxOrdinal(local, ordinal measurer.Expression) { |
| cb.writef("%s = %s;\n", formatExpr{local}, formatExpr{ordinal}) |
| } |
| |
| func (cb *codeBuffer) writef(format string, a ...interface{}) { |
| for i := 0; i < cb.level; i++ { |
| cb.buf.WriteString(indent) |
| } |
| cb.buf.WriteString(fmt.Sprintf(format, a...)) |
| } |
| |
| const indent = " " |
| |
| func (cb *codeBuffer) indent(fn func()) { |
| cb.level++ |
| fn() |
| cb.level-- |
| } |
| |
| type formatExpr struct { |
| measurer.Expression |
| } |
| |
| func (val formatExpr) String() string { |
| return val.Fmt(val) |
| } |
| |
| var _ measurer.ExpressionFormatter = formatExpr{} |
| |
| func (formatExpr) CaseNum(num int) string { |
| return fmt.Sprintf("%d", num) |
| } |
| |
| func (formatExpr) CaseLocal(name string, _ measurer.TapeKind) string { |
| return name |
| } |
| |
| func (formatExpr) CaseMemberOf(val measurer.Expression, member string, _ measurer.TapeKind, nullable bool) string { |
| var maybeUnwrap string |
| if nullable { |
| maybeUnwrap = ".as_ref()" |
| } else if kind := val.AssertKind(measurer.Struct, measurer.Union, measurer.Table); kind == measurer.Table { |
| maybeUnwrap = ".as_ref().unwrap()" |
| } |
| return fmt.Sprintf("%s.%s%s", formatExpr{val}, member, maybeUnwrap) |
| } |
| |
| func (formatExpr) CaseFidlAlign(val measurer.Expression) string { |
| return fmt.Sprintf("round_up_to_align(%s, 8)", formatExpr{val}) |
| } |
| |
| func (formatExpr) CaseLength(val measurer.Expression) string { |
| return fmt.Sprintf("%s%s.len()", formatExpr{val}, maybeUnwrap(val)) |
| } |
| |
| func (formatExpr) CaseHasMember(val measurer.Expression, member string) string { |
| return fmt.Sprintf("%s.%s", formatExpr{val}, member) |
| } |
| |
| func (formatExpr) CaseMult(lhs, rhs measurer.Expression) string { |
| return fmt.Sprintf("%s * %s", formatExpr{lhs}, formatExpr{rhs}) |
| } |
| |
| func maybeUnwrap(val measurer.Expression) string { |
| if val.Nullable() { |
| return ".unwrap()" |
| } |
| return "" |
| } |
| |
| func toCrateName(libraryName fidl.LibraryName) string { |
| return fmt.Sprintf("fidl_%s", strings.Join(libraryName.Parts(), "_")) |
| } |
| |
| func toTypeName(declName fidl.Name) string { |
| return fmt.Sprintf("%s::%s", |
| toCrateName(declName.LibraryName()), |
| fidl.ToUpperCamelCase(declName.DeclarationName())) |
| } |
| |
| type tmplParams struct { |
| Uses []string |
| TargetType string |
| } |
| |
| func newTmplParams(m *measurer.Measurer, targetMt *measurer.MeasuringTape) tmplParams { |
| var uses []string |
| for _, libraryName := range m.RootLibraries() { |
| uses = append(uses, toCrateName(libraryName)) |
| } |
| sort.Strings(uses) |
| |
| return tmplParams{ |
| Uses: uses, |
| TargetType: toTypeName(targetMt.Name()), |
| } |
| } |
| |
| var topOfRs = template.Must(template.New("tmpls").Parse( |
| `// WARNING: This file is machine generated by measure-tape. |
| |
| #![allow(unused_imports)] |
| use { |
| fidl::encoding::round_up_to_align, |
| {{- range .Uses }} |
| {{ . }}, |
| {{- end }} |
| fuchsia_zircon as zx, |
| }; |
| |
| #[derive(Debug, Eq, PartialEq)] |
| pub struct Size { |
| pub num_bytes: usize, |
| pub num_handles: usize, |
| } |
| |
| #[inline] |
| pub fn measure(value: &{{ .TargetType }}) -> Size { |
| let mut size_agg = SizeAgg { maxed_out: false, num_bytes: 0, num_handles: 0 }; |
| value.measure(&mut size_agg); |
| return size_agg.to_size(); |
| } |
| |
| struct SizeAgg { |
| maxed_out: bool, |
| num_bytes: usize, |
| num_handles: usize, |
| } |
| |
| impl SizeAgg { |
| #[inline(always)] |
| fn add_num_bytes(&mut self, num_bytes: usize) { |
| self.num_bytes += num_bytes; |
| } |
| |
| #[inline(always)] |
| #[allow(dead_code)] |
| fn add_num_handles(&mut self, num_handles: usize) { |
| self.num_handles += num_handles; |
| } |
| |
| #[inline(always)] |
| fn to_size(&self) -> Size { |
| if self.maxed_out { |
| return Size { |
| num_bytes: zx::sys::ZX_CHANNEL_MAX_MSG_BYTES as usize, |
| num_handles: zx::sys::ZX_CHANNEL_MAX_MSG_HANDLES as usize, |
| }; |
| } |
| return Size { num_bytes: self.num_bytes, num_handles: self.num_handles }; |
| } |
| } |
| |
| trait Measurable { |
| fn measure(&self, size_agg: &mut SizeAgg); |
| } |
| |
| trait MeasurableOutOfLine { |
| fn measure_out_of_line(&self, size_agg: &mut SizeAgg); |
| } |
| |
| trait MeasurableHandles { |
| fn measure_handles(&self, size_agg: &mut SizeAgg); |
| } |
| `)) |