blob: 839522739db13cf3d6f46c35178b4660daa9d85b [file] [view] [edit]
# Interpreter Design
## Data Model
[StableHLO programs](spec.md#programs) are computations over tensors
(n-dimensional arrays), which, in the current model, are implemented using the
class `Tensor`. The underlying storage class for a `Tensor` object,
`detail::Buffer`, stores the `mlir::ShapedType` of the tensor along with a
`mlir::HeapAsmResourceBlob` object representing a mutable blob of tensor
data laid out as contiguous byte array in
[major-to-minor order](https://www.tensorflow.org/xla/shapes).
`detail::Buffer` objects are reference-counted to simplify memory management.
Individual elements of a tensor are represented using the `Element` class, which
uses a discriminated union holding one of `APInt`, `APFloat` or
`pair<APFloat,APFloat>` for storage. The last one is used for storing elements
with complex types.
`Tensor` has the following APIs to interact with its individual elements:
- `Element Tensor::get(llvm::ArrayRef<int64_t> index)`: To extract an
individual tensor element at multi-dimensional index `index` as `Element`
object.
- `void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);`:
To update an `Element` object `element` into a tensor at multi-dimensional
index `index`.
## How the interpreter works
The entry function to the interpreter is
```C++
SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);
```
which does the following:
1. Tracks the SSA arguments of `func` and their associated runtime `Tensor`
values, provided in `args`, using a symbol table map, M.
2. For each op within `func`, in SSACFG order:
- Invokes `eval` on the op. For each SSA operand of the op, extract its
runtime value from M to be provided as an argument to the `eval` invocation.
- Tracks the SSA result(s) of the op and the evaluated value in M.
The op-level `eval` mentioned in (2) is responsible for implementing the
execution semantics of the op. Following is an example for `stablehlo::AddOp`.
In the example, individual elements of the `lhs` and `rhs` tensors are pairwise
extracted as `Element` objects which are then added. The result of the addition,
an `Element` object, is stored in the final `result` tensor.
```C++
Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
Tensor result(op.getType());
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) + rhs.get(*it));
return result;
}
```
Overall, the design of the interpreter is optimized for readability of
implementations of `eval` functions for individual ops because it's meant to
serve as a reference implementation for StableHLO. For example, instead of
defining `eval` as a template function and parameterizing it with element types,
we encapsulate details about how different element types are handled in
`Element::operator+` etc., simplifying the implementation of `eval`.
## Using the interpreter for constant folding
We can use the interpreter mechanism to fold operations with constant operand
values. The following code snippet demonstrates an idea of the implementation
for folding `stablehlo::AddOp` with floating-point typed operands:
```C++
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto attrs = adaptor.getOperands();
DenseElementsAttr lhsData = dyn_cast<DenseElementsAttr>(attrs[0]);
DenseElementsAttr rhsData = dyn_cast<DenseElementsAttr>(attrs[1]);
if (!lhsData || !rhsData) return {};
auto lhs = Tensor(lhsData);
auto rhs = Tensor(rhsData);
auto result = eval(*this, lhs, rhs);
SmallVector<APFloat> values;
for (auto i = 0; i < result.getNumElements(); ++i) {
Element element = result.get(i);
values.push_back(cast<FloatAttr>(element.getValue()).getValue());
}
return DenseElementsAttr::get(result.getType(), values);
}
```
At the moment, we aren't actively working on integrating the interpreter into
constant folding because we aren't planning to implement folder for StableHLO.
However, in the future, we are planning to leverage the interpreter for constant
folding in MHLO, at which point we'll improve the ergonomics of the code snippet
above (e.g. we could have a helper function which packs constant operands into
`Tensor` objects and unpacks `Tensor` results into `OpFoldResult`).
## Testing the StableHLO interpreter
The interpreter takes as inputs (A) a StableHLO program, and (B) data values to
be fed to the program, and generates output data values, which are matched
against the user-provided expected data values. The data values (B) are
hard-coded in the program itself using `stablehlo.constant` operations. The
interpreter evaluates the input program. The output(s) of the op under test
is checked via checks (e.g. `check.expect_eq`, `check.expect_almost_eq`), as
shown below. `check.expect_eq` and `check.expect_eq_const` check for bitwise
equality for any supported type and `check.expect_almost_eq` and
`check.expect_almost_eq_const` check for near equality within a tolerance,
explained in testing guideline (G6), for floating point and complex types.
```C++
// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
%0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
%1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
%2 = stablehlo.add %0, %1 : tensor<2xui4>
check.expect_eq_const %2, [15, 5] : tensor<2xui4>
func.return
}
```
A test utility `stablehlo-translate --interpret`
([code](https://github.com/openxla/stablehlo/tree/main/stablehlo/tools/StablehloTranslateMain.cpp))
is responsible for parsing the program, interpreting each function including the
operations constituting the function. We have a dedicated test-suite, consisting
of several tests exercising various runtime behaviors, for each StableHLO Op.
The tests can be found [here](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret).
### Testing guidelines
**(G1) Do we need to test for all the supported types for every op?**
We can use a combination of the following rules to decide:
1. While implementing an op, if there exists code in the corresponding `eval`
function to handle a particular type, then it is imperative to have test(s)
to cover that type. As an example, for the `add` op, there is exclusive code
to handle integer, boolean, floating-point, and complex types, and hence we
need one test for each category of types.
2. If a set of types is handled uniformly in the corresponding `eval` function,
then a single test for all those types should be sufficient. As an example,
for the `add` op, all the variants of integer types (`si4`, `u4`, `si8`, `u8`
and so on) are handled alike using `llvm::APInt` APIs, and hence we can skip
adding tests for each of those variants, and instead add a single
representative test. To avoid ambiguity in selecting the representative, we
should use the following guidelines:
- If all the types, handled uniformly, have the same primitive type
(i.e., if all are integer, or floating-point, or complex types), then
choose the one with maximum bit-width.
- If all the types, handled uniformly, have a mix of primitive types, then
choose the one with the following primitive type, in decreasing order of
preference: integer, floating-point, boolean, complex.
**(G2) How do we decide on the number of tests needed to cover an op's
behavior?**
The goal is to comprehensively cover the logic of the interpreter for the op
(i.e. all corner cases of the implementation) with a minimal number of tests.
Minimizing the number of tests is important for maintainability. The fewer tests
we have, the easier it is to review them and to make sure that they
comprehensively cover the op. As a result, we expect that most of the simpler
ops will end up having just one test. If for some good reason comprehensive
coverage is impractical, then it is fine to stop at >= 90%. This will be decided
on case-by-case basis during pull request review.
**(G3) How about adding tests for the interpreter infrastructure?**
The interpreter infrastructure is mostly straightforward and can be added to
our trust base. The only non-trivial part is how various types are packed into
and unpacked from the underlying interpreter storage. As discussed in (G1), we
will be testing only those types of op which are handled differently. With
that it is possible that the packing/un-packing code, corresponding to different
variants of integer/floating-point types, might not get fully covered during
testing. To ensure full coverage, we can choose an op like `constant` that
supports all the StableHLO element types and write exhaustive tests.
**(G4) If the implementation of an op depends on other ops, should we write
tests for the latter?**
No. For example, the implementation of `batch_norm_grad` can be based on
`divide`, `subtract`, `multiply` and others. We should avoid testing the latter
ops while testing the former.
**(G5) Should we write tests to exercise the implementation-defined / undefined
behaviors?**
We should not write tests which exercise the implementation-defined or
undefined behaviors of the op. Tests exercising implementation-defined behaviors
demonstrate a local behavior of the interpreter which should not be
generalized. Tests exercising undefined behavior do not contribute towards
the understanding of the op's behavior.
**(G6) While writing tests for floating-point types, to what precision does the
expected result need to be specified in checks?**
For elementary operations (addition, subtraction, multiplication, division, and
square), an implementation following IEEE specification is expected to provide a
rounded result within 0.5 ULP of the mathematically exact result. That said, we
can safely imagine the expected result coming out of these operations to be at
most 1 ULP apart. However, this may not work for transcendental functions
(`sine`, `cosine`, etc.) for which the precision guarantees are
implementation-defined ([rationale](https://github.com/openxla/stablehlo/issues/96)).
The current implementation uses a "one-size-fits-all" tolerance value of 0.0001.
The following example demonstrates the above tolerance in action.
```mlir
func.func @check_tolerance() {
%0 = stablehlo.constant dense<0.2> : tensor<f32>
// The following check succeeds as %0 is almost equal to the provided
// constant modulo the tolerance, mentioned above.
check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>
// The following check fails as %0 is not bitwise equal to the provided
// constant.
check.expect_eq_const %0, dense<0.19999> : tensor<f32>
func.return
}
```
This is just the first step in testing the numerical accuracy of StableHLO ops.
At the moment, this is an underspecified area of the StableHLO spec, and there is
ongoing work to figure it out [#1156](https://github.com/openxla/stablehlo/issues/1156)
based on our experience using StableHLO in practice and on feedback from
stakeholders. As this works proceeds, we will update the infrastructure
accordingly.
**(G7) Anything about the coding-style of the tests?**
1. Make sure to use the actual name of the inputs/outputs instead of defaulting
to SSA values (e.g. %0, %1, etc.)
1. Make sure the tests use pretty-printed format, if it exists.
**(G8) Should we include the example already provided in the spec?**
Yes (for completeness of testing).