spirv-fuzz: Improve support for compute shaders in donation (#3277)

(1) Runtime arrays are turned into fixed-size arrays, by turning
    OpTypeRuntimeArray into OpTypeArray and uses of OpArrayLength into
    uses of the constant used for the length of the fixed-size array.

(2) Atomic instructions are not donated, and uses of their results are
    replaced with uses of constants of the result type.
diff --git a/source/fuzz/fuzzer_pass_add_function_calls.cpp b/source/fuzz/fuzzer_pass_add_function_calls.cpp
index f666eb2..569df10 100644
--- a/source/fuzz/fuzzer_pass_add_function_calls.cpp
+++ b/source/fuzz/fuzzer_pass_add_function_calls.cpp
@@ -214,8 +214,9 @@
         result.push_back(fresh_variable_id);
 
         // Now bring the variable into existence.
-        if (type_instruction->GetSingleWordInOperand(0) ==
-            SpvStorageClassFunction) {
+        auto storage_class = static_cast<SpvStorageClass>(
+            type_instruction->GetSingleWordInOperand(0));
+        if (storage_class == SpvStorageClassFunction) {
           // Add a new zero-initialized local variable to the current
           // function, noting that its pointee value is irrelevant.
           ApplyTransformation(TransformationAddLocalVariable(
@@ -224,16 +225,19 @@
                   type_instruction->GetSingleWordInOperand(1)),
               true));
         } else {
-          assert(type_instruction->GetSingleWordInOperand(0) ==
-                     SpvStorageClassPrivate &&
-                 "Only Function and Private storage classes are "
+          assert((storage_class == SpvStorageClassPrivate ||
+                  storage_class == SpvStorageClassWorkgroup) &&
+                 "Only Function, Private and Workgroup storage classes are "
                  "supported at present.");
-          // Add a new zero-initialized global variable to the module,
-          // noting that its pointee value is irrelevant.
+          // Add a new global variable to the module, zero-initializing it if
+          // it has Private storage class, and noting that its pointee value is
+          // irrelevant.
           ApplyTransformation(TransformationAddGlobalVariable(
-              fresh_variable_id, arg_type_id,
-              FindOrCreateZeroConstant(
-                  type_instruction->GetSingleWordInOperand(1)),
+              fresh_variable_id, arg_type_id, storage_class,
+              storage_class == SpvStorageClassPrivate
+                  ? FindOrCreateZeroConstant(
+                        type_instruction->GetSingleWordInOperand(1))
+                  : 0,
               true));
         }
       } else {
diff --git a/source/fuzz/fuzzer_pass_add_global_variables.cpp b/source/fuzz/fuzzer_pass_add_global_variables.cpp
index 80708ed..4023b22 100644
--- a/source/fuzz/fuzzer_pass_add_global_variables.cpp
+++ b/source/fuzz/fuzzer_pass_add_global_variables.cpp
@@ -66,9 +66,11 @@
           available_pointers_to_basic_type[GetFuzzerContext()->RandomIndex(
               available_pointers_to_basic_type)];
     }
+    // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3274):  We could
+    //  add new variables with Workgroup storage class in compute shaders.
     ApplyTransformation(TransformationAddGlobalVariable(
         GetFuzzerContext()->GetFreshId(), pointer_type_id,
-        FindOrCreateZeroConstant(basic_type), true));
+        SpvStorageClassPrivate, FindOrCreateZeroConstant(basic_type), true));
   }
 }
 
diff --git a/source/fuzz/fuzzer_pass_donate_modules.cpp b/source/fuzz/fuzzer_pass_donate_modules.cpp
index 4ba5305..b043b7f 100644
--- a/source/fuzz/fuzzer_pass_donate_modules.cpp
+++ b/source/fuzz/fuzzer_pass_donate_modules.cpp
@@ -116,6 +116,7 @@
   switch (donor_storage_class) {
     case SpvStorageClassFunction:
     case SpvStorageClassPrivate:
+    case SpvStorageClassWorkgroup:
       // We leave these alone
       return donor_storage_class;
     case SpvStorageClassInput:
@@ -280,36 +281,51 @@
         // It is OK to have multiple structurally identical array types, so
         // we go ahead and add a remapped version of the type declared by the
         // donor.
+        uint32_t component_type_id = type_or_value.GetSingleWordInOperand(0);
         new_result_id = GetFuzzerContext()->GetFreshId();
         ApplyTransformation(TransformationAddTypeArray(
-            new_result_id,
-            original_id_to_donated_id->at(
-                type_or_value.GetSingleWordInOperand(0)),
+            new_result_id, original_id_to_donated_id->at(component_type_id),
             original_id_to_donated_id->at(
                 type_or_value.GetSingleWordInOperand(1))));
       } break;
