blob: 36ce89c12687dbffa1790b39e86337cabed42eec [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Copyright 2022 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef STABLEHLO_DIALECT_BASE_H
#define STABLEHLO_DIALECT_BASE_H
#include <algorithm>
#include <optional>
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"
// Include order matters
#include "stablehlo/dialect/BaseAttrInterfaces.h.inc"
namespace mlir {
namespace hlo {
// TODO(zhouxin) change to a better name as it's used by both of size and bound
// Check if the dimension size is dynamic.
inline static bool isDynamicDimSize(int64_t val) {
return ShapedType::isDynamic(val);
}
inline static bool isStaticDimSize(int64_t val) {
return !isDynamicDimSize(val);
}
// Checks whether every position in the given array contains the given value.
bool isSplatArray(ArrayRef<int64_t> arr, int64_t val);
// Verifies that the two types have compatible shape with bounds but allows
// different element types.
LogicalResult verifyCompatibleShapeWithBounds(Type type1, Type type2);
// Returns true if the given element types are compatible for the purposes of
// HLO type inference, accounting for special properties of quantization and
// sparsity.
bool isCompatibleElementTypeForHloTypeInference(Type tp1, Type tp2);
// Returns true if the given types are compatible for the purposes of HLO type
// inference, accounting for special properties of dynamism, quantization and
// sparsity.
bool isCompatibleForHloTypeInference(Type tp1, Type tp2);
// Returns true if the given type ranges are compatible for the purposes of HLO
// type inference, accounting for special properties of dynamism, quantization
// and sparsity.
bool isCompatibleForHloTypeInference(TypeRange tp1, TypeRange tp2);
// Returns true if the given shape, expressed as a runtime value, is compatible
// with the given type for the purposes of HLO type inference.
// If we know that this runtime value is a constant, then we perform the check.
// If we don't, then we return true - because shape mismatches at runtime are
// undefined behavior.
bool isCompatibleForHloTypeInference(Value shape1, Type tp2);
// TODO(zhouxin) Move type inference related methods to TypeInference.cpp
std::pair<int64_t, int64_t> inferConcatenatedDimAndBound(int64_t leftSize,
int64_t rightSize,
int64_t leftBound,
int64_t rightBound);
FailureOr<std::pair<int64_t, int64_t>> inferMostSpecificDimAndBound(
std::optional<Location> location, int64_t dim, int64_t leftSize,
int64_t rightSize, int64_t leftBound, int64_t rightBound);
FailureOr<std::pair<int64_t, int64_t>> inferLeastSpecificDimAndBound(
std::optional<Location> location, int64_t dim, int64_t leftSize,
int64_t rightSize, int64_t leftBound, int64_t rightBound);
// Infer single least specific return type from inputTypes with support for
// bounds. (Size, bound) of each dimension of the return type will be merged
// from corresponding dimensions of every inputType by extracting the least
// specific one. Return unranked tensor if any input is unranked.
FailureOr<Type> inferLeastSpecificType(std::optional<Location> location,
TypeRange inputTypes);
// Infer single most specific return type from inputTypes with support for
// bounds. (Size, bound) of each dimension of the return type will be merged
// from corresponding dimensions of every inputType by extracting the most
// specific one. Return unranked tensor if all inputs are unranked.
FailureOr<Type> inferMostSpecificType(std::optional<Location> location,
TypeRange inputTypes);
LogicalResult inferMostSpecificTypeComponents(
std::optional<Location> location, TypeRange inputTypes,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes);
// Matches a constant with integer value into int64_t.
LogicalResult matchInt(Value value, int64_t &result);
// Matches a constant tensor with integer values into a 1-dimensional vector.
// Doesn't preserve the bitness or the signedness of the underlying values,
// extracting them into int64_t.
LogicalResult matchInts(Value value, SmallVector<int64_t> &result);
// Matches a constant tensor with integer values into a 1-dimensional vector.
// Preserves the bitness and the signedness of the underlying values.
LogicalResult matchInts(Value value, SmallVector<APSInt> &result);
// Matches a constant tensor with integer values.
// Unlike the functions above, it doesn't return these values - it just checks
// that the given argument is indeed a constant tensor with integer values.
LogicalResult matchInts(Value value);
// Shape derivation function that computes the shape of the result based on an
// operand. For a 2-dimensional input tensor, this produces IR of the form
//
// %0 = dim %arg0, 0 : memref<?x?xf32>
// %1 = index_cast %0 : index to i64
// %2 = dim %arg0, 1 : memref<?x?xf32>
// %3 = index_cast %2 : index to i64
// %4 = "shape.shape_of"(%1, %3)
// : (i64, i64) -> tensor<2xi64>
//
// and returns %4 as the shape value.
LogicalResult deriveShapeFromOperand(
OpBuilder *builder, Operation *op, Value operand,
SmallVectorImpl<Value> *reifiedReturnShapes);
// Type derivation function that returns a tensor type with a new element type.
ShapedType getSameShapeTensorType(ShapedType shapedType, Type elementType);
// Takes a tensor type that may have complex elements and returns a type that
// maintains the shape, but with real numeric data types.
// Ex: tensor<4xcomplex<f32>> --> tensor<4xf32>
ShapedType createRealType(ShapedType type);
// Verify bounds expressed by HLO_BoundedAttrInterface against the provided
// type. See documentation for HLO_BoundedAttrInterface for the list of checks.
LogicalResult verifyBounds(ArrayRef<int64_t> bounds, RankedTensorType type,
function_ref<InFlightDiagnostic()> emitError);
// If an encoding attribute conforms to HLO_BoundedAttrInterface, return the
// bounds that it carries. Otherwise, return an empty ArrayRef.
ArrayRef<int64_t> encodingToBounds(Attribute encoding);
// Create an HLO_BoundedAttrInterface encoding attribute that carries the given
// bounds. Requires a prototype - an existing encoding attribute - to obtain
// the underlying dialect that knows how to create these attributes.
Attribute boundsToEncoding(Attribute prototype, ArrayRef<int64_t> bounds);
// Get refinements for return types from an indices_of_shape_operands attribute,
// with tuples types flattened (see `flattenTupleTypes` below).
// If the attribute doesn't exist, returns failure.
// If the attribute exists but is not invalid with respect to the operation,
// reports an optional error and returns failure.
// If the attribute is valid but not all shape operands are constants,
// returns failure.
LogicalResult getShapeRefinements(
std::optional<Location> location, Operation *operation,
SmallVector<ShapedTypeComponents> &refinements);
// For each type in `types`, recursively flatten tuple types into `result`.
// Result is populated via in-order traversal of tuple types in `types`, i.e.:
// * Flattenings of individual types from `types` follow one another in the
// same order as `types`.
// * Same for flattenings of element types of tuple types.
void flattenTupleTypes(TypeRange types, SmallVector<Type> &result);
// Does the inverse of `flattenTupleTypes` - takes `types` and recursively
// unflattens it, creating tuple types as needed to exactly match the structure
// of `prototype`.
// Fails if the number of elements in flattened prototype is different from
// the number of elements in types.
LogicalResult unflattenTupleTypes(TypeRange prototype, TypeRange types,
SmallVector<Type> &result);
ShapedType createShapedType(ShapedTypeComponents components);
// This interface is implemented by both StableHLO and MHLO dialects
// and is used as the foundation for sharing verification, type inference and
// prettyprinting logic between them.
class HloDialectInterface : public DialectInterface::Base<HloDialectInterface> {
public:
HloDialectInterface(Dialect *dialect) : Base(dialect) {}
// Creates a TokenType type, specific to this dialect.
// See docs for the particular type in the corresponding dialect.
virtual Type createTokenType() const = 0;
// Check whether the type is of TokenType in the corresponding dialect.
virtual bool isTokenType(Type type) const = 0;
// Creates a TypeExtensions attribute, specific to this dialect.
// See docs for the particular attribute in the corresponding dialect.
virtual Attribute createTypeExtensions(ArrayRef<int64_t> bounds) const = 0;
};
namespace bytecode {
// Helper methods for bytecode
// Enum reader and writer. Many attrs have a single enum type to serialize.
// Use the attributes underlying type to get the numeric value.
// Note this may cause issues if enums use an int64_t and have a large value.
// All enums in StableHLO and CHLO currently use uint32_t.
template <typename EnumTypeAttr, typename SymbolizeFn>
EnumTypeAttr readEnumAttribute(DialectBytecodeReader &reader,
MLIRContext *context, SymbolizeFn symbolizeFn) {
uint64_t code;
if (failed(reader.readVarInt(code))) return EnumTypeAttr();
auto enumOpt = symbolizeFn(static_cast<uint32_t>(code));
if (!enumOpt.has_value()) return EnumTypeAttr();
return EnumTypeAttr::get(context, enumOpt.value());
}
template <typename EnumType, typename EnumTypeAttr>
void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) {
static_assert(
std::is_same<typename std::underlying_type<EnumType>::type,
uint32_t>::value,
"writeEnumAttribute is only implemented for uint32_t enum values");
uint32_t enumVal = static_cast<typename std::underlying_type<EnumType>::type>(
val.getValue());
writer.writeVarInt(enumVal);
}
} // namespace bytecode
namespace OpTrait {
template <typename ConcreteType>
class BroadcastingElementwise
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
template <typename ConcreteType>
class IsCommutative
: public mlir::OpTrait::TraitBase<ConcreteType, IsCommutative> {};
template <typename ConcreteType>
class PairwiseSameOperandAndResultType
: public mlir::OpTrait::TraitBase<ConcreteType,
PairwiseSameOperandAndResultType> {
public:
static LogicalResult verifyTrait(Operation *op) {
const int numOperands = op->getNumOperands();
const int numResults = op->getNumResults();
if (numOperands != numResults) {
return op->emitOpError()
<< "requires the same number of operands and results";
}
for (int idx : llvm::seq<int>(0, numOperands)) {
if (op->getOperand(idx).getType() != op->getResult(idx).getType()) {
return op->emitOpError()
<< "requires the same type for operand and result at index "
<< idx;
}
}
return success();
}
};
template <typename ConcreteType>
class CompatibleOperandsAndResultElementType
: public mlir::OpTrait::TraitBase<ConcreteType,
CompatibleOperandsAndResultElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
Type expected;
if (op->getNumResults() != 0) expected = op->getResult(0).getType();
if (op->getNumOperands() != 0) expected = op->getOperand(0).getType();
if (!expected) return failure();
auto typeMatch = [&](Type actual) {
return isCompatibleElementTypeForHloTypeInference(actual, expected);
};
auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch) &&
llvm::all_of(op->getResultTypes(), typeMatch);
if (!allMatch) {
return op->emitOpError(
"requires compatible element types for all operands and results");
}
return success(allMatch);
}
};
template <typename ConcreteType>
class CompatibleOperandsElementType
: public mlir::OpTrait::TraitBase<ConcreteType,
CompatibleOperandsElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
if (failed(mlir::OpTrait::impl::verifyAtLeastNOperands(op, 1)))
return failure();
Type expected = op->getOperand(0).getType();
auto typeMatch = [&](Type actual) {
return isCompatibleElementTypeForHloTypeInference(actual, expected);
};
auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch);
if (!allMatch) {
return op->emitOpError(
"requires compatible element types for all operands");
}
return success();
}
};
template <typename ConcreteType>
class CompatibleOperandsAndResultType
: public mlir::OpTrait::TraitBase<ConcreteType,
CompatibleOperandsAndResultType> {
public:
static LogicalResult verifyTrait(Operation *op) {
Type expected;
if (op->getNumResults() != 0) expected = op->getResult(0).getType();
if (op->getNumOperands() != 0) expected = op->getOperand(0).getType();
if (!expected) return failure();
auto typeMatch = [&](Type actual) {
return isCompatibleForHloTypeInference(actual, expected);
};
auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch) &&
llvm::all_of(op->getResultTypes(), typeMatch);
if (!allMatch) {
return op->emitOpError(
"requires compatible types for all operands and results");
}
return success(allMatch);
}
static LogicalResult inferReturnTypes(
MLIRContext * /*context*/, std::optional<Location> location,
ValueRange operands, DictionaryAttr /*attributes*/,
OpaqueProperties /*properties*/, RegionRange /*regions*/,
SmallVectorImpl<Type> &inferredReturnTypes) {
// TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
// support quantization or sparsity.
if (operands.empty())
return emitOptionalError(
location,
"Expected non-empty operands for [CompatibleOperandsAndResultType]");
auto inferredTypeOrErr =
inferMostSpecificType(location, operands.getTypes());
if (failed(inferredTypeOrErr)) return failure();
inferredReturnTypes.emplace_back(*inferredTypeOrErr);
return success();
}
// This function is not going to be called automatically.
// It needs to be paired with INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS
// (see examples in StablehloOps.cpp).
static LogicalResult inferReturnTypeComponentsFromOperands(
MLIRContext *context, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
SmallVector<Type> inferredReturnTypes;
if (failed(inferReturnTypes(context, location, operands.getValues(),
attributes, properties, regions,
inferredReturnTypes)))
return failure();
if (inferredReturnTypes.size() != 1) return failure();
auto inferredReturnType = dyn_cast<ShapedType>(inferredReturnTypes[0]);
if (!inferredReturnType) return failure();
inferredReturnShapes.push_back(inferredReturnType);
return success();
}
};
template <typename ConcreteType>
struct SpeculatableIfStaticDimInOutputIsStaticInInputImplTrait
: public mlir::OpTrait::TraitBase<
ConcreteType,
SpeculatableIfStaticDimInOutputIsStaticInInputImplTrait> {
// A unary elementwise op is not speculatable if a dimension of the result
// type is static while the corresponding dimension in the input type is
// dynamic. Indeed, the input dimension could differ at runtime.
// If the output dimension is dynamic, there is no expectation, so there
// cannot be a mismatch.
// If the input dimension is static, the output dimension can be inferred from
// it, so there cannot be a mismatch.
mlir::Speculation::Speculatability getSpeculatability() {
auto op = this->getOperation();
auto inputType = cast<RankedTensorType>(op->getOperand(0).getType());
auto resultType = cast<RankedTensorType>(op->getResult(0).getType());
for (size_t i : llvm::seq(resultType.getRank())) {
if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(i))
return mlir::Speculation::NotSpeculatable;
}
return mlir::Speculation::Speculatable;
}
};
template <typename ConcreteType>
struct RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait
: public mlir::OpTrait::TraitBase<
ConcreteType,
RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait> {
mlir::Speculation::Speculatability getSpeculatability() {
auto op = this->getOperation();
auto inputType = cast<RankedTensorType>(op->getOperand(0).getType());
auto resultType = cast<RankedTensorType>(op->getResult(0).getType());
for (size_t i : llvm::seq(resultType.getRank())) {
if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(i))
return mlir::Speculation::NotSpeculatable;
}
return mlir::Speculation::RecursivelySpeculatable;
}
};
template <typename ConcreteType>
struct SpeculatableIfAllInputsStaticImplTrait
: public mlir::OpTrait::TraitBase<ConcreteType,
SpeculatableIfAllInputsStaticImplTrait> {
mlir::Speculation::Speculatability getSpeculatability() {
return llvm::all_of(this->getOperation()->getOperandTypes(),
[](Type t) {
return cast<RankedTensorType>(t).hasStaticShape();
})
? mlir::Speculation::Speculatable
: mlir::Speculation::NotSpeculatable;
}
};
template <typename ConcreteType>
struct RecursivelySpeculatableIfAllInputsStaticImplTrait
: public mlir::OpTrait::TraitBase<
ConcreteType, RecursivelySpeculatableIfAllInputsStaticImplTrait> {
mlir::Speculation::Speculatability getSpeculatability() {
return llvm::all_of(this->getOperation()->getOperandTypes(),
[](Type t) {
return cast<RankedTensorType>(t).hasStaticShape();
})
? mlir::Speculation::RecursivelySpeculatable
: mlir::Speculation::NotSpeculatable;
}
};
} // namespace OpTrait
} // namespace hlo
} // namespace mlir
#endif // STABLEHLO_DIALECT_BASE_H