spirv-fuzz: Do not allow creation of constants of block-decorated structs (#3903)

Fixes #3902.
diff --git a/source/fuzz/fuzzer_pass_add_composite_inserts.cpp b/source/fuzz/fuzzer_pass_add_composite_inserts.cpp
index 515407b..e58c754 100644
--- a/source/fuzz/fuzzer_pass_add_composite_inserts.cpp
+++ b/source/fuzz/fuzzer_pass_add_composite_inserts.cpp
@@ -166,9 +166,8 @@
         // this type.
         uint32_t available_object_id;
         if (available_objects.empty()) {
-          auto current_node_type =
-              GetIRContext()->get_type_mgr()->GetType(current_node_type_id);
-          if (!fuzzerutil::CanCreateConstant(*current_node_type)) {
+          if (!fuzzerutil::CanCreateConstant(GetIRContext(),
+                                             current_node_type_id)) {
             return;
           }
           available_object_id =
diff --git a/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp b/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp
index 5e79ec8..1bc0c2b 100644
--- a/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp
+++ b/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp
@@ -100,9 +100,8 @@
           // The id will be a zero constant if the type allows it, and an
           // OpUndef otherwise. We want to avoid using OpUndef, if possible, to
           // avoid undefined behaviour in the module as much as possible.
-          if (fuzzerutil::CanCreateConstant(
-                  *GetIRContext()->get_type_mgr()->GetType(
-                      instruction->type_id()))) {
+          if (fuzzerutil::CanCreateConstant(GetIRContext(),
+                                            instruction->type_id())) {
             wrapper_info.set_value_to_copy_id(
                 FindOrCreateZeroConstant(instruction->type_id(), true));
           } else {
diff --git a/source/fuzz/fuzzer_pass_replace_branches_from_dead_blocks_with_exits.cpp b/source/fuzz/fuzzer_pass_replace_branches_from_dead_blocks_with_exits.cpp
index 200aaad..e6bebea 100644
--- a/source/fuzz/fuzzer_pass_replace_branches_from_dead_blocks_with_exits.cpp
+++ b/source/fuzz/fuzzer_pass_replace_branches_from_dead_blocks_with_exits.cpp
@@ -81,7 +81,8 @@
           GetIRContext()->get_type_mgr()->GetType(function.type_id());
       if (function_return_type->AsVoid()) {
         opcodes.emplace_back(SpvOpReturn);
-      } else if (fuzzerutil::CanCreateConstant(*function_return_type)) {
+      } else if (fuzzerutil::CanCreateConstant(GetIRContext(),
+                                               function.type_id())) {
         // For simplicity we only allow OpReturnValue if the function return
         // type is a type for which we can create a constant.  This allows us a
         // zero of the given type as a default return value.
diff --git a/source/fuzz/fuzzer_pass_replace_copy_objects_with_stores_loads.cpp b/source/fuzz/fuzzer_pass_replace_copy_objects_with_stores_loads.cpp
index 51cb569..e372924 100644
--- a/source/fuzz/fuzzer_pass_replace_copy_objects_with_stores_loads.cpp
+++ b/source/fuzz/fuzzer_pass_replace_copy_objects_with_stores_loads.cpp
@@ -69,9 +69,8 @@
     // Find or create a constant to initialize the variable from. The type of
     // |instruction| must be such that the function FindOrCreateConstant can be
     // called.
-    auto instruction_type =
-        GetIRContext()->get_type_mgr()->GetType(instruction->type_id());
-    if (!fuzzerutil::CanCreateConstant(*instruction_type)) {
+    if (!fuzzerutil::CanCreateConstant(GetIRContext(),
+                                       instruction->type_id())) {
       return;
     }
     auto variable_initializer_id =
diff --git a/source/fuzz/fuzzer_pass_replace_parameter_with_global.cpp b/source/fuzz/fuzzer_pass_replace_parameter_with_global.cpp
index 8672a3b..6b3a63b 100644
--- a/source/fuzz/fuzzer_pass_replace_parameter_with_global.cpp
+++ b/source/fuzz/fuzzer_pass_replace_parameter_with_global.cpp
@@ -53,27 +53,23 @@
     // function has at least one parameter.
     if (std::none_of(params.begin(), params.end(),
                      [this](const opt::Instruction* param) {
-                       const auto* param_type =
-                           GetIRContext()->get_type_mgr()->GetType(
-                               param->type_id());
-                       assert(param_type && "Parameter has invalid type");
                        return TransformationReplaceParameterWithGlobal::
-                           IsParameterTypeSupported(*param_type);
+                           IsParameterTypeSupported(GetIRContext(),
+                                                    param->type_id());
                      })) {
       continue;
     }
 
     // Select id of a parameter to replace.
-    const opt::Instruction* replaced_param = nullptr;
-    const opt::analysis::Type* param_type = nullptr;
+    const opt::Instruction* replaced_param;
+    uint32_t param_type_id;
     do {
       replaced_param = GetFuzzerContext()->RemoveAtRandomIndex(&params);
-      param_type =
-          GetIRContext()->get_type_mgr()->GetType(replaced_param->type_id());
-      assert(param_type && "Parameter has invalid type");
+      param_type_id = replaced_param->type_id();
+      assert(param_type_id && "Parameter has invalid type");
     } while (
         !TransformationReplaceParameterWithGlobal::IsParameterTypeSupported(
-            *param_type));
+            GetIRContext(), param_type_id));
 
     assert(replaced_param && "Unable to find a parameter to replace");
 
diff --git a/source/fuzz/fuzzer_pass_replace_params_with_struct.cpp b/source/fuzz/fuzzer_pass_replace_params_with_struct.cpp
index 86d6d06..0e0610f 100644
--- a/source/fuzz/fuzzer_pass_replace_params_with_struct.cpp
+++ b/source/fuzz/fuzzer_pass_replace_params_with_struct.cpp
@@ -53,15 +53,13 @@
     std::iota(parameter_index.begin(), parameter_index.end(), 0);
 
     // Remove the indices of unsupported parameters.
-    auto new_end = std::remove_if(
-        parameter_index.begin(), parameter_index.end(),
-        [this, &params](uint32_t index) {
-          const auto* type =
-              GetIRContext()->get_type_mgr()->GetType(params[index]->type_id());
-          assert(type && "Parameter has invalid type");
-          return !TransformationReplaceParamsWithStruct::
-              IsParameterTypeSupported(*type);
-        });
+    auto new_end =
+        std::remove_if(parameter_index.begin(), parameter_index.end(),
+                       [this, &params](uint32_t index) {
+                         return !TransformationReplaceParamsWithStruct::
+                             IsParameterTypeSupported(GetIRContext(),
+                                                      params[index]->type_id());
+                       });
 
     // std::remove_if changes the vector so that removed elements are placed at
     // the end (i.e. [new_end, parameter_index.end()) is a range of removed
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 95ead39..122f21e 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -1119,22 +1119,32 @@
   }
 }
 
-bool CanCreateConstant(const opt::analysis::Type& type) {
-  switch (type.kind()) {
-    case opt::analysis::Type::kBool:
-    case opt::analysis::Type::kInteger:
-    case opt::analysis::Type::kFloat:
-    case opt::analysis::Type::kMatrix:
-    case opt::analysis::Type::kVector:
+bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id) {
+  opt::Instruction* type_instr = ir_context->get_def_use_mgr()->GetDef(type_id);
+  assert(type_instr != nullptr && "The type must exist.");
+  assert(spvOpcodeGeneratesType(type_instr->opcode()) &&
+         "A type-generating opcode was expected.");
+  switch (type_instr->opcode()) {
+    case SpvOpTypeBool:
+    case SpvOpTypeInt:
+    case SpvOpTypeFloat:
+    case SpvOpTypeMatrix:
+    case SpvOpTypeVector:
       return true;
-    case opt::analysis::Type::kArray:
-      return CanCreateConstant(*type.AsArray()->element_type());
-    case opt::analysis::Type::kStruct:
-      return std::all_of(type.AsStruct()->element_types().begin(),
-                         type.AsStruct()->element_types().end(),
-                         [](const opt::analysis::Type* element_type) {
-                           return CanCreateConstant(*element_type);
-                         });
+    case SpvOpTypeArray:
+      return CanCreateConstant(ir_context,
+                               type_instr->GetSingleWordInOperand(0));
+    case SpvOpTypeStruct:
+      if (HasBlockOrBufferBlockDecoration(ir_context, type_id)) {
+        return false;
+      }
+      for (uint32_t index = 0; index < type_instr->NumInOperands(); index++) {
+        if (!CanCreateConstant(ir_context,
+                               type_instr->GetSingleWordInOperand(index))) {
+          return false;
+        }
+      }
+      return true;
     default:
       return false;
   }
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 6a6efb8..569f4a6 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -399,9 +399,10 @@
     uint32_t scalar_or_composite_type_id, bool is_irrelevant);
 
 // Returns true if it is possible to create an OpConstant or an
-// OpConstantComposite instruction of |type|. That is, returns true if |type|
-// and all its constituents are either scalar or composite.
-bool CanCreateConstant(const opt::analysis::Type& type);
+// OpConstantComposite instruction of type |type_id|. That is, returns true if
+// the type associated with |type_id| and all its constituents are either scalar
+// or composite.
+bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id);
 
 // Returns the result id of an OpConstant instruction. |scalar_type_id| must be
 // a result id of a scalar type (i.e. int, float or bool). Returns 0 if no such
diff --git a/source/fuzz/transformation_mutate_pointer.cpp b/source/fuzz/transformation_mutate_pointer.cpp
index 36c5951..fefedbd 100644
--- a/source/fuzz/transformation_mutate_pointer.cpp
+++ b/source/fuzz/transformation_mutate_pointer.cpp
@@ -139,25 +139,28 @@
     return false;
   }
 
-  const auto* type = ir_context->get_type_mgr()->GetType(inst.type_id());
-  assert(type && "|inst| has invalid type id");
-
-  const auto* pointer_type = type->AsPointer();
+  opt::Instruction* type_inst =
+      ir_context->get_def_use_mgr()->GetDef(inst.type_id());
+  assert(type_inst != nullptr && "|inst| has invalid type id");
 
   // |inst| must be a pointer.
-  if (!pointer_type) {
+  if (type_inst->opcode() != SpvOpTypePointer) {
     return false;
   }
 
   // |inst| must have a supported storage class.
-  if (pointer_type->storage_class() != SpvStorageClassFunction &&
-      pointer_type->storage_class() != SpvStorageClassPrivate &&
-      pointer_type->storage_class() != SpvStorageClassWorkgroup) {
-    return false;
+  switch (static_cast<SpvStorageClass>(type_inst->GetSingleWordInOperand(0))) {
+    case SpvStorageClassFunction:
+    case SpvStorageClassPrivate:
+    case SpvStorageClassWorkgroup:
+      break;
+    default:
+      return false;
   }
 
   // |inst|'s pointee must consist of scalars and/or composites.
-  return fuzzerutil::CanCreateConstant(*pointer_type->pointee_type());
+  return fuzzerutil::CanCreateConstant(ir_context,
+                                       type_inst->GetSingleWordInOperand(1));
 }
 
 std::unordered_set<uint32_t> TransformationMutatePointer::GetFreshIds() const {
diff --git a/source/fuzz/transformation_replace_parameter_with_global.cpp b/source/fuzz/transformation_replace_parameter_with_global.cpp
index e75a9b2..cdf7645 100644
--- a/source/fuzz/transformation_replace_parameter_with_global.cpp
+++ b/source/fuzz/transformation_replace_parameter_with_global.cpp
@@ -57,10 +57,7 @@
   // |parameter_id|.
 
   // Check that replaced parameter has valid type.
-  const auto* param_type =
-      ir_context->get_type_mgr()->GetType(param_inst->type_id());
-  assert(param_type && "Parameter has invalid type");
-  if (!IsParameterTypeSupported(*param_type)) {
+  if (!IsParameterTypeSupported(ir_context, param_inst->type_id())) {
     return false;
   }
 
@@ -198,10 +195,10 @@
 }
 
 bool TransformationReplaceParameterWithGlobal::IsParameterTypeSupported(
-    const opt::analysis::Type& type) {
+    opt::IRContext* ir_context, uint32_t param_type_id) {
   // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3403):
   //  Think about other type instructions we can add here.
-  return fuzzerutil::CanCreateConstant(type);
+  return fuzzerutil::CanCreateConstant(ir_context, param_type_id);
 }
 
 std::unordered_set<uint32_t>
diff --git a/source/fuzz/transformation_replace_parameter_with_global.h b/source/fuzz/transformation_replace_parameter_with_global.h
index a5bdc5b..c2d5f8f 100644
--- a/source/fuzz/transformation_replace_parameter_with_global.h
+++ b/source/fuzz/transformation_replace_parameter_with_global.h
@@ -57,7 +57,8 @@
 
   // Returns true if the type of the parameter is supported by this
   // transformation.
-  static bool IsParameterTypeSupported(const opt::analysis::Type& type);
+  static bool IsParameterTypeSupported(opt::IRContext* ir_context,
+                                       uint32_t param_type_id);
 
  private:
   protobufs::TransformationReplaceParameterWithGlobal message_;
diff --git a/source/fuzz/transformation_replace_params_with_struct.cpp b/source/fuzz/transformation_replace_params_with_struct.cpp
index 3f8b21b..0a135e5 100644
--- a/source/fuzz/transformation_replace_params_with_struct.cpp
+++ b/source/fuzz/transformation_replace_params_with_struct.cpp
@@ -85,10 +85,8 @@
     }
 
     // Check that the parameter with result id |id| has supported type.
-    const auto* type = ir_context->get_type_mgr()->GetType(
-        fuzzerutil::GetTypeId(ir_context, id));
-    assert(type && "Parameter has invalid type");
-    if (!IsParameterTypeSupported(*type)) {
+    if (!IsParameterTypeSupported(ir_context,
+                                  fuzzerutil::GetTypeId(ir_context, id))) {
       return false;
     }
   }
@@ -263,10 +261,10 @@
 }
 
 bool TransformationReplaceParamsWithStruct::IsParameterTypeSupported(
-    const opt::analysis::Type& param_type) {
+    opt::IRContext* ir_context, uint32_t param_type_id) {
   // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3403):
   //  Consider adding support for more types of parameters.
-  return fuzzerutil::CanCreateConstant(param_type);
+  return fuzzerutil::CanCreateConstant(ir_context, param_type_id);
 }
 
 uint32_t TransformationReplaceParamsWithStruct::MaybeGetRequiredStructType(
diff --git a/source/fuzz/transformation_replace_params_with_struct.h b/source/fuzz/transformation_replace_params_with_struct.h
index 4f88f8e..afa6b14 100644
--- a/source/fuzz/transformation_replace_params_with_struct.h
+++ b/source/fuzz/transformation_replace_params_with_struct.h
@@ -68,7 +68,8 @@
   protobufs::Transformation ToMessage() const override;
 
   // Returns true if parameter's type is supported by this transformation.
-  static bool IsParameterTypeSupported(const opt::analysis::Type& param_type);
+  static bool IsParameterTypeSupported(opt::IRContext* ir_context,
+                                       uint32_t param_type_id);
 
  private:
   // Returns a result id of the OpTypeStruct instruction required by this
diff --git a/test/fuzz/transformation_flatten_conditional_branch_test.cpp b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
index d77173d..579b696 100644
--- a/test/fuzz/transformation_flatten_conditional_branch_test.cpp
+++ b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
@@ -25,8 +25,8 @@
 protobufs::SideEffectWrapperInfo MakeSideEffectWrapperInfo(
     const protobufs::InstructionDescriptor& instruction,
     uint32_t merge_block_id, uint32_t execute_block_id,
-    uint32_t actual_result_id = 0, uint32_t alternative_block_id = 0,
-    uint32_t placeholder_result_id = 0, uint32_t value_to_copy_id = 0) {
+    uint32_t actual_result_id, uint32_t alternative_block_id,
+    uint32_t placeholder_result_id, uint32_t value_to_copy_id) {
   protobufs::SideEffectWrapperInfo result;
   *result.mutable_instruction() = instruction;
   result.set_merge_block_id(merge_block_id);
@@ -38,6 +38,13 @@
   return result;
 }
 
+protobufs::SideEffectWrapperInfo MakeSideEffectWrapperInfo(
+    const protobufs::InstructionDescriptor& instruction,
+    uint32_t merge_block_id, uint32_t execute_block_id) {
+  return MakeSideEffectWrapperInfo(instruction, merge_block_id,
+                                   execute_block_id, 0, 0, 0, 0);
+}
+
 TEST(TransformationFlattenConditionalBranchTest, Inapplicable) {
   std::string shader = R"(
                OpCapability Shader
@@ -434,19 +441,19 @@
 #endif
 
   // The map maps from an instruction to a list with not enough fresh ids.
-  ASSERT_FALSE(
-      TransformationFlattenConditionalBranch(
-          31, true,
-          {{MakeSideEffectWrapperInfo(
-              MakeInstructionDescriptor(6, SpvOpLoad, 0), 100, 101, 102, 103)}})
-          .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(TransformationFlattenConditionalBranch(
+                   31, true,
+                   {{MakeSideEffectWrapperInfo(
+                       MakeInstructionDescriptor(6, SpvOpLoad, 0), 100, 101,
+                       102, 103, 0, 0)}})
+                   .IsApplicable(context.get(), transformation_context));
 
   // Not all fresh ids given are distinct.
   ASSERT_FALSE(TransformationFlattenConditionalBranch(
                    31, true,
                    {{MakeSideEffectWrapperInfo(
                        MakeInstructionDescriptor(6, SpvOpLoad, 0), 100, 100,
-                       102, 103, 104)}})
+                       102, 103, 104, 0)}})
                    .IsApplicable(context.get(), transformation_context));
 
   // %48 heads a construct containing an OpSampledImage instruction.
@@ -454,7 +461,7 @@
                    48, true,
                    {{MakeSideEffectWrapperInfo(
                        MakeInstructionDescriptor(53, SpvOpLoad, 0), 100, 101,
-                       102, 103, 104)}})
+                       102, 103, 104, 0)}})
                    .IsApplicable(context.get(), transformation_context));
 
   // %0 is not a valid id.
@@ -1059,6 +1066,59 @@
   ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
 }
 
+TEST(TransformationFlattenConditionalBranchTest,
+     LoadFromBufferBlockDecoratedStruct) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+               OpMemberDecorate %11 0 Offset 0
+               OpDecorate %11 BufferBlock
+               OpDecorate %13 DescriptorSet 0
+               OpDecorate %13 Binding 0
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeBool
+          %7 = OpConstantTrue %6
+         %10 = OpTypeInt 32 1
+         %11 = OpTypeStruct %10
+         %12 = OpTypePointer Uniform %11
+         %13 = OpVariable %12 Uniform
+         %21 = OpUndef %11
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpSelectionMerge %9 None
+               OpBranchConditional %7 %8 %9
+          %8 = OpLabel
+         %20 = OpLoad %11 %13
+               OpBranch %9
+          %9 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  auto transformation = TransformationFlattenConditionalBranch(
+      5, true,
+      {MakeSideEffectWrapperInfo(MakeInstructionDescriptor(20, SpvOpLoad, 0),
+                                 100, 101, 102, 103, 104, 21)});
+  ASSERT_TRUE(
+      transformation.IsApplicable(context.get(), transformation_context));
+  ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+  ASSERT_TRUE(IsValid(env, context.get()));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools