blob: ad3e375c326fee8adb6a2f3623b481027c64c927 [file] [log] [blame]
//===- MatchInterfaces.h - Transform Dialect Interfaces ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
#include <optional>
#include <type_traits>
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace transform {
class MatchOpInterface;
namespace detail {
/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
/// first operand.
template <typename OpTy>
DiagnosedSilenceableFailure matchOptionalOperation(OpTy op,
TransformResults &results,
TransformState &state) {
if constexpr (std::is_same_v<
typename llvm::function_traits<
decltype(&OpTy::matchOperation)>::template arg_t<0>,
Operation *>) {
return op.matchOperation(nullptr, results, state);
} else {
return op.matchOperation(std::nullopt, results, state);
}
}
} // namespace detail
template <typename OpTy>
class AtMostOneOpMatcherOpTrait
: public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
template <typename T>
using has_get_operand_handle =
decltype(std::declval<T &>().getOperandHandle());
template <typename T>
using has_match_operation_ptr = decltype(std::declval<T &>().matchOperation(
std::declval<Operation *>(), std::declval<TransformResults &>(),
std::declval<TransformState &>()));
template <typename T>
using has_match_operation_optional =
decltype(std::declval<T &>().matchOperation(
std::declval<std::optional<Operation *>>(),
std::declval<TransformResults &>(),
std::declval<TransformState &>()));
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
"AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects "
"operation type to have the getOperandHandle() method");
static_assert(
llvm::is_detected<has_match_operation_ptr, OpTy>::value ||
llvm::is_detected<has_match_operation_optional, OpTy>::value,
"AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected operation "
"type to have either the matchOperation(Operation *, TransformResults "
"&, TransformState &) or the matchOperation(std::optional<Operation*>, "
"TransformResults &, TransformState &) method");
// This must be a dynamic assert because interface registration is dynamic.
assert(
isa<MatchOpInterface>(op) &&
"AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
"operations with MatchOpInterface");
Value operandHandle = cast<OpTy>(op).getOperandHandle();
if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
return op->emitError() << "AtMostOneOpMatcherOpTrait/"
"SingleOpMatchOpTrait requires the op handle "
"to be of TransformHandleTypeInterface";
}
return success();
}
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
auto payload = state.getPayloadOps(operandHandle);
if (!llvm::hasNItemsOrLess(payload, 1)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "AtMostOneOpMatcherOpTrait requires the operand handle to "
"point to at most one payload op";
}
if (payload.empty()) {
return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()),
results, state);
}
return cast<OpTy>(this->getOperation())
.matchOperation(*payload.begin(), results, state);
}
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(this->getOperation()->getOperands(), effects);
producesHandle(this->getOperation()->getResults(), effects);
onlyReadsPayload(effects);
}
};
template <typename OpTy>
class SingleOpMatcherOpTrait : public AtMostOneOpMatcherOpTrait<OpTy> {
public:
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
auto payload = state.getPayloadOps(operandHandle);
if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleOpMatchOpTrait requires the operand handle to point to "
"a single payload op";
}
return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
rewriter, results, state);
}
};
template <typename OpTy>
class SingleValueMatcherOpTrait
: public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
// This must be a dynamic assert because interface registration is
// dynamic.
assert(isa<MatchOpInterface>(op) &&
"SingleValueMatchOpTrait is only available on operations with "
"MatchOpInterface");
Value operandHandle = cast<OpTy>(op).getOperandHandle();
if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
return op->emitError() << "SingleValueMatchOpTrait requires an operand "
"of TransformValueHandleTypeInterface";
}
return success();
}
DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
TransformResults &results,
TransformState &state) {
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
auto payload = state.getPayloadValues(operandHandle);
if (!llvm::hasSingleElement(payload)) {
return emitDefiniteFailure(this->getOperation()->getLoc())
<< "SingleValueMatchOpTrait requires the value handle to point "
"to a single payload value";
}
return cast<OpTy>(this->getOperation())
.matchValue(*payload.begin(), results, state);
}
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(this->getOperation()->getOperands(), effects);
producesHandle(this->getOperation()->getResults(), effects);
onlyReadsPayload(effects);
}
};
//===----------------------------------------------------------------------===//
// Printing/parsing for positional specification matchers
//===----------------------------------------------------------------------===//
/// Parses a positional index specification for transform match operations.
/// The following forms are accepted:
///
/// - `all`: sets `isAll` and returns;
/// - comma-separated-integer-list: populates `rawDimList` with the values;
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
/// with the values and sets `isInverted`.
ParseResult parseTransformMatchDims(OpAsmParser &parser,
DenseI64ArrayAttr &rawDimList,
UnitAttr &isInverted, UnitAttr &isAll);
/// Prints a positional index specification for transform match operations.
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
UnitAttr isAll);
//===----------------------------------------------------------------------===//
// Utilities for positional specification matchers
//===----------------------------------------------------------------------===//
/// Checks if the positional specification defined is valid and reports errors
/// otherwise.
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
bool inverted, bool all);
/// Populates `result` with the positional identifiers relative to `maxNumber`.
/// If `isAll` is set, the result will contain all numbers from `0` to
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
/// values from `rawList` are are interpreted as counting backwards from
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
/// numbers remain as is. If `isInverted` is set, populates `result` with those
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
/// `rawList`. If `rawList` contains values that are greater than or equal to
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
/// given location. `maxNumber` must be positive. If `rawList` contains
/// duplicate numbers or numbers that become duplicate after negative value
/// remapping, emits a silenceable error.
DiagnosedSilenceableFailure
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
ArrayRef<int64_t> rawList, int64_t maxNumber,
SmallVectorImpl<int64_t> &result);
} // namespace transform
} // namespace mlir
#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h.inc"
#endif // MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H