Don't scalarize spec constant sized arrays Fixes #1952 * Prevent scalarization of arrays that are sized by a specialization constant
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp index d51dd8e..66a97bd 100644 --- a/source/opt/scalar_replacement_pass.cpp +++ b/source/opt/scalar_replacement_pass.cpp
@@ -500,6 +500,12 @@ return len; } +bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const { + const Instruction* inst = get_def_use_mgr()->GetDef(id); + assert(inst); + return spvOpcodeIsSpecConstant(inst->opcode()); +} + Instruction* ScalarReplacementPass::GetStorageType( const Instruction* inst) const { assert(inst->opcode() == SpvOpVariable); @@ -536,7 +542,12 @@ return false; return true; case SpvOpTypeArray: - if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) return false; + if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) { + return false; + } + if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) { + return false; + } return true; // TODO(alanbaker): Develop some heuristics for when this should be // re-enabled.
diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h index c89bbc4..a82abf4 100644 --- a/source/opt/scalar_replacement_pass.h +++ b/source/opt/scalar_replacement_pass.h
@@ -169,6 +169,11 @@ // |type| must be a vector or matrix type. size_t GetNumElements(const Instruction* type) const; + // Returns true if |id| is a specialization constant. + // + // |id| must be registered definition. + bool IsSpecConstant(uint32_t id) const; + // Returns an id for a pointer to |id|. uint32_t GetOrCreatePointerType(uint32_t id);
diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp index 8415ef0..7058cc5 100644 --- a/test/opt/scalar_replacement_test.cpp +++ b/test/opt/scalar_replacement_test.cpp
@@ -1402,6 +1402,42 @@ SinglePassRunAndMatch<ScalarReplacementPass>(text, true); } +TEST_F(ScalarReplacementTest, SpecConstantArray) { + const std::string text = R"( +; CHECK: [[int:%\w+]] = OpTypeInt +; CHECK: [[spec_const:%\w+]] = OpSpecConstant [[int]] 4 +; CHECK: [[spec_op:%\w+]] = OpSpecConstantOp [[int]] IAdd [[spec_const]] [[spec_const]] +; CHECK: [[array1:%\w+]] = OpTypeArray [[int]] [[spec_const]] +; CHECK: [[array2:%\w+]] = OpTypeArray [[int]] [[spec_op]] +; CHECK: [[ptr_array1:%\w+]] = OpTypePointer Function [[array1]] +; CHECK: [[ptr_array2:%\w+]] = OpTypePointer Function [[array2]] +; CHECK: OpLabel +; CHECK-NEXT: OpVariable [[ptr_array1]] Function +; CHECK-NEXT: OpVariable [[ptr_array2]] Function +; CHECK-NOT: OpVariable +OpCapability Shader +OpCapability Linkage +OpMemoryModel Logical GLSL450 +%void = OpTypeVoid +%void_fn = OpTypeFunction %void +%int = OpTypeInt 32 0 +%spec_const = OpSpecConstant %int 4 +%spec_op = OpSpecConstantOp %int IAdd %spec_const %spec_const +%array_1 = OpTypeArray %int %spec_const +%array_2 = OpTypeArray %int %spec_op +%ptr_array_1_Function = OpTypePointer Function %array_1 +%ptr_array_2_Function = OpTypePointer Function %array_2 +%func = OpFunction %void None %void_fn +%1 = OpLabel +%var_1 = OpVariable %ptr_array_1_Function Function +%var_2 = OpVariable %ptr_array_2_Function Function +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndMatch<ScalarReplacementPass>(text, true); +} + TEST_F(ScalarReplacementTest, CreateAmbiguousNullConstant2) { const std::string text = R"( ;