| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| Copyright 2022 The StableHLO Authors. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef STABLEHLO_DIALECT_STABLEHLO_OPS |
| #define STABLEHLO_DIALECT_STABLEHLO_OPS |
| |
| include "mlir/Dialect/Shape/IR/ShapeBase.td" |
| include "mlir/IR/OpBase.td" |
| include "mlir/IR/SymbolInterfaces.td" |
| include "mlir/Interfaces/InferTypeOpInterface.td" |
| include "mlir/Interfaces/SideEffectInterfaces.td" |
| include "mlir/IR/OpAsmInterface.td" |
| include "stablehlo/dialect/Base.td" |
| |
| def StableHLO_Dialect : Dialect { |
| let name = "stablehlo"; |
| let cppNamespace = "::mlir::stablehlo"; |
| |
| let description = [{ |
| StableHLO is an operation set that expresses ML computations. It has been |
| originally bootstrapped from the MHLO dialect and enhances it with additional |
| functionality, including serialization and versioning, to be used as |
| a portability layer between ML frameworks and ML compilers. |
| }]; |
| |
| let useDefaultAttributePrinterParser = 0; |
| let useDefaultTypePrinterParser = 0; |
| } |
| |
| class StableHLO_Op<string mnemonic, list<Trait> traits = []> : |
| Op<StableHLO_Dialect, mnemonic, traits> { |
| string commonClassDeclaration = [{ |
| // Relax the strict default implementation with one that allows |
| // for StableHLO-specific differences. |
| static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
| return mlir::hlo::isCompatibleForHloTypeInference(l, r); |
| } |
| }]; |
| let extraClassDeclaration = commonClassDeclaration; |
| } |
| |
| include "stablehlo/dialect/StablehloEnums.td" |
| include "stablehlo/dialect/StablehloAttrs.td" |
| include "stablehlo/dialect/StablehloTypes.td" |
| |
| class StableHLO_ShapedInterfaceOp<string mnemonic, list<Trait> traits> : |
| StableHLO_Op<mnemonic, traits # [DeclareOpInterfaceMethods<InferShapedTypeOpInterface, |
| ["reifyReturnTypeShapes"]>]> {} |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO nullary op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| def StableHLO_ConstantOp : StableHLO_Op<"constant", |
| [ConstantLike, Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>, |
| DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { |
| let summary = "Constant operation"; |
| let description = [{ |
| Produces an `output` tensor from a constant `value`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant |
| |
| Example: |
| ```mlir |
| %output = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> |
| ``` |
| }]; |
| let arguments = (ins |
| ElementsAttr:$value |
| ); |
| |
| let results = (outs |
| HLO_StaticShapeTensorOrPerAxisQuantizedTensor:$output |
| ); |
| |
| let builders = [ |
| OpBuilder<(ins "Attribute":$value)>]; |
| |
| let hasCustomAssemblyFormat = 1; |
| let hasFolder = 1; |
| |
| let extraClassDeclaration = [{ |
| static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); |
| }]; |
| } |
| |
| def StableHLO_IotaOp : StableHLO_Op<"iota", [Pure]> { |
| let summary = "Iota operation"; |
| let description = [{ |
| Fills an `output` tensor with values in increasing order starting from zero |
| along the `iota_dimension` dimension. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota |
| |
| Example: |
| ```mlir |
| %output = stablehlo.iota dim = 0 : tensor<4x5xi32> |
| ``` |
| }]; |
| let arguments = (ins I64Attr:$iota_dimension); |
| |
| let results = (outs HLO_StaticShapeIntFpComplexOrQuantizedTensor:$output); |
| |
| let hasVerifier = 1; |
| |
| let assemblyFormat = "`dim` `=` $iota_dimension attr-dict `:` type($output)"; |
| } |
| |
| def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [ConditionallySpeculatable, NoMemoryEffect]> { |
| let summary = "DynamicIota operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation does the same thing as IotaOp except that the |
| result shape is specified dynamically via `output_shape`: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota |
| |
| Example: |
| ```mlir |
| %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<1xindex>) -> tensor<4xi32> |
| ``` |
| }]; |
| |
| let arguments = (ins HLO_StaticDimensionTensor:$output_shape, I64Attr:$iota_dimension); |
| let results = (outs HLO_Tensor:$result); |
| let hasVerifier = 1; |
| |
| let assemblyFormat = [{ |
| $output_shape `,` `dim` `=` $iota_dimension attr-dict `:` functional-type(operands, results) |
| }]; |
| |
| let extraClassDeclaration = [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_CreateTokenOp : StableHLO_Op<"create_token", [Pure, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "CreateToken operation"; |
| let description = [{ |
| This operation is on its way out of StableHLO, so it is not included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. |
| |
| Informally, this operation does the same thing as AfterAllOp with 0 inputs: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all |
| |
| Example: |
| ```mlir |
| %output = stablehlo.create_token : !stablehlo.token |
| ``` |
| }]; |
| |
| let results = (outs HLO_Token:$output); |
| |
| let assemblyFormat = "attr-dict `:` type(results)"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO unary elementwise op definitions. |
| //===----------------------------------------------------------------------===// |
| // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions |
| |
| class StableHLO_UnaryElementwiseOp<string mnemonic, list<Trait> traits, |
| Type OperandType, Type ResultType = OperandType> : StableHLO_Op<mnemonic, |
| traits # [Elementwise, InferShapedTypeOpInterface, SameOperandsAndResultShape, |
| HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, NoMemoryEffect]> { |
| let arguments = (ins OperandType:$operand); |
| let results = (outs ResultType:$result); |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| LogicalResult reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), |
| operands.front(), |
| &reifiedReturnShapes); |
| } |
| }]; |
| |
| let assemblyFormat = [{ |
| $operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result)) |
| }]; |
| } |
| |
| // Abs supports complex to real, so element type is not guaranteed to match. |
| def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs", |
| [DeclareOpInterfaceMethods<InferTypeOpInterface>], |
| RankedTensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt] /* abs_i1 */>, |
| RankedTensorOf<[HLO_SInt, HLO_Float, HLO_QuantizedInt]>> { |
| let summary = "Abs operation"; |
| let description = [{ |
| Performs element-wise abs operation on `operand` tensor and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#abs |
| |
| Example: |
| ```mlir |
| %result = stablehlo.abs %operand : tensor<3xi32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_CbrtOp: StableHLO_UnaryElementwiseOp<"cbrt", |
| [HLO_CompatibleOperandsAndResultType /*cbrt_c1*/], |
| HLO_FpComplexOrQuantizedIntTensor /*cbrt_i1*/> { /*cbrt_c1*/ |
| let summary = "Cbrt operation"; |
| let description = [{ |
| Performs element-wise cubic root operation on `operand` tensor and produces |
| a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cbrt |
| |
| Example: |
| ```mlir |
| %result = stablehlo.cbrt %operand : tensor<4xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_CeilOp: StableHLO_UnaryElementwiseOp<"ceil", |
| [HLO_CompatibleOperandsAndResultType], HLO_FpOrQuantizedIntTensor> { |
| let summary = "Ceil operation"; |
| let description = [{ |
| Performs element-wise ceil of `operand` tensor and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#ceil |
| |
| Example: |
| ```mlir |
| %result = stablehlo.ceil %operand : tensor<5xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_ConvertOp : StableHLO_UnaryElementwiseOp<"convert", |
| [SameOperandsAndResultShape /*convert_c1*/], HLO_NonQuantizedTensor> { /*convert_i1*/ |
| let summary = "Convert operation"; |
| let description = [{ |
| Performs an element-wise conversion from one element type to another on |
| `operand` tensor and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert |
| |
| Example: |
| ```mlir |
| %result = stablehlo.convert %operand : (tensor<3xi64>) -> tensor<3xcomplex<f64>> |
| ``` |
| }]; |
| let builders = [ |
| OpBuilder<(ins "Value":$operand, "Type":$result_element_ty)>]; |
| } |
| |
| def StableHLO_ClzOp: StableHLO_UnaryElementwiseOp<"count_leading_zeros", |
| [HLO_CompatibleOperandsAndResultType /*count_leading_zeros_c1*/], |
| HLO_IntTensor /*count_leading_zeros_i1*/> { /*count_leading_zeros_c1*/ |
| let summary = "Clz operation"; |
| let description = [{ |
| Performs element-wise count of the number of leading zero bits in the |
| `operand` tensor and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#count_leading_zeros |
| |
| Example: |
| ```mlir |
| %result = stablehlo.count_leading_zeros %operand : tensor<2x2xi64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine", |
| [HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> { |
| let summary = "Cosine operation"; |
| let description = [{ |
| Performs element-wise cosine operation on `operand` tensor and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine |
| |
| Example: |
| ```mlir |
| %result = stablehlo.cosine %operand : tensor<2xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential", |
| [HLO_CompatibleOperandsAndResultType /*exponential_c1*/], |
| HLO_FpComplexOrQuantizedIntTensor /*exponential_i1*/> { |
| let summary = "Exp operation"; |
| let description = [{ |
| Performs element-wise exponential operation on `operand` tensor and produces |
| a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential |
| |
| Example: |
| ```mlir |
| %result = stablehlo.exponential %operand : tensor<2x2xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one", |
| [HLO_CompatibleOperandsAndResultType], /*exponential_minus_one_c1*/ |
| HLO_FpComplexOrQuantizedIntTensor /*exponential_minus_one_i1*/> { /*exponential_minus_one_c1*/ |
| let summary = "Expm1 operation"; |
| let description = [{ |
| Performs element-wise exponential minus one operation on `operand` tensor |
| and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_one |
| |
| Example: |
| ```mlir |
| %result = stablehlo.exponential_minus_one %operand : tensor<2xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_FloorOp: StableHLO_UnaryElementwiseOp<"floor", |
| [HLO_CompatibleOperandsAndResultType], HLO_FpOrQuantizedIntTensor> { |
| let summary = "Floor operation"; |
| let description = [{ |
| Performs element-wise floor of `operand` tensor and produces a `result` |
| tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor |
| |
| Example: |
| ```mlir |
| %result = stablehlo.floor %operand : tensor<2xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_ImagOp: StableHLO_UnaryElementwiseOp<"imag", |
| [DeclareOpInterfaceMethods<InferTypeOpInterface>], |
| HLO_FpOrComplexTensor /*imag_i1*/, HLO_FpTensor> {/*imag_c1*/ |
| let summary = "Imag operation"; |
| let description = [{ |
| Extracts the imaginary part, element-wise, from the `operand` and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag |
| |
| Example: |
| ```mlir |
| %result = stablehlo.imag %operand : (tensor<2xcomplex<f32>>) -> tensor<2xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_IsFiniteOp: StableHLO_UnaryElementwiseOp<"is_finite", |
| [DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpOrQuantizedIntTensor> { |
| /*is_finite_c1*/ |
| let summary = "IsFinite operation"; |
| let description = [{ |
| Performs element-wise check whether the value in `x` is finite (i.e. is |
| neither +Inf, -Inf, nor NaN) and produces a `y` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite |
| |
| Example: |
| ```mlir |
| %y = stablehlo.is_finite %x : (tensor<7xf64>) -> tensor<7xi1> |
| ``` |
| }]; |
| let arguments = (ins HLO_FpOrQuantizedIntTensor:$x); /*is_finite_i1*/ |
| let results = (outs HLO_PredTensor:$y); |
| |
| let assemblyFormat = [{ |
| operands attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log", |
| [HLO_CompatibleOperandsAndResultType /*log_c1*/], |
| HLO_FpComplexOrQuantizedIntTensor /*log_i1*/> { |
| let summary = "Log operation"; |
| let description = [{ |
| Performs element-wise logarithm operation on `operand` tensor and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log |
| |
| Example: |
| ```mlir |
| %result = stablehlo.log %operand : tensor<2x2xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one", |
| [HLO_CompatibleOperandsAndResultType /*log_plus_one_c1*/], |
| HLO_FpComplexOrQuantizedIntTensor /*log_plus_one_i1*/> { /*log_plus_one_c1*/ |
| let summary = "Log1p operation"; |
| let description = [{ |
| Performs element-wise logarithm plus one operation on `operand` tensor and |
| produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one |
| |
| Example: |
| ```mlir |
| %result = stablehlo.log_plus_one %operand : tensor<5xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_LogisticOp: StableHLO_UnaryElementwiseOp<"logistic", |
| [HLO_CompatibleOperandsAndResultType /*logistic_c1*/], |
| HLO_FpComplexOrQuantizedIntTensor /*logistic_i1*/> { /*logistic_c1*/ |
| let summary = "Logistic operation"; |
| let description = [{ |
| Performs element-wise logistic operation on `operand` tensor and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic |
| |
| Example: |
| ```mlir |
| %result = stablehlo.logistic %operand : tensor<2x2xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_NotOp: StableHLO_UnaryElementwiseOp<"not", |
| [HLO_CompatibleOperandsAndResultType], HLO_PredOrIntTensor> { |
| let summary = "Not operation"; |
| let description = [{ |
| Performs element-wise NOT of tensor `operand` of type integer and produces |
| a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not |
| |
| Example: |
| ```mlir |
| %result = stablehlo.not %operand : tensor<5x3x1xi1> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_NegOp: StableHLO_UnaryElementwiseOp<"negate", |
| [HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> { |
| let summary = "Neg operation"; |
| let description = [{ |
| Performs element-wise negation of `operand` tensor and produces a `result` |
| tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate |
| |
| Example: |
| ```mlir |
| %result = stablehlo.negate %operand : tensor<2x3xi32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_PopulationCountOp: StableHLO_UnaryElementwiseOp<"popcnt", |
| [HLO_CompatibleOperandsAndResultType /*popcnt_c1*/], |
| HLO_IntTensor /*popcnt_i1*/> { /*popcnt_c1*/ |
| let summary = "PopulationCount operation"; |
| let description = [{ |
| Performs element-wise count of the number of bits set in the `operand` |
| tensor and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt |
| |
| Example: |
| ```mlir |
| %result = stablehlo.popcnt %operand : tensor<4xi64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_RealOp: StableHLO_UnaryElementwiseOp<"real", |
| [DeclareOpInterfaceMethods<InferTypeOpInterface>], |
| HLO_FpOrComplexTensor /*real_i1*/, HLO_FpTensor> {/*real_c1*/ |
| let summary = "Real operation"; |
| let description = [{ |
| Extracts the real part, element-wise, from the `operand` and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real |
| |
| Example: |
| ```mlir |
| %result = stablehlo.real %operand : (tensor<2xcomplex<f32>>) -> tensor<2xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_RoundOp: StableHLO_UnaryElementwiseOp<"round_nearest_afz", |
| [HLO_CompatibleOperandsAndResultType /*round_nearest_afz_c1*/], |
| HLO_FpOrQuantizedIntTensor /*round_nearest_afz_i1*/> { /*round_nearest_afz_c1*/ |
| let summary = "Round operation"; |
| let description = [{ |
| Performs element-wise rounding towards the nearest integer, breaking ties |
| away from zero, on the `operand` tensor and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz |
| |
| Example: |
| ```mlir |
| %result = stablehlo.round_nearest_afz %operand : tensor<5xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_RoundNearestEvenOp: StableHLO_UnaryElementwiseOp<"round_nearest_even", |
| [HLO_CompatibleOperandsAndResultType /*round_nearest_even_c1*/], |
| HLO_FpOrQuantizedIntTensor /*round_nearest_even_i1*/> { /*round_nearest_even_c1*/ |
| let summary = "RoundNearestEven operation"; |
| let description = [{ |
| Performs element-wise rounding towards the nearest integer, breaking ties |
| towards the even integer, on the `operand` tensor and produces a `result` |
| tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even |
| |
| Example: |
| ```mlir |
| %result = stablehlo.round_nearest_even %operand : tensor<5xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", |
| [HLO_CompatibleOperandsAndResultType /* rsqrt_c1 */], |
| HLO_FpComplexOrQuantizedIntTensor /* rsqrt_i1 */> { |
| let summary = "Rsqrt operation"; |
| let description = [{ |
| Performs element-wise reciprocal square root operation on `operand` tensor |
| and produces a `result` tensor, implementing the `rSqrt` operation from the |
| IEEE-754 specification. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt |
| |
| Example: |
| ```mlir |
| %result = stablehlo.rsqrt %operand : tensor<2x2xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign", |
| [HLO_CompatibleOperandsAndResultType /*sign_c1*/], |
| RankedTensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt]> /*sign_i1*/> { /*sign_c1*/ |
| let summary = "Sign operation"; |
| let description = [{ |
| Returns the sign of the `operand` element-wise and produces a `result` |
| tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign |
| |
| Example: |
| ```mlir |
| %result = stablehlo.sign %operand : tensor<5xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine", |
| [HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> { |
| let summary = "Sine operation"; |
| let description = [{ |
| Performs element-wise sine operation on `operand` tensor and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine |
| |
| Example: |
| ```mlir |
| %result = stablehlo.sine %operand : tensor<2xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", |
| [HLO_CompatibleOperandsAndResultType /* sqrt_c1 */], |
| HLO_FpComplexOrQuantizedIntTensor /* sqrt_i1 */> { |
| let summary = "Sqrt operation"; |
| let description = [{ |
| Performs element-wise square root operation on `operand` tensor and produces |
| a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt |
| |
| Example: |
| ```mlir |
| %result = stablehlo.sqrt %operand : tensor<2x2xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_TanhOp: StableHLO_UnaryElementwiseOp<"tanh", |
| [HLO_CompatibleOperandsAndResultType], |
| HLO_FpComplexOrQuantizedIntTensor> { |
| let summary = "Tanh operation"; |
| let description = [{ |
| Performs element-wise hyperbolic tangent operation on `operand` tensor and |
| produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh |
| |
| Example: |
| ```mlir |
| %result = stablehlo.tanh %operand : tensor<2xf32> |
| ``` |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO binary elementwise op definitions. |
| //===----------------------------------------------------------------------===// |
| // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations |
| |
| class StableHLO_BinaryElementwiseOp<string mnemonic, list<Trait> traits, |
| Type OperandType = HLO_Tensor, Type ResultType = OperandType> : |
| StableHLO_Op<mnemonic, traits # [InferShapedTypeOpInterface, |
| SameOperandsAndResultShape, Elementwise, |
| HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect]> { |
| let arguments = (ins |
| OperandType:$lhs, |
| OperandType:$rhs |
| ); |
| |
| string binaryElementwiseOpCommonClassDeclaration = commonClassDeclaration # [{ |
| LogicalResult reifyReturnTypeShapes( |
| OpBuilder& builder, ValueRange operands, |
| SmallVectorImpl<Value>& reifiedReturnShapes) { |
| return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), |
| operands.front(), |
| &reifiedReturnShapes); |
| } |
| }]; |
| |
| let extraClassDeclaration = binaryElementwiseOpCommonClassDeclaration; |
| |
| let results = (outs ResultType:$result); |
| |
| let assemblyFormat = [{ |
| $lhs `,` $rhs attr-dict |
| `:` custom<SameOperandsAndResultType>(type($lhs), type($rhs), type($result)) |
| }]; |
| } |
| |
| def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add", [HLO_Commutative, |
| InferTypeOpInterface, |
| DeclareOpInterfaceMethods<InferShapedTypeOpInterface, ["inferReturnTypeComponents"]>], |
| HLO_TensorOrPerAxisQuantizedTensor> { |
| let summary = "Add operation"; |
| let description = [{ |
| Performs element-wise addition of two tensors `lhs` and `rhs` and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#add |
| |
| Example: |
| ```mlir |
| %result = stablehlo.add %lhs, %rhs : tensor<2x2xi32> |
| ``` |
| }]; |
| |
| let extraClassDeclaration = binaryElementwiseOpCommonClassDeclaration # [{ |
| static LogicalResult inferReturnTypes( |
| MLIRContext * /*context*/, std::optional<Location> location, |
| ValueRange operands, DictionaryAttr /*attributes*/, |
| OpaqueProperties /*properties*/, RegionRange /*regions*/, |
| SmallVectorImpl<Type> &inferredReturnTypes) { |
| if (operands.empty()) |
| return emitOptionalError( |
| location, |
| "Expected non-empty operands for AddOp::inferReturnTypes"); |
| |
| auto inferredTypeOrErr = |
| mlir::hlo::inferMostSpecificType(location, operands.getTypes()); |
| if (failed(inferredTypeOrErr)) return failure(); |
| inferredReturnTypes.emplace_back(*inferredTypeOrErr); |
| return success(); |
| } |
| }]; |
| |
| let hasVerifier = 1; |
| |
| } |
| |
| def StableHLO_Atan2Op : StableHLO_BinaryElementwiseOp<"atan2", |
| [HLO_CompatibleOperandsAndResultType /*atan2_c1*/], |
| HLO_FpComplexOrQuantizedIntTensor /*atan2_i1, atan2_i2*/> { /*atan2_c1*/ |
| let summary = "Atan2 operation"; |
| let description = [{ |
| Performs element-wise atan2 operation on `lhs` and `rhs` tensor and produces |
| a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#atan2 |
| |
| Example: |
| ```mlir |
| %result = stablehlo.atan2 %lhs, %rhs : tensor<3xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_ComplexOp: StableHLO_BinaryElementwiseOp<"complex", |
| [HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect, |
| SameOperandsElementType /*complex_c1*/, |
| SameOperandsAndResultShape /*complex_c2*/, |
| DeclareOpInterfaceMethods<InferTypeOpInterface> /*complex_c3*/]> { |
| let summary = "Complex operation"; |
| let description = [{ |
| Performs element-wise conversion to a complex value from a pair of real and |
| imaginary values, `lhs` and `rhs`, and produces a `result` tensor. |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex |
| Example: |
| ```mlir |
| %result = stablehlo.complex %lhs, %rhs : tensor<2xcomplex<f64>> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Fp32Or64Tensor:$lhs /*complex_i1*/, |
| HLO_Fp32Or64Tensor:$rhs /*complex_i2*/ |
| ); |
| let results = (outs |
| HLO_ComplexTensor:$result |
| ); |
| |
| let assemblyFormat = [{ |
| operands attr-dict |
| `:` custom<ComplexOpType>(type($lhs), type($rhs), type($result)) |
| }]; |
| } |
| |
| def StableHLO_DivOp : StableHLO_BinaryElementwiseOp<"divide", |
| [HLO_CompatibleOperandsAndResultType /* div_c1 */], |
| HLO_IntFpOrComplexOrQuantizedIntTensor /* div_i1, div_i2 */> { |
| let summary = "Div operation"; |
| let description = [{ |
| Performs element-wise division of dividend `lhs` and divisor `rhs` tensors |
| and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide |
| |
| Example: |
| ```mlir |
| %result = stablehlo.divide %lhs, %rhs : tensor<4xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_MaxOp : StableHLO_BinaryElementwiseOp<"maximum", |
| [HLO_Commutative, HLO_CompatibleOperandsAndResultType]> { |
| let summary = "Max operation"; |
| let description = [{ |
| Performs element-wise max operation on tensors `lhs` and `rhs` and produces |
| a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum |
| |
| Example: |
| ```mlir |
| %result = stablehlo.maximum %lhs, %rhs : tensor<4xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_MinOp : StableHLO_BinaryElementwiseOp<"minimum", |
| [HLO_Commutative, HLO_CompatibleOperandsAndResultType]> { |
| let summary = "Min operation"; |
| let description = [{ |
| Performs element-wise min operation on tensors `lhs` and `rhs` and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum |
| |
| Example: |
| ```mlir |
| %result = stablehlo.minimum %lhs, %rhs : tensor<4xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply", |
| [HLO_Commutative, HLO_CompatibleOperandsAndResultType]> { |
| let summary = "Mul operation"; |
| let description = [{ |
| Performs element-wise product of two tensors `lhs` and `rhs` and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#multiply |
| |
| Example: |
| ```mlir |
| %result = stablehlo.multiply %lhs, %rhs : tensor<2xi32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power", |
| [HLO_CompatibleOperandsAndResultType /* pow_c1 */], |
| HLO_IntFpOrComplexOrQuantizedIntTensor /* pow_i1, pow_i2 */> { |
| let summary = "Power operation"; |
| let description = [{ |
| Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and |
| produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power |
| |
| Example: |
| ```mlir |
| %result = stablehlo.power %lhs, %rhs : tensor<6xf64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder", |
| [HLO_CompatibleOperandsAndResultType /*remainder_c1*/], |
| HLO_IntFpOrComplexOrQuantizedIntTensor /*remainder_i1, remainder_i2*/> { |
| let summary = "Rem operation"; |
| let description = [{ |
| Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors |
| and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder |
| |
| Example: |
| ```mlir |
| %result = stablehlo.remainder %lhs, %rhs : tensor<4xi64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_ShiftLeftOp : StableHLO_BinaryElementwiseOp<"shift_left", |
| [HLO_CompatibleOperandsAndResultType /*shift_left_c1*/], |
| HLO_IntTensor /*shift_left_i1, shift_left_i2*/> { /*shift_left_c1*/ |
| let summary = "ShiftLeft operation"; |
| let description = [{ |
| Performs element-wise left-shift operation on the `lhs` tensor by `rhs` |
| number of bits and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_left |
| |
| Example: |
| ```mlir |
| %result = stablehlo.shift_left %lhs, %rhs : tensor<3xi64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_ShiftRightArithmeticOp : StableHLO_BinaryElementwiseOp<"shift_right_arithmetic", |
| [HLO_CompatibleOperandsAndResultType /*shift_right_arithmetic_c1*/], |
| HLO_IntTensor /*shift_right_arithmetic_i1, shift_right_arithmetic_i2*/> { /*shift_right_arithmetic_c1*/ |
| let summary = "ShiftRightArithmetic operation"; |
| let description = [{ |
| Performs element-wise arithmetic right-shift operation on the `lhs` tensor |
| by `rhs` number of bits and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_arithmetic |
| |
| Example: |
| ```mlir |
| %result = stablehlo.shift_right_arithmetic %lhs, %rhs : tensor<3xi64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_ShiftRightLogicalOp : StableHLO_BinaryElementwiseOp<"shift_right_logical", |
| [HLO_CompatibleOperandsAndResultType /*shift_right_logical_c1*/], |
| HLO_IntTensor /*shift_right_logical_i1, shift_right_logical_i2*/> { /*shift_right_logical_c1*/ |
| let summary = "ShiftRightLogical operation"; |
| let description = [{ |
| Performs element-wise logical right-shift operation on the `lhs` tensor by |
| `rhs` number of bits and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_logical |
| |
| Example: |
| ```mlir |
| %result = stablehlo.shift_right_logical %lhs, %rhs : tensor<3xi64> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_SubtractOp : StableHLO_BinaryElementwiseOp<"subtract", |
| [HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> { |
| let summary = "Subtract operation"; |
| let description = [{ |
| Performs element-wise subtraction of two tensors `lhs` and `rhs` and |
| produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#subtract |
| |
| Example: |
| ```mlir |
| %result = stablehlo.subtract %lhs, %rhs : tensor<2xi32> |
| ``` |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO binary logical elementwise op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations |
| class StableHLO_BinaryBiwiseOrLogicalElementwiseOp<string mnemonic> : |
| StableHLO_BinaryElementwiseOp<mnemonic, |
| [HLO_Commutative, HLO_CompatibleOperandsAndResultType, |
| HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect]> { |
| let arguments = (ins |
| HLO_PredOrIntTensor:$lhs, |
| HLO_PredOrIntTensor:$rhs |
| ); |
| } |
| |
| def StableHLO_AndOp: StableHLO_BinaryBiwiseOrLogicalElementwiseOp<"and"> { |
| let summary = "And operation"; |
| let description = [{ |
| Performs element-wise AND of two tensors `lhs` and `rhs` and produces a |
| `result` tensor |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#and |
| |
| Example: |
| ```mlir |
| %result = stablehlo.and %lhs, %rhs : tensor<2x2xi32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_OrOp: StableHLO_BinaryBiwiseOrLogicalElementwiseOp<"or"> { |
| let summary = "Or operation"; |
| let description = [{ |
| Performs element-wise OR of two tensors `lhs` and `rhs` and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#or |
| |
| Example: |
| ```mlir |
| %result = stablehlo.or %lhs, %rhs : tensor<2xi1> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_XorOp : StableHLO_BinaryBiwiseOrLogicalElementwiseOp<"xor"> { |
| let summary = "Xor operation"; |
| let description = [{ |
| Performs element-wise XOR of two tensors `lhs` and `rhs` and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#xor |
| |
| Example: |
| ```mlir |
| %result = stablehlo.xor %lhs, %rhs : tensor<2xi32> |
| ``` |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO communication op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| def StableHLO_InfeedOp : StableHLO_Op<"infeed"> { |
| let summary = "Infeed operation"; |
| let description = [{ |
| Reads data from the infeed and produces `results`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#infeed |
| |
| Example: |
| ```mlir |
| %results0:2 = "stablehlo.infeed"(%token) : |
| (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token) |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Token:$token, /*infeed_i1*/ |
| DefaultValuedStrAttr<StrAttr, "">:$infeed_config, /*infeed_i2*/ |
| OptionalAttr<ArrayAttr>:$layout |
| ); |
| let results = (outs Variadic<HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken>); |
| let hasVerifier = 1; |
| } |
| |
| def StableHLO_OutfeedOp : StableHLO_Op<"outfeed", |
| [DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "Outfeed operation"; |
| let description = [{ |
| Writes `inputs` to the outfeed and produces a `result` token. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#outfeed |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.outfeed"(%input0, %token) : |
| (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token |
| ``` |
| }]; |
| |
| let arguments = (ins |
| Variadic<HLO_TensorOrPerAxisQuantizedTensor>:$inputs, /*outfeed_i1*/ |
| HLO_Token:$token, /*outfeed_i2*/ |
| DefaultValuedStrAttr<StrAttr, "">:$outfeed_config /*outfeed_i3*/ |
| ); |
| let results = (outs HLO_Token); |
| } |
| |
| def StableHLO_SendOp : StableHLO_Op<"send", |
| [DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "Send operation"; |
| let description = [{ |
| Sends `inputs` to a channel `channel_id` and produces a `result` token. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#send |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.send"(%operand, %token) { |
| channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>, |
| is_host_transfer = true |
| } : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token |
| ``` |
| }]; |
| |
| let arguments = (ins |
| Variadic<HLO_TensorOrPerAxisQuantizedTensor>:$inputs, /*send_i1*/ |
| HLO_Token:$token, /*send_i2*/ |
| StableHLO_ChannelHandle:$channel_handle, /*send_i3_i4*/ |
| DefaultValuedOptionalAttr<BoolAttr, "false">:$is_host_transfer /*send_i5*/ |
| ); |
| |
| let results = (outs HLO_Token); |
| } |
| |
| def StableHLO_RecvOp : StableHLO_Op<"recv"> { |
| let summary = "Recv operation"; |
| let description = [{ |
| Receives data from a channel with `channel_id` and produces `results`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#recv |
| |
| Example: |
| ```mlir |
| %results:2 = "stablehlo.recv"(%token) { |
| channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>, |
| is_host_transfer = true |
| } : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token) |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Token:$token, /*recv_i1*/ |
| StableHLO_ChannelHandle:$channel_handle, /*recv_i2_i3*/ |
| DefaultValuedOptionalAttr<BoolAttr, "false">:$is_host_transfer /*recv_i4*/ |
| ); |
| |
| let results = (outs Variadic<HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken>); |
| let hasVerifier = 1; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO parallelism related op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| def StableHLO_ReplicaIdOp : StableHLO_Op<"replica_id", [Pure, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "ReplicaId operation"; |
| let description = [{ |
| Produces `replica_id` of the current process. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#replica_id |
| |
| Example: |
| ```mlir |
| %result = stablehlo.replica_id : tensor<ui32> |
| ``` |
| }]; |
| let results = (outs UI32RankedTensor); |
| |
| let assemblyFormat = "attr-dict `:` type(results)"; |
| } |
| |
| def StableHLO_PartitionIdOp : StableHLO_Op<"partition_id", [Pure, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "PartitionId operation"; |
| let description = [{ |
| Produces `partition_id` of the current process. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#partition_id |
| |
| Example: |
| ```mlir |
| %result = stablehlo.partition_id : tensor<ui32> |
| ``` |
| }]; |
| let results = (outs UI32RankedTensor); |
| |
| let assemblyFormat = "attr-dict `:` type(results)"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO control flow op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| def StableHLO_AfterAllOp : StableHLO_Op<"after_all", [Pure, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "AfterAll operation"; |
| let description = [{ |
| Ensures that the operations producing the `inputs` are executed before any |
| operations that depend on `result`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all |
| |
| Example: |
| ```mlir |
| %result = stablehlo.after_all %input0, %input1 : !stablehlo.token |
| ``` |
| }]; |
| |
| let arguments = (ins Variadic<HLO_Token>:$inputs /*after_all_i1*/); |
| let results = (outs HLO_Token:$result); |
| |
| let assemblyFormat = [{ |
| $inputs attr-dict |
| `:` custom<VariadicSameOperandsAndResultType>(ref($inputs), type($inputs), type($result)) |
| }]; |
| } |
| |
| // Xla Client API has two separate calls for indexed and predicated conditional, |
| // although both eventually map to kConditional HLO. IfOp maps to predicated |
| // conditional use of kConditional HLO. |
| def StableHLO_IfOp: StableHLO_Op<"if", [ |
| RecursiveMemoryEffects, |
| RecursivelySpeculatable, |
| SingleBlockImplicitTerminator<"ReturnOp">, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "If operation"; |
| let description = [{ |
| Produces the output from executing exactly one branch from `true_branch` or |
| `false_branch` depending on the value of `pred`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#if |
| |
| Example: |
| %result = "stablehlo.if"(%pred) ({ |
| "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> () |
| }, { |
| "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> () |
| }) : (tensor<i1>) -> tensor<i32> |
| }]; |
| |
| let arguments = (ins |
| HLO_PredTensor:$pred /*if_i1*/ |
| ); |
| |
| let regions = (region SizedRegion<1>:$true_branch /*if_i2*/, |
| SizedRegion<1>:$false_branch /*if_i3*/); |
| |
| let results = (outs Variadic<HLO_TensorOrPerAxisQuantizedTensorOrToken>); |
| } |
| |
| // Xla Client API has two separate calls for indexed and predicated conditional, |
| // although both eventually map to kConditional HLO. CaseOp maps to indexed |
| // conditional use of kConditional HLO. |
| def StableHLO_CaseOp: StableHLO_Op<"case", [ |
| RecursiveMemoryEffects, |
| RecursivelySpeculatable, |
| SingleBlockImplicitTerminator<"ReturnOp">, |
| DeclareOpInterfaceMethods<InferTypeOpInterface> /*case_c4*/ |
| ]> { |
| let summary = "Case operation"; |
| let description = [{ |
| Produces the output from executing exactly one `function` from `branches` |
| depending on the value of `index`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case |
| |
| Example: |
| ```mlir |
| %result0, %result1 = "stablehlo.case"(%index) ({ |
| stablehlo.return %result_branch0, %result_branch0 : tensor<2xi64>, tensor<2xi64> |
| }, { |
| stablehlo.return %result_branch1, %result_branch1 : tensor<2xi64>, tensor<2xi64> |
| }) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>) |
| ``` |
| }]; |
| |
| let arguments = (ins |
| I32RankedTensor:$index /*case_i1*/ |
| ); |
| |
| let regions = (region VariadicRegion<SizedRegion<1>>:$branches /*case_i2*/); |
| |
| let results = (outs Variadic<HLO_TensorOrPerAxisQuantizedTensorOrToken>); |
| } |
| |
| |
| def StableHLO_WhileOp: StableHLO_Op<"while", [ |
| RecursiveMemoryEffects, |
| RecursivelySpeculatable, |
| SingleBlockImplicitTerminator<"ReturnOp">, |
| DeclareOpInterfaceMethods<InferTypeOpInterface> /*while_c3*/, |
| OpAsmOpInterface |
| ]> { |
| let summary = "While operation"; |
| let description = [{ |
| Produces the output from executing `body` function 0 or more times while the |
| `cond` function outputs `true`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#while |
| |
| Example: |
| ```mlir |
| %results0, %results1 = stablehlo.while(%arg0 = %init_i, %arg1 = %init_sum) : tensor<i64>, tensor<i64> |
| cond { |
| %cond = stablehlo.compare LT, %arg0, %ten : (tensor<i64>, tensor<i64>) -> tensor<i1> |
| stablehlo.return %cond : tensor<i1> |
| } do { |
| %new_sum = stablehlo.add %arg1, %one : tensor<i64> |
| %new_i = stablehlo.add %arg0, %one : tensor<i64> |
| stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64> |
| } |
| ``` |
| }]; |
| let arguments = (ins Variadic<HLO_TensorOrPerAxisQuantizedTensorOrToken>:$operand /*while_i1*/); |
| |
| let regions = (region |
| SizedRegion<1>:$cond /*while_i2*/, |
| SizedRegion<1>:$body /*while_i3*/ |
| ); |
| |
| let results = (outs Variadic<HLO_TensorOrPerAxisQuantizedTensorOrToken>); |
| |
| let hasVerifier = 1; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| // Method of OpAsmOpInterface used during custom printing to name the block |
| // arguments in the nested regions. We name both the condition and the body |
| // regions entry arguments the same way, with a `iterArg` prefix. Since the |
| // two regions are side-by-side they will have the same name, which allows |
| // us to print them once and share it for the two regions, and still be able |
| // to parse them back. |
| void getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { |
| for (BlockArgument arg : region.getArguments()) |
| setNameFn(arg, "iterArg"); |
| } |
| }]; |
| let hasCustomAssemblyFormat = 1; |
| } |
| |
| def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", |
| [ConditionallySpeculatable, SameOperandsAndResultElementType] /*all_gather_c6*/> { |
| string summary = "AllGather operation"; |
| string description = [{ |
| Within each process group in the process grid, concatenates the values of the |
| `operand` tensor from each process along `all_gather_dim` and produces a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_gather |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.all_gather"(%operand) { |
| all_gather_dim = 1 : i64, |
| replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, |
| channel_handle = #stablehlo.channel_handle<handle = 0, type = 0> |
| } : (tensor<2x2xi64>) -> tensor<2x4xi64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, /*all_gather_i1*/ |
| I64Attr:$all_gather_dim, /*all_gather_i2*/ |
| I64ElementsAttr:$replica_groups, /*all_gather_i3*/ |
| OptionalAttr<StableHLO_ChannelHandle>:$channel_handle, /*all_gather_i4*/ |
| UnitAttr:$use_global_device_ids /*all_gather_i5*/ |
| ); |
| let results = (outs HLO_Tensor); |
| let hasVerifier = 1; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce", |
| [HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, |
| InferTensorType /*all_reduce_c6, all_reduce_c7*/]> { |
| let summary = "AllReduce operation"; |
| let description = [{ |
| Within each process group in the process grid, applies a reduction function |
| `computation` to the values of the `operand` tensor from each process and |
| produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_reduce |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.all_reduce"(%operand) ({ |
| ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): |
| %0 = stablehlo.add %arg1, %arg2 : tensor<i64> |
| stablehlo.return %0 : tensor<i64> |
| }) { |
| replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> |
| channel_handle = #stablehlo.channel_handle<handle = 0, type = 0> |
| // use_global_device_ids = false |
| } : (tensor<4xi64>) -> tensor<4xi64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, /*all_reduce_i1*/ |
| I64ElementsAttr:$replica_groups, /*all_reduce_i2*/ |
| OptionalAttr<StableHLO_ChannelHandle>:$channel_handle, /*all_reduce_i3*/ |
| UnitAttr:$use_global_device_ids /*all_reduce_i4*/ |
| ); |
| let regions = (region SizedRegion<1>:$computation /*all_reduce_i5*/); |
| let results = (outs HLO_Tensor); |
| let hasVerifier = 1; |
| } |
| |
| def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", [ConditionallySpeculatable]> { |
| let summary = "ReduceScatter operation"; |
| let description = [{ |
| Within each process group in the process grid, performs reduction, using |
| `computations`, over the values of the `operand` tensor from each process, |
| splits the reduction result along `scatter_dimension` into parts, and |
| scatters the split parts between the processes to produce the `result`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_scatter |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.reduce_scatter"(%operand) ({ |
| ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): |
| %0 = stablehlo.add %arg0, %arg1 : tensor<i64> |
| stablehlo.return %0 : tensor<i64> |
| }) { |
| scatter_dimension = 1 : i64, |
| replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, |
| channel_handle = #stablehlo.channel_handle<handle = 0, type = 0> |
| } : (tensor<2x4xi64>) -> tensor<2x2xi64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, /*reduce_scatter_i1*/ |
| I64Attr:$scatter_dimension, /*reduce_scatter_i2*/ |
| I64ElementsAttr:$replica_groups, /*reduce_scatter_i3*/ |
| OptionalAttr<StableHLO_ChannelHandle>:$channel_handle, /*reduce_scatter_i4*/ |
| UnitAttr:$use_global_device_ids /*reduce_scatter_i5*/ |
| ); |
| let regions = (region SizedRegion<1>:$computation /*reduce_scatter_i6*/); |
| let results = (outs HLO_Tensor); |
| let hasVerifier = 1; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all", |
| [ConditionallySpeculatable, |
| SameOperandsAndResultElementType /*all_to_all_c9*/, |
| InferTensorType /*all_to_all_c9*/]> { |
| let summary = "AllToAll operation"; |
| let description = [{ |
| Within each process group in the process grid, splits the values of the |
| `operand` tensor along `split_dimension` into parts, scatters the split parts |
| between the processes, concatenates the scattered parts along `concat_dimension` |
| and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_to_all |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.all_to_all"(%operand) { |
| split_dimension = 1 : i64, |
| concat_dimension = 0 : i64, |
| split_count = 2 : i64, |
| replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> |
| } : (tensor<2x4xi64>) -> tensor<4x2xi64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, /*all_to_all_i1*/ |
| I64Attr:$split_dimension, /*all_to_all_i2*/ |
| I64Attr:$concat_dimension, /*all_to_all_i3*/ |
| I64Attr:$split_count, /*all_to_all_i4*/ |
| I64ElementsAttr:$replica_groups, /*all_to_all_i5*/ |
| OptionalAttr<StableHLO_ChannelHandle>:$channel_handle /*all_to_all_i6*/ |
| ); |
| let results = (outs HLO_Tensor); |
| |
| // channel_handle is only used for the SPMD partitioner, so we add a |
| // simplified builder method for convenience. |
| let builders = [ |
| OpBuilder<(ins |
| "::mlir::Type":$result_type, "::mlir::Value":$operand, |
| "::mlir::IntegerAttr": $split_dimension, |
| "::mlir::IntegerAttr": $concat_dimension, |
| "::mlir::IntegerAttr": $split_count, |
| "::mlir::DenseIntElementsAttr": $replica_groups)>]; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [ |
| HLO_RecursivelySpeculatableIfAllInputsStatic, |
| RecursiveMemoryEffects, |
| SameVariadicOperandSize /*reduce_c3*/, |
| InferTensorTypeWithReify /*reduce_c7, reduce_c8*/, |
| SingleBlockImplicitTerminator<"ReturnOp"> |
| ]> { /*reduce_c7*/ |
| let summary = "Reduce operation"; |
| let description = [{ |
| Applies a reduction function `body` to `inputs` and `init_values` along the |
| `dimensions` and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.reduce"(%input, %init_value) ({ |
| ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): |
| %0 = stablehlo.add %arg0, %arg1 : tensor<i64> |
| stablehlo.return %0 : tensor<i64> |
| }) { |
| dimensions = array<i64: 1> |
| } : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64> |
| ``` |
| }]; |
| let arguments = (ins |
| Variadic<HLO_Tensor>:$inputs, /*reduce_i1*/ |
| Variadic<HLO_Tensor>:$init_values, /*reduce_i2*/ |
| GenericDenseI64ArrayAttr:$dimensions /*reduce_i3*/ |
| ); |
| let regions = (region SizedRegion<1>:$body /*reduce_i4*/); |
| |
| // Builder |
| // The following custom builder allows inferring the operation type using the |
| // 'element_types' of the arguments of the 'body'. |
| let builders = [ |
| OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$init_values, |
| "DenseI64ArrayAttr":$dimensions, "TypeRange":$element_types)>, |
| ]; |
| |
| let results = (outs Variadic<HLO_Tensor>); |
| |
| let hasCustomAssemblyFormat = 1; |
| |
| let hasVerifier = 1; |
| |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO tuple op definitions. |
| //===----------------------------------------------------------------------===// |
| def StableHLO_GetTupleElementOp: StableHLO_Op<"get_tuple_element", [Pure, |
| DeclareOpInterfaceMethods<InferTypeOpInterface> /*get_tuple_element_c2*/]> { |
| let summary = "GetTupleElement operation"; |
| let description = [{ |
| Extracts element at `index` position of the `operand` tuple and produces a |
| `result`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#get_tuple_element |
| |
| Example: |
| ```mlir |
| %result = stablehlo.get_tuple_element %operand[0] : (tuple<tensor<2xf64>, tuple<tensor<i64>>>) -> tensor<2xf64> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tuple:$operand, /*get_tuple_element_i1*/ |
| I32Attr:$index /*get_tuple_element_i2*/ |
| ); |
| |
| let results = (outs HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple); |
| |
| let assemblyFormat = [{ |
| $operand `[` $index `]` attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_TupleOp : StableHLO_Op<"tuple", [Pure, |
| DeclareOpInterfaceMethods<InferTypeOpInterface> /*tuple_c1*/]> { |
| let summary = "Tuple operation"; |
| let description = [{ |
| Produces a `result` tuple from values `val`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tuple |
| |
| Example: |
| ```mlir |
| %result = stablehlo.tuple %val0, %val1 : tuple<tensor<2xf64>, tuple<tensor<i64>>> |
| ``` |
| }]; |
| |
| let arguments = (ins Variadic<HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple>:$val /*tuple_i1*/); |
| let results = (outs HLO_Tuple:$result); |
| |
| let assemblyFormat = [{ |
| $val attr-dict `:` custom<TupleOpType>(type($val), type($result)) |
| }]; |
| } |
| |
| def StableHLO_CompareOp: StableHLO_Op<"compare", |
| [HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect, |
| Elementwise, |
| HLO_CompatibleOperandsElementType /*compare_c1*/, |
| SameOperandsAndResultShape /*compare_c2*/, |
| InferTensorTypeWithReify /*compare_c1, compare_c2*/]> { |
| let summary = "Compare operation"; |
| let description = [{ |
| Performs element-wise comparison of `lhs` and `rhs` tensors according to |
| `comparison_direction` and `compare_type`, and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#compare |
| |
| Example: |
| ```mlir |
| %result = stablehlo.compare LT, %lhs, %rhs, FLOAT : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$lhs /*compare_i1*/, |
| HLO_Tensor:$rhs /*compare_i2*/, |
| StableHLO_ComparisonDirectionAttr:$comparison_direction /*compare_i3*/, |
| OptionalAttr<StableHLO_ComparisonTypeAttr>:$compare_type /*compare_i4*/ |
| ); |
| let results = (outs HLO_PredTensor); |
| |
| let builders = [ |
| OpBuilder<(ins "Value":$lhs, "Value":$rhs, |
| "::mlir::stablehlo::ComparisonDirection":$comparison_direction, |
| CArg<"::mlir::stablehlo::ComparisonType", |
| "::mlir::stablehlo::ComparisonType::NOTYPE">:$compare_type)>, |
| ]; |
| |
| let assemblyFormat = [{ |
| $comparison_direction `,` $lhs `,` $rhs (`,` $compare_type^)? |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO Slice definitions. |
| //===----------------------------------------------------------------------===// |
| |
| def StableHLO_SliceOp: StableHLO_Op< |
| "slice", |
| [HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, NoMemoryEffect, |
| SameOperandsAndResultElementType /*slice_c1*/, |
| AllMatchSameOperatorTrait<["start_indices", "limit_indices", |
| "strides"], "$_self.size()", "size"> /*slice_c2*/, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "Slice operation"; |
| let description = [{ |
| Extracts a slice from the `operand` using statically-computed starting |
| indices and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice |
| |
| Example: |
| ```mlir |
| %result = stablehlo.slice %operand [1:3, 4:8:2] |
| : (tensor<3x8xi64>) -> tensor<2x2xi64> |
| |
| // Same in generic form: the `1:3` above is mapped to the first entry in |
| // `start_indices` and `limit_indices`, while `strides` is implicitly 1. |
| // The `4:8:2` above is parsed into the second entry of `start_indices`, |
| // `limit_indices` and `strides` respectively. |
| %result = "stablehlo.slice" (%operand) { |
| start_indices = array<i64: 1, 4>, |
| limit_indices = array<i64: 3, 8>, |
| strides = array<i64: 1, 2> |
| } : (tensor<3x8xi64>) -> tensor<2x2xi64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| GenericDenseI64ArrayAttr:$start_indices, |
| GenericDenseI64ArrayAttr:$limit_indices, |
| GenericDenseI64ArrayAttr:$strides |
| ); |
| |
| let assemblyFormat = [{ |
| $operand custom<SliceRanges>($start_indices, $limit_indices, $strides) |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| |
| let results = (outs HLO_Tensor); |
| } |
| |
| def StableHLO_DynamicSliceOp: StableHLO_Op<"dynamic_slice", |
| [Pure, AllElementTypesMatch<["operand", "result"]> /*dynamic_slice_c1*/, |
| InferTensorType]> { |
| let summary = "DynamicSlice operation"; |
| let description = [{ |
| Extracts a slice from the `operand` using dynamically-computed starting |
| indices and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_slice |
| |
| Example: |
| ```mlir |
| %result = stablehlo.dynamic_slice %operand, %start_indices0, %start_indices1, sizes = [2, 2] |
| : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand /*dynamic_slice_i1*/, |
| Variadic<HLO_ScalarIntTensor>:$start_indices /*dynamic_slice_i2*/, |
| GenericDenseI64ArrayAttr:$slice_sizes /*dynamic_slice_i3*/ |
| ); |
| |
| let results = (outs HLO_Tensor:$result); |
| |
| let assemblyFormat = [{ |
| $operand `,` custom<VariadicOperandWithAttribute>($start_indices) |
| `sizes` `=` custom<DenseI64Array>($slice_sizes) |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_DynamicUpdateSliceOp: StableHLO_Op<"dynamic_update_slice", |
| [Pure, AllElementTypesMatch<["operand", "update", "result"]> /*dynamic_update_slice_c1, dynamic_update_slice_c2*/, |
| InferTensorType]> { |
| let summary = "DynamicUpdateSlice operation"; |
| let description = [{ |
| Produces a `result` tensor which is equal to the `operand` tensor except |
| that the slice starting at `start_indices` is updated with the values in |
| `update`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_update_slice |
| |
| Example: |
| ```mlir |
| %result = stablehlo.dynamic_update_slice %operand, %update, %start_indices0, %start_indices1 |
| : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand /*dynamic_update_slice_i1*/, |
| HLO_Tensor:$update /*dynamic_update_slice_i2*/, |
| Variadic<HLO_ScalarIntTensor>:$start_indices /*dynamic_update_slice_i3*/ |
| ); |
| let results = (outs HLO_Tensor:$result); |
| |
| let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; |
| } |
| |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO Other op definitions. |
| //===----------------------------------------------------------------------===// |
| |
| def StableHLO_BatchNormGradOp : StableHLO_Op<"batch_norm_grad", |
| [HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect, |
| HLO_CompatibleOperandsAndResultElementType /*batch_norm_grad_c2*/, |
| InferTensorType /*batch_norm_grad_c3, batch_norm_grad_c4*/]> { |
| let summary = "BatchNormGrad operation"; |
| let description = [{ |
| Computes gradients of several inputs of BatchNormTrainingOp backpropagating |
| from `grad_output`, and produces `grad_operand`, `grad_scale` and |
| `grad_offset` tensors. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#batch_norm_grad |
| |
| Example: |
| ```mlir |
| %grad_operand, %grad_scale, %grad_offset = |
| "stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) { |
| epsilon = 0.0 : f32, |
| feature_index = 2 : i64 |
| } : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, |
| tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) |
| ``` |
| }]; |
| |
| let arguments = (ins |
| RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$operand, /*batch_norm_grad_i1*/ |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$scale, /*batch_norm_grad_i2*/ |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$mean, /*batch_norm_grad_i3*/ |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$variance, /*batch_norm_grad_i4*/ |
| RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$grad_output, /*batch_norm_grad_i5*/ |
| F32Attr:$epsilon, /*batch_norm_grad_i6*/ |
| I64Attr:$feature_index /*batch_norm_grad_i7*/ |
| ); |
| |
| let results = (outs |
| RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$grad_operand, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$grad_scale, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$grad_offset); |
| } |
| |
| def StableHLO_BatchNormInferenceOp : StableHLO_Op<"batch_norm_inference", |
| [HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect, |
| HLO_CompatibleOperandsAndResultElementType /*batch_norm_inference_c2*/, |
| InferTensorType /*batch_norm_inference_c7*/]> { |
| let summary = "BatchNormInference operation"; |
| let description = [{ |
| Normalizes the `operand` tensor across all dimensions except for the |
| `feature_index` dimension and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#batch_norm_inference |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) { |
| epsilon = 0.0 : f32, |
| feature_index = 2 : i64 |
| } : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$operand /*batch_norm_inference_i1*/, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$scale /*batch_norm_inference_i2*/, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$offset /*batch_norm_inference_i3*/, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$mean /*batch_norm_inference_i4*/, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$variance /*batch_norm_inference_i5*/, |
| F32Attr:$epsilon /*batch_norm_inference_i6*/, |
| I64Attr:$feature_index /*batch_norm_inference_i7*/ |
| ); |
| |
| let results = (outs RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$result); |
| } |
| |
| def StableHLO_BatchNormTrainingOp : StableHLO_Op<"batch_norm_training", |
| [HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect, |
| HLO_CompatibleOperandsAndResultElementType /*batch_norm_training_c2*/, |
| InferTensorType /*batch_norm_training_c5, batch_norm_training_c6, batch_norm_training_c7*/]> { |
| let summary = "BatchNormTraining operation"; |
| let description = [{ |
| Computes mean and variance across batch and spatial dimensions and |
| normalizes the `operand` tensor, for each feature in the `feature_index` |
| dimension and produces `output`, `batch_mean` and `batch_var` tensors. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#batch_norm_training |
| |
| Example: |
| ```mlir |
| %output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) { |
| epsilon = 0.0 : f32, |
| feature_index = 2 : i64 |
| } : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) -> |
| (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) |
| ``` |
| }]; |
| |
| let arguments = (ins |
| RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$operand /*batch_norm_training_i1*/, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$scale /*batch_norm_training_i2*/, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$offset /*batch_norm_training_i3*/, |
| F32Attr:$epsilon /*batch_norm_training_i4*/, |
| I64Attr:$feature_index /*batch_norm_training_i5*/ |
| ); |
| |
| let results = (outs |
| RankedTensorOf<[HLO_Float, HLO_QuantizedInt]>:$output, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$batch_mean, |
| 1DTensorOf<[HLO_Float, HLO_QuantizedInt]>:$batch_var); |
| } |
| |
| def StableHLO_BitcastConvertOp : StableHLO_ShapedInterfaceOp<"bitcast_convert", |
| [ConditionallySpeculatable, NoMemoryEffect]> { |
| let summary = "BitcastConvert operation"; |
| let description = [{ |
| Performs a bitcast operation on `operand` tensor and produces a `result` |
| tensor where the bits of the entire `operand` tensor are reinterpreted using |
| the type of the `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#bitcast_convert |
| |
| Example: |
| ```mlir |
| %result = stablehlo.bitcast_convert %operand : (tensor<f64>) -> tensor<4xf16> |
| ``` |
| }]; |
| |
| let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand /*bitcast_convert_i1*/); |
| let results = (outs HLO_TensorOrPerAxisQuantizedTensor); |
| let hasVerifier = 1; |
| |
| let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_BroadcastOp : StableHLO_ShapedInterfaceOp<"broadcast", |
| [Pure, SameOperandsAndResultElementType, InferTensorType]> { |
| let summary = "Broadcast operation"; |
| let description = [{ |
| This operation is on its way out of StableHLO, so it is not included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. |
| |
| Informally, this operation does the same thing as XLA's Broadcast: |
| https://www.tensorflow.org/xla/operation_semantics#broadcast |
| |
| Example: |
| ```mlir |
| %result = stablehlo.broadcast %operand, sizes = [1, 2] : (tensor<3xi32>) -> tensor<1x2x3xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| GenericDenseI64ArrayAttr:$broadcast_sizes |
| ); |
| |
| let results = (outs HLO_TensorOrPerAxisQuantizedTensor); |
| |
| let assemblyFormat = [{ |
| $operand `,` `sizes` `=` custom<DenseI64Array>($broadcast_sizes) |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim", |
| [Pure, HLO_CompatibleOperandsAndResultElementType /*broadcast_in_dim_c1*/]> { |
| let summary = "BroadcastInDim operation"; |
| let description = [{ |
| Expands the dimensions and/or rank of an input tensor by duplicating the |
| data in the `operand` tensor and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim |
| |
| Example: |
| ```mlir |
| %result = stablehlo.broadcast_in_dim %operand, dims = [2, 1] : (tensor<1x3xi32>) -> tensor<2x3x2xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_TensorOrPerAxisQuantizedTensor:$operand /*broadcast_in_dim_i1*/, |
| GenericDenseI64ArrayAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/ |
| ); |
| |
| let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); |
| |
| let hasVerifier = 1; |
| |
| let assemblyFormat = [{ |
| $operand `,` `dims` `=` custom<DenseI64Array>($broadcast_dimensions) |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp< |
| "dynamic_broadcast_in_dim", [Pure]> { |
| let summary = "DynamicBroadcastInDim operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation does the same thing as BroadcastInDimOp except |
| that the result shape is specified dynamically via `output_dimensions`: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#broadcast_in_dim |
| |
| It also accepts optional attributes to express static knowledge about the |
| expanding behavior of dimensions. If not specified, all dimensions are |
| assumed to be possibly expanding. The sets of dimensions that are known to |
| be expanding and the set of dimensions that are known to be non-expanding |
| must be disjoint and they must be a subset of the operand's dimensions. |
| }]; |
| let arguments = (ins |
| HLO_TensorOrPerAxisQuantizedTensor:$operand, |
| HLO_StaticDimensionTensor:$output_dimensions, |
| GenericDenseI64ArrayAttr:$broadcast_dimensions, |
| OptionalAttr<GenericDenseI64ArrayAttr>:$known_expanding_dimensions, |
| OptionalAttr<GenericDenseI64ArrayAttr>:$known_nonexpanding_dimensions |
| ); |
| |
| let results = (outs HLO_TensorOrPerAxisQuantizedTensor); |
| |
| let builders = [ |
| OpBuilder<(ins |
| "Type":$result_type, "Value":$operand, "Value":$output_dimensions, |
| "DenseI64ArrayAttr":$broadcast_dimensions), [{ |
| build($_builder, $_state, result_type, operand, output_dimensions, |
| broadcast_dimensions, /*known_expanding_dimensions=*/{}, |
| /*known_nonexpanding_dimensions=*/{}); |
| }]> |
| ]; |
| |
| let hasVerifier = 1; |
| |
| let assemblyFormat = [{ |
| $operand `,` $output_dimensions `,` `dims` `=` custom<DenseI64Array>($broadcast_dimensions) |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| // Note: There is no HLO_CallOp because the standard call operation mlir::func::CallOp |
| // is used instead. A mlir::func::CallOp is exported to a HLO call instruction |
| // directly. |
| |
| def StableHLO_CholeskyOp : StableHLO_Op<"cholesky", |
| [NoMemoryEffect, |
| HLO_CompatibleOperandsAndResultElementType /*cholesky_c1*/, |
| InferTensorType /*cholesky_c1*/]> { |
| let summary = "Cholesky operation"; |
| let description = [{ |
| Computes the Cholesky decomposition of a batch of matrices. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cholesky |
| |
| Example: |
| ```mlir |
| %result = stablehlo.cholesky %a, lower = true : tensor<3x3xf64> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_FpComplexOrQuantizedIntTensor:$a, |
| DefaultValuedOptionalAttr<BoolAttr, "false">:$lower |
| ); |
| |
| let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result); |
| |
| let assemblyFormat = [{ |
| $a (`,` `lower` `=` $lower^)? attr-dict `:` custom<SameOperandsAndResultType>(type($a), type($result)) |
| }]; |
| } |
| |
| def StableHLO_ClampOp : StableHLO_ShapedInterfaceOp<"clamp", |
| [HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect, |
| HLO_CompatibleOperandsAndResultElementType /* clamp_c3 */, HLO_BroadcastingElementwise, |
| InferTensorType]> { |
| let summary = "Clamp operation"; |
| let description = [{ |
| Clamps every element of the `operand` tensor between a minimum and maximum |
| value and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#clamp |
| |
| Example: |
| ```mlir |
| %result = stablehlo.clamp %min, %operand, %max : tensor<3xi32> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$min, /*clamp_i1*/ |
| HLO_Tensor:$operand, /*clamp_c3, clamp_i2*/ |
| HLO_Tensor:$max /*clamp_i3*/ |
| ); |
| let results = (outs HLO_Tensor:$result); |
| |
| let assemblyFormat = [{ |
| $min `,` $operand `,` $max attr-dict |
| `:` custom<SameOperandsAndResultType>(type($min), type($operand), type($max), type($result)) |
| }]; |
| } |
| |
| def StableHLO_ConcatenateOp : StableHLO_ShapedInterfaceOp<"concatenate", |
| [ConditionallySpeculatable, NoMemoryEffect, |
| SameOperandsAndResultElementType /*concatenate_c1, concatenate_c3, concatenate_c5*/, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "Concatenate operation"; |
| let description = [{ |
| Concatenates a variadic number of tensors in `inputs` along `dimension` |
| dimension in the same order as the given arguments and produces a `result` |
| tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#concatenate |
| |
| Example: |
| ```mlir |
| %result = stablehlo.concatenate %input0, %input1, dim = 0 : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| Variadic<HLO_Tensor>:$inputs /*concatenate_i1*/, |
| I64Attr:$dimension /*concatenate_i2*/ |
| ); |
| |
| let results = (outs HLO_Tensor); |
| |
| let assemblyFormat = [{ |
| custom<VariadicOperandWithAttribute>($inputs) `dim` `=` $dimension attr-dict `:` functional-type(operands, results) |
| }]; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| |
| def StableHLO_CollectiveBroadcastOp: StableHLO_Op<"collective_broadcast", |
| [HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, |
| SameOperandsAndResultElementType /*collective_broadcast_c3*/, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>, |
| DeclareOpInterfaceMethods<InferShapedTypeOpInterface, ["inferReturnTypeComponents"]> |
| ]> { |
| let summary = "CollectiveBroadcast operation"; |
| let description = [{ |
| Within each process group in the process grid, send the value of the |
| `operand` tensor from the source process to the target processes and produce a |
| `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_broadcast |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.collective_broadcast"(%operand) { |
| replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, |
| channel_handle = #stablehlo.channel_handle<handle = 0, type = 0> |
| } : (tensor<1x2xi64>) -> tensor<1x2xi64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, /*collective_broadcast_i1*/ |
| I64ElementsAttr:$replica_groups, /*collective_broadcast_i2*/ |
| OptionalAttr<StableHLO_ChannelHandle>:$channel_handle /*collective_broadcast_i3*/ |
| ); |
| let results = (outs HLO_Tensor); |
| let hasVerifier = 1; |
| // channel_handle is only used for the SPMD partitioner, so we add a |
| // simplified builder method for convenience. |
| let builders = [ |
| OpBuilder<(ins |
| "::mlir::Type":$result_type, "::mlir::Value":$operand, |
| "::mlir::DenseIntElementsAttr":$replica_groups)>]; |
| } |
| |
| def StableHLO_CollectivePermuteOp: StableHLO_Op<"collective_permute", |
| [HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, |
| SameOperandsAndResultElementType, /*collective_permute_c5*/ |
| DeclareOpInterfaceMethods<InferTypeOpInterface>, |
| DeclareOpInterfaceMethods<InferShapedTypeOpInterface, ["inferReturnTypeComponents"]> |
| ]> { |
| let summary = "CollectivePermute operation"; |
| let description = [{ |
| Within each process group in the process grid, sends the value of the |
| `operand` tensor from the source process to the target process and produces |
| a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective_permute |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.collective_permute"(%operand) { |
| source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, |
| channel_handle = #stablehlo.channel_handle<handle = 0, type = 0> |
| } : (tensor<2x2xi64>) -> tensor<2x2xi64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, /*collective_permute_i1*/ |
| I64ElementsAttr:$source_target_pairs, /*collective_permute_i2*/ |
| OptionalAttr<StableHLO_ChannelHandle>:$channel_handle /*collective_permute_i3*/ |
| ); |
| let results = (outs HLO_Tensor); |
| let hasVerifier = 1; |
| // channel_handle is only used for the SPMD partitioner, so we add a |
| // simplified builder method for convenience. |
| let builders = [ |
| OpBuilder<(ins |
| "::mlir::Type":$result_type, "::mlir::Value":$operand, |
| "::mlir::DenseIntElementsAttr":$source_target_pairs)>]; |
| } |
| |
| def StableHLO_CompositeOp : StableHLO_Op<"composite", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> { |
| let summary = "Composite operation"; |
| let description = [{ |
| Encapsulates an operation made up (composed) of other StableHLO operations, |
| taking `inputs` and `composite_attributes` and producing `results`. The |
| semantics of the op are implemented by the `decomposition` attribute. The |
| `composite` op can be replaced with its decomposition without changing program |
| semantics. In cases where inlining the decomposition does not provide the same |
| op semantics, prefer using `custom_call`. |
| |
| The `version` field (defaults to `0`) is used to denote when a composite's |
| semantics change. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#composite |
| |
| Example: |
| ```mlir |
| %results = stablehlo.composite "my.op" %input0, %input1 { |
| composite_attributes = { |
| my_attribute = "my_value" |
| }, |
| decomposition = @my_op, |
| version = 1 : i32 |
| } : (tensor<f32>, tensor<f32>) -> tensor<f32> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| Variadic<HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple>:$inputs, /*composite_i1*/ |
| StrAttr:$name, /*composite_i2*/ |
| DefaultValuedOptionalAttr<DictionaryAttr, "{}">:$composite_attributes, /*composite_i3*/ |
| FlatSymbolRefAttr:$decomposition, /*composite_i4*/ |
| DefaultValuedOptionalAttr<I32Attr, "0">:$version /*composite_i5*/ |
| ); |
| let results = (outs Variadic<HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple>); |
| |
| let assemblyFormat = "$name $inputs attr-dict `:` functional-type(operands, results)"; |
| } |
| |
| def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", |
| [ConditionallySpeculatable, NoMemoryEffect]> { |
| let summary = "Convolution operation"; |
| let description = [{ |
| Computes dot products between windows of `lhs` and slices of `rhs` and |
| produces `result`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution |
| |
| Example: |
| ```mlir |
| %result = stablehlo.convolution(%lhs, %rhs) |
| dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], |
| window = { |
| stride = [4, 4], |
| pad = [[0, 0], [0, 0]], |
| lhs_dilate = [2, 2], |
| rhs_dilate = [1, 1], |
| reverse = [0, 0] |
| } { |
| feature_group_count = 1 : i64, |
| batch_group_count = 1 : i64, |
| precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>] |
| } : |
| (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64> |
| ``` |
| }]; |
| let arguments = !con( |
| (ins |
| HLO_Tensor:$lhs, /*convolution_i1*/ |
| HLO_TensorOrPerAxisQuantizedTensor:$rhs), /*convolution_i2*/ |
| StableHLO_ConvolutionAttributes.attributes /*convolution_i3, convolution_i4, |
| convolution_i5, convolution_i6, convolution_i7, convolution_i8, |
| convolution_i9, convolution_i10, convolution_i11, convolution_i12, |
| convolution_i13, convolution_i14, convolution_i15, convolution_i16, |
| convolution_i17, convolution_i18, convolution_i19*/ |
| ); |
| |
| let results = (outs HLO_TensorOrPerAxisQuantizedTensor); |
| let hasVerifier = 1; |
| |
| let extraClassDeclaration = [{ |
| bool hasWindowReversal() { |
| auto reversal = getWindowReversal(); |
| return reversal.has_value() && llvm::any_of(reversal.value(), [](bool v) { return v; }); |
| } |
| }]; |
| |
| let assemblyFormat = [{ |
| `(`operands`)` |
| `dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,` |
| `window` `=` `{` custom<WindowAttributes>($window_strides, $padding, |
| $lhs_dilation, $rhs_dilation, |
| $window_reversal) `}` |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_CrossReplicaSumOp : StableHLO_Op<"cross-replica-sum", |
| [Pure, HLO_CompatibleOperandsAndResultType]> { |
| let summary = "CrossReplicaSum operation"; |
| let description = [{ |
| This operation is on its way out of StableHLO, so it is not included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. |
| |
| Informally, this operation does the same thing as AllReduceOp with |
| `channel_id = 0`, `use_global_device_ids = false` and `computation` |
| implementing addition: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#all_reduce |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.cross-replica-sum"(%operand) { |
| replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> |
| } : (tensor<4xf32>) -> tensor<4xf32> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| I64ElementsAttr:$replica_groups |
| ); |
| |
| let results = (outs HLO_Tensor); |
| } |
| |
| def StableHLO_CustomCallOp: StableHLO_Op<"custom_call", |
| [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { |
| let summary = "CustomCall operation"; |
| let description = [{ |
| Encapsulates an implementation-defined operation `call_target_name` that |
| takes `inputs` and `called_computations` and produces `results`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call |
| |
| Example: |
| ```mlir |
| %results = stablehlo.custom_call @foo(%input0) { |
| backend_config = "bar", |
| called_computations = [@foo] |
| } : (tensor<f64>) -> tensor<f64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| Variadic<HLO_CustomCallValue>:$inputs, |
| StrAttr:$call_target_name, |
| DefaultValuedOptionalAttr<BoolAttr, "false">:$has_side_effect, |
| DefaultValuedStrAttr<StrAttr, "">:$backend_config, |
| // TODO(b/189822916): Remove this field when all clients are migrated to |
| // the status-returning API. |
| DefaultValuedOptionalAttr< |
| StableHLO_CustomCallApiVersionAttr, |
| "::mlir::stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL">: |
| $api_version, |
| DefaultValuedOptionalAttr<StableHLO_FlatSymbolRefArrayAttr, "{}">:$called_computations, |
| OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$operand_layouts, |
| OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$result_layouts, |
| DefaultValuedOptionalAttr< |
| TypedArrayAttrBase< |
| StableHLO_OutputOperandAlias, |
| "Aliasing attribute for outputs and operands of CustomCall">, |
| "{}">:$output_operand_aliases |
| ); |
| |
| let results = (outs Variadic<HLO_CustomCallValue>); |
| let hasVerifier = 1; |
| |
| let assemblyFormat = [{ |
| custom<CustomCallTarget>($call_target_name) `(` $inputs `)` |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_DotOp: StableHLO_Op<"dot", [Pure]> { |
| let summary = "Dot operation"; |
| let description = [{ |
| This operation is on its way out of StableHLO, so it is not included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. |
| |
| Informally, this operation does the same thing as XLA's Dot: |
| https://www.tensorflow.org/xla/operation_semantics#dot |
| |
| Example: |
| ```mlir |
| %0 = stablehlo.dot %arg0, %arg1 : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<1x1xi32> |
| ``` |
| }]; |
| let arguments = ( |
| ins HLO_Tensor:$lhs, |
| HLO_TensorOrPerAxisQuantizedTensor:$rhs, |
| StableHLO_PrecisionConfigAttr:$precision_config |
| ); |
| let results = (outs HLO_TensorOrPerAxisQuantizedTensor); |
| let hasVerifier = 1; |
| |
| // Use empty `` to prevent extra whitespace before precision config. |
| let assemblyFormat = [{ |
| $lhs `,` $rhs `` custom<PrecisionConfig>($precision_config) attr-dict |
| `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_DotGeneralOp: StableHLO_ShapedInterfaceOp<"dot_general", |
| [ConditionallySpeculatable, NoMemoryEffect]> { |
| let summary = "DotGeneral operation"; |
| let description = [{ |
| Computes dot products between slices of `lhs` and slices of `rhs` and |
| produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general |
| |
| Example: |
| ```mlir |
| %result = stablehlo.dot_general %lhs, %rhs, |
| batching_dims = [0] x [0], |
| contracting_dims = [2] x [1], |
| precision = [DEFAULT, DEFAULT] |
| : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$lhs /*dot_general_i1*/, |
| HLO_TensorOrPerAxisQuantizedTensor:$rhs /*dot_general_i2*/, |
| StableHLO_DotDimensionNumbers:$dot_dimension_numbers /*dot_general_i3, dot_general_i4, dot_general_i5, dot_general_i6*/, |
| StableHLO_PrecisionConfigAttr:$precision_config /*dot_general_i7*/ |
| ); |
| |
| let results = (outs HLO_TensorOrPerAxisQuantizedTensor); |
| let hasVerifier = 1; |
| |
| // Use empty `` to prevent extra whitespace before precision config. |
| let assemblyFormat = [{ |
| $lhs `,` $rhs `,` custom<DotDimensionNumbers>($dot_dimension_numbers) `` |
| custom<PrecisionConfig>($precision_config) attr-dict |
| `:` functional-type(operands, results) |
| }]; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_EinsumOp: StableHLO_Op<"einsum", [Pure]> { |
| let summary = "Einsum operation"; |
| let description = [{ |
| This operation is on its way out of StableHLO, so it is not included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. |
| |
| Informally, this operation does the same thing as TF's einsum: |
| https://www.tensorflow.org/api_docs/python/tf/einsum |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.einsum"(%lhs, %rhs) { |
| einsum_config = "ab,bc->ac" |
| } : (tensor<4x16xf32>, tensor<16x4xf32>) -> tensor<4x4xf32> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$lhs, |
| HLO_Tensor:$rhs, |
| StrAttr:$einsum_config |
| ); |
| |
| let results = (outs HLO_Tensor); |
| |
| let assemblyFormat = [{ |
| $lhs `,` $rhs `,` `config` `=` $einsum_config attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_UnaryEinsumOp: StableHLO_Op<"unary_einsum", [Pure]> { |
| let summary = "UnaryEinsum operation"; |
| let description = [{ |
| This operation is on its way out of StableHLO, so it is not included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. |
| |
| Informally, this operation does the same thing as TF's einsum: |
| https://www.tensorflow.org/api_docs/python/tf/einsum |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.unary_einsum"(%operand) { |
| einsum_config = "ab->a" |
| } : (tensor<4x16xf32>) -> tensor<4xf32> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| StrAttr:$einsum_config |
| ); |
| |
| let results = (outs HLO_Tensor); |
| |
| let assemblyFormat = [{ |
| $operand `,` `config` `=` $einsum_config attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_FftOp: StableHLO_Op<"fft", |
| [ConditionallySpeculatable, NoMemoryEffect, |
| InferTensorType]> { |
| let summary = "Fft operation"; |
| let description = [{ |
| Performs the forward and inverse Fourier transforms for real and complex |
| inputs/outputs. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#fft |
| |
| Example: |
| ```mlir |
| %result = stablehlo.fft %operand, type = FFT, length = [4] : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_FpOrComplexTensor:$operand, |
| StableHLO_FftTypeAttr:$fft_type, |
| GenericDenseI64ArrayAttr:$fft_length |
| ); |
| |
| let results = (outs HLO_FpOrComplexTensor); |
| |
| let assemblyFormat = [{ |
| $operand `,` `type` `=` $fft_type `,` `length` `=` custom<DenseI64Array>($fft_length) |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_GatherOp: StableHLO_Op<"gather", |
| [ConditionallySpeculatable, NoMemoryEffect, |
| InferTensorTypeWithReify /*gather_c13*/, |
| AllElementTypesMatch<["operand", "result"]> /*gather_c14*/]> { |
| let summary = "Gather operation"; |
| let description = [{ |
| Gathers slices from `operand` tensor from offsets specified in |
| `start_indices` and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.gather"(%operand, %start_indices) { |
| dimension_numbers = #stablehlo.gather< |
| offset_dims = [2, 3], |
| collapsed_slice_dims = [0], |
| start_index_map = [1, 0], |
| index_vector_dim = 2>, |
| slice_sizes = array<i64: 1, 2, 2>, |
| indices_are_sorted = false |
| } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand /*gather_i1*/, |
| HLO_IntTensor:$start_indices /*gather_i2*/, |
| StableHLO_GatherDimensionNumbers:$dimension_numbers /*gather_i3, gather_i4, gather_i5, gather_i6*/, |
| GenericDenseI64ArrayAttr:$slice_sizes /*gather_i7*/, |
| DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted /*gather_i8*/ |
| ); |
| |
| let results = (outs HLO_Tensor:$result); |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_GetDimensionSizeOp: StableHLO_Op<"get_dimension_size", |
| [Pure, InferTensorType]> { |
| let summary = "GetDimensionSize operation"; |
| let description = [{ |
| Produces the size of the given `dimension` of the `operand`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#get_dimension_size |
| |
| Example: |
| ```mlir |
| %result = stablehlo.get_dimension_size %operand, dim = 1 : (tensor<2x3xi64>) -> tensor<i32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_TensorOrPerAxisQuantizedTensor:$operand, /*get_dimension_size_i1*/ |
| I64Attr:$dimension /*get_dimension_size_i2*/ |
| ); |
| // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the |
| // XLA semantics is available. This limitation is because of the current XLA |
| // implementation. |
| let results = (outs I32RankedTensor); |
| |
| let assemblyFormat = [{ |
| $operand `,` `dim` `=` $dimension attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_MapOp: StableHLO_ShapedInterfaceOp<"map", |
| [HLO_RecursivelySpeculatableIfAllInputsStatic, RecursiveMemoryEffects, |
| SameOperandsAndResultShape /*map_c1, map_c2*/, |
| SingleBlockImplicitTerminator<"ReturnOp">, InferTensorTypeWithReify]> { |
| let summary = "Map operation"; |
| let description = [{ |
| Applies a map function `computation` to `inputs` along the `dimensions` and |
| produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#map |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.map"(%input0, %input1) ({ |
| ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): |
| %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64> |
| stablehlo.return %0 : tensor<i64> |
| }) { |
| dimensions = array<i64: 0, 1> |
| } : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> |
| ``` |
| }]; |
| let arguments = (ins |
| Variadic<HLO_Tensor>:$inputs /*map_i1*/, |
| GenericDenseI64ArrayAttr:$dimensions /*map_i2*/ |
| ); |
| let regions = (region SizedRegion<1>:$computation /*map_i3*/); |
| let results = (outs HLO_Tensor); |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_ReshapeOp: StableHLO_Op<"reshape", |
| [ConditionallySpeculatable, NoMemoryEffect, |
| HLO_CompatibleOperandsAndResultElementType]> { |
| let summary = "Reshape operation"; |
| let description = [{ |
| Performs reshape of `operand` tensor to a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape |
| |
| Example: |
| ```mlir |
| %result = stablehlo.reshape %operand : (tensor<2xf32>) -> tensor<1x2xf32> |
| ``` |
| }]; |
| |
| let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand); |
| |
| let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); |
| let hasVerifier = 1; |
| |
| let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", |
| [ConditionallySpeculatable, NoMemoryEffect]> { |
| let summary = "DynamicReshape operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation does the same thing as ReshapeOp except that the |
| result shape is specified dynamically via `output_shape`: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reshape |
| |
| Example: |
| ```mlir |
| %0 = stablehlo.dynamic_reshape %arg0, %shape : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32> |
| ``` |
| }]; |
| |
| let arguments = (ins HLO_Tensor:$operand, HLO_StaticDimensionTensor:$output_shape); |
| let results = (outs HLO_Tensor:$result); |
| |
| let hasVerifier = 1; |
| |
| let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_ScatterOp: StableHLO_Op<"scatter", |
| [ConditionallySpeculatable, RecursiveMemoryEffects, |
| SameVariadicOperandSize /*scatter_c5*/, |
| DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c16, |
| scater_c17*/]> { |
| let summary = "Scatter operation"; |
| let description = [{ |
| Produces `results` tensors which are equal to `inputs` tensors except that |
| several slices specified by `scatter_indices` are updated with the values |
| `updates` using `update_computation`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({ |
| ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): |
| %0 = stablehlo.add %arg0, %arg1 : tensor<i64> |
| stablehlo.return %0 : tensor<i64> |
| }) { |
| scatter_dimension_numbers = #stablehlo.scatter< |
| update_window_dims = [2, 3], |
| inserted_window_dims = [0], |
| scatter_dims_to_operand_dims = [1, 0], |
| index_vector_dim = 2>, |
| indices_are_sorted = false, |
| unique_indices = false |
| } : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64> |
| ``` |
| }]; |
| let arguments = (ins |
| Variadic<HLO_Tensor>:$inputs, /*scatter_i1*/ |
| RankedTensorOf<[AnyInteger, Index]>:$scatter_indices, /*scatter_i2*/ |
| Variadic<HLO_Tensor>:$updates, /*scatter_i3*/ |
| StableHLO_ScatterDimensionNumbers:$scatter_dimension_numbers, /*scatter_i4...scatter_i7*/ |
| DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted, /*scatter_i8*/ |
| DefaultValuedOptionalAttr<BoolAttr, "false">:$unique_indices /*scatter_i9*/ |
| ); |
| |
| let regions = (region SizedRegion<1>:$update_computation /*scatter_i10*/); |
| |
| let results = (outs Variadic<HLO_Tensor>); |
| |
| let hasVerifier = 1; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_SelectOp: StableHLO_Op<"select", |
| [HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect, |
| HLO_BroadcastingElementwise, |
| InferTensorTypeWithReify]> { |
| let summary = "Select operation"; |
| let description = [{ |
| Produces a `result` tensor where each element is selected from `on_true` or |
| `on_false` tensor based on the value of the corresponding element of `pred`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select |
| |
| Example: |
| ```mlir |
| %result = stablehlo.select %pred, %on_true, %on_false : tensor<2x2xi1>, tensor<2x2xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_PredTensor:$pred, /*select_i1*/ |
| HLO_Tensor:$on_true, /*select_i2*/ |
| HLO_Tensor:$on_false /*select_i3*/ |
| ); |
| |
| let results = (outs HLO_Tensor:$result); |
| |
| let assemblyFormat = [{ |
| operands attr-dict `:` |
| custom<SelectOpType>(type($pred), type($on_true), type($on_false), type($result)) |
| }]; |
| } |
| |
| def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter", |
| [HLO_RecursivelySpeculatableIfStaticDimInOutputIsStaticInInput, |
| DeclareOpInterfaceMethods<InferTypeOpInterface> /*select_and_scatter_c11, |
| select_and_scatter_c12*/, RecursiveMemoryEffects]> { |
| let summary = "SelectAndScatter operation"; |
| let description = [{ |
| Scatters the values from the `source` tensor using `scatter` based on the |
| outcome of `reduce_window` of the `input` tensor using `select` and produces |
| a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#select_and_scatter |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({ |
| ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): |
| %0 = stablehlo.compare GE, %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i1> |
| stablehlo.return %0 : tensor<i1> |
| }, { |
| ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): |
| %0 = stablehlo.add %arg0, %arg1 : tensor<i64> |
| stablehlo.return %0 : tensor<i64> |
| }) { |
| window_dimensions = dense<[3, 1]> : tensor<2xi64>, |
| window_strides = dense<[2, 1]> : tensor<2xi64>, |
| padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64> |
| } : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand, /*select_and_scatter_i1*/ |
| HLO_Tensor:$source, /*select_and_scatter_i2*/ |
| HLO_Tensor:$init_value, /*select_and_scatter_i3*/ |
| OptionalAttr<GenericDenseI64ArrayAttr>:$window_dimensions, /*select_and_scatter_i4*/ |
| OptionalAttr<GenericDenseI64ArrayAttr>:$window_strides, /*select_and_scatter_i5*/ |
| OptionalAttr<I64ElementsAttr>:$padding /*select_and_scatter_i6*/ |
| ); |
| |
| let regions = (region |
| SizedRegion<1>:$select, /*select_and_scatter_i7*/ |
| SizedRegion<1>:$scatter /*select_and_scatter_i8*/ |
| ); |
| |
| let results = (outs HLO_Tensor); |
| |
| let hasVerifier = 1; |
| } |
| |
| def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", |
| [ConditionallySpeculatable, NoMemoryEffect, |
| InferTensorType]> { |
| let summary = "SetDimensionSize operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation does the same thing as XLA's SetDimensionSize: |
| https://www.tensorflow.org/xla/operation_semantics#setdimensionsize |
| |
| Example: |
| ```mlir |
| %0 = stablehlo.set_dimension_size %arg0, %arg1, dim = 1 : (tensor<4x2xf32>, tensor<i32>) -> tensor<4x2xf32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| I32RankedTensor:$size, |
| I64Attr:$dimension |
| ); |
| let results = (outs HLO_Tensor); |
| |
| let assemblyFormat = [{ |
| $operand `,` $size `,` `dim` `=` $dimension attr-dict |
| `:` functional-type(operands, results) |
| }]; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_SortOp : StableHLO_Op<"sort", |
| [HLO_RecursivelySpeculatableIfAllInputsStatic, RecursiveMemoryEffects, |
| SameOperandsAndResultShape /*sort_c1, sort_c3*/, |
| InferTensorType /*sort_c2*/]> { |
| let summary = "Sort operation"; |
| let description = [{ |
| Sorts a variadic number of tensors in `inputs` together, according to a |
| custom `comparator`, along the given `dimension` and produces a variadic |
| number of tensors as `results`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sort |
| |
| Example: |
| ```mlir |
| %result0, %result1 = "stablehlo.sort"(%input0, %input1) ({ |
| ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>): |
| %predicate = stablehlo.compare GT, %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i1> |
| stablehlo.return %predicate : tensor<i1> |
| }) { |
| dimension = 0 : i64, |
| is_stable = true |
| } : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>) |
| }]; |
| let arguments = (ins |
| Variadic<HLO_Tensor>:$inputs /*sort_i1*/, |
| DefaultValuedOptionalAttr<I64Attr, "-1">:$dimension /*sort_i2*/, |
| DefaultValuedOptionalAttr<BoolAttr, "false">:$is_stable /*sort_i3*/ |
| ); |
| |
| let results = (outs Variadic<HLO_Tensor>); |
| |
| let regions = (region SizedRegion<1>:$comparator /*sort_i4*/); |
| |
| let builders = [ |
| OpBuilder<(ins "ValueRange":$inputs, CArg<"int64_t", "-1">:$dimension, |
| CArg<"bool", "false">:$is_stable)>]; |
| |
| let hasVerifier = 1; |
| } |
| |
| def StableHLO_ReverseOp: StableHLO_ShapedInterfaceOp<"reverse", |
| [HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, NoMemoryEffect, |
| SameOperandsAndResultElementType /*reverse_c1*/, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "Reverse operation"; |
| let description = [{ |
| Reverses the order of elements in the `operand` along the specified |
| `dimensions` and produces a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reverse |
| |
| Example: |
| ```mlir |
| %result = stablehlo.reverse %operand, dims = [1] : tensor<3x2xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| GenericDenseI64ArrayAttr:$dimensions |
| ); |
| |
| let hasVerifier = 1; |
| |
| let results = (outs HLO_Tensor:$result); |
| |
| let assemblyFormat = [{ |
| $operand `,` `dims` `=` custom<DenseI64Array>($dimensions) |
| attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result)) |
| }]; |
| } |
| |
| def StableHLO_PadOp: StableHLO_ShapedInterfaceOp<"pad", |
| [HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, NoMemoryEffect, |
| SameOperandsAndResultElementType /*pad_c1*/, |
| /*pad_c2, pad_i3, pad_i4, pad_i5*/ |
| AllMatchSameOperatorTrait<["edge_padding_low", "edge_padding_high", |
| "interior_padding"], "$_self.size()", "size">, |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "Pad operation"; |
| let description = [{ |
| Expands `operand` by padding around the tensor as well as between the |
| elements of the tensor with the given `padding_value`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad |
| |
| Example: |
| ```mlir |
| %0 = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [1, 2] |
| : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand /*pad_i1*/, |
| HLO_Tensor:$padding_value /*pad_i2*/, |
| GenericDenseI64ArrayAttr:$edge_padding_low /*pad_i3*/, |
| GenericDenseI64ArrayAttr:$edge_padding_high /*pad_i4*/, |
| GenericDenseI64ArrayAttr:$interior_padding /*pad_i5*/ |
| ); |
| |
| let results = (outs HLO_Tensor); |
| |
| let assemblyFormat = [{ |
| $operand `,` $padding_value `,` |
| `low` `=` custom<DenseI64Array>($edge_padding_low) `,` |
| `high` `=` custom<DenseI64Array>($edge_padding_high) `,` |
| `interior` `=` custom<DenseI64Array>($interior_padding) |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_TraceOp: StableHLO_Op<"trace"> { |
| let summary = "Trace operation"; |
| let description = [{ |
| This operation is on its way out of StableHLO, so it is not included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/604. |
| |
| It is not used by JAX, PyTorch or TensorFlow, so it looks like we should've |
| classified it as "Private to XLA" and not included it in StableHLO in the |
| first place. With that in mind, its semantics will not be documented here. |
| |
| Example: |
| ```mlir |
| stablehlo.trace %arg0, "In test code." : tensor<5x1x5xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| StrAttr:$tag |
| ); |
| let assemblyFormat = "$operand `,` $tag attr-dict `:` type($operand)"; |
| } |
| |
| def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose", |
| [ConditionallySpeculatable, NoMemoryEffect, |
| HLO_CompatibleOperandsAndResultElementType, /*transpose_c1*/ |
| DeclareOpInterfaceMethods<InferTypeOpInterface>]> { |
| let summary = "Transpose operation"; |
| let description = [{ |
| Permutes the dimensions of `operand` tensor using `permutation` and produces |
| a `result` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose |
| |
| Example: |
| ```mlir |
| %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<1x2x3xi32>) -> tensor<3x2x1xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_TensorOrPerAxisQuantizedTensor:$operand, |
| GenericDenseI64ArrayAttr:$permutation |
| ); |
| let results = (outs HLO_TensorOrPerAxisQuantizedTensor:$result); |
| |
| let hasVerifier = 1; |
| |
| let assemblyFormat = [{ |
| $operand `,` `dims` `=` custom<DenseI64Array>($permutation) |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_TriangularSolveOp: StableHLO_Op<"triangular_solve", |
| [NoMemoryEffect, ConditionallySpeculatable, |
| HLO_CompatibleOperandsAndResultElementType, InferTensorType]> { |
| let summary = "TriangularSolve operation"; |
| let description = [{ |
| Solves batches of systems of linear equations with lower or upper triangular |
| coefficient matrices. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#triangular_solve |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.triangular_solve"(%a, %b) { |
| left_side = true, |
| lower = true, |
| unit_diagonal = false, |
| transpose_a = #stablehlo<transpose NO_TRANSPOSE> |
| } : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_FpOrComplexTensor:$a, |
| HLO_FpOrComplexTensor:$b, |
| BoolAttr:$left_side, |
| BoolAttr:$lower, |
| BoolAttr:$unit_diagonal, |
| StableHLO_TransposeAttr:$transpose_a |
| ); |
| let results = (outs HLO_FpOrComplexTensor); |
| |
| let extraClassDeclaration = commonClassDeclaration # [{ |
| /// Interface method for ConditionallySpeculatable. |
| mlir::Speculation::Speculatability getSpeculatability(); |
| }]; |
| } |
| |
| def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [ |
| HLO_RecursivelySpeculatableIfAllInputsStatic, |
| RecursiveMemoryEffects, |
| SameVariadicOperandSize /*reduce_window_c1*/, |
| SingleBlockImplicitTerminator<"ReturnOp">, |
| InferTensorType /*reduce_window_c1, reduce_window_c14, reduce_window_c15, reduce_window_c16*/]> { |
| let summary = "ReduceWindow operation"; |
| let description = [{ |
| Applies a reduction function `body` to windows of `inputs` and `init_values` |
| and produces `results`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.reduce_window"(%input, %init_value) ({ |
| ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>): |
| %0 = stablehlo.add %arg0, %arg1 : tensor<i64> |
| stablehlo.return %0 : tensor<i64> |
| }) { |
| window_dimensions = array<i64: 2, 1>, |
| window_strides = array<i64: 4, 1>, |
| base_dilations = array<i64: 2, 1>, |
| window_dilations = array<i64: 3, 1>, |
| padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> |
| } : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| Variadic<HLO_Tensor>:$inputs /*reduce_window_i1*/, |
| Variadic<HLO_Tensor>:$init_values /*reduce_window_i2*/, |
| GenericDenseI64ArrayAttr:$window_dimensions /*reduce_window_i3*/, |
| // If strides or dilations attributes are missing then the default value is |
| // one for each of the operand dimensions. Similarly, padding values are zero |
| // for both low and high in each of the dimensions, if not specified. |
| OptionalAttr<GenericDenseI64ArrayAttr>:$window_strides /*reduce_window_i4*/, |
| OptionalAttr<GenericDenseI64ArrayAttr>:$base_dilations /*reduce_window_i5*/, |
| OptionalAttr<GenericDenseI64ArrayAttr>:$window_dilations /*reduce_window_i6*/, |
| OptionalAttr<I64ElementsAttr>:$padding /*reduce_window_i7*/ |
| ); |
| |
| let results = (outs Variadic<HLO_Tensor>); |
| |
| let regions = (region SizedRegion<1>:$body /*reduce_window_i8*/); |
| |
| let hasVerifier = 1; |
| |
| |
| // Builder for non-variadic version of the operation. |
| let builders = [ |
| OpBuilder<(ins "Type":$result_type, "Value":$operand, |
| "Value":$init_value, |
| "DenseI64ArrayAttr":$window_dimensions, |
| "DenseI64ArrayAttr":$window_strides, |
| "DenseI64ArrayAttr":$base_dilations, |
| "DenseI64ArrayAttr":$window_dilations, |
| "DenseIntElementsAttr":$padding), |
| [{ |
| build($_builder, $_state, TypeRange(result_type), ValueRange(operand), |
| ValueRange(init_value), window_dimensions, window_strides, |
| base_dilations, window_dilations, padding); |
| }]>, |
| OpBuilder<(ins "ValueRange":$operands, |
| "ValueRange":$init_values, |
| "DenseI64ArrayAttr":$window_dimensions, |
| "DenseI64ArrayAttr":$window_strides, |
| "DenseI64ArrayAttr":$base_dilations, |
| "DenseI64ArrayAttr":$window_dilations, |
| "DenseIntElementsAttr":$padding, |
| "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder |
| )>, |
| ]; |
| // TODO(hinsu): Implement custom printer and parser. |
| } |
| |
| def StableHLO_ReturnOp : StableHLO_Op<"return", [Pure, Terminator]> { |
| let summary = "Return operation"; |
| let summary = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/425. |
| |
| Informally, this operation serves as a terminator for regions defined by |
| the StableHLO ops. Non-StableHLO ops, e.g. `func.func`, have their own |
| terminators, e.g. `func.return`. |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.reduce"(%input, %init_value) ({ |
| ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): |
| %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32> |
| "stablehlo.return"(%0) : (tensor<i32>) -> () |
| }) { |
| dimensions = array<i64: 1> |
| } : (tensor<1x6xi32>, tensor<i32>) -> tensor<1xi32> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| Variadic<HLO_TensorOrPerAxisQuantizedTensorOrToken>:$results |
| ); |
| |
| let assemblyFormat = "$results attr-dict (`:` type($results)^)?"; |
| } |
| |
| def StableHLO_TorchIndexSelectOp : StableHLO_Op<"torch_index_select", [Pure]> { |
| let summary = "TorchIndexSelect operation"; |
| let description = [{ |
| This operation is on its way out of StableHLO, so it is not included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/3. |
| |
| Informally, this operation does the same thing as PyTorch's index_select, |
| augmented with support for batch dimensions: |
| https://pytorch.org/docs/stable/generated/torch.index_select.html. |
| |
| The `batch_dims` attribute specifies the number of major batch dimensions |
| (0 or more) that act like a multidimensional loop over both the operand and |
| the index. |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.torch_index_select"(%operand, %index) { |
| dim = 2 : i64, |
| batch_dims = 1 : i64 |
| } : (tensor<8x128x3072x64xf32>, tensor<8x16x1024xi32>) -> tensor<8x128x16x1024x64xf32> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| HLO_Tensor:$index, |
| I64Attr:$dim, |
| I64Attr:$batch_dims |
| ); |
| |
| let results = (outs HLO_Tensor); |
| } |
| |
| def StableHLO_OptimizationBarrierOp : StableHLO_Op<"optimization_barrier", |
| [Pure, HLO_PairwiseSameOperandAndResultType, |
| DeclareOpInterfaceMethods<InferTypeOpInterface> /*optimization_barrier_c1*/]> { |
| let summary = "OptimizationBarrier operation"; |
| let description = [{ |
| Ensures that the operations that produce the `operand` are executed before any |
| operations that depend on the `result` and prevents compiler transformations |
| from moving operations across the barrier. Other than that, the operation is |
| an identity, i.e. `result` = `operand`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#optimization_barrier |
| |
| Example: |
| ```mlir |
| %result0, %result1 = stablehlo.optimization_barrier %operand0, %operand1 : tensor<f32>, tensor<f32> |
| ``` |
| }]; |
| |
| let arguments = (ins Variadic<HLO_TensorOrToken>:$operand); |
| |
| let results = (outs Variadic<HLO_TensorOrToken>:$result); |
| |
| // Use `attr-dict` before `$operand` because Optional Group anchors in custom |
| // directives are currently not supported. Also since inputs are variadic, |
| // print `()` if no arguments are present, otherwise parsing is ambiguous: |
| // stablehlo.optimization_barrier |
| // %1 = stablehlo.add ... |
| // ^ Without lookahead, ambiguous if this is an operand to the previous line |
| // or the start of a separate operation, since newlines are ignored. |
| let assemblyFormat = [{ |
| attr-dict ($operand^ `:` custom<PairwiseOpType>(type($operand), type($result))):(`(` `)`)? |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO RNG operations. |
| //===----------------------------------------------------------------------===// |
| |
| def StableHLO_RngOp : StableHLO_Op<"rng", [InferTensorTypeWithReify, AllElementTypesMatch<["a", "b", "result"]>]> { |
| let summary = "Rng operation"; |
| let description = [{ |
| Generates random numbers using the `rng_distribution` algorithm and produces |
| a `result` tensor of a given shape `shape`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng |
| |
| Example: |
| ```mlir |
| %result = stablehlo.rng %a, %b, %shape, distribution = NORMAL : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32> |
| ``` |
| }]; |
| let arguments = (ins |
| 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$a, |
| 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$b, |
| HLO_StaticDimensionTensor:$shape, |
| StableHLO_RngDistributionAttr:$rng_distribution |
| ); |
| |
| let results = (outs HLO_PredIntOrFpRankedTensor:$result); |
| |
| let assemblyFormat = [{ |
| $a `,` $b `,` $shape `,` `distribution` `=` $rng_distribution |
| attr-dict `:` functional-type(operands, results) |
| }]; |
| } |
| |
| def StableHLO_RngBitGeneratorOp : StableHLO_Op<"rng_bit_generator", |
| [HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, |
| NoMemoryEffect]> { |
| let summary = "RngBitGenerator operation"; |
| let description = [{ |
| Returns an `output` filled with uniform random data and an updated output |
| state `output_state` given an initial state `initial_state` using the |
| pseudorandom number generator algorithm `rng_algorithm`. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng_bit_generator |
| |
| Example: |
| ```mlir |
| %output_state, %output = stablehlo.rng_bit_generator %initial_state, algorithm = THREE_FRY : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>) |
| ``` |
| }]; |
| let arguments = (ins |
| StableHLO_RngAlgorithmAttr:$rng_algorithm, |
| HLO_IntOrFpTensor:$initial_state |
| ); |
| |
| let results = (outs |
| HLO_IntOrFpTensor:$output_state, |
| HLO_StaticShapeIntOrFpTensor:$output |
| ); |
| |
| let hasVerifier = 1; |
| |
| let assemblyFormat = [{ |
| $initial_state `,` `algorithm` `=` $rng_algorithm attr-dict |
| `:` functional-type(operands, results) |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StableHLO Quantize operation. |
| //===----------------------------------------------------------------------===// |
| |
| // TODO(b/230662142): Implement unknown scales/zero_point cases. |
| def StableHLO_UniformQuantizeOp : StableHLO_UnaryElementwiseOp<"uniform_quantize", |
| [], TensorOf<[HLO_Float, HLO_QuantizedInt, HLO_PerAxisQuantizedInt]> /*uniform_quantize_i1*/, |
| TensorOf<[HLO_QuantizedInt, HLO_PerAxisQuantizedInt]>> { /*uniform_quantize_c1*/ |
| let summary = "UniformQuantize operation"; |
| let description = [{ |
| Performs element-wise conversion of floating-point tensor or quantized |
| tensor `operand` to a quantized tensor `result` according to the |
| quantization parameters defined by the `result` type. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_quantize |
| |
| Example: |
| ```mlir |
| %result = stablehlo.uniform_quantize %operand : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_UniformDequantizeOp : StableHLO_UnaryElementwiseOp<"uniform_dequantize", |
| [InferTensorType], TensorOf<[HLO_QuantizedInt, HLO_PerAxisQuantizedInt]> /*uniform_dequantize_i1*/, |
| HLO_FpTensor> { /*uniform_dequantize_c1, uniform_dequantize_c2*/ |
| let summary = "UniformDequantize operation"; |
| let description = [{ |
| Performs element-wise conversion of quantized tensor `operand` to a |
| floating-point tensor `result` according to the quantization parameters |
| defined by the `operand` type. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_dequantize |
| |
| Example: |
| ```mlir |
| %result = stablehlo.uniform_dequantize %operand : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32> |
| ``` |
| }]; |
| } |
| |
| def StableHLO_ReducePrecisionOp : StableHLO_Op<"reduce_precision", |
| [HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, NoMemoryEffect, |
| Elementwise, |
| HLO_CompatibleOperandsAndResultType /*reduce_precision_c1*/]> { |
| let summary = "ReducePrecision operation"; |
| let description = [{ |
| Performs element-wise conversion of `operand` to another floating-point type |
| that uses `exponent_bits` and `mantissa_bits` and back to the original |
| floating-point type and produces an `output` tensor. |
| |
| See: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_precision |
| |
| Example: |
| ```mlir |
| %output = stablehlo.reduce_precision %operand, format = e5m10 : tensor<6xf64> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_FpOrQuantizedIntTensor:$operand, /*reduce_precision_i1*/ |
| I32Attr:$exponent_bits, /*reduce_precision_i2*/ |
| I32Attr:$mantissa_bits /*reduce_precision_i3*/ |
| ); |
| let hasVerifier = 1; |
| let results = (outs HLO_FpOrQuantizedIntTensor:$output); |
| |
| let assemblyFormat = [{ |
| $operand `,` `format` `=` custom<ExponentMantissa>($exponent_bits, $mantissa_bits) |
| attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($output)) |
| }]; |
| } |
| |
| def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp< |
| "real_dynamic_slice", |
| [Pure, AllElementTypesMatch<["operand", "result"]>, |
| AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { |
| let summary = "RealDynamicSlice operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation does the same thing as SliceOp except |
| that `start_indices`, `limit_indices` and `strides` are specified dynamically: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#slice |
| |
| Example: |
| ```mlir |
| %result = stablehlo.real_dynamic_slice %operand, |
| %start_indices, %limit_indices, %strides |
| : (tensor<256x?xf32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<256x?xf32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| HLO_DimensionTensor:$start_indices, |
| HLO_DimensionTensor:$limit_indices, |
| HLO_DimensionTensor:$strides |
| ); |
| let results = (outs HLO_Tensor:$result); |
| let hasVerifier = 1; |
| |
| let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; |
| } |
| |
| def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad", |
| [Pure, AllElementTypesMatch<["operand", "padding_value", "result"]>, |
| AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> { |
| let summary = "DynamicPad operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation does the same thing as PadOp except |
| that `edge_padding_low`, `edge_padding_high` and `interior_padding` are |
| specified dynamically: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad |
| |
| Example: |
| ```mlir |
| %result = stablehlo.dynamic_pad %operand, %padding_value, |
| %edge_padding_low, %edge_padding_high, %interior_padding |
| : (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32> |
| ``` |
| }]; |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| HLO_Tensor:$padding_value, |
| HLO_DimensionTensor:$edge_padding_low, |
| HLO_DimensionTensor:$edge_padding_high, |
| HLO_DimensionTensor:$interior_padding |
| ); |
| let results = (outs HLO_Tensor:$result); |
| let description = [{ |
| Dynamically Pads the `operand`, with amount of padding added at |
| low-end/high-end/interior is passed through input tensors. |
| }]; |
| let hasVerifier = 1; |
| |
| let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; |
| } |
| |
| def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather", |
| [InferTensorTypeWithReify, Pure]> { |
| let summary = "DynamicGather operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation does the same thing as GatherOp except |
| that `slice_sizes` are specified dynamically: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slice_sizes) { |
| dimension_numbers = #stablehlo.gather< |
| offset_dims = [2, 3], |
| collapsed_slice_dims = [0], |
| start_index_map = [0, 2], |
| index_vector_dim = 2>, |
| indices_are_sorted = false |
| } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi32> |
| ``` |
| }]; |
| |
| let arguments = (ins |
| HLO_Tensor:$operand, |
| HLO_IntTensor:$start_indices, |
| HLO_IntTensor:$slice_sizes, |
| StableHLO_GatherDimensionNumbers:$dimension_numbers, |
| DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted |
| ); |
| let results = (outs HLO_Tensor); |
| } |
| |
| def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [Pure]> { |
| let summary = "DynamicConv operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation does the same thing as ConvolutionOp except |
| that `padding` is specified dynamically via `d_padding`: |
| https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution |
| |
| Example: |
| ```mlir |
| %result = "stablehlo.dynamic_conv"(%lhs, %rhs, %d_padding) { |
| window_strides = array<i64: 4, 4>, |
| lhs_dilation = array<i64: 2, 2>, |
| rhs_dilation = array<i64: 1, 1>, |
| window_reversal = array<i1: false, false>, |
| dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, |
| feature_group_count = 1 : i64, |
| batch_group_count = 1 : i64, |
| precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>] |
| } : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>, tensor<2x2xi64>) -> tensor<1x2x2x1xi32> |
| ``` |
| }]; |
| |
| let arguments = !con( |
| (ins |
| HLO_Tensor:$lhs, |
| HLO_Tensor:$rhs, |
| HLO_Tensor:$d_padding), |
| StableHLO_ConvolutionAttributes.attributes); |
| let results = (outs HLO_Tensor); |
| } |
| |
| def StableHLO_ComputeReshapeShapeOp : StableHLO_Op< |
| "compute_reshape_shape", |
| [Pure, AllShapesMatch<["dynamic_shape", "result"]>]> { |
| let summary = "ComputeReshapeShape operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation computes an output_shape for DynamicReshapeOp |
| from the `num_elements` number of elements in an operand of DynamicReshapeOp |
| and the `dynamic_shape` shape provided to TF's reshape: |
| https://www.tensorflow.org/api_docs/python/tf/reshape |
| |
| For example, for `num_elements = 12` and `dynamic_shape = [2, -1]`, |
| the `result` is going to be `[2, 6]`. If operands are not valid (e.g. if |
| dimensions do not evenly divide the number of elements, or if there are |
| multiple -1 values in dimensions), this leads to undefined behavior. |
| |
| Example: |
| ```mlir |
| %result = stablehlo.compute_reshape_shape %num_elements, %dynamic_shape |
| : (index, tensor<2xi32>) -> tensor<2xi32> |
| ``` |
| }]; |
| |
| let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape); |
| let results = (outs 1DTensorOf<[AnyInteger, Index]>:$result); |
| |
| let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; |
| } |
| |
| def StableHLO_CstrReshapableOp : |
| StableHLO_Op<"cstr_reshapable", [Pure]> { |
| let summary = "CstrReshapable operation"; |
| let description = [{ |
| This operation is a work in progress, so it is not yet included in |
| the StableHLO specification: https://github.com/openxla/stablehlo/issues/8. |
| |
| Informally, this operation creates a witness on the constraint that |
| ComputeReshapeShape would succeed with the provided operands. |
| |
| Example: |
| ```mlir |
| %result = stablehlo.cstr_reshapable %num_elements, %dynamic_shape |
| : (index, tensor<3xi32>) -> !shape.witness |
| ``` |
| }]; |
| |
| let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape); |
| let results = (outs Shape_WitnessType:$result); |
| |
| let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; |
| } |
| |
| #endif // STABLEHLO_DIALECT_STABLEHLO_OPS |