+      case SpvOpTypeRuntimeArray: {
+        // A runtime array is allowed as the final member of an SSBO.  During
+        // donation we turn runtime arrays into fixed-size arrays.  For dead
+        // code donations this is OK because the array is never indexed into at
+        // runtime, so it does not matter what its size is.  For live-safe code,
+        // all accesses are made in-bounds, so this is also OK.
+        //
+        // The special OpArrayLength instruction, which works on runtime arrays,
+        // is rewritten to yield the fixed length that is used for the array.
+
+        uint32_t component_type_id = type_or_value.GetSingleWordInOperand(0);
+        new_result_id = GetFuzzerContext()->GetFreshId();
+        ApplyTransformation(TransformationAddTypeArray(
+            new_result_id, original_id_to_donated_id->at(component_type_id),
+            FindOrCreate32BitIntegerConstant(
+                GetFuzzerContext()->GetRandomSizeForNewArray(), false)));
+      } break;
       case SpvOpTypeStruct: {
         // Similar to SpvOpTypeArray.
-        new_result_id = GetFuzzerContext()->GetFreshId();
         std::vector<uint32_t> member_type_ids;
-        type_or_value.ForEachInId(
-            [&member_type_ids,
-             &original_id_to_donated_id](const uint32_t* component_type_id) {
-              member_type_ids.push_back(
-                  original_id_to_donated_id->at(*component_type_id));
-            });
+        for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
+          auto component_type_id = type_or_value.GetSingleWordInOperand(i);
+          member_type_ids.push_back(
+              original_id_to_donated_id->at(component_type_id));
+        }
+        new_result_id = GetFuzzerContext()->GetFreshId();
         ApplyTransformation(
             TransformationAddTypeStruct(new_result_id, member_type_ids));
       } break;
       case SpvOpTypePointer: {
         // Similar to SpvOpTypeArray.
+        uint32_t pointee_type_id = type_or_value.GetSingleWordInOperand(1);
         new_result_id = GetFuzzerContext()->GetFreshId();
         ApplyTransformation(TransformationAddTypePointer(
             new_result_id,
             AdaptStorageClass(static_cast<SpvStorageClass>(
                 type_or_value.GetSingleWordInOperand(0))),
-            original_id_to_donated_id->at(
-                type_or_value.GetSingleWordInOperand(1))));
+            original_id_to_donated_id->at(pointee_type_id)));
       } break;
       case SpvOpTypeFunction: {
         // It is not OK to have multiple function types that use identical ids
@@ -333,8 +349,10 @@
 
         std::vector<uint32_t> return_and_parameter_types;
         for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
-          return_and_parameter_types.push_back(original_id_to_donated_id->at(
-              type_or_value.GetSingleWordInOperand(i)));
+          uint32_t return_or_parameter_type =
+              type_or_value.GetSingleWordInOperand(i);
+          return_and_parameter_types.push_back(
+              original_id_to_donated_id->at(return_or_parameter_type));
         }
         uint32_t existing_function_id = fuzzerutil::FindFunctionType(
             GetIRContext(), return_and_parameter_types);
@@ -379,6 +397,10 @@
             data_words));
       } break;
       case SpvOpConstantComposite: {
+        assert(original_id_to_donated_id->count(type_or_value.type_id()) &&
+               "Composite types for which it is possible to create a constant "
+               "should have been donated.");
+
         // It is OK to have duplicate constant composite definitions, so add
         // this to the module using remapped versions of all consituent ids and
         // the result type.
@@ -387,6 +409,9 @@
         type_or_value.ForEachInId(
             [&constituent_ids,
              &original_id_to_donated_id](const uint32_t* constituent_id) {
+              assert(original_id_to_donated_id->count(*constituent_id) &&
+                     "The constants used to construct this composite should "
+                     "have been donated.");
               constituent_ids.push_back(
                   original_id_to_donated_id->at(*constituent_id));
             });
@@ -396,12 +421,6 @@
             constituent_ids));
       } break;
       case SpvOpConstantNull: {
-        if (!original_id_to_donated_id->count(type_or_value.type_id())) {
-          // We did not donate the type associated with this null constant, so
-          // we cannot donate the null constant.
-          continue;
-        }
-
         // It is fine to have multiple OpConstantNull instructions of the same
         // type, so we just add this to the recipient module.
         new_result_id = GetFuzzerContext()->GetFreshId();
@@ -413,10 +432,14 @@
         // This is a global variable that could have one of various storage
         // classes.  However, we change all global variable pointer storage
         // classes (such as Uniform, Input and Output) to private when donating
-        // pointer types.  Thus this variable's pointer type is guaranteed to
-        // have storage class private.  As a result, we simply add a Private
-        // storage class global variable, using remapped versions of the result
-        // type and initializer ids for the global variable in the donor.
+        // pointer types, with the exception of the Workgroup storage class.
+        //
+        // Thus this variable's pointer type is guaranteed to have storage class
+        // Private or Workgroup.
+        //
+        // We add a global variable with either Private or Workgroup storage
+        // class, using remapped versions of the result type and initializer ids
+        // for the global variable in the donor.
         //
         // We regard the added variable as having an irrelevant value.  This
         // means that future passes can add stores to the variable in any
@@ -426,19 +449,35 @@
         uint32_t remapped_pointer_type =
             original_id_to_donated_id->at(type_or_value.type_id());
         uint32_t initializer_id;
+        SpvStorageClass storage_class =
+            static_cast<SpvStorageClass>(type_or_value.GetSingleWordInOperand(
+                0)) == SpvStorageClassWorkgroup
+                ? SpvStorageClassWorkgroup
+                : SpvStorageClassPrivate;
         if (type_or_value.NumInOperands() == 1) {
-          // The variable did not have an initializer; initialize it to zero.
-          // This is to limit problems associated with uninitialized data.
-          initializer_id = FindOrCreateZeroConstant(
-              fuzzerutil::GetPointeeTypeIdFromPointerType(
-                  GetIRContext(), remapped_pointer_type));
+          // The variable did not have an initializer.  Initialize it to zero
+          // if it has Private storage class (to limit problems associated with
+          // uninitialized data), and leave it uninitialized if it has Workgroup
+          // storage class (as Workgroup variables cannot have initializers).
+
+          // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3275): we
+          //  could initialize Workgroup variables at the start of an entry
+          //  point, and should do so if their uninitialized nature proves
+          //  problematic.
+          initializer_id =
+              storage_class == SpvStorageClassWorkgroup
+                  ? 0
+                  : FindOrCreateZeroConstant(
+                        fuzzerutil::GetPointeeTypeIdFromPointerType(
+                            GetIRContext(), remapped_pointer_type));
         } else {
           // The variable already had an initializer; use its remapped id.
           initializer_id = original_id_to_donated_id->at(
               type_or_value.GetSingleWordInOperand(1));
         }
         ApplyTransformation(TransformationAddGlobalVariable(
-            new_result_id, remapped_pointer_type, initializer_id, true));
+            new_result_id, remapped_pointer_type, storage_class, initializer_id,
+            true));
       } break;
       case SpvOpUndef: {
         // It is fine to have multiple Undef instructions of the same type, so
@@ -493,9 +532,80 @@
     // Scan through the function, remapping each result id that it generates to
     // a fresh id.  This is necessary because functions include forward
     // references, e.g. to labels.
-    function_to_donate->ForEachInst([this, &original_id_to_donated_id](
+    function_to_donate->ForEachInst([this, donor_ir_context,
+                                     &original_id_to_donated_id](
                                         const opt::Instruction* instruction) {
-      if (instruction->result_id()) {
+      if (!instruction->result_id()) {
+        return;
+      }
+      if (IgnoreInstruction(instruction)) {
+        if (instruction->opcode() == SpvOpArrayLength) {
+          // We treat the OpArrayLength instruction specially.  In the donor
+          // shader this gets the length of a runtime array that is the final
+          // member of a struct.  During donation, we will have converted the
+          // runtime array type, and the associated struct field, into a fixed-
+          // size array.  We can then use the constant size of this fixed-sized
+          // array wherever we would have used the result of an OpArrayLength
+          // instruction.
+          uint32_t donated_variable_id = original_id_to_donated_id->at(
+              instruction->GetSingleWordInOperand(0));
+          auto donated_variable_instruction =
+              GetIRContext()->get_def_use_mgr()->GetDef(donated_variable_id);
+          auto pointer_to_struct_instruction =
+              GetIRContext()->get_def_use_mgr()->GetDef(
+                  donated_variable_instruction->type_id());
+          assert(pointer_to_struct_instruction->opcode() == SpvOpTypePointer &&
+                 "Type of variable must be pointer.");
+          auto donated_struct_type_instruction =
+              GetIRContext()->get_def_use_mgr()->GetDef(
+                  pointer_to_struct_instruction->GetSingleWordInOperand(1));
+          assert(
+              donated_struct_type_instruction->opcode() == SpvOpTypeStruct &&
+              "Pointee type of pointer used by OpArrayLength must be struct.");
+          assert(donated_struct_type_instruction->NumInOperands() ==
+                     instruction->GetSingleWordInOperand(1) + 1 &&
+                 "OpArrayLength must refer to the final member of the given "
+                 "struct.");
+          uint32_t fixed_size_array_type_id =
+              donated_struct_type_instruction->GetSingleWordInOperand(
+                  donated_struct_type_instruction->NumInOperands() - 1);
+          auto fixed_size_array_type_instruction =
+              GetIRContext()->get_def_use_mgr()->GetDef(
+                  fixed_size_array_type_id);
+          assert(fixed_size_array_type_instruction->opcode() ==
+                     SpvOpTypeArray &&
+                 "The donated array type must be fixed-size.");
+          auto array_size_id =
+              fixed_size_array_type_instruction->GetSingleWordInOperand(1);
+          original_id_to_donated_id->insert(
+              {instruction->result_id(), array_size_id});
+        } else if (instruction->type_id()) {
+          // If the ignored instruction has a basic result type then we
+          // associate its result id with a constant of that type, so that
+          // instructions that use the result id will use the constant instead.
+          // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3177):
+          //  Using this particular constant is arbitrary, so if we have a
+          //  mechanism for noting that an id use is arbitrary and could be
+          //  fuzzed we should use it here.
+          auto type_inst = donor_ir_context->get_def_use_mgr()->GetDef(
+              instruction->type_id());
+          switch (type_inst->opcode()) {
+            case SpvOpTypeArray:
+            case SpvOpTypeBool:
+            case SpvOpTypeFloat:
+            case SpvOpTypeInt:
+            case SpvOpTypeStruct:
+            case SpvOpTypeVector:
+            case SpvOpTypeMatrix:
+              original_id_to_donated_id->insert(
+                  {instruction->result_id(),
+                   FindOrCreateZeroConstant(
+                       original_id_to_donated_id->at(instruction->type_id()))});
+            default:
+              break;
+          }
+        }
+      } else {
         original_id_to_donated_id->insert(
             {instruction->result_id(), GetFuzzerContext()->GetFreshId()});
       }
@@ -505,6 +615,10 @@
     function_to_donate->ForEachInst([this, &donated_instructions,
                                      &original_id_to_donated_id](
                                         const opt::Instruction* instruction) {
+      if (IgnoreInstruction(instruction)) {
+        return;
+      }
+
       // Get the instruction's input operands into donation-ready form,
       // remapping any id uses in the process.
       opt::Instruction::OperandList input_operands;
@@ -642,9 +756,28 @@
                     GetFuzzerContext()->GetFreshId());
 
                 // Get the bound for the component being indexed into.
-                uint32_t bound =
-                    TransformationAddFunction::GetBoundForCompositeIndex(
-                        donor_ir_context, *should_be_composite_type);
+                uint32_t bound;
+                if (should_be_composite_type->opcode() ==
+                    SpvOpTypeRuntimeArray) {
+                  // The donor is indexing into a runtime array.  We do not
+                  // donate runtime arrays.  Instead, we donate a corresponding
+                  // fixed-size array for every runtime array.  We should thus
+                  // find that donor composite type's result id maps to a fixed-
+                  // size array.
+                  auto fixed_size_array_type =
+                      GetIRContext()->get_def_use_mgr()->GetDef(
+                          original_id_to_donated_id->at(
+                              should_be_composite_type->result_id()));
+                  assert(fixed_size_array_type->opcode() == SpvOpTypeArray &&
+                         "A runtime array type in the donor should have been "
+                         "replaced by a fixed-sized array in the recipient.");
+                  // The size of this fixed-size array is a suitable bound.
+                  bound = TransformationAddFunction::GetBoundForCompositeIndex(
+                      GetIRContext(), *fixed_size_array_type);
+                } else {
+                  bound = TransformationAddFunction::GetBoundForCompositeIndex(
+                      donor_ir_context, *should_be_composite_type);
+                }
                 const uint32_t index_id = inst.GetSingleWordInOperand(index);
                 auto index_inst =
                     donor_ir_context->get_def_use_mgr()->GetDef(index_id);
@@ -707,6 +840,37 @@
   }
 }
 
+bool FuzzerPassDonateModules::IgnoreInstruction(
+    const opt::Instruction* instruction) {
+  switch (instruction->opcode()) {
+    case SpvOpArrayLength:
+      // We ignore instructions that get the length of runtime arrays, because
+      // we turn all runtime arrays into fixed-size arrays.
+    case SpvOpAtomicLoad:
+    case SpvOpAtomicStore:
+    case SpvOpAtomicExchange:
+    case SpvOpAtomicCompareExchange:
+    case SpvOpAtomicCompareExchangeWeak:
+    case SpvOpAtomicIIncrement:
+    case SpvOpAtomicIDecrement:
+    case SpvOpAtomicIAdd:
+    case SpvOpAtomicISub:
+    case SpvOpAtomicSMin:
+    case SpvOpAtomicUMin:
+    case SpvOpAtomicSMax:
+    case SpvOpAtomicUMax:
+    case SpvOpAtomicAnd:
+    case SpvOpAtomicOr:
+    case SpvOpAtomicXor:
+      // We conservatively ignore all atomic instructions at present.
+      // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3276): Consider
+      //  being less conservative here.
+      return true;
+    default:
+      return false;
+  }
+}
+
 std::vector<uint32_t>
 FuzzerPassDonateModules::GetFunctionsInCallGraphTopologicalOrder(
     opt::IRContext* context) {
diff --git a/source/fuzz/fuzzer_pass_donate_modules.h b/source/fuzz/fuzzer_pass_donate_modules.h
index 9087daf..909f3bc 100644
--- a/source/fuzz/fuzzer_pass_donate_modules.h
+++ b/source/fuzz/fuzzer_pass_donate_modules.h
@@ -77,6 +77,12 @@
                        std::map<uint32_t, uint32_t>* original_id_to_donated_id,
                        bool make_livesafe);
 
+  // During donation we will have to ignore some instructions, e.g. because they
+  // use opcodes that we cannot support or because they reference the ids of
+  // instructions that have not been donated.  This function encapsulates the
+  // logic for deciding which instructions should be ignored.
+  bool IgnoreInstruction(const opt::Instruction* instruction);
+
   // Returns the ids of all functions in |context| in a topological order in
   // relation to the call graph of |context|, which is assumed to be recursion-
   // free.
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto
index 5dc70c3..2af3a92 100644
--- a/source/fuzz/protobufs/spvtoolsfuzz.proto
+++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -560,8 +560,9 @@
 
 message TransformationAddGlobalVariable {
 
-  // Adds a global variable of the given type to the module, with Private
-  // storage class and optionally with an initializer.
+  // Adds a global variable of the given type to the module, with Private or
+  // Workgroup storage class, and optionally (for the Private case) with an
+  // initializer.
 
   // Fresh id for the global variable
   uint32 fresh_id = 1;
@@ -569,13 +570,15 @@
   // The type of the global variable
   uint32 type_id = 2;
 
+  uint32 storage_class = 3;
+
   // Initial value of the variable
-  uint32 initializer_id = 3;
+  uint32 initializer_id = 4;
 
   // True if and only if the behaviour of the module should not depend on the
   // value of the variable, in which case stores to the variable can be
   // performed in an arbitrary fashion.
-  bool value_is_irrelevant = 4;
+  bool value_is_irrelevant = 5;
 
 }
 
diff --git a/source/fuzz/transformation_add_function.cpp b/source/fuzz/transformation_add_function.cpp
index 45fe342..c990f23 100644
--- a/source/fuzz/transformation_add_function.cpp
+++ b/source/fuzz/transformation_add_function.cpp
@@ -897,6 +897,11 @@
     case SpvOpTypeStruct: {
       return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
     }
+    case SpvOpTypeRuntimeArray:
+      assert(false &&
+             "GetBoundForCompositeIndex should not be invoked with an "
+             "OpTypeRuntimeArray, which does not have a static bound.");
+      return 0;
     default:
       assert(false && "Unknown composite type.");
       return 0;
@@ -909,6 +914,7 @@
   uint32_t sub_object_type_id;
   switch (composite_type_inst.opcode()) {
     case SpvOpTypeArray:
+    case SpvOpTypeRuntimeArray:
       sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
       break;
     case SpvOpTypeMatrix:
diff --git a/source/fuzz/transformation_add_global_variable.cpp b/source/fuzz/transformation_add_global_variable.cpp
index c016428..6464bfb 100644
--- a/source/fuzz/transformation_add_global_variable.cpp
+++ b/source/fuzz/transformation_add_global_variable.cpp
@@ -24,10 +24,11 @@
     : message_(message) {}
 
 TransformationAddGlobalVariable::TransformationAddGlobalVariable(
-    uint32_t fresh_id, uint32_t type_id, uint32_t initializer_id,
-    bool value_is_irrelevant) {
+    uint32_t fresh_id, uint32_t type_id, SpvStorageClass storage_class,
+    uint32_t initializer_id, bool value_is_irrelevant) {
   message_.set_fresh_id(fresh_id);
   message_.set_type_id(type_id);
+  message_.set_storage_class(storage_class);
   message_.set_initializer_id(initializer_id);
   message_.set_value_is_irrelevant(value_is_irrelevant);
 }
@@ -38,6 +39,17 @@
   if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) {
     return false;
   }
+
+  // The storage class must be Private or Workgroup.
+  auto storage_class = static_cast<SpvStorageClass>(message_.storage_class());
+  switch (storage_class) {
+    case SpvStorageClassPrivate:
+    case SpvStorageClassWorkgroup:
+      break;
+    default:
+      assert(false && "Unsupported storage class.");
+      return false;
+  }
   // The type id must correspond to a type.
   auto type = ir_context->get_type_mgr()->GetType(message_.type_id());
   if (!type) {
@@ -48,23 +60,32 @@
   if (!pointer_type) {
     return false;
   }
-  // ... with Private storage class.
-  if (pointer_type->storage_class() != SpvStorageClassPrivate) {
+  // ... with the right storage class.
+  if (pointer_type->storage_class() != storage_class) {
     return false;
   }
-  // The initializer id must be the id of a constant.  Check this with the
-  // constant manager.
-  auto constant_id = ir_context->get_constant_mgr()->GetConstantsFromIds(
-      {message_.initializer_id()});
-  if (constant_id.empty()) {
-    return false;
-  }
-  assert(constant_id.size() == 1 &&
-         "We asked for the constant associated with a single id; we should "
-         "get a single constant.");
-  // The type of the constant must match the pointee type of the pointer.
-  if (pointer_type->pointee_type() != constant_id[0]->type()) {
-    return false;
+  if (message_.initializer_id()) {
+    // An initializer is not allowed if the storage class is Workgroup.
+    if (storage_class == SpvStorageClassWorkgroup) {
+      assert(false &&
+             "By construction this transformation should not have an "
+             "initializer when Workgroup storage class is used.");
+      return false;
+    }
+    // The initializer id must be the id of a constant.  Check this with the
+    // constant manager.
+    auto constant_id = ir_context->get_constant_mgr()->GetConstantsFromIds(
+        {message_.initializer_id()});
+    if (constant_id.empty()) {
+      return false;
+    }
+    assert(constant_id.size() == 1 &&
+           "We asked for the constant associated with a single id; we should "
+           "get a single constant.");
+    // The type of the constant must match the pointee type of the pointer.
+    if (pointer_type->pointee_type() != constant_id[0]->type()) {
+      return false;
+    }
   }
   return true;
 }
@@ -74,7 +95,7 @@
     TransformationContext* transformation_context) const {
   opt::Instruction::OperandList input_operands;
   input_operands.push_back(
-      {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassPrivate}});
+      {SPV_OPERAND_TYPE_STORAGE_CLASS, {message_.storage_class()}});
   if (message_.initializer_id()) {
     input_operands.push_back(
         {SPV_OPERAND_TYPE_ID, {message_.initializer_id()}});
@@ -84,7 +105,7 @@
       input_operands));
   fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id());
 
-  if (PrivateGlobalsMustBeDeclaredInEntryPointInterfaces(ir_context)) {
+  if (GlobalVariablesMustBeDeclaredInEntryPointInterfaces(ir_context)) {
     // Conservatively add this global to the interface of every entry point in
     // the module.  This means that the global is available for other
     // transformations to use.
@@ -117,7 +138,7 @@
 }
 
 bool TransformationAddGlobalVariable::
-    PrivateGlobalsMustBeDeclaredInEntryPointInterfaces(
+    GlobalVariablesMustBeDeclaredInEntryPointInterfaces(
         opt::IRContext* ir_context) {
   // TODO(afd): We capture the universal environments for which this requirement
   //  holds.  The check should be refined on demand for other target
diff --git a/source/fuzz/transformation_add_global_variable.h b/source/fuzz/transformation_add_global_variable.h
index f28af44..289af9e 100644
--- a/source/fuzz/transformation_add_global_variable.h
+++ b/source/fuzz/transformation_add_global_variable.h
@@ -29,22 +29,26 @@
       const protobufs::TransformationAddGlobalVariable& message);
 
   TransformationAddGlobalVariable(uint32_t fresh_id, uint32_t type_id,
+                                  SpvStorageClass storage_class,
                                   uint32_t initializer_id,
                                   bool value_is_irrelevant);
 
   // - |message_.fresh_id| must be fresh
-  // - |message_.type_id| must be the id of a pointer type with Private storage
-  //   class
-  // - |message_.initializer_id| must either be 0 or the id of a constant whose
+  // - |message_.type_id| must be the id of a pointer type with the same storage
+  //   class as |message_.storage_class|
+  // - |message_.storage_class| must be Private or Workgroup
+  // - |message_.initializer_id| must be 0 if |message_.storage_class| is
+  //   Workgroup, and otherwise may either be 0 or the id of a constant whose
   //   type is the pointee type of |message_.type_id|
   bool IsApplicable(
       opt::IRContext* ir_context,
       const TransformationContext& transformation_context) const override;
 
-  // Adds a global variable with Private storage class to the module, with type
-  // |message_.type_id| and either no initializer or |message_.initializer_id|
-  // as an initializer, depending on whether |message_.initializer_id| is 0.
-  // The global variable has result id |message_.fresh_id|.
+  // Adds a global variable with storage class |message_.storage_class| to the
+  // module, with type |message_.type_id| and either no initializer or
+  // |message_.initializer_id| as an initializer, depending on whether
+  // |message_.initializer_id| is 0.  The global variable has result id
+  // |message_.fresh_id|.
   //
   // If |message_.value_is_irrelevant| holds, adds a corresponding fact to the
   // fact manager in |transformation_context|.
@@ -54,7 +58,10 @@
   protobufs::Transformation ToMessage() const override;
 
  private:
-  static bool PrivateGlobalsMustBeDeclaredInEntryPointInterfaces(
+  // Returns true if and only if the SPIR-V version being used requires that
+  // global variables accessed in the static call graph of an entry point need
+  // to be listed in that entry point's interface.
+  static bool GlobalVariablesMustBeDeclaredInEntryPointInterfaces(
       opt::IRContext* ir_context);
 
   protobufs::TransformationAddGlobalVariable message_;
diff --git a/test/fuzz/fuzzer_pass_donate_modules_test.cpp b/test/fuzz/fuzzer_pass_donate_modules_test.cpp
index 40d7d24..bbc92b9 100644
--- a/test/fuzz/fuzzer_pass_donate_modules_test.cpp
+++ b/test/fuzz/fuzzer_pass_donate_modules_test.cpp
@@ -198,8 +198,8 @@
   TransformationContext transformation_context(&fact_manager,
                                                validator_options);
 
-  auto prng = MakeUnique<PseudoRandomGenerator>(0);
-  FuzzerContext fuzzer_context(prng.get(), 100);
+  PseudoRandomGenerator prng(0);
+  FuzzerContext fuzzer_context(&prng, 100);
   protobufs::TransformationSequence transformation_sequence;
 
   FuzzerPassDonateModules fuzzer_pass(recipient_context.get(),
@@ -276,7 +276,8 @@
   TransformationContext transformation_context(&fact_manager,
                                                validator_options);
 
-  FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100);
+  PseudoRandomGenerator prng(0);
+  FuzzerContext fuzzer_context(&prng, 100);
   protobufs::TransformationSequence transformation_sequence;
 
   FuzzerPassDonateModules fuzzer_pass(recipient_context.get(),
@@ -403,7 +404,8 @@
   TransformationContext transformation_context(&fact_manager,
                                                validator_options);
 
-  FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100);
+  PseudoRandomGenerator prng(0);
+  FuzzerContext fuzzer_context(&prng, 100);
   protobufs::TransformationSequence transformation_sequence;
 
   FuzzerPassDonateModules fuzzer_pass(recipient_context.get(),
@@ -560,7 +562,8 @@
   TransformationContext transformation_context(&fact_manager,
                                                validator_options);
 
-  FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100);
+  PseudoRandomGenerator prng(0);
+  FuzzerContext fuzzer_context(&prng, 100);
   protobufs::TransformationSequence transformation_sequence;
 
   FuzzerPassDonateModules fuzzer_pass(recipient_context.get(),
@@ -574,6 +577,367 @@
   ASSERT_TRUE(IsValid(env, recipient_context.get()));
 }
 
+TEST(FuzzerPassDonateModulesTest, DonateComputeShaderWithRuntimeArray) {
+  std::string recipient_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main"
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  std::string donor_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main"
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+               OpDecorate %9 ArrayStride 4
+               OpMemberDecorate %10 0 Offset 0
+               OpDecorate %10 BufferBlock
+               OpDecorate %12 DescriptorSet 0
+               OpDecorate %12 Binding 0
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %9 = OpTypeRuntimeArray %6
+         %10 = OpTypeStruct %9
+         %11 = OpTypePointer Uniform %10
+         %12 = OpVariable %11 Uniform
+         %13 = OpTypeInt 32 0
+         %16 = OpConstant %6 0
+         %18 = OpConstant %6 1
+         %20 = OpTypePointer Uniform %6
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+          %8 = OpVariable %7 Function
+         %14 = OpArrayLength %13 %12 0
+         %15 = OpBitcast %6 %14
+               OpStore %8 %15
+         %17 = OpLoad %6 %8
+         %19 = OpISub %6 %17 %18
+         %21 = OpAccessChain %20 %12 %16 %19
+               OpStore %21 %16
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto recipient_context =
+      BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+
+  const auto donor_context =
+      BuildModule(env, consumer, donor_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, donor_context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+  PseudoRandomGenerator prng(0);
+  FuzzerContext fuzzer_context(&prng, 100);
+  protobufs::TransformationSequence transformation_sequence;
+
+  FuzzerPassDonateModules fuzzer_pass(recipient_context.get(),
+                                      &transformation_context, &fuzzer_context,
+                                      &transformation_sequence, {});
+
+  fuzzer_pass.DonateSingleModule(donor_context.get(), false);
+
+  // We just check that the result is valid.  Checking to what it should be
+  // exactly equal to would be very fragile.
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+}
+
+TEST(FuzzerPassDonateModulesTest, DonateComputeShaderWithRuntimeArrayLivesafe) {
+  std::string recipient_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main"
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  std::string donor_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main"
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+               OpDecorate %16 ArrayStride 4
+               OpMemberDecorate %17 0 Offset 0
+               OpDecorate %17 BufferBlock
+               OpDecorate %19 DescriptorSet 0
+               OpDecorate %19 Binding 0
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Function %6
+          %9 = OpConstant %6 0
+         %16 = OpTypeRuntimeArray %6
+         %17 = OpTypeStruct %16
+         %18 = OpTypePointer Uniform %17
+         %19 = OpVariable %18 Uniform
+         %20 = OpTypeInt 32 0
+         %23 = OpTypeBool
+         %26 = OpConstant %6 32
+         %27 = OpTypePointer Uniform %6
+         %30 = OpConstant %6 1
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+          %8 = OpVariable %7 Function
+               OpStore %8 %9
+               OpBranch %10
+         %10 = OpLabel
+               OpLoopMerge %12 %13 None
+               OpBranch %14
+         %14 = OpLabel
+         %15 = OpLoad %6 %8
+         %21 = OpArrayLength %20 %19 0
+         %22 = OpBitcast %6 %21
+         %24 = OpSLessThan %23 %15 %22
+               OpBranchConditional %24 %11 %12
+         %11 = OpLabel
+         %25 = OpLoad %6 %8
+         %28 = OpAccessChain %27 %19 %9 %25
+               OpStore %28 %26
+               OpBranch %13
+         %13 = OpLabel
+         %29 = OpLoad %6 %8
+         %31 = OpIAdd %6 %29 %30
+               OpStore %8 %31
+               OpBranch %10
+         %12 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto recipient_context =
+      BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+
+  const auto donor_context =
+      BuildModule(env, consumer, donor_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, donor_context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+  PseudoRandomGenerator prng(0);
+  FuzzerContext fuzzer_context(&prng, 100);
+  protobufs::TransformationSequence transformation_sequence;
+
+  FuzzerPassDonateModules fuzzer_pass(recipient_context.get(),
+                                      &transformation_context, &fuzzer_context,
+                                      &transformation_sequence, {});
+
+  fuzzer_pass.DonateSingleModule(donor_context.get(), true);
+
+  // We just check that the result is valid.  Checking to what it should be
+  // exactly equal to would be very fragile.
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+}
+
+TEST(FuzzerPassDonateModulesTest, DonateComputeShaderWithWorkgroupVariables) {
+  std::string recipient_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main"
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  std::string donor_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main"
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Workgroup %6
+          %8 = OpVariable %7 Workgroup
+          %9 = OpConstant %6 2
+         %10 = OpVariable %7 Workgroup
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpStore %8 %9
+         %11 = OpLoad %6 %8
+               OpStore %10 %11
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto recipient_context =
+      BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+
+  const auto donor_context =
+      BuildModule(env, consumer, donor_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, donor_context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+  PseudoRandomGenerator prng(0);
+  FuzzerContext fuzzer_context(&prng, 100);
+  protobufs::TransformationSequence transformation_sequence;
+
+  FuzzerPassDonateModules fuzzer_pass(recipient_context.get(),
+                                      &transformation_context, &fuzzer_context,
+                                      &transformation_sequence, {});
+
+  fuzzer_pass.DonateSingleModule(donor_context.get(), true);
+
+  // We just check that the result is valid.  Checking to what it should be
+  // exactly equal to would be very fragile.
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+}
+
+TEST(FuzzerPassDonateModulesTest, DonateComputeShaderWithAtomics) {
+  std::string recipient_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main"
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  std::string donor_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main"
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+               OpMemberDecorate %9 0 Offset 0
+               OpDecorate %9 BufferBlock
+               OpDecorate %11 DescriptorSet 0
+               OpDecorate %11 Binding 0
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 0
+          %7 = OpTypePointer Function %6
+          %9 = OpTypeStruct %6
+         %10 = OpTypePointer Uniform %9
+         %11 = OpVariable %10 Uniform
+         %12 = OpTypeInt 32 1
+         %13 = OpConstant %12 0
+         %14 = OpTypePointer Uniform %6
+         %16 = OpConstant %6 1
+         %17 = OpConstant %6 0
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+          %8 = OpVariable %7 Function
+         %15 = OpAccessChain %14 %11 %13
+         %18 = OpAtomicIAdd %6 %15 %16 %17 %16
+               OpStore %8 %18
+         %19 = OpAccessChain %14 %11 %13
+         %20 = OpLoad %6 %8
+         %21 = OpAtomicUMin %6 %19 %16 %17 %20
+               OpStore %8 %21
+         %22 = OpAccessChain %14 %11 %13
+         %23 = OpLoad %6 %8
+         %24 = OpAtomicUMax %6 %22 %16 %17 %23
+               OpStore %8 %24
+         %25 = OpAccessChain %14 %11 %13
+         %26 = OpLoad %6 %8
+         %27 = OpAtomicAnd %6 %25 %16 %17 %26
+               OpStore %8 %27
+         %28 = OpAccessChain %14 %11 %13
+         %29 = OpLoad %6 %8
+         %30 = OpAtomicOr %6 %28 %16 %17 %29
+               OpStore %8 %30
+         %31 = OpAccessChain %14 %11 %13
+         %32 = OpLoad %6 %8
+         %33 = OpAtomicXor %6 %31 %16 %17 %32
+               OpStore %8 %33
+         %34 = OpAccessChain %14 %11 %13
+         %35 = OpLoad %6 %8
+         %36 = OpAtomicExchange %6 %34 %16 %17 %35
+               OpStore %8 %36
+         %37 = OpAccessChain %14 %11 %13
+         %38 = OpLoad %6 %8
+         %39 = OpAtomicCompareExchange %6 %37 %16 %17 %17 %16 %38
+               OpStore %8 %39
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto recipient_context =
+      BuildModule(env, consumer, recipient_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+
+  const auto donor_context =
+      BuildModule(env, consumer, donor_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, donor_context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+  PseudoRandomGenerator prng(0);
+  FuzzerContext fuzzer_context(&prng, 100);
+  protobufs::TransformationSequence transformation_sequence;
+
+  FuzzerPassDonateModules fuzzer_pass(recipient_context.get(),
+                                      &transformation_context, &fuzzer_context,
+                                      &transformation_sequence, {});
+
+  fuzzer_pass.DonateSingleModule(donor_context.get(), true);
+
+  // We just check that the result is valid.  Checking to what it should be
+  // exactly equal to would be very fragile.
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+}
+
 TEST(FuzzerPassDonateModulesTest, Miscellaneous1) {
   std::string recipient_shader = R"(
                OpCapability Shader
diff --git a/test/fuzz/transformation_add_global_variable_test.cpp b/test/fuzz/transformation_add_global_variable_test.cpp
index 9b8faa4..5c74ca0 100644
--- a/test/fuzz/transformation_add_global_variable_test.cpp
+++ b/test/fuzz/transformation_add_global_variable_test.cpp
@@ -65,66 +65,82 @@
                                                validator_options);
 
   // Id already in use
-  ASSERT_FALSE(TransformationAddGlobalVariable(4, 10, 0, true)
-                   .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(
+      TransformationAddGlobalVariable(4, 10, SpvStorageClassPrivate, 0, true)
+          .IsApplicable(context.get(), transformation_context));
   // %1 is not a type
-  ASSERT_FALSE(TransformationAddGlobalVariable(100, 1, 0, false)
-                   .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(
+      TransformationAddGlobalVariable(100, 1, SpvStorageClassPrivate, 0, false)
+          .IsApplicable(context.get(), transformation_context));
 
   // %7 is not a pointer type
-  ASSERT_FALSE(TransformationAddGlobalVariable(100, 7, 0, true)
-                   .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(
+      TransformationAddGlobalVariable(100, 7, SpvStorageClassPrivate, 0, true)
+          .IsApplicable(context.get(), transformation_context));
 
   // %9 does not have Private storage class
-  ASSERT_FALSE(TransformationAddGlobalVariable(100, 9, 0, false)
-                   .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(
+      TransformationAddGlobalVariable(100, 9, SpvStorageClassPrivate, 0, false)
+          .IsApplicable(context.get(), transformation_context));
 
   // %15 does not have Private storage class
-  ASSERT_FALSE(TransformationAddGlobalVariable(100, 15, 0, true)
-                   .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(
+      TransformationAddGlobalVariable(100, 15, SpvStorageClassPrivate, 0, true)
+          .IsApplicable(context.get(), transformation_context));
 
   // %10 is a pointer to float, while %16 is an int constant
-  ASSERT_FALSE(TransformationAddGlobalVariable(100, 10, 16, false)
+  ASSERT_FALSE(TransformationAddGlobalVariable(100, 10, SpvStorageClassPrivate,
+                                               16, false)
                    .IsApplicable(context.get(), transformation_context));
 
   // %10 is a Private pointer to float, while %15 is a variable with type
   // Uniform float pointer
-  ASSERT_FALSE(TransformationAddGlobalVariable(100, 10, 15, true)
-                   .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(
+      TransformationAddGlobalVariable(100, 10, SpvStorageClassPrivate, 15, true)
+          .IsApplicable(context.get(), transformation_context));
 
   // %12 is a Private pointer to int, while %10 is a variable with type
   // Private float pointer
-  ASSERT_FALSE(TransformationAddGlobalVariable(100, 12, 10, false)
+  ASSERT_FALSE(TransformationAddGlobalVariable(100, 12, SpvStorageClassPrivate,
+                                               10, false)
                    .IsApplicable(context.get(), transformation_context));
 
   // %10 is pointer-to-float, and %14 has type pointer-to-float; that's not OK
   // since the initializer's type should be the *pointee* type.
-  ASSERT_FALSE(TransformationAddGlobalVariable(104, 10, 14, true)
-                   .IsApplicable(context.get(), transformation_context));
+  ASSERT_FALSE(
+      TransformationAddGlobalVariable(104, 10, SpvStorageClassPrivate, 14, true)
+          .IsApplicable(context.get(), transformation_context));
 
   // This would work in principle, but logical addressing does not allow
   // a pointer to a pointer.
-  ASSERT_FALSE(TransformationAddGlobalVariable(104, 17, 14, false)
+  ASSERT_FALSE(TransformationAddGlobalVariable(104, 17, SpvStorageClassPrivate,
+                                               14, false)
                    .IsApplicable(context.get(), transformation_context));
 
   TransformationAddGlobalVariable transformations[] = {
       // %100 = OpVariable %12 Private
-      TransformationAddGlobalVariable(100, 12, 16, true),
+      TransformationAddGlobalVariable(100, 12, SpvStorageClassPrivate, 16,
+                                      true),
 
       // %101 = OpVariable %10 Private
-      TransformationAddGlobalVariable(101, 10, 40, false),
+      TransformationAddGlobalVariable(101, 10, SpvStorageClassPrivate, 40,
+                                      false),
 
       // %102 = OpVariable %13 Private
-      TransformationAddGlobalVariable(102, 13, 41, true),
+      TransformationAddGlobalVariable(102, 13, SpvStorageClassPrivate, 41,
+                                      true),
 
       // %103 = OpVariable %12 Private %16
-      TransformationAddGlobalVariable(103, 12, 16, false),
+      TransformationAddGlobalVariable(103, 12, SpvStorageClassPrivate, 16,
+                                      false),
 
       // %104 = OpVariable %19 Private %21
-      TransformationAddGlobalVariable(104, 19, 21, true),
+      TransformationAddGlobalVariable(104, 19, SpvStorageClassPrivate, 21,
+                                      true),
 
       // %105 = OpVariable %19 Private %22
-      TransformationAddGlobalVariable(105, 19, 22, false)};
+      TransformationAddGlobalVariable(105, 19, SpvStorageClassPrivate, 22,
+                                      false)};
 
   for (auto& transformation : transformations) {
     ASSERT_TRUE(
@@ -239,13 +255,16 @@
 
   TransformationAddGlobalVariable transformations[] = {
       // %100 = OpVariable %12 Private
-      TransformationAddGlobalVariable(100, 12, 16, true),
+      TransformationAddGlobalVariable(100, 12, SpvStorageClassPrivate, 16,
+                                      true),
 
       // %101 = OpVariable %12 Private %16
-      TransformationAddGlobalVariable(101, 12, 16, false),
+      TransformationAddGlobalVariable(101, 12, SpvStorageClassPrivate, 16,
+                                      false),
 
       // %102 = OpVariable %19 Private %21
-      TransformationAddGlobalVariable(102, 19, 21, true)};
+      TransformationAddGlobalVariable(102, 19, SpvStorageClassPrivate, 21,
+                                      true)};
 
   for (auto& transformation : transformations) {
     ASSERT_TRUE(
@@ -301,6 +320,85 @@
   ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
 }
 
+TEST(TransformationAddGlobalVariableTest, TestAddingWorkgroupGlobals) {
+  // This checks that workgroup globals can be added to a compute shader.
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main"
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Workgroup %6
+         %50 = OpConstant %6 2
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_4;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+#ifndef NDEBUG
+  ASSERT_DEATH(
+      TransformationAddGlobalVariable(8, 7, SpvStorageClassWorkgroup, 50, true)
+          .IsApplicable(context.get(), transformation_context),
+      "By construction this transformation should not have an.*initializer "
+      "when Workgroup storage class is used");
+#endif
+
+  TransformationAddGlobalVariable transformations[] = {
+      // %8 = OpVariable %7 Workgroup
+      TransformationAddGlobalVariable(8, 7, SpvStorageClassWorkgroup, 0, true),
+
+      // %10 = OpVariable %7 Workgroup
+      TransformationAddGlobalVariable(10, 7, SpvStorageClassWorkgroup, 0,
+                                      false)};
+
+  for (auto& transformation : transformations) {
+    ASSERT_TRUE(
+        transformation.IsApplicable(context.get(), transformation_context));
+    transformation.Apply(context.get(), &transformation_context);
+  }
+  ASSERT_TRUE(
+      transformation_context.GetFactManager()->PointeeValueIsIrrelevant(8));
+  ASSERT_FALSE(
+      transformation_context.GetFactManager()->PointeeValueIsIrrelevant(10));
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %4 "main" %8 %10
+               OpExecutionMode %4 LocalSize 1 1 1
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 1
+          %7 = OpTypePointer Workgroup %6
+         %50 = OpConstant %6 2
+          %8 = OpVariable %7 Workgroup
+         %10 = OpVariable %7 Workgroup
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools