blob: c9a23403768d5b415a27225d354239645db9b75c [file] [log] [blame]
// Copyright 2022 OpenXLA Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mlir/Dialect/Tosa/IR/TosaOps.td"
#include "stablehlo/dialect/StablehloOps.td"
// Helper functions.
Rewrite onesLike(op: Op, type: Type) -> Op [{
auto elementType = type.cast<mlir::TensorType>().getElementType();
llvm::SmallVector<mlir::Attribute, 4> outputValue;
if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) {
outputValue.push_back(rewriter.getFloatAttr(elementType, 1));
} else {
outputValue.push_back(rewriter.getIntegerAttr(elementType, 1));
}
return rewriter.create<mlir::tosa::ConstOp>(
op->getLoc(), type,
mlir::DenseElementsAttr::get(
llvm::cast<mlir::ShapedType>(type), outputValue));
}];
Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{
auto elementType = type.cast<mlir::TensorType>().getElementType();
const llvm::fltSemantics& semantic =
elementType.cast<mlir::FloatType>().getFloatSemantics();
llvm::SmallVector<mlir::Attribute, 4> outputValue;
outputValue.push_back(rewriter.getFloatAttr(
elementType, llvm::APFloat::getInf(semantic, false)));
return rewriter.create<mlir::tosa::ConstOp>(
op->getLoc(), type,
mlir::DenseElementsAttr::get(
llvm::cast<mlir::ShapedType>(type), outputValue));
}];
// Nullary ops.
Pattern =>
replace op<stablehlo.constant> {value = input: Attr<_: Tosa_Tensor>}
with op<tosa.const> {value = input};
// Unary ops.
Pattern =>
replace op<stablehlo.abs>(input : Value<_: Tosa_Tensor>)
with op<tosa.abs>(input);
Pattern =>
replace op<stablehlo.ceil>(input : Value<_: Tosa_Tensor>)
with op<tosa.ceil>(input);
Pattern =>
replace op<stablehlo.convert>(input : Value<_: Tosa_Tensor>)
with op<tosa.cast>(input);
Pattern =>
replace op<stablehlo.exponential>(input : Value<_: Tosa_Tensor>)
with op<tosa.exp>(input);
Pattern {
let root = op<stablehlo.exponential_minus_one>
(input : Value<inputType: Tosa_Tensor>);
rewrite root with {
let ones = onesLike(root, inputType);
let expResult = op<tosa.exp>(input) -> (inputType);
let expMinusOneResult = op<tosa.sub>(expResult, ones) -> (inputType);
replace root with expMinusOneResult;
};
}
Pattern =>
replace op<stablehlo.floor>(input : Value<_: Tosa_Tensor>)
with op<tosa.floor>(input);
Pattern {
let root = op<stablehlo.is_finite>(input : Value<inputType: Tosa_Tensor>);
rewrite root with {
let positiveInfinity = positiveFloatInfinityLike(root, inputType);
let inputAbs = op<tosa.abs>(input) -> (inputType);
let equalsResult = op<tosa.equal>(positiveInfinity, inputAbs);
let notEqualsResult = op<tosa.logical_not>(equalsResult);
replace root with notEqualsResult;
};
}
Pattern =>
replace op<stablehlo.log>(input : Value<_: Tosa_Tensor>)
with op<tosa.log>(input);
Pattern {
let root = op<stablehlo.log_plus_one>(input : Value<inputType: Tosa_Tensor>);
rewrite root with {
let ones = onesLike(root, inputType);
let addResult = op<tosa.add>(input, ones) -> (inputType);
let logPlusOneResult = op<tosa.log>(addResult) -> (inputType);
replace root with logPlusOneResult;
};
}
Pattern =>
replace op<stablehlo.negate>(input : Value<_: Tosa_Tensor>)
with op<tosa.negate>(input);
Pattern =>
replace op<stablehlo.tanh>(input : Value<_: Tosa_Tensor>)
with op<tosa.tanh>(input);
// Binary ops.
Pattern =>
replace op<stablehlo.add>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.add>(input0, input1);
Pattern =>
replace op<stablehlo.and>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.bitwise_and>(input0, input1);
Pattern =>
replace op<stablehlo.divide>(input0 : Value<_: Tosa_Int32Tensor>,
input1 : Value<_: Tosa_Int32Tensor>)
with op<tosa.div>(input0, input1);
Pattern =>
replace op<stablehlo.maximum>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.maximum>(input0, input1);
Pattern =>
replace op<stablehlo.minimum>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.minimum>(input0, input1);
Pattern =>
replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.mul>(input0, input1) {shift = attr<"0 : i32">};
Pattern =>
replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.bitwise_or>(input0, input1);
Pattern =>
replace op<stablehlo.power>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.pow>(input0, input1);
Pattern =>
replace op<stablehlo.shift_left>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.logical_left_shift>(input0, input1);
Pattern =>
replace op<stablehlo.shift_right_logical>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.logical_right_shift>(input0, input1);
Pattern =>
replace op<stablehlo.subtract>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.sub>(input0, input1);
Pattern =>
replace op<stablehlo.xor>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>)
with op<tosa.bitwise_xor>(input0, input1);
// Ternary ops.
Pattern =>
replace op<stablehlo.select>(input0 : Value<_: Tosa_Tensor>,
input1 : Value<_: Tosa_Tensor>,
input2 : Value<_: Tosa_Tensor>)
with op<tosa.select>(input0, input1, input2);