blob: 00cdb13feb29bb63004716ba5ddc61ee9cabd880 [file] [log] [blame]
//===- Arith.h - Arith dialect ------------------------------------*- C++-*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ARITH_IR_ARITH_H_
#define MLIR_DIALECT_ARITH_IR_ARITH_H_
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/StringExtras.h"
//===----------------------------------------------------------------------===//
// ArithDialect
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/ArithOpsDialect.h.inc"
//===----------------------------------------------------------------------===//
// Arith Dialect Enum Attributes
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/ArithOpsEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.h.inc"
//===----------------------------------------------------------------------===//
// Arith Interfaces
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.h.inc"
//===----------------------------------------------------------------------===//
// Arith Dialect Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Arith/IR/ArithOps.h.inc"
namespace mlir {
namespace arith {
/// Specialization of `arith.constant` op that returns an integer value.
class ConstantIntOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant int op that produces an integer of the specified width.
static void build(OpBuilder &builder, OperationState &result, int64_t value,
unsigned width);
/// Build a constant int op that produces an integer of the specified type,
/// which must be an integer type.
static void build(OpBuilder &builder, OperationState &result, int64_t value,
Type type);
inline int64_t value() {
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
}
static bool classof(Operation *op);
};
/// Specialization of `arith.constant` op that returns a floating point value.
class ConstantFloatOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant float op that produces a float of the specified type.
static void build(OpBuilder &builder, OperationState &result,
const APFloat &value, FloatType type);
inline APFloat value() {
return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
}
static bool classof(Operation *op);
};
/// Specialization of `arith.constant` op that returns an integer of index type.
class ConstantIndexOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant int op that produces an index.
static void build(OpBuilder &builder, OperationState &result, int64_t value);
inline int64_t value() {
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
}
static bool classof(Operation *op);
};
} // namespace arith
} // namespace mlir
//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//
namespace mlir {
namespace arith {
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
/// comparison predicates.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
const APInt &rhs);
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
/// comparison predicates.
bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs);
/// Returns the identity value attribute associated with an AtomicRMWKind op.
/// `useOnlyFiniteValue` defines whether the identity value should steer away
/// from infinity representations or anything that is not a proper finite
/// number.
/// E.g., The identity value for maxf is in theory `-Inf`, but if we want to
/// stay in the finite range, it would be `BiggestRepresentableNegativeFloat`.
/// The purpose of this boolean is to offer constants that will play nice
/// with fast math related optimizations.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc,
bool useOnlyFiniteValue = false);
/// Return the identity numeric value associated to the give op. Return
/// std::nullopt if there is no known neutral element.
/// If `op` has `FastMathFlags::ninf`, only finite values will be used
/// as neutral element.
std::optional<TypedAttr> getNeutralElement(Operation *op);
/// Returns the identity value associated with an AtomicRMWKind op.
/// \see getIdentityValueAttr for a description of what `useOnlyFiniteValue`
/// does.
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
Location loc, bool useOnlyFiniteValue = false);
/// Returns the value obtained by applying the reduction operation kind
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs);
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
} // namespace arith
} // namespace mlir
#endif // MLIR_DIALECT_ARITH_IR_ARITH_H_