Revert "Updates to ConvolutionOP verifier to support quantization constraints…"
This reverts commit a8f44d0becc773e0ec01703235559827bd16f68f.
diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp
index 8ba072f..48cfb4d 100644
--- a/stablehlo/dialect/TypeInference.cpp
+++ b/stablehlo/dialect/TypeInference.cpp
@@ -65,25 +65,6 @@
namespace mlir {
namespace hlo {
-namespace {
-//===----------------------------------------------------------------------===//
-// Utils for quantization specific verifications
-//===----------------------------------------------------------------------===//
-template <typename T>
-bool allQuantized(ArrayRef<Type> typeRange) {
- return llvm::all_of(typeRange, [&](Type val) {
- return val.cast<ShapedType>().getElementType().isa<T>();
- });
-}
-
-template <typename T>
-bool noneQuantized(ArrayRef<Type> typeRange) {
- return llvm::all_of(typeRange, [&](Type val) {
- return !val.cast<ShapedType>().getElementType().isa<T>();
- });
-}
-
-} // namespace
//===----------------------------------------------------------------------===//
// Utils for shape functions.
@@ -3472,61 +3453,6 @@
"is incompatible with return type of operation ",
shapedResultType, "");
- llvm::SmallVector<Type, 3> typeEntries{lhsType, rhsType, resultType};
- if (noneQuantized<quant::QuantizedType>(typeEntries)) return success();
- // convolution_c28
- if (!allQuantized<quant::QuantizedType>(typeEntries)) {
- return emitOptionalError(location,
- "not all of operands and result are quantized");
- }
-
- auto lhsQType =
- getElementTypeOrSelf(lhsType).dyn_cast<quant::QuantizedType>();
- auto rhsQType =
- getElementTypeOrSelf(rhsType).dyn_cast<quant::QuantizedType>();
- auto resultQType =
- getElementTypeOrSelf(resultType).dyn_cast<quant::QuantizedType>();
- // convolution_c29
- if (lhsQType.getStorageType() != rhsQType.getStorageType())
- return emitOptionalError(location, "mismatched operand storage types ",
- lhsQType.getStorageType(), " and ",
- rhsQType.getStorageType());
- // convolution_c30
- auto expressedType = lhsQType.getExpressedType();
- if (expressedType != rhsQType.getExpressedType() ||
- expressedType != resultQType.getExpressedType())
- return emitOptionalError(location,
- "mismatched operands and result expressed types");
-
- llvm::SmallVector<Type, 2> typeEntriesPerAxis{rhsType, resultType};
- if (noneQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis))
- return success();
- // convolution_c31
- if (!allQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis)) {
- return emitOptionalError(location,
- "rhs and result are of mixed per_tensor and "
- "per_axis quantized tensor type ",
- rhsType, " and ", resultType);
- }
-
- auto rhsQPAType = rhsQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
- auto resultQPAType =
- resultQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
- // convolution_c32
- if (rhsQPAType &&
- rhsQPAType.getQuantizedDimension() != kernelOutputFeatureDimension)
- return emitOptionalError(
- location, "mismatched kernel_output_feature_dimension ",
- kernelOutputFeatureDimension, " and rhs quantized dimension ",
- rhsQPAType.getQuantizedDimension());
- // convolution_c33
- if (resultQPAType &&
- resultQPAType.getQuantizedDimension() != outputFeatureDimension)
- return emitOptionalError(location, "mismatched output_feature_dimension ",
- outputFeatureDimension,
- " and result quantized dimension ",
- resultQPAType.getQuantizedDimension());
-
return success();
}
diff --git a/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/tests/ops_stablehlo_quantized.mlir
index 0423aff..cb54cd5 100644
--- a/stablehlo/tests/ops_stablehlo_quantized.mlir
+++ b/stablehlo/tests/ops_stablehlo_quantized.mlir
@@ -821,75 +821,3 @@
%0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor<4x!quant.uniform<si8:f32, 1.000000e+00>>) -> tensor<4xf32>
func.return %0 : tensor<4xf32>
}
-
-// -----
-
-func.func @convolution_c28(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
- // expected-error@+1 {{not all of operands and result are quantized}}
- %0 = stablehlo.convolution(%arg0, %arg1)
- dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
- window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
- {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
- (tensor<1x8x8x207xf32>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
- func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
-}
-
-// -----
-
-func.func @convolution_c29(%arg0: tensor<1x8x8x207x!quant.uniform<i16:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
- // expected-error@+1 {{mismatched operand storage types 'i16' and 'i8'}}
- %0 = stablehlo.convolution(%arg0, %arg1)
- dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
- window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
- {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
- (tensor<1x8x8x207x!quant.uniform<i16:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
- func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
-}
-
-// -----
-
-func.func @convolution_c30(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f64, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
- // expected-error@+1 {{mismatched operands and result expressed types}}
- %0 = stablehlo.convolution(%arg0, %arg1)
- dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
- window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
- {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
- (tensor<1x8x8x207x!quant.uniform<i8:f64, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
- func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
-}
-
-// -----
-
-func.func @convolution_c31(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
- // expected-error@+1 {{rhs and result are of mixed per_tensor and per_axis quantized tensor type 'tensor<3x3x207x16x!quant.uniform<i8:f32:0, {1.000000e-01:-30}>>' and 'tensor<1x8x8x16x!quant.uniform<i8:f32, 1.000000e+01:50>>'}}
- %0 = stablehlo.convolution(%arg0, %arg1)
- dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
- window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
- {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
- (tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
- func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
-}
-
-// -----
-
-func.func @convolution_c32(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {0.1:-30}>> {
- // expected-error@+1 {{mismatched kernel_output_feature_dimension 3 and rhs quantized dimension 0}}
- %0 = stablehlo.convolution(%arg0, %arg1)
- dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
- window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
- {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
- (tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {0.1:-30}>>
- func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32:0, {0.1:-30}>>
-}
-
-// -----
-
-func.func @convolution_c33(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:3, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {2.0:-30}>> {
- // expected-error@+1 {{mismatched output_feature_dimension 3 and result quantized dimension 0}}
- %0 = stablehlo.convolution(%arg0, %arg1)
- dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
- window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
- {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
- (tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:3, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {2.0:-30}>>
- func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32:0, {2.0:-30}>>
-}