Validate OpForwardPointer (#2156)

* Validate OpForwardPointer

The validator does not have a a check that OpForwardPointer is giving
a forward reference to a pointer type.  We add that check.

https://crbug.com/910852

* Remove more specialized check.

There was a check that the forward pointer is actually a poiner type,
but it was only done if it was used in a struct.  This was too specific.
Remove it in favour of the more general check that was added.

* Format

* Check the storage type in OpTypeForwardPointer

* Fix typo is test case epxected results.
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index 3bbdb87..e0f2786 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -191,11 +191,6 @@
              << _.getIdName(member_type_id) << ".";
     }
     if (_.IsForwardPointer(member_type_id)) {
-      if (member_type->opcode() != SpvOpTypePointer) {
-        return _.diag(SPV_ERROR_INVALID_ID, inst)
-               << "Found a forward reference to a non-pointer "
-                  "type in OpTypeStruct instruction.";
-      }
       // If we're dealing with a forward pointer:
       // Find out the type that the pointer is pointing to (must be struct)
       // word 3 is the <id> of the type being pointed to.
@@ -296,10 +291,32 @@
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateTypeForwardPointer(ValidationState_t& _,
+                                        const Instruction* inst) {
+  const auto pointer_type_id = inst->GetOperandAs<uint32_t>(0);
+  const auto pointer_type_inst = _.FindDef(pointer_type_id);
+  if (pointer_type_inst->opcode() != SpvOpTypePointer) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "Pointer type in OpTypeForwardPointer is not a pointer type.";
+  }
+
+  if (inst->GetOperandAs<uint32_t>(1) !=
+      pointer_type_inst->GetOperandAs<uint32_t>(1)) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "Storage class in OpTypeForwardPointer does not match the "
+              "pointer definition.";
+  }
+
+  return SPV_SUCCESS;
+}
+
 }  // namespace
 
 spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
-  if (!spvOpcodeGeneratesType(inst->opcode())) return SPV_SUCCESS;
+  if (!spvOpcodeGeneratesType(inst->opcode()) &&
+      inst->opcode() != SpvOpTypeForwardPointer) {
+    return SPV_SUCCESS;
+  }
 
   if (auto error = ValidateUniqueness(_, inst)) return error;
 
@@ -325,6 +342,9 @@
     case SpvOpTypeFunction:
       if (auto error = ValidateTypeFunction(_, inst)) return error;
       break;
+    case SpvOpTypeForwardPointer:
+      if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
+      break;
     default:
       break;
   }
diff --git a/test/val/val_data_test.cpp b/test/val/val_data_test.cpp
index b414aaa..fcf447a 100644
--- a/test/val/val_data_test.cpp
+++ b/test/val/val_data_test.cpp
@@ -574,8 +574,8 @@
   CompileSuccessfully(str.c_str());
   ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Found a forward reference to a non-pointer type in "
-                        "OpTypeStruct instruction."));
+              HasSubstr("Pointer type in OpTypeForwardPointer is not a pointer "
+                        "type.\n  OpTypeForwardPointer %float Generic\n"));
 }
 
 TEST_F(ValidateData, forward_ref_points_to_non_struct) {
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 7aec5a4..1c4fdd3 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -6324,6 +6324,56 @@
                 "dominate its parent 7[%7]\n  %14 = OpPhi %float %11 %10 %13 "
                 "%7"));
 }
+
+TEST_F(ValidateIdWithMessage, OpTypeForwardPointerNotAPointerType) {
+  std::string spirv = R"(
+     OpCapability GenericPointer
+     OpCapability VariablePointersStorageBuffer
+     OpMemoryModel Logical GLSL450
+     OpEntryPoint Fragment %1 "main"
+     OpExecutionMode %1 OriginLowerLeft
+     OpTypeForwardPointer %2 CrossWorkgroup
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%1 = OpFunction %2 DontInline %3
+%4 = OpLabel
+     OpReturn
+     OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Pointer type in OpTypeForwardPointer is not a pointer "
+                        "type.\n  OpTypeForwardPointer %void CrossWorkgroup"));
+}
+
+TEST_F(ValidateIdWithMessage, OpTypeForwardPointerWrongStorageClass) {
+  std::string spirv = R"(
+     OpCapability GenericPointer
+     OpCapability VariablePointersStorageBuffer
+     OpMemoryModel Logical GLSL450
+     OpEntryPoint Fragment %1 "main"
+     OpExecutionMode %1 OriginLowerLeft
+     OpTypeForwardPointer %2 CrossWorkgroup
+%int = OpTypeInt 32 1
+%2 = OpTypePointer Function %int
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%1 = OpFunction %void None %3
+%4 = OpLabel
+     OpReturn
+     OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_3);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Storage class in OpTypeForwardPointer does not match the "
+                "pointer definition.\n  OpTypeForwardPointer "
+                "%_ptr_Function_int CrossWorkgroup"));
+}
 }  // namespace
 }  // namespace val
 }  // namespace spvtools