Add folding rule to merge a vector shuffle feeding another one.
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index c3c956d..6682986 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -1540,74 +1540,73 @@
// corresponding |a| in the FMix is 0 or 1, we can extract from one of the
// operands of the FMix.
FoldingRule FMixFeedingExtract() {
- return
- [](opt::Instruction* inst,
- const std::vector<const analysis::Constant*>&) {
- assert(inst->opcode() == SpvOpCompositeExtract &&
- "Wrong opcode. Should be OpCompositeExtract.");
- opt::IRContext* context = inst->context();
- analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
- analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+ return [](opt::Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ assert(inst->opcode() == SpvOpCompositeExtract &&
+ "Wrong opcode. Should be OpCompositeExtract.");
+ opt::IRContext* context = inst->context();
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ analysis::ConstantManager* const_mgr = context->get_constant_mgr();
- uint32_t composite_id =
- inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
- opt::Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
+ uint32_t composite_id =
+ inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
+ opt::Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
- if (composite_inst->opcode() != SpvOpExtInst) {
- return false;
- }
+ if (composite_inst->opcode() != SpvOpExtInst) {
+ return false;
+ }
- uint32_t inst_set_id =
- inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+ uint32_t inst_set_id =
+ inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
- if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
- inst_set_id ||
- composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
- GLSLstd450FMix) {
- return false;
- }
+ if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
+ inst_set_id ||
+ composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
+ GLSLstd450FMix) {
+ return false;
+ }
- // Get the |a| for the FMix instruction.
- uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
- std::unique_ptr<opt::Instruction> a(inst->Clone(inst->context()));
- a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
- context->get_instruction_folder().FoldInstruction(a.get());
+ // Get the |a| for the FMix instruction.
+ uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
+ std::unique_ptr<opt::Instruction> a(inst->Clone(inst->context()));
+ a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
+ context->get_instruction_folder().FoldInstruction(a.get());
- if (a->opcode() != SpvOpCopyObject) {
- return false;
- }
+ if (a->opcode() != SpvOpCopyObject) {
+ return false;
+ }
- const analysis::Constant* a_const =
- const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
+ const analysis::Constant* a_const =
+ const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
- if (!a_const) {
- return false;
- }
+ if (!a_const) {
+ return false;
+ }
- bool use_x = false;
+ bool use_x = false;
- assert(a_const->type()->AsFloat());
- double element_value = a_const->GetValueAsDouble();
- if (element_value == 0.0) {
- use_x = true;
- } else if (element_value == 1.0) {
- use_x = false;
- } else {
- return false;
- }
+ assert(a_const->type()->AsFloat());
+ double element_value = a_const->GetValueAsDouble();
+ if (element_value == 0.0) {
+ use_x = true;
+ } else if (element_value == 1.0) {
+ use_x = false;
+ } else {
+ return false;
+ }
- // Get the id of the of the vector the element comes from.
- uint32_t new_vector = 0;
- if (use_x) {
- new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
- } else {
- new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
- }
+ // Get the id of the of the vector the element comes from.
+ uint32_t new_vector = 0;
+ if (use_x) {
+ new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
+ } else {
+ new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
+ }
- // Update the extract instruction.
- inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
- return true;
- };
+ // Update the extract instruction.
+ inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
+ return true;
+ };
}
FoldingRule RedundantPhi() {
@@ -2019,6 +2018,111 @@
return false;
};
}
+
+FoldingRule VectorShuffleFeedingShuffle() {
+ return [](opt::Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ assert(inst->opcode() == SpvOpVectorShuffle &&
+ "Wrong opcode. Should be OpVectorShuffle.");
+
+ IRContext* context = inst->context();
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+
+ Instruction* feeding_shuffle_inst =
+ def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
+ analysis::Vector* op0_type =
+ type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
+ uint32_t op0_length = op0_type->element_count();
+
+ bool feeder_is_op0 = true;
+ if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
+ feeding_shuffle_inst =
+ def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
+ feeder_is_op0 = false;
+ }
+
+ if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
+ return false;
+ }
+
+ Instruction* feeder2 =
+ def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
+ analysis::Vector* feeder_op0_type =
+ type_mgr->GetType(feeder2->type_id())->AsVector();
+ uint32_t feeder_op0_length = feeder_op0_type->element_count();
+
+ uint32_t new_feeder_id = 0;
+ std::vector<Operand> new_operands;
+ new_operands.resize(
+ 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
+ for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
+ uint32_t component_index = inst->GetSingleWordInOperand(op);
+
+ if (feeder_is_op0 == (component_index < op0_length)) {
+ // This component comes from the feeding_shuffle_inst. Update
+ // |component_index| to be the index into the operand of the feeder.
+
+ // Adjust component_index to get the index into the operands of the
+ // feeding_shuffle_inst.
+ if (component_index >= op0_length) {
+ component_index -= op0_length;
+ }
+ component_index =
+ feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
+
+ // Check if we are using a component from the first or second operand of
+ // the feeding instruction.
+ if (component_index < feeder_op0_length) {
+ if (new_feeder_id == 0) {
+ // First time through, save the id of the operand the element comes
+ // from.
+ new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
+ } else if (new_feeder_id !=
+ feeding_shuffle_inst->GetSingleWordInOperand(0)) {
+ // We need both elements of the feeding_shuffle_inst, so we cannot
+ // fold.
+ return false;
+ }
+ } else {
+ if (new_feeder_id == 0) {
+ // First time through, save the id of the operand the element comes
+ // from.
+ new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
+ } else if (new_feeder_id !=
+ feeding_shuffle_inst->GetSingleWordInOperand(1)) {
+ // We need both elements of the feeding_shuffle_inst, so we cannot
+ // fold.
+ return false;
+ }
+ component_index -= feeder_op0_length;
+ }
+
+ if (!feeder_is_op0) {
+ component_index += op0_length;
+ }
+ }
+ new_operands.push_back(
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
+ }
+
+ if (new_feeder_id == 0) {
+ new_feeder_id = inst->GetSingleWordInOperand(feeder_is_op0 ? 1 : 0);
+ }
+
+ if (feeder_is_op0) {
+ new_operands[0].words[0] = new_feeder_id;
+ new_operands[1] = inst->GetInOperand(1);
+ } else {
+ new_operands[1].words[0] = new_feeder_id;
+ new_operands[0] = inst->GetInOperand(0);
+ }
+
+ inst->SetInOperands(std::move(new_operands));
+ return true;
+ };
+}
+
} // namespace
FoldingRules::FoldingRules() {
@@ -2087,7 +2191,8 @@
rules_[SpvOpStore].push_back(StoringUndef());
rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
-}
+ rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
+}
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 662add1..ff4a1b5 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -5650,6 +5650,174 @@
0 /* OpStore */, true)
));
+INSTANTIATE_TEST_CASE_P(VectorShuffleMatchingTest, MatchingInstructionWithNoResultFoldingTest,
+::testing::Values(
+ // Test case 0: Basic test 1
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: OpVectorShuffle\n" +
+ "; CHECK: OpVectorShuffle {{%\\w+}} %7 %5 2 3 6 7\n" +
+ "; CHECK: OpReturn\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" +
+ "%9 = OpVectorShuffle %v4double %7 %8 2 3 4 5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 9, true),
+ // Test case 1: Basic test 2
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: OpVectorShuffle\n" +
+ "; CHECK: OpVectorShuffle {{%\\w+}} %6 %7 0 1 4 5\n" +
+ "; CHECK: OpReturn\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" +
+ "%9 = OpVectorShuffle %v4double %8 %7 2 3 4 5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 9, true),
+ // Test case 2: Basic test 3
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: OpVectorShuffle\n" +
+ "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 3 2 4 5\n" +
+ "; CHECK: OpReturn\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" +
+ "%9 = OpVectorShuffle %v4double %8 %7 1 0 4 5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 9, true),
+ // Test case 3: Basic test 4
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: OpVectorShuffle\n" +
+ "; CHECK: OpVectorShuffle {{%\\w+}} %7 %6 2 3 5 4\n" +
+ "; CHECK: OpReturn\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" +
+ "%9 = OpVectorShuffle %v4double %7 %8 2 3 7 6\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 9, true),
+ // Test case 4: Don't use feeder.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: OpVectorShuffle\n" +
+ "; CHECK: OpVectorShuffle {{%\\w+}} %7 %7 2 3 0 1\n" +
+ "; CHECK: OpReturn\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" +
+ "%9 = OpVectorShuffle %v4double %7 %8 2 3 0 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 9, true),
+ // Test case 5: Don't fold, need both operands of the feeder.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" +
+ "%9 = OpVectorShuffle %v4double %7 %8 2 3 7 5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 9, false),
+ // Test case 6: Don't fold, need both operands of the feeder.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v4double %5 %6 2 3 4 5\n" +
+ "%9 = OpVectorShuffle %v4double %8 %7 2 0 7 5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 9, false),
+ // Test case 7: Fold, need both operands of the feeder, but they are the same.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: OpVectorShuffle\n" +
+ "; CHECK: OpVectorShuffle {{%\\w+}} %5 %7 0 2 7 5\n" +
+ "; CHECK: OpReturn\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v4double %5 %5 2 3 4 5\n" +
+ "%9 = OpVectorShuffle %v4double %8 %7 2 0 7 5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 9, true),
+ // Test case 8: Fold, need both operands of the feeder, but they are the same.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: OpVectorShuffle\n" +
+ "; CHECK: OpVectorShuffle {{%\\w+}} %7 %5 2 0 5 7\n" +
+ "; CHECK: OpReturn\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpVariable %_ptr_v4double Function\n" +
+ "%3 = OpVariable %_ptr_v4double Function\n" +
+ "%4 = OpVariable %_ptr_v4double Function\n" +
+ "%5 = OpLoad %v4double %2\n" +
+ "%6 = OpLoad %v4double %3\n" +
+ "%7 = OpLoad %v4double %4\n" +
+ "%8 = OpVectorShuffle %v4double %5 %5 2 3 4 5\n" +
+ "%9 = OpVectorShuffle %v4double %7 %8 2 0 7 5\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 9, true)
+));
#endif
} // namespace