blob: c9b9bb0e32f1039571292a39df80b3c19a2eee56 [file] [log] [blame]
//===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines convenience types for working with standard operations
// in the MLIR operation set.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
#include "mlir/Analysis/CallInterfaces.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"
namespace mlir {
class AffineMap;
class Builder;
class FuncOp;
class OpBuilder;
class StandardOpsDialect : public Dialect {
public:
StandardOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "std"; }
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
};
/// The predicate indicates the type of the comparison to perform:
/// (un)orderedness, (in)equality and less/greater than (or equal to) as
/// well as predicates that are always true or false.
enum class CmpFPredicate {
FirstValidValue,
// Always false
AlwaysFalse = FirstValidValue,
// Ordered comparisons
OEQ,
OGT,
OGE,
OLT,
OLE,
ONE,
// Both ordered
ORD,
// Unordered comparisons
UEQ,
UGT,
UGE,
ULT,
ULE,
UNE,
// Any unordered
UNO,
// Always true
AlwaysTrue,
// Number of predicates.
NumPredicates
};
#define GET_OP_CLASSES
#include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
/// This is a refinement of the "constant" op for the case where it is
/// returning a float value of FloatType.
///
/// %1 = "std.constant"(){value: 42.0} : bf16
///
class ConstantFloatOp : public ConstantOp {
public:
using ConstantOp::ConstantOp;
/// Builds a constant float op producing a float of the specified type.
static void build(Builder *builder, OperationState &result,
const APFloat &value, FloatType type);
APFloat getValue() { return getAttrOfType<FloatAttr>("value").getValue(); }
static bool classof(Operation *op);
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value of IntegerType.
///
/// %1 = "std.constant"(){value: 42} : i32
///
class ConstantIntOp : public ConstantOp {
public:
using ConstantOp::ConstantOp;
/// Build a constant int op producing an integer of the specified width.
static void build(Builder *builder, OperationState &result, int64_t value,
unsigned width);
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
static void build(Builder *builder, OperationState &result, int64_t value,
Type type);
int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
static bool classof(Operation *op);
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value of Index type.
///
/// %1 = "std.constant"(){value: 99} : () -> index
///
class ConstantIndexOp : public ConstantOp {
public:
using ConstantOp::ConstantOp;
/// Build a constant int op producing an index.
static void build(Builder *builder, OperationState &result, int64_t value);
int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
static bool classof(Operation *op);
};
// DmaStartOp starts a non-blocking DMA operation that transfers data from a
// source memref to a destination memref. The source and destination memref need
// not be of the same dimensionality, but need to have the same elemental type.
// The operands include the source and destination memref's each followed by its
// indices, size of the data transfer in terms of the number of elements (of the
// elemental type of the memref), a tag memref with its indices, and optionally
// at the end, a stride and a number_of_elements_per_stride arguments. The tag
// location is used by a DmaWaitOp to check for completion. The indices of the
// source memref, destination memref, and the tag memref have the same
// restrictions as any load/store. The optional stride arguments should be of
// 'index' type, and specify a stride for the slower memory space (memory space
// with a lower memory space id), transferring chunks of
// number_of_elements_per_stride every stride until %num_elements are
// transferred. Either both or no stride arguments should be specified.
//
// For example, a DmaStartOp operation that transfers 256 elements of a memref
// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
// 1 at indices [%k, %l], would be specified as follows:
//
// %num_elements = constant 256
// %idx = constant 0 : index
// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
// memref<40 x 128 x f32>, (d0) -> (d0), 0>,
// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
// memref<1 x i32>, (d0) -> (d0), 2>
//
// If %stride and %num_elt_per_stride are specified, the DMA is expected to
// transfer %num_elt_per_stride elements every %stride elements apart from
// memory space 0 until %num_elements are transferred.
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
// %num_elt_per_stride :
//
// TODO(mlir-team): add additional operands to allow source and destination
// striding, and multiple stride levels.
// TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
class DmaStartOp
: public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(Builder *builder, OperationState &result, Value srcMemRef,
ValueRange srcIndices, Value destMemRef,
ValueRange destIndices, Value numElements, Value tagMemRef,
ValueRange tagIndices, Value stride = nullptr,
Value elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation.
Value getSrcMemRef() { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() {
return getSrcMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the source memref indices for this DMA operation.
operand_range getSrcIndices() {
return {getOperation()->operand_begin() + 1,
getOperation()->operand_begin() + 1 + getSrcMemRefRank()};
}
// Returns the destination MemRefType for this DMA operations.
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return getDstMemRef().getType().cast<MemRefType>().getRank();
}
unsigned getSrcMemorySpace() {
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
}
unsigned getDstMemorySpace() {
return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
}
// Returns the destination memref indices for this DMA operation.
operand_range getDstIndices() {
return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1,
getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 +
getDstMemRefRank()};
}
// Returns the number of elements being transferred by this DMA operation.
Value getNumElements() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
}
// Returns the Tag MemRef for this DMA operation.
Value getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
unsigned tagIndexStartPos =
1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
return {getOperation()->operand_begin() + tagIndexStartPos,
getOperation()->operand_begin() + tagIndexStartPos +
getTagMemRefRank()};
}
/// Returns true if this is a DMA from a faster memory space to a slower one.
bool isDestMemorySpaceFaster() {
return (getSrcMemorySpace() < getDstMemorySpace());
}
/// Returns true if this is a DMA from a slower memory space to a faster one.
bool isSrcMemorySpaceFaster() {
// Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace());
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
}
static StringRef getOperationName() { return "std.dma_start"; }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verify();
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
bool isStrided() {
return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
1 + 1 + getTagMemRefRank();
}
Value getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
Value getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
}
};
// DmaWaitOp blocks until the completion of a DMA operation associated with the
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
// with the same restrictions as any load/store index. %num_elements is the
// number of elements associated with the DMA operation. For example:
//
// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
// memref<2048 x f32>, (d0) -> (d0), 0>,
// memref<256 x f32>, (d0) -> (d0), 1>
// memref<1 x i32>, (d0) -> (d0), 2>
// ...
// ...
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
//
class DmaWaitOp
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
using Op::Op;
static void build(Builder *builder, OperationState &result, Value tagMemRef,
ValueRange tagIndices, Value numElements);
static StringRef getOperationName() { return "std.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
Value getTagMemRef() { return getOperand(0); }
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
return {getOperation()->operand_begin() + 1,
getOperation()->operand_begin() + 1 + getTagMemRefRank()};
}
// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
return getTagMemRef().getType().cast<MemRefType>().getRank();
}
// Returns the number of elements transferred in the associated DMA operation.
Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
};
/// Prints dimension and symbol list.
void printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end, unsigned numDims,
OpAsmPrinter &p);
/// Parses dimension and symbol list and returns true if parsing failed.
ParseResult parseDimAndSymbolList(OpAsmParser &parser,
SmallVectorImpl<Value> &operands,
unsigned &numDims);
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
} // end namespace mlir
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H