Handles more cases of redundant selects
* Handles OpConstantNull and vector types
* vector selects (except against a null) are converted to vector
shuffles
* Added tests
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index 7e4dddb..8121931 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -1549,25 +1549,65 @@
assert(inst->NumInOperands() == 3);
assert(constants.size() == 3);
- const analysis::BoolConstant* bc =
- constants[0] ? constants[0]->AsBoolConstant() : nullptr;
uint32_t true_id = inst->GetSingleWordInOperand(1);
uint32_t false_id = inst->GetSingleWordInOperand(2);
- if (bc) {
- // Select condition is constant, result is known
- inst->SetOpcode(SpvOpCopyObject);
- inst->SetInOperands(
- {{SPV_OPERAND_TYPE_ID, {bc->value() ? true_id : false_id}}});
- return true;
- } else if (true_id == false_id) {
+ if (true_id == false_id) {
// Both results are the same, condition doesn't matter
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
return true;
- } else {
- return false;
+ } else if (constants[0]) {
+ const analysis::Type* type = constants[0]->type();
+ if (type->AsBool()) {
+ // Scalar constant value, select the corresponding value.
+ inst->SetOpcode(SpvOpCopyObject);
+ if (constants[0]->AsNullConstant() ||
+ !constants[0]->AsBoolConstant()->value()) {
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
+ } else {
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
+ }
+ return true;
+ } else {
+ assert(type->AsVector());
+ if (constants[0]->AsNullConstant()) {
+ // All values come from false id.
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
+ return true;
+ } else {
+ // Convert to a vector shuffle.
+ std::vector<ir::Operand> ops;
+ ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
+ ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
+ const analysis::VectorConstant* vector_const =
+ constants[0]->AsVectorConstant();
+ uint32_t size =
+ static_cast<uint32_t>(vector_const->GetComponents().size());
+ for (uint32_t i = 0; i != size; ++i) {
+ const analysis::Constant* component =
+ vector_const->GetComponents()[i];
+ if (component->AsNullConstant() ||
+ !component->AsBoolConstant()->value()) {
+ // Selecting from the false vector which is the second input
+ // vector to the shuffle. Offset the index by |size|.
+ ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
+ } else {
+ // Selecting from true vector which is the first input vector to
+ // the shuffle.
+ ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
+ }
+ }
+
+ inst->SetOpcode(SpvOpVectorShuffle);
+ inst->SetInOperands(std::move(ops));
+ return true;
+ }
+ }
}
+
+ return false;
};
}
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 345cfea..7d813d4 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -130,6 +130,7 @@
%101 = OpConstantTrue %bool ; Need a def with an numerical id to define id maps.
%true = OpConstantTrue %bool
%false = OpConstantFalse %bool
+%bool_null = OpConstantNull %bool
%short = OpTypeInt 16 1
%int = OpTypeInt 32 1
%long = OpTypeInt 64 1
@@ -139,6 +140,7 @@
%v4float = OpTypeVector %float 4
%v4double = OpTypeVector %double 4
%v2float = OpTypeVector %float 2
+%v2bool = OpTypeVector %bool 2
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
%_ptr_uint = OpTypePointer Function %uint
@@ -176,6 +178,9 @@
%v2int_2_3 = OpConstantComposite %v2int %int_2 %int_3
%v2int_3_2 = OpConstantComposite %v2int %int_3 %int_2
%v2int_4_4 = OpConstantComposite %v2int %int_4 %int_4
+%v2bool_null = OpConstantNull %v2bool
+%v2bool_true_false = OpConstantComposite %v2bool %true %false
+%v2bool_false_true = OpConstantComposite %v2bool %false %true
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
%v2int_null = OpConstantNull %v2int
%102 = OpConstantComposite %v2int %103 %103
@@ -2336,40 +2341,6 @@
2, 0)
));
-INSTANTIATE_TEST_CASE_P(SelectFoldingTest, GeneralInstructionFoldingTest,
-::testing::Values(
- // Test case 0: Fold select with the same values for both sides
- InstructionFoldingCase<uint32_t>(
- Header() + "%main = OpFunction %void None %void_func\n" +
- "%main_lab = OpLabel\n" +
- "%n = OpVariable %_ptr_bool Function\n" +
- "%load = OpLoad %bool %n\n" +
- "%2 = OpSelect %int %load %100 %100\n" +
- "OpReturn\n" +
- "OpFunctionEnd",
- 2, INT_0_ID),
- // Test case 1: Fold select true to left side
- InstructionFoldingCase<uint32_t>(
- Header() + "%main = OpFunction %void None %void_func\n" +
- "%main_lab = OpLabel\n" +
- "%n = OpVariable %_ptr_int Function\n" +
- "%load = OpLoad %bool %n\n" +
- "%2 = OpSelect %int %true %100 %n\n" +
- "OpReturn\n" +
- "OpFunctionEnd",
- 2, INT_0_ID),
- // Test case 2: Fold select false to right side
- InstructionFoldingCase<uint32_t>(
- Header() + "%main = OpFunction %void None %void_func\n" +
- "%main_lab = OpLabel\n" +
- "%n = OpVariable %_ptr_int Function\n" +
- "%load = OpLoad %bool %n\n" +
- "%2 = OpSelect %int %false %n %100\n" +
- "OpReturn\n" +
- "OpFunctionEnd",
- 2, INT_0_ID)
-));
-
INSTANTIATE_TEST_CASE_P(FloatRedundantFoldingTest, GeneralInstructionFoldingTest,
::testing::Values(
// Test case 0: Don't fold n + 1.0
@@ -4302,5 +4273,113 @@
"OpFunctionEnd\n",
4, true)
));
+
+INSTANTIATE_TEST_CASE_P(SelectFoldingTest, MatchingInstructionFoldingTest,
+::testing::Values(
+ // Test case 0: Fold select with the same values for both sides
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
+ "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_bool Function\n" +
+ "%load = OpLoad %bool %n\n" +
+ "%2 = OpSelect %int %load %100 %100\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, true),
+ // Test case 1: Fold select true to left side
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
+ "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_int Function\n" +
+ "%load = OpLoad %bool %n\n" +
+ "%2 = OpSelect %int %true %100 %n\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, true),
+ // Test case 2: Fold select false to right side
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
+ "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_int Function\n" +
+ "%load = OpLoad %bool %n\n" +
+ "%2 = OpSelect %int %false %n %100\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, true),
+ // Test case 3: Fold select null to right side
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[int0:%\\w+]] = OpConstant [[int]] 0\n" +
+ "; CHECK: %2 = OpCopyObject [[int]] [[int0]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_int Function\n" +
+ "%load = OpLoad %int %n\n" +
+ "%2 = OpSelect %int %bool_null %load %100\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, true),
+ // Test case 4: vector null
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK: [[v2int2_2:%\\w+]] = OpConstantComposite [[v2int]] [[int2]] [[int2]]\n" +
+ "; CHECK: %2 = OpCopyObject [[v2int]] [[v2int2_2]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%n = OpVariable %_ptr_v2int Function\n" +
+ "%load = OpLoad %v2int %n\n" +
+ "%2 = OpSelect %v2int %v2bool_null %load %v2int_2_2\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 2, true),
+ // Test case 5: vector select
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 0 3\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%m = OpVariable %_ptr_v2int Function\n" +
+ "%n = OpVariable %_ptr_v2int Function\n" +
+ "%2 = OpLoad %v2int %n\n" +
+ "%3 = OpLoad %v2int %n\n" +
+ "%4 = OpSelect %v2int %v2bool_true_false %2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true),
+ // Test case 6: vector select
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[v2int:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: %4 = OpVectorShuffle [[v2int]] %2 %3 2 1\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%m = OpVariable %_ptr_v2int Function\n" +
+ "%n = OpVariable %_ptr_v2int Function\n" +
+ "%2 = OpLoad %v2int %n\n" +
+ "%3 = OpLoad %v2int %n\n" +
+ "%4 = OpSelect %v2int %v2bool_false_true %2 %3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, true)
+));
#endif
} // anonymous namespace