Instrument: Debug Printf support (#3215)

Create a pass to instrument OpDebugPrintf instructions.  This pass replaces all OpDebugPrintf instructions with instructions to write a record containing the string id and the all specified values into a special printf output buffer (if space allows). This pass is designed to support the printf validation in the Vulkan validation layers.

Fixes #3210
diff --git a/Android.mk b/Android.mk
index db4f43b..eec709a 100644
--- a/Android.mk
+++ b/Android.mk
@@ -119,6 +119,7 @@
 		source/opt/inline_opaque_pass.cpp \
 		source/opt/inst_bindless_check_pass.cpp \
 		source/opt/inst_buff_addr_check_pass.cpp \
+		source/opt/inst_debug_printf_pass.cpp \
 		source/opt/instruction.cpp \
 		source/opt/instruction_list.cpp \
 		source/opt/instrument_pass.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index 1337059..d3107fd 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -590,6 +590,8 @@
     "source/opt/inst_bindless_check_pass.h",
     "source/opt/inst_buff_addr_check_pass.cpp",
     "source/opt/inst_buff_addr_check_pass.h",
+    "source/opt/inst_debug_printf_pass.cpp",
+    "source/opt/inst_debug_printf_pass.h",
     "source/opt/instruction.cpp",
     "source/opt/instruction.h",
     "source/opt/instruction_list.cpp",
diff --git a/include/spirv-tools/instrument.hpp b/include/spirv-tools/instrument.hpp
index 2dcb333..d3180e4 100644
--- a/include/spirv-tools/instrument.hpp
+++ b/include/spirv-tools/instrument.hpp
@@ -208,6 +208,9 @@
 // The binding for the input buffer read by InstBuffAddrCheckPass.
 static const int kDebugInputBindingBuffAddr = 2;
 
+// This is the output buffer written by InstDebugPrintfPass.
+static const int kDebugOutputPrintfStream = 3;
+
 // Bindless Validation Input Buffer Format
 //
 // An input buffer for bindless validation consists of a single array of
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index c31ccef..b904923 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -791,6 +791,18 @@
                                                  uint32_t shader_id,
                                                  uint32_t version = 2);
 
+// Create a pass to instrument OpDebugPrintf instructions.
+// This pass replaces all OpDebugPrintf instructions with instructions to write
+// a record containing the string id and the all specified values into a special
+// printf output buffer (if space allows). This pass is designed to support
+// the printf validation in the Vulkan validation layers.
+//
+// The instrumentation will write buffers in debug descriptor set |desc_set|.
+// It will write |shader_id| in each output record to identify the shader
+// module which generated the record.
+Optimizer::PassToken CreateInstDebugPrintfPass(uint32_t desc_set,
+                                               uint32_t shader_id);
+
 // Create a pass to upgrade to the VulkanKHR memory model.
 // This pass upgrades the Logical GLSL450 memory model to Logical VulkanKHR.
 // Additionally, it modifies memory, image, atomic and barrier operations to
diff --git a/source/enum_set.h b/source/enum_set.h
index 2e7046d..d4d31e3 100644
--- a/source/enum_set.h
+++ b/source/enum_set.h
@@ -93,6 +93,10 @@
   // enum value is already in the set.
   void Add(EnumType c) { AddWord(ToWord(c)); }
 
+  // Removes the given enum value from the set.  This has no effect if the
+  // enum value is not in the set.
+  void Remove(EnumType c) { RemoveWord(ToWord(c)); }
+
   // Returns true if this enum value is in the set.
   bool Contains(EnumType c) const { return ContainsWord(ToWord(c)); }
 
@@ -141,6 +145,17 @@
     }
   }
 
+  // Removes the given enum value (as a 32-bit word) from the set.  This has no
+  // effect if the enum value is not in the set.
+  void RemoveWord(uint32_t word) {
+    if (auto new_bits = AsMask(word)) {
+      mask_ &= ~new_bits;
+    } else {
+      auto itr = Overflow().find(word);
+      if (itr != Overflow().end()) Overflow().erase(itr);
+    }
+  }
+
   // Returns true if the enum represented as a 32-bit word is in the set.
   bool ContainsWord(uint32_t word) const {
     // We shouldn't call Overflow() since this is a const method.
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 0f719cb..1428c74 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -58,6 +58,7 @@
   inline_pass.h
   inst_bindless_check_pass.h
   inst_buff_addr_check_pass.h
+  inst_debug_printf_pass.h
   instruction.h
   instruction_list.h
   instrument_pass.h
@@ -164,6 +165,7 @@
   inline_pass.cpp
   inst_bindless_check_pass.cpp
   inst_buff_addr_check_pass.cpp
+  inst_debug_printf_pass.cpp
   instruction.cpp
   instruction_list.cpp
   instrument_pass.cpp
diff --git a/source/opt/feature_manager.cpp b/source/opt/feature_manager.cpp
index 63d50b6..b4d6f1b 100644
--- a/source/opt/feature_manager.cpp
+++ b/source/opt/feature_manager.cpp
@@ -47,6 +47,11 @@
   }
 }
 
+void FeatureManager::RemoveExtension(Extension ext) {
+  if (!extensions_.Contains(ext)) return;
+  extensions_.Remove(ext);
+}
+
 void FeatureManager::AddCapability(SpvCapability cap) {
   if (capabilities_.Contains(cap)) return;
 
@@ -60,6 +65,11 @@
   }
 }
 
+void FeatureManager::RemoveCapability(SpvCapability cap) {
+  if (!capabilities_.Contains(cap)) return;
+  capabilities_.Remove(cap);
+}
+
 void FeatureManager::AddCapabilities(Module* module) {
   for (Instruction& inst : module->capabilities()) {
     AddCapability(static_cast<SpvCapability>(inst.GetSingleWordInOperand(0)));
diff --git a/source/opt/feature_manager.h b/source/opt/feature_manager.h
index 2fe3291..881d5e6 100644
--- a/source/opt/feature_manager.h
+++ b/source/opt/feature_manager.h
@@ -30,11 +30,17 @@
   // Returns true if |ext| is an enabled extension in the module.
   bool HasExtension(Extension ext) const { return extensions_.Contains(ext); }
 
+  // Removes the given |extension| from the current FeatureManager.
+  void RemoveExtension(Extension extension);
+
   // Returns true if |cap| is an enabled capability in the module.
   bool HasCapability(SpvCapability cap) const {
     return capabilities_.Contains(cap);
   }
 
+  // Removes the given |capability| from the current FeatureManager.
+  void RemoveCapability(SpvCapability capability);
+
   // Analyzes |module| and records enabled extensions and capabilities.
   void Analyze(Module* module);
 
diff --git a/source/opt/inst_debug_printf_pass.cpp b/source/opt/inst_debug_printf_pass.cpp
new file mode 100644
index 0000000..c0e6bc3
--- /dev/null
+++ b/source/opt/inst_debug_printf_pass.cpp
@@ -0,0 +1,266 @@
+// Copyright (c) 2020 The Khronos Group Inc.
+// Copyright (c) 2020 Valve Corporation
+// Copyright (c) 2020 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "inst_debug_printf_pass.h"
+
+#include "spirv/unified1/NonSemanticDebugPrintf.h"
+
+namespace spvtools {
+namespace opt {
+
+void InstDebugPrintfPass::GenOutputValues(Instruction* val_inst,
+                                          std::vector<uint32_t>* val_ids,
+                                          InstructionBuilder* builder) {
+  uint32_t val_ty_id = val_inst->type_id();
+  analysis::TypeManager* type_mgr = context()->get_type_mgr();
+  analysis::Type* val_ty = type_mgr->GetType(val_ty_id);
+  switch (val_ty->kind()) {
+    case analysis::Type::kVector: {
+      analysis::Vector* v_ty = val_ty->AsVector();
+      const analysis::Type* c_ty = v_ty->element_type();
+      uint32_t c_ty_id = type_mgr->GetId(c_ty);
+      for (uint32_t c = 0; c < v_ty->element_count(); ++c) {
+        Instruction* c_inst = builder->AddIdLiteralOp(
+            c_ty_id, SpvOpCompositeExtract, val_inst->result_id(), c);
+        GenOutputValues(c_inst, val_ids, builder);
+      }
+      return;
+    }
+    case analysis::Type::kBool: {
+      // Select between uint32 zero or one
+      uint32_t zero_id = builder->GetUintConstantId(0);
+      uint32_t one_id = builder->GetUintConstantId(1);
+      Instruction* sel_inst = builder->AddTernaryOp(
+          GetUintId(), SpvOpSelect, val_inst->result_id(), one_id, zero_id);
+      val_ids->push_back(sel_inst->result_id());
+      return;
+    }
+    case analysis::Type::kFloat: {
+      analysis::Float* f_ty = val_ty->AsFloat();
+      switch (f_ty->width()) {
+        case 16: {
+          // Convert float16 to float32 and recurse
+          Instruction* f32_inst = builder->AddUnaryOp(
+              GetFloatId(), SpvOpFConvert, val_inst->result_id());
+          GenOutputValues(f32_inst, val_ids, builder);
+          return;
+        }
+        case 64: {
+          // Bitcast float64 to uint64 and recurse
+          Instruction* ui64_inst = builder->AddUnaryOp(
+              GetUint64Id(), SpvOpBitcast, val_inst->result_id());
+          GenOutputValues(ui64_inst, val_ids, builder);
+          return;
+        }
+        case 32: {
+          // Bitcase float32 to uint32
+          Instruction* bc_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast,
+                                                     val_inst->result_id());
+          val_ids->push_back(bc_inst->result_id());
+          return;
+        }
+        default:
+          assert(false && "unsupported float width");
+          return;
+      }
+    }
+    case analysis::Type::kInteger: {
+      analysis::Integer* i_ty = val_ty->AsInteger();
+      switch (i_ty->width()) {
+        case 64: {
+          Instruction* ui64_inst = val_inst;
+          if (i_ty->IsSigned()) {
+            // Bitcast sint64 to uint64
+            ui64_inst = builder->AddUnaryOp(GetUint64Id(), SpvOpBitcast,
+                                            val_inst->result_id());
+          }
+          // Break uint64 into 2x uint32
+          Instruction* lo_ui64_inst = builder->AddUnaryOp(
+              GetUintId(), SpvOpUConvert, ui64_inst->result_id());
+          Instruction* rshift_ui64_inst = builder->AddBinaryOp(
+              GetUint64Id(), SpvOpShiftRightLogical, ui64_inst->result_id(),
+              builder->GetUintConstantId(32));
+          Instruction* hi_ui64_inst = builder->AddUnaryOp(
+              GetUintId(), SpvOpUConvert, rshift_ui64_inst->result_id());
+          val_ids->push_back(lo_ui64_inst->result_id());
+          val_ids->push_back(hi_ui64_inst->result_id());
+          return;
+        }
+        case 8: {
+          Instruction* ui8_inst = val_inst;
+          if (i_ty->IsSigned()) {
+            // Bitcast sint8 to uint8
+            ui8_inst = builder->AddUnaryOp(GetUint8Id(), SpvOpBitcast,
+                                           val_inst->result_id());
+          }
+          // Convert uint8 to uint32
+          Instruction* ui32_inst = builder->AddUnaryOp(
+              GetUintId(), SpvOpUConvert, ui8_inst->result_id());
+          val_ids->push_back(ui32_inst->result_id());
+          return;
+        }
+        case 32: {
+          Instruction* ui32_inst = val_inst;
+          if (i_ty->IsSigned()) {
+            // Bitcast sint32 to uint32
+            ui32_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast,
+                                            val_inst->result_id());
+          }
+          // uint32 needs no further processing
+          val_ids->push_back(ui32_inst->result_id());
+          return;
+        }
+        default:
+          // TODO(greg-lunarg): Support non-32-bit int
+          assert(false && "unsupported int width");
+          return;
+      }
+    }
+    default:
+      assert(false && "unsupported type");
+      return;
+  }
+}
+
+void InstDebugPrintfPass::GenOutputCode(
+    Instruction* printf_inst, uint32_t stage_idx,
+    std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
+  BasicBlock* back_blk_ptr = &*new_blocks->back();
+  InstructionBuilder builder(
+      context(), back_blk_ptr,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+  // Gen debug printf record validation-specific values. The format string
+  // will have its id written. Vectors will need to be broken down into
+  // component values. float16 will need to be converted to float32. Pointer
+  // and uint64 will need to be converted to two uint32 values. float32 will
+  // need to be bitcast to uint32. int32 will need to be bitcast to uint32.
+  std::vector<uint32_t> val_ids;
+  bool is_first_operand = false;
+  printf_inst->ForEachInId(
+      [&is_first_operand, &val_ids, &builder, this](const uint32_t* iid) {
+        // skip set operand
+        if (!is_first_operand) {
+          is_first_operand = true;
+          return;
+        }
+        Instruction* opnd_inst = get_def_use_mgr()->GetDef(*iid);
+        if (opnd_inst->opcode() == SpvOpString) {
+          uint32_t string_id_id = builder.GetUintConstantId(*iid);
+          val_ids.push_back(string_id_id);
+        } else {
+          GenOutputValues(opnd_inst, &val_ids, &builder);
+        }
+      });
+  GenDebugStreamWrite(uid2offset_[printf_inst->unique_id()], stage_idx, val_ids,
+                      &builder);
+  context()->KillInst(printf_inst);
+}
+
+void InstDebugPrintfPass::GenDebugPrintfCode(
+    BasicBlock::iterator ref_inst_itr,
+    UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
+    std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
+  // If not DebugPrintf OpExtInst, return.
+  Instruction* printf_inst = &*ref_inst_itr;
+  if (printf_inst->opcode() != SpvOpExtInst) return;
+  if (printf_inst->GetSingleWordInOperand(0) != ext_inst_printf_id_) return;
+  if (printf_inst->GetSingleWordInOperand(1) !=
+      NonSemanticDebugPrintfDebugPrintf)
+    return;
+  // Initialize DefUse manager before dismantling module
+  (void)get_def_use_mgr();
+  // Move original block's preceding instructions into first new block
+  std::unique_ptr<BasicBlock> new_blk_ptr;
+  MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
+  new_blocks->push_back(std::move(new_blk_ptr));
+  // Generate instructions to output printf args to printf buffer
+  GenOutputCode(printf_inst, stage_idx, new_blocks);
+  // Caller expects at least two blocks with last block containing remaining
+  // code, so end block after instrumentation, create remainder block, and
+  // branch to it
+  uint32_t rem_blk_id = TakeNextId();
+  std::unique_ptr<Instruction> rem_label(NewLabel(rem_blk_id));
+  BasicBlock* back_blk_ptr = &*new_blocks->back();
+  InstructionBuilder builder(
+      context(), back_blk_ptr,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+  (void)builder.AddBranch(rem_blk_id);
+  // Gen remainder block
+  new_blk_ptr.reset(new BasicBlock(std::move(rem_label)));
+  builder.SetInsertPoint(&*new_blk_ptr);
+  // Move original block's remaining code into remainder block and add
+  // to new blocks
+  MovePostludeCode(ref_block_itr, &*new_blk_ptr);
+  new_blocks->push_back(std::move(new_blk_ptr));
+}
+
+void InstDebugPrintfPass::InitializeInstDebugPrintf() {
+  // Initialize base class
+  InitializeInstrument();
+}
+
+Pass::Status InstDebugPrintfPass::ProcessImpl() {
+  // Perform printf instrumentation on each entry point function in module
+  InstProcessFunction pfn =
+      [this](BasicBlock::iterator ref_inst_itr,
+             UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
+             std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
+        return GenDebugPrintfCode(ref_inst_itr, ref_block_itr, stage_idx,
+                                  new_blocks);
+      };
+  (void)InstProcessEntryPointCallTree(pfn);
+  // Remove DebugPrintf OpExtInstImport instruction
+  Instruction* ext_inst_import_inst =
+      get_def_use_mgr()->GetDef(ext_inst_printf_id_);
+  context()->KillInst(ext_inst_import_inst);
+  // If no remaining non-semantic instruction sets, remove non-semantic debug
+  // info extension from module and feature manager
+  bool non_sem_set_seen = false;
+  for (auto c_itr = context()->module()->ext_inst_import_begin();
+       c_itr != context()->module()->ext_inst_import_end(); ++c_itr) {
+    const char* set_name =
+        reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]);
+    const char* non_sem_str = "NonSemantic.";
+    if (!strncmp(set_name, non_sem_str, strlen(non_sem_str))) {
+      non_sem_set_seen = true;
+      break;
+    }
+  }
+  if (!non_sem_set_seen) {
+    for (auto c_itr = context()->module()->extension_begin();
+         c_itr != context()->module()->extension_end(); ++c_itr) {
+      const char* ext_name =
+          reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]);
+      if (!strcmp(ext_name, "SPV_KHR_non_semantic_info")) {
+        context()->KillInst(&*c_itr);
+        break;
+      }
+    }
+    context()->get_feature_mgr()->RemoveExtension(kSPV_KHR_non_semantic_info);
+  }
+  return Status::SuccessWithChange;
+}
+
+Pass::Status InstDebugPrintfPass::Process() {
+  ext_inst_printf_id_ =
+      get_module()->GetExtInstImportId("NonSemantic.DebugPrintf");
+  if (ext_inst_printf_id_ == 0) return Status::SuccessWithoutChange;
+  InitializeInstDebugPrintf();
+  return ProcessImpl();
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/inst_debug_printf_pass.h b/source/opt/inst_debug_printf_pass.h
new file mode 100644
index 0000000..2968a20
--- /dev/null
+++ b/source/opt/inst_debug_printf_pass.h
@@ -0,0 +1,96 @@
+// Copyright (c) 2020 The Khronos Group Inc.
+// Copyright (c) 2020 Valve Corporation
+// Copyright (c) 2020 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_INST_DEBUG_PRINTF_PASS_H_
+#define LIBSPIRV_OPT_INST_DEBUG_PRINTF_PASS_H_
+
+#include "instrument_pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// This class/pass is designed to support the debug printf GPU-assisted layer
+// of https://github.com/KhronosGroup/Vulkan-ValidationLayers. Its internal and
+// external design may change as the layer evolves.
+class InstDebugPrintfPass : public InstrumentPass {
+ public:
+  // For test harness only
+  InstDebugPrintfPass()
+      : InstrumentPass(7, 23, kInstValidationIdDebugPrintf, 2) {}
+  // For all other interfaces
+  InstDebugPrintfPass(uint32_t desc_set, uint32_t shader_id)
+      : InstrumentPass(desc_set, shader_id, kInstValidationIdDebugPrintf, 2) {}
+
+  ~InstDebugPrintfPass() override = default;
+
+  // See optimizer.hpp for pass user documentation.
+  Status Process() override;
+
+  const char* name() const override { return "inst-printf-pass"; }
+
+ private:
+  // Generate instructions for OpDebugPrintf.
+  //
+  // If |ref_inst_itr| is an OpDebugPrintf, return in |new_blocks| the result
+  // of replacing it with buffer write instructions within its block at
+  // |ref_block_itr|.  The instructions write a record to the printf
+  // output buffer stream including |function_idx, instruction_idx, stage_idx|
+  // and removes the OpDebugPrintf. The block at |ref_block_itr| can just be
+  // replaced with the block in |new_blocks|. Besides the buffer writes, this
+  // block will comprise all instructions preceding and following
+  // |ref_inst_itr|.
+  //
+  // This function is designed to be passed to
+  // InstrumentPass::InstProcessEntryPointCallTree(), which applies the
+  // function to each instruction in a module and replaces the instruction
+  // if warranted.
+  //
+  // This instrumentation function utilizes GenDebugStreamWrite() to write its
+  // error records. The validation-specific part of the error record will
+  // consist of a uint32 which is the id of the format string plus a sequence
+  // of uint32s representing the values of the remaining operands of the
+  // DebugPrintf.
+  void GenDebugPrintfCode(BasicBlock::iterator ref_inst_itr,
+                          UptrVectorIterator<BasicBlock> ref_block_itr,
+                          uint32_t stage_idx,
+                          std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
+
+  // Generate a sequence of uint32 instructions in |builder| (if necessary)
+  // representing the value of |val_inst|, which must be a buffer pointer, a
+  // uint64, or a scalar or vector of type uint32, float32 or float16. Append
+  // the ids of all values to the end of |val_ids|.
+  void GenOutputValues(Instruction* val_inst, std::vector<uint32_t>* val_ids,
+                       InstructionBuilder* builder);
+
+  // Generate instructions to write a record containing the operands of
+  // |printf_inst| arguments to printf buffer, adding new code to the end of
+  // the last block in |new_blocks|. Kill OpDebugPrintf instruction.
+  void GenOutputCode(Instruction* printf_inst, uint32_t stage_idx,
+                     std::vector<std::unique_ptr<BasicBlock>>* new_blocks);
+
+  // Initialize state for instrumenting bindless checking
+  void InitializeInstDebugPrintf();
+
+  // Apply GenDebugPrintfCode to every instruction in module.
+  Pass::Status ProcessImpl();
+
+  uint32_t ext_inst_printf_id_;
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // LIBSPIRV_OPT_INST_DEBUG_PRINTF_PASS_H_
diff --git a/source/opt/instrument_pass.cpp b/source/opt/instrument_pass.cpp
index b1a6edb..c8c6c21 100644
--- a/source/opt/instrument_pass.cpp
+++ b/source/opt/instrument_pass.cpp
@@ -380,6 +380,8 @@
       return kDebugOutputBindingStream;
     case kInstValidationIdBuffAddr:
       return kDebugOutputBindingStream;
+    case kInstValidationIdDebugPrintf:
+      return kDebugOutputPrintfStream;
     default:
       assert(false && "unexpected validation id");
   }
@@ -529,6 +531,16 @@
   return input_buffer_id_;
 }
 
+uint32_t InstrumentPass::GetFloatId() {
+  if (float_id_ == 0) {
+    analysis::TypeManager* type_mgr = context()->get_type_mgr();
+    analysis::Float float_ty(32);
+    analysis::Type* reg_float_ty = type_mgr->GetRegisteredType(&float_ty);
+    float_id_ = type_mgr->GetTypeInstruction(reg_float_ty);
+  }
+  return float_id_;
+}
+
 uint32_t InstrumentPass::GetVec4FloatId() {
   if (v4float_id_ == 0) {
     analysis::TypeManager* type_mgr = context()->get_type_mgr();
@@ -561,6 +573,16 @@
   return uint64_id_;
 }
 
+uint32_t InstrumentPass::GetUint8Id() {
+  if (uint8_id_ == 0) {
+    analysis::TypeManager* type_mgr = context()->get_type_mgr();
+    analysis::Integer uint8_ty(8, false);
+    analysis::Type* reg_uint8_ty = type_mgr->GetRegisteredType(&uint8_ty);
+    uint8_id_ = type_mgr->GetTypeInstruction(reg_uint8_ty);
+  }
+  return uint8_id_;
+}
+
 uint32_t InstrumentPass::GetVecUintId(uint32_t len) {
   analysis::TypeManager* type_mgr = context()->get_type_mgr();
   analysis::Integer uint_ty(32, false);
@@ -606,21 +628,22 @@
   // Total param count is common params plus validation-specific
   // params
   uint32_t param_cnt = kInstCommonParamCnt + val_spec_param_cnt;
-  if (output_func_id_ == 0) {
+  if (param2output_func_id_[param_cnt] == 0) {
     // Create function
-    output_func_id_ = TakeNextId();
+    param2output_func_id_[param_cnt] = TakeNextId();
     analysis::TypeManager* type_mgr = context()->get_type_mgr();
     std::vector<const analysis::Type*> param_types;
     for (uint32_t c = 0; c < param_cnt; ++c)
       param_types.push_back(type_mgr->GetType(GetUintId()));
     analysis::Function func_ty(type_mgr->GetType(GetVoidId()), param_types);
     analysis::Type* reg_func_ty = type_mgr->GetRegisteredType(&func_ty);
-    std::unique_ptr<Instruction> func_inst(new Instruction(
-        get_module()->context(), SpvOpFunction, GetVoidId(), output_func_id_,
-        {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
-          {SpvFunctionControlMaskNone}},
-         {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
-          {type_mgr->GetTypeInstruction(reg_func_ty)}}}));
+    std::unique_ptr<Instruction> func_inst(
+        new Instruction(get_module()->context(), SpvOpFunction, GetVoidId(),
+                        param2output_func_id_[param_cnt],
+                        {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
+                          {SpvFunctionControlMaskNone}},
+                         {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
+                          {type_mgr->GetTypeInstruction(reg_func_ty)}}}));
     get_def_use_mgr()->AnalyzeInstDefUse(&*func_inst);
     std::unique_ptr<Function> output_func =
         MakeUnique<Function>(std::move(func_inst));
@@ -709,10 +732,8 @@
     get_def_use_mgr()->AnalyzeInstDefUse(&*func_end_inst);
     output_func->SetFunctionEnd(std::move(func_end_inst));
     context()->AddFunction(std::move(output_func));
-    output_func_param_cnt_ = param_cnt;
   }
-  assert(param_cnt == output_func_param_cnt_ && "bad arg count");
-  return output_func_id_;
+  return param2output_func_id_[param_cnt];
 }
 
 uint32_t InstrumentPass::GetDirectReadFunctionId(uint32_t param_cnt) {
@@ -848,7 +869,7 @@
   std::unordered_set<uint32_t> done;
   // Don't process input and output functions
   for (auto& ifn : param2input_func_id_) done.insert(ifn.second);
-  if (output_func_id_ != 0) done.insert(output_func_id_);
+  for (auto& ofn : param2output_func_id_) done.insert(ofn.second);
   // Process all functions from roots
   while (!roots->empty()) {
     const uint32_t fi = roots->front();
@@ -926,12 +947,12 @@
   output_buffer_id_ = 0;
   output_buffer_ptr_id_ = 0;
   input_buffer_ptr_id_ = 0;
-  output_func_id_ = 0;
-  output_func_param_cnt_ = 0;
   input_buffer_id_ = 0;
+  float_id_ = 0;
   v4float_id_ = 0;
   uint_id_ = 0;
   uint64_id_ = 0;
+  uint8_id_ = 0;
   v4uint_id_ = 0;
   v3uint_id_ = 0;
   bool_id_ = 0;
@@ -944,6 +965,10 @@
   id2function_.clear();
   id2block_.clear();
 
+  // clear maps
+  param2input_func_id_.clear();
+  param2output_func_id_.clear();
+
   // Initialize function and block maps.
   for (auto& fn : *get_module()) {
     id2function_[fn.result_id()] = &fn;
diff --git a/source/opt/instrument_pass.h b/source/opt/instrument_pass.h
index 02568fb..11afdce 100644
--- a/source/opt/instrument_pass.h
+++ b/source/opt/instrument_pass.h
@@ -61,6 +61,7 @@
 // its output buffers.
 static const uint32_t kInstValidationIdBindless = 0;
 static const uint32_t kInstValidationIdBuffAddr = 1;
+static const uint32_t kInstValidationIdDebugPrintf = 2;
 
 class InstrumentPass : public Pass {
   using cbb_ptr = const BasicBlock*;
@@ -227,9 +228,12 @@
   // Return id for 32-bit unsigned type
   uint32_t GetUintId();
 
-  // Return id for 32-bit unsigned type
+  // Return id for 64-bit unsigned type
   uint32_t GetUint64Id();
 
+  // Return id for 8-bit unsigned type
+  uint32_t GetUint8Id();
+
   // Return id for 32-bit unsigned type
   uint32_t GetBoolId();
 
@@ -267,6 +271,9 @@
   // Return id for debug input buffer
   uint32_t GetInputBufferId();
 
+  // Return id for 32-bit float type
+  uint32_t GetFloatId();
+
   // Return id for v4float type
   uint32_t GetVec4FloatId();
 
@@ -383,17 +390,17 @@
   uint32_t input_buffer_ptr_id_;
 
   // id for debug output function
-  uint32_t output_func_id_;
+  std::unordered_map<uint32_t, uint32_t> param2output_func_id_;
 
   // ids for debug input functions
   std::unordered_map<uint32_t, uint32_t> param2input_func_id_;
 
-  // param count for output function
-  uint32_t output_func_param_cnt_;
-
   // id for input buffer variable
   uint32_t input_buffer_id_;
 
+  // id for 32-bit float type
+  uint32_t float_id_;
+
   // id for v4float type
   uint32_t v4float_id_;
 
@@ -406,9 +413,12 @@
   // id for 32-bit unsigned type
   uint32_t uint_id_;
 
-  // id for 32-bit unsigned type
+  // id for 64-bit unsigned type
   uint32_t uint64_id_;
 
+  // id for 8-bit unsigned type
+  uint32_t uint8_id_;
+
   // id for bool type
   uint32_t bool_id_;
 
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index 241aa75..6e271f5 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -425,6 +425,8 @@
     RegisterPass(CreateConvertRelaxedToHalfPass());
   } else if (pass_name == "relax-float-ops") {
     RegisterPass(CreateRelaxFloatOpsPass());
+  } else if (pass_name == "inst-debug-printf") {
+    RegisterPass(CreateInstDebugPrintfPass(7, 23));
   } else if (pass_name == "simplify-instructions") {
     RegisterPass(CreateSimplificationPass());
   } else if (pass_name == "ssa-rewrite") {
@@ -886,6 +888,12 @@
                                              input_init_enable, version));
 }
 
+Optimizer::PassToken CreateInstDebugPrintfPass(uint32_t desc_set,
+                                               uint32_t shader_id) {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::InstDebugPrintfPass>(desc_set, shader_id));
+}
+
 Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t desc_set,
                                                  uint32_t shader_id,
                                                  uint32_t version) {
diff --git a/source/opt/passes.h b/source/opt/passes.h
index 1a3675c..5b4ab89 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -46,6 +46,7 @@
 #include "source/opt/inline_opaque_pass.h"
 #include "source/opt/inst_bindless_check_pass.h"
 #include "source/opt/inst_buff_addr_check_pass.h"
+#include "source/opt/inst_debug_printf_pass.h"
 #include "source/opt/legalize_vector_shuffle_pass.h"
 #include "source/opt/licm_pass.h"
 #include "source/opt/local_access_chain_convert_pass.h"
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index 327f265..3954338 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -55,6 +55,7 @@
        insert_extract_elim_test.cpp
        inst_bindless_check_test.cpp
        inst_buff_addr_check_test.cpp
+       inst_debug_printf_test.cpp
        instruction_list_test.cpp
        instruction_test.cpp
        ir_builder.cpp
diff --git a/test/opt/inst_debug_printf_test.cpp b/test/opt/inst_debug_printf_test.cpp
new file mode 100644
index 0000000..8123ffb
--- /dev/null
+++ b/test/opt/inst_debug_printf_test.cpp
@@ -0,0 +1,215 @@
+// Copyright (c) 2020 Valve Corporation
+// Copyright (c) 2020 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Debug Printf Instrumentation Tests.
+
+#include <string>
+#include <vector>
+
+#include "test/opt/assembly_builder.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using InstDebugPrintfTest = PassTest<::testing::Test>;
+
+TEST_F(InstDebugPrintfTest, V4Float32) {
+  // SamplerState g_sDefault;
+  // Texture2D g_tColor;
+  //
+  // struct PS_INPUT
+  // {
+  //   float2 vBaseTexCoord : TEXCOORD0;
+  // };
+  //
+  // struct PS_OUTPUT
+  // {
+  //   float4 vDiffuse : SV_Target0;
+  // };
+  //
+  // PS_OUTPUT MainPs(PS_INPUT i)
+  // {
+  //   PS_OUTPUT o;
+  //
+  //   o.vDiffuse.rgba = g_tColor.Sample(g_sDefault, (i.vBaseTexCoord.xy).xy);
+  //   debugPrintfEXT("diffuse: %v4f", o.vDiffuse.rgba);
+  //   return o;
+  // }
+
+  const std::string defs =
+      R"(OpCapability Shader
+OpExtension "SPV_KHR_non_semantic_info"
+%1 = OpExtInstImport "NonSemantic.DebugPrintf"
+; CHECK-NOT: OpExtension "SPV_KHR_non_semantic_info"
+; CHECK-NOT: %1 = OpExtInstImport "NonSemantic.DebugPrintf"
+; CHECK: OpExtension "SPV_KHR_storage_buffer_storage_class"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "MainPs" %3 %4
+; CHECK: OpEntryPoint Fragment %2 "MainPs" %3 %4 %gl_FragCoord
+OpExecutionMode %2 OriginUpperLeft
+%5 = OpString "Color is %vn"
+)";
+
+  const std::string decorates =
+      R"(OpDecorate %6 DescriptorSet 0
+OpDecorate %6 Binding 1
+OpDecorate %7 DescriptorSet 0
+OpDecorate %7 Binding 0
+OpDecorate %3 Location 0
+OpDecorate %4 Location 0
+; CHECK: OpDecorate %_runtimearr_uint ArrayStride 4
+; CHECK: OpDecorate %_struct_47 Block
+; CHECK: OpMemberDecorate %_struct_47 0 Offset 0
+; CHECK: OpMemberDecorate %_struct_47 1 Offset 4
+; CHECK: OpDecorate %49 DescriptorSet 7
+; CHECK: OpDecorate %49 Binding 3
+; CHECK: OpDecorate %gl_FragCoord BuiltIn FragCoord
+)";
+
+  const std::string globals =
+      R"(%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v2float = OpTypeVector %float 2
+%v4float = OpTypeVector %float 4
+%13 = OpTypeImage %float 2D 0 0 0 1 Unknown
+%_ptr_UniformConstant_13 = OpTypePointer UniformConstant %13
+%6 = OpVariable %_ptr_UniformConstant_13 UniformConstant
+%15 = OpTypeSampler
+%_ptr_UniformConstant_15 = OpTypePointer UniformConstant %15
+%7 = OpVariable %_ptr_UniformConstant_15 UniformConstant
+%17 = OpTypeSampledImage %13
+%_ptr_Input_v2float = OpTypePointer Input %v2float
+%3 = OpVariable %_ptr_Input_v2float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%4 = OpVariable %_ptr_Output_v4float Output
+; CHECK: %uint = OpTypeInt 32 0
+; CHECK: %38 = OpTypeFunction %void %uint %uint %uint %uint %uint %uint
+; CHECK: %_runtimearr_uint = OpTypeRuntimeArray %uint
+; CHECK: %_struct_47 = OpTypeStruct %uint %_runtimearr_uint
+; CHECK: %_ptr_StorageBuffer__struct_47 = OpTypePointer StorageBuffer %_struct_47
+; CHECK: %49 = OpVariable %_ptr_StorageBuffer__struct_47 StorageBuffer
+; CHECK: %_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+; CHECK: %bool = OpTypeBool
+; CHECK: %_ptr_Input_v4float = OpTypePointer Input %v4float
+; CHECK: %gl_FragCoord = OpVariable %_ptr_Input_v4float Input
+; CHECK: %v4uint = OpTypeVector %uint 4
+)";
+
+  const std::string main =
+      R"(%2 = OpFunction %void None %9
+%20 = OpLabel
+%21 = OpLoad %v2float %3
+%22 = OpLoad %13 %6
+%23 = OpLoad %15 %7
+%24 = OpSampledImage %17 %22 %23
+%25 = OpImageSampleImplicitLod %v4float %24 %21
+%26 = OpExtInst %void %1 1 %5 %25
+; CHECK-NOT: %26 = OpExtInst %void %1 1 %5 %25
+; CHECK: %29 = OpCompositeExtract %float %25 0
+; CHECK: %30 = OpBitcast %uint %29
+; CHECK: %31 = OpCompositeExtract %float %25 1
+; CHECK: %32 = OpBitcast %uint %31
+; CHECK: %33 = OpCompositeExtract %float %25 2
+; CHECK: %34 = OpBitcast %uint %33
+; CHECK: %35 = OpCompositeExtract %float %25 3
+; CHECK: %36 = OpBitcast %uint %35
+; CHECK: %101 = OpFunctionCall %void %37 %uint_36 %uint_5 %30 %32 %34 %36
+; CHECK: OpBranch %102
+; CHECK: %102 = OpLabel
+OpStore %4 %25
+OpReturn
+OpFunctionEnd
+)";
+
+  const std::string output_func =
+      R"(; CHECK: %37 = OpFunction %void None %38
+; CHECK: %39 = OpFunctionParameter %uint
+; CHECK: %40 = OpFunctionParameter %uint
+; CHECK: %41 = OpFunctionParameter %uint
+; CHECK: %42 = OpFunctionParameter %uint
+; CHECK: %43 = OpFunctionParameter %uint
+; CHECK: %44 = OpFunctionParameter %uint
+; CHECK: %45 = OpLabel
+; CHECK: %52 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_0
+; CHECK: %55 = OpAtomicIAdd %uint %52 %uint_4 %uint_0 %uint_12
+; CHECK: %56 = OpIAdd %uint %55 %uint_12
+; CHECK: %57 = OpArrayLength %uint %49 1
+; CHECK: %59 = OpULessThanEqual %bool %56 %57
+; CHECK: OpSelectionMerge %60 None
+; CHECK: OpBranchConditional %59 %61 %60
+; CHECK: %61 = OpLabel
+; CHECK: %62 = OpIAdd %uint %55 %uint_0
+; CHECK: %64 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %62
+; CHECK: OpStore %64 %uint_12
+; CHECK: %66 = OpIAdd %uint %55 %uint_1
+; CHECK: %67 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %66
+; CHECK: OpStore %67 %uint_23
+; CHECK: %69 = OpIAdd %uint %55 %uint_2
+; CHECK: %70 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %69
+; CHECK: OpStore %70 %39
+; CHECK: %72 = OpIAdd %uint %55 %uint_3
+; CHECK: %73 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %72
+; CHECK: OpStore %73 %uint_4
+; CHECK: %76 = OpLoad %v4float %gl_FragCoord
+; CHECK: %78 = OpBitcast %v4uint %76
+; CHECK: %79 = OpCompositeExtract %uint %78 0
+; CHECK: %80 = OpIAdd %uint %55 %uint_4
+; CHECK: %81 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %80
+; CHECK: OpStore %81 %79
+; CHECK: %82 = OpCompositeExtract %uint %78 1
+; CHECK: %83 = OpIAdd %uint %55 %uint_5
+; CHECK: %84 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %83
+; CHECK: OpStore %84 %82
+; CHECK: %86 = OpIAdd %uint %55 %uint_7
+; CHECK: %87 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %86
+; CHECK: OpStore %87 %40
+; CHECK: %89 = OpIAdd %uint %55 %uint_8
+; CHECK: %90 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %89
+; CHECK: OpStore %90 %41
+; CHECK: %92 = OpIAdd %uint %55 %uint_9
+; CHECK: %93 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %92
+; CHECK: OpStore %93 %42
+; CHECK: %95 = OpIAdd %uint %55 %uint_10
+; CHECK: %96 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %95
+; CHECK: OpStore %96 %43
+; CHECK: %98 = OpIAdd %uint %55 %uint_11
+; CHECK: %99 = OpAccessChain %_ptr_StorageBuffer_uint %49 %uint_1 %98
+; CHECK: OpStore %99 %44
+; CHECK: OpBranch %60
+; CHECK: %60 = OpLabel
+; CHECK: OpReturn
+; CHECK: OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<InstDebugPrintfPass>(
+      defs + decorates + globals + main + output_func, true);
+}
+
+// TODO(greg-lunarg): Add tests to verify handling of these cases:
+//
+//   Compute shader
+//   Geometry shader
+//   Tesselation control shader
+//   Tesselation eval shader
+//   Vertex shader
+
+}  // namespace
+}  // namespace opt
+}  // namespace spvtools