Validate OpForwardPointer (#2156) * Validate OpForwardPointer The validator does not have a a check that OpForwardPointer is giving a forward reference to a pointer type. We add that check. https://crbug.com/910852 * Remove more specialized check. There was a check that the forward pointer is actually a poiner type, but it was only done if it was used in a struct. This was too specific. Remove it in favour of the more general check that was added. * Format * Check the storage type in OpTypeForwardPointer * Fix typo is test case epxected results.
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp index 3bbdb87..e0f2786 100644 --- a/source/val/validate_type.cpp +++ b/source/val/validate_type.cpp
@@ -191,11 +191,6 @@ << _.getIdName(member_type_id) << "."; } if (_.IsForwardPointer(member_type_id)) { - if (member_type->opcode() != SpvOpTypePointer) { - return _.diag(SPV_ERROR_INVALID_ID, inst) - << "Found a forward reference to a non-pointer " - "type in OpTypeStruct instruction."; - } // If we're dealing with a forward pointer: // Find out the type that the pointer is pointing to (must be struct) // word 3 is the <id> of the type being pointed to. @@ -296,10 +291,32 @@ return SPV_SUCCESS; } +spv_result_t ValidateTypeForwardPointer(ValidationState_t& _, + const Instruction* inst) { + const auto pointer_type_id = inst->GetOperandAs<uint32_t>(0); + const auto pointer_type_inst = _.FindDef(pointer_type_id); + if (pointer_type_inst->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Pointer type in OpTypeForwardPointer is not a pointer type."; + } + + if (inst->GetOperandAs<uint32_t>(1) != + pointer_type_inst->GetOperandAs<uint32_t>(1)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Storage class in OpTypeForwardPointer does not match the " + "pointer definition."; + } + + return SPV_SUCCESS; +} + } // namespace spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) { - if (!spvOpcodeGeneratesType(inst->opcode())) return SPV_SUCCESS; + if (!spvOpcodeGeneratesType(inst->opcode()) && + inst->opcode() != SpvOpTypeForwardPointer) { + return SPV_SUCCESS; + } if (auto error = ValidateUniqueness(_, inst)) return error; @@ -325,6 +342,9 @@ case SpvOpTypeFunction: if (auto error = ValidateTypeFunction(_, inst)) return error; break; + case SpvOpTypeForwardPointer: + if (auto error = ValidateTypeForwardPointer(_, inst)) return error; + break; default: break; }
diff --git a/test/val/val_data_test.cpp b/test/val/val_data_test.cpp index b414aaa..fcf447a 100644 --- a/test/val/val_data_test.cpp +++ b/test/val/val_data_test.cpp
@@ -574,8 +574,8 @@ CompileSuccessfully(str.c_str()); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_THAT(getDiagnosticString(), - HasSubstr("Found a forward reference to a non-pointer type in " - "OpTypeStruct instruction.")); + HasSubstr("Pointer type in OpTypeForwardPointer is not a pointer " + "type.\n OpTypeForwardPointer %float Generic\n")); } TEST_F(ValidateData, forward_ref_points_to_non_struct) {
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp index 7aec5a4..1c4fdd3 100644 --- a/test/val/val_id_test.cpp +++ b/test/val/val_id_test.cpp
@@ -6324,6 +6324,56 @@ "dominate its parent 7[%7]\n %14 = OpPhi %float %11 %10 %13 " "%7")); } + +TEST_F(ValidateIdWithMessage, OpTypeForwardPointerNotAPointerType) { + std::string spirv = R"( + OpCapability GenericPointer + OpCapability VariablePointersStorageBuffer + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginLowerLeft + OpTypeForwardPointer %2 CrossWorkgroup +%2 = OpTypeVoid +%3 = OpTypeFunction %2 +%1 = OpFunction %2 DontInline %3 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Pointer type in OpTypeForwardPointer is not a pointer " + "type.\n OpTypeForwardPointer %void CrossWorkgroup")); +} + +TEST_F(ValidateIdWithMessage, OpTypeForwardPointerWrongStorageClass) { + std::string spirv = R"( + OpCapability GenericPointer + OpCapability VariablePointersStorageBuffer + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %1 "main" + OpExecutionMode %1 OriginLowerLeft + OpTypeForwardPointer %2 CrossWorkgroup +%int = OpTypeInt 32 1 +%2 = OpTypePointer Function %int +%void = OpTypeVoid +%3 = OpTypeFunction %void +%1 = OpFunction %void None %3 +%4 = OpLabel + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Storage class in OpTypeForwardPointer does not match the " + "pointer definition.\n OpTypeForwardPointer " + "%_ptr_Function_int CrossWorkgroup")); +} } // namespace } // namespace val } // namespace spvtools