blob: 359d7b2279639acb4c5e74930b72aebf24bdcb06 [file] [log] [blame] [edit]
//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert the Arith dialect to the EmitC
// dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Conversion Patterns
//===----------------------------------------------------------------------===//
namespace {
class ArithConstantOpConversionPattern
: public OpConversionPattern<arith::ConstantOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp arithConst,
arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
if (!newTy)
return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
adaptor.getValue());
return success();
}
};
/// Get the signed or unsigned type corresponding to \p ty.
Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
if (isa<IntegerType>(ty)) {
if (ty.isUnsignedInteger() != needsUnsigned) {
auto signedness = needsUnsigned
? IntegerType::SignednessSemantics::Unsigned
: IntegerType::SignednessSemantics::Signed;
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
signedness);
}
} else if (emitc::isPointerWideType(ty)) {
if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
if (needsUnsigned)
return emitc::SizeTType::get(ty.getContext());
return emitc::PtrDiffTType::get(ty.getContext());
}
}
return ty;
}
/// Insert a cast operation to type \p ty if \p val does not have this type.
Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
}
class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isa<FloatType>(adaptor.getRhs().getType())) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cmpf currently only supported on "
"floats, not tensors/vectors thereof");
}
bool unordered = false;
emitc::CmpPredicate predicate;
switch (op.getPredicate()) {
case arith::CmpFPredicate::AlwaysFalse: {
auto constant = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rewriter.getI1Type(),
rewriter.getBoolAttr(/*value=*/false));
rewriter.replaceOp(op, constant);
return success();
}
case arith::CmpFPredicate::OEQ:
unordered = false;
predicate = emitc::CmpPredicate::eq;
break;
case arith::CmpFPredicate::OGT:
unordered = false;
predicate = emitc::CmpPredicate::gt;
break;
case arith::CmpFPredicate::OGE:
unordered = false;
predicate = emitc::CmpPredicate::ge;
break;
case arith::CmpFPredicate::OLT:
unordered = false;
predicate = emitc::CmpPredicate::lt;
break;
case arith::CmpFPredicate::OLE:
unordered = false;
predicate = emitc::CmpPredicate::le;
break;
case arith::CmpFPredicate::ONE:
unordered = false;
predicate = emitc::CmpPredicate::ne;
break;
case arith::CmpFPredicate::ORD: {
// ordered, i.e. none of the operands is NaN
auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
adaptor.getRhs());
rewriter.replaceOp(op, cmp);
return success();
}
case arith::CmpFPredicate::UEQ:
unordered = true;
predicate = emitc::CmpPredicate::eq;
break;
case arith::CmpFPredicate::UGT:
unordered = true;
predicate = emitc::CmpPredicate::gt;
break;
case arith::CmpFPredicate::UGE:
unordered = true;
predicate = emitc::CmpPredicate::ge;
break;
case arith::CmpFPredicate::ULT:
unordered = true;
predicate = emitc::CmpPredicate::lt;
break;
case arith::CmpFPredicate::ULE:
unordered = true;
predicate = emitc::CmpPredicate::le;
break;
case arith::CmpFPredicate::UNE:
unordered = true;
predicate = emitc::CmpPredicate::ne;
break;
case arith::CmpFPredicate::UNO: {
// unordered, i.e. either operand is nan
auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
adaptor.getRhs());
rewriter.replaceOp(op, cmp);
return success();
}
case arith::CmpFPredicate::AlwaysTrue: {
auto constant = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rewriter.getI1Type(),
rewriter.getBoolAttr(/*value=*/true));
rewriter.replaceOp(op, constant);
return success();
}
}
// Compare the values naively
auto cmpResult =
rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
adaptor.getLhs(), adaptor.getRhs());
// Adjust the results for unordered/ordered semantics
if (unordered) {
auto isUnordered = createCheckIsUnordered(
rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
isUnordered, cmpResult);
return success();
}
auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
isOrdered, cmpResult);
return success();
}
private:
/// Return a value that is true if \p operand is NaN.
Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is NaN exactly when it compares unequal to itself.
return rewriter.create<emitc::CmpOp>(
loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
}
/// Return a value that is true if \p operand is not NaN.
Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is not NaN exactly when it compares equal to itself.
return rewriter.create<emitc::CmpOp>(
loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
}
/// Return a value that is true if the operands \p first and \p second are
/// unordered (i.e., at least one of them is NaN).
Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
Location loc, Value first, Value second) const {
auto firstIsNaN = isNaN(rewriter, loc, first);
auto secondIsNaN = isNaN(rewriter, loc, second);
return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
firstIsNaN, secondIsNaN);
}
/// Return a value that is true if the operands \p first and \p second are
/// both ordered (i.e., none one of them is NaN).
Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
Value first, Value second) const {
auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
firstIsNotNaN, secondIsNotNaN);
}
};
class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
public:
using OpConversionPattern::OpConversionPattern;
bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
switch (pred) {
case arith::CmpIPredicate::eq:
case arith::CmpIPredicate::ne:
case arith::CmpIPredicate::slt:
case arith::CmpIPredicate::sle:
case arith::CmpIPredicate::sgt:
case arith::CmpIPredicate::sge:
return false;
case arith::CmpIPredicate::ult:
case arith::CmpIPredicate::ule:
case arith::CmpIPredicate::ugt:
case arith::CmpIPredicate::uge:
return true;
}
llvm_unreachable("unknown cmpi predicate kind");
}
emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
switch (pred) {
case arith::CmpIPredicate::eq:
return emitc::CmpPredicate::eq;
case arith::CmpIPredicate::ne:
return emitc::CmpPredicate::ne;
case arith::CmpIPredicate::slt:
case arith::CmpIPredicate::ult:
return emitc::CmpPredicate::lt;
case arith::CmpIPredicate::sle:
case arith::CmpIPredicate::ule:
return emitc::CmpPredicate::le;
case arith::CmpIPredicate::sgt:
case arith::CmpIPredicate::ugt:
return emitc::CmpPredicate::gt;
case arith::CmpIPredicate::sge:
case arith::CmpIPredicate::uge:
return emitc::CmpPredicate::ge;
}
llvm_unreachable("unknown cmpi predicate kind");
}
LogicalResult
matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type type = adaptor.getLhs().getType();
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
return success();
}
};
class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto adaptedOp = adaptor.getOperand();
auto adaptedOpType = adaptedOp.getType();
if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
return rewriter.notifyMatchFailure(
op.getLoc(),
"negf currently only supports scalar types, not vectors or tensors");
}
if (!emitc::isSupportedFloatType(adaptedOpType)) {
return rewriter.notifyMatchFailure(
op.getLoc(), "floating-point type is not supported by EmitC");
}
rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
adaptedOp);
return success();
}
};
template <typename ArithOp, bool castToUnsigned>
class CastConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type opReturnType = this->getTypeConverter()->convertType(op.getType());
if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
emitc::isPointerWideType(opReturnType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
op, "CastConversion only supports unary ops");
}
Type operandType = adaptor.getIn().getType();
if (!operandType || !(isa<IntegerType>(operandType) ||
emitc::isPointerWideType(operandType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
// Signed (sign-extending) casts from i1 are not supported.
if (operandType.isInteger(1) && !castToUnsigned)
return rewriter.notifyMatchFailure(op,
"operation not supported on i1 type");
// to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
// truncation.
if (opReturnType.isInteger(1)) {
Type attrType = (emitc::isPointerWideType(operandType))
? rewriter.getIndexType()
: operandType;
auto constOne = rewriter.create<emitc::ConstantOp>(
op.getLoc(), operandType, rewriter.getOneAttr(attrType));
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
oneAndOperand);
return success();
}
bool isTruncation =
(isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth());
bool doUnsigned = castToUnsigned || isTruncation;
// Adapt the signedness of the result (bitwidth-preserving cast)
// This is needed e.g., if the return type is signless.
Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
// Adapt the signedness of the operand (bitwidth-preserving cast)
Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
// Actual cast (may change bitwidth)
auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
castDestType, actualOp);
// Cast to the expected output type
auto result = adaptValueType(cast, rewriter, opReturnType);
rewriter.replaceOp(op, result);
return success();
}
};
template <typename ArithOp>
class UnsignedCastConversion : public CastConversion<ArithOp, true> {
using CastConversion<ArithOp, true>::CastConversion;
};
template <typename ArithOp>
class SignedCastConversion : public CastConversion<ArithOp, false> {
using CastConversion<ArithOp, false>::CastConversion;
};
template <typename ArithOp, typename EmitCOp>
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
if (!newTy)
return rewriter.notifyMatchFailure(arithOp,
"converting result type failed");
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
adaptor.getOperands());
return success();
}
};
template <class ArithOp, class EmitCOp>
class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
if (!newRetTy)
return rewriter.notifyMatchFailure(uiBinOp,
"converting result type failed");
if (!isa<IntegerType>(newRetTy)) {
return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
}
Type unsignedType =
adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
if (!unsignedType)
return rewriter.notifyMatchFailure(uiBinOp,
"converting result type failed");
Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
auto newDivOp =
rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
ArrayRef<Value>{lhsAdapted, rhsAdapted});
Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
rewriter.replaceOp(uiBinOp, resultAdapted);
return success();
}
};
template <typename ArithOp, typename EmitCOp>
class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type type = this->getTypeConverter()->convertType(op.getType());
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
if (type.isInteger(1)) {
// arith expects wrap-around arithmethic, which doesn't happen on `bool`.
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}
Type arithmeticType = type;
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
!bitEnumContainsAll(op.getOverflowFlags(),
arith::IntegerOverflowFlags::nsw)) {
// If the C type is signed and the op doesn't guarantee "No Signed Wrap",
// we compute in unsigned integers to avoid UB.
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
Value arithmeticResult = rewriter.template create<EmitCOp>(
op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type);
rewriter.replaceOp(op, result);
return success();
}
};
template <typename ArithOp, typename EmitCOp>
class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType>(type)) {
return rewriter.notifyMatchFailure(
op,
"expected integer type, vector/tensor support not yet implemented");
}
// Bitwise ops can be performed directly on booleans
if (type.isInteger(1)) {
rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
adaptor.getRhs());
return success();
}
// Bitwise ops are defined by the C standard on unsigned operands.
Type arithmeticType =
adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
Value arithmeticResult = rewriter.template create<EmitCOp>(
op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type);
rewriter.replaceOp(op, result);
return success();
}
};
template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
class ShiftOpConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type type = this->getTypeConverter()->convertType(op.getType());
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
if (type.isInteger(1)) {
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}
Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
// Shift amount interpreted as unsigned per Arith dialect spec.
Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
/*needsUnsigned=*/true);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
// Add a runtime check for overflow
Value width;
if (emitc::isPointerWideType(type)) {
Value eight = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType, rewriter.getIndexAttr(8));
emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
sizeOfCall.getResult(0));
} else {
width = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType,
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
}
Value excessCheck = rewriter.create<emitc::CmpOp>(
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
// Any concrete value is a valid refinement of poison.
Value poison = rewriter.create<emitc::ConstantOp>(
op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
op.getLoc(), arithmeticType, /*do_not_inline=*/false);
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
auto currentPoint = rewriter.getInsertionPoint();
rewriter.setInsertionPointToStart(&bodyBlock);
Value arithmeticResult =
rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
Value result = adaptValueType(ternary, rewriter, type);
rewriter.replaceOp(op, result);
return success();
}
};
template <typename ArithOp, typename EmitCOp>
class SignedShiftOpConversion final
: public ShiftOpConversion<ArithOp, EmitCOp, false> {
using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
};
template <typename ArithOp, typename EmitCOp>
class UnsignedShiftOpConversion final
: public ShiftOpConversion<ArithOp, EmitCOp, true> {
using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
};
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(selectOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
if (!adaptor.getCondition().getType().isInteger(1))
return rewriter.notifyMatchFailure(
selectOp,
"can only be converted if condition is a scalar of type i1");
rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
adaptor.getOperands());
return success();
}
};
// Floating-point to integer conversions.
template <typename CastOp>
class FtoICastOpConversion : public OpConversionPattern<CastOp> {
public:
FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<CastOp>(typeConverter, context) {}
LogicalResult
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedFloatType(operandType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
// Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
// truncated to 0, whereas a boolean conversion would return true.
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");
// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
Type actualResultType = dstType;
if (isa<arith::FPToUIOp>(castOp)) {
actualResultType =
rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
Value result = rewriter.create<emitc::CastOp>(
castOp.getLoc(), actualResultType, adaptor.getOperands());
if (isa<arith::FPToUIOp>(castOp)) {
result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
}
rewriter.replaceOp(castOp, result);
return success();
}
};
// Integer to floating-point conversions.
template <typename CastOp>
class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
public:
ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<CastOp>(typeConverter, context) {}
LogicalResult
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Vectors in particular are not supported
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedIntegerType(operandType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
if (!emitc::isSupportedFloatType(dstType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");
// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
Type actualOperandType = operandType;
if (isa<arith::UIToFPOp>(castOp)) {
actualOperandType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
Value fpCastOperand = adaptor.getIn();
if (actualOperandType != operandType) {
fpCastOperand = rewriter.template create<emitc::CastOp>(
castOp.getLoc(), actualOperandType, fpCastOperand);
}
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
return success();
}
};
// Floating-point to floating-point conversions.
template <typename CastOp>
class FpCastOpConversion : public OpConversionPattern<CastOp> {
public:
FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<CastOp>(typeConverter, context) {}
LogicalResult
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Vectors in particular are not supported.
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedFloatType(operandType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");
if (auto roundingModeOp =
dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
// Only supporting default rounding mode as of now.
if (roundingModeOp.getRoundingModeAttr())
return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
}
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
if (!emitc::isSupportedFloatType(dstType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");
Value fpCastOperand = adaptor.getIn();
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
mlir::populateEmitCSizeTTypeConversions(typeConverter);
// clang-format off
patterns.add<
ArithConstantOpConversionPattern,
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
CmpFOpConversion,
CmpIOpConversion,
NegFOpConversion,
SelectOpConversion,
// Truncation is guaranteed for unsigned types.
UnsignedCastConversion<arith::TruncIOp>,
SignedCastConversion<arith::ExtSIOp>,
UnsignedCastConversion<arith::ExtUIOp>,
SignedCastConversion<arith::IndexCastOp>,
UnsignedCastConversion<arith::IndexCastUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
FtoICastOpConversion<arith::FPToUIOp>,
FpCastOpConversion<arith::ExtFOp>,
FpCastOpConversion<arith::TruncFOp>
>(typeConverter, ctx);
// clang-format on
}