spirv-val: Emit an error when an OpSwitch target is not an OpLabel (#2298)
Fixes #1628.
* spirv-val: Emit an error when an OpBranch target is not an OpLabel
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index 8fe30a8..fe79dde 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -112,6 +112,19 @@
return SPV_SUCCESS;
}
+spv_result_t ValidateBranch(ValidationState_t& _, const Instruction* inst) {
+ // target operands must be OpLabel
+ const auto id = inst->GetOperandAs<uint32_t>(0);
+ const auto target = _.FindDef(id);
+ if (!target || SpvOpLabel != target->opcode()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "'Target Label' operands for OpBranch must be the ID "
+ "of an OpLabel instruction";
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t ValidateBranchConditional(ValidationState_t& _,
const Instruction* inst) {
// num_operands is either 3 or 5 --- if 5, the last two need to be literal
@@ -155,6 +168,26 @@
return SPV_SUCCESS;
}
+spv_result_t ValidateSwitch(ValidationState_t& _, const Instruction* inst) {
+ const auto num_operands = inst->operands().size();
+ // At least two operands (selector, default), any more than that are
+ // literal/target.
+
+ // target operands must be OpLabel
+ for (size_t i = 2; i < num_operands; i += 2) {
+ // literal, id
+ const auto id = inst->GetOperandAs<uint32_t>(i + 1);
+ const auto target = _.FindDef(id);
+ if (!target || SpvOpLabel != target->opcode()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "'Target Label' operands for OpSwitch must be IDs of an "
+ "OpLabel instruction";
+ }
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t ValidateReturnValue(ValidationState_t& _,
const Instruction* inst) {
const auto value_id = inst->GetOperandAs<uint32_t>(0);
@@ -764,12 +797,18 @@
case SpvOpPhi:
if (auto error = ValidatePhi(_, inst)) return error;
break;
+ case SpvOpBranch:
+ if (auto error = ValidateBranch(_, inst)) return error;
+ break;
case SpvOpBranchConditional:
if (auto error = ValidateBranchConditional(_, inst)) return error;
break;
case SpvOpReturnValue:
if (auto error = ValidateReturnValue(_, inst)) return error;
break;
+ case SpvOpSwitch:
+ if (auto error = ValidateSwitch(_, inst)) return error;
+ break;
default:
break;
}
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index aed0a57..fae5702 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -493,13 +493,10 @@
str += "OpFunctionEnd\n";
CompileSuccessfully(str);
- ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
- EXPECT_THAT(
- getDiagnosticString(),
- MatchesRegex("Block\\(s\\) \\{11\\[%11\\]\\} are referenced but not "
- "defined in function .\\[%Main\\]\n %Main = OpFunction "
- "%void None %10\n"))
- << str;
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("'Target Label' operands for OpBranch must "
+ "be the ID of an OpLabel instruction"));
}
TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) {
@@ -2060,6 +2057,57 @@
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateCFG, SwitchTargetMustBeLabel) {
+ const std::string text = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "foo"
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %void = OpTypeVoid
+ %5 = OpTypeFunction %void
+ %1 = OpFunction %void None %5
+ %6 = OpLabel
+ %7 = OpCopyObject %uint %uint_0
+ OpSelectionMerge %8 None
+ OpSwitch %uint_0 %8 0 %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(text);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("'Target Label' operands for OpSwitch must "
+ "be IDs of an OpLabel instruction"));
+}
+
+TEST_F(ValidateCFG, BranchTargetMustBeLabel) {
+ const std::string text = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "foo"
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %void = OpTypeVoid
+ %5 = OpTypeFunction %void
+ %1 = OpFunction %void None %5
+ %2 = OpLabel
+ %7 = OpCopyObject %uint %uint_0
+ OpBranch %7
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(text);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("'Target Label' operands for OpBranch must "
+ "be the ID of an OpLabel instruction"));
+}
+
/// TODO(umar): Nested CFG constructs
} // namespace