Split EliminateDeadInputComponents into safe and unsafe versions. (#4984)

Safe version will only optimize vertex shaders. All other shaders will
succeed without change.

Change --eliminate-dead-input-components to use new safe version.

Unsafe version (allowing non-vertex shaders) currently only available
through API. Should only be used in combination with other optimizations
to keep interfaces consistent. See optimizer.hpp for more details.
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index 17a2556..41752d6 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -913,6 +913,14 @@
 // shader, then apply EliminateDeadOutputStores to this shader.
 Optimizer::PassToken CreateEliminateDeadOutputComponentsPass();
 
+// Removes unused components from composite input variables. This safe
+// version will not cause interface incompatibilities since it only changes
+// vertex shaders. The current implementation just removes trailing unused
+// components from input structs and input arrays. The pass performs best
+// after maximizing dead code removal. A subsequent dead code elimination
+// pass would be beneficial in removing newly unused component types.
+Optimizer::PassToken CreateEliminateDeadInputComponentsSafePass();
+
 // Analyzes shader and populates |live_locs| and |live_builtins|. Best results
 // will be obtained if shader has all dead code eliminated first. |live_locs|
 // and |live_builtins| are subsequently used when calling
diff --git a/source/opt/eliminate_dead_input_components_pass.cpp b/source/opt/eliminate_dead_input_components_pass.cpp
index f31b567..260aa3d 100644
--- a/source/opt/eliminate_dead_input_components_pass.cpp
+++ b/source/opt/eliminate_dead_input_components_pass.cpp
@@ -35,7 +35,11 @@
 namespace opt {
 
 Pass::Status EliminateDeadInputComponentsPass::Process() {
-  // Current functionality assumes shader capability
+  // Process non-vertex only if explicitly allowed.
+  auto stage = context()->GetStage();
+  if (stage != spv::ExecutionModel::Vertex && vertex_shader_only_)
+    return Status::SuccessWithoutChange;
+  // Current functionality assumes shader capability.
   if (!context()->get_feature_mgr()->HasCapability(spv::Capability::Shader))
     return Status::SuccessWithoutChange;
   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
@@ -62,13 +66,21 @@
     }
     const analysis::Array* arr_type = ptr_type->pointee_type()->AsArray();
     if (arr_type != nullptr) {
+      // Only process array if input of vertex shader, or output of
+      // fragment shader. Otherwise, if one shader has a runtime index and the
+      // other does not, interface incompatibility can occur.
+      if (!((ptr_type->storage_class() == spv::StorageClass::Input &&
+             stage == spv::ExecutionModel::Vertex) ||
+            (ptr_type->storage_class() == spv::StorageClass::Output &&
+             stage == spv::ExecutionModel::Fragment)))
+        continue;
       unsigned arr_len_id = arr_type->LengthId();
       Instruction* arr_len_inst = def_use_mgr->GetDef(arr_len_id);
       if (arr_len_inst->opcode() != spv::Op::OpConstant) {
         continue;
       }
       // SPIR-V requires array size is >= 1, so this works for signed or
-      // unsigned size
+      // unsigned size.
       unsigned original_max =
           arr_len_inst->GetSingleWordInOperand(kConstantValueInIdx) - 1;
       unsigned max_idx = FindMaxIndex(var, original_max);
@@ -92,7 +104,7 @@
   }
 
   // Move changed vars after their new type instruction to preserve backward
-  // referencing
+  // referencing.
   for (auto var : vars_to_move) {
     auto type_id = var->type_id();
     auto type_inst = def_use_mgr->GetDef(type_id);
diff --git a/source/opt/eliminate_dead_input_components_pass.h b/source/opt/eliminate_dead_input_components_pass.h
index 16b4545..111366e 100644
--- a/source/opt/eliminate_dead_input_components_pass.h
+++ b/source/opt/eliminate_dead_input_components_pass.h
@@ -28,13 +28,14 @@
 // See optimizer.hpp for documentation.
 class EliminateDeadInputComponentsPass : public Pass {
  public:
-  explicit EliminateDeadInputComponentsPass(bool output_instead = false)
-      : output_instead_(output_instead) {}
+  explicit EliminateDeadInputComponentsPass(bool output_instead = false,
+                                            bool vertex_shader_only = true)
+      : output_instead_(output_instead),
+        vertex_shader_only_(vertex_shader_only) {}
 
   const char* name() const override {
     return "eliminate-dead-input-components";
   }
-
   Status Process() override;
 
   // Return the mask of preserved Analyses.
@@ -61,6 +62,9 @@
 
   // Process output variables instead
   bool output_instead_;
+
+  // Only process vertex shaders
+  bool vertex_shader_only_;
 };
 
 }  // namespace opt
diff --git a/source/opt/instrument_pass.cpp b/source/opt/instrument_pass.cpp
index 88fa5e1..9318383 100644
--- a/source/opt/instrument_pass.cpp
+++ b/source/opt/instrument_pass.cpp
@@ -26,7 +26,6 @@
 static const int kInstCommonParamCnt = 1;
 
 // Indices of operands in SPIR-V instructions
-static const int kEntryPointExecutionModelInIdx = 0;
 static const int kEntryPointFunctionIdInIdx = 1;
 
 }  // anonymous namespace
@@ -1056,22 +1055,7 @@
   // one model per module. In such cases we will need
   // to clone any functions which are in the call trees of entrypoints
   // with differing execution models.
-  uint32_t ecnt = 0;
-  auto stage = spv::ExecutionModel::Max;
-  for (auto& e : get_module()->entry_points()) {
-    if (ecnt == 0)
-      stage = spv::ExecutionModel(
-          e.GetSingleWordInOperand(kEntryPointExecutionModelInIdx));
-    else if (spv::ExecutionModel(e.GetSingleWordInOperand(
-                 kEntryPointExecutionModelInIdx)) != stage) {
-      if (consumer()) {
-        std::string message = "Mixed stage shader module not supported";
-        consumer()(SPV_MSG_ERROR, 0, {0, 0, 0}, message.c_str());
-      }
-      return false;
-    }
-    ++ecnt;
-  }
+  spv::ExecutionModel stage = context()->GetStage();
   // Check for supported stages
   if (stage != spv::ExecutionModel::Vertex &&
       stage != spv::ExecutionModel::Fragment &&
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index 7583bd1..4cf3292 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -525,7 +525,7 @@
   } else if (pass_name == "remove-dont-inline") {
     RegisterPass(CreateRemoveDontInlinePass());
   } else if (pass_name == "eliminate-dead-input-components") {
-    RegisterPass(CreateEliminateDeadInputComponentsPass());
+    RegisterPass(CreateEliminateDeadInputComponentsSafePass());
   } else if (pass_name == "fix-func-call-param") {
     RegisterPass(CreateFixFuncCallArgumentsPass());
   } else if (pass_name == "convert-to-sampled-image") {
@@ -1017,6 +1017,18 @@
 
 Optimizer::PassToken CreateEliminateDeadInputComponentsPass() {
   return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::EliminateDeadInputComponentsPass>(
+          /* output_instead */ false, /* vertex_shader_only */ false));
+}
+
+Optimizer::PassToken CreateEliminateDeadOutputComponentsPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::EliminateDeadInputComponentsPass>(
+          /* output_instead */ true, /* vertex_shader_only */ false));
+}
+
+Optimizer::PassToken CreateEliminateDeadInputComponentsSafePass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
       MakeUnique<opt::EliminateDeadInputComponentsPass>());
 }
 
@@ -1034,12 +1046,6 @@
       MakeUnique<opt::EliminateDeadOutputStoresPass>(live_locs, live_builtins));
 }
 
-Optimizer::PassToken CreateEliminateDeadOutputComponentsPass() {
-  return MakeUnique<Optimizer::PassToken::Impl>(
-      MakeUnique<opt::EliminateDeadInputComponentsPass>(
-          /* output_instead */ true));
-}
-
 Optimizer::PassToken CreateConvertToSampledImagePass(
     const std::vector<opt::DescriptorSetAndBinding>&
         descriptor_set_binding_pairs) {
diff --git a/test/opt/eliminate_dead_input_components_test.cpp b/test/opt/eliminate_dead_input_components_test.cpp
index 2c2e636..667350d 100644
--- a/test/opt/eliminate_dead_input_components_test.cpp
+++ b/test/opt/eliminate_dead_input_components_test.cpp
@@ -85,7 +85,8 @@
 
   SetTargetEnv(SPV_ENV_VULKAN_1_3);
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, false,
+                                                          false);
 }
 
 TEST_F(ElimDeadInputComponentsTest, ElimOneConstantIndexInBounds) {
@@ -135,7 +136,8 @@
 
   SetTargetEnv(SPV_ENV_VULKAN_1_3);
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, false,
+                                                          false);
 }
 
 TEST_F(ElimDeadInputComponentsTest, ElimTwoConstantIndices) {
@@ -202,7 +204,8 @@
 
   SetTargetEnv(SPV_ENV_VULKAN_1_3);
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, false,
+                                                          false);
 }
 
 TEST_F(ElimDeadInputComponentsTest, NoElimMaxConstantIndex) {
@@ -268,7 +271,8 @@
 
   SetTargetEnv(SPV_ENV_VULKAN_1_3);
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, false,
+                                                          false);
 }
 
 TEST_F(ElimDeadInputComponentsTest, NoElimNonConstantIndex) {
@@ -350,7 +354,8 @@
 
   SetTargetEnv(SPV_ENV_VULKAN_1_3);
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, false,
+                                                          false);
 }
 
 TEST_F(ElimDeadInputComponentsTest, NoElimNonIndexedAccessChain) {
@@ -396,7 +401,8 @@
 
   SetTargetEnv(SPV_ENV_VULKAN_1_3);
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, false,
+                                                          false);
 }
 
 TEST_F(ElimDeadInputComponentsTest, ElimStructMember) {
@@ -460,7 +466,8 @@
 
   SetTargetEnv(SPV_ENV_VULKAN_1_3);
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, false,
+                                                          false);
 }
 
 TEST_F(ElimDeadInputComponentsTest, ElimOutputStructMember) {
@@ -558,7 +565,8 @@
 
   SetTargetEnv(SPV_ENV_VULKAN_1_3);
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, true);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, true,
+                                                          false);
 }
 
 TEST_F(ElimDeadInputComponentsTest, ElimOutputArrayMembers) {
@@ -577,7 +585,8 @@
                OpCapability Shader
           %1 = OpExtInstImport "GLSL.std.450"
                OpMemoryModel Logical GLSL450
-               OpEntryPoint Vertex %main "main" %uv
+               OpEntryPoint Fragment %main "main" %uv
+               OpExecutionMode %main OriginUpperLeft
                OpSource GLSL 450
                OpName %main "main"
                OpName %uv "uv"
@@ -609,7 +618,72 @@
 
   SetTargetEnv(SPV_ENV_VULKAN_1_3);
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, true);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, true,
+                                                          false);
+}
+
+TEST_F(ElimDeadInputComponentsTest, VertexOnly) {
+  // Should NOT eliminate uv
+  //
+  // #version 450
+  //
+  // in Vertex {
+  //   vec4 Cd;
+  //   vec2 uv;
+  // } iVert;
+  //
+  // out vec4 fragColor;
+  //
+  // void main()
+  // {
+  //   vec4 color = vec4(iVert.Cd);
+  //   fragColor = color;
+  // }
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %iVert %fragColor
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %Vertex "Vertex"
+               OpMemberName %Vertex 0 "Cd"
+               OpMemberName %Vertex 1 "uv"
+               OpName %iVert "iVert"
+               OpName %fragColor "fragColor"
+               OpDecorate %Vertex Block
+               OpDecorate %iVert Location 0
+               OpDecorate %fragColor Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+    %v2float = OpTypeVector %float 2
+     %Vertex = OpTypeStruct %v4float %v2float
+; CHECK: %Vertex = OpTypeStruct %v4float %v2float
+%_ptr_Input_Vertex = OpTypePointer Input %Vertex
+; CHECK: %_ptr_Input_Vertex = OpTypePointer Input %Vertex
+      %iVert = OpVariable %_ptr_Input_Vertex Input
+; CHECK: %iVert = OpVariable %_ptr_Input_Vertex Input
+        %int = OpTypeInt 32 1
+      %int_0 = OpConstant %int 0
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+  %fragColor = OpVariable %_ptr_Output_v4float Output
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %17 = OpAccessChain %_ptr_Input_v4float %iVert %int_0
+         %18 = OpLoad %v4float %17
+               OpStore %fragColor %18
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SetTargetEnv(SPV_ENV_VULKAN_1_3);
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true, false,
+                                                          true);
 }
 
 }  // namespace