spirv-fuzz: fuzzerutil::MaybeGetConstant* #3487

Part of #3428.
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 80dff2d..ef6e41c 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -112,19 +112,6 @@
   return true;
 }
 
-uint32_t MaybeGetBoolConstantId(opt::IRContext* context, bool value) {
-  opt::analysis::Bool bool_type;
-  auto registered_bool_type =
-      context->get_type_mgr()->GetRegisteredType(&bool_type);
-  if (!registered_bool_type) {
-    return 0;
-  }
-  opt::analysis::BoolConstant bool_constant(registered_bool_type->AsBool(),
-                                            value);
-  return context->get_constant_mgr()->FindDeclaredConstant(
-      &bool_constant, context->get_type_mgr()->GetId(&bool_type));
-}
-
 void AddUnreachableEdgeAndUpdateOpPhis(
     opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
     bool condition_value,
@@ -135,7 +122,7 @@
          "Precondition on terminator of bb_from is not satisfied");
 
   // Get the id of the boolean constant to be used as the condition.
-  uint32_t bool_id = MaybeGetBoolConstantId(context, condition_value);
+  uint32_t bool_id = MaybeGetBoolConstant(context, condition_value);
   assert(
       bool_id &&
       "Precondition that condition value must be available is not satisfied");
@@ -771,6 +758,199 @@
   return ir_context->get_type_mgr()->GetId(&type);
 }
 
