| // Copyright 2022 OpenXLA Authors. All Rights Reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "mlir/Dialect/Tosa/IR/TosaOps.td" |
| #include "stablehlo/dialect/StablehloOps.td" |
| |
| // Helper functions. |
| Rewrite onesLike(op: Op, type: Type) -> Op [{ |
| auto elementType = type.cast<mlir::TensorType>().getElementType(); |
| llvm::SmallVector<mlir::Attribute, 4> outputValue; |
| |
| if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) { |
| outputValue.push_back(rewriter.getFloatAttr(elementType, 1)); |
| } else { |
| outputValue.push_back(rewriter.getIntegerAttr(elementType, 1)); |
| } |
| |
| return rewriter.create<mlir::tosa::ConstOp>( |
| op->getLoc(), type, |
| mlir::DenseElementsAttr::get( |
| llvm::cast<mlir::ShapedType>(type), outputValue)); |
| }]; |
| |
| Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{ |
| auto elementType = type.cast<mlir::TensorType>().getElementType(); |
| const llvm::fltSemantics& semantic = |
| elementType.cast<mlir::FloatType>().getFloatSemantics(); |
| |
| llvm::SmallVector<mlir::Attribute, 4> outputValue; |
| outputValue.push_back(rewriter.getFloatAttr( |
| elementType, llvm::APFloat::getInf(semantic, false))); |
| |
| return rewriter.create<mlir::tosa::ConstOp>( |
| op->getLoc(), type, |
| mlir::DenseElementsAttr::get( |
| llvm::cast<mlir::ShapedType>(type), outputValue)); |
| }]; |
| |
| // Nullary ops. |
| Pattern => |
| replace op<stablehlo.constant> {value = input: Attr<_: Tosa_Tensor>} |
| with op<tosa.const> {value = input}; |
| |
| // Unary ops. |
| Pattern => |
| replace op<stablehlo.abs>(input : Value<_: Tosa_Tensor>) |
| with op<tosa.abs>(input); |
| Pattern => |
| replace op<stablehlo.ceil>(input : Value<_: Tosa_Tensor>) |
| with op<tosa.ceil>(input); |
| Pattern => |
| replace op<stablehlo.convert>(input : Value<_: Tosa_Tensor>) |
| with op<tosa.cast>(input); |
| Pattern => |
| replace op<stablehlo.exponential>(input : Value<_: Tosa_Tensor>) |
| with op<tosa.exp>(input); |
| Pattern { |
| let root = op<stablehlo.exponential_minus_one> |
| (input : Value<inputType: Tosa_Tensor>); |
| rewrite root with { |
| let ones = onesLike(root, inputType); |
| let expResult = op<tosa.exp>(input) -> (inputType); |
| let expMinusOneResult = op<tosa.sub>(expResult, ones) -> (inputType); |
| replace root with expMinusOneResult; |
| }; |
| } |
| Pattern => |
| replace op<stablehlo.floor>(input : Value<_: Tosa_Tensor>) |
| with op<tosa.floor>(input); |
| Pattern { |
| let root = op<stablehlo.is_finite>(input : Value<inputType: Tosa_Tensor>); |
| rewrite root with { |
| let positiveInfinity = positiveFloatInfinityLike(root, inputType); |
| let inputAbs = op<tosa.abs>(input) -> (inputType); |
| let equalsResult = op<tosa.equal>(positiveInfinity, inputAbs); |
| let notEqualsResult = op<tosa.logical_not>(equalsResult); |
| replace root with notEqualsResult; |
| }; |
| } |
| Pattern => |
| replace op<stablehlo.log>(input : Value<_: Tosa_Tensor>) |
| with op<tosa.log>(input); |
| Pattern { |
| let root = op<stablehlo.log_plus_one>(input : Value<inputType: Tosa_Tensor>); |
| rewrite root with { |
| let ones = onesLike(root, inputType); |
| let addResult = op<tosa.add>(input, ones) -> (inputType); |
| let logPlusOneResult = op<tosa.log>(addResult) -> (inputType); |
| replace root with logPlusOneResult; |
| }; |
| } |
| Pattern => |
| replace op<stablehlo.negate>(input : Value<_: Tosa_Tensor>) |
| with op<tosa.negate>(input); |
| Pattern => |
| replace op<stablehlo.tanh>(input : Value<_: Tosa_Tensor>) |
| with op<tosa.tanh>(input); |
| |
| // Binary ops. |
| Pattern => |
| replace op<stablehlo.add>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.add>(input0, input1); |
| Pattern => |
| replace op<stablehlo.and>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.bitwise_and>(input0, input1); |
| Pattern => |
| replace op<stablehlo.divide>(input0 : Value<_: Tosa_Int32Tensor>, |
| input1 : Value<_: Tosa_Int32Tensor>) |
| with op<tosa.div>(input0, input1); |
| Pattern => |
| replace op<stablehlo.maximum>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.maximum>(input0, input1); |
| Pattern => |
| replace op<stablehlo.minimum>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.minimum>(input0, input1); |
| Pattern => |
| replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.mul>(input0, input1) {shift = attr<"0 : i32">}; |
| Pattern => |
| replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.bitwise_or>(input0, input1); |
| Pattern => |
| replace op<stablehlo.power>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.pow>(input0, input1); |
| Pattern => |
| replace op<stablehlo.shift_left>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.logical_left_shift>(input0, input1); |
| Pattern => |
| replace op<stablehlo.shift_right_logical>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.logical_right_shift>(input0, input1); |
| Pattern => |
| replace op<stablehlo.subtract>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.sub>(input0, input1); |
| Pattern => |
| replace op<stablehlo.xor>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>) |
| with op<tosa.bitwise_xor>(input0, input1); |
| |
| // Ternary ops. |
| Pattern => |
| replace op<stablehlo.select>(input0 : Value<_: Tosa_Tensor>, |
| input1 : Value<_: Tosa_Tensor>, |
| input2 : Value<_: Tosa_Tensor>) |
| with op<tosa.select>(input0, input1, input2); |