Allow forward pointer to be used in types generally (#4044)
Fixes #4042
* Allow types to have forward declarations as long as that declaration
is an OpTypeForwardPointer
diff --git a/source/operand.cpp b/source/operand.cpp
index d4b64a8..6755eab 100644
--- a/source/operand.cpp
+++ b/source/operand.cpp
@@ -24,6 +24,7 @@
#include "DebugInfo.h"
#include "OpenCLDebugInfo100.h"
#include "source/macro.h"
+#include "source/opcode.h"
#include "source/spirv_constant.h"
#include "source/spirv_target_env.h"
@@ -491,6 +492,11 @@
std::function<bool(unsigned)> spvOperandCanBeForwardDeclaredFunction(
SpvOp opcode) {
std::function<bool(unsigned index)> out;
+ if (spvOpcodeGeneratesType(opcode)) {
+ // All types can use forward pointers.
+ out = [](unsigned) { return true; };
+ return out;
+ }
switch (opcode) {
case SpvOpExecutionMode:
case SpvOpExecutionModeId:
@@ -503,7 +509,6 @@
case SpvOpDecorateId:
case SpvOpDecorateStringGOOGLE:
case SpvOpMemberDecorateStringGOOGLE:
- case SpvOpTypeStruct:
case SpvOpBranch:
case SpvOpLoopMerge:
out = [](unsigned) { return true; };
diff --git a/source/val/validate_id.cpp b/source/val/validate_id.cpp
index e1a775a..2bab203 100644
--- a/source/val/validate_id.cpp
+++ b/source/val/validate_id.cpp
@@ -201,7 +201,7 @@
ret = SPV_SUCCESS;
}
} else if (can_have_forward_declared_ids(i)) {
- if (inst->opcode() == SpvOpTypeStruct &&
+ if (spvOpcodeGeneratesType(inst->opcode()) &&
!_.IsForwardPointer(operand_word)) {
ret = _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word)
diff --git a/test/val/val_data_test.cpp b/test/val/val_data_test.cpp
index 30afd03..1080a59 100644
--- a/test/val/val_data_test.cpp
+++ b/test/val/val_data_test.cpp
@@ -497,7 +497,7 @@
CompileSuccessfully(str.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("ID 3[%3] has not been defined"));
+ HasSubstr("Operand 3[%3] requires a previous definition"));
}
TEST_F(ValidateData, matrix_bad_column_type) {
@@ -944,6 +944,40 @@
"OpTypeStruct %_runtimearr_uint %uint\n"));
}
+TEST_F(ValidateData, TypeForwardReference) {
+ std::string test = R"(
+OpCapability Shader
+OpCapability PhysicalStorageBufferAddresses
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpTypeForwardPointer %1 PhysicalStorageBuffer
+%2 = OpTypeStruct
+%3 = OpTypeRuntimeArray %1
+%1 = OpTypePointer PhysicalStorageBuffer %2
+)";
+
+ CompileSuccessfully(test, SPV_ENV_UNIVERSAL_1_5);
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_5));
+}
+
+TEST_F(ValidateData, TypeForwardReferenceMustBeForwardPointer) {
+ std::string test = R"(
+OpCapability Shader
+OpCapability PhysicalStorageBufferAddresses
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeStruct
+%2 = OpTypeRuntimeArray %3
+%3 = OpTypePointer PhysicalStorageBuffer %1
+)";
+
+ CompileSuccessfully(test, SPV_ENV_UNIVERSAL_1_5);
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_5));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Operand 3[%_ptr_PhysicalStorageBuffer__struct_1] "
+ "requires a previous definition"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools