Don't scalarize spec constant sized arrays
Fixes #1952
* Prevent scalarization of arrays that are sized by a specialization
constant
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
index d51dd8e..66a97bd 100644
--- a/source/opt/scalar_replacement_pass.cpp
+++ b/source/opt/scalar_replacement_pass.cpp
@@ -500,6 +500,12 @@
return len;
}
+bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
+ const Instruction* inst = get_def_use_mgr()->GetDef(id);
+ assert(inst);
+ return spvOpcodeIsSpecConstant(inst->opcode());
+}
+
Instruction* ScalarReplacementPass::GetStorageType(
const Instruction* inst) const {
assert(inst->opcode() == SpvOpVariable);
@@ -536,7 +542,12 @@
return false;
return true;
case SpvOpTypeArray:
- if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) return false;
+ if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
+ return false;
+ }
+ if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
+ return false;
+ }
return true;
// TODO(alanbaker): Develop some heuristics for when this should be
// re-enabled.
diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h
index c89bbc4..a82abf4 100644
--- a/source/opt/scalar_replacement_pass.h
+++ b/source/opt/scalar_replacement_pass.h
@@ -169,6 +169,11 @@
// |type| must be a vector or matrix type.
size_t GetNumElements(const Instruction* type) const;
+ // Returns true if |id| is a specialization constant.
+ //
+ // |id| must be registered definition.
+ bool IsSpecConstant(uint32_t id) const;
+
// Returns an id for a pointer to |id|.
uint32_t GetOrCreatePointerType(uint32_t id);
diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp
index 8415ef0..7058cc5 100644
--- a/test/opt/scalar_replacement_test.cpp
+++ b/test/opt/scalar_replacement_test.cpp
@@ -1402,6 +1402,42 @@
SinglePassRunAndMatch<ScalarReplacementPass>(text, true);
}
+TEST_F(ScalarReplacementTest, SpecConstantArray) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt
+; CHECK: [[spec_const:%\w+]] = OpSpecConstant [[int]] 4
+; CHECK: [[spec_op:%\w+]] = OpSpecConstantOp [[int]] IAdd [[spec_const]] [[spec_const]]
+; CHECK: [[array1:%\w+]] = OpTypeArray [[int]] [[spec_const]]
+; CHECK: [[array2:%\w+]] = OpTypeArray [[int]] [[spec_op]]
+; CHECK: [[ptr_array1:%\w+]] = OpTypePointer Function [[array1]]
+; CHECK: [[ptr_array2:%\w+]] = OpTypePointer Function [[array2]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[ptr_array1]] Function
+; CHECK-NEXT: OpVariable [[ptr_array2]] Function
+; CHECK-NOT: OpVariable
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%int = OpTypeInt 32 0
+%spec_const = OpSpecConstant %int 4
+%spec_op = OpSpecConstantOp %int IAdd %spec_const %spec_const
+%array_1 = OpTypeArray %int %spec_const
+%array_2 = OpTypeArray %int %spec_op
+%ptr_array_1_Function = OpTypePointer Function %array_1
+%ptr_array_2_Function = OpTypePointer Function %array_2
+%func = OpFunction %void None %void_fn
+%1 = OpLabel
+%var_1 = OpVariable %ptr_array_1_Function Function
+%var_2 = OpVariable %ptr_array_2_Function Function
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<ScalarReplacementPass>(text, true);
+}
+
TEST_F(ScalarReplacementTest, CreateAmbiguousNullConstant2) {
const std::string text = R"(
;