Case validation with repeated labels (#2689)
Fixes #2686
* Update validation to handle the default case being mentioned multiple
times
* new tests
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index 5a5082f..8d8839e 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -493,41 +493,54 @@
std::map<uint32_t, uint32_t> num_fall_through_targeted;
uint32_t default_case_fall_through = 0u;
uint32_t default_target = switch_inst->GetOperandAs<uint32_t>(1u);
- std::unordered_set<uint32_t> seen;
+ bool default_appears_multiple_times = false;
+ for (uint32_t i = 3; i < switch_inst->operands().size(); i += 2) {
+ if (default_target == switch_inst->GetOperandAs<uint32_t>(i)) {
+ default_appears_multiple_times = true;
+ break;
+ }
+ }
+ std::unordered_map<uint32_t, uint32_t> seen_to_fall_through;
for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) {
uint32_t target = switch_inst->GetOperandAs<uint32_t>(i);
if (target == merge->id()) continue;
- if (!seen.insert(target).second) continue;
-
- const auto target_block = function->GetBlock(target).first;
- // OpSwitch must dominate all its case constructs.
- if (header->reachable() && target_block->reachable() &&
- !header->dominates(*target_block)) {
- return _.diag(SPV_ERROR_INVALID_CFG, header->label())
- << "Selection header " << _.getIdName(header->id())
- << " does not dominate its case construct " << _.getIdName(target);
- }
-
uint32_t case_fall_through = 0u;
- if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through,
- merge, case_targets, function)) {
- return error;
- }
-
- // Track how many time the fall through case has been targeted.
- if (case_fall_through != 0u) {
- auto where = num_fall_through_targeted.lower_bound(case_fall_through);
- if (where == num_fall_through_targeted.end() ||
- where->first != case_fall_through) {
- num_fall_through_targeted.insert(where,
- std::make_pair(case_fall_through, 1));
- } else {
- where->second++;
+ auto seen_iter = seen_to_fall_through.find(target);
+ if (seen_iter == seen_to_fall_through.end()) {
+ const auto target_block = function->GetBlock(target).first;
+ // OpSwitch must dominate all its case constructs.
+ if (header->reachable() && target_block->reachable() &&
+ !header->dominates(*target_block)) {
+ return _.diag(SPV_ERROR_INVALID_CFG, header->label())
+ << "Selection header " << _.getIdName(header->id())
+ << " does not dominate its case construct "
+ << _.getIdName(target);
}
+
+ if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through,
+ merge, case_targets, function)) {
+ return error;
+ }
+
+ // Track how many time the fall through case has been targeted.
+ if (case_fall_through != 0u) {
+ auto where = num_fall_through_targeted.lower_bound(case_fall_through);
+ if (where == num_fall_through_targeted.end() ||
+ where->first != case_fall_through) {
+ num_fall_through_targeted.insert(
+ where, std::make_pair(case_fall_through, 1));
+ } else {
+ where->second++;
+ }
+ }
+ seen_to_fall_through.insert(std::make_pair(target, case_fall_through));
+ } else {
+ case_fall_through = seen_iter->second;
}
- if (case_fall_through == default_target) {
+ if (case_fall_through == default_target &&
+ !default_appears_multiple_times) {
case_fall_through = default_case_fall_through;
}
if (case_fall_through != 0u) {
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index 0a9357b..b1857cf 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -2623,6 +2623,134 @@
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateCFG, SwitchCaseOrderingBad1) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %default "default"
+OpName %other "other"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%undef = OpUndef %int
+%void_fn = OpTypeFunction %void
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSelectionMerge %merge None
+OpSwitch %undef %default 0 %other 1 %default
+%default = OpLabel
+OpBranch %other
+%other = OpLabel
+OpBranch %merge
+%merge = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(text);
+ EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Case construct that targets 1[%default] has branches to the "
+ "case construct that targets 2[%other], but does not "
+ "immediately precede it in the OpSwitch's target list"));
+}
+
+TEST_F(ValidateCFG, SwitchCaseOrderingBad2) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %default "default"
+OpName %other "other"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%undef = OpUndef %int
+%void_fn = OpTypeFunction %void
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSelectionMerge %merge None
+OpSwitch %undef %default 0 %default 1 %other
+%other = OpLabel
+OpBranch %default
+%default = OpLabel
+OpBranch %merge
+%merge = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(text);
+ EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Case construct that targets 2[%other] has branches to the "
+ "case construct that targets 1[%default], but does not "
+ "immediately precede it in the OpSwitch's target list"));
+}
+
+TEST_F(ValidateCFG, SwitchMultipleDefaultWithFallThroughGood) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %first "first"
+OpName %second "second"
+OpName %third "third"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%undef = OpUndef %int
+%void_fn = OpTypeFunction %void
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSelectionMerge %merge None
+OpSwitch %undef %second 0 %first 1 %second 2 %third
+%first = OpLabel
+OpBranch %second
+%second = OpLabel
+OpBranch %third
+%third = OpLabel
+OpBranch %merge
+%merge = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(text);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, SwitchMultipleDefaultWithFallThroughBad) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %first "first"
+OpName %second "second"
+OpName %third "third"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%undef = OpUndef %int
+%void_fn = OpTypeFunction %void
+%func = OpFunction %void None %void_fn
+%entry = OpLabel
+OpSelectionMerge %merge None
+OpSwitch %undef %second 0 %second 1 %first 2 %third
+%first = OpLabel
+OpBranch %second
+%second = OpLabel
+OpBranch %third
+%third = OpLabel
+OpBranch %merge
+%merge = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(text);
+ EXPECT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+}
+
TEST_F(ValidateCFG, GoodUnreachableSelection) {
const std::string text = R"(
OpCapability Shader