For Vulkan, disallow structures containing opaque types (#2546)
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index e3c7662..afc0656 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -190,6 +190,35 @@
return SPV_SUCCESS;
}
+bool ContainsOpaqueType(ValidationState_t& _, const Instruction* str) {
+ const size_t elem_type_index = 1;
+ uint32_t elem_type_id;
+ Instruction* elem_type;
+
+ if (spvOpcodeIsBaseOpaqueType(str->opcode())) {
+ return true;
+ }
+
+ switch (str->opcode()) {
+ case SpvOpTypeArray:
+ case SpvOpTypeRuntimeArray:
+ elem_type_id = str->GetOperandAs<uint32_t>(elem_type_index);
+ elem_type = _.FindDef(elem_type_id);
+ return ContainsOpaqueType(_, elem_type);
+ case SpvOpTypeStruct:
+ for (size_t member_type_index = 1;
+ member_type_index < str->operands().size(); ++member_type_index) {
+ auto member_type_id = str->GetOperandAs<uint32_t>(member_type_index);
+ auto member_type = _.FindDef(member_type_id);
+ if (ContainsOpaqueType(_, member_type)) return true;
+ }
+ break;
+ default:
+ break;
+ }
+ return false;
+}
+
spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) {
const uint32_t struct_id = inst->GetOperandAs<uint32_t>(0);
for (size_t member_type_index = 1;
@@ -289,6 +318,14 @@
if (num_builtin_members > 0) {
_.RegisterStructTypeWithBuiltInMember(struct_id);
}
+
+ if (spvIsVulkanEnv(_.context()->target_env) &&
+ !_.options()->before_hlsl_legalization && ContainsOpaqueType(_, inst)) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "In " << spvLogStringForEnv(_.context()->target_env)
+ << ", OpTypeStruct must not contain an opaque type.";
+ }
+
return SPV_SUCCESS;
}
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 73e2ce1..8839490 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -937,6 +937,26 @@
"a type."));
}
+TEST_F(ValidateIdWithMessage, OpTypeStructOpaqueTypeBad) {
+ std::string spirv = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Vertex %main "main"
+ %1 = OpTypeSampler
+ %2 = OpTypeStruct %1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %main = OpFunction %void None %3
+ %5 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("OpTypeStruct must not contain an opaque type"));
+}
+
TEST_F(ValidateIdWithMessage, OpTypePointerGood) {
std::string spirv = kGLSL450MemoryModel + R"(
%1 = OpTypeInt 32 0