RFC: Float8E4M3FNUZ and Float8E5M2FNUZ

Status: Approved
Initial version: 3/21/2023
Last updated: 4/5/2023
Discussion thread: GitHub

Summary

Graphcore, AMD, and Qualcomm have proposed two new FP8 types, Float8E4M3FNUZ and Float8E5M2FNUZ[^1]. These types are implemented in commercially available hardware[^2], and added to MLIR builtin types[^4] and LLVM APFloat[^5].

These two types appear similar to the existing types Float8E4M3FN and Float8E5M2[^3], but differ in important ways.

Details

Both Float8E4M3FNUZ and Float8E5M2FNUZ differ from typical floating point types in their support for NaN, infinities, negative zero, and exponent bias. The suffix “FNUZ” is derived from these differences. F is for “finite” (no infinities), N for with special NaN encoding, UZ for unsigned zero. I propose keeping this naming scheme in StableHLO, matching LLVM/MLIR.

These changes mean there's an additional exponent value available. To keep the same dynamic range as an IEEE-like FP8 type, the exponent is biased one more than would be expected given the number of exponent bits (8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ).

Float8E4M3FNUZ

8-bit floating point with 3 bit mantissa.

An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. This is not a standard type as defined by IEEE-754, but it follows similar conventions, with the exception that there are no infinity values, no negative zero, and only one NaN representation. This type has the following characteristics:

  • bit encoding: S1E4M3 - 0bSEEEEMMM
  • exponent bias: 8
  • infinities: Not supported
  • NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
  • denormals when exponent is 0

Comparison of Float8E4M3FN and Float8E4M3FNUZ

Float8E4M3FNFloat8E4M3FNUZ
Bias78
Min Normal Value0bS0001000 = -1S $\times$ 1.0 $\times$ 2-60bS0001000 = -1S $\times$ 1.0 $\times$ 2-7
Max Normal Value0bS1111110 = -1S $\times$ 1.75 $\times$ 28 = 4480bS1111111 = -1S $\times$ 1.875 $\times$ 27 = 240
Min Subnormal Value0bS0000001 = -1S $\times$ 0.125 $\times$ 2-60bS0000001 = -1S $\times$ 0.125 $\times$ 2-7
Max Subnormal Value0bS0000111 = -1S $\times$ 0.875 $\times$ 2-60bS0000111 = -1S $\times$ 0.875 $\times$ 2-7
NaN0bS11111110b10000000
InfinityN/AN/A
-00b10000000N/A

Float8E5M2FNUZ

8-bit floating point with 2 bit mantissa.

An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. This is not a standard type as defined by IEEE-754, but it follows similar conventions, with the exception that there are no infinity values, no negative zero, and only one NaN representation. This type has the following characteristics:

  • bit encoding: S1E5M2 - 0bSEEEEEMM
  • exponent bias: 16
  • infinities: Not supported
  • NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
  • denormals when exponent is 0

Comparison of Float8E5M2 and Float8E5M2FNUZ

Float8E5M2Float8E5M2FNUZ
Bias1516
Min Normal Value0bS0000100 = -1S $\times$ 1.0 $\times$ 2-140bS0000100 = -1S $\times$ 1.0 $\times$ 2-15
Max Normal Value0bS1111011 = -1S $\times$ 1.75 $\times$ 215 = 573440bS1111111 = -1S $\times$ 1.75 $\times$ 215 = 57344
Min Subnormal Value0bS0000001 = -1S $\times$ 0.25 $\times$ 2-140bS0000001 = -1S $\times$ 0.25 $\times$ 2-15
Max Subnormal Value0bS0000011 = -1S $\times$ 0.75 $\times$ 2-140bS0000011 = -1S $\times$ 0.75 $\times$ 2-15
NaN0bS11111MM, where MM is non-zero.0b10000000
Infinity0bS1111100N/A
-00b10000000N/A

Changes in StableHLO

I propose adding these types to StableHLO similar to the previously introduces FP8 types FP8 RFC with some differences.

StableHLO Interpreter

To provide a reference implementation, I intend to add support for Float8E4M3FNUZ and Float8E5M2FNUZ in the StableHLO interpreter. This will be useful for testing other backends and validating new implementations. This will be achieved in two ways:

  1. Map directly to the appropriate APFloat operation.
  2. Cast up to the appropriate type, use that implementation, cast back down.

Float8E4M3FNUZ and Float8E5M2FNUZ Arithmetic

I intend for Float8E4M3FNUZ and Float8E5M2FNUZ to be types that support the appropriate arithmetic operations, like any other floating point type. For platforms that don't have hardware support for these types, they may either throw an error and reject the program or cast up to an appropriate higher precision type that is supported, compute the answer, and cast back down.

This is a simple approach that aligns with user expectations of a floating point data type, and is the approach taken by BFloat16. This also gives backends freedom to exploit any hardware support.

Here's an example of a real JAX program (logging the MLIR) computing a simple dot product in Float8E5M2FNUZ. Note the answer is slightly “wrong”, as expected due to the lower precision.

>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(16, dtype=jnp.float8_e5m2fnuz)
module @jit_iota {
  func.func public @main() -> tensor<16xf8E5M2FNUZ> {
    %0 = stablehlo.iota dim = 0 : tensor<16xf8E5M2FNUZ>
    return %0 : tensor<16xf8E5M2FNUZ>
  }
}
>>> x
Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 10, 12, 12, 12, 14, 16], dtype=float8_e5m2fnuz)
>>> x @ x
module @jit_matmul {
  func.func public @main(%arg0: tensor<16xf8E5M2FNUZ> {mhlo.sharding = ""}, %arg1: tensor<16xf8E5M2FNUZ> {mhlo.sharding = ""}) -> tensor<f8E5M2FNUZ> {
    %0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<16xf8E5M2FNUZ>, tensor<16xf8E5M2FNUZ>) -> tensor<f8E5M2FNUZ>
    return %0 : tensor<f8E5M2FNUZ>
  }
}
Array(1280, dtype=float8_e5m2fnuz)

Scaling

At the StableHLO-level we won't impose any scaling requirements on users. Given arithmetic operations can be supported, we leave the scaling choice to the user. I believe this is the correct approach given that FP8 applications are an active area of research.

Graphcore‘s IPU supports hardware scaling by biasing the hardware’s interpretation of the exponent ±32 at runtime[^2]. This is a backend-specific peephole optimisation that doesn't impact StableHLO. Other backends may similarly optimise their implementation of these types.

Testing

Built on the StableHLO interpreter, I intend to introduce tests for all possible operations with Float8E4M3FNUZ and Float8E5M2FNUZ inputs. This will at a minimum mean adding additional cases to the interpret_X.mlir family of tests.

Not Included

Given any new data types, we could also consider derived types. One example would be hypothetical complex number types with the real and imaginary component being constructed from Float8E4M3FNUZ or Float8E5M2FNUZ. I don't exclude the possibility of this being done in the future, but that is not being proposed here.

[^1]: 8-bit Numerical Formats for Deep Neural Networks by Noune et al. [^2]: Graphcore Tile Vertex ISA 1.3.1 IPU21 [^3]: FP8 Formats for Deep Learning by Micikevicius et al. [^4]: Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR [^5]: [llvm][APFloat] Add NaN-in-negative-zero formats by AMD and GraphCore