Don't move debug or decorations when folding (#2772)

Fixes #2764

* Don't replace all uses when simplifying instructions, instead only
update non-debug, non-decoration uses
  * added a test
* Add a new version of RAUW that takes a predicate to decide whether to
replace the use or not
  * used in simplification pass
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index 20309ca..b600f12 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -190,6 +190,13 @@
 }
 
 bool IRContext::ReplaceAllUsesWith(uint32_t before, uint32_t after) {
+  return ReplaceAllUsesWithPredicate(
+      before, after, [](Instruction*, uint32_t) { return true; });
+}
+
+bool IRContext::ReplaceAllUsesWithPredicate(
+    uint32_t before, uint32_t after,
+    const std::function<bool(Instruction*, uint32_t)>& predicate) {
   if (before == after) return false;
 
   // Ensure that |after| has been registered as def.
@@ -198,8 +205,10 @@
 
   std::vector<std::pair<Instruction*, uint32_t>> uses_to_update;
   get_def_use_mgr()->ForEachUse(
-      before, [&uses_to_update](Instruction* user, uint32_t index) {
-        uses_to_update.emplace_back(user, index);
+      before, [&predicate, &uses_to_update](Instruction* user, uint32_t index) {
+        if (predicate(user, index)) {
+          uses_to_update.emplace_back(user, index);
+        }
       });
 
   Instruction* prev = nullptr;
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 37c6449..308f633 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -382,6 +382,15 @@
   // |before| and |after| must be registered definitions in the DefUseManager.
   bool ReplaceAllUsesWith(uint32_t before, uint32_t after);
 
+  // Replace all uses of |before| id with |after| id if those uses
+  // (instruction, operand pair) return true for |predicate|. Returns true if
+  // any replacement happens. This method does not kill the definition of the
+  // |before| id. If |after| is the same as |before|, does nothing and return
+  // false.
+  bool ReplaceAllUsesWithPredicate(
+      uint32_t before, uint32_t after,
+      const std::function<bool(Instruction*, uint32_t)>& predicate);
+
   // Returns true if all of the analyses that are suppose to be valid are
   // actually valid.
   bool IsConsistent();
diff --git a/source/opt/simplification_pass.cpp b/source/opt/simplification_pass.cpp
index 6ea4566..5780e5d 100644
--- a/source/opt/simplification_pass.cpp
+++ b/source/opt/simplification_pass.cpp
@@ -71,8 +71,16 @@
               }
             });
             if (inst->opcode() == SpvOpCopyObject) {
-              context()->ReplaceAllUsesWith(inst->result_id(),
-                                            inst->GetSingleWordInOperand(0));
+              context()->ReplaceAllUsesWithPredicate(
+                  inst->result_id(), inst->GetSingleWordInOperand(0),
+                  [](Instruction* user, uint32_t) {
+                    const auto opcode = user->opcode();
+                    if (!spvOpcodeIsDebug(opcode) &&
+                        !spvOpcodeIsDecoration(opcode)) {
+                      return true;
+                    }
+                    return false;
+                  });
               inst_to_kill.insert(inst);
               in_work_list.insert(inst);
             } else if (inst->opcode() == SpvOpNop) {
@@ -107,8 +115,15 @@
           });
 
       if (inst->opcode() == SpvOpCopyObject) {
-        context()->ReplaceAllUsesWith(inst->result_id(),
-                                      inst->GetSingleWordInOperand(0));
+        context()->ReplaceAllUsesWithPredicate(
+            inst->result_id(), inst->GetSingleWordInOperand(0),
+            [](Instruction* user, uint32_t) {
+              const auto opcode = user->opcode();
+              if (!spvOpcodeIsDebug(opcode) && !spvOpcodeIsDecoration(opcode)) {
+                return true;
+              }
+              return false;
+            });
         inst_to_kill.insert(inst);
         in_work_list.insert(inst);
       } else if (inst->opcode() == SpvOpNop) {
diff --git a/test/opt/simplification_test.cpp b/test/opt/simplification_test.cpp
index 4dbcfbe..1420498 100644
--- a/test/opt/simplification_test.cpp
+++ b/test/opt/simplification_test.cpp
@@ -279,6 +279,52 @@
   SinglePassRunAndCheck<SimplificationPass>(before, after, false);
 }
 
+TEST_F(SimplificationTest, DontMoveDecorations) {
+  const std::string spirv = R"(
+; CHECK-NOT: RelaxedPrecision
+; CHECK: [[sub:%\w+]] = OpFSub
+; CHECK: OpStore {{.*}} [[sub]]
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+OpDecorate %add RelaxedPrecision
+OpDecorate %block Block
+OpMemberDecorate %block 0 Offset 0
+OpMemberDecorate %block 1 Offset 4
+OpDecorate %in DescriptorSet 0
+OpDecorate %in Binding 0
+OpDecorate %out DescriptorSet 0
+OpDecorate %out Binding 1
+%void = OpTypeVoid
+%float = OpTypeFloat 32
+%void_fn = OpTypeFunction %void
+%block = OpTypeStruct %float %float
+%ptr_ssbo_block = OpTypePointer StorageBuffer %block
+%in = OpVariable %ptr_ssbo_block StorageBuffer
+%out = OpVariable %ptr_ssbo_block StorageBuffer
+%ptr_ssbo_float = OpTypePointer StorageBuffer %float
+%int = OpTypeInt 32 0
+%int_0 = OpConstant %int 0
+%int_1 = OpConstant %int 1
+%float_0 = OpConstant %float 0
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+%in_gep_0 = OpAccessChain %ptr_ssbo_float %in %int_0
+%in_gep_1 = OpAccessChain %ptr_ssbo_float %in %int_1
+%load_0 = OpLoad %float %in_gep_0
+%load_1 = OpLoad %float %in_gep_1
+%sub = OpFSub %float %load_0 %load_1
+%add = OpFAdd %float %float_0 %sub
+%out_gep_0 = OpAccessChain %ptr_ssbo_float %out %int_0
+OpStore %out_gep_0 %add
+OpReturn
+OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<SimplificationPass>(spirv, true);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools