Fix binding number calculation in desc sroa (#4095)
When there is an array of strutured buffers, desc sroa will only split
the array, but not a struct type in the structured buffer. However,
the calcualtion of the number of binding a struct requires does not take
this into consideration. This commit will fix that.
diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp
index b68549a..5e95006 100644
--- a/source/opt/desc_sroa.cpp
+++ b/source/opt/desc_sroa.cpp
@@ -63,16 +63,7 @@
// All structures with descriptor assignments must be replaced by variables,
// one for each of their members - with the exceptions of buffers.
- // Buffers are represented as structures, but we shouldn't replace a buffer
- // with its elements. All buffers have offset decorations for members of their
- // structure types.
- bool has_offset_decoration = false;
- context()->get_decoration_mgr()->ForEachDecoration(
- var_type_inst->result_id(), SpvDecorationOffset,
- [&has_offset_decoration](const Instruction&) {
- has_offset_decoration = true;
- });
- if (has_offset_decoration) {
+ if (IsTypeOfStructuredBuffer(var_type_inst)) {
return false;
}
@@ -99,6 +90,23 @@
return true;
}
+bool DescriptorScalarReplacement::IsTypeOfStructuredBuffer(
+ const Instruction* type) const {
+ if (type->opcode() != SpvOpTypeStruct) {
+ return false;
+ }
+
+ // All buffers have offset decorations for members of their structure types.
+ // This is how we distinguish it from a structure of descriptors.
+ bool has_offset_decoration = false;
+ context()->get_decoration_mgr()->ForEachDecoration(
+ type->result_id(), SpvDecorationOffset,
+ [&has_offset_decoration](const Instruction&) {
+ has_offset_decoration = true;
+ });
+ return has_offset_decoration;
+}
+
bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
std::vector<Instruction*> access_chain_work_list;
std::vector<Instruction*> load_work_list;
@@ -368,7 +376,8 @@
// The number of bindings consumed by a structure is the sum of the bindings
// used by its members.
- if (type_inst->opcode() == SpvOpTypeStruct) {
+ if (type_inst->opcode() == SpvOpTypeStruct &&
+ !IsTypeOfStructuredBuffer(type_inst)) {
uint32_t sum = 0;
for (uint32_t i = 0; i < type_inst->NumInOperands(); i++)
sum += GetNumBindingsUsedByType(type_inst->GetSingleWordInOperand(i));
diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h
index c3aa0ea..cd72fd3 100644
--- a/source/opt/desc_sroa.h
+++ b/source/opt/desc_sroa.h
@@ -93,6 +93,11 @@
// bindings used by its members.
uint32_t GetNumBindingsUsedByType(uint32_t type_id);
+ // Returns true if |type| is a type that could be used for a structured buffer
+ // as opposed to a type that would be used for a structure of resource
+ // descriptors.
+ bool IsTypeOfStructuredBuffer(const Instruction* type) const;
+
// A map from an OpVariable instruction to the set of variables that will be
// used to replace it. The entry |replacement_variables_[var][i]| is the id of
// a variable that will be used in the place of the the ith element of the
diff --git a/test/opt/desc_sroa_test.cpp b/test/opt/desc_sroa_test.cpp
index cdcc9a8..b35ad47 100644
--- a/test/opt/desc_sroa_test.cpp
+++ b/test/opt/desc_sroa_test.cpp
@@ -729,6 +729,47 @@
SinglePassRunAndMatch<DescriptorScalarReplacement>(checks + shader, true);
}
+TEST_F(DescriptorScalarReplacementTest, BindingForResourceArrayOfStructs) {
+ // Check that correct binding numbers are given to an array of descriptors
+ // to structs.
+
+ const std::string shader = R"(
+; CHECK: OpDecorate {{%\w+}} Binding 0
+; CHECK: OpDecorate {{%\w+}} Binding 1
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "psmain"
+ OpExecutionMode %2 OriginUpperLeft
+ OpDecorate %5 DescriptorSet 0
+ OpDecorate %5 Binding 0
+ OpMemberDecorate %_struct_4 0 Offset 0
+ OpMemberDecorate %_struct_4 1 Offset 4
+ OpDecorate %_struct_4 Block
+ %float = OpTypeFloat 32
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %int_1 = OpConstant %int 1
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+ %_struct_4 = OpTypeStruct %float %int
+%_arr__struct_4_uint_2 = OpTypeArray %_struct_4 %uint_2
+%_ptr_Uniform__arr__struct_4_uint_2 = OpTypePointer Uniform %_arr__struct_4_uint_2
+ %void = OpTypeVoid
+ %25 = OpTypeFunction %void
+%_ptr_Uniform_int = OpTypePointer Uniform %int
+ %5 = OpVariable %_ptr_Uniform__arr__struct_4_uint_2 Uniform
+ %2 = OpFunction %void None %25
+ %29 = OpLabel
+ %40 = OpAccessChain %_ptr_Uniform_int %5 %int_0 %int_1
+ %41 = OpAccessChain %_ptr_Uniform_int %5 %int_1 %int_1
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<DescriptorScalarReplacement>(shader, true);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools