spirv-fuzz: Fix flatten conditional branch transformation (#3859)
Fixes #3850.
diff --git a/source/fuzz/transformation_flatten_conditional_branch.cpp b/source/fuzz/transformation_flatten_conditional_branch.cpp
index bd814b7..9dead67 100644
--- a/source/fuzz/transformation_flatten_conditional_branch.cpp
+++ b/source/fuzz/transformation_flatten_conditional_branch.cpp
@@ -280,6 +280,12 @@
uint32_t first_block_last_branch_id =
branch_instruction->GetSingleWordInOperand(branches[0]);
+ // Record the block that will be reached if the branch condition is true.
+ // This information is needed later to determine how to rewrite OpPhi
+ // instructions as OpSelect instructions at the branch's convergence point.
+ uint32_t branch_instruction_true_block_id =
+ branch_instruction->GetSingleWordInOperand(1);
+
// The current header should unconditionally branch to the starting block in
// the first branch to be laid out, if such a branch exists (i.e. the header
// does not branch directly to the convergence block), and to the starting
@@ -333,16 +339,38 @@
// with OpSelect.
ir_context->get_instr_block(convergence_block_id)
- ->ForEachPhiInst([&condition_operand](opt::Instruction* phi_inst) {
- phi_inst->SetOpcode(SpvOpSelect);
+ ->ForEachPhiInst([branch_instruction_true_block_id, &condition_operand,
+ header_block,
+ ir_context](opt::Instruction* phi_inst) {
+ assert(phi_inst->NumInOperands() == 4 &&
+ "We are going to replace an OpPhi with an OpSelect. This "
+ "only makes sense if the block has two distinct "
+ "predecessors.");
+ // The OpPhi takes values from two distinct predecessors. One
+ // predecessor is associated with the "true" path of the conditional
+ // we are flattening, the other with the "false" path, but these
+ // predecessors can appear in either order as operands to the OpPhi
+ // instruction.
+
std::vector<opt::Operand> operands;
operands.emplace_back(condition_operand);
- // Only consider the operands referring to the instructions ids, as
- // the block labels are not necessary anymore.
- for (uint32_t i = 0; i < phi_inst->NumInOperands(); i += 2) {
- operands.emplace_back(phi_inst->GetInOperand(i));
- }
+ if (ir_context->GetDominatorAnalysis(header_block->GetParent())
+ ->Dominates(branch_instruction_true_block_id,
+ phi_inst->GetSingleWordInOperand(1))) {
+ // The "true" branch is handled first in the OpPhi's operands; we
+ // thus provide operands to OpSelect in the same order that they
+ // appear in the OpPhi.
+ operands.emplace_back(phi_inst->GetInOperand(0));
+ operands.emplace_back(phi_inst->GetInOperand(2));
+ } else {
+ // The "false" branch is handled first in the OpPhi's operands; we
+ // thus provide operands to OpSelect in reverse of the order that
+ // they appear in the OpPhi.
+ operands.emplace_back(phi_inst->GetInOperand(2));
+ operands.emplace_back(phi_inst->GetInOperand(0));
+ }
+ phi_inst->SetOpcode(SpvOpSelect);
phi_inst->SetInOperands(std::move(operands));
});
}
diff --git a/test/fuzz/transformation_flatten_conditional_branch_test.cpp b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
index f36ecbe..d77173d 100644
--- a/test/fuzz/transformation_flatten_conditional_branch_test.cpp
+++ b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
@@ -791,6 +791,274 @@
ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
}
+TEST(TransformationFlattenConditionalBranchTest, PhiToSelect1) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource ESSL 310
+ %3 = OpTypeVoid
+ %4 = OpTypeBool
+ %5 = OpConstantTrue %4
+ %10 = OpConstantFalse %4
+ %6 = OpTypeFunction %3
+ %2 = OpFunction %3 None %6
+ %7 = OpLabel
+ OpSelectionMerge %8 None
+ OpBranchConditional %5 %9 %8
+ %9 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpPhi %4 %5 %9 %10 %7
+ OpReturn
+ OpFunctionEnd
+)";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+
+ auto transformation = TransformationFlattenConditionalBranch(7, true, {});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource ESSL 310
+ %3 = OpTypeVoid
+ %4 = OpTypeBool
+ %5 = OpConstantTrue %4
+ %10 = OpConstantFalse %4
+ %6 = OpTypeFunction %3
+ %2 = OpFunction %3 None %6
+ %7 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpSelect %4 %5 %5 %10
+ OpReturn
+ OpFunctionEnd
+)";
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+TEST(TransformationFlattenConditionalBranchTest, PhiToSelect2) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource ESSL 310
+ %3 = OpTypeVoid
+ %4 = OpTypeBool
+ %5 = OpConstantTrue %4
+ %10 = OpConstantFalse %4
+ %6 = OpTypeFunction %3
+ %2 = OpFunction %3 None %6
+ %7 = OpLabel
+ OpSelectionMerge %8 None
+ OpBranchConditional %5 %9 %8
+ %9 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpPhi %4 %10 %7 %5 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+
+ auto transformation = TransformationFlattenConditionalBranch(7, true, {});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource ESSL 310
+ %3 = OpTypeVoid
+ %4 = OpTypeBool
+ %5 = OpConstantTrue %4
+ %10 = OpConstantFalse %4
+ %6 = OpTypeFunction %3
+ %2 = OpFunction %3 None %6
+ %7 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpSelect %4 %5 %5 %10
+ OpReturn
+ OpFunctionEnd
+)";
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+TEST(TransformationFlattenConditionalBranchTest, PhiToSelect3) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource ESSL 310
+ %3 = OpTypeVoid
+ %4 = OpTypeBool
+ %5 = OpConstantTrue %4
+ %10 = OpConstantFalse %4
+ %6 = OpTypeFunction %3
+ %2 = OpFunction %3 None %6
+ %7 = OpLabel
+ OpSelectionMerge %8 None
+ OpBranchConditional %5 %9 %12
+ %9 = OpLabel
+ OpBranch %8
+ %12 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpPhi %4 %10 %12 %5 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+
+ auto transformation = TransformationFlattenConditionalBranch(7, true, {});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource ESSL 310
+ %3 = OpTypeVoid
+ %4 = OpTypeBool
+ %5 = OpConstantTrue %4
+ %10 = OpConstantFalse %4
+ %6 = OpTypeFunction %3
+ %2 = OpFunction %3 None %6
+ %7 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpBranch %12
+ %12 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpSelect %4 %5 %5 %10
+ OpReturn
+ OpFunctionEnd
+)";
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+TEST(TransformationFlattenConditionalBranchTest, PhiToSelect4) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource ESSL 310
+ %3 = OpTypeVoid
+ %4 = OpTypeBool
+ %5 = OpConstantTrue %4
+ %10 = OpConstantFalse %4
+ %6 = OpTypeFunction %3
+ %2 = OpFunction %3 None %6
+ %7 = OpLabel
+ OpSelectionMerge %8 None
+ OpBranchConditional %5 %9 %12
+ %9 = OpLabel
+ OpBranch %8
+ %12 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpPhi %4 %5 %9 %10 %12
+ OpReturn
+ OpFunctionEnd
+)";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+
+ auto transformation = TransformationFlattenConditionalBranch(7, true, {});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource ESSL 310
+ %3 = OpTypeVoid
+ %4 = OpTypeBool
+ %5 = OpConstantTrue %4
+ %10 = OpConstantFalse %4
+ %6 = OpTypeFunction %3
+ %2 = OpFunction %3 None %6
+ %7 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpBranch %12
+ %12 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %11 = OpSelect %4 %5 %5 %10
+ OpReturn
+ OpFunctionEnd
+)";
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
} // namespace
} // namespace fuzz
} // namespace spvtools