Add fix storage class code. (#2434)

This pass tries to fix validation error due to a mismatch of storage classes
in instructions.  There is no guarantee that all such error will be fixed,
and it is possible that in fixing these errors, it could lead to other
errors.

Fixes #2430.
diff --git a/Android.mk b/Android.mk
index 4479ca8..bfcaddc 100644
--- a/Android.mk
+++ b/Android.mk
@@ -101,6 +101,7 @@
 		source/opt/eliminate_dead_functions_util.cpp \
 		source/opt/eliminate_dead_members_pass.cpp \
 		source/opt/feature_manager.cpp \
+		source/opt/fix_storage_class.cpp \
 		source/opt/flatten_decoration_pass.cpp \
 		source/opt/fold.cpp \
 		source/opt/folding_rules.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index 5aa1af8..8367a39 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -502,6 +502,8 @@
     "source/opt/eliminate_dead_members_pass.h",
     "source/opt/feature_manager.cpp",
     "source/opt/feature_manager.h",
+    "source/opt/fix_storage_class.cpp",
+    "source/opt/fix_storage_class.h",
     "source/opt/flatten_decoration_pass.cpp",
     "source/opt/flatten_decoration_pass.h",
     "source/opt/fold.cpp",
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index 4e92bb0..846ad70 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -753,6 +753,11 @@
 // variables to the first value stored in them, if that is a constant.
 Optimizer::PassToken CreateGenerateWebGPUInitializersPass();
 
+// Create a pass to fix incorrect storage classes.  In order to make code
+// generation simpler, DXC may generate code where the storage classes do not
+// match up correctly.  This pass will fix the errors that it can.
+Optimizer::PassToken CreateFixStorageClassPass();
+
 }  // namespace spvtools
 
 #endif  // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 9eff861..3e59904 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -40,6 +40,7 @@
   eliminate_dead_functions_util.h
   eliminate_dead_members_pass.h
   feature_manager.h
+  fix_storage_class.h
   flatten_decoration_pass.h
   fold.h
   folding_rules.h
@@ -138,6 +139,7 @@
   eliminate_dead_functions_util.cpp
   eliminate_dead_members_pass.cpp
   feature_manager.cpp
+  fix_storage_class.cpp
   flatten_decoration_pass.cpp
   fold.cpp
   folding_rules.cpp
diff --git a/source/opt/fix_storage_class.cpp b/source/opt/fix_storage_class.cpp
new file mode 100644
index 0000000..6808cfd
--- /dev/null
+++ b/source/opt/fix_storage_class.cpp
@@ -0,0 +1,133 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fix_storage_class.h"
+
+#include "source/opt/instruction.h"
+#include "source/opt/ir_context.h"
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status FixStorageClass::Process() {
+  bool modified = false;
+
+  get_module()->ForEachInst([this, &modified](Instruction* inst) {
+    if (inst->opcode() == SpvOpVariable) {
+      std::vector<Instruction*> uses;
+      get_def_use_mgr()->ForEachUser(
+          inst, [&uses](Instruction* use) { uses.push_back(use); });
+      for (Instruction* use : uses) {
+        modified |= PropagateStorageClass(
+            use, static_cast<SpvStorageClass>(inst->GetSingleWordInOperand(0)));
+      }
+    }
+  });
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+bool FixStorageClass::PropagateStorageClass(Instruction* inst,
+                                            SpvStorageClass storage_class) {
+  if (!IsPointerResultType(inst)) {
+    return false;
+  }
+
+  if (IsPointerToStorageClass(inst, storage_class)) {
+    return false;
+  }
+
+  switch (inst->opcode()) {
+    case SpvOpAccessChain:
+    case SpvOpPtrAccessChain:
+    case SpvOpInBoundsAccessChain:
+    case SpvOpCopyObject:
+    case SpvOpPhi:
+    case SpvOpSelect:
+      FixInstruction(inst, storage_class);
+      return true;
+    case SpvOpFunctionCall:
+      // We cannot be sure of the actual connection between the storage class
+      // of the parameter and the storage class of the result, so we should not
+      // do anything.  If the result type needs to be fixed, the function call
+      // should be inlined.
+      return false;
+    case SpvOpImageTexelPointer:
+    case SpvOpLoad:
+    case SpvOpStore:
+    case SpvOpCopyMemory:
+    case SpvOpCopyMemorySized:
+    case SpvOpVariable:
+      // Nothing to change for these opcode.  The result type is the same
+      // regardless of the storage class of the operand.
+      return false;
+    default:
+      assert(false &&
+             "Not expecting instruction to have a pointer result type.");
+      return false;
+  }
+}
+
+void FixStorageClass::FixInstruction(Instruction* inst,
+                                     SpvStorageClass storage_class) {
+  assert(IsPointerResultType(inst) &&
+         "The result type of the instruction must be a pointer.");
+
+  ChangeResultStorageClass(inst, storage_class);
+
+  std::vector<Instruction*> uses;
+  get_def_use_mgr()->ForEachUser(
+      inst, [&uses](Instruction* use) { uses.push_back(use); });
+  for (Instruction* use : uses) {
+    PropagateStorageClass(use, storage_class);
+  }
+}
+
+void FixStorageClass::ChangeResultStorageClass(
+    Instruction* inst, SpvStorageClass storage_class) const {
+  analysis::TypeManager* type_mgr = context()->get_type_mgr();
+  Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
+  assert(result_type_inst->opcode() == SpvOpTypePointer);
+  uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
+  uint32_t new_result_type_id =
+      type_mgr->FindPointerToType(pointee_type_id, storage_class);
+  inst->SetResultType(new_result_type_id);
+  context()->UpdateDefUse(inst);
+}
+
+bool FixStorageClass::IsPointerResultType(Instruction* inst) {
+  if (inst->type_id() == 0) {
+    return false;
+  }
+  const analysis::Type* ret_type =
+      context()->get_type_mgr()->GetType(inst->type_id());
+  return ret_type->AsPointer() != nullptr;
+}
+
+bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
+                                              SpvStorageClass storage_class) {
+  analysis::TypeManager* type_mgr = context()->get_type_mgr();
+  analysis::Type* pType = type_mgr->GetType(inst->type_id());
+  const analysis::Pointer* result_type = pType->AsPointer();
+
+  if (result_type == nullptr) {
+    return false;
+  }
+
+  return (result_type->storage_class() == storage_class);
+}
+
+// namespace opt
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/fix_storage_class.h b/source/opt/fix_storage_class.h
new file mode 100644
index 0000000..c496db6
--- /dev/null
+++ b/source/opt/fix_storage_class.h
@@ -0,0 +1,75 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_FIX_STORAGE_CLASS_H_
+#define SOURCE_OPT_FIX_STORAGE_CLASS_H_
+
+#include <unordered_map>
+
+#include "source/opt/ir_context.h"
+#include "source/opt/module.h"
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// This pass tries to fix validation error due to a mismatch of storage classes
+// in instructions.  There is no guarantee that all such error will be fixed,
+// and it is possible that in fixing these errors, it could lead to other
+// errors.
+class FixStorageClass : public Pass {
+ public:
+  const char* name() const override { return "fix-storage-class"; }
+  Status Process() override;
+
+  // Return the mask of preserved Analyses.
+  IRContext::Analysis GetPreservedAnalyses() override {
+    return IRContext::kAnalysisDefUse |
+           IRContext::kAnalysisInstrToBlockMapping |
+           IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG |
+           IRContext::kAnalysisDominatorAnalysis |
+           IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap |
+           IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
+  }
+
+ private:
+  // Changes the storage class of the result of |inst| to |storage_class| in
+  // appropriate, and propagates the change to the users of |inst| as well.
+  // Returns true of any changes were made.
+  bool PropagateStorageClass(Instruction* inst, SpvStorageClass storage_class);
+
+  // Changes the storage class of the result of |inst| to |storage_class|.
+  // Is it assumed that the result type of |inst| is a pointer type.
+  // Propagates the change to the users of |inst| as well.
+  // Returns true of any changes were made.
+  void FixInstruction(Instruction* inst, SpvStorageClass storage_class);
+
+  // Changes the storage class of the result of |inst| to |storage_class|.  The
+  // result type of |inst| must be a pointer.
+  void ChangeResultStorageClass(Instruction* inst,
+                                SpvStorageClass storage_class) const;
+
+  // Returns true if the result type of |inst| is a pointer.
+  bool IsPointerResultType(Instruction* inst);
+
+  // Returns true if the result of |inst| is a pointer to storage class
+  // |storage_class|.
+  bool IsPointerToStorageClass(Instruction* inst,
+                               SpvStorageClass storage_class);
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // SOURCE_OPT_FIX_STORAGE_CLASS_H_
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index c6e48e6..b9e73ab 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -21,7 +21,6 @@
 #include <vector>
 
 #include <source/spirv_optimizer_options.h>
-#include "code_sink.h"
 #include "source/opt/build_module.h"
 #include "source/opt/log.h"
 #include "source/opt/pass_manager.h"
@@ -116,6 +115,10 @@
           // Make private variable function scope
           .RegisterPass(CreateEliminateDeadFunctionsPass())
           .RegisterPass(CreatePrivateToLocalPass())
+          // Fix up the storage classes that DXC may have purposely generated
+          // incorrectly.  All functions are inlined, and a lot of dead code has
+          // been removed.
+          .RegisterPass(CreateFixStorageClassPass())
           // Propagate the value stored to the loads in very simple cases.
           .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
           .RegisterPass(CreateLocalSingleStoreElimPass())
@@ -451,6 +454,8 @@
     RegisterPass(CreateCCPPass());
   } else if (pass_name == "code-sink") {
     RegisterPass(CreateCodeSinkingPass());
+  } else if (pass_name == "fix-storage-class") {
+    RegisterPass(CreateFixStorageClassPass());
   } else if (pass_name == "O") {
     RegisterPerformancePasses();
   } else if (pass_name == "Os") {
@@ -834,4 +839,9 @@
       MakeUnique<opt::GenerateWebGPUInitializersPass>());
 }
 
