Instrument: Add version 2 of record formats (#2630)

New version has additional word in stage-specific section. Also
some changes in content for tesselation and compute shaders. Either
version can be invoked at pass creation. This is done to ease integration
and updating of validation layers. Version 1 is deprecated and eventually
will go away.

Also sneaking in fix to version 1 compute shaders.
diff --git a/include/spirv-tools/instrument.hpp b/include/spirv-tools/instrument.hpp
index 9711b92..4d37e3f 100644
--- a/include/spirv-tools/instrument.hpp
+++ b/include/spirv-tools/instrument.hpp
@@ -34,6 +34,11 @@
 // generated by InstrumentPass::GenDebugStreamWrite. This method is utilized
 // by InstBindlessCheckPass.
 //
+// kInst2* values support version 2 of the output record format. These should
+// be used if available and version 2 is enabled. Version 1 is DEPRECATED.
+// Specifically, version 1 uses two words for the stage-specific section of
+// the output record; version 2 uses three words.
+//
 // The first member of the debug output buffer contains the next available word
 // in the data stream to be written. Shaders will atomically read and update
 // this value so as not to overwrite each others records. This value must be
@@ -70,38 +75,58 @@
 
 // Stage-specific Stream Record Offsets
 //
-// Each stage will contain different values in the next two words of the record
-// used to identify which instantiation of the shader generated the validation
-// error.
+// Each stage will contain different values in the next set of words of the
+// record used to identify which instantiation of the shader generated the
+// validation error.
 //
 // Vertex Shader Output Record Offsets
 static const int kInstVertOutVertexIndex = kInstCommonOutCnt;
 static const int kInstVertOutInstanceIndex = kInstCommonOutCnt + 1;
+static const int kInstVertOutUnused = kInstCommonOutCnt + 2;
 
 // Frag Shader Output Record Offsets
 static const int kInstFragOutFragCoordX = kInstCommonOutCnt;
 static const int kInstFragOutFragCoordY = kInstCommonOutCnt + 1;
+static const int kInstFragOutUnused = kInstCommonOutCnt + 2;
 
 // Compute Shader Output Record Offsets
+static const int kInstCompOutGlobalInvocationIdX = kInstCommonOutCnt;
+static const int kInstCompOutGlobalInvocationIdY = kInstCommonOutCnt + 1;
+static const int kInstCompOutGlobalInvocationIdZ = kInstCommonOutCnt + 2;
+
+// Compute Shader Output Record Offsets - Version 1 (DEPRECATED)
 static const int kInstCompOutGlobalInvocationId = kInstCommonOutCnt;
 static const int kInstCompOutUnused = kInstCommonOutCnt + 1;
 
-// Tessellation Shader Output Record Offsets
+// Tessellation Control Shader Output Record Offsets
+static const int kInstTessCtlOutInvocationId = kInstCommonOutCnt;
+static const int kInstTessCtlOutPrimitiveId = kInstCommonOutCnt + 1;
+static const int kInstTessCtlOutUnused = kInstCommonOutCnt + 2;
+
+// Tessellation Eval Shader Output Record Offsets
+static const int kInstTessEvalOutPrimitiveId = kInstCommonOutCnt;
+static const int kInstTessEvalOutTessCoordU = kInstCommonOutCnt + 1;
+static const int kInstTessEvalOutTessCoordV = kInstCommonOutCnt + 2;
+
+// Tessellation Shader Output Record Offsets - Version 1 (DEPRECATED)
 static const int kInstTessOutInvocationId = kInstCommonOutCnt;
 static const int kInstTessOutUnused = kInstCommonOutCnt + 1;
 
 // Geometry Shader Output Record Offsets
 static const int kInstGeomOutPrimitiveId = kInstCommonOutCnt;
 static const int kInstGeomOutInvocationId = kInstCommonOutCnt + 1;
+static const int kInstGeomOutUnused = kInstCommonOutCnt + 2;
 
 // Size of Common and Stage-specific Members
 static const int kInstStageOutCnt = kInstCommonOutCnt + 2;
+static const int kInst2StageOutCnt = kInstCommonOutCnt + 3;
 
-// Validation Error Code
+// Validation Error Code Offset
 //
 // This identifies the validation error. It also helps to identify
 // how many words follow in the record and their meaning.
 static const int kInstValidationOutError = kInstStageOutCnt;
+static const int kInst2ValidationOutError = kInst2StageOutCnt;
 
 // Validation-specific Output Record Offsets
 //
@@ -114,11 +139,19 @@
 static const int kInstBindlessBoundsOutDescBound = kInstStageOutCnt + 2;
 static const int kInstBindlessBoundsOutCnt = kInstStageOutCnt + 3;
 
+static const int kInst2BindlessBoundsOutDescIndex = kInst2StageOutCnt + 1;
+static const int kInst2BindlessBoundsOutDescBound = kInst2StageOutCnt + 2;
+static const int kInst2BindlessBoundsOutCnt = kInst2StageOutCnt + 3;
+
 // A bindless uninitialized error will output the index.
 static const int kInstBindlessUninitOutDescIndex = kInstStageOutCnt + 1;
 static const int kInstBindlessUninitOutUnused = kInstStageOutCnt + 2;
 static const int kInstBindlessUninitOutCnt = kInstStageOutCnt + 3;
 
+static const int kInst2BindlessUninitOutDescIndex = kInst2StageOutCnt + 1;
+static const int kInst2BindlessUninitOutUnused = kInst2StageOutCnt + 2;
+static const int kInst2BindlessUninitOutCnt = kInst2StageOutCnt + 3;
+
 // DEPRECATED
 static const int kInstBindlessOutDescIndex = kInstStageOutCnt + 1;
 static const int kInstBindlessOutDescBound = kInstStageOutCnt + 2;
@@ -126,6 +159,7 @@
 
 // Maximum Output Record Member Count
 static const int kInstMaxOutCnt = kInstStageOutCnt + 3;
+static const int kInst2MaxOutCnt = kInst2StageOutCnt + 3;
 
 // Validation Error Codes
 //
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index 313c992..1d8d6e0 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -738,9 +738,10 @@
 // |input_length_enable| controls instrumentation of runtime descriptor array
 // references, and |input_init_enable| controls instrumentation of descriptor
 // initialization checking, both of which require input buffer support.
+// |version| specifies the buffer record format.
 Optimizer::PassToken CreateInstBindlessCheckPass(
     uint32_t desc_set, uint32_t shader_id, bool input_length_enable = false,
-    bool input_init_enable = false);
+    bool input_init_enable = false, uint32_t version = 1);
 
 // Create a pass to upgrade to the VulkanKHR memory model.
 // This pass upgrades the Logical GLSL450 memory model to Logical VulkanKHR.
diff --git a/source/opt/inst_bindless_check_pass.h b/source/opt/inst_bindless_check_pass.h
index 5e9921e..12384b1 100644
--- a/source/opt/inst_bindless_check_pass.h
+++ b/source/opt/inst_bindless_check_pass.h
@@ -30,13 +30,14 @@
  public:
   // For test harness only
   InstBindlessCheckPass()
-      : InstrumentPass(7, 23, kInstValidationIdBindless),
+      : InstrumentPass(7, 23, kInstValidationIdBindless, 1),
         input_length_enabled_(true),
         input_init_enabled_(true) {}
   // For all other interfaces
   InstBindlessCheckPass(uint32_t desc_set, uint32_t shader_id,
-                        bool input_length_enable, bool input_init_enable)
-      : InstrumentPass(desc_set, shader_id, kInstValidationIdBindless),
+                        bool input_length_enable, bool input_init_enable,
+                        uint32_t version)
+      : InstrumentPass(desc_set, shader_id, kInstValidationIdBindless, version),
         input_length_enabled_(input_length_enable),
         input_init_enabled_(input_init_enable) {}
 
diff --git a/source/opt/instrument_pass.cpp b/source/opt/instrument_pass.cpp
index 032cd28..6f448d7 100644
--- a/source/opt/instrument_pass.cpp
+++ b/source/opt/instrument_pass.cpp
@@ -143,23 +143,21 @@
                           element_val_inst->result_id(), builder);
 }
 
+uint32_t InstrumentPass::GenVarLoad(uint32_t var_id,
+                                    InstructionBuilder* builder) {
+  Instruction* var_inst = get_def_use_mgr()->GetDef(var_id);
+  uint32_t type_id = GetPointeeTypeId(var_inst);
+  Instruction* load_inst = builder->AddUnaryOp(type_id, SpvOpLoad, var_id);
+  return load_inst->result_id();
+}
+
 void InstrumentPass::GenBuiltinOutputCode(uint32_t builtin_id,
                                           uint32_t builtin_off,
                                           uint32_t base_offset_id,
                                           InstructionBuilder* builder) {
   // Load and store builtin
-  Instruction* var_inst = get_def_use_mgr()->GetDef(builtin_id);
-  uint32_t type_id = GetPointeeTypeId(var_inst);
-  Instruction* load_inst = builder->AddUnaryOp(type_id, SpvOpLoad, builtin_id);
-  uint32_t val_id = GenUintCastCode(load_inst->result_id(), builder);
-  GenDebugOutputFieldCode(base_offset_id, builtin_off, val_id, builder);
-}
-
-void InstrumentPass::GenUintNullOutputCode(uint32_t field_off,
-                                           uint32_t base_offset_id,
-                                           InstructionBuilder* builder) {
-  GenDebugOutputFieldCode(base_offset_id, field_off,
-                          builder->GetNullId(GetUintId()), builder);
+  uint32_t load_id = GenVarLoad(builtin_id, builder);
+  GenDebugOutputFieldCode(base_offset_id, builtin_off, load_id, builder);
 }
 
 void InstrumentPass::GenStageStreamWriteCode(uint32_t stage_idx,
@@ -169,37 +167,97 @@
   switch (stage_idx) {
     case SpvExecutionModelVertex: {
       // Load and store VertexId and InstanceId
-      GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInVertexIndex),
-                           kInstVertOutVertexIndex, base_offset_id, builder);
-      GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInInstanceIndex),
-                           kInstVertOutInstanceIndex, base_offset_id, builder);
+      GenBuiltinOutputCode(
+          context()->GetBuiltinInputVarId(SpvBuiltInVertexIndex),
+          kInstVertOutVertexIndex, base_offset_id, builder);
+      GenBuiltinOutputCode(
+          context()->GetBuiltinInputVarId(SpvBuiltInInstanceIndex),
+          kInstVertOutInstanceIndex, base_offset_id, builder);
     } break;
     case SpvExecutionModelGLCompute: {
-      // Load and store GlobalInvocationId. Second word is unused; store zero.
-      GenBuiltinOutputCode(
-          context()->GetBuiltinVarId(SpvBuiltInGlobalInvocationId),
-          kInstCompOutGlobalInvocationId, base_offset_id, builder);
-      GenUintNullOutputCode(kInstCompOutUnused, base_offset_id, builder);
+      // Load and store GlobalInvocationId.
+      uint32_t load_id = GenVarLoad(
+          context()->GetBuiltinInputVarId(SpvBuiltInGlobalInvocationId),
+          builder);
+      Instruction* x_inst = builder->AddIdLiteralOp(
+          GetUintId(), SpvOpCompositeExtract, load_id, 0);
+      Instruction* y_inst = builder->AddIdLiteralOp(
+          GetUintId(), SpvOpCompositeExtract, load_id, 1);
+      Instruction* z_inst = builder->AddIdLiteralOp(
+          GetUintId(), SpvOpCompositeExtract, load_id, 2);
+      if (version_ == 1) {
+        // For version 1 format, as a stopgap, pack uvec3 into first word:
+        // x << 21 | y << 10 | z. Second word is unused. (DEPRECATED)
+        Instruction* x_shft_inst = builder->AddBinaryOp(
+            GetUintId(), SpvOpShiftLeftLogical, x_inst->result_id(),
+            builder->GetUintConstantId(21));
+        Instruction* y_shft_inst = builder->AddBinaryOp(
+            GetUintId(), SpvOpShiftLeftLogical, y_inst->result_id(),
+            builder->GetUintConstantId(10));
+        Instruction* x_or_y_inst = builder->AddBinaryOp(
+            GetUintId(), SpvOpBitwiseOr, x_shft_inst->result_id(),
+            y_shft_inst->result_id());
+        Instruction* x_or_y_or_z_inst =
+            builder->AddBinaryOp(GetUintId(), SpvOpBitwiseOr,
+                                 x_or_y_inst->result_id(), z_inst->result_id());
+        GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationId,
+                                x_or_y_or_z_inst->result_id(), builder);
+      } else {
+        // For version 2 format, write all three words
+        GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdX,
+                                x_inst->result_id(), builder);
+        GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdY,
+                                y_inst->result_id(), builder);
+        GenDebugOutputFieldCode(base_offset_id, kInstCompOutGlobalInvocationIdZ,
+                                z_inst->result_id(), builder);
+      }
     } break;
     case SpvExecutionModelGeometry: {
       // Load and store PrimitiveId and InvocationId.
-      GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInPrimitiveId),
-                           kInstGeomOutPrimitiveId, base_offset_id, builder);
-      GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInInvocationId),
-                           kInstGeomOutInvocationId, base_offset_id, builder);
+      GenBuiltinOutputCode(
+          context()->GetBuiltinInputVarId(SpvBuiltInPrimitiveId),
+          kInstGeomOutPrimitiveId, base_offset_id, builder);
+      GenBuiltinOutputCode(
+          context()->GetBuiltinInputVarId(SpvBuiltInInvocationId),
+          kInstGeomOutInvocationId, base_offset_id, builder);
     } break;
-    case SpvExecutionModelTessellationControl:
+    case SpvExecutionModelTessellationControl: {
+      // Load and store InvocationId and PrimitiveId
+      GenBuiltinOutputCode(
+          context()->GetBuiltinInputVarId(SpvBuiltInInvocationId),
+          kInstTessCtlOutInvocationId, base_offset_id, builder);
+      GenBuiltinOutputCode(
+          context()->GetBuiltinInputVarId(SpvBuiltInPrimitiveId),
+          kInstTessCtlOutPrimitiveId, base_offset_id, builder);
+    } break;
     case SpvExecutionModelTessellationEvaluation: {
-      // Load and store InvocationId. Second word is unused; store zero.
-      GenBuiltinOutputCode(context()->GetBuiltinVarId(SpvBuiltInInvocationId),
-                           kInstTessOutInvocationId, base_offset_id, builder);
-      GenUintNullOutputCode(kInstTessOutUnused, base_offset_id, builder);
+      if (version_ == 1) {
+        // For format version 1, load and store InvocationId.
+        GenBuiltinOutputCode(
+            context()->GetBuiltinInputVarId(SpvBuiltInInvocationId),
+            kInstTessOutInvocationId, base_offset_id, builder);
+      } else {
+        // For format version 2, load and store PrimitiveId and TessCoord.uv
+        GenBuiltinOutputCode(
+            context()->GetBuiltinInputVarId(SpvBuiltInPrimitiveId),
+            kInstTessEvalOutPrimitiveId, base_offset_id, builder);
+        uint32_t load_id = GenVarLoad(
+            context()->GetBuiltinInputVarId(SpvBuiltInTessCoord), builder);
+        Instruction* u_inst = builder->AddIdLiteralOp(
+            GetUintId(), SpvOpCompositeExtract, load_id, 0);
+        Instruction* v_inst = builder->AddIdLiteralOp(
+            GetUintId(), SpvOpCompositeExtract, load_id, 1);
+        GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordU,
+                                u_inst->result_id(), builder);
+        GenDebugOutputFieldCode(base_offset_id, kInstTessEvalOutTessCoordV,
+                                v_inst->result_id(), builder);
+      }
     } break;
     case SpvExecutionModelFragment: {
       // Load FragCoord and convert to Uint
-      Instruction* frag_coord_inst =
-          builder->AddUnaryOp(GetVec4FloatId(), SpvOpLoad,
-                              context()->GetBuiltinVarId(SpvBuiltInFragCoord));
+      Instruction* frag_coord_inst = builder->AddUnaryOp(
+          GetVec4FloatId(), SpvOpLoad,
+          context()->GetBuiltinInputVarId(SpvBuiltInFragCoord));
       Instruction* uint_frag_coord_inst = builder->AddUnaryOp(
           GetVec4UintId(), SpvOpBitcast, frag_coord_inst->result_id());
       for (uint32_t u = 0; u < 2u; ++u)
@@ -591,9 +649,11 @@
     GenCommonStreamWriteCode(obuf_record_sz, param_vec[kInstCommonParamInstIdx],
                              stage_idx, obuf_curr_sz_id, &builder);
     GenStageStreamWriteCode(stage_idx, obuf_curr_sz_id, &builder);
+    uint32_t val_spec_offset =
+        (version_ == 1) ? kInstStageOutCnt : kInst2StageOutCnt;
     // Gen writes of validation specific data
     for (uint32_t i = 0; i < val_spec_param_cnt; ++i) {
-      GenDebugOutputFieldCode(obuf_curr_sz_id, kInstStageOutCnt + i,
+      GenDebugOutputFieldCode(obuf_curr_sz_id, val_spec_offset + i,
                               param_vec[kInstCommonParamCnt + i], &builder);
     }
     // Close write block and gen merge block
diff --git a/source/opt/instrument_pass.h b/source/opt/instrument_pass.h
index c4b97d6..d255698 100644
--- a/source/opt/instrument_pass.h
+++ b/source/opt/instrument_pass.h
@@ -78,16 +78,18 @@
   }
 
  protected:
-  // Create instrumentation pass which utilizes descriptor set |desc_set|
-  // for debug input and output buffers and writes |shader_id| into debug
-  // output records.
-  InstrumentPass(uint32_t desc_set, uint32_t shader_id, uint32_t validation_id)
+  // Create instrumentation pass for |validation_id| which utilizes descriptor
+  // set |desc_set| for debug input and output buffers and writes |shader_id|
+  // into debug output records with format |version|.
+  InstrumentPass(uint32_t desc_set, uint32_t shader_id, uint32_t validation_id,
+                 uint32_t version)
       : Pass(),
         desc_set_(desc_set),
         shader_id_(shader_id),
-        validation_id_(validation_id) {}
+        validation_id_(validation_id),
+        version_(version) {}
 
-  // Initialize state for instrumentation of module by |validation_id|.
+  // Initialize state for instrumentation of module.
   void InitializeInstrument();
 
   // Call |pfn| on all instructions in all functions in the call tree of the
@@ -146,6 +148,7 @@
   //     Stage
   //     Stage-specific Word 0
   //     Stage-specific Word 1
+  //     ...
   //     Validation Error Code
   //     Validation-specific Word 0
   //     Validation-specific Word 1
@@ -170,12 +173,12 @@
   // following Stage-specific words.
   //
   // The Stage-specific Words identify which invocation of the shader generated
-  // the error. Every stage will write two words, although in some cases the
-  // second word is unused and so zero is written. Vertex shaders will write
-  // the Vertex and Instance ID. Fragment shaders will write FragCoord.xy.
-  // Compute shaders will write the Global Invocation ID and zero (unused).
-  // Both tesselation shaders will write the Invocation Id and zero (unused).
-  // The geometry shader will write the Primitive ID and Invocation ID.
+  // the error. Every stage will write a fixed number of words. Vertex shaders
+  // will write the Vertex and Instance ID. Fragment shaders will write
+  // FragCoord.xy. Compute shaders will write the GlobalInvocation ID.
+  // The tesselation eval shader will write the Primitive ID and TessCoords.uv.
+  // The tesselation control shader and geometry shader will write the
+  // Primitive ID and Invocation ID.
   //
   // The Validation Error Code specifies the exact error which has occurred.
   // These are enumerated with the kInstError* static consts. This allows
@@ -291,16 +294,15 @@
                                       uint32_t component,
                                       InstructionBuilder* builder);
 
+  // Generate instructions into |builder| which will load |var_id| and return
+  // its result id.
+  uint32_t GenVarLoad(uint32_t var_id, InstructionBuilder* builder);
+
   // Generate instructions into |builder| which will load the uint |builtin_id|
   // and write it into the debug output buffer at |base_off| + |builtin_off|.
   void GenBuiltinOutputCode(uint32_t builtin_id, uint32_t builtin_off,
                             uint32_t base_off, InstructionBuilder* builder);
 
-  // Generate instructions into |builder| which will write a uint null into
-  // the debug output buffer at |base_off| + |builtin_off|.
-  void GenUintNullOutputCode(uint32_t field_off, uint32_t base_off,
-                             InstructionBuilder* builder);
-
   // Generate instructions into |builder| which will write the |stage_idx|-
   // specific members of the debug output stream at |base_off|.
   void GenStageStreamWriteCode(uint32_t stage_idx, uint32_t base_off,
@@ -376,6 +378,9 @@
   // id for void type
   uint32_t void_id_;
 
+  // Record format version
+  uint32_t version_;
+
   // boolean to remember storage buffer extension
   bool storage_buffer_ext_defined_;
 
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index 61c5425..081fdbc 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -621,7 +621,7 @@
   return &it->second;
 }
 
-uint32_t IRContext::FindBuiltinVar(uint32_t builtin) {
+uint32_t IRContext::FindBuiltinInputVar(uint32_t builtin) {
   for (auto& a : module_->annotations()) {
     if (a.opcode() != SpvOpDecorate) continue;
     if (a.GetSingleWordInOperand(kSpvDecorateDecorationInIdx) !=
@@ -631,6 +631,7 @@
     uint32_t target_id = a.GetSingleWordInOperand(kSpvDecorateTargetIdInIdx);
     Instruction* b_var = get_def_use_mgr()->GetDef(target_id);
     if (b_var->opcode() != SpvOpVariable) continue;
+    if (b_var->GetSingleWordInOperand(0) != SpvStorageClassInput) continue;
     return target_id;
   }
   return 0;
@@ -653,14 +654,14 @@
   }
 }
 
-uint32_t IRContext::GetBuiltinVarId(uint32_t builtin) {
+uint32_t IRContext::GetBuiltinInputVarId(uint32_t builtin) {
   if (!AreAnalysesValid(kAnalysisBuiltinVarId)) ResetBuiltinAnalysis();
   // If cached, return it.
   std::unordered_map<uint32_t, uint32_t>::iterator it =
       builtin_var_id_map_.find(builtin);
   if (it != builtin_var_id_map_.end()) return it->second;
   // Look for one in shader
-  uint32_t var_id = FindBuiltinVar(builtin);
+  uint32_t var_id = FindBuiltinInputVar(builtin);
   if (var_id == 0) {
     // If not found, create it
     // TODO(greg-lunarg): Add support for all builtins
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index c857c52..32d5b17 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -491,10 +491,10 @@
   uint32_t max_id_bound() const { return max_id_bound_; }
   void set_max_id_bound(uint32_t new_bound) { max_id_bound_ = new_bound; }
 
-  // Return id of variable only decorated with |builtin|, if in module.
+  // Return id of input variable only decorated with |builtin|, if in module.
   // Create variable and return its id otherwise. If builtin not currently
   // supported, return 0.
-  uint32_t GetBuiltinVarId(uint32_t builtin);
+  uint32_t GetBuiltinInputVarId(uint32_t builtin);
 
   // Returns the function whose id is |id|, if one exists.  Returns |nullptr|
   // otherwise.
@@ -657,9 +657,9 @@
   // true if the cfg is invalidated.
   bool CheckCFG();
 
-  // Return id of variable only decorated with |builtin|, if in module.
+  // Return id of input variable only decorated with |builtin|, if in module.
   // Return 0 otherwise.
-  uint32_t FindBuiltinVar(uint32_t builtin);
+  uint32_t FindBuiltinInputVar(uint32_t builtin);
 
   // Add |var_id| to all entry points in module.
   void AddVarToEntryPoints(uint32_t var_id);
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index 4c8daed..3fd6e80 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -397,7 +397,7 @@
   } else if (pass_name == "replace-invalid-opcode") {
     RegisterPass(CreateReplaceInvalidOpcodePass());
   } else if (pass_name == "inst-bindless-check") {
-    RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true));
+    RegisterPass(CreateInstBindlessCheckPass(7, 23, true, true, 1));
     RegisterPass(CreateSimplificationPass());
     RegisterPass(CreateDeadBranchElimPass());
     RegisterPass(CreateBlockMergePass());
@@ -847,10 +847,12 @@
 Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t desc_set,
                                                  uint32_t shader_id,
                                                  bool input_length_enable,
-                                                 bool input_init_enable) {
+                                                 bool input_init_enable,
+                                                 uint32_t version) {
   return MakeUnique<Optimizer::PassToken::Impl>(
-      MakeUnique<opt::InstBindlessCheckPass>(
-          desc_set, shader_id, input_length_enable, input_init_enable));
+      MakeUnique<opt::InstBindlessCheckPass>(desc_set, shader_id,
+                                             input_length_enable,
+                                             input_init_enable, version));
 }
 
 Optimizer::PassToken CreateCodeSinkingPass() {
diff --git a/test/opt/inst_bindless_check_test.cpp b/test/opt/inst_bindless_check_test.cpp
index 94a37cf..ea25de3 100644
--- a/test/opt/inst_bindless_check_test.cpp
+++ b/test/opt/inst_bindless_check_test.cpp
@@ -4226,6 +4226,235 @@
       true);
 }
 
+TEST_F(InstBindlessTest, SimpleV2) {
+  // Texture2D g_tColor[128];
+  //
+  // layout(push_constant) cbuffer PerViewConstantBuffer_t
+  // {
+  //   uint g_nDataIdx;
+  // };
+  //
+  // SamplerState g_sAniso;
+  //
+  // struct PS_INPUT
+  // {
+  //   float2 vTextureCoords : TEXCOORD2;
+  // };
+  //
+  // struct PS_OUTPUT
+  // {
+  //   float4 vColor : SV_Target0;
+  // };
+  //
+  // PS_OUTPUT MainPs(PS_INPUT i)
+  // {
+  //   PS_OUTPUT ps_output;
+  //   ps_output.vColor =
+  //       g_tColor[ g_nDataIdx ].Sample(g_sAniso, i.vTextureCoords.xy);
+  //   return ps_output;
+  // }
+
+  const std::string entry_before =
+      R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor
+OpExecutionMode %MainPs OriginUpperLeft
+OpSource HLSL 500
+)";
+
+  const std::string entry_after =
+      R"(OpCapability Shader
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %MainPs "MainPs" %i_vTextureCoords %_entryPointOutput_vColor %gl_FragCoord
+OpExecutionMode %MainPs OriginUpperLeft
+OpSource HLSL 500
+)";
+
+  const std::string names_annots =
+      R"(OpName %MainPs "MainPs"
+OpName %g_tColor "g_tColor"
+OpName %PerViewConstantBuffer_t "PerViewConstantBuffer_t"
+OpMemberName %PerViewConstantBuffer_t 0 "g_nDataIdx"
+OpName %_ ""
+OpName %g_sAniso "g_sAniso"
+OpName %i_vTextureCoords "i.vTextureCoords"
+OpName %_entryPointOutput_vColor "@entryPointOutput.vColor"
+OpDecorate %g_tColor DescriptorSet 3
+OpDecorate %g_tColor Binding 0
+OpMemberDecorate %PerViewConstantBuffer_t 0 Offset 0
+OpDecorate %PerViewConstantBuffer_t Block
+OpDecorate %g_sAniso DescriptorSet 0
+OpDecorate %i_vTextureCoords Location 0
+OpDecorate %_entryPointOutput_vColor Location 0
+)";
+
+  const std::string new_annots =
+      R"(OpDecorate %_runtimearr_uint ArrayStride 4
+OpDecorate %_struct_55 Block
+OpMemberDecorate %_struct_55 0 Offset 0
+OpMemberDecorate %_struct_55 1 Offset 4
+OpDecorate %57 DescriptorSet 7
+OpDecorate %57 Binding 0
+OpDecorate %gl_FragCoord BuiltIn FragCoord
+)";
+
+  const std::string consts_types_vars =
+      R"(%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v2float = OpTypeVector %float 2
+%v4float = OpTypeVector %float 4
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%16 = OpTypeImage %float 2D 0 0 0 1 Unknown
+%uint = OpTypeInt 32 0
+%uint_128 = OpConstant %uint 128
+%_arr_16_uint_128 = OpTypeArray %16 %uint_128
+%_ptr_UniformConstant__arr_16_uint_128 = OpTypePointer UniformConstant %_arr_16_uint_128
+%g_tColor = OpVariable %_ptr_UniformConstant__arr_16_uint_128 UniformConstant
+%PerViewConstantBuffer_t = OpTypeStruct %uint
+%_ptr_PushConstant_PerViewConstantBuffer_t = OpTypePointer PushConstant %PerViewConstantBuffer_t
+%_ = OpVariable %_ptr_PushConstant_PerViewConstantBuffer_t PushConstant
+%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint
+%_ptr_UniformConstant_16 = OpTypePointer UniformConstant %16
+%24 = OpTypeSampler
+%_ptr_UniformConstant_24 = OpTypePointer UniformConstant %24
+%g_sAniso = OpVariable %_ptr_UniformConstant_24 UniformConstant
+%26 = OpTypeSampledImage %16
+%_ptr_Input_v2float = OpTypePointer Input %v2float
+%i_vTextureCoords = OpVariable %_ptr_Input_v2float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%_entryPointOutput_vColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+  const std::string new_consts_types_vars =
+      R"(%uint_0 = OpConstant %uint 0
+%bool = OpTypeBool
+%48 = OpTypeFunction %void %uint %uint %uint %uint
+%_runtimearr_uint = OpTypeRuntimeArray %uint
+%_struct_55 = OpTypeStruct %uint %_runtimearr_uint
+%_ptr_StorageBuffer__struct_55 = OpTypePointer StorageBuffer %_struct_55
+%57 = OpVariable %_ptr_StorageBuffer__struct_55 StorageBuffer
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+%uint_9 = OpConstant %uint 9
+%uint_4 = OpConstant %uint 4
+%uint_1 = OpConstant %uint 1
+%uint_23 = OpConstant %uint 23
+%uint_2 = OpConstant %uint 2
+%uint_3 = OpConstant %uint 3
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%gl_FragCoord = OpVariable %_ptr_Input_v4float Input
+%v4uint = OpTypeVector %uint 4
+%uint_5 = OpConstant %uint 5
+%uint_7 = OpConstant %uint 7
+%uint_8 = OpConstant %uint 8
+%uint_56 = OpConstant %uint 56
+%102 = OpConstantNull %v4float
+)";
+
+  const std::string func_pt1 =
+      R"(%MainPs = OpFunction %void None %10
+%29 = OpLabel
+%30 = OpLoad %v2float %i_vTextureCoords
+%31 = OpAccessChain %_ptr_PushConstant_uint %_ %int_0
+%32 = OpLoad %uint %31
+%33 = OpAccessChain %_ptr_UniformConstant_16 %g_tColor %32
+%34 = OpLoad %16 %33
+%35 = OpLoad %24 %g_sAniso
+%36 = OpSampledImage %26 %34 %35
+)";
+
+  const std::string func_pt2_before =
+      R"(%37 = OpImageSampleImplicitLod %v4float %36 %30
+OpStore %_entryPointOutput_vColor %37
+OpReturn
+OpFunctionEnd
+)";
+
+  const std::string func_pt2_after =
+      R"(%40 = OpULessThan %bool %32 %uint_128
+OpSelectionMerge %41 None
+OpBranchConditional %40 %42 %43
+%42 = OpLabel
+%44 = OpLoad %16 %33
+%45 = OpSampledImage %26 %44 %35
+%46 = OpImageSampleImplicitLod %v4float %45 %30
+OpBranch %41
+%43 = OpLabel
+%101 = OpFunctionCall %void %47 %uint_56 %uint_0 %32 %uint_128
+OpBranch %41
+%41 = OpLabel
+%103 = OpPhi %v4float %46 %42 %102 %43
+OpStore %_entryPointOutput_vColor %103
+OpReturn
+OpFunctionEnd
+)";
+
+  const std::string output_func =
+      R"(%47 = OpFunction %void None %48
+%49 = OpFunctionParameter %uint
+%50 = OpFunctionParameter %uint
+%51 = OpFunctionParameter %uint
+%52 = OpFunctionParameter %uint
+%53 = OpLabel
+%59 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_0
+%62 = OpAtomicIAdd %uint %59 %uint_4 %uint_0 %uint_9
+%63 = OpIAdd %uint %62 %uint_9
+%64 = OpArrayLength %uint %57 1
+%65 = OpULessThanEqual %bool %63 %64
+OpSelectionMerge %66 None
+OpBranchConditional %65 %67 %66
+%67 = OpLabel
+%68 = OpIAdd %uint %62 %uint_0
+%70 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %68
+OpStore %70 %uint_9
+%72 = OpIAdd %uint %62 %uint_1
+%73 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %72
+OpStore %73 %uint_23
+%75 = OpIAdd %uint %62 %uint_2
+%76 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %75
+OpStore %76 %49
+%78 = OpIAdd %uint %62 %uint_3
+%79 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %78
+OpStore %79 %uint_4
+%82 = OpLoad %v4float %gl_FragCoord
+%84 = OpBitcast %v4uint %82
+%85 = OpCompositeExtract %uint %84 0
+%86 = OpIAdd %uint %62 %uint_4
+%87 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %86
+OpStore %87 %85
+%88 = OpCompositeExtract %uint %84 1
+%90 = OpIAdd %uint %62 %uint_5
+%91 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %90
+OpStore %91 %88
+%93 = OpIAdd %uint %62 %uint_7
+%94 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %93
+OpStore %94 %50
+%96 = OpIAdd %uint %62 %uint_8
+%97 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %96
+OpStore %97 %51
+%98 = OpIAdd %uint %62 %uint_9
+%99 = OpAccessChain %_ptr_StorageBuffer_uint %57 %uint_1 %98
+OpStore %99 %52
+OpBranch %66
+%66 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<InstBindlessCheckPass, uint32_t, uint32_t, bool, bool,
+                        uint32_t>(
+      entry_before + names_annots + consts_types_vars + func_pt1 +
+          func_pt2_before,
+      entry_after + names_annots + new_annots + consts_types_vars +
+          new_consts_types_vars + func_pt1 + func_pt2_after + output_func,
+      true, true, 7u, 23u, true, true, 2u);
+}
+
 // TODO(greg-lunarg): Add tests to verify handling of these cases:
 //
 //   Compute shader