+uint32_t MaybeGetZeroConstant(opt::IRContext* ir_context,
+                              uint32_t scalar_or_composite_type_id) {
+  const auto* type =
+      ir_context->get_type_mgr()->GetType(scalar_or_composite_type_id);
+  assert(type && "|scalar_or_composite_type_id| is invalid");
+
+  switch (type->kind()) {
+    case opt::analysis::Type::kBool:
+      return MaybeGetBoolConstant(ir_context, false);
+    case opt::analysis::Type::kFloat:
+    case opt::analysis::Type::kInteger: {
+      std::vector<uint32_t> words = {0};
+      if ((type->AsInteger() && type->AsInteger()->width() > 32) ||
+          (type->AsFloat() && type->AsFloat()->width() > 32)) {
+        words.push_back(0);
+      }
+
+      return MaybeGetScalarConstant(ir_context, words,
+                                    scalar_or_composite_type_id);
+    }
+    case opt::analysis::Type::kStruct: {
+      std::vector<uint32_t> component_ids;
+      for (const auto* component_type : type->AsStruct()->element_types()) {
+        auto component_type_id =
+            ir_context->get_type_mgr()->GetId(component_type);
+        assert(component_type_id && "Component type is invalid");
+
+        auto component_id = MaybeGetZeroConstant(ir_context, component_type_id);
+        if (component_id == 0) {
+          return 0;
+        }
+
+        component_ids.push_back(component_id);
+      }
+
+      return MaybeGetCompositeConstant(ir_context, component_ids,
+                                       scalar_or_composite_type_id);
+    }
+    case opt::analysis::Type::kMatrix:
+    case opt::analysis::Type::kVector: {
+      const auto* component_type = type->AsVector()
+                                       ? type->AsVector()->element_type()
+                                       : type->AsMatrix()->element_type();
+      auto component_type_id =
+          ir_context->get_type_mgr()->GetId(component_type);
+      assert(component_type_id && "Component type is invalid");
+
+      if (auto component_id =
+              MaybeGetZeroConstant(ir_context, component_type_id)) {
+        auto component_count = type->AsVector()
+                                   ? type->AsVector()->element_count()
+                                   : type->AsMatrix()->element_count();
+        return MaybeGetCompositeConstant(
+            ir_context, std::vector<uint32_t>(component_count, component_id),
+            scalar_or_composite_type_id);
+      }
+
+      return 0;
+    }
+    case opt::analysis::Type::kArray: {
+      auto component_type_id =
+          ir_context->get_type_mgr()->GetId(type->AsArray()->element_type());
+      assert(component_type_id && "Component type is invalid");
+
+      if (auto component_id =
+              MaybeGetZeroConstant(ir_context, component_type_id)) {
+        auto type_id = ir_context->get_type_mgr()->GetId(type);
+        assert(type_id && "|type| is invalid");
+
+        const auto* type_inst = ir_context->get_def_use_mgr()->GetDef(type_id);
+        assert(type_inst && "Array's type id is invalid");
+
+        return MaybeGetCompositeConstant(
+            ir_context,
+            std::vector<uint32_t>(GetArraySize(*type_inst, ir_context),
+                                  component_id),
+            scalar_or_composite_type_id);
+      }
+
+      return 0;
+    }
+    default:
+      assert(false && "Type is not supported");
+      return 0;
+  }
+}
+
+uint32_t MaybeGetScalarConstant(opt::IRContext* ir_context,
+                                const std::vector<uint32_t>& words,
+                                uint32_t scalar_type_id) {
+  const auto* type = ir_context->get_type_mgr()->GetType(scalar_type_id);
+  assert(type && "|scalar_type_id| is invalid");
+
+  if (const auto* int_type = type->AsInteger()) {
+    return MaybeGetIntegerConstant(ir_context, words, int_type->width(),
+                                   int_type->IsSigned());
+  } else if (const auto* float_type = type->AsFloat()) {
+    return MaybeGetFloatConstant(ir_context, words, float_type->width());
+  } else {
+    assert(type->AsBool() && words.size() == 1 &&
+           "|scalar_type_id| doesn't represent a scalar type");
+    return MaybeGetBoolConstant(ir_context, words[0]);
+  }
+}
+
+uint32_t MaybeGetCompositeConstant(opt::IRContext* ir_context,
+                                   const std::vector<uint32_t>& component_ids,
+                                   uint32_t composite_type_id) {
+  std::vector<const opt::analysis::Constant*> constants;
+  for (auto id : component_ids) {
+    const auto* component_constant =
+        ir_context->get_constant_mgr()->FindDeclaredConstant(id);
+    assert(component_constant && "|id| is invalid");
+
+    constants.push_back(component_constant);
+  }
+
+  const auto* type = ir_context->get_type_mgr()->GetType(composite_type_id);
+  assert(type && "|composite_type_id| is invalid");
+
+  std::unique_ptr<opt::analysis::Constant> composite_constant;
+  switch (type->kind()) {
+    case opt::analysis::Type::kStruct:
+      composite_constant = MakeUnique<opt::analysis::StructConstant>(
+          type->AsStruct(), std::move(constants));
+      break;
+    case opt::analysis::Type::kVector:
+      composite_constant = MakeUnique<opt::analysis::VectorConstant>(
+          type->AsVector(), std::move(constants));
+      break;
+    case opt::analysis::Type::kMatrix:
+      composite_constant = MakeUnique<opt::analysis::MatrixConstant>(
+          type->AsMatrix(), std::move(constants));
+      break;
+    case opt::analysis::Type::kArray:
+      composite_constant = MakeUnique<opt::analysis::ArrayConstant>(
+          type->AsArray(), std::move(constants));
+      break;
+    default:
+      assert(false &&
+             "|composite_type_id| is not a result id of a composite type");
+      return 0;
+  }
+
+  return ir_context->get_constant_mgr()->FindDeclaredConstant(
+      composite_constant.get(), composite_type_id);
+}
+
+uint32_t MaybeGetIntegerConstant(opt::IRContext* ir_context,
+                                 const std::vector<uint32_t>& words,
+                                 uint32_t width, bool is_signed) {
+  auto type_id = MaybeGetIntegerType(ir_context, width, is_signed);
+  if (!type_id) {
+    return 0;
+  }
+
+  const auto* type = ir_context->get_type_mgr()->GetType(type_id);
+  assert(type && "|type_id| is invalid");
+
+  opt::analysis::IntConstant constant(type->AsInteger(), words);
+  return ir_context->get_constant_mgr()->FindDeclaredConstant(&constant,
+                                                              type_id);
+}
+
+uint32_t MaybeGetFloatConstant(opt::IRContext* ir_context,
+                               const std::vector<uint32_t>& words,
+                               uint32_t width) {
+  auto type_id = MaybeGetFloatType(ir_context, width);
+  if (!type_id) {
+    return 0;
+  }
+
+  const auto* type = ir_context->get_type_mgr()->GetType(type_id);
+  assert(type && "|type_id| is invalid");
+
+  opt::analysis::FloatConstant constant(type->AsFloat(), words);
+  return ir_context->get_constant_mgr()->FindDeclaredConstant(&constant,
+                                                              type_id);
+}
+
+uint32_t MaybeGetBoolConstant(opt::IRContext* context, bool value) {
+  opt::analysis::Bool bool_type;
+  auto registered_bool_type =
+      context->get_type_mgr()->GetRegisteredType(&bool_type);
+  if (!registered_bool_type) {
+    return 0;
+  }
+  opt::analysis::BoolConstant bool_constant(registered_bool_type->AsBool(),
+                                            value);
+  return context->get_constant_mgr()->FindDeclaredConstant(
+      &bool_constant, context->get_type_mgr()->GetId(&bool_type));
+}
+
 void AddIntegerType(opt::IRContext* ir_context, uint32_t result_id,
                     uint32_t width, bool is_signed) {
   ir_context->module()->AddType(MakeUnique<opt::Instruction>(
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 5b471bd..efe0d0c 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -53,10 +53,6 @@
     opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
     const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids);
 
-// Returns the id of a boolean constant with value |value| if it exists in the
-// module, or 0 otherwise.
-uint32_t MaybeGetBoolConstantId(opt::IRContext* context, bool value);
-
 // Requires that a boolean constant with value |condition_value| is available,
 // that PhiIdsOkForNewEdge(context, bb_from, bb_to, phi_ids) holds, and that
 // bb_from ends with "OpBranch %some_block".  Turns OpBranch into
@@ -306,6 +302,48 @@
 uint32_t MaybeGetStructType(opt::IRContext* ir_context,
                             const std::vector<uint32_t>& component_type_ids);
 
+// Recursive definition is the following:
+// - if |scalar_or_composite_type_id| is a result id of a scalar type - returns
+//   a result id of the following constants (depending on the type): int -> 0,
+//   float -> 0.0, bool -> false.
+// - otherwise, returns a result id of an OpConstantComposite instruction.
+//   Every component of the composite constant is looked up by calling this
+//   function with the type id of that component.
+// Returns 0 if no such instruction is present in the module.
+uint32_t MaybeGetZeroConstant(opt::IRContext* ir_context,
+                              uint32_t scalar_or_composite_type_id);
+
+// Returns the result id of an OpConstant instruction. |scalar_type_id| must be
+// a result id of a scalar type (i.e. int, float or bool). Returns 0 if no such
+// instruction is present in the module.
+uint32_t MaybeGetScalarConstant(opt::IRContext* ir_context,
+                                const std::vector<uint32_t>& words,
+                                uint32_t scalar_type_id);
+
+// Returns the result id of an OpConstantComposite instruction.
+// |composite_type_id| must be a result id of a composite type (i.e. vector,
+// matrix, struct or array). Returns 0 if no such instruction is present in the
+// module.
+uint32_t MaybeGetCompositeConstant(opt::IRContext* ir_context,
+                                   const std::vector<uint32_t>& component_ids,
+                                   uint32_t composite_type_id);
+
+// Returns the result id of an OpConstant instruction of integral type.
+// Returns 0 if no such instruction or type is present in the module.
+uint32_t MaybeGetIntegerConstant(opt::IRContext* ir_context,
+                                 const std::vector<uint32_t>& words,
+                                 uint32_t width, bool is_signed);
+
+// Returns the result id of an OpConstant instruction of floating-point type.
+// Returns 0 if no such instruction or type is present in the module.
+uint32_t MaybeGetFloatConstant(opt::IRContext* ir_context,
+                               const std::vector<uint32_t>& words,
+                               uint32_t width);
+
+// Returns the id of a boolean constant with value |value| if it exists in the
+// module, or 0 otherwise.
+uint32_t MaybeGetBoolConstant(opt::IRContext* context, bool value);
+
 // Creates a new OpTypeInt instruction in the module. Updates module's id bound
 // to accommodate for |result_id|.
 void AddIntegerType(opt::IRContext* ir_context, uint32_t result_id,
diff --git a/source/fuzz/transformation_add_dead_block.cpp b/source/fuzz/transformation_add_dead_block.cpp
index b246c3f..0bbda5a 100644
--- a/source/fuzz/transformation_add_dead_block.cpp
+++ b/source/fuzz/transformation_add_dead_block.cpp
@@ -40,8 +40,8 @@
 
   // First, we check that a constant with the same value as
   // |message_.condition_value| is present.
-  if (!fuzzerutil::MaybeGetBoolConstantId(ir_context,
-                                          message_.condition_value())) {
+  if (!fuzzerutil::MaybeGetBoolConstant(ir_context,
+                                        message_.condition_value())) {
     // The required constant is not present, so the transformation cannot be
     // applied.
     return false;
@@ -92,8 +92,8 @@
       existing_block->terminator()->GetSingleWordInOperand(0);
 
   // Get the id of the boolean value that will be used as the branch condition.
-  auto bool_id = fuzzerutil::MaybeGetBoolConstantId(ir_context,
-                                                    message_.condition_value());
+  auto bool_id =
+      fuzzerutil::MaybeGetBoolConstant(ir_context, message_.condition_value());
 
   // Make a new block that unconditionally branches to the original successor
   // block.
diff --git a/source/fuzz/transformation_add_dead_break.cpp b/source/fuzz/transformation_add_dead_break.cpp
index db9de7d..44c9aba 100644
--- a/source/fuzz/transformation_add_dead_break.cpp
+++ b/source/fuzz/transformation_add_dead_break.cpp
@@ -112,8 +112,8 @@
     const TransformationContext& transformation_context) const {
   // First, we check that a constant with the same value as
   // |message_.break_condition_value| is present.
-  if (!fuzzerutil::MaybeGetBoolConstantId(ir_context,
-                                          message_.break_condition_value())) {
+  if (!fuzzerutil::MaybeGetBoolConstant(ir_context,
+                                        message_.break_condition_value())) {
     // The required constant is not present, so the transformation cannot be
     // applied.
     return false;
diff --git a/source/fuzz/transformation_add_dead_continue.cpp b/source/fuzz/transformation_add_dead_continue.cpp
index 1fc6d67..1328b1e 100644
--- a/source/fuzz/transformation_add_dead_continue.cpp
+++ b/source/fuzz/transformation_add_dead_continue.cpp
@@ -38,8 +38,8 @@
     const TransformationContext& transformation_context) const {
   // First, we check that a constant with the same value as
   // |message_.continue_condition_value| is present.
-  if (!fuzzerutil::MaybeGetBoolConstantId(
-          ir_context, message_.continue_condition_value())) {
+  if (!fuzzerutil::MaybeGetBoolConstant(ir_context,
+                                        message_.continue_condition_value())) {
     // The required constant is not present, so the transformation cannot be
     // applied.
     return false;