Add more folding for composite instructions (#4802)
* Add move folding for composite instructions
Fold chains of insert into construct
If a chain of OpCompositeInsert instruction write to every element of a
composite object, then we can replace it with an OpCompositeConstruct.
Fold a construct fed by extracts to a single extract
We already fold an OpCompositeConstruct when it is simlpy reconstructing
an object that was decomposed by a series of OpCompositeExtract
instructions. However, we do not do that if that object is an element
of a larger object.
I have updated the rule, so that if the original object is a an element
of a larger object, then the OpCompositeConstruct is replaced with a
single OpCompositeExtract from the larger object.
Fixes #4371.
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index d15ad04..ab7a20e 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -1631,6 +1631,57 @@
return true;
}
+// Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
+// OpCompositeExtract instruction, and returns the type of the final element
+// being accessed.
+const analysis::Type* GetElementType(uint32_t type_id,
+ Instruction::iterator start,
+ Instruction::iterator end,
+ const analysis::TypeManager* type_mgr) {
+ const analysis::Type* type = type_mgr->GetType(type_id);
+ for (auto index : make_range(std::move(start), std::move(end))) {
+ assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
+ index.words.size() == 1);
+ if (auto* array_type = type->AsArray()) {
+ type = array_type->element_type();
+ } else if (auto* matrix_type = type->AsMatrix()) {
+ type = matrix_type->element_type();
+ } else if (auto* struct_type = type->AsStruct()) {
+ type = struct_type->element_types()[index.words[0]];
+ } else {
+ type = nullptr;
+ }
+ }
+ return type;
+}
+
+// Returns true of |inst_1| and |inst_2| have the same indexes that will be used
+// to index into a composite object, excluding the last index. The two
+// instructions must have the same opcode, and be either OpCompositeExtract or
+// OpCompositeInsert instructions.
+bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
+ assert(inst_1->opcode() == inst_2->opcode() &&
+ "Expecting the opcodes to be the same.");
+ assert((inst_1->opcode() == SpvOpCompositeInsert ||
+ inst_1->opcode() == SpvOpCompositeExtract) &&
+ "Instructions must be OpCompositeInsert or OpCompositeExtract.");
+
+ if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
+ return false;
+ }
+
+ uint32_t first_index_position =
+ (inst_1->opcode() == SpvOpCompositeInsert ? 2 : 1);
+ for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
+ i++) {
+ if (inst_1->GetSingleWordInOperand(i) !=
+ inst_2->GetSingleWordInOperand(i)) {
+ return false;
+ }
+ }
+ return true;
+}
+
// If the OpCompositeConstruct is simply putting back together elements that
// where extracted from the same source, we can simply reuse the source.
//
@@ -1653,19 +1704,24 @@
// - extractions
// - extracting the same position they are inserting
// - all extract from the same id.
+ Instruction* first_element_inst = nullptr;
for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
const uint32_t element_id = inst->GetSingleWordInOperand(i);
Instruction* element_inst = def_use_mgr->GetDef(element_id);
+ if (first_element_inst == nullptr) {
+ first_element_inst = element_inst;
+ }
if (element_inst->opcode() != SpvOpCompositeExtract) {
return false;
}
- if (element_inst->NumInOperands() != 2) {
+ if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
return false;
}
- if (element_inst->GetSingleWordInOperand(1) != i) {
+ if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
+ 1) != i) {
return false;
}
@@ -1681,13 +1737,31 @@
// The last check it to see that the object being extracted from is the
// correct type.
Instruction* original_inst = def_use_mgr->GetDef(original_id);
- if (original_inst->type_id() != inst->type_id()) {
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ const analysis::Type* original_type =
+ GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
+ first_element_inst->end() - 1, type_mgr);
+
+ if (original_type == nullptr) {
return false;
}
- // Simplify by using the original object.
- inst->SetOpcode(SpvOpCopyObject);
- inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
+ if (inst->type_id() != type_mgr->GetId(original_type)) {
+ return false;
+ }
+
+ if (first_element_inst->NumInOperands() == 2) {
+ // Simplify by using the original object.
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
+ return true;
+ }
+
+ // Copies the original id and all indexes except for the last to the new
+ // extract instruction.
+ inst->SetOpcode(SpvOpCompositeExtract);
+ inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
+ first_element_inst->end() - 1));
return true;
}
@@ -1891,6 +1965,139 @@
};
}
+// Returns the number of elements in the composite type |type|. Returns 0 if
+// |type| is a scalar value.
+uint32_t GetNumberOfElements(const analysis::Type* type) {
+ if (auto* vector_type = type->AsVector()) {
+ return vector_type->element_count();
+ }
+ if (auto* matrix_type = type->AsMatrix()) {
+ return matrix_type->element_count();
+ }
+ if (auto* struct_type = type->AsStruct()) {
+ return static_cast<uint32_t>(struct_type->element_types().size());
+ }
+ if (auto* array_type = type->AsArray()) {
+ return array_type->length_info().words[0];
+ }
+ return 0;
+}
+
+// Returns a map with the set of values that were inserted into an object by
+// the chain of OpCompositeInsertInstruction starting with |inst|.
+// The map will map the index to the value inserted at that index.
+std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
+ analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
+ std::map<uint32_t, uint32_t> values_inserted;
+ Instruction* current_inst = inst;
+ while (current_inst->opcode() == SpvOpCompositeInsert) {
+ if (current_inst->NumInOperands() > inst->NumInOperands()) {
+ // This is the catch the case
+ // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
+ // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
+ // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
+ // In this case we cannot do a single construct to get the matrix.
+ uint32_t partially_inserted_element_index =
+ current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
+ if (values_inserted.count(partially_inserted_element_index) == 0)
+ return {};
+ }
+ if (HaveSameIndexesExceptForLast(inst, current_inst)) {
+ values_inserted.insert(
+ {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
+ 1),
+ current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
+ }
+ current_inst = def_use_mgr->GetDef(
+ current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
+ }
+ return values_inserted;
+}
+
+// Returns true of there is an entry in |values_inserted| for every element of
+// |Type|.
+bool DoInsertedValuesCoverEntireObject(
+ const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
+ uint32_t container_size = GetNumberOfElements(type);
+ if (container_size != values_inserted.size()) {
+ return false;
+ }
+
+ if (values_inserted.rbegin()->first >= container_size) {
+ return false;
+ }
+ return true;
+}
+
+// Returns the type of the element that immediately contains the element being
+// inserted by the OpCompositeInsert instruction |inst|.
+const analysis::Type* GetContainerType(Instruction* inst) {
+ assert(inst->opcode() == SpvOpCompositeInsert);
+ analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
+ return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
+ type_mgr);
+}
+
+// Returns an OpCompositeConstruct instruction that build an object with
+// |type_id| out of the values in |values_inserted|. Each value will be
+// placed at the index corresponding to the value. The new instruction will
+// be placed before |insert_before|.
+Instruction* BuildCompositeConstruct(
+ uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
+ Instruction* insert_before) {
+ InstructionBuilder ir_builder(
+ insert_before->context(), insert_before,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ std::vector<uint32_t> ids_in_order;
+ for (auto it : values_inserted) {
+ ids_in_order.push_back(it.second);
+ }
+ Instruction* construct =
+ ir_builder.AddCompositeConstruct(type_id, ids_in_order);
+ return construct;
+}
+
+// Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
+// object as |inst| with final index removed. If the resulting
+// OpCompositeInsert instruction would have no remaining indexes, the
+// instruction is replaced with an OpCopyObject instead.
+void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
+ if (inst->NumInOperands() == 3) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
+ } else {
+ inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
+ inst->RemoveOperand(inst->NumOperands() - 1);
+ }
+}
+
+// Replaces a series of |OpCompositeInsert| instruction that cover the entire
+// object with an |OpCompositeConstruct|.
+bool CompositeInsertToCompositeConstruct(
+ IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ assert(inst->opcode() == SpvOpCompositeInsert &&
+ "Wrong opcode. Should be OpCompositeInsert.");
+ if (inst->NumInOperands() < 3) return false;
+
+ std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
+ const analysis::Type* container_type = GetContainerType(inst);
+ if (container_type == nullptr) {
+ return false;
+ }
+
+ if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
+ return false;
+ }
+
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ Instruction* construct = BuildCompositeConstruct(
+ type_mgr->GetId(container_type), values_inserted, inst);
+ InsertConstructedObject(inst, construct);
+ return true;
+}
+
FoldingRule RedundantPhi() {
// An OpPhi instruction where all values are the same or the result of the phi
// itself, can be replaced by the value itself.
@@ -2591,6 +2798,8 @@
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
+ rules_[SpvOpCompositeInsert].push_back(CompositeInsertToCompositeConstruct);
+
rules_[SpvOpDot].push_back(DotProductDoingExtract());
rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 2ca3256..e2240b8 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -147,6 +147,7 @@
%v2double = OpTypeVector %double 2
%v2half = OpTypeVector %half 2
%v2bool = OpTypeVector %bool 2
+%m2x2int = OpTypeMatrix %v2int 2
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
%_ptr_uint = OpTypePointer Function %uint
@@ -218,7 +219,9 @@
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
%v2int_null = OpConstantNull %v2int
%102 = OpConstantComposite %v2int %103 %103
+%v4int_undef = OpUndef %v4int
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
+%m2x2int_undef = OpUndef %m2x2int
%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
%float_n1 = OpConstant %float -1
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
@@ -6862,7 +6865,7 @@
4, true)
));
-INSTANTIATE_TEST_SUITE_P(CompositeExtractMatchingTest, MatchingInstructionFoldingTest,
+INSTANTIATE_TEST_SUITE_P(CompositeExtractOrInsertMatchingTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: Extracting from result of consecutive shuffles of differing
// size.
@@ -7002,7 +7005,145 @@
"%4 = OpCompositeExtract %int %3 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
- 4, true)
+ 4, true),
+ // Test case 8: Inserting every element of a vector turns into a composite construct.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+ "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+ "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
+ "; CHECK: %5 = OpCopyObject [[v4]] [[construct]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
+ "%3 = OpCompositeInsert %v4int %int_1 %2 1\n" +
+ "%4 = OpCompositeInsert %v4int %int_2 %3 2\n" +
+ "%5 = OpCompositeInsert %v4int %int_3 %4 3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 5, true),
+ // Test case 9: Inserting every element of a vector turns into a composite construct in a different order.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+ "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+ "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
+ "; CHECK: %5 = OpCopyObject [[v4]] [[construct]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
+ "%4 = OpCompositeInsert %v4int %int_2 %2 2\n" +
+ "%3 = OpCompositeInsert %v4int %int_1 %4 1\n" +
+ "%5 = OpCompositeInsert %v4int %int_3 %3 3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 5, true),
+ // Test case 10: Check multiple inserts to the same position are handled correctly.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+ "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+ "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
+ "; CHECK: %6 = OpCopyObject [[v4]] [[construct]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
+ "%3 = OpCompositeInsert %v4int %int_2 %2 2\n" +
+ "%4 = OpCompositeInsert %v4int %int_4 %3 1\n" +
+ "%5 = OpCompositeInsert %v4int %int_1 %4 1\n" +
+ "%6 = OpCompositeInsert %v4int %int_3 %5 3\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 6, true),
+ // Test case 11: The last indexes are 0 and 1, but they have different first indexes. This should not be folded.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %m2x2int %100 %m2x2int_undef 0 0\n" +
+ "%3 = OpCompositeInsert %m2x2int %int_1 %2 1 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 3, false),
+ // Test case 12: Don't fold when there is a partial insertion.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0\n" +
+ "%3 = OpCompositeInsert %m2x2int %int_4 %2 0 0\n" +
+ "%4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 4, false),
+ // Test case 13: Insert into a column of a matrix
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
+ "; CHECK-DAG: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+// We keep this insert in the chain. DeadInsertElimPass should remove it.
+ "; CHECK: [[insert:%\\w+]] = OpCompositeInsert [[m2x2]] %100 [[m2x2_undef]] 0 0\n" +
+ "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v2]] %100 [[int1]]\n" +
+ "; CHECK: %3 = OpCompositeInsert [[m2x2]] [[construct]] [[insert]] 0\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeInsert %m2x2int %100 %m2x2int_undef 0 0\n" +
+ "%3 = OpCompositeInsert %m2x2int %int_1 %2 0 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 3, true),
+ // Test case 14: Insert all elements of the matrix.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK-DAG: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
+ "; CHECK-DAG: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
+ "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+ "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+ "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+ "; CHECK: [[c0:%\\w+]] = OpCompositeConstruct [[v2]] %100 [[int1]]\n" +
+ "; CHECK: [[c1:%\\w+]] = OpCompositeConstruct [[v2]] [[int2]] [[int3]]\n" +
+ "; CHECK: [[matrix:%\\w+]] = OpCompositeConstruct [[m2x2]] [[c0]] [[c1]]\n" +
+ "; CHECK: %5 = OpCopyObject [[m2x2]] [[matrix]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%2 = OpCompositeConstruct %v2int %100 %int_1\n" +
+ "%3 = OpCompositeInsert %m2x2int %2 %m2x2int_undef 0\n" +
+ "%4 = OpCompositeInsert %m2x2int %int_2 %3 1 0\n" +
+ "%5 = OpCompositeInsert %m2x2int %int_3 %4 1 1\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 5, true),
+ // Test case 15: Replace construct with extract when reconstructing a member
+ // of another object.
+ InstructionFoldingCase<bool>(
+ Header() +
+ "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+ "; CHECK: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
+ "; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
+ "; CHECK: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
+ "; CHECK: %5 = OpCompositeExtract [[v2]] [[m2x2_undef]]\n" +
+ "%main = OpFunction %void None %void_func\n" +
+ "%main_lab = OpLabel\n" +
+ "%3 = OpCompositeExtract %int %m2x2int_undef 1 0\n" +
+ "%4 = OpCompositeExtract %int %m2x2int_undef 1 1\n" +
+ "%5 = OpCompositeConstruct %v2int %3 %4\n" +
+ "OpReturn\n" +
+ "OpFunctionEnd",
+ 5, true)
));
INSTANTIATE_TEST_SUITE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,