+Optimizer::PassToken CreateFixStorageClassPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::FixStorageClass>());
+}
+
 }  // namespace spvtools
diff --git a/source/opt/passes.h b/source/opt/passes.h
index d80f4ac..232e16c 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -32,6 +32,7 @@
 #include "source/opt/eliminate_dead_constant_pass.h"
 #include "source/opt/eliminate_dead_functions_pass.h"
 #include "source/opt/eliminate_dead_members_pass.h"
+#include "source/opt/fix_storage_class.h"
 #include "source/opt/flatten_decoration_pass.h"
 #include "source/opt/fold_spec_constant_op_and_composite_pass.h"
 #include "source/opt/freeze_spec_constant_value_pass.h"
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index adc78be..9e30f20 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -36,6 +36,7 @@
        eliminate_dead_functions_test.cpp
        eliminate_dead_member_test.cpp
        feature_manager_test.cpp
+       fix_storage_class_test.cpp
        flatten_decoration_test.cpp
        fold_spec_const_op_composite_test.cpp
        fold_test.cpp
diff --git a/test/opt/fix_storage_class_test.cpp b/test/opt/fix_storage_class_test.cpp
new file mode 100644
index 0000000..1cb0c80
--- /dev/null
+++ b/test/opt/fix_storage_class_test.cpp
@@ -0,0 +1,445 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "test/opt/assembly_builder.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using FixStorageClassTest = PassTest<::testing::Test>;
+
+TEST_F(FixStorageClassTest, FixAccessChain) {
+  const std::string text = R"(
+; CHECK: OpAccessChain %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Uniform_float
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "testMain" %gl_GlobalInvocationID %gl_LocalInvocationID %gl_WorkGroupID
+               OpExecutionMode %1 LocalSize 8 8 1
+               OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+               OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+               OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+               OpDecorate %8 DescriptorSet 0
+               OpDecorate %8 Binding 0
+               OpDecorate %_runtimearr_float ArrayStride 4
+               OpMemberDecorate %_struct_7 0 Offset 0
+               OpDecorate %_struct_7 BufferBlock
+        %int = OpTypeInt 32 1
+      %int_0 = OpConstant %int 0
+      %float = OpTypeFloat 32
+    %float_2 = OpConstant %float 2
+       %uint = OpTypeInt 32 0
+    %uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%ptr = OpTypePointer Function %_arr_float_uint_10
+%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10
+  %_struct_5 = OpTypeStruct %_arr__arr_float_uint_10_uint_10
+%_ptr_Workgroup__struct_5 = OpTypePointer Workgroup %_struct_5
+%_runtimearr_float = OpTypeRuntimeArray %float
+  %_struct_7 = OpTypeStruct %_runtimearr_float
+%_ptr_Uniform__struct_7 = OpTypePointer Uniform %_struct_7
+     %v3uint = OpTypeVector %uint 3
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+       %void = OpTypeVoid
+         %30 = OpTypeFunction %void
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+          %6 = OpVariable %_ptr_Workgroup__struct_5 Workgroup
+          %8 = OpVariable %_ptr_Uniform__struct_7 Uniform
+%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
+          %1 = OpFunction %void None %30
+         %38 = OpLabel
+         %44 = OpLoad %v3uint %gl_LocalInvocationID
+         %50 = OpAccessChain %_ptr_Function_float %6 %int_0 %int_0 %int_0
+         %51 = OpLoad %float %50
+         %52 = OpFMul %float %float_2 %51
+               OpStore %50 %52
+         %55 = OpLoad %float %50
+         %59 = OpCompositeExtract %uint %44 0
+         %60 = OpAccessChain %_ptr_Uniform_float %8 %int_0 %59
+               OpStore %60 %55
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FixStorageClass>(text, false);
+}
+
+TEST_F(FixStorageClassTest, FixLinkedAccessChain) {
+  const std::string text = R"(
+; CHECK: OpAccessChain %_ptr_Workgroup__arr_float_uint_10
+; CHECK: OpAccessChain %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Uniform_float
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "testMain" %gl_GlobalInvocationID %gl_LocalInvocationID %gl_WorkGroupID
+               OpExecutionMode %1 LocalSize 8 8 1
+               OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+               OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+               OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+               OpDecorate %5 DescriptorSet 0
+               OpDecorate %5 Binding 0
+               OpDecorate %_runtimearr_float ArrayStride 4
+               OpMemberDecorate %_struct_7 0 Offset 0
+               OpDecorate %_struct_7 BufferBlock
+        %int = OpTypeInt 32 1
+      %int_0 = OpConstant %int 0
+      %float = OpTypeFloat 32
+    %float_2 = OpConstant %float 2
+       %uint = OpTypeInt 32 0
+    %uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
+%_ptr = OpTypePointer Function %_arr_float_uint_10
+%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10
+ %_struct_17 = OpTypeStruct %_arr__arr_float_uint_10_uint_10
+%_ptr_Workgroup__struct_17 = OpTypePointer Workgroup %_struct_17
+%_runtimearr_float = OpTypeRuntimeArray %float
+  %_struct_7 = OpTypeStruct %_runtimearr_float
+%_ptr_Uniform__struct_7 = OpTypePointer Uniform %_struct_7
+     %v3uint = OpTypeVector %uint 3
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+       %void = OpTypeVoid
+         %23 = OpTypeFunction %void
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+         %27 = OpVariable %_ptr_Workgroup__struct_17 Workgroup
+          %5 = OpVariable %_ptr_Uniform__struct_7 Uniform
+%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
+          %1 = OpFunction %void None %23
+         %28 = OpLabel
+         %29 = OpLoad %v3uint %gl_LocalInvocationID
+         %30 = OpAccessChain %_ptr_Function__arr_float_uint_10 %27 %int_0 %int_0
+         %31 = OpAccessChain %_ptr_Function_float %30 %int_0
+         %32 = OpLoad %float %31
+         %33 = OpFMul %float %float_2 %32
+               OpStore %31 %33
+         %34 = OpLoad %float %31
+         %35 = OpCompositeExtract %uint %29 0
+         %36 = OpAccessChain %_ptr_Uniform_float %5 %int_0 %35
+               OpStore %36 %34
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FixStorageClass>(text, false);
+}
+
+TEST_F(FixStorageClassTest, FixCopyObject) {
+  const std::string text = R"(
+; CHECK: OpCopyObject %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Uniform_float
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "testMain" %gl_GlobalInvocationID %gl_LocalInvocationID %gl_WorkGroupID
+               OpExecutionMode %1 LocalSize 8 8 1
+               OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+               OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+               OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+               OpDecorate %8 DescriptorSet 0
+               OpDecorate %8 Binding 0
+               OpDecorate %_runtimearr_float ArrayStride 4
+               OpMemberDecorate %_struct_7 0 Offset 0
+               OpDecorate %_struct_7 BufferBlock
+        %int = OpTypeInt 32 1
+      %int_0 = OpConstant %int 0
+      %float = OpTypeFloat 32
+    %float_2 = OpConstant %float 2
+       %uint = OpTypeInt 32 0
+    %uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%ptr = OpTypePointer Function %_arr_float_uint_10
+%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10
+  %_struct_5 = OpTypeStruct %_arr__arr_float_uint_10_uint_10
+%_ptr_Workgroup__struct_5 = OpTypePointer Workgroup %_struct_5
+%_runtimearr_float = OpTypeRuntimeArray %float
+  %_struct_7 = OpTypeStruct %_runtimearr_float
+%_ptr_Uniform__struct_7 = OpTypePointer Uniform %_struct_7
+     %v3uint = OpTypeVector %uint 3
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+       %void = OpTypeVoid
+         %30 = OpTypeFunction %void
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+          %6 = OpVariable %_ptr_Workgroup__struct_5 Workgroup
+          %8 = OpVariable %_ptr_Uniform__struct_7 Uniform
+%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
+          %1 = OpFunction %void None %30
+         %38 = OpLabel
+         %44 = OpLoad %v3uint %gl_LocalInvocationID
+         %cp = OpCopyObject %_ptr_Function_float %6
+         %50 = OpAccessChain %_ptr_Function_float %cp %int_0 %int_0 %int_0
+         %51 = OpLoad %float %50
+         %52 = OpFMul %float %float_2 %51
+               OpStore %50 %52
+         %55 = OpLoad %float %50
+         %59 = OpCompositeExtract %uint %44 0
+         %60 = OpAccessChain %_ptr_Uniform_float %8 %int_0 %59
+               OpStore %60 %55
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FixStorageClass>(text, false);
+}
+
+TEST_F(FixStorageClassTest, FixPhiInSelMerge) {
+  const std::string text = R"(
+; CHECK: OpPhi %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Uniform_float
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "testMain" %gl_GlobalInvocationID %gl_LocalInvocationID %gl_WorkGroupID
+               OpExecutionMode %1 LocalSize 8 8 1
+               OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+               OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+               OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+               OpDecorate %5 DescriptorSet 0
+               OpDecorate %5 Binding 0
+               OpDecorate %_runtimearr_float ArrayStride 4
+               OpMemberDecorate %_struct_7 0 Offset 0
+               OpDecorate %_struct_7 BufferBlock
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+        %int = OpTypeInt 32 1
+      %int_0 = OpConstant %int 0
+      %float = OpTypeFloat 32
+    %float_2 = OpConstant %float 2
+       %uint = OpTypeInt 32 0
+    %uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
+%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10
+ %_struct_19 = OpTypeStruct %_arr__arr_float_uint_10_uint_10
+%_ptr_Workgroup__struct_19 = OpTypePointer Workgroup %_struct_19
+%_runtimearr_float = OpTypeRuntimeArray %float
+  %_struct_7 = OpTypeStruct %_runtimearr_float
+%_ptr_Uniform__struct_7 = OpTypePointer Uniform %_struct_7
+     %v3uint = OpTypeVector %uint 3
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+       %void = OpTypeVoid
+         %25 = OpTypeFunction %void
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+         %28 = OpVariable %_ptr_Workgroup__struct_19 Workgroup
+         %29 = OpVariable %_ptr_Workgroup__struct_19 Workgroup
+          %5 = OpVariable %_ptr_Uniform__struct_7 Uniform
+%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
+          %1 = OpFunction %void None %25
+         %30 = OpLabel
+               OpSelectionMerge %31 None
+               OpBranchConditional %true %32 %31
+         %32 = OpLabel
+               OpBranch %31
+         %31 = OpLabel
+         %33 = OpPhi %_ptr_Function_float %28 %30 %29 %32
+         %34 = OpLoad %v3uint %gl_LocalInvocationID
+         %35 = OpAccessChain %_ptr_Function_float %33 %int_0 %int_0 %int_0
+         %36 = OpLoad %float %35
+         %37 = OpFMul %float %float_2 %36
+               OpStore %35 %37
+         %38 = OpLoad %float %35
+         %39 = OpCompositeExtract %uint %34 0
+         %40 = OpAccessChain %_ptr_Uniform_float %5 %int_0 %39
+               OpStore %40 %38
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FixStorageClass>(text, false);
+}
+
+TEST_F(FixStorageClassTest, FixPhiInLoop) {
+  const std::string text = R"(
+; CHECK: OpPhi %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Uniform_float
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "testMain" %gl_GlobalInvocationID %gl_LocalInvocationID %gl_WorkGroupID
+               OpExecutionMode %1 LocalSize 8 8 1
+               OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+               OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+               OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+               OpDecorate %5 DescriptorSet 0
+               OpDecorate %5 Binding 0
+               OpDecorate %_runtimearr_float ArrayStride 4
+               OpMemberDecorate %_struct_7 0 Offset 0
+               OpDecorate %_struct_7 BufferBlock
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+        %int = OpTypeInt 32 1
+      %int_0 = OpConstant %int 0
+      %float = OpTypeFloat 32
+    %float_2 = OpConstant %float 2
+       %uint = OpTypeInt 32 0
+    %uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
+%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10
+ %_struct_19 = OpTypeStruct %_arr__arr_float_uint_10_uint_10
+%_ptr_Workgroup__struct_19 = OpTypePointer Workgroup %_struct_19
+%_runtimearr_float = OpTypeRuntimeArray %float
+  %_struct_7 = OpTypeStruct %_runtimearr_float
+%_ptr_Uniform__struct_7 = OpTypePointer Uniform %_struct_7
+     %v3uint = OpTypeVector %uint 3
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+       %void = OpTypeVoid
+         %25 = OpTypeFunction %void
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+         %28 = OpVariable %_ptr_Workgroup__struct_19 Workgroup
+         %29 = OpVariable %_ptr_Workgroup__struct_19 Workgroup
+          %5 = OpVariable %_ptr_Uniform__struct_7 Uniform
+%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
+          %1 = OpFunction %void None %25
+         %30 = OpLabel
+               OpSelectionMerge %31 None
+               OpBranchConditional %true %32 %31
+         %32 = OpLabel
+               OpBranch %31
+         %31 = OpLabel
+         %33 = OpPhi %_ptr_Function_float %28 %30 %29 %32
+         %34 = OpLoad %v3uint %gl_LocalInvocationID
+         %35 = OpAccessChain %_ptr_Function_float %33 %int_0 %int_0 %int_0
+         %36 = OpLoad %float %35
+         %37 = OpFMul %float %float_2 %36
+               OpStore %35 %37
+         %38 = OpLoad %float %35
+         %39 = OpCompositeExtract %uint %34 0
+         %40 = OpAccessChain %_ptr_Uniform_float %5 %int_0 %39
+               OpStore %40 %38
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FixStorageClass>(text, false);
+}
+
+TEST_F(FixStorageClassTest, DontChangeFunctionCalls) {
+  const std::string text = R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %1 "testMain"
+OpExecutionMode %1 LocalSize 8 8 1
+OpDecorate %2 DescriptorSet 0
+OpDecorate %2 Binding 0
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%_ptr_Workgroup_int = OpTypePointer Workgroup %int
+%_ptr_Uniform_int = OpTypePointer Uniform %int
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%9 = OpTypeFunction %_ptr_Uniform_int %_ptr_Function_int
+%10 = OpVariable %_ptr_Workgroup_int Workgroup
+%2 = OpVariable %_ptr_Uniform_int Uniform
+%1 = OpFunction %void None %8
+%11 = OpLabel
+%12 = OpFunctionCall %_ptr_Uniform_int %13 %10
+OpReturn
+OpFunctionEnd
+%13 = OpFunction %_ptr_Uniform_int None %9
+%14 = OpFunctionParameter %_ptr_Function_int
+%15 = OpLabel
+OpReturnValue %2
+OpFunctionEnd
+)";
+
+  SinglePassRunAndCheck<FixStorageClass>(text, text, false, false);
+}
+
+TEST_F(FixStorageClassTest, FixSelect) {
+  const std::string text = R"(
+; CHECK: OpSelect %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Workgroup_float
+; CHECK: OpAccessChain %_ptr_Uniform_float
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "testMain" %gl_GlobalInvocationID %gl_LocalInvocationID %gl_WorkGroupID
+               OpExecutionMode %1 LocalSize 8 8 1
+               OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
+               OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+               OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+               OpDecorate %5 DescriptorSet 0
+               OpDecorate %5 Binding 0
+               OpDecorate %_runtimearr_float ArrayStride 4
+               OpMemberDecorate %_struct_7 0 Offset 0
+               OpDecorate %_struct_7 BufferBlock
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+        %int = OpTypeInt 32 1
+      %int_0 = OpConstant %int 0
+      %float = OpTypeFloat 32
+    %float_2 = OpConstant %float 2
+       %uint = OpTypeInt 32 0
+    %uint_10 = OpConstant %uint 10
+%_arr_float_uint_10 = OpTypeArray %float %uint_10
+%_ptr_Function__arr_float_uint_10 = OpTypePointer Function %_arr_float_uint_10
+%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10
+ %_struct_19 = OpTypeStruct %_arr__arr_float_uint_10_uint_10
+%_ptr_Workgroup__struct_19 = OpTypePointer Workgroup %_struct_19
+%_runtimearr_float = OpTypeRuntimeArray %float
+  %_struct_7 = OpTypeStruct %_runtimearr_float
+%_ptr_Uniform__struct_7 = OpTypePointer Uniform %_struct_7
+     %v3uint = OpTypeVector %uint 3
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+       %void = OpTypeVoid
+         %25 = OpTypeFunction %void
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+         %28 = OpVariable %_ptr_Workgroup__struct_19 Workgroup
+         %29 = OpVariable %_ptr_Workgroup__struct_19 Workgroup
+          %5 = OpVariable %_ptr_Uniform__struct_7 Uniform
+%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
+%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
+          %1 = OpFunction %void None %25
+         %30 = OpLabel
+         %33 = OpSelect %_ptr_Function_float %true %28 %29
+         %34 = OpLoad %v3uint %gl_LocalInvocationID
+         %35 = OpAccessChain %_ptr_Function_float %33 %int_0 %int_0 %int_0
+         %36 = OpLoad %float %35
+         %37 = OpFMul %float %float_2 %36
+               OpStore %35 %37
+         %38 = OpLoad %float %35
+         %39 = OpCompositeExtract %uint %34 0
+         %40 = OpAccessChain %_ptr_Uniform_float %5 %int_0 %39
+               OpStore %40 %38
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FixStorageClass>(text, false);
+}
+
+}  // namespace
+}  // namespace opt
+}  // namespace spvtools
diff --git a/test/tools/opt/flags.py b/test/tools/opt/flags.py
index 69462fc..411b000 100644
--- a/test/tools/opt/flags.py
+++ b/test/tools/opt/flags.py
@@ -227,6 +227,7 @@
       'inline-entry-points-exhaustive',
       'eliminate-dead-functions',
       'private-to-local',
+      'fix-storage-class',
       'eliminate-local-single-block',
       'eliminate-local-single-store',
       'eliminate-dead-code-aggressive',