blob: 8ba072fa4deaee6783ecf5edf2f762f1feca2698 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2022 The StableHLO Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "stablehlo/dialect/TypeInference.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <array>
#include <cstdint>
#include <functional>
#include <numeric>
#include <optional>
#include <set>
#include <unordered_map>
#include <utility>
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Regex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/AssemblyFormat.h"
namespace mlir {
namespace hlo {
namespace {
//===----------------------------------------------------------------------===//
// Utils for quantization specific verifications
//===----------------------------------------------------------------------===//
template <typename T>
bool allQuantized(ArrayRef<Type> typeRange) {
return llvm::all_of(typeRange, [&](Type val) {
return val.cast<ShapedType>().getElementType().isa<T>();
});
}
template <typename T>
bool noneQuantized(ArrayRef<Type> typeRange) {
return llvm::all_of(typeRange, [&](Type val) {
return !val.cast<ShapedType>().getElementType().isa<T>();
});
}
} // namespace
//===----------------------------------------------------------------------===//
// Utils for shape functions.
//===----------------------------------------------------------------------===//
// Checks if the vector `nums` has duplicates.
bool isUnique(ArrayRef<int64_t> nums) {
llvm::SmallDenseSet<int64_t> set(nums.begin(), nums.end());
return set.size() == nums.size();
}
bool tensorsHaveSameElType(TypeRange types, bool ignoreFpPrecision = true) {
if (!types.empty()) {
auto tensorTy1 = types[0].cast<ShapedType>();
Type tensorEl1 = tensorTy1.getElementType();
for (auto otherTensor : llvm::drop_begin(types, 1)) {
auto tensorTy2 = otherTensor.cast<ShapedType>();
Type tensorEl2 = tensorTy2.getElementType();
if (ignoreFpPrecision && tensorEl1.isa<FloatType>() &&
tensorTy2.getElementType().isa<FloatType>())
continue;
if (tensorEl1 != tensorEl2) return false;
}
}
return true;
}
// Return true if type1 and type2 are tensors and have the same
// element-type, else return false. With float element-types, ignore comparing
// floating-point precision if ignoreFpPrecision is True.
bool tensorsHaveSameElType(Type type1, Type type2,
bool ignoreFpPrecision = true) {
return tensorsHaveSameElType({type1, type2}, ignoreFpPrecision);
}
unsigned getBitWidth(Type type) {
if (auto complexTy = type.dyn_cast<ComplexType>())
return 2 * getBitWidth(complexTy.getElementType());
if (auto quantTy = type.dyn_cast<quant::QuantizedType>())
return getBitWidth(quantTy.getStorageType());
return type.getIntOrFloatBitWidth();
}
template <typename T>
bool matchesType(Type a, Type b) {
bool matches = a.isa<T>() && b.isa<T>();
// Check that expressed type matches for quantized types
if constexpr (std::is_same<T, quant::QuantizedType>::value) {
return matches && (a.cast<quant::QuantizedType>().getExpressedType() ==
b.cast<quant::QuantizedType>().getExpressedType());
}
return matches;
}
// Returns true if the element-type of type1 can be promoted to that of type2.
// An element-type 'x' is promotatble to element-type 'y' is they have the same
// base type and bitwidth(x) <= bitwidth(y). When 'x' and 'y' are quantized
// element-types, then promotion is applied only to the 'storage_type'
// component.
bool isPromotableElementType(Type type1, Type type2,
bool ignoreFpPrecision = false) {
auto tensorTy1 = type1.dyn_cast<ShapedType>();
auto tensorTy2 = type2.dyn_cast<ShapedType>();
if (!tensorTy1 || !tensorTy2) return false;
Type tensorEl1 = tensorTy1.getElementType();
Type tensorEl2 = tensorTy2.getElementType();
bool isSameType = matchesType<IntegerType>(tensorEl1, tensorEl2) ||
matchesType<FloatType>(tensorEl1, tensorEl2) ||
matchesType<ComplexType>(tensorEl1, tensorEl2) ||
matchesType<quant::QuantizedType>(tensorEl1, tensorEl2);
if (!isSameType) return false;
if (ignoreFpPrecision && tensorEl1.isa<FloatType>()) return true;
return getBitWidth(tensorEl1) <= getBitWidth(tensorEl2);
}
// Return true if type1 and type2 are shape-compatible and have same element
// type. If 'ignoreFpPrecision' is True, then allow floats with different
// precisions while checking element-types.
bool compatibleShapeAndElementType(Type type1, Type type2,
bool ignoreFpPrecision = false) {
if (failed(verifyCompatibleShape(type1, type2))) return false;
return tensorsHaveSameElType(type1, type2, ignoreFpPrecision);
}
bool verifyCompatibleDims(int64_t dimSize1, int64_t dimSize2) {
return isDynamicDimSize(dimSize1) || isDynamicDimSize(dimSize2) ||
dimSize1 == dimSize2;
}
// Convert a 1D dense int64 attribute to a list of values.
FailureOr<SmallVector<int64_t>> convert1DAttribute(
std::optional<DenseIntElementsAttr> optionalAttr,
std::optional<Location> loc, StringRef attrName) {
if (!optionalAttr.has_value()) return SmallVector<int64_t>{};
DenseIntElementsAttr attr = *optionalAttr;
auto attrType = attr.getType().cast<RankedTensorType>();
if (attrType.getRank() != 1)
return emitOptionalError(loc, "expects the shape of ", attrName,
" attribute to be 1-D, but got {",
attrType.getShape(), "}.");
auto values = attr.getValues<int64_t>();
return SmallVector<int64_t>{values.begin(), values.end()};
}
FailureOr<SmallVector<std::pair<int64_t, int64_t>>> convertPaddingAttribute(
std::optional<DenseIntElementsAttr> optionalAttr,
std::optional<Location> loc) {
if (!optionalAttr.has_value())
return SmallVector<std::pair<int64_t, int64_t>>{};
DenseIntElementsAttr attr = *optionalAttr;
auto attrType = attr.getType().cast<RankedTensorType>();
if (attrType.getRank() != 2 || attrType.getShape()[1] != 2)
return emitOptionalError(
loc, "expects the shape of padding-attribute to be {N, 2}, but got {",
attrType.getShape(), "}.");
auto it = attr.getValues<int64_t>().begin();
SmallVector<std::pair<int64_t, int64_t>> out(attr.getNumElements() / 2);
for (auto& item : out) {
int64_t first = *it;
++it;
int64_t second = *it;
++it;
item = {first, second};
}
return out;
}
// Convert a 1D dense bool attribute to a list of values.
FailureOr<SmallVector<bool>> convertWindowReversalAttribute(
std::optional<DenseElementsAttr> optionalAttr, std::optional<Location> loc,
StringRef attrName) {
if (!optionalAttr.has_value()) return SmallVector<bool>{};
DenseElementsAttr attr = *optionalAttr;
auto attrType = attr.getType().cast<RankedTensorType>();
if (attrType.getRank() != 1)
return emitOptionalError(loc, "expects the shape of ", attrName,
" attribute to be 1-D, but got {",
attrType.getShape(), "}.");
auto values = attr.getValues<bool>();
return SmallVector<bool>{values.begin(), values.end()};
}
// If a window with the given bound in some dimension is dilated with the given
// dilation factor in that dimension, then the value returned is the bound for
// the array in that dimension after dilation.
//
// For a 1D array with 3 entries 1, 2, 3, a dilation factor of 2 yields a new
// window with values 1, x, 2, x, 3, where x indicates holes left by the
// dilation. So DilatedBound(3, 2) == 5.
int64_t dilatedBound(int64_t bound, int64_t dilation) {
assert(bound >= 0 && "The dimension to dilate must be >= 0");
if (bound == 0) return 0;
// Suppose the array has three entries 123 and the dilation factor is 4. Then
// the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
// the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
// add 1 to account for the final input element.
return (bound - 1) * dilation + 1;
}
// Returns the number of valid positions of a window with the given size and
// stride within an array with the given bound. This is the bound of an output
// array with one element per valid position of the window.
//
// For example, for arguments of (bound=5, window_size=2, stride=2), the
// returned value is 2. There are valid positions at offset 0 and offset 2,
// while offset 4 is not valid since the window's last entry would be at 5,
// which is beyond the bound of 5.
int64_t stridedBound(int64_t bound, int64_t windowSize, int64_t stride) {
assert(windowSize >= 0 && "Expected window size to be >= 0");
assert(bound >= 0 && "Expected bound to be >= 0");
if (bound == 0 || windowSize > bound) return 0;
// Without considering stride, the maximum valid offset is bound -
// window_size. Taking stride into account, the valid offsets then have the
// form q * stride for q = 0, ..., Q such that q * stride <= bound -
// window_size. This implies that Q equals floor(bound - window_size /
// stride). There are Q + 1 valid values of q, yielding the formula below.
return (bound - windowSize) / stride + 1;
}
LogicalResult verifyPairwiseCompatibleShapes(TypeRange values) {
for (auto type1 : values)
for (auto type2 : values)
if (failed(verifyCompatibleShape(type1, type2))) return failure();
return success();
}
LogicalResult verifyBatchNorm(std::optional<Location> location,
ValueRange multiDimOperands,
ValueRange singleDimOperands,
int64_t featureIndex) {
// batch_norm_grad_c3
if (failed(verifyPairwiseCompatibleShapes(multiDimOperands.getTypes())))
return emitOptionalError(
location,
"expects multi-dimensional operands to have compatible shapes.");
// batch_norm_grad_c4, batch_norm_inference_c3...batch_norm_inference_c6,
// batch_norm_training_c3, batch_norm_training_c4
if (failed(verifyPairwiseCompatibleShapes(singleDimOperands.getTypes())))
return emitOptionalError(
location,
"expects single-dimensional operands to have compatible shapes.");
auto multiDimType = multiDimOperands[0].getType().cast<RankedTensorType>();
// batch_norm_grad_c1, batch_norm_inference_c1, batch_norm_training_c1
if (featureIndex >= multiDimType.getRank())
return emitOptionalError(
location,
"expects featureIndex to be smaller than the rank of "
"multi-dimensional operands; got featureIndex ",
featureIndex, ", and rank ", multiDimType.getRank(), ".");
// batch_norm_grad_c1, batch_norm_inference_c1, batch_norm_training_c1
if (featureIndex < 0)
return emitOptionalError(location, "expects featureIndex to be a ",
"non-negative number, got ", featureIndex, ".");
const int64_t featureCount = multiDimType.getDimSize(featureIndex);
const int64_t singleDimSize =
singleDimOperands[0].getType().cast<RankedTensorType>().getDimSize(0);
// batch_norm_grad_c5, batch_norm_inference_c3...batch_norm_inference_c6,
// batch_norm_training_c3, batch_norm_training_c4
if (!verifyCompatibleDims(singleDimSize, featureCount))
return emitOptionalError(
location,
"expects the size of single-dimensional operands to be compatible with "
"feature count, but the size of single-dimensional operands is ",
dimSizeToString(singleDimSize), " and the feature count is ",
dimSizeToString(featureCount), ".");
return success();
}
LogicalResult inferBatchNormOp(
std::optional<Location> location, ValueRange multiDimOperands,
ValueRange singleDimOperands, int64_t featureIndex,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes,
bool is_inference) {
if (failed(verifyBatchNorm(location, multiDimOperands, singleDimOperands,
featureIndex)))
return failure();
// Batch norm ops require operands to be ranked.
auto multiDimType = multiDimOperands[0].getType().cast<RankedTensorType>();
// batch_norm_grad_c3, batch_norm_inference_c7, batch_norm_training_c7
inferredReturnShapes.emplace_back(multiDimType.getShape(),
multiDimType.getElementType(),
multiDimType.getEncoding());
if (is_inference) return success();
SmallVector<int64_t> singleDimShape{multiDimType.getDimSize(featureIndex)};
ArrayRef<int64_t> multiDimBounds =
encodingToBounds(multiDimType.getEncoding());
SmallVector<int64_t> singleDimBounds;
if (!multiDimBounds.empty())
singleDimBounds.emplace_back(multiDimBounds[featureIndex]);
auto singleDimReturnShape = ShapedTypeComponents(
singleDimShape, multiDimType.getElementType(),
singleDimBounds.empty()
? nullptr
: boundsToEncoding(multiDimType.getEncoding(), singleDimBounds));
// batch_norm_grad_c4, batch_norm_training_c5
inferredReturnShapes.emplace_back(singleDimReturnShape);
// batch_norm_grad_c4, batch_norm_training_c6
inferredReturnShapes.emplace_back(singleDimReturnShape);
return success();
}
// Verifies various properties of window-attributes (viz., stride, padding,
// lhs_dilation and rhs_dilation) and collects all the window-attributes for
// each kernel spatial dimensions.
FailureOr<SmallVector<WindowDimension>>
verifyWindowAttributesAndInferWindowDimensions(
ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides,
ArrayRef<std::pair<int64_t, int64_t>> padding,
ArrayRef<int64_t> lhsDilation, ArrayRef<int64_t> rhsDilation,
ArrayRef<bool> windowReversal, std::optional<Location> loc) {
const auto verifySize = [&](const size_t attrSize,
StringRef attrName) -> LogicalResult {
if (attrSize == 0 || attrSize == windowDimensions.size()) return success();
return emitOptionalError(
loc, "expects ", attrName,
" to have same dimension-size as size of window dimensions (",
windowDimensions.size(), "), but got: ", attrSize, ".");
};
// reduce_window_c6, select_and_scatter_c6
if (failed(verifySize(windowStrides.size(), "window-strides")))
return failure();
// reduce_window_c8
if (failed(verifySize(lhsDilation.size(), "base-dilation factors")))
return failure();
// reduce_window_c10
if (failed(verifySize(rhsDilation.size(), "window-dilation factors")))
return failure();
// reduce_window_c12
if (failed(verifySize(padding.size(), "padding-entries"))) return failure();
if (failed(verifySize(windowReversal.size(), "window-reversal")))
return failure();
SmallVector<WindowDimension> window(windowDimensions.size());
for (size_t i = 0; i < windowDimensions.size(); i++) {
WindowDimension& dim = window[i];
dim.size = windowDimensions[i];
// reduce_window_c5, select_and_scatter_c5
if (!isDynamicDimSize(dim.size) && dim.size <= 0)
return emitOptionalError(loc,
"expects window to have positive value for ", i,
"-th window dimension, but got ", dim.size, ".");
if (!windowStrides.empty()) dim.stride = windowStrides[i];
// reduce_window_c7, select_and_scatter_c7
if (dim.stride <= 0)
return emitOptionalError(
loc, "expects window to have positive stride for ", i,
"-th window dimension, but got ", dim.stride, ".");
if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i];
// reduce_window_c9
if (dim.baseDilation <= 0)
return emitOptionalError(
loc, "expects window to have positive base dilation factor for ", i,
"-th window dimension, but got ", dim.baseDilation, ".");
if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i];
// reduce_window_c11
if (dim.windowDilation <= 0)
return emitOptionalError(
loc, "expects window to have positive window dilation factor for ", i,
"-th window dimension, but got ", dim.windowDilation, ".");
if (!padding.empty()) {
dim.paddingLow = padding[i].first;
dim.paddingHigh = padding[i].second;
}
}
return window;
}
// Infer the shape of the output window.
// Foreach dimension d,
// output-window-shape[d] =
// stridedBound(padding_low + dilatedBound(base_shape[d]) +
// padding_high,
// dilatedBound(window_shape[d]))
// where (padding_low, padding_high) is the padding-pair for d.
SmallVector<int64_t> inferWindowOutputShape(ArrayRef<int64_t> baseShape,
ArrayRef<WindowDimension> window) {
assert(baseShape.size() == window.size() &&
"Size of window dimensions must match the size of base shape.");
SmallVector<int64_t> outputDimensions(window.size());
for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i) {
if (isDynamicDimSize(baseShape[i]) || isDynamicDimSize(window[i].size)) {
outputDimensions[i] = ShapedType::kDynamic;
} else {
const auto& dim = window[i];
const int64_t dilatedBase = dilatedBound(baseShape[i], dim.baseDilation);
const int64_t paddedDilatedBase =
dim.paddingLow + dilatedBase + dim.paddingHigh;
const int64_t dilatedWindow = dilatedBound(dim.size, dim.windowDilation);
outputDimensions[i] =
stridedBound(paddedDilatedBase, dilatedWindow, dim.stride);
}
}
return outputDimensions;
}
LogicalResult verifyReplicaGroups(std::optional<Location> location,
DenseIntElementsAttr replicaGroups,
bool allGroupsMustHaveSameSize,
bool useGlobalDeviceIds,
std::optional<size_t> expectedGroupSize) {
auto replicaGroupType = replicaGroups.getType().cast<RankedTensorType>();
// all_gather_i3, all_to_all_i5
if (replicaGroupType.getRank() != 2)
return emitOptionalError(location,
"replica groups should be a rank 2 tensor");
// Revisit the following check in light of #498.
if (useGlobalDeviceIds &&
(replicaGroupType.getShape()[0] * replicaGroupType.getShape()[1] == 0))
return emitOptionalError(location,
"if `use_global_device_ids` is set, the replica "
"groups cannot be empty");
auto replicaIds = replicaGroups.getValues<int64_t>();
llvm::SmallSet<int64_t, 8> replicaIdsSeen;
for (int64_t replicaId : replicaIds) {
// Replica groups are stored in a 2D tensor. If the op supports non-uniform
// groups, null replica IDs are stored as -1.
// all_gather_c4
if (replicaId == -1) {
if (!allGroupsMustHaveSameSize) continue;
return emitOptionalError(location, "Invalid replica id -1");
}
// all_gather_c2, all_reduce_c1, all_to_all_c5
if (!replicaIdsSeen.insert(replicaId).second)
return emitOptionalError(location, "replica id #", replicaId,
" seen more than once");
}
// all_gather_c4, all_reduce_c3, all_to_all_c7
for (size_t id = 0; id < replicaIdsSeen.size(); id++)
if (!replicaIdsSeen.contains(id))
return emitOptionalError(location, "replica id #", id,
" not seen in replica groups");
// all_to_all_c8
if (allGroupsMustHaveSameSize && expectedGroupSize &&
(replicaIds.size() / replicaGroupType.getShape()[0] !=
*expectedGroupSize))
return emitOptionalError(location, "group size of replica_groups must be ",
*expectedGroupSize);
return success();
}
LogicalResult verifyReduceOpInputsAndInferShape(
std::optional<Location> location, SmallVector<ShapedType> inputTypes,
ArrayRef<int64_t> dimensions, SmallVector<int64_t>& newDimensions,
Attribute& encoding) {
// Check for unranked tensors in input operands.
uint64_t numInputs = inputTypes.size();
int64_t rankedInputIdx = -1;
for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
if (inputTypes[inputIdx].hasRank()) {
rankedInputIdx = inputIdx;
break;
}
}
bool allInputsUnranked = (rankedInputIdx == -1);
// reduce_c1
if (!allInputsUnranked) {
for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx)
if (failed(mlir::verifyCompatibleShape(inputTypes[rankedInputIdx],
inputTypes[inputIdx])))
return emitOptionalError(
location, "expects all inputs to have compatible shapes. Shape at",
" input-index ", inputIdx,
" is not compatible with shape at input-index ", rankedInputIdx);
}
DenseSet<int64_t> dimensionsToReduceSet;
for (int64_t dimension : dimensions) {
// reduce_c4
if ((!allInputsUnranked &&
dimension >= inputTypes[rankedInputIdx].getRank()) ||
dimension < 0)
return emitOptionalError(
location, "Out-of-bounds dimension ", dimension, ", expected to be ",
allInputsUnranked
? "> 0"
: "less than the input-tensor rank " +
std::to_string(inputTypes[rankedInputIdx].getRank()));
// reduce_c5
if (!dimensionsToReduceSet.insert(dimension).second)
return emitOptionalError(location,
"Duplicate reduction dimension: ", dimension);
}
if (!allInputsUnranked) {
auto rankedInput = inputTypes[rankedInputIdx].cast<RankedTensorType>();
ArrayRef<int64_t> inputBounds = encodingToBounds(rankedInput.getEncoding());
SmallVector<int64_t> newBounds;
for (int inputIdx = 0; inputIdx < rankedInput.getRank(); ++inputIdx) {
if (!dimensionsToReduceSet.count(inputIdx)) {
newDimensions.push_back(rankedInput.getDimSize(inputIdx));
if (!inputBounds.empty()) newBounds.push_back(inputBounds[inputIdx]);
}
}
// Set encoding based on the bounds only if the bounds is not empty.
encoding = nullptr;
if (!newBounds.empty())
encoding = boundsToEncoding(rankedInput.getEncoding(), newBounds);
}
return success();
}
// Returns the types of the terminator arguments of the input mlir::Block
// 'block'.
FailureOr<SmallVector<ShapedType>> getAccumulatorTypes(
std::optional<Location> loc, Region& region) {
if (region.empty()) {
return emitOptionalError(
loc, "Expects non-empty reduction block for type inference");
}
Block& block = region.front();
return llvm::to_vector(
llvm::map_range(block.getTerminator()->getOperands(),
[&](Value v) { return v.getType().cast<ShapedType>(); }));
}
LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
ArrayRef<ShapedType> inputTypes,
ArrayRef<ShapedType> initValueTypes,
ArrayRef<int64_t> allowedDimensions) {
int64_t numInputs = inputTypes.size();
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
if (static_cast<int64_t>(block.getArguments().size()) != numInputs * 2)
return emitOptionalError(loc, "Reduction-region must take ", numInputs * 2,
" parameters, but takes ",
block.getArguments().size(), " parameter(s)");
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
if (block.getTerminator()->getOperands().empty())
return emitOptionalError(
loc, "The reduction-region expected to return some value(s)");
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
if (static_cast<int64_t>(block.getTerminator()->getOperands().size()) !=
numInputs)
return emitOptionalError(loc, "Reduction-region here must produce ",
numInputs, " tensors, but produces ",
block.getTerminator()->getOperands().size(),
" instead");
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
SmallVector<ShapedType> accumulatorSubShapes;
for (Value retOperand : block.getTerminator()->getOperands()) {
auto shapedTy = retOperand.getType().dyn_cast<ShapedType>();
if (!shapedTy)
return emitOptionalError(loc,
"Reduction-region here must produce "
"tensor-typed result(s), but produces ",
retOperand.getType(), " instead");
accumulatorSubShapes.push_back(shapedTy);
}
for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
// all_reduce_c5, reduce_c2, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c10
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
block.getArgument(inputIdx).getType()))
return emitOptionalError(
loc, "The type of reduction-region's parameter at index ", inputIdx,
" is different than the corresponding result type: ",
block.getArgument(inputIdx).getType(), " vs ",
accumulatorSubShapes[inputIdx]);
// all_reduce_c5, reduce_c2, reduce_scatter_c7, reduce_window_c13,
// scatter_c15, select_and_scatter_c3, select_and_scatter_c10
if (!compatibleShapeAndElementType(
accumulatorSubShapes[inputIdx],
block.getArgument(numInputs + inputIdx).getType(),
/*ignoreFpPrecision=*/true))
return emitOptionalError(
loc, "The type of reduction-region's parameter at index ",
numInputs + inputIdx,
" is different than the corresponding result type: ",
block.getArgument(numInputs + inputIdx).getType(), " vs ",
accumulatorSubShapes[inputIdx]);
// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// reduce_window_i2, scatter_c6, scatter_c15, select_and_scatter_c10
if (failed(verifyCompatibleShape(initValueTypes[inputIdx],
accumulatorSubShapes[inputIdx])))
return emitOptionalError(
loc, "The shape of reduction-region's result type at index ",
inputIdx, " differs from the op's corresponding init-value type: ",
accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]);
if (!isPromotableElementType(initValueTypes[inputIdx],
accumulatorSubShapes[inputIdx],
/*ignoreFpPrecision=*/true))
return emitOptionalError(
loc, "The element-type of reduction-region's result type at index ",
inputIdx,
" is expected to be promotable from the op's corresponding "
"init-value element-type: ",
accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]);
// reduce_c6, reduce_window_c3, scatter_c6, scatter_c15,
// select_and_scatter_c10
if (!isPromotableElementType(
inputTypes[inputIdx],
block.getArgument(numInputs + inputIdx).getType(),
/*ignoreFpPrecision=*/true))
return emitOptionalError(
loc, "The element-type of reduction-region's argument at index ",
numInputs + inputIdx, " is expected to be promotable from ",
inputTypes[inputIdx].getElementType(), ", but got ",
getElementTypeOrSelf(
block.getArgument(numInputs + inputIdx).getType()));
Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
auto blockArgTensorTy = blockArgType.cast<ShapedType>();
auto allInputsUnranked = llvm::none_of(
inputTypes, [&](ShapedType type) { return type.hasRank(); });
if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success();
auto argShape = blockArgTensorTy.getShape();
// reduce_c6, reduce_window_c13, select_and_scatter_c10
if (argShape.size() > allowedDimensions.size())
return emitOptionalError(
loc, "The rank of reduction-region's argument at index ",
numInputs + inputIdx,
" is expected to be <= ", allowedDimensions.size(), ", got ",
argShape.size());
int64_t argShapeIdx = 0;
for (int64_t outputShapeIdx = 0;
outputShapeIdx < static_cast<int64_t>(allowedDimensions.size()) &&
argShapeIdx < static_cast<int64_t>(argShape.size());
outputShapeIdx++)
if (verifyCompatibleDims(allowedDimensions[outputShapeIdx],
argShape[argShapeIdx]))
argShapeIdx++;
// reduce_c6, reduce_window_c13
if (argShapeIdx != static_cast<int64_t>(argShape.size()))
return emitOptionalError(
loc, "The shape of reduction-region's argument at index ",
numInputs + inputIdx,
" is not compatible with that of reduce-op's input-parameter "
"at index ",
inputIdx);
}
return success();
}
LogicalResult verifyReduceWindowOpInputsAndInferWindow(
std::optional<Location> location, SmallVector<ShapedType> inputTypes,
SmallVector<ShapedType> initValueTypes, ArrayRef<int64_t> windowDimensions,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<ArrayRef<int64_t>> baseDilations,
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding,
SmallVector<int64_t>& windowDims,
SmallVector<WindowDimension>& inferredWindow) {
// reduce_window_c1
if (inputTypes.empty())
return emitOptionalError(location, "requires at least 1 input value");
// Check for unranked tensors in input operands.
uint64_t numInputs = inputTypes.size();
int64_t rankedInputIdx = -1;
for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
if (inputTypes[inputIdx].hasRank()) {
rankedInputIdx = inputIdx;
break;
}
}
bool allInputsUnranked = (rankedInputIdx == -1);
// reduce_window_c2
if (!allInputsUnranked) {
for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx)
if (failed(mlir::verifyCompatibleShape(inputTypes[rankedInputIdx],
inputTypes[inputIdx])))
return emitOptionalError(
location, "expects all inputs to have compatible shapes. Shape at",
" input-index ", inputIdx,
" is not compatible with shape at input-index ", rankedInputIdx);
}
// reduce_window_c12, reduce_window_i7
auto paddingOrErr = convertPaddingAttribute(padding, location);
if (failed(paddingOrErr)) return failure();
// reduce_window_c4
for (const auto inputType : inputTypes) {
if (!inputType.hasRank()) continue;
if (inputType.getRank() != static_cast<int64_t>(windowDimensions.size()))
return emitOptionalError(
location, "expects window-dimensions size == input rank, but got ",
"window-dimensions size: ", windowDimensions.size(),
" and input: ", inputType, " with rank = ", inputType.getRank(), ".");
}
// reduce_window_c5...reduce_window_c12
auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
windowDimensions, windowStrides.value_or(SmallVector<int64_t, 0>{}),
*paddingOrErr,
/*lhsDilation=*/baseDilations.value_or(SmallVector<int64_t, 0>{}),
/*rhsDilation=*/windowDilations.value_or(SmallVector<int64_t, 0>{}),
/*windowReversal=*/std::nullopt, location);
if (failed(windowOrErr)) return failure();
windowDims.append(windowDimensions.begin(), windowDimensions.end());
inferredWindow.append(*windowOrErr);
return success();
}
// Shape function can be called directly from autogenerated `build()` function,
// which may not guarantee the added region(s) in `odsState.regions` to be
// non-empty. Need check it here to avoid a crash for the ops that need regions
// in type inference, i.e. `IfOp/CaseOp/MapOp`.
LogicalResult verifyRegionNotEmpty(std::optional<Location> location,
Region& region) {
if (region.empty())
return emitOptionalError(location, "expect non-empty region");
return success();
}
// Checks:
// P1. Same sizes for input, kernel and output spatialDims.
// P2. Spatial and non-spatial dimensions (for input,kernel, &output) should
// be unique and in range [0, num_dims), where num_dims = rank of input
// (lhs/rhs) tensors.
//
// Note that the spatial + non-spatial dimensions may not cover all the
// dimensions in the range [0,num) because of the presence of 'unknown'
// dimensions (ref. `printConvolutionDimensions()`)
LogicalResult isSpatialDimensionsValid(
Type lhsType, int64_t inputBatchDimension, int64_t inputFeatureDimension,
ArrayRef<int64_t> inputSpatialDimensions,
int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension,
ArrayRef<int64_t> kernelSpatialDimensions, int64_t outputBatchDimension,
int64_t outputFeatureDimension, ArrayRef<int64_t> outputSpatialDimensions,
std::optional<Location> location) {
uint64_t spatialDimNum = inputSpatialDimensions.size();
// P1.
if ((spatialDimNum != kernelSpatialDimensions.size()) ||
(spatialDimNum != outputSpatialDimensions.size()))
return emitOptionalError(location,
"expects the same size for input, kernel "
"and output spatial-dimensions, but got ",
spatialDimNum, ", ",
kernelSpatialDimensions.size(), ", and ",
outputSpatialDimensions.size(), " resp.");
// P2.
SmallVector<int64_t> inputDimNums(spatialDimNum + 2);
inputDimNums[0] = inputBatchDimension;
inputDimNums[1] = inputFeatureDimension;
std::copy(inputSpatialDimensions.begin(), inputSpatialDimensions.end(),
inputDimNums.begin() + 2);
SmallVector<int64_t> windowDimNums(spatialDimNum + 2);
windowDimNums[0] = kernelInputFeatureDimension;
windowDimNums[1] = kernelOutputFeatureDimension;
std::copy(kernelSpatialDimensions.begin(), kernelSpatialDimensions.end(),
windowDimNums.begin() + 2);
SmallVector<int64_t> OutputDimNums(spatialDimNum + 2);
OutputDimNums[0] = outputBatchDimension;
OutputDimNums[1] = outputFeatureDimension;
std::copy(outputSpatialDimensions.begin(), outputSpatialDimensions.end(),
OutputDimNums.begin() + 2);
auto numDims = lhsType.cast<RankedTensorType>().getRank();
const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; };
if (!llvm::all_of(inputDimNums, inRange) ||
!llvm::all_of(windowDimNums, inRange) ||
!llvm::all_of(OutputDimNums, inRange))
return emitOptionalError(location,
"expects input, kernel, and output "
"dimension-numbers to be in-range [0, ",
numDims, ").");
if (!isUnique(inputDimNums))
return emitOptionalError(
location, "expects input dimension-numbers to be unique, got {",
inputDimNums, "}.");
if (!isUnique(windowDimNums))
return emitOptionalError(
location, "expects kernel dimension-numbers to be unique, got {",
windowDimNums, "}.");
if (!isUnique(OutputDimNums))
return emitOptionalError(
location, "expects output dimension-numbers to be unique, got {",
OutputDimNums, "}.");
return success();
}
// Checks if the precision config has a valid size, if provided.
LogicalResult verifyPrecisionConfig(std::optional<Location> loc,
std::optional<ArrayAttr> maybeArrayAttr) {
if (!maybeArrayAttr.has_value()) return success();
auto arrayAttr = maybeArrayAttr.value();
if (!arrayAttr) return success();
return arrayAttr.size() <= 2
? success()
: emitOptionalError(loc,
"expects precision config to be empty or have "
"<= 2 elements.");
}
// Verifies the following properties:
// P1. The input, kernel, and output spatial-dimensions are valid.
// P2. Given,
// input-dimensions: b * input-spatial-dims * f
// kernel-dimensions: kernel-spatial-dims * i * o
// output-dimensions: b' * out-spatial-dims * f'
// where b = input-batch-dim
// where f = input-feature-dim
// where i = kernel-input-feature-dim
// where o = kernel-output-feature-dim
// where b' = output-batch-dim
// where f' = output-feature-dim
// Check the following properties w.r.t feature_group_count (fgc) and
// batch_group_count (bgc).
// * fgc > 0, bgc > 0 and !(fgc > 1 && bgc > 1)
// * dim(lhs, b) % bgc == 0
// * dim(lhs, f) % fgc == 0 and
// dim(lhs, f) / fgc = dim(rhs, i)
// * dim(rhs, o) (or dim(output, f')) % bgc == 0 and
// dim(rhs, o) (or dim(output, f')) % fgc == 0
// P3. Precision config is null, of size 0 or of size 2.
LogicalResult verifyConvolutionAttributes(
std::optional<Location> location, Type lhsType, Type rhsType,
int64_t inputBatchDimension, int64_t inputFeatureDimension,
ArrayRef<int64_t> inputSpatialDimensions,
int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension,
ArrayRef<int64_t> kernelSpatialDimensions, int64_t outputBatchDimension,
int64_t outputFeatureDimension, ArrayRef<int64_t> outputSpatialDimensions,
int64_t featureGroupCount, int64_t batchGroupCount,
std::optional<ArrayAttr> precisionConfig) {
// P1.
if (failed(isSpatialDimensionsValid(
lhsType, inputBatchDimension, inputFeatureDimension,
inputSpatialDimensions, kernelInputFeatureDimension,
kernelOutputFeatureDimension, kernelSpatialDimensions,
outputBatchDimension, outputFeatureDimension, outputSpatialDimensions,
location)))
return failure();
// P2.
if (featureGroupCount <= 0)
return emitOptionalError(
location, "expects feature_group_count to be a positive number, got ",
featureGroupCount, ".");
if (batchGroupCount <= 0)
return emitOptionalError(
location, "expects batch_group_count to be a positive number, got ",
batchGroupCount, ".");
if (batchGroupCount > 1 && featureGroupCount > 1)
return emitOptionalError(
location,
"expects batch_group_count and feature_group_count not to be both "
"greater than 1. Got ",
batchGroupCount, " and ", featureGroupCount, " resp.");
auto rankedLhsType = lhsType.cast<RankedTensorType>();
const int64_t inputFeatures = rankedLhsType.getShape()[inputFeatureDimension];
const int64_t inputBatch = rankedLhsType.getShape()[inputBatchDimension];
auto rankedRhsType = rhsType.cast<RankedTensorType>();
const int64_t kernelInputFeatures =
rankedRhsType.getShape()[kernelInputFeatureDimension];
const int64_t kernelOutputFeatures =
rankedRhsType.getShape()[kernelOutputFeatureDimension];
if (!isDynamicDimSize(kernelOutputFeatures)) {
if (kernelOutputFeatures % batchGroupCount != 0)
return emitOptionalError(
location, "expects output feature dimension size (",
kernelOutputFeatures,
") to be a multiple of batch_group_count. Got batch_group_count = ",
batchGroupCount, ".");
if (kernelOutputFeatures % featureGroupCount != 0)
return emitOptionalError(location,
"expects kernel output feature dimension (",
kernelOutputFeatures,
") to be divisible by feature_group_count. For "
"feature_group_count = ",
featureGroupCount, ".");
}
if (!isDynamicDimSize(inputFeatures)) {
if (inputFeatures % featureGroupCount != 0)
return emitOptionalError(location, "expects input feature dimension (",
inputFeatures,
") to be a multiple of feature_group_count. Got "
"feature_group_count = ",
featureGroupCount, ".");
if (!isDynamicDimSize(kernelInputFeatures) &&
inputFeatures / featureGroupCount != kernelInputFeatures)
return emitOptionalError(
location, "expects input feature dimension (", inputFeatures,
") / "
"feature_group_count = kernel input feature dimension (",
kernelInputFeatures,
"). Got feature_group_count = ", featureGroupCount, ".");
}
if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0)
return emitOptionalError(location, "expects input batch dimension (",
inputBatch,
") to be divisible by "
"batch_group_count. Got batch_group_count = ",
batchGroupCount, ".");
// P3.
if (failed(verifyPrecisionConfig(location, precisionConfig)))
return failure();
return success();
}
LogicalResult validateScatterDimensionNumbers(
ShapedType operandType, ArrayRef<int64_t> scatterIndicesShape,
ShapedType updateType, bool operandTypeRanked,
bool scatterIndicesTypeRanked, bool updatesTypeRanked,
ArrayRef<int64_t> updateWindowDims, ArrayRef<int64_t> insertedWindowDims,
ArrayRef<int64_t> scatterDimsToOperandDims, int64_t indexVectorDim,
std::optional<Location> loc) {
// scatter_c2
if (operandTypeRanked) {
auto windowSize = updateWindowDims.size() + insertedWindowDims.size();
if (operandType.getRank() != static_cast<int64_t>(windowSize))
return emitOptionalError(loc,
"Expects rank-of operand to match "
"size-of('update_window_dims') + "
"size-of('inserted_window_dims') i.e. ",
windowSize, " but got ", operandType.getRank(),
".");
}
// scatter_c7
if (!llvm::is_sorted(updateWindowDims))
return emitOptionalError(loc,
"Expects update_window_dims to be sorted; got: [",
updateWindowDims, "].");
// scatter_c7
if (!isUnique(updateWindowDims))
return emitOptionalError(loc,
"Expects update_window_dims to not repeat; got: [",
updateWindowDims, "].");
// scatter_c8
if (updatesTypeRanked) {
for (int64_t windowDim : updateWindowDims) {
if (windowDim < 0 || windowDim >= updateType.getRank())
return emitOptionalError(
loc,
"Expects each element of update_window_dims to be in range "
"[0, "
"rank-of('updates') i.e. [0, ",
updateType.getRank(), "). got: ", windowDim, ".");
}
}
// scatter_c9
if (!llvm::is_sorted(insertedWindowDims))
return emitOptionalError(
loc, "Expects inserted_window_dims to be sorted; got: [",
insertedWindowDims, "].");
// scatter_c9
if (!isUnique(insertedWindowDims))
return emitOptionalError(
loc, "Expects inserted_window_dims to not repeat; got: [",
insertedWindowDims, "].)");
// scatter_c10
if (operandTypeRanked) {
for (int64_t insertedDim : insertedWindowDims)
if (insertedDim < 0 || insertedDim >= operandType.getRank())
return emitOptionalError(
loc,
"Expects each element of inserted_window_dims to be in range "
"[0, rank-of('operand') i.e. [0, ",
operandType.getRank(), "). got: ", insertedDim, ".");
}
// scatter_c11
if (scatterIndicesTypeRanked) {
if (indexVectorDim == static_cast<int64_t>(scatterIndicesShape.size()) &&
scatterDimsToOperandDims.size() != 1)
return emitOptionalError(
loc, "Scatter op has ", scatterDimsToOperandDims.size(),
" elements in scatter_dims_to_operand_dims and "
"the bound of dimension index_vector_dim=",
indexVectorDim,
" of scatter_indices is 1. These two numbers must be equal.");
if (!isDynamicDimSize(scatterIndicesShape[indexVectorDim]) &&
static_cast<int64_t>(scatterDimsToOperandDims.size()) !=
scatterIndicesShape[indexVectorDim])
return emitOptionalError(loc, "Scatter op has ",
scatterDimsToOperandDims.size(),
" elements in scatter_dims_to_operand_dims and "
"the bound of dimension index_vector_dim=",
indexVectorDim, " of scatter_indices is ",
scatterIndicesShape[indexVectorDim],
". These two numbers must be equal.");
}
// scatter_c12
if (!isUnique(scatterDimsToOperandDims))
return emitOptionalError(
loc, "Expects scatter_dims_to_operand_dims to not repeat; got: [",
scatterDimsToOperandDims, "].");
// scatter_c13
if (operandTypeRanked) {
for (int64_t i = 0;
i < static_cast<int64_t>(scatterDimsToOperandDims.size()); ++i) {
int64_t scatterDimToOperandDim = scatterDimsToOperandDims[i];
if (scatterDimToOperandDim < 0 ||
scatterDimToOperandDim >= operandType.getRank())
return emitOptionalError(
loc, "Invalid scatter_dims_to_operand_dims mapping; domain is [0, ",
operandType.getRank(), "), got: ", i, "->", scatterDimToOperandDim,
".");
}
}
return success();
}
static LogicalResult verifyGather(
std::optional<Location> location, ShapeAdaptor operandShape,
ShapeAdaptor startIndicesShape, ShapeAdaptor sliceSizesShape,
ArrayRef<int64_t> offsetDims, ArrayRef<int64_t> collapsedSliceDims,
ArrayRef<int64_t> startIndexMap, int64_t indexVectorDim) {
// gather_c9
if (!isUnique(startIndexMap))
return emitOptionalError(location,
"expects start_index_map to not repeat, got: [",
startIndexMap, "]");
// gather_c10
for (int64_t i = 0; i < static_cast<int64_t>(startIndexMap.size()); ++i)
if (startIndexMap[i] < 0 ||
(operandShape.hasRank() && startIndexMap[i] >= operandShape.getRank()))
return emitOptionalError(
location, "start_index_map[", i, "]: ", startIndexMap[i],
" is out of bounds for ", "operand rank ", operandShape.getRank());
if (startIndicesShape.hasRank()) {
// gather_c2
// index_vector_dim == start_indices.rank implies a trailing 1 on the shape
// of start_indices.
if (indexVectorDim > startIndicesShape.getRank() || indexVectorDim < 0)
return emitOptionalError(location, "index_vector_dim ", indexVectorDim,
" is out of bounds for start indices with rank ",
startIndicesShape.getRank());
// gather_c3
bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank();
if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) {
int64_t effectiveDimSize;
if (impliedTrailingDim)
effectiveDimSize = 1;
else
effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim);
if (effectiveDimSize != static_cast<int64_t>(startIndexMap.size()))
return emitOptionalError(
location, "start_index_map size (", startIndexMap.size(),
") is not equal to size of index dimension (", indexVectorDim,
") of start_indices (", effectiveDimSize, ")");
}
}
// gather_c4
if (!llvm::is_sorted(offsetDims))
return emitOptionalError(
location, "expects offset_dims to be sorted, got: [", offsetDims, "]");
if (!isUnique(offsetDims))
return emitOptionalError(
location, "expects offset_dims to not repeat, got: [", offsetDims, "]");
// gather_c6
if (!llvm::is_sorted(collapsedSliceDims))
return emitOptionalError(
location, "expects collapsed_slice_dims to be sorted, got: [",
collapsedSliceDims, "]");
if (!isUnique(collapsedSliceDims))
return emitOptionalError(
location, "expects collapsed_slice_dims to not repeat, got: [",
collapsedSliceDims, "]");
// gather_c1
int64_t impliedOperandRank = offsetDims.size() + collapsedSliceDims.size();
if (operandShape.hasRank() && operandShape.getRank() != impliedOperandRank)
return emitOptionalError(
location, "offset_dims size (", offsetDims.size(),
") plus collapse_slice_dims size (", collapsedSliceDims.size(),
") is not equal to operand rank (", operandShape.getRank(), ")");
// gather_i7
if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1)
return emitOptionalError(location, "slice_sizes.rank != 1 (got ",
sliceSizesShape.getRank(), ')');
if (sliceSizesShape.hasStaticShape()) {
int64_t sliceSize = sliceSizesShape.getNumElements();
// gather_c11
if (sliceSize != impliedOperandRank)
return emitOptionalError(location, "slice_sizes size (", sliceSize,
") not equal to (implied) operand rank (",
impliedOperandRank, ")");
// gather_c7
for (auto dim : collapsedSliceDims)
if (dim < 0 || dim >= sliceSize)
return emitOptionalError(location, "collapsed dimension ", dim,
" is out of bounds for slice_sizes.size (",
sliceSize, ")");
}
return success();
}
template <typename dimTy>
static void inferGatherShape(
int64_t resultRank, llvm::function_ref<dimTy(int64_t)> getStartIndicesDim,
llvm::function_ref<dimTy(int64_t)> getSliceDim,
ArrayRef<int64_t> offsetDims, ArrayRef<int64_t> collapsedSliceDims,
ArrayRef<int64_t> startIndexMap, int64_t indexVectorDim,
SmallVectorImpl<dimTy>& shape) {
// We don't necessarily know the rank of sliceSizes, but we do know that it
// can't be larger than the highest collapsed dimension. So go through those
// and populate the leading dimensions of adjustedSliceSizes. The trailing
// dimensions can just be adjusted by an offset.
const auto* maxCollapsedDimIt =
std::max_element(collapsedSliceDims.begin(), collapsedSliceDims.end());
int64_t maxCollapsedDim = -1;
if (maxCollapsedDimIt != collapsedSliceDims.end())
maxCollapsedDim = *maxCollapsedDimIt;
SmallVector<dimTy> adjustedSliceSizePrefix;
for (int dimIndex = 0; dimIndex <= maxCollapsedDim; ++dimIndex) {
if (llvm::is_contained(collapsedSliceDims, dimIndex)) continue;
adjustedSliceSizePrefix.push_back(getSliceDim(dimIndex));
}
auto getAdjustedSliceDim = [&](int64_t index) -> dimTy {
if (index < static_cast<int64_t>(adjustedSliceSizePrefix.size()))
return adjustedSliceSizePrefix[index];
return getSliceDim(index + collapsedSliceDims.size());
};
// Dimensions in the output that aren't offset dimensions are called batch
// dimensions.
SmallVector<int64_t> batchDims;
for (int dim = 0; dim < resultRank; ++dim)
if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
for (int i = 0; i < resultRank; ++i) {
const auto* offsetDimsIt =
std::find(offsetDims.begin(), offsetDims.end(), i);
if (offsetDimsIt != offsetDims.end()) {
auto index = std::distance(offsetDims.begin(), offsetDimsIt);
shape.push_back(getAdjustedSliceDim(index));
continue;
}
auto* batchDimsIt = std::find(batchDims.begin(), batchDims.end(), i);
assert(batchDimsIt != batchDims.end());
auto index = std::distance(batchDims.begin(), batchDimsIt);
// This can never run into the special case where start_indices gets
// implicitly expanded with a trailing 1 if
// index_vector_dim = start_indices.rank because then index would equal
// index_vector_dim, which means we'd be looking at index+1, which would be
// out of bounds anyway.
if (index >= indexVectorDim) ++index;
shape.push_back(getStartIndicesDim(index));
}
}
void reifyGatherDimSizes(int64_t resultRank,
llvm::function_ref<Value(int64_t)> getStartIndicesDim,
llvm::function_ref<Value(int64_t)> getSliceDim,
ArrayRef<int64_t> offsetDims,
ArrayRef<int64_t> collapsedSliceDims,
ArrayRef<int64_t> startIndexMap,
int64_t indexVectorDim,
SmallVectorImpl<Value>& shape) {
inferGatherShape<Value>(resultRank, getStartIndicesDim, getSliceDim,
offsetDims, collapsedSliceDims, startIndexMap,
indexVectorDim, shape);
}
static LogicalResult inferGatherReturnTypeComponents(
std::optional<Location> location, ShapeAdaptor operandShape,
Value startIndices, llvm::function_ref<int64_t(int64_t)> getSliceDim,
ArrayRef<int64_t> offsetDims, ArrayRef<int64_t> collapsedSliceDims,
ArrayRef<int64_t> startIndexMap, int64_t indexVectorDim,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
Type elementType = operandShape.getElementType();
ShapeAdaptor startIndicesShape(startIndices.getType());
// We need this to determine the result rank. We could still place bounds on
// the result rank if that was something ShapedTypeComponents could express.
if (!startIndicesShape.hasRank()) {
inferredReturnShapes.push_back(elementType);
return success();
}
int64_t startIndicesRank = startIndicesShape.getRank();
// If index_vector_dim == start_indices.rank, then an implicit trailing 1 is
// appended to start_indices shape.
if (indexVectorDim == startIndicesRank) ++startIndicesRank;
int64_t resultRank = offsetDims.size() + startIndicesRank - 1;
// gather_c5
for (int64_t i = 0; i < static_cast<int64_t>(offsetDims.size()); ++i)
if (offsetDims[i] < 0 || offsetDims[i] >= resultRank)
return emitOptionalError(location, "offset_dims[", i,
"]: ", offsetDims[i], " is out of bounds for ",
"implied result rank ", resultRank);
auto getStartIndicesDim = [&](int64_t index) {
return startIndicesShape.getDimSize(index);
};
// gather_c13
SmallVector<int64_t> shape;
inferGatherShape<int64_t>(resultRank, getStartIndicesDim, getSliceDim,
offsetDims, collapsedSliceDims, startIndexMap,
indexVectorDim, shape);
// The dimension sizes of result, corresponding to offset dimensions, depend
// on attributes (like `collapsed_slice_dims` and `slice_sizes`) and hence are
// always static. Whereas, the dimension sizes of result, corresponding to
// batch dimensions, depends on input `start_indices` and could be dynamic.
// The corresponding bounds, in that case, are propagated from the
// `start_indices`.
Attribute encoding =
startIndices.getType().cast<RankedTensorType>().getEncoding();
ArrayRef<int64_t> startIndicesBounds = encodingToBounds(encoding);
SmallVector<int64_t> inferredBounds(resultRank, ShapedType::kDynamic);
if (!startIndicesBounds.empty()) {
llvm::BitVector isOffsetDim(resultRank);
for (auto offsetDim : offsetDims) isOffsetDim.set(offsetDim);
int64_t startIndicesDim = 0;
for (int resultDim = 0; resultDim < resultRank; ++resultDim) {
if (isOffsetDim.test(resultDim)) continue;
if (startIndicesDim == indexVectorDim) ++startIndicesDim;
inferredBounds[resultDim] = startIndicesBounds[startIndicesDim++];
}
}
inferredReturnShapes.emplace_back(shape, elementType,
boundsToEncoding(encoding, inferredBounds));
return success();
}
// Used by IfOp and CaseOp
LogicalResult inferConditionalOp(std::optional<Location> location,
Value operand, RegionRange branches,
SmallVectorImpl<Type>& inferredReturnTypes) {
// case_i1, if_i1
auto operandRankedTy = operand.getType().dyn_cast<RankedTensorType>();
if (operandRankedTy && operandRankedTy.getRank() != 0)
return emitOptionalError(location,
"operand should be rank 0 tensor but got rank ",
operandRankedTy.getRank());
// case_c1
if (branches.empty())
return emitOptionalError(location, "expect at least one branch");
for (auto region : branches)
if (failed(verifyRegionNotEmpty(location, *region))) return failure();
ValueTypeRange<OperandRange> branch0ResultTypes =
branches[0]->front().getTerminator()->getOperandTypes();
for (unsigned i = 0; i < branches.size(); ++i) {
Twine branchName = "branch " + Twine(i);
Region* region = branches[i];
// case_c2, if_c1
if (region->getNumArguments() != 0)
return emitOptionalError(location, branchName,
" must have 0 arguments, but found ",
region->getNumArguments());
// case_c3, if_c2
auto branchResultTypes = region->front().getTerminator()->getOperandTypes();
if (!hlo::isCompatibleForHloTypeInference(branch0ResultTypes,
branchResultTypes))
return emitOptionalError(location, "branch 0 and ", branchName,
" have mismatched return types: ",
branch0ResultTypes, " vs ", branchResultTypes);
}
// case_c4, if_c3
for (unsigned i = 0; i < branch0ResultTypes.size(); ++i) {
SmallVector<Type> inputTypes;
for (auto branch : branches)
inputTypes.push_back(
branch->front().getTerminator()->getOperandTypes()[i]);
auto inferredTypeOrErr = inferLeastSpecificType(location, inputTypes);
if (failed(inferredTypeOrErr)) return failure();
inferredReturnTypes.emplace_back(*inferredTypeOrErr);
}
return success();
}
LogicalResult verifyDimInBounds(std::optional<Location> loc, ShapedType type,
int64_t dim) {
if (dim < 0)
return emitOptionalError(
loc, "requires non-negative dimension attribute; found (", dim, ")");
if (type.hasRank() && dim >= type.getRank())
return emitOptionalError(loc, "requires dimension attribute in range [0, ",
type.getRank(), "); found (", dim, ")");
return success();
}
//===----------------------------------------------------------------------===//
// Shape functions for ops.
//===----------------------------------------------------------------------===//
LogicalResult inferAbsOp(std::optional<Location>, Value operand,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto operandTy = operand.getType().cast<ShapedType>();
// abs_c2
Type elementTy = operandTy.getElementType();
if (auto complexTy = elementTy.dyn_cast<ComplexType>())
elementTy = complexTy.getElementType();
// abs_c1
inferredReturnTypes.push_back(operandTy.clone(elementTy));
return success();
}
LogicalResult inferAfterAllOp(HloDialectInterface* dialect,
std::optional<Location> location,
SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(dialect->createTokenType());
return success();
}
LogicalResult inferAllToAllOp(
std::optional<Location> location, Value operand, int64_t splitDimension,
int64_t concatDimension, int64_t splitCount,
DenseIntElementsAttr replicaGroups,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// all_to_all_c4
if (splitCount <= 0)
return emitOptionalError(location, "AllToAll split_count must be > 0");
// all_to_all_c5, all_to_all_c7, all_to_all_i5
if (failed(verifyReplicaGroups(location, replicaGroups,
/*allGroupsMustHaveSameSize=*/true,
/*useGlobalDeviceIds=*/false, splitCount)))
return failure();
// all_to_all_c1
if (splitDimension < 0)
return emitOptionalError(location,
"AllToAll split_dimension cannot be negative");
// all_to_all_c3
if (concatDimension < 0)
return emitOptionalError(location,
"AllToAll concat_dimension cannot be negative");
Type operandType = operand.getType();
auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
if (!operandRankedType) {
inferredReturnShapes.emplace_back(
operandType.cast<ShapedType>().getElementType());
return success();
}
int64_t inputRank = operandRankedType.getRank();
// all_to_all_c1
if (splitDimension >= inputRank)
return emitOptionalError(location, "AllToAll split_dimension ",
splitDimension,
" is out-of-bounds for input rank ", inputRank);
// all_to_all_c3
if (concatDimension >= inputRank)
return emitOptionalError(location, "AllToAll concat_dimension ",
concatDimension,
" is out-of-bounds for input rank ", inputRank);
SmallVector<int64_t> resultShape(operandRankedType.getShape().begin(),
operandRankedType.getShape().end());
// all_to_all_c2
if (isStaticDimSize(resultShape[splitDimension]) &&
resultShape[splitDimension] % splitCount != 0)
return emitOptionalError(
location, "split dimension has size ", resultShape[splitDimension],
", expected to be a multiple of split_count ", splitCount);
if (isStaticDimSize(resultShape[splitDimension]))
resultShape[splitDimension] /= splitCount;
if (isStaticDimSize(resultShape[concatDimension]))
resultShape[concatDimension] *= splitCount;
SmallVector<int64_t> resultBounds =
to_vector(encodingToBounds(operandRankedType.getEncoding()));
if (!resultBounds.empty()) {
if (isStaticDimSize(resultBounds[splitDimension]))
resultBounds[splitDimension] /= splitCount;
if (isStaticDimSize(resultBounds[concatDimension]))
resultBounds[concatDimension] *= splitCount;
}
inferredReturnShapes.emplace_back(
resultShape, operandRankedType.getElementType(),
boundsToEncoding(operandRankedType.getEncoding(), resultBounds));
return success();
}
LogicalResult inferAllReduceOp(
std::optional<Location> location, ValueRange operands, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
TypeRange inputTypes = operands.getTypes();
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
// all_reduce_c6, all_reduce_c7
auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation);
if (failed(accumulatorTypesOrErr)) return failure();
for (size_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
inferredReturnShapes.emplace_back(
getSameShapeTensorType(inputArgTensorTypes[inputIdx],
(*accumulatorTypesOrErr)[0].getElementType()));
}
return success();
}
LogicalResult inferBatchNormGradOp(
std::optional<Location> location, Value operand, Value scale, Value mean,
Value variance, Value gradOutput, int64_t featureIndex,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
return inferBatchNormOp(location, {operand, gradOutput},
{scale, mean, variance}, featureIndex,
inferredReturnShapes, /*is_inference=*/false);
}
LogicalResult inferBatchNormInferenceOp(
std::optional<Location> location, Value operand, Value scale, Value offset,
Value mean, Value variance, int64_t featureIndex,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
return inferBatchNormOp(location, {operand}, {scale, offset, mean, variance},
featureIndex, inferredReturnShapes,
/*is_inference=*/true);
}
LogicalResult inferBatchNormTrainingOp(
std::optional<Location> location, Value operand, Value scale, Value offset,
int64_t featureIndex,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
return inferBatchNormOp(location, {operand}, {scale, offset}, featureIndex,
inferredReturnShapes, /*is_inference=*/false);
}
LogicalResult inferBroadcastOp(
std::optional<Location> location, Value operand,
ArrayRef<int64_t> broadcastSizes,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
if (!operandType) return failure();
for (int64_t size : broadcastSizes)
if (size < 0)
return emitOptionalError(location,
"Broadcast with negative dimension size ", size);
SmallVector<int64_t> shapeValues(broadcastSizes);
llvm::append_range(shapeValues, operandType.getShape());
inferredReturnShapes.emplace_back(shapeValues, operandType.getElementType());
return success();
}
LogicalResult inferCaseOp(std::optional<Location> location, Value index,
RegionRange branches,
SmallVectorImpl<Type>& inferredReturnTypes) {
return inferConditionalOp(location, index, branches, inferredReturnTypes);
}
LogicalResult inferCholeskyOp(
std::optional<Location> location, Value a,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
Type aType = a.getType();
RankedTensorType aRankedType = aType.dyn_cast<RankedTensorType>();
if (!aRankedType) {
// cholesky_c1
inferredReturnShapes.emplace_back(
aType.cast<ShapedType>().getElementType());
return success();
}
ArrayRef<int64_t> aShape = aRankedType.getShape();
// cholesky_c2
if (aShape.size() < 2)
return emitOptionalError(
location, "argument 'a' must have rank >= 2, got shape ", aShape, ".");
// cholesky_c3
if (!verifyCompatibleDims(aShape[aShape.size() - 2],
aShape[aShape.size() - 1]))
return emitOptionalError(
location, "minor dimensions of 'a' must have equal size, got shape ",
aShape, ".");
// cholesky_c1
inferredReturnShapes.emplace_back(aRankedType.getShape(),
aRankedType.getElementType(),
aRankedType.getEncoding());
return success();
}
LogicalResult inferClampOp(
std::optional<Location> location, Value min, Value operand, Value max,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto operandType = operand.getType().cast<RankedTensorType>();
auto operandShape = operandType.getShape();
auto minType = min.getType().cast<RankedTensorType>();
// clamp_c1
auto minShape = minType.getShape();
if (failed(verifyCompatibleShape(minType, operandType)) &&
minType.getRank() != 0)
return emitOptionalError(
location, "min shape [",
llvm::make_range(minShape.begin(), minShape.end()),
"] is not scalar and is not compatible to operand shape [",
llvm::make_range(operandShape.begin(), operandShape.end()), "]");
// clamp_c2
auto maxType = max.getType().cast<RankedTensorType>();
auto maxShape = maxType.getShape();
if (failed(verifyCompatibleShape(maxType, operandType)) &&
maxType.getRank() != 0)
return emitOptionalError(
location, "max shape [",
llvm::make_range(maxShape.begin(), maxShape.end()),
"] is not scalar and is not compatible to operand shape [",
llvm::make_range(operandShape.begin(), operandShape.end()), "]");
// clamp_c4
inferredReturnShapes.emplace_back(operandType.cast<ShapedType>());
return success();
}
LogicalResult inferCompareOp(
MLIRContext* context, std::optional<Location>, Value lhs,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// compare_c1
ShapedTypeComponents& components =
inferredReturnShapes.emplace_back(IntegerType::get(context, /*width=*/1));
auto argTy = lhs.getType().cast<ShapedType>();
// compare_c2
if (argTy.hasRank())
components =
ShapedTypeComponents(argTy.getShape(), components.getElementType());
return success();
}
LogicalResult inferComplexOp(std::optional<Location> location, Value lhs,
SmallVectorImpl<Type>& inferredReturnTypes) {
ShapedType operandType = lhs.getType().cast<ShapedType>();
// complex_c3
ComplexType elementTy = ComplexType::get(operandType.getElementType());
// complex_c2
inferredReturnTypes.push_back(getSameShapeTensorType(operandType, elementTy));
return success();
}
LogicalResult inferConcatenateOp(std::optional<Location> location,
TypeRange inputTypes, int64_t dimension,
SmallVectorImpl<Type>& inferredReturnTypes) {
// concatenate_c4
if (dimension < 0)
return emitOptionalError(location, "dimension ", dimension, " is negative");
RankedTensorType firstRankedType;
int firstRankedIndex = -1;
for (uint64_t i = 0; i < inputTypes.size(); i++) {
auto secondType = inputTypes[i].cast<ShapedType>();
if (!secondType.hasRank()) continue;
if (!firstRankedType) {
firstRankedType = secondType.cast<RankedTensorType>();
firstRankedIndex = i;
// concatenate_c4
if (firstRankedType.getRank() == 0)
return emitOptionalError(location,
"rank-0 values cannot be concatenated");
// concatenate_c4
if (dimension >= firstRankedType.getRank())
return emitOptionalError(location, "dimension ", dimension,
" is out-of-bounds for input rank ",
firstRankedType.getRank());
continue;
}
// concatenate_c2
if (firstRankedType.getRank() != secondType.getRank())
return emitOptionalError(location, "operands (", firstRankedIndex,
") and (", i, ") do not match rank");
auto firstShape = firstRankedType.getShape();
auto secondShape = secondType.getShape();
for (int d = 0; d < firstRankedType.getRank(); ++d) {
// concatenate_c2
if (d != dimension &&
!verifyCompatibleDims(firstShape[d], secondShape[d]))
return emitOptionalError(
location, "shapes of operand (", firstRankedIndex, ") and (", i,
") do not match at non-concat "
"index: (",
llvm::make_range(firstShape.begin(), firstShape.end()), ") != (",
llvm::make_range(secondShape.begin(), secondShape.end()),
") at non-concat index ", d);
}
}
// concatenate_c5
auto elementType = inputTypes[0].cast<ShapedType>().getElementType();
if (!firstRankedType) {
inferredReturnTypes.push_back(UnrankedTensorType::get(elementType));
return success();
}
// Infer the most specific (size, bound) of all dimensions of the return type
auto rank = firstRankedType.getRank();
SmallVector<int64_t> inferredSizes(rank, ShapedType::kDynamic);
SmallVector<int64_t> inferredBounds(rank, ShapedType::kDynamic);
// Note: for the concatenate dimension, 0 should be the identity element:
// Any dim size can keep unchanged when concatenated with 0
inferredSizes[dimension] = 0;
bool anyInputHaveBounds = false;
// Note: unranked input types can't be ignored, consider these input types:
// c0: (<5x?xf32>, <*xf32>) with concat dim 0 should infer <?x?xf32>
// c1: (<5x?xf32>, <*xf32>) with concat dim 1 should infer <5x?xf32>
// Instead, they should be replaced with dynamic tensors: tensor<?x...?x>
for (const auto& it : llvm::enumerate(inputTypes)) {
RankedTensorType rankedType = it.value().dyn_cast<RankedTensorType>();
SmallVector<int64_t> bounds;
if (rankedType)
bounds = to_vector(encodingToBounds(rankedType.getEncoding()));
if (!bounds.empty()) anyInputHaveBounds = true;
for (int dim = 0; dim < rank; ++dim) {
std::pair<int64_t, int64_t> inferredDimAndBound;
int64_t leftSize = inferredSizes[dim];
int64_t rightSize =
rankedType ? rankedType.getShape()[dim] : ShapedType::kDynamic;
int64_t leftBound = inferredBounds[dim];
int64_t rightBound = bounds.empty() ? ShapedType::kDynamic : bounds[dim];
if (dim == dimension) {
inferredDimAndBound = inferConcatenatedDimAndBound(
leftSize, rightSize, leftBound, rightBound);
} else {
auto inferredDimAndBoundOrErr = inferMostSpecificDimAndBound(
location, dim, leftSize, rightSize, leftBound, rightBound);
if (failed(inferredDimAndBoundOrErr)) return failure();
inferredDimAndBound = *inferredDimAndBoundOrErr;
}
inferredSizes[dim] = inferredDimAndBound.first;
inferredBounds[dim] = inferredDimAndBound.second;
}
}
// concatenate_c5, concatenate_c6
inferredReturnTypes.push_back(RankedTensorType::get(
inferredSizes, elementType,
boundsToEncoding(
firstRankedType.getEncoding(),
// Empty array as argument is an indicator to boundsToEncoding() that
// there are no bounds at all in inputs, thus sparsity attributes will
// be included in the return type
anyInputHaveBounds ? inferredBounds : llvm::ArrayRef<int64_t>({}))));
return success();
}
LogicalResult inferConstantOp(std::optional<Location>, ElementsAttr value,
SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(value.getType());
return success();
}
LogicalResult inferConvertOp(
std::optional<Location> location, Value operand,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto operandType = operand.getType().cast<ShapedType>();
// convert_c1
inferredReturnShapes.emplace_back(
operandType.hasRank() ? operandType.getShape() : ArrayRef<int64_t>{});
return success();
}
/*
* We intend to verify the following properties
* P1. Verify the input, kernel types.
* P2. Verify the convolution atributes.
* P3. Verify and collect the window atributes.
* P4. Verify precision_config attribute.
* P5. Verify the return shape.
* TODO(b/232574102): Verify the element-type of return-value.
*/
LogicalResult inferConvolutionOp(
std::optional<Location> location, Type lhsType, Type rhsType,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<DenseIntElementsAttr> padding,
std::optional<ArrayRef<int64_t>> lhsDilation,
std::optional<ArrayRef<int64_t>> rhsDilation,
std::optional<ArrayRef<bool>> windowReversal, int64_t inputBatchDimension,
int64_t inputFeatureDimension, ArrayRef<int64_t> inputSpatialDimensions,
int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension,
ArrayRef<int64_t> kernelSpatialDimensions, int64_t outputBatchDimension,
int64_t outputFeatureDimension, ArrayRef<int64_t> outputSpatialDimensions,
int64_t featureGroupCount, int64_t batchGroupCount,
std::optional<ArrayAttr> precisionConfig,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto rankedLhsType = lhsType.dyn_cast<RankedTensorType>();
auto rankedRhsType = rhsType.dyn_cast<RankedTensorType>();
if (!rankedLhsType || !rankedRhsType) {
inferredReturnShapes.push_back({});
return success();
}
// P1.
int numDims = rankedLhsType.getRank();
if (numDims != rankedRhsType.getRank())
return emitOptionalError(location,
"expects convolution arguments to have same "
"number of dimensions. Got: ",
rankedLhsType, " and ", rankedRhsType, ".");
if (numDims < 2)
return emitOptionalError(
location,
"expects convolution arguments to have >= 2 dimensions. Got: ",
rankedLhsType, " and ", rankedRhsType, ".");
// P2.
if (failed(verifyConvolutionAttributes(
location, lhsType, rhsType, inputBatchDimension,
inputFeatureDimension, inputSpatialDimensions,
kernelInputFeatureDimension, kernelOutputFeatureDimension,
kernelSpatialDimensions, outputBatchDimension, outputFeatureDimension,
outputSpatialDimensions, featureGroupCount, batchGroupCount,
precisionConfig)))
return failure();
if ((size_t)numDims != inputSpatialDimensions.size() + 2)
return emitOptionalError(location, "expects convolution arguments to have ",
inputSpatialDimensions.size() + 2,
" dimensions. Got: ", numDims);
// P3.
SmallVector<int64_t> windowDimensions(kernelSpatialDimensions.size());
for (size_t i = 0; i < windowDimensions.size(); i++)
windowDimensions[i] = rankedRhsType.getShape()[kernelSpatialDimensions[i]];
auto paddingOrErr = convertPaddingAttribute(padding, location);
if (failed(paddingOrErr)) return failure();
// TODO: add missing tests for ConvolutionOp.
auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
windowDimensions, windowStrides.value_or(ArrayRef<int64_t>{}),
*paddingOrErr, lhsDilation.value_or(ArrayRef<int64_t>{}),
rhsDilation.value_or(ArrayRef<int64_t>{}),
windowReversal.value_or(ArrayRef<bool>{}), location);
if (failed(windowOrErr)) return failure();
// P3.
if (failed(verifyPrecisionConfig(location, precisionConfig)))
return failure();
// P5.
SmallVector<int64_t> outputDimensions(rankedLhsType.getShape().size(),
ShapedType::kDynamic);
// Infer the output spatial dimensions.
auto numSpatialDims = inputSpatialDimensions.size();
SmallVector<int64_t> inputSpatialDimVals(numSpatialDims);
for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
inputSpatialDimVals[i] =
rankedLhsType.getShape()[inputSpatialDimensions[i]];
auto windowOutputShape =
inferWindowOutputShape(inputSpatialDimVals, *windowOrErr);
for (int64_t i = 0; i < static_cast<int64_t>(windowOrErr->size()); ++i)
outputDimensions[outputSpatialDimensions[i]] = windowOutputShape[i];
// Infer the output-batch-dimension and output-feature-dimension.
const int64_t inputBatch = rankedLhsType.getShape()[inputBatchDimension];
const int64_t kernelOutputFeatures =
rankedRhsType.getShape()[kernelOutputFeatureDimension];
outputDimensions[outputBatchDimension] = isDynamicDimSize(inputBatch)
? ShapedType::kDynamic
: inputBatch / batchGroupCount;
outputDimensions[outputFeatureDimension] = kernelOutputFeatures;
inferredReturnShapes.emplace_back(outputDimensions);
return success();
}
LogicalResult inferCreateTokenOp(HloDialectInterface* dialect,
std::optional<Location> location,
SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(dialect->createTokenType());
return success();
}
LogicalResult inferDotOp(
std::optional<Location> location, RankedTensorType lhsType,
RankedTensorType rhsType, std::optional<ArrayAttr> precisionConfig,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
if (failed(verifyPrecisionConfig(location, precisionConfig)))
return failure();
SmallVector<int64_t> dimensions;
if (1 == lhsType.getRank() && 1 == rhsType.getRank() &&
// vector dot vector
verifyCompatibleDims(lhsType.getDimSize(0), rhsType.getDimSize(0))) {
} else if (2 == lhsType.getRank() && 1 == rhsType.getRank() &&
verifyCompatibleDims(lhsType.getDimSize(1),
rhsType.getDimSize(0))) {
// matrix dot vector
dimensions.push_back(lhsType.getDimSize(0));
} else if (1 == lhsType.getRank() && 2 == rhsType.getRank() &&
verifyCompatibleDims(lhsType.getDimSize(0),
rhsType.getDimSize(0))) {
// vector dot matrix
dimensions.push_back(rhsType.getDimSize(1));
} else if (2 == lhsType.getRank() && 2 == rhsType.getRank() &&
verifyCompatibleDims(lhsType.getDimSize(1),
rhsType.getDimSize(0))) {
// matrix dot matrix
dimensions.push_back(lhsType.getDimSize(0));
dimensions.push_back(rhsType.getDimSize(1));
} else {
return emitOptionalError(location,
"expected both lhs/rhs ranks to be "
"either 1 or 2");
}
inferredReturnShapes.emplace_back(dimensions);
return success();
}
LogicalResult inferDotGeneralOp(
std::optional<Location> location, Type lhsType, Type rhsType,
ArrayRef<int64_t> lhsBatchingDimensions,
ArrayRef<int64_t> rhsBatchingDimensions,
ArrayRef<int64_t> lhsContractingDimensions,
ArrayRef<int64_t> rhsContractingDimensions,
std::optional<ArrayAttr> precisionConfig,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// dot_general_c11
if (failed(verifyPrecisionConfig(location, precisionConfig)))
return failure();
// dot_general_c1
if (lhsBatchingDimensions.size() != rhsBatchingDimensions.size())
return emitOptionalError(location,
"lhs and rhs should have the same "
"number of batching dimensions");
// dot_general_c2
if (lhsContractingDimensions.size() != rhsContractingDimensions.size())
return emitOptionalError(location,
"lhs and rhs should have the same "
"number of contracting dimensions");
llvm::SmallDenseSet<int64_t> dimSet;
auto checkDimsDistinct =
[&](ArrayRef<int64_t> batchingDims, ArrayRef<int64_t> contractingDims,
llvm::SmallDenseSet<int64_t>& dimSet, llvm::StringRef lhs,
llvm::StringRef rhs) -> LogicalResult {
auto dims = llvm::concat<const int64_t>(batchingDims, contractingDims);
for (auto dim : dims) {
auto [_, wasInserted] = dimSet.insert(dim);
if (!wasInserted)
return emitOptionalError(location, "has duplicated dimension from ",
lhs, " and ", rhs, ": ", dim);
}
return success();
};
// dot_general_c3
if (failed(checkDimsDistinct(lhsBatchingDimensions, lhsContractingDimensions,
dimSet, "lhs_batching_dimensions",
"lhs_contracting_dimensions")))
return failure();
dimSet.clear();
// dot_general_c4
if (failed(checkDimsDistinct(rhsBatchingDimensions, rhsContractingDimensions,
dimSet, "rhs_batching_dimensions",
"rhs_contracting_dimensions")))
return failure();
auto checkDimsInRange = [&](int64_t rank, ArrayRef<int64_t> dims,
llvm::StringRef dimName) -> LogicalResult {
auto inRange = [&](int64_t i) -> bool { return 0 <= i && i < rank; };
const auto* dimsNotInRange =
std::find_if_not(dims.begin(), dims.end(), inRange);
if (dimsNotInRange != dims.end())
return emitOptionalError(location, dimName, " value: ", *dimsNotInRange,
" is out of range: ", "[0, ", rank, ")");
return success();
};
auto lhsRankedType = lhsType.dyn_cast<RankedTensorType>();
if (lhsRankedType) {
// dot_general_c5
// dot_general_c6
if (failed(checkDimsInRange(lhsRankedType.getRank(), lhsBatchingDimensions,
"lhs_batching_dimensions")) ||
failed(checkDimsInRange(lhsRankedType.getRank(),
lhsContractingDimensions,
"lhs_contracting_dimensions")))
return failure();
}
auto rhsRankedType = rhsType.dyn_cast<RankedTensorType>();
if (rhsRankedType) {
// dot_general_c7
// dot_general_c8
if (failed(checkDimsInRange(rhsRankedType.getRank(), rhsBatchingDimensions,
"rhs_batching_dimensions")) ||
failed(checkDimsInRange(rhsRankedType.getRank(),
rhsContractingDimensions,
"rhs_contracting_dimensions")))
return failure();
}
if (lhsRankedType && rhsRankedType) {
// Dimension sizes must be compatible for lhs/rhs.
auto lhsShape = lhsRankedType.getShape();
auto rhsShape = rhsRankedType.getShape();
for (auto [lhs, rhs] :
llvm::zip(lhsBatchingDimensions, rhsBatchingDimensions)) {
// dot_general_c9
if (!verifyCompatibleDims(lhsShape[lhs], rhsShape[rhs]))
return emitOptionalError(location,
"batching dimension sizes must "
"match for lhs/rhs");
}
for (auto [lhs, rhs] :
llvm::zip(lhsContractingDimensions, rhsContractingDimensions)) {
// dot_general_c10
if (!verifyCompatibleDims(lhsShape[lhs], rhsShape[rhs]))
return emitOptionalError(location,
"contracting dimension sizes must "
"match for lhs/rhs");
}
}
if (!lhsRankedType || !rhsRankedType) {
inferredReturnShapes.push_back({});
return success();
}
// Infer the output dimensions of the operation.
auto lhsShape = lhsRankedType.getShape();
auto rhsShape = rhsRankedType.getShape();
SmallVector<int64_t> dimensions;
for (const int64_t lhsBatchingDim : lhsBatchingDimensions)
dimensions.push_back(lhsShape[lhsBatchingDim]);
for (int64_t i = 0; i < lhsRankedType.getRank(); i++)
if (!llvm::is_contained(lhsBatchingDimensions, i) &&
!llvm::is_contained(lhsContractingDimensions, i))
dimensions.push_back(lhsShape[i]);
for (int64_t i = 0; i < rhsRankedType.getRank(); i++)
if (!llvm::is_contained(rhsBatchingDimensions, i) &&
!llvm::is_contained(rhsContractingDimensions, i))
dimensions.push_back(rhsShape[i]);
// dot_general_c12
inferredReturnShapes.emplace_back(dimensions);
return success();
}
LogicalResult inferDynamicGatherOp(
std::optional<Location> location, Value operand, Value startIndices,
Value sliceSizes, ArrayRef<int64_t> offsetDims,
ArrayRef<int64_t> collapsedSliceDims, ArrayRef<int64_t> startIndexMap,
int64_t indexVectorDim,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
ShapeAdaptor operandShape(operand.getType());
ShapeAdaptor startIndicesShape(startIndices.getType());
ShapeAdaptor sliceSizesShape(sliceSizes.getType());
if (failed(verifyGather(location, /*operandShape=*/operandShape,
/*startIndicesShape=*/startIndicesShape,
/*sliceSizesShape=*/sliceSizesShape, offsetDims,
collapsedSliceDims, startIndexMap, indexVectorDim)))
return failure();
auto getSliceDim = [&](int64_t index) {
DenseIntElementsAttr sliceSizesAttr;
if (!matchPattern(sliceSizes, m_Constant(&sliceSizesAttr)))
return ShapedType::kDynamic;
return sliceSizesAttr.getValues<APInt>()[index].getSExtValue();
};
return inferGatherReturnTypeComponents(
location, operandShape, startIndices, getSliceDim, offsetDims,
collapsedSliceDims, startIndexMap, indexVectorDim, inferredReturnShapes);
}
LogicalResult inferDynamicSliceOp(
std::optional<Location> location, Type operandType,
TypeRange startIndicesTypes, ArrayRef<int64_t> sliceSizes,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// dynamic_slice_c2
int numSliceSizes = sliceSizes.size();
int numStartIndices = startIndicesTypes.size();
if (numStartIndices != numSliceSizes)
return emitOptionalError(location, "has mismatched number of slice sizes (",
numSliceSizes, ") and number of start indices (",
numStartIndices, ")");
auto rankedOperandType = operandType.dyn_cast<RankedTensorType>();
if (!rankedOperandType) return failure();
// dynamic_slice_c2
if (rankedOperandType.getRank() != numStartIndices)
return emitOptionalError(
location, "has mismatched number of start indices (", numStartIndices,
") and the rank of operand (", rankedOperandType.getRank(), ")");
// dynamic_slice_c3
if (!tensorsHaveSameElType(startIndicesTypes))
return emitOptionalError(location,
"start indices must have same element type");
// dynamic_slice_c4
for (int i = 0; i < numSliceSizes; ++i) {
int64_t sliceSize = sliceSizes[i];
if (sliceSize < 0)
return emitOptionalError(
location, "has negative size index to dynamic slice: ", sliceSize);
if (!rankedOperandType.isDynamicDim(i)) {
int64_t dimSize = rankedOperandType.getDimSize(i);
if (sliceSize > dimSize)
return emitOptionalError(location, "has slice size ", sliceSize,
" greater than dimension size ", dimSize,
" in dimension ", i, " of operand");
}
}
// dynamic_slice_c5
inferredReturnShapes.emplace_back(sliceSizes,
rankedOperandType.getElementType());
return success();
}
LogicalResult inferDynamicUpdateSliceOp(
std::optional<Location> location, Value operand, Value update,
ValueRange startIndices,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto operandType = operand.getType().cast<ShapedType>();
auto updateType = update.getType().cast<ShapedType>();
// dynamic_update_slice_c3
if (updateType.hasRank() && operandType.hasRank() &&
updateType.getRank() != operandType.getRank())
return emitOptionalError(
location,
"update rank does not match operand rank: ", updateType.getRank(),
" vs ", operandType.getRank(), ".");
// dynamic_update_slice_c4
if (operandType.hasRank() &&
(int64_t)startIndices.size() != operandType.getRank())
return emitOptionalError(
location, "expects number of start_indices to match operand rank: ",
startIndices.size(), " vs ", operandType.getRank(), ".");
// dynamic_update_slice_c5
if (!tensorsHaveSameElType(startIndices.getTypes()))
return emitOptionalError(location,
"start indices must have same element type");
// dynamic_update_slice_c6
if (operandType.hasRank() && updateType.hasRank())
for (auto [index, dims] : llvm::enumerate(
llvm::zip(operandType.getShape(), updateType.getShape()))) {
auto [operandDim, updateDim] = dims;
if (isDynamicDimSize(updateDim)) continue;
if (isStaticDimSize(operandDim)) {
if (updateDim < 0 || updateDim > operandDim)
return emitOptionalError(location, "expects size at dimension ",
index, " of update to be in range [0, ",
operandDim, "]. Got: ", updateDim, ".");
} else {
if (updateDim < 0)
return emitOptionalError(
location, "expects size at dimension ", index,
" of update to be non-negative. Got: ", updateDim, ".");
}
}
// dynamic_update_slice_c1
if (operandType.hasRank())
inferredReturnShapes.emplace_back(
operandType.getShape(), operandType.getElementType(),
operandType.cast<RankedTensorType>().getEncoding());
else
inferredReturnShapes.emplace_back(operandType.getElementType());
return success();
}
// We intend to verify the following properties
// P1. 1 <= rank <= 3
// P2. Element types agree with fft_type
// P3. Operand shape dimensions agree with fft_length for the given fft_type
LogicalResult inferFftOp(
std::optional<Location> location, Value operand, bool isFftTypeRfft,
bool isFftTypeIrfft, ArrayRef<int64_t> fftLength,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
int64_t fftRank = fftLength.size();
// P1.
if (fftRank > 3 || fftRank < 1)
return emitOptionalError(location, "rank must be between 1 and 3, but got ",
fftRank, ".");
// P2. Element type agreement
// FFT : C -> C
// IFFT : C -> C
// RFFT : R -> C
// IRFFT : C -> R
auto operandType = operand.getType().cast<ShapedType>();
Type operandElementType = operandType.getElementType();
// Check the input element type and infer return element type
if (isFftTypeRfft) {
if (!operandElementType.isF32() && !operandElementType.isF64())
return emitOptionalError(
location, "RFFT requires f32 or f64 input type, but is given ",
operandElementType, ".");
} else {
if (!operandElementType.isa<ComplexType>())
return emitOptionalError(location, "FFT/IFFT/IRFFT",
" take a complex tensor as input, but is given ",
operandType, ".");
}
// Generate the output element type
Type resultElementType = operandElementType;
if (isFftTypeRfft) // RFFT : R -> C
resultElementType = ComplexType::get(resultElementType);
else if (isFftTypeIrfft) // IRFFT : C -> R
resultElementType = operandElementType.cast<ComplexType>().getElementType();
// P3. Check input shape and infer return shape
auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
if (!operandRankedType) {
inferredReturnShapes.emplace_back(resultElementType);
return success();
}
auto operandShape = operandRankedType.getShape();
if (static_cast<int64_t>(operandShape.size()) < fftRank)
return emitOptionalError(
location, "operand rank must not be less than fft rank of ", fftRank,
" for operand of type ", operandRankedType, ".");
SmallVector<int64_t> resultShape = to_vector(operandShape);
if (isFftTypeRfft) {
auto shapeBack = operandShape.take_back(fftRank);
for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
if (!verifyCompatibleDims(operandDim, fftDim))
return emitOptionalError(location,
"RFFT requires innermost dimensions to be "
"compatible with fft_length. Got: ",
operandShape, " but wanted ", fftLength, ".");
}
if (fftLength[fftRank - 1] != 0)
resultShape[resultShape.size() - 1] = fftLength[fftRank - 1] / 2 + 1;
}
if (isFftTypeIrfft) {
auto shapeBack = operandShape.take_back(fftRank).drop_back();
for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
if (!verifyCompatibleDims(operandDim, fftDim))
return emitOptionalError(location,
"IRFFT requires non-final dimensions to be "
"compatible with fft_length. Got: ",
operandShape, " but wanted ", fftLength,
", and ", operandDim, " != ", fftDim, ".");
}
if ((!verifyCompatibleDims(operandShape[operandShape.size() - 1], 0) ||
fftLength[fftRank - 1] != 0) &&
!verifyCompatibleDims(operandShape[operandShape.size() - 1],
fftLength[fftRank - 1] / 2 + 1))
return emitOptionalError(location,
"IRFFT requires innermost dimension to be "
"compatible with fft_length[-1]/2+1. Got: ",
operandShape[operandShape.size() - 1],
" but fft_length is ", fftLength, ".");
resultShape[resultShape.size() - 1] = fftLength[fftRank - 1];
}
auto resultBounds = encodingToBounds(operandRankedType.getEncoding()).vec();
if ((isFftTypeIrfft || isFftTypeRfft) && !resultBounds.empty())
resultBounds.back() = ShapedType::kDynamic;
inferredReturnShapes.emplace_back(
resultShape, resultElementType,
boundsToEncoding(operandRankedType.getEncoding(), resultBounds));
return success();
}
LogicalResult inferGatherOp(
std::optional<Location> location, Value operand, Value startIndices,
ArrayRef<int64_t> offsetDims, ArrayRef<int64_t> collapsedSliceDims,
ArrayRef<int64_t> startIndexMap, int64_t indexVectorDim,
ArrayRef<int64_t> sliceSizes,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
ShapeAdaptor operandShape(operand.getType());
ShapeAdaptor startIndicesShape(startIndices.getType());
SmallVector<int64_t, 1> ssShape{static_cast<int64_t>(sliceSizes.size())};
ShapedTypeComponents ssSTC{ssShape};
ShapeAdaptor sliceSizesShape(ssSTC);
// For some reason the getType call is necessary here
if (failed(verifyGather(location,
/*operandShape=*/operandShape,
/*startIndicesShape=*/startIndicesShape,
/*sliceSizesShape=*/sliceSizesShape, offsetDims,
collapsedSliceDims, startIndexMap, indexVectorDim)))
return failure();
// gather_c8
for (auto dim : collapsedSliceDims) {
int64_t sliceDimSize = sliceSizes[dim];
if (sliceDimSize > 1)
return emitOptionalError(location, "slice_sizes collapsed dimension ",
dim, " should <= 1 but got ", sliceDimSize);
}
// gather_c12
if (operandShape.hasRank()) {
for (const auto& it : llvm::enumerate(sliceSizes)) {
if (operandShape.isDynamicDim(it.index())) continue;
auto operandDimSize = operandShape.getDimSize(it.index());
auto sliceDimSize = it.value();
if (sliceDimSize < 0 || sliceDimSize > operandDimSize)
return emitOptionalError(location, "slice size (", sliceDimSize,
") is out of bounds for operand dimension (",
operandDimSize, ") at index ", it.index());
}
}
auto getSliceDim = [&sliceSizes](int64_t index) -> int64_t {
return sliceSizes[index];
};
return inferGatherReturnTypeComponents(
location, operandShape, startIndices, getSliceDim, offsetDims,
collapsedSliceDims, startIndexMap, indexVectorDim, inferredReturnShapes);
}
LogicalResult inferGetDimensionSizeOp(
std::optional<Location> location, Type operandType, int64_t dimension,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// get_dimension_size_c1
if (failed(verifyDimInBounds(location, operandType.cast<ShapedType>(),
dimension)))
return failure();
inferredReturnShapes.emplace_back(
ArrayRef<int64_t>{}, IntegerType::get(operandType.getContext(), 32));
return success();
}
LogicalResult inferGetTupleElementOp(
std::optional<Location> location, Value operand, int32_t index,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto operandType = operand.getType().dyn_cast<TupleType>();
if (!operandType) return failure();
// get_tuple_element_c1
if (index < 0 || index >= static_cast<int64_t>(operandType.size()))
return emitOptionalError(location, "index ", index,
" is out of bounds of operand with size ",
operandType.size());
// get_tuple_element_c2
inferredReturnTypes.push_back(operandType.getType(index));
return success();
}
LogicalResult inferImagOp(std::optional<Location> location, Value operand,
SmallVectorImpl<Type>& inferredReturnTypes) {
// imag_c2
inferredReturnTypes.push_back(
createRealType(operand.getType().cast<ShapedType>()));
return success();
}
LogicalResult inferIsFiniteOp(MLIRContext* context, std::optional<Location>,
Value x,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto argTy = x.getType().cast<ShapedType>();
Builder b(context);
inferredReturnTypes.push_back(getSameShapeTensorType(argTy, b.getI1Type()));
return success();
}
LogicalResult inferIfOp(std::optional<Location> location, Value pred,
RegionRange branches,
SmallVectorImpl<Type>& inferredReturnTypes) {
return inferConditionalOp(location, pred, branches, inferredReturnTypes);
}
LogicalResult inferMapOp(
std::optional<Location> location, ValueRange inputs,
ArrayRef<int64_t> dimensions, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
if (failed(verifyRegionNotEmpty(location, computation))) return failure();
// map_c4
auto& computationBlock = computation.front();
auto computationArgs = computationBlock.getArguments();
if (inputs.size() != computationArgs.size())
return emitOptionalError(location,
"expects number of operands to match the arity of "
"map computation, but got: ",
inputs.size(), " and ", computationArgs.size());
// map_c4
for (const auto& indexedArg : llvm::enumerate(computationArgs)) {
auto argType = indexedArg.value().getType().dyn_cast<RankedTensorType>();
if (!argType || argType.getRank() != 0)
return emitOptionalError(
location,
"computation arguments must be 0-rank tensor, but got: arg #",
indexedArg.index(), " of type ", indexedArg.value().getType());
auto operandElemTy = inputs[indexedArg.index()]
.getType()
.cast<ShapedType>()
.getElementType();
if (argType.getElementType() != operandElemTy)
return emitOptionalError(location,
"element type of operands and computation "
"arguments must match, but got: ",
operandElemTy, " and ",
argType.getElementType());
}
// map_c4
auto computationOutputs = computationBlock.getTerminator()->getOperands();
if (computationOutputs.size() != 1)
return emitOptionalError(location,
"computation must return single output, but got: ",
computationOutputs.size());
// map_c4
auto computationOutputType =
computationOutputs[0].getType().dyn_cast<RankedTensorType>();
if (!computationOutputType || computationOutputType.getRank() != 0)
return emitOptionalError(location,
"computation must return 0-rank tensor, but got: ",
computationOutputs[0].getType());
// map_c3
for (const auto& indexedValue : llvm::enumerate(dimensions)) {
if (indexedValue.value() != static_cast<int64_t>(indexedValue.index()))
return emitOptionalError(
location,
"requires monotonically increasing dimension numbers, but got: ",
dimensions);
}
// map_c3
ArrayRef<int64_t> resultShape;
bool allInputsUnranked = true;
for (auto operand : inputs) {
auto operandType = operand.getType().cast<ShapedType>();
if (operandType.hasRank()) {
if (dimensions.size() != operandType.getShape().size())
return emitOptionalError(
location,
"applied to a subset of dimensions currently not supported: "
"operand dimensions = ",
operandType.getShape().size(),
", requested map dimensions size = ", dimensions.size());
resultShape = operandType.getShape();
allInputsUnranked = false;
}
}
// map_c4
if (allInputsUnranked)
inferredReturnShapes.emplace_back(computationOutputType.getElementType());
else
inferredReturnShapes.emplace_back(resultShape,
computationOutputType.getElementType());
return success();
}
LogicalResult inferOptimizationBarrierOp(
std::optional<Location> location, ValueRange operand,
SmallVectorImpl<Type>& inferredReturnTypes) {
// optimization_barrier_c1
for (auto inputArgType : operand.getTypes())
inferredReturnTypes.emplace_back(inputArgType);
return success();
}
LogicalResult inferOutfeedOp(HloDialectInterface* dialect,
std::optional<Location> location,
SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(dialect->createTokenType());
return success();
}
LogicalResult inferPadOp(std::optional<Location> location, Type operandType,
Type paddingValueType,
ArrayRef<int64_t> edgePaddingLow,
ArrayRef<int64_t> edgePaddingHigh,
ArrayRef<int64_t> interiorPadding,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto inputType = operandType.cast<RankedTensorType>();
auto padType = paddingValueType.cast<RankedTensorType>();
// pad_i2
if (padType.getRank() != 0)
return emitOptionalError(location,
"padding value type should be a rank-0 "
"tensor, is rank ",
padType.getRank());
int64_t rank = inputType.getRank();
// pad_c2
if (static_cast<int64_t>(edgePaddingLow.size()) != rank)
return emitOptionalError(location, "edge_padding_low length (",
edgePaddingLow.size(),
") must match operand rank (", rank, ")");
auto inputShape = inputType.getShape();
SmallVector<int64_t> resultShape(rank, ShapedType::kDynamic);
ArrayRef<int64_t> inputBounds = encodingToBounds(inputType.getEncoding());
SmallVector<int64_t> resultBounds(inputBounds.size(), ShapedType::kDynamic);
for (int i = 0, e = inputShape.size(); i < e; i++) {
int64_t paddingLowVal = edgePaddingLow[i];
int64_t paddingHighVal = edgePaddingHigh[i];
int64_t paddingInteriorVal = interiorPadding[i];
// pad_c3
if (paddingInteriorVal < 0)
return emitOptionalError(
location,
"Interior padding cannot be negative: ", paddingInteriorVal);
bool isStaticDim = !isDynamicDimSize(inputShape[i]);
bool isStaticBound =
!inputBounds.empty() && !isDynamicDimSize(inputBounds[i]);
if (isStaticDim || isStaticBound) {
int64_t operandSizeOrBound = isStaticDim ? inputShape[i] : inputBounds[i];
int64_t resultSizeOrBound =
operandSizeOrBound + paddingLowVal + paddingHighVal +
std::max<int64_t>(operandSizeOrBound - 1, 0ll) * paddingInteriorVal;
// pad_c4
if (resultSizeOrBound < 0) {
auto sizeOrBound = isStaticDim ? "size" : "bound";
return emitOptionalError(location, "Padding result in negative ",
sizeOrBound, " for dimension ", i);
}
(isStaticDim ? resultShape : resultBounds)[i] = resultSizeOrBound;
}
}
// pad_c1
inferredReturnTypes.push_back(RankedTensorType::get(
resultShape, inputType.getElementType(),
boundsToEncoding(inputType.getEncoding(), resultBounds)));
return success();
}
LogicalResult inferPartitionIdOp(MLIRContext* context, std::optional<Location>,
SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(RankedTensorType::get(
/*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
return success();
}
LogicalResult inferRealOp(std::optional<Location>, Value operand,
SmallVectorImpl<Type>& inferredReturnTypes) {
// real_c2
inferredReturnTypes.push_back(
createRealType(operand.getType().cast<ShapedType>()));
return success();
}
LogicalResult inferReduceOp(
std::optional<Location> location, TypeRange inputTypes,
ArrayRef<int64_t> dimensions, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
SmallVector<int64_t> newDimensions;
Attribute encoding;
// reduce_c1, reduce_c4, reduce_c5, reduce_i3
if (failed(verifyReduceOpInputsAndInferShape(
location, inputArgTensorTypes, dimensions, newDimensions, encoding)))
return failure();
// reduce_c3, reduce_c7, reduce_c8
auto accumulatorTypesOrErr = getAccumulatorTypes(location, body);
if (failed(accumulatorTypesOrErr)) return failure();
for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
ShapedType inputType = inputArgTensorTypes[inputIdx];
Type elementType = (*accumulatorTypesOrErr)[inputIdx].getElementType();
if (inputType.hasRank())
inferredReturnShapes.emplace_back(newDimensions, elementType, encoding);
else
inferredReturnShapes.emplace_back(elementType);
}
return success();
}
LogicalResult inferReduceWindowOp(
std::optional<Location> location, ValueRange inputs, ValueRange initValues,
ArrayRef<int64_t> windowDimensions,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<ArrayRef<int64_t>> baseDilations,
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
SmallVector<ShapedType> initValueTypes{llvm::map_range(
initValues.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
SmallVector<int64_t> windowDims;
SmallVector<WindowDimension> inferredWindow;
// reduce_window_c1, reduce_window_c2, reduce_window_c4...reduce_window_c12,
// reduce_window_i4...reduce_window_i7
if (failed(verifyReduceWindowOpInputsAndInferWindow(
location, inputTypes, initValueTypes, windowDimensions, windowStrides,
baseDilations, windowDilations, padding, windowDims, inferredWindow)))
return failure();
// reduce_window_c1, reduce_window_c14...reduce_window_c16
auto accumulatorTypesOrErr = getAccumulatorTypes(location, body);
if (failed(accumulatorTypesOrErr)) return failure();
for (size_t i = 0; i < inputTypes.size(); ++i) {
auto inputRankedType = inputs[i].getType().dyn_cast<RankedTensorType>();
if (!inputRankedType) {
inferredReturnShapes.emplace_back(
(*accumulatorTypesOrErr)[i].getElementType());
} else {
auto resultShape =
inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow);
auto inputBounds = encodingToBounds(inputRankedType.getEncoding());
if (inputBounds.empty()) {
inferredReturnShapes.emplace_back(
resultShape, (*accumulatorTypesOrErr)[i].getElementType());
} else {
auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow);
inferredReturnShapes.emplace_back(
resultShape, (*accumulatorTypesOrErr)[i].getElementType(),
boundsToEncoding(inputRankedType.getEncoding(), resultBounds));
}
}
}
return success();
}
LogicalResult inferReplicaIdOp(MLIRContext* context, std::optional<Location>,
SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(RankedTensorType::get(
/*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
return success();
}
LogicalResult inferReverseOp(
std::optional<Location> location, Type operandType,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
return hlo::inferMostSpecificTypeComponents(location, operandType,
inferredReturnShapes);
}
LogicalResult inferRngOp(
std::optional<Location> location, Value a, Value b, Value shape,
bool isRngDistributionUniform,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
if (!isRngDistributionUniform) {
auto muTy = a.getType().cast<ShapedType>().getElementType();
auto sigmaTy = b.getType().cast<ShapedType>().getElementType();
if (!muTy.isa<FloatType>() || !sigmaTy.isa<FloatType>())
return emitOptionalError(location, "mu and sigma must be floats");
}
SmallVector<int64_t> shapeVector;
auto shapeOperandType = shape.getType().cast<ShapedType>();
Type elementType = getElementTypeOrSelf(b);
// Operand `shape` (static 1D by ODS) may be a constant or not, if `shape` is:
// 1. not constant (e.g. tensor<3xi64>): infer tensor<?x?x?x>.
// 2. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>.
// Match to check whether the `shape` operand is a constant.
DenseIntElementsAttr shapeAttr;
if (!matchPattern(shape, m_Constant(&shapeAttr))) {
int size = shapeOperandType.getDimSize(0);
shapeVector.resize(size, ShapedType::kDynamic);
inferredReturnShapes.emplace_back(shapeVector, elementType);
return success();
}
// `shape` operand is a constant.
shapeVector.reserve(shapeAttr.size());
for (const APInt& dimSize : shapeAttr.getValues<APInt>())
shapeVector.push_back(dimSize.getSExtValue());
inferredReturnShapes.emplace_back(shapeVector, elementType);
return success();
}
LogicalResult inferScatterOp(std::optional<Location> location,
ValueRange inputs, Region& updateComputation,
SmallVectorImpl<Type>& inferredReturnTypes) {
// scatter_c16, scatter_c17
auto accumulatorTypesOrErr = getAccumulatorTypes(location, updateComputation);
if (failed(accumulatorTypesOrErr)) return failure();
for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
auto inputShapedTy = inputs[inputIdx].getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
inputShapedTy, (*accumulatorTypesOrErr)[inputIdx].getElementType()));
}
return success();
}
LogicalResult inferSelectOp(
std::optional<Location> location, Value pred, Value onTrue, Value onFalse,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto predType = pred.getType().cast<ShapedType>();
auto trueType = onTrue.getType().cast<ShapedType>();
auto falseType = onFalse.getType().cast<ShapedType>();
// select_c2
if (!compatibleShapeAndElementType(trueType, falseType))
return emitOptionalError(
location, "requires compatible types for non-predicate operands");
// select_c1
bool predCannotBeScalar = predType.hasRank() && predType.getRank() != 0;
if (predCannotBeScalar)
if (failed(verifyCompatibleShape(predType, trueType)))
return emitOptionalError(location,
"requires the same shape for all operands");
// select_c2
SmallVector<Type> inferredReturnTypes;
return inferMostSpecificTypeComponents(location, {trueType, falseType},
inferredReturnShapes);
}
LogicalResult inferSelectAndScatterOp(
std::optional<Location> location, Value operand, Region& scatter,
SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11, select_and_scatter_c12
auto accumulatorTypesOrErr = getAccumulatorTypes(location, scatter);
if (failed(accumulatorTypesOrErr)) return failure();
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
operandShapedTy, (*accumulatorTypesOrErr)[0].getElementType()));
return success();
}
LogicalResult inferSendOp(HloDialectInterface* dialect,
std::optional<Location> location,
bool isDeviceToDevice, bool isDeviceToHost,
bool isHostTransfer,
SmallVectorImpl<Type>& inferredReturnTypes) {
// send_c1_i4
if (!isHostTransfer && !isDeviceToDevice)
return emitOptionalError(location,
"channel_type should be DEVICE_TO_DEVICE when "
"is_host_transfer is false");
// send_c1_i4
if (isHostTransfer && !isDeviceToHost)
return emitOptionalError(location,
"channel_type should be DEVICE_TO_HOST when "
"is_host_transfer is true");
inferredReturnTypes.push_back(dialect->createTokenType());
return success();
}
LogicalResult inferSetDimensionSizeOp(
HloDialectInterface* dialect, std::optional<Location> location,
Type operandType, Value size, int64_t dimension,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto sizeType = size.getType().dyn_cast<RankedTensorType>();
if (sizeType && sizeType.getRank() != 0)
return emitOptionalError(location, "size operand should be of rank-0");
if (failed(verifyDimInBounds(location, operandType.cast<ShapedType>(),
dimension)))
return failure();
auto inputType = operandType.dyn_cast<RankedTensorType>();
if (!inputType) {
inferredReturnShapes.emplace_back(
operandType.cast<ShapedType>().getElementType());
return success();
}
int64_t rank = inputType.getRank();
if (dimension < 0 || dimension >= rank)
return emitOptionalError(location, "expects dimension to be in range [0, ",
rank, "); got: [", dimension, "].");
auto shape = llvm::to_vector<4>(inputType.getShape());
llvm::SmallVector<int64_t, 4> bounds(rank, ShapedType::kDynamic);
ArrayRef<int64_t> inputBounds = encodingToBounds(inputType.getEncoding());
if (!inputBounds.empty()) bounds = llvm::to_vector<4>(inputBounds);
if (!hlo::isDynamicDimSize(shape[dimension]))
bounds[dimension] = shape[dimension];
shape[dimension] = ShapedType::kDynamic;
DenseIntElementsAttr sizeAttr;
if (matchPattern(size, m_Constant(&sizeAttr))) {
int64_t splat =
sizeAttr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
if (splat == bounds[dimension]) {
shape[dimension] = splat;
bounds[dimension] = ShapedType::kDynamic;
}
}
if (llvm::all_of(bounds, [&](auto b) { return isDynamicDimSize(b); }))
inferredReturnShapes.emplace_back(shape, inputType.getElementType());
else
inferredReturnShapes.emplace_back(shape, inputType.getElementType(),
dialect->createTypeExtensions(bounds));
return success();
}
LogicalResult inferSliceOp(std::optional<Location> location, Type operandType,
ArrayRef<int64_t> startIndices,
ArrayRef<int64_t> limitIndices,
ArrayRef<int64_t> strides,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto rankedTy = operandType.dyn_cast<RankedTensorType>();
if (!rankedTy) {
// The operand type is unranked, so the best we can infer for the result
// type is an unranked tensor with the same element type as the operand
// type.
inferredReturnTypes.assign({operandType});
return success();
}
// slice_c2
int64_t rank = rankedTy.getRank();
if (static_cast<int64_t>(startIndices.size()) != rank)
return emitOptionalError(
location, "the number of elements in start_indices (",
startIndices.size(), ") does not match the rank of the operand (", rank,
")");
ArrayRef<int64_t> inputBounds = encodingToBounds(rankedTy.getEncoding());
SmallVector<int64_t> shape(rank, ShapedType::kDynamic);
SmallVector<int64_t> resultBounds(inputBounds.size(), ShapedType::kDynamic);
for (int64_t i = 0, e = rank; i != e; i++) {
// slice_c3
if (startIndices[i] < 0)
return emitOptionalError(location, "negative start index ",
startIndices[i], " in dimension ", i);
bool isStaticDim = !isDynamicDimSize(rankedTy.getDimSize(i));
bool isStaticBound =
!inputBounds.empty() && !isDynamicDimSize(inputBounds[i]);
if (isStaticDim || isStaticBound) {
int64_t operandSizeOrBound =
isStaticDim ? rankedTy.getDimSize(i) : inputBounds[i];
StringRef sizeOrBound = isStaticDim ? "size" : "bound";
// slice_c3
if (limitIndices[i] > operandSizeOrBound)
return emitOptionalError(location, "limit index ", limitIndices[i],
" is larger than dimension ", sizeOrBound, " ",
operandSizeOrBound, " in dimension ", i);
}
// slice_c3
if (startIndices[i] > limitIndices[i])
return emitOptionalError(location, "start index ", startIndices[i],
" is larger than limit index ", limitIndices[i],
" in dimension ", i);
// slice_c4
if (strides[i] <= 0)
return emitOptionalError(location, "stride must be positive but got ",
strides[i], " in dimension ", i);
// slice_c5
shape[i] = static_cast<int64_t>(
llvm::divideCeil(limitIndices[i] - startIndices[i], strides[i]));
}
// slice_c1
inferredReturnTypes.push_back(RankedTensorType::get(
shape, rankedTy.getElementType(),
boundsToEncoding(rankedTy.getEncoding(), resultBounds)));
return success();
}
LogicalResult inferSortOp(
std::optional<Location>, ValueRange inputs,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// sort_c2
for (auto resultType : inputs.getTypes()) {
auto rankedResult = resultType.dyn_cast<RankedTensorType>();
if (rankedResult)
inferredReturnShapes.emplace_back(rankedResult.getShape(),
rankedResult.getElementType(),
rankedResult.getEncoding());
else
inferredReturnShapes.emplace_back(resultType.cast<ShapedType>());
}
return success();
}
LogicalResult inferTopKOp(
std::optional<Location> location, Value operand, int64_t k,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
Builder builder(operand.getContext());
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
if (!operandType) {
inferredReturnShapes.emplace_back(
operand.getType().cast<ShapedType>().getElementType());
inferredReturnShapes.emplace_back(builder.getI32Type());
return success();
}
if (operandType.getRank() < 1)
return emitOptionalError(location, "operand's rank must be at least 1");
auto operandLastDim = operandType.getRank() - 1;
if (!operandType.isDynamicDim(operandLastDim) &&
operandType.getDimSize(operandLastDim) < k)
return emitOptionalError(location,
"operand's last dimension must be at least ", k);
SmallVector<int64_t> resultShape(operandType.getShape());
resultShape[operandLastDim] = k;
SmallVector<int64_t> resultBounds(
encodingToBounds(operandType.getEncoding()));
if (!resultBounds.empty())
resultBounds[operandLastDim] = ShapedType::kDynamic;
inferredReturnShapes.emplace_back(
resultShape, operandType.getElementType(),
hlo::boundsToEncoding(operandType.getEncoding(), resultBounds));
inferredReturnShapes.emplace_back(
resultShape, builder.getI32Type(),
hlo::boundsToEncoding(operandType.getEncoding(), resultBounds));
return success();
}
LogicalResult inferTransposeOp(std::optional<Location> loc, Value operand,
ArrayRef<int64_t> permutation,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto type = operand.getType();
auto rankedTy = type.dyn_cast<RankedTensorType>();
if (!rankedTy) {
inferredReturnTypes.emplace_back(type);
return success();
}
int64_t rank = rankedTy.getRank();
if (static_cast<int64_t>(permutation.size()) != rank)
return emitOptionalError(loc, "TransposeOp operand rank ", rank,
" does not match permutation size ",
permutation.size());
std::vector<int64_t> range(rank);
std::iota(range.begin(), range.end(), 0);
if (!std::is_permutation(range.begin(), range.end(), permutation.begin()))
return emitOptionalError(loc,
"attribute permutation must be a permutation"
" of [",
range, "] but got ", permutation);
ArrayRef<int64_t> inputBounds = encodingToBounds(rankedTy.getEncoding());
SmallVector<int64_t> resultShape;
SmallVector<int64_t> resultBounds;
ArrayRef<int64_t> inputShape = rankedTy.getShape();
for (int64_t dim : permutation) {
resultShape.push_back(inputShape[dim]);
if (!inputBounds.empty()) resultBounds.push_back(inputBounds[dim]);
}
inferredReturnTypes.push_back(RankedTensorType::get(
resultShape, rankedTy.getElementType(),
boundsToEncoding(rankedTy.getEncoding(), resultBounds)));
return success();
}
LogicalResult inferTriangularSolveOp(
std::optional<Location> location, Value a, Value b, bool leftSide,
bool isTransposeAInvalid,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// ODS enforces that a and b are of same element type: float or complex.
auto elementType = a.getType().cast<ShapedType>().getElementType();
auto aType = a.getType().dyn_cast<RankedTensorType>();
if (!aType) {
inferredReturnShapes.emplace_back(elementType);
return success();
}
auto aRank = aType.getRank();
if (aRank < 2)
return emitOptionalError(
location, "operand 'a' must have rank >= 2, but got ", aType);
if (!verifyCompatibleDims(aType.getDimSize(aRank - 2),
aType.getDimSize(aRank - 1)))
return emitOptionalError(location,
"two minor dimensions of operand 'a' must ",
"be compatible, but got ", aType);
auto bType = b.getType().dyn_cast<RankedTensorType>();
if (!bType) {
inferredReturnShapes.emplace_back(elementType);
return success();
}
auto bRank = bType.getRank();
if (aRank != bRank)
return emitOptionalError(location,
"operands must have equal rank, but got ", aType,
" and ", bType);
if (!verifyCompatibleDims(aType.getDimSize(aRank - 1),
bType.getDimSize(bRank - (leftSide ? 2 : 1))))
return emitOptionalError(location,
"shared dimension of operands 'a' and 'b' must ",
"be compatible, but got ", aType, " and ", bType);
auto aBatchDims = aType.getShape().drop_back(2);
auto bBatchDims = bType.getShape().drop_back(2);
if (failed(verifyCompatibleShape(aBatchDims, bBatchDims)))
return emitOptionalError(location, "batch dimensions of the operands must ",
"be compatible, but got ", aType, " and ", bType);
if (isTransposeAInvalid)
return emitOptionalError(
location, "Invalid transpose option value for triangular solve");
inferredReturnShapes.emplace_back(bType.getShape(), bType.getElementType(),
bType.getEncoding());
return success();
}
LogicalResult inferTupleOp(MLIRContext* context, std::optional<Location>,
ValueRange val,
SmallVectorImpl<Type>& inferredReturnTypes) {
// tuple_c1
inferredReturnTypes.push_back(TupleType::get(context, val.getTypes()));
return success();
}
LogicalResult inferUniformDequantizeOp(
std::optional<Location> location, Value operand,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto operandType = operand.getType().cast<ShapedType>();
// Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType;
auto quantType = operandType.getElementType().cast<quant::QuantizedType>();
auto shape = operandType.cast<ShapedType>().getShape();
// uniform_dequantize_c1, uniform_dequantize_c2
inferredReturnShapes.emplace_back(shape, quantType.getExpressedType());
return success();
}
LogicalResult inferUniformQuantizeOp(
std::optional<Location> location, Value operand,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto operandType = operand.getType().cast<ShapedType>();
// uniform_quantize_c1
inferredReturnShapes.emplace_back(
operandType.hasRank() ? operandType.getShape() : ArrayRef<int64_t>{});
return success();
}
LogicalResult inferWhileOp(std::optional<Location>, ValueRange operand,
SmallVectorImpl<Type>& inferredReturnTypes) {
// while_c3
for (const auto& resultType : operand.getType())
inferredReturnTypes.push_back(resultType);
return success();
}
//===----------------------------------------------------------------------===//
// Verifiers for ops.
//===----------------------------------------------------------------------===//
LogicalResult verifyAllGatherOp(std::optional<Location> location, Value operand,
int64_t allGatherDim,
DenseIntElementsAttr replicaGroups,
int64_t channelId, bool useGlobalDeviceIds,
Value result) {
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
auto resultType = result.getType().dyn_cast<RankedTensorType>();
// all_gather_c1
if (allGatherDim < 0)
return emitOptionalError(location, "all_gather_dim cannot be negative");
if (operandType) {
// all_gather_c1
if (allGatherDim >= operandType.getRank())
return emitOptionalError(
location, "all_gather_dim must be a valid index of operand");
// TODO(#1745): Sync verification of AllGather with HLO.
if (operandType.getDimSize(allGatherDim) == 0)
return emitOptionalError(
location,
"dimension size of operand at 'all_gather_dim' cannot be zero");
}
// all_gather_i3, all_gather_c2, all_gather_c4
if (failed(verifyReplicaGroups(location, replicaGroups,
/*allGroupsMustHaveSameSize=*/true,
useGlobalDeviceIds,
/*expectedGroupSize=*/std::nullopt)))
return failure();
// all_gather_c5
if (useGlobalDeviceIds && channelId < 0)
return emitOptionalError(
location,
"channel_id cannot be negative when useGlobalDeviceIds is set");
// all_gather_c6
if (operandType && resultType) {
if (resultType.getRank() != operandType.getRank())
return emitOptionalError(location,
"operand and result must have the same rank");
for (int64_t i = 0; i < operandType.getRank(); i++) {
if (i == allGatherDim) continue;
// all_gather_c6
if (!verifyCompatibleDims(resultType.getDimSize(i),
operandType.getDimSize(i)))
return emitOptionalError(
location,
"operand and result should have the same shape except for the "
"dimension size at 'all_gather_dim'");
}
if (operandType.isDynamicDim(allGatherDim) ||
resultType.isDynamicDim(allGatherDim))
return success();
// all_gather_c6
if ((resultType.getDimSize(allGatherDim) %
operandType.getDimSize(allGatherDim)) != 0)
return emitOptionalError(
location, "result gather dimension has size ",
resultType.getDimSize(allGatherDim),
", expected to be a multiple of operand gather dimension size ",
operandType.getDimSize(allGatherDim));
}
return success();
}
LogicalResult verifyAllReduceOp(std::optional<Location> location, Value operand,
DenseIntElementsAttr replicaGroups,
int64_t channelId, bool useGlobalDeviceIds,
Region& computation) {
// TODO(#498): AllReduceOp does not have rank-2 replicaGroups.
// all_reduce_c1...all_reduce_c3
if (failed(verifyReplicaGroups(location, replicaGroups,
/*allGroupsMustHaveSameSize=*/false,
useGlobalDeviceIds,
/*expectedGroupSize=*/std::nullopt)))
return failure();
// all_reduce_c4
if (useGlobalDeviceIds && channelId <= 0)
return emitOptionalError(
location,
"channel_id must be positive when useGlobalDeviceIds is set but got: ",
channelId);
auto operandType = operand.getType().cast<ShapedType>();
// all_reduce_c5
if (failed(verifyReducerShape(
location, computation.front(), {operandType},
{RankedTensorType::get({}, operandType.getElementType())},
/*allowedDimensions=*/{})))
return failure();
return success();
}
LogicalResult verifyBitcastConvertOp(std::optional<Location> location,
Value operand, Value result) {
auto operandShapedType = operand.getType().cast<ShapedType>();
auto targetShapedType = result.getType().cast<ShapedType>();
// bitcast_convert_c2
auto targetElt = targetShapedType.getElementType();
auto operandElt = operandShapedType.getElementType();
if (targetElt.isa<ComplexType>() != operandElt.isa<ComplexType>())
return emitOptionalError(
location, "cannot convert between real and complex types, but got: ",
operandShapedType, " and ", targetShapedType);
auto targetEltBitWidth = getBitWidth(targetElt);
auto operandEltBitWidth = getBitWidth(operandElt);
auto operandType = operandShapedType.dyn_cast<RankedTensorType>();
auto targetType = targetShapedType.dyn_cast<RankedTensorType>();
if (!operandType || !targetType) return success();
auto targetShape = targetType.getShape();
auto operandShape = operandType.getShape();
ArrayRef<int64_t> smallerEltShape, biggerEltShape;
if (operandEltBitWidth < targetEltBitWidth) {
smallerEltShape = operandShape;
biggerEltShape = targetShape;
} else {
smallerEltShape = targetShape;
biggerEltShape = operandShape;
}
ArrayRef<int64_t> smallerEltPrefix;
auto smallerEltBitWidth = std::min(targetEltBitWidth, operandEltBitWidth);
auto biggerEltBitWidth = std::max(targetEltBitWidth, operandEltBitWidth);
// bitcast_convert_c1
if (operandEltBitWidth != targetEltBitWidth) {
if (smallerEltShape.size() != biggerEltShape.size() + 1) {
return emitOptionalError(
location, "rank of smaller element type (", smallerEltShape.size(),
") should be 1 more than rank of larger element type (",
biggerEltShape.size(), "), but ", smallerEltShape.size(),
" != ", biggerEltShape.size(), " + 1.");
}
smallerEltPrefix = smallerEltShape.drop_back();
if (!isDynamicDimSize(smallerEltShape.back()) &&
smallerEltShape.back() * smallerEltBitWidth != biggerEltBitWidth) {
return emitOptionalError(
location, "requires compatible bit widths. ", "Got: ", operandType,
" and ", targetType, ", but ", smallerEltBitWidth, " * ",
smallerEltShape.back(), " != ", biggerEltBitWidth, ".");
}
} else {
smallerEltPrefix = smallerEltShape;
}
for (auto it : llvm::zip(smallerEltPrefix, biggerEltShape)) {
auto targetDim = std::get<0>(it);
auto operandDim = std::get<1>(it);
// bitcast_convert_c1
if (!verifyCompatibleDims(targetDim, operandDim))
return emitOptionalError(location,
"operand and result shapes must match except "
"for the innermost dimension of the shape with "
"the smaller element type. Got: ",
operandType, " and ", targetType, ".");
}
return success();
}
LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
Value operand,
ArrayRef<int64_t> broadcastDimensions,
Value result) {
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
if (!operandType) {
// The following verification checks all depend on knowing the rank of
// the operand. Bail out now if we don't know the rank of the operand.
return success();
}
// broadcast_in_dim_c2
auto dimensionsSize = broadcastDimensions.size();
auto operandRank = operandType.getRank();
if (static_cast<int64_t>(dimensionsSize) != operandRank)
return emitOptionalError(location, "broadcast_dimensions size (",
dimensionsSize, ") does not match operand rank (",
operandRank, ")");
// broadcast_in_dim_c4
if (!isUnique(broadcastDimensions))
return emitOptionalError(location,
"broadcast_dimensions should not have duplicates");
auto resultType = result.getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
for (size_t i = 0; i != dimensionsSize; ++i) {
auto dimIndex = broadcastDimensions[i];
// broadcast_in_dim_c3
if (dimIndex < 0 || dimIndex >= resultRank)
return emitOptionalError(location,
"broadcast_dimensions contains invalid value ",
dimIndex, " for result with rank ", resultRank);
if (!operandType.isDynamicDim(i)) {
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
// broadcast_in_dim_c5
if (dimSize != 1 && dimSize != resultDimSize)
return emitOptionalError(
location, "size of operand dimension ", i, " (", dimSize,
") is not equal to 1 or size of result dimension ", dimIndex, " (",
resultDimSize, ")");
}
}
return success();
}
LogicalResult verifyCollectiveBroadcastOp(std::optional<Location> location,
DenseIntElementsAttr replicaGroups) {
// collective_permute_i2
auto replicaGroupType = replicaGroups.getType().cast<RankedTensorType>();
if (replicaGroupType.getRank() != 2)
return emitOptionalError(
location, "replica groups should be a rank 2 tensor,",
"but instead it is of rank ", replicaGroupType.getRank());
auto replicaIds = replicaGroups.getValues<int64_t>();
llvm::SmallSet<int64_t, 8> replicaIdsSeen;
for (int64_t replicaId : replicaIds) {
// collective_broadcast_c2
// We only check that is is not negative, as it is impossible
// to statically know `num_replicas` or `num_partitions`
if (replicaId < 0)
return emitOptionalError(
location, "replica_groups values must be positive, but was given ",
replicaId);
// collective_broadcast_c1
if (!replicaIdsSeen.insert(replicaId).second)
return emitOptionalError(location, "replica id #", replicaId,
" seen more than once");
}
return success();
}
LogicalResult verifyCollectivePermuteOp(
std::optional<Location> location, DenseIntElementsAttr sourceTargetPairs) {
auto type = sourceTargetPairs.getType().dyn_cast<RankedTensorType>();
// collective_permute_i2
if (type.getRank() != 2)
return emitOptionalError(location,
"expect source_target_pairs attribute to be of "
"rank 2, but got rank ",
type.getRank());
// collective_permute_c1
if (type.getShape()[1] != 2)
return emitOptionalError(
location,
"expect source_target_pairs attribute of shape (N, 2), but got (",
type.getShape(), ")");
llvm::DenseSet<int64_t> sources;
llvm::DenseSet<int64_t> targets;
for (auto i = sourceTargetPairs.begin(), e = sourceTargetPairs.end(); i != e;
++i) {
auto val = (*i).getSExtValue();
// collective_permute_c4
if (val < 0)
return emitOptionalError(
location, "replica ids in source_target_pairs must be >= 0.");
if (i.getIndex() % 2 == 0) {
bool isUnique = sources.insert(val).second;
// collective_permute_c2
if (!isUnique)
return emitOptionalError(location, "duplicate sources not allowed.");
} else {
bool isUnique = targets.insert(val).second;
// collective_permute_c3
if (!isUnique)
return emitOptionalError(location, "duplicate targets not allowed.");
}
}
return success();
}
LogicalResult verifyCompositeOp(std::optional<Location> loc, Operation* op,
StringRef name, StringRef decomposition,
SymbolTableCollection& symbolTable) {
// composite_c1
auto nameRegexString = "^[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+$";
llvm::Regex nameRegex(nameRegexString);
if (!nameRegex.match(name))
return emitOptionalError(loc,
"name must be a valid namespaced op name, i.e. it "
"must match the following regular expression: ",
nameRegexString, " e.g. \"my_namespace.my_op\"");
// composite_c2
auto decomp = symbolTable.lookupNearestSymbolFrom<mlir::func::FuncOp>(
op, StringAttr::get(op->getContext(), decomposition));
if (!decomp) {
return emitOptionalError(loc, "'", decomposition,
"' does not reference a valid function");
}
auto decompFunType = decomp.getFunctionType();
// composite_c3
auto types = op->getOperandTypes();
auto decompTypes = decompFunType.getInputs();
if (types.size() != decompTypes.size()) {
return emitOptionalError(loc, "has ", types.size(),
" operand(s), but decomposition has ",
decompTypes.size());
}
for (size_t i = 0; i < types.size(); i++) {
if (types[i] != decompTypes[i]) {
return emitOptionalError(loc, "operand at index ", i, " has type ",
types[i], ", but decomposition has type ",
decompTypes[i]);
}
}
// composite_c4
auto resTypes = op->getResultTypes();
auto decompResTypes = decompFunType.getResults();
if (resTypes.size() != decompResTypes.size()) {
return emitOptionalError(loc, "has ", resTypes.size(),
" result(s), but decomposition has ",
decompResTypes.size());
}
for (size_t i = 0; i < resTypes.size(); i++) {
if (resTypes[i] != decompResTypes[i]) {
return emitOptionalError(loc, "result at index ", i, " has type ",
resTypes[i], ", but decomposition has type ",
decompResTypes[i]);
}
}
return success();
}
LogicalResult verifyConvolutionOp(
std::optional<Location> location, Type lhsType, Type rhsType,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<DenseIntElementsAttr> padding,
std::optional<ArrayRef<int64_t>> lhsDilation,
std::optional<ArrayRef<int64_t>> rhsDilation,
std::optional<ArrayRef<bool>> windowReversal, int64_t inputBatchDimension,
int64_t inputFeatureDimension, ArrayRef<int64_t> inputSpatialDimensions,
int64_t kernelInputFeatureDimension, int64_t kernelOutputFeatureDimension,
ArrayRef<int64_t> kernelSpatialDimensions, int64_t outputBatchDimension,
int64_t outputFeatureDimension, ArrayRef<int64_t> outputSpatialDimensions,
int64_t featureGroupCount, int64_t batchGroupCount,
std::optional<ArrayAttr> precisionConfig, Type resultType) {
SmallVector<ShapedTypeComponents> inferredReturnShapes;
if (failed(inferConvolutionOp(
location, lhsType, rhsType, windowStrides, padding, lhsDilation,
rhsDilation, windowReversal, inputBatchDimension,
inputFeatureDimension, inputSpatialDimensions,
kernelInputFeatureDimension, kernelOutputFeatureDimension,
kernelSpatialDimensions, outputBatchDimension, outputFeatureDimension,
outputSpatialDimensions, featureGroupCount, batchGroupCount,
precisionConfig, inferredReturnShapes)))
return failure();
auto inferredShape = inferredReturnShapes[0];
auto shapedResultType = resultType.cast<ShapedType>();
if (inferredShape.hasRank() && shapedResultType.hasRank() &&
failed(verifyCompatibleShape(inferredShape.getDims(),
shapedResultType.getShape())))
return emitOptionalError(location, "inferred shape '",
dimSizesToString(inferredShape.getDims()), "' ",
"is incompatible with return type of operation ",
shapedResultType, "");
llvm::SmallVector<Type, 3> typeEntries{lhsType, rhsType, resultType};
if (noneQuantized<quant::QuantizedType>(typeEntries)) return success();
// convolution_c28
if (!allQuantized<quant::QuantizedType>(typeEntries)) {
return emitOptionalError(location,
"not all of operands and result are quantized");
}
auto lhsQType =
getElementTypeOrSelf(lhsType).dyn_cast<quant::QuantizedType>();
auto rhsQType =
getElementTypeOrSelf(rhsType).dyn_cast<quant::QuantizedType>();
auto resultQType =
getElementTypeOrSelf(resultType).dyn_cast<quant::QuantizedType>();
// convolution_c29
if (lhsQType.getStorageType() != rhsQType.getStorageType())
return emitOptionalError(location, "mismatched operand storage types ",
lhsQType.getStorageType(), " and ",
rhsQType.getStorageType());
// convolution_c30
auto expressedType = lhsQType.getExpressedType();
if (expressedType != rhsQType.getExpressedType() ||
expressedType != resultQType.getExpressedType())
return emitOptionalError(location,
"mismatched operands and result expressed types");
llvm::SmallVector<Type, 2> typeEntriesPerAxis{rhsType, resultType};
if (noneQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis))
return success();
// convolution_c31
if (!allQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis)) {
return emitOptionalError(location,
"rhs and result are of mixed per_tensor and "
"per_axis quantized tensor type ",
rhsType, " and ", resultType);
}
auto rhsQPAType = rhsQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
auto resultQPAType =
resultQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
// convolution_c32
if (rhsQPAType &&
rhsQPAType.getQuantizedDimension() != kernelOutputFeatureDimension)
return emitOptionalError(
location, "mismatched kernel_output_feature_dimension ",
kernelOutputFeatureDimension, " and rhs quantized dimension ",
rhsQPAType.getQuantizedDimension());
// convolution_c33
if (resultQPAType &&
resultQPAType.getQuantizedDimension() != outputFeatureDimension)
return emitOptionalError(location, "mismatched output_feature_dimension ",
outputFeatureDimension,
" and result quantized dimension ",
resultQPAType.getQuantizedDimension());
return success();
}
LogicalResult verifyDotOp(std::optional<Location> location,
RankedTensorType lhsType, RankedTensorType rhsType,
std::optional<ArrayAttr> precisionConfig,
Value result) {
SmallVector<ShapedTypeComponents> inferredReturnShapes;
if (failed(inferDotOp(location, lhsType, rhsType, precisionConfig,
inferredReturnShapes)))
return failure();
auto inferredShape = inferredReturnShapes[0];
auto resultType = result.getType().cast<ShapedType>();
if (inferredShape.hasRank() && resultType.hasRank() &&
failed(verifyCompatibleShape(inferredShape.getDims(),
resultType.getShape())))
return emitOptionalError(
location, "inferred shape '", dimSizesToString(inferredShape.getDims()),
"' ", "is incompatible with return type of operation ", resultType, "");
return success();
}
LogicalResult verifyDotGeneralOp(std::optional<Location> location, Value lhs,
Value rhs,
ArrayRef<int64_t> lhsBatchingDimensions,
ArrayRef<int64_t> rhsBatchingDimensions,
ArrayRef<int64_t> lhsContractingDimensions,
ArrayRef<int64_t> rhsContractingDimensions,
std::optional<ArrayAttr> precisionConfig,
Value result) {
SmallVector<ShapedTypeComponents> inferredReturnShapes;
if (failed(inferDotGeneralOp(
location, lhs.getType(), rhs.getType(), lhsBatchingDimensions,
rhsBatchingDimensions, lhsContractingDimensions,
rhsContractingDimensions, precisionConfig, inferredReturnShapes)))
return failure();
auto inferredShape = inferredReturnShapes[0];
auto resultType = result.getType().cast<ShapedType>();
if (inferredShape.hasRank() && resultType.hasRank() &&
failed(verifyCompatibleShape(inferredShape.getDims(),
resultType.getShape())))
return emitOptionalError(
location, "inferred shape '", dimSizesToString(inferredShape.getDims()),
"' ", "is incompatible with return type of operation ", resultType, "");
return success();
}
LogicalResult verifyDynamicBroadcastInDimOp(
std::optional<Location> location, Value operand, Value outputDimensions,
ArrayRef<int64_t> broadcastDimensions,
std::optional<ArrayRef<int64_t>> knownExpandingDimensions,
std::optional<ArrayRef<int64_t>> knownNonexpandingDimensions,
Value result) {
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
auto resultType = result.getType().cast<RankedTensorType>();
auto outputDimensionsType =
outputDimensions.getType().cast<RankedTensorType>();
auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
auto resultRank = resultType.getRank();
// Verify broadcast_dimensions.
auto bcastDimensions = broadcastDimensions;
int64_t bcastDimensionsSize = bcastDimensions.size();
if (operandType) {
auto operandRank = operandType.getRank();
if (bcastDimensionsSize != operandRank)
return emitOptionalError(
location, "broadcast_dimensions size (", bcastDimensionsSize,
") does not match operand rank (", operandRank, ")");
if (resultRank < operandRank)
return emitOptionalError(location, "result rank (", resultRank,
") is less than operand rank (", operandRank,
")");
for (int i = 0; i != bcastDimensionsSize; ++i) {
auto dimIndex = bcastDimensions[i];
if (dimIndex < 0 || dimIndex >= resultRank)
return emitOptionalError(
location, "broadcast_dimensions contains invalid value ", dimIndex,
" for result with rank ", resultRank);
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
// Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so
// we add a manual check for this.
if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize)))
return emitOptionalError(location, "size of operand dimension ", i,
" (", dimSize,
") is not compatible "
"with size of result dimension ",
dimIndex, " (", resultDimSize, ")");
}
}
if (outputDimensionsSize != resultRank)
return emitOptionalError(location, "result rank (", resultRank,
") is not equal to number of output dimensions (",
outputDimensionsSize, ")");
// Verify that the known expanding and non-expanding dimensions are a subset
// of the operand's dimensions.
int64_t numKnownExpansionBehavior = 0;
DenseSet<int64_t> knownExpansionBehavior;
auto collectExpansionBehaviorDims =
[&](const std::optional<ArrayRef<int64_t>>& attr) {
if (!attr) return;
for (const auto& i : attr.value()) {
numKnownExpansionBehavior++;
knownExpansionBehavior.insert(i);
}
};
collectExpansionBehaviorDims(knownExpandingDimensions);
collectExpansionBehaviorDims(knownNonexpandingDimensions);
if (knownExpansionBehavior.size() != numKnownExpansionBehavior)
return emitOptionalError(
location,
"duplicate expansion hint for at least one operand dimension");
for (int64_t i : knownExpansionBehavior)
if (operandType && (i < 0 || i >= operandType.getRank()))
return emitOptionalError(location, "hint for expanding dimension ", i,
" does not refer to a "
"valid operand dimension");
if (!isCompatibleForHloTypeInference(outputDimensions, resultType))
return emitOptionalError(
location,
"output_dimensions are incompatible with return type of operation ",
resultType);
return success();
}
LogicalResult verifyDynamicIotaOp(std::optional<Location> location,
Value outputShape, int64_t iotaDimension,
Value result) {
auto shape = result.getType().cast<ShapedType>();
if (!isCompatibleForHloTypeInference(outputShape, shape))
return emitOptionalError(
location, "output_shape is incompatible with return type of operation ",
result.getType());
if (iotaDimension >= shape.getRank() || iotaDimension < 0)
return emitOptionalError(
location,
"iota dimension cannot go beyond the output rank or be negative.");
return success();
}
LogicalResult verifyDynamicPadOp(std::optional<Location> location,
Value operand, Value paddingValue,
Value edgePaddingLow, Value edgePaddingHigh,
Value interiorPadding, Value result) {
auto inputType = operand.getType().dyn_cast<RankedTensorType>();
// If operand is unranked, there is very little to verify statically.
if (!inputType) return success();
int inputRank = inputType.getRank();
auto padType = paddingValue.getType().cast<RankedTensorType>();
if (padType.getRank() != 0)
return emitOptionalError(location, "padding value type should be a rank-0");
auto paddingLowType = edgePaddingLow.getType().cast<RankedTensorType>();
if (paddingLowType.getNumElements() != inputRank)
return emitOptionalError(location, "edge_padding_low length(",
paddingLowType.getNumElements(),
") must match operand rank(", inputRank, ").");
auto paddingHighType = edgePaddingHigh.getType().cast<RankedTensorType>();
if (paddingHighType.getNumElements() != inputRank)
return emitOptionalError(location, "edge_padding_high length(",
paddingHighType.getNumElements(),
") must match operand rank(", inputRank, ").");
auto interiorPaddingType = interiorPadding.getType().cast<RankedTensorType>();
if (interiorPaddingType.getNumElements() != inputRank)
return emitOptionalError(location, "edge_padding_interior length(",
interiorPaddingType.getNumElements(),
") must match operand rank(", inputRank, ").");
auto outputType = result.getType().dyn_cast<RankedTensorType>();
// If result is unranked, there is very little to verify statically.
if (!outputType) return success();
int outputRank = outputType.getRank();
if (inputRank != outputRank)
return emitOptionalError(location, "operand rank(", inputRank,
") must match result(", outputRank, ").");
return success();
}
LogicalResult verifyDynamicReshapeOp(std::optional<Location> location,
Value outputShape, Value result) {
auto resultType = result.getType().cast<ShapedType>();
auto outputShapeType = outputShape.getType().cast<ShapedType>();
if (outputShapeType.getDimSize(0) != resultType.getRank())
return emitOptionalError(location,
"output should have a rank equal to the number of "
"elements in output_shape");
if (!isCompatibleForHloTypeInference(outputShape, resultType))
return emitOptionalError(
location, "output_shape is incompatible with return type of operation ",
resultType);
return success();
}
LogicalResult verifyInfeedOp(HloDialectInterface* dialect,
std::optional<Location> location,
std::optional<ArrayAttr> layout,
ValueRange results) {
auto resultTypes = results.getType();
// infeed_c1
if (resultTypes.empty())
return emitOptionalError(
location, "result is expected to be at least of size 1, but got ",
resultTypes.size());
// infeed_c2
for (auto resultType : results.drop_back().getTypes())
if (!resultType.isa<TensorType>())
return emitOptionalError(
location,
"all elements of result types, except the last element, are expected "
"to be of tensor type, but got ",
resultType);
// infeed_c3
if (!dialect->isTokenType(results.back().getType()))
return emitOptionalError(location,
"last element of result types is expected to "
"be of token type, but got ",
results.back().getType());
if (!layout.has_value()) return success();
if (!layout.value())
return emitOptionalError(location,
"layout-attribute expected to be of array-type.");
if (layout.value().size() != resultTypes.size() - 1)
return emitOptionalError(location, "layout-attribute size must be ",
resultTypes.size() - 1,
" (which is the number of "
"op-results - 1 (for token result)), but got ",
layout.value().size());
for (auto childLayout : layout.value()) {
mlir::ArrayAttr childLayoutArr = childLayout.dyn_cast<mlir::ArrayAttr>();
if (!childLayoutArr)
return emitOptionalError(location,
"layout-attribute expected to have "
"elements of type array, but got ",
childLayout);
for (auto i : childLayoutArr) {
mlir::IntegerAttr attr = i.dyn_cast<mlir::IntegerAttr>();
if (!attr)
return emitOptionalError(location,
"layout-attribute's leaf elements are "
"expected to be of type integer, but got ",
i);
}
}
return success();
}
LogicalResult verifyIotaOp(std::optional<Location> location,
int64_t iotaDimension, Value result) {
auto shape = result.getType().cast<ShapedType>();
if (!shape.hasRank()) return success();
if (shape.getRank() == 0)
return emitOptionalError(location, "does not support scalars.");
if (iotaDimension >= shape.getRank() || iotaDimension < 0)
return emitOptionalError(
location,
"iota dimension cannot go beyond the output rank or be negative.");
return success();
}
// Verifies that operand rank matches start_indices/limit_indices/strides size
LogicalResult verifyRealDynamicSliceOp(std::optional<Location> location,
Value operand, Value startIndices,
Value limitIndices, Value strides) {
auto inputType = operand.getType().dyn_cast<RankedTensorType>();
// If operand is unranked, there is very little to verify statically.
if (!inputType) return success();
int inputRank = inputType.getRank();
auto startType = startIndices.getType().cast<RankedTensorType>();
auto limitType = limitIndices.getType().cast<RankedTensorType>();
auto stridesType = strides.getType().cast<RankedTensorType>();
if (inputRank != startType.getNumElements())
return emitOptionalError(
location, "has mismatched number of operand rank (", inputRank,
") and start_indices size (", startType.getNumElements(), ")");
if (inputRank != limitType.getNumElements())
return emitOptionalError(
location, "has mismatched number of operand rank (", inputRank,
") and limit_indices size (", limitType.getNumElements(), ")");
if (inputRank != stridesType.getNumElements())
return emitOptionalError(
location, "has mismatched number of operand rank (", inputRank,
") and strides size (", stridesType.getNumElements(), ")");
return success();
}
LogicalResult verifyRecvOp(HloDialectInterface* dialect,
std::optional<Location> location,
bool isDeviceToDevice, bool isHostToDevice,
bool isHostTransfer, ValueRange results) {
// recv_c1_i3
if (!isHostTransfer && !isDeviceToDevice)
return emitOptionalError(location,
"channel_type should be DEVICE_TO_DEVICE when "
"is_host_transfer is false");
// recv_c1_i3
if (isHostTransfer && !isHostToDevice)
return emitOptionalError(location,
"channel_type should be HOST_TO_DEVICE when "
"is_host_transfer is true");
// recv_c2
if (results.empty())
return emitOptionalError(
location, "result is expected to be at least of size 1, but got ",
results.size());
// recv_c3
for (auto resultType : results.drop_back().getTypes())
if (!resultType.isa<TensorType>())
return emitOptionalError(
location,
"everything but the last element of result types is expected to be "
"of tensor type, but got ",
resultType);
// recv_c4
if (!dialect->isTokenType(results.back().getType()))
return emitOptionalError(location,
"last element of result types is expected to "
"be of token type, but got ",
results.back().getType());
return success();
}
LogicalResult verifyReduceOp(std::optional<Location> location,
ValueRange inputs, ValueRange initValues,
ArrayRef<int64_t> dimensions, Region& body) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
SmallVector<ShapedType> initValueTypes{llvm::map_range(
initValues.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
SmallVector<int64_t> newDimensions;
Attribute encoding;
// reduce_c1, reduce_c4, reduce_c5, reduce_i3
if (failed(verifyReduceOpInputsAndInferShape(location, inputTypes, dimensions,
newDimensions, encoding)))
return failure();
// reduce_c2, reduce_c6
if (failed(verifyReducerShape(location, body.front(), inputTypes,
initValueTypes, newDimensions)))
return failure();
return success();
}
LogicalResult verifyReducePrecisionOp(std::optional<Location> location,
int32_t exponentBits,
int32_t mantissaBits) {
// reduce_precision_c2
if (exponentBits < 1)
return emitOptionalError(location, "exponent_bits must be at least 1.");
// reduce_precision_c3
if (mantissaBits < 0)
return emitOptionalError(location, "mantissa_bits must be at least 0.");
return success();
}
LogicalResult verifyReduceScatterOp(std::optional<Location> location,
Value operand, int64_t scatterDimension,
DenseIntElementsAttr replicaGroups,
int64_t channelId, bool useGlobalDeviceIds,
Region& computation, Value result) {
if (failed(verifyReplicaGroups(location, replicaGroups,
/*allGroupsMustHaveSameSize=*/true,
useGlobalDeviceIds,
/*expectedGroupSize=*/std::nullopt)))
return failure();
auto operandType = operand.getType().cast<ShapedType>();
// reduce_scatter_c7
if (failed(verifyReducerShape(
location, computation.front(), {operandType},
{RankedTensorType::get({}, operandType.getElementType())},
/*allowedDimensions=*/{})))
return failure();
auto resultType = result.getType().cast<ShapedType>();
if (!operandType.hasRank() || !resultType.hasRank()) return success();
// reduce_scatter_c8
if (operandType.getRank() != resultType.getRank())
return emitOptionalError(location,
"operand and result should have same rank");
// reduce_scatter_c2
if (scatterDimension < 0)
return emitOptionalError(location, "expects scatter_dimension >= 0");
// reduce_scatter_c2
if (scatterDimension >= operandType.getRank())
return emitOptionalError(
location, "scatter dim should be less than operand/result rank");
// reduce_scatter_c6
if (useGlobalDeviceIds && channelId <= 0)
return emitOptionalError(
location,
"channel_id must be positive when useGlobalDeviceIds is set but got: ",
channelId);
if (operandType.isDynamicDim(scatterDimension) ||
resultType.isDynamicDim(scatterDimension))
return success();
auto operandScatterDimSize = operandType.getDimSize(scatterDimension);
auto resultScatterDimSize = resultType.getDimSize(scatterDimension);
// TODO(#1746): Sync verification of ReduceScatter with HLO.
if (resultScatterDimSize == 0)
return emitOptionalError(
location, "result dimension size at scatter_dimension cannot be zero");
// TODO(#1746): Sync verification of ReduceScatter with HLO.
if (operandScatterDimSize == 0)
return emitOptionalError(
location, "operand dimension size at scatter_dimension cannot be zero");
// reduce_scatter_c8
if (isStaticDimSize(operandScatterDimSize) &&
isStaticDimSize(resultScatterDimSize) &&
operandScatterDimSize % resultScatterDimSize != 0)
return emitOptionalError(
location, "operand scatter dimension has size ", operandScatterDimSize,
", expected to be a multiple of result scatter dimension size ",
resultScatterDimSize);
// reduce_scatter_c8
for (auto index : llvm::seq<int64_t>(0, operandType.getRank())) {
if (index == scatterDimension) continue;
if (!verifyCompatibleDims(operandType.getDimSize(index),
resultType.getDimSize(index)))
return emitOptionalError(
location, "non scatter dimensions should be same for operand (",
operandType.getDimSize(index), ") and result (",
resultType.getDimSize(index), ")");
}
// reduce_scatter_c9
auto accumulatorTypesOrErr = getAccumulatorTypes(location, computation);
if (failed(accumulatorTypesOrErr)) return failure();
if (resultType.getElementType() !=
(*accumulatorTypesOrErr)[0].getElementType()) {
return emitOptionalError(location, "result element-type is expected to be ",
(*accumulatorTypesOrErr)[0].getElementType(),
", but got ", resultType.getElementType());
}
return success();
}
LogicalResult verifyReduceWindowOp(
std::optional<Location> location, ValueRange inputs, ValueRange initValues,
ArrayRef<int64_t> windowDimensions,
std::optional<ArrayRef<int64_t>> windowStrides,
std::optional<ArrayRef<int64_t>> baseDilations,
std::optional<ArrayRef<int64_t>> windowDilations,
std::optional<DenseIntElementsAttr> padding, Region& body) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
SmallVector<ShapedType> initValueTypes{llvm::map_range(
initValues.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
SmallVector<int64_t> windowDims;
SmallVector<WindowDimension> inferredWindow;
// reduce_window_c1, reduce_window_c2, reduce_window_c4...reduce_window_c12,
// reduce_window_i4...reduce_window_i7
if (failed(verifyReduceWindowOpInputsAndInferWindow(
location, inputTypes, initValueTypes, windowDimensions, windowStrides,
baseDilations, windowDilations, padding, windowDims, inferredWindow)))
return failure();
// reduce_window_c3, reduce_window_c13, reduce_window_i2
if (failed(verifyReducerShape(location, body.front(), inputTypes,
initValueTypes, windowDims)))
return failure();
return success();
}
LogicalResult verifyReshapeOp(std::optional<Location> location, Value operand,
Value result) {
// If the operand type is dynamically shaped there is nothing to verify.
auto operandTy = operand.getType().dyn_cast<RankedTensorType>();
if (!operandTy || !operandTy.hasStaticShape()) return success();
// If the operand type is statically shaped (not required) the number of
// elements must match that of the result type.
auto resultTy = result.getType().cast<RankedTensorType>();
assert(resultTy && resultTy.hasStaticShape() &&
"result type must be statically shaped");
int64_t numResultElements = resultTy.getNumElements();
int64_t numOperandElements = operandTy.getNumElements();
if (numResultElements != numOperandElements)
return emitOptionalError(location, "number of output elements (",
numResultElements,
") doesn't match expected number of elements (",
numOperandElements, ")");
return success();
}
LogicalResult verifyReverseOp(std::optional<Location> location, Value operand,
ArrayRef<int64_t> dimensions) {
llvm::SmallDenseSet<int64_t> uniqueDims(dimensions.begin(), dimensions.end());
// reverse_c2
if (uniqueDims.size() != dimensions.size())
return emitOptionalError(location,
"dimensions should be unique. Got: ", dimensions);
auto operandTy = operand.getType().dyn_cast<RankedTensorType>();
for (int64_t dim : dimensions) {
// reverse_c3
if (dim < 0)
return emitOptionalError(
location,
"all dimensions should be non-negative. Got dimension: ", dim, ".");
if (operandTy && dim >= operandTy.getRank())
return emitOptionalError(
location, "all dimensions should be between [0, ",
operandTy.getRank(), "). Got dimension: ", dim, ".");
}
return success();
}
LogicalResult verifyRngBitGeneratorOp(std::optional<Location> location,
Value initialState, Value outputState) {
auto initialShape = initialState.getType().dyn_cast<RankedTensorType>();
auto outputShape = outputState.getType().dyn_cast<RankedTensorType>();
if (failed(verifyCompatibleShape(initialShape.getShape(),
outputShape.getShape())))
return emitOptionalError(
location,
"output state shape must be compatible with initial state shape. Got: ",
initialShape, " and ", outputShape);
return success();
}
LogicalResult verifyScatterOp(std::optional<Location> location,
ValueRange inputs, Value scatterIndices,
ValueRange updates,
ArrayRef<int64_t> updateWindowDims,
ArrayRef<int64_t> insertedWindowDims,
ArrayRef<int64_t> scatterDimsToOperandDims,
int64_t indexVectorDim,
Region& updateComputation) {
// Get the first operand and update, since variadic Scatter is not yet
// implemented
auto numOperands = inputs.size();
auto scatterIndicesType = scatterIndices.getType().cast<ShapedType>();
SmallVector<ShapedType, 1> operandTypes = llvm::to_vector(llvm::map_range(
inputs.getTypes(), [](Type type) { return type.cast<ShapedType>(); }));
SmallVector<ShapedType, 1> updatesTypes = llvm::to_vector(llvm::map_range(
updates.getTypes(), [](Type type) { return type.cast<ShapedType>(); }));
bool scatterIndicesTypeRanked = scatterIndicesType.isa<RankedTensorType>();
// scatter_c1
for (auto operandType : operandTypes)
if (failed(verifyCompatibleShape(operandTypes[0].getShape(),
operandType.getShape())))
return emitOptionalError(location,
"Not all inputs have compatible shapes.");
// scatter_c3
for (auto updateType : updatesTypes)
if (failed(verifyCompatibleShape(updatesTypes[0].getShape(),
updateType.getShape())))
return emitOptionalError(location,
"Not all updates have compatible shapes.");
// scatter_c14
if (scatterIndicesTypeRanked) {
if (indexVectorDim > scatterIndicesType.getRank() || indexVectorDim < 0)
return emitOptionalError(
location,
"expects scatter index leaf dimension to be within [0, "
"rank(scatter_indices) + 1. rank(scatter_indices) is ",
scatterIndicesType.getRank(), " and scatter index leaf dimension is ",
indexVectorDim, ".");
}
SmallVector<ShapedType> inputTypes, initValueTypes;
for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
inputTypes.push_back(operandTypes[i]);
initValueTypes.push_back(
RankedTensorType::get({}, updatesTypes[i].getElementType()));
}
// scatter_c6, scatter_c15
if (failed(verifyReducerShape(location, updateComputation.front(), inputTypes,
initValueTypes,
/*allowedDimensions=*/{})))
return failure();
// rank-of('updates[i]') == size-of('update_window_dims') +
// rank-of('scatter_indices') - 1, where 'scatter_indices' is expanded by a
// trailing 1 dimension if 'index_vector_dim' == rank-of('scatter_indices')
// for all values of `i`.
SmallVector<int64_t> expandedScatterIndicesShape;
if (scatterIndicesTypeRanked) {
expandedScatterIndicesShape =
llvm::to_vector(scatterIndicesType.getShape());
if (static_cast<int64_t>(expandedScatterIndicesShape.size()) ==
indexVectorDim)
expandedScatterIndicesShape.push_back(1);
}
// scatter_c4
for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
if (scatterIndicesTypeRanked && updatesTypes[i].isa<RankedTensorType>()) {
int64_t expectedUpdatesRank =
expandedScatterIndicesShape.size() - 1 + updateWindowDims.size();
if (updatesTypes[i].getRank() != expectedUpdatesRank)
return emitOptionalError(
location, "expects updates tensor must be of rank ",
expectedUpdatesRank,
" ( == rank-of('scatter_indices') - 1 + "
"size-of('update_window_dims'), where 'scatter_indices' is "
"expanded by a trailing 1 dimension if 'index_vector_dim' == "
"rank-of('scatter_indices')), but got ",
updatesTypes[i].getRank(), ".");
}
}
// scatter_c2, scatter_c7...scatter_c13
for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
if (failed(validateScatterDimensionNumbers(
operandTypes[i], expandedScatterIndicesShape, updatesTypes[i],
operandTypes[i].isa<RankedTensorType>(), scatterIndicesTypeRanked,
updatesTypes[i].isa<RankedTensorType>(), updateWindowDims,
insertedWindowDims, scatterDimsToOperandDims, indexVectorDim,
location)))
return failure();
}
for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
if (updatesTypes[i].isa<RankedTensorType>()) {
auto updatesShape = updatesTypes[i].getShape();
if (operandTypes[i].isa<RankedTensorType>()) {
auto operandShape = operandTypes[i].getShape();
int64_t insertedDimsSeen = 0;
SmallVector<int64_t> maxUpdateSliceSizes;
const auto dimensionsSize = operandTypes[i].getRank();
maxUpdateSliceSizes.reserve(dimensionsSize);
for (int i = 0; i < dimensionsSize; ++i) {
if (insertedDimsSeen <
static_cast<int64_t>(insertedWindowDims.size()) &&
insertedWindowDims[insertedDimsSeen] == i)
++insertedDimsSeen;
else
maxUpdateSliceSizes.push_back(operandShape[i]);
}
for (int64_t i = 0; i < static_cast<int64_t>(updateWindowDims.size());
++i) {
auto updateWindowDim = updateWindowDims[i];
if (isDynamicDimSize(updatesShape[updateWindowDim]) ||
isDynamicDimSize(maxUpdateSliceSizes[i]))
continue;
// scatter_c4
if (updatesShape[updateWindowDim] > maxUpdateSliceSizes[i]) {
return emitOptionalError(
location,
"expects bounds of the window dimensions of updates to not "
"exceed the bounds of the corresponding dimensions of operand. "
"For dimension ",
updateWindowDim, ", updates bound is ",
updatesShape[updateWindowDim], ", operand bound is ",
maxUpdateSliceSizes[i], ".");
}
}
}
if (scatterIndicesTypeRanked) {
int64_t scatterDimsSeen = 0;
for (int64_t i = 0; i < static_cast<int64_t>(updatesShape.size());
++i) {
bool isUpdateWindowDim = std::binary_search(
updateWindowDims.begin(), updateWindowDims.end(), i);
if (isUpdateWindowDim) continue;
if (scatterDimsSeen == indexVectorDim) ++scatterDimsSeen;
// scatter_c4
if (!verifyCompatibleDims(
updatesShape[i],
expandedScatterIndicesShape[scatterDimsSeen]))
return emitOptionalError(
location,
"expects bounds of the scatter dimensions of updates to be "
"same as the bounds of the corresponding dimensions of scatter "
"indices. For scatter dimension ",
i, ", updates bound is ", updatesShape[i],
" , scatter_indices bound is ",
expandedScatterIndicesShape[scatterDimsSeen], ".");
++scatterDimsSeen;
}
}
}
}
return success();
}
// We intend to verify the following properties:
// P1. Check if the select function has a proper shape of (T,T) -> PRED, where
// T is a 0-D tensor with element-type same as 'operand' element-type.
// P2. Verify scatter-computation type.
// P3. size-of(window_dimension) == rank-of(input),
// where input is an element of 'inputs'.
// P4. Verify and collect the window attributes.
// P5. Check if the result type of window operation matches the source type.
LogicalResult verifySelectAndScatterOp(
std::optional<Location> location, Value operand, Value source,
Value initValue, std::optional<ArrayRef<int64_t>> windowDimensionsOpt,
std::optional<ArrayRef<int64_t>> windowStridesOpt,
std::optional<DenseIntElementsAttr> padding, Region& select,
Region& scatter) {
auto operandType = operand.getType().cast<ShapedType>();
auto initValueType = initValue.getType().cast<ShapedType>();
auto sourceType = source.getType().cast<ShapedType>();
Block& selectBlock = select.front();
// select_and_scatter_c9
if (selectBlock.getArguments().size() != 2)
return emitOptionalError(
location, "expects the select-region to take 2 parameters, but takes ",
selectBlock.getArguments().size());
Type expectedSelectArgType =
RankedTensorType::get({}, operandType.getElementType());
for (const auto& selectArgIt : llvm::enumerate(selectBlock.getArguments()))
// select_and_scatter_c9
if (!compatibleShapeAndElementType(expectedSelectArgType,
selectArgIt.value().getType(),
/*ignoreFpPrecision=*/true))
return emitOptionalError(
location, "expects the type of select-region's parameter at index ",
selectArgIt.index(), " to be ", expectedSelectArgType, ", but got ",
selectArgIt.value().getType());
auto selectResult = selectBlock.getTerminator()->getOperands();
// select_and_scatter_c9
if (selectResult.size() != 1)
return emitOptionalError(
location, "expects select-region to return single value, but got: ",
selectResult.size());
auto selectResultType = selectResult[0].getType().dyn_cast<ShapedType>();
// select_and_scatter_c9
if (!selectResultType || !selectResultType.getElementType().isInteger(1) ||
(selectResultType.hasRank() &&
selectResultType.cast<RankedTensorType>().getRank() != 0))
return emitOptionalError(
location,
"expects the return-type of select-region to be tensor<i1>, but got: ",
selectResult[0].getType());
// select_and_scatter_c10
if (failed(verifyReducerShape(
location, scatter.front(),
{RankedTensorType::get({}, sourceType.getElementType())},
{initValueType},
/*allowedDimensions=*/{})))
return failure();
auto windowDims = windowDimensionsOpt.value_or(SmallVector<int64_t>{});
if (operandType.hasRank()) {
// select_and_scatter_c4
if (operandType.getRank() != static_cast<int64_t>(windowDims.size()))
return emitOptionalError(
location,
"expects window-dimensions size == operand rank, but got "
"window-dimensions size: ",
windowDims.size(), " and operand-type: ", operandType,
" with rank = ", operandType.getRank(), ".");
}
auto windowStrides = windowStridesOpt.value_or(SmallVector<int64_t>{});
// select_and_scatter_c8, select_and_scatter_i6
auto paddingOrErr = convertPaddingAttribute(padding, location);
if (failed(paddingOrErr)) return failure();
// select_and_scatter_c5, select_and_scatter_c7
auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
windowDims, windowStrides, *paddingOrErr,
/*lhsDilation=*/{}, /*rhsDilation=*/{}, /*windowReversal*/ {}, location);
if (failed(windowOrErr)) return failure();
ShapedType windowResultType;
if (!operandType.hasRank())
windowResultType = UnrankedTensorType::get(operandType.getElementType());
else
windowResultType = RankedTensorType::get(
inferWindowOutputShape(operandType.getShape(), *windowOrErr),
operandType.getElementType());
// select_and_scatter_c1, select_and_scatter_c2
if (!compatibleShapeAndElementType(windowResultType, sourceType,
/*ignoreFpPrecision=*/true))
return emitOptionalError(location, "expects source-type to be ",
windowResultType, ", but got", sourceType);
return success();
}
LogicalResult verifySortOp(std::optional<Location> location, ValueRange inputs,
int64_t dimension, Region& comparator) {
auto operandTypes = inputs.getTypes();
for (auto operandType : operandTypes) {
auto operandShapedType = operandType.cast<ShapedType>();
if (operandShapedType.hasRank()) {
int64_t cmpDim = dimension;
int64_t rank = operandShapedType.getRank();
// sort_c4
if (cmpDim < -rank || cmpDim >= rank)
return emitOptionalError(
location, "dimension attribute value must be in range [-", rank,
", ", rank, "), but found ", cmpDim);
else
break; // ODS SameOperandsAndResultShape asserts inputs have same shape
}
}
Block& block = comparator.front();
// sort_c5
size_t numOperands = operandTypes.size();
if (block.getNumArguments() != 2 * numOperands)
return emitOptionalError(location, "comparator block should have ",
2 * numOperands, " arguments");
// sort_c5
for (const auto& indexedOperandType : llvm::enumerate(operandTypes)) {
int index = indexedOperandType.index();
Type elementType =
indexedOperandType.value().cast<ShapedType>().getElementType();
Type shapedType = RankedTensorType::get({}, elementType);
for (int i : {2 * index, 2 * index + 1}) {
Type argType = block.getArgument(i).getType();
if (argType != shapedType)
return emitOptionalError(location, "comparator block argument #", i,
" should be of type ", shapedType, " but got ",
argType);
}
}
// sort_c5
auto comparatorResult = block.getTerminator()->getOperands();
if (comparatorResult.size() != 1)
return emitOptionalError(location,
"comparator must return single output but got ",
comparatorResult.size());
// sort_c5
auto comparatorResultType = comparatorResult[0].getType().cast<ShapedType>();
if ((comparatorResultType.hasRank() && comparatorResultType.getRank() != 0) ||
!comparatorResultType.getElementType().isInteger(1))
return emitOptionalError(location,
"comparator must return tensor<i1> but got ",
comparatorResult[0].getType());
return success();
}
LogicalResult verifyWhileOp(std::optional<Location> location,
ValueRange operand, Region& cond, Region& body) {
auto operandTypes = operand.getTypes();
auto condArgsTypes = cond.front().getArgumentTypes();
auto bodyArgsTypes = body.front().getArgumentTypes();
// while_c1
if (!isCompatibleForHloTypeInference(operandTypes, condArgsTypes))
return emitOptionalError(location,
"expect operands to be compatible with condition "
"block arguments but got ",
operandTypes, " vs ", condArgsTypes);
// while_c2
if (!isCompatibleForHloTypeInference(operandTypes, bodyArgsTypes))
return emitOptionalError(
location,
"expect operands to be compatible with body block arguments but got ",
operandTypes, " vs ", bodyArgsTypes);
// while_c2
auto bodyReturnTypes = body.front().getTerminator()->getOperandTypes();
if (!isCompatibleForHloTypeInference(operandTypes, bodyReturnTypes))
return emitOptionalError(location,
"expect operands to be compatible with body block "
"return types but got ",
operandTypes, " vs ", bodyReturnTypes);
// while_c1
auto condReturnTypes = cond.front().back().getOperandTypes();
if (condReturnTypes.size() != 1)
return emitOptionalError(
location, "expect condition body returns a single value but got ",
condReturnTypes.size());
// while_c1
auto operandType = condReturnTypes[0].cast<ShapedType>();
if ((operandType.hasRank() && operandType.getRank() != 0) ||
!operandType.getElementType().isInteger(1))
return emitOptionalError(
location,
"expect condition block return a zero-ranked tensor of i1 but got ",
condReturnTypes[0]);
return success();
}
} // end namespace hlo
} // end namespace mlir