Update Offset to ConstOffset bitmask if operand is constant. (#3024)

Update Offset to ConstOffset bitmask if operand is constant.

Fixes #3005
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index eea7238..de740ca 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -36,6 +36,45 @@
 const uint32_t kFMixAIdInIdx = 4;
 const uint32_t kStoreObjectInIdx = 1;
 
+// Some image instructions may contain an "image operands" argument.
+// Returns the operand index for the "image operands".
+// Returns -1 if the instruction does not have image operands.
+int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
+  const auto opcode = inst->opcode();
+  switch (opcode) {
+    case SpvOpImageSampleImplicitLod:
+    case SpvOpImageSampleExplicitLod:
+    case SpvOpImageSampleProjImplicitLod:
+    case SpvOpImageSampleProjExplicitLod:
+    case SpvOpImageFetch:
+    case SpvOpImageRead:
+    case SpvOpImageSparseSampleImplicitLod:
+    case SpvOpImageSparseSampleExplicitLod:
+    case SpvOpImageSparseSampleProjImplicitLod:
+    case SpvOpImageSparseSampleProjExplicitLod:
+    case SpvOpImageSparseFetch:
+    case SpvOpImageSparseRead:
+      return inst->NumOperands() > 4 ? 2 : -1;
+    case SpvOpImageSampleDrefImplicitLod:
+    case SpvOpImageSampleDrefExplicitLod:
+    case SpvOpImageSampleProjDrefImplicitLod:
+    case SpvOpImageSampleProjDrefExplicitLod:
+    case SpvOpImageGather:
+    case SpvOpImageDrefGather:
+    case SpvOpImageSparseSampleDrefImplicitLod:
+    case SpvOpImageSparseSampleDrefExplicitLod:
+    case SpvOpImageSparseSampleProjDrefImplicitLod:
+    case SpvOpImageSparseSampleProjDrefExplicitLod:
+    case SpvOpImageSparseGather:
+    case SpvOpImageSparseDrefGather:
+      return inst->NumOperands() > 5 ? 3 : -1;
+    case SpvOpImageWrite:
+      return inst->NumOperands() > 3 ? 3 : -1;
+    default:
+      return -1;
+  }
+}
+
 // Returns the element width of |type|.
 uint32_t ElementWidth(const analysis::Type* type) {
   if (const analysis::Vector* vec_type = type->AsVector()) {
@@ -2316,6 +2355,64 @@
   };
 }
 
+// If an image instruction's operand is a constant, updates the image operand
+// flag from Offset to ConstOffset.
+FoldingRule UpdateImageOperands() {
+  return [](IRContext*, Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants) {
+    const auto opcode = inst->opcode();
+    (void)opcode;
+    assert((opcode == SpvOpImageSampleImplicitLod ||
+            opcode == SpvOpImageSampleExplicitLod ||
+            opcode == SpvOpImageSampleDrefImplicitLod ||
+            opcode == SpvOpImageSampleDrefExplicitLod ||
+            opcode == SpvOpImageSampleProjImplicitLod ||
+            opcode == SpvOpImageSampleProjExplicitLod ||
+            opcode == SpvOpImageSampleProjDrefImplicitLod ||
+            opcode == SpvOpImageSampleProjDrefExplicitLod ||
+            opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
+            opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
+            opcode == SpvOpImageWrite ||
+            opcode == SpvOpImageSparseSampleImplicitLod ||
+            opcode == SpvOpImageSparseSampleExplicitLod ||
+            opcode == SpvOpImageSparseSampleDrefImplicitLod ||
+            opcode == SpvOpImageSparseSampleDrefExplicitLod ||
+            opcode == SpvOpImageSparseSampleProjImplicitLod ||
+            opcode == SpvOpImageSparseSampleProjExplicitLod ||
+            opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
+            opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
+            opcode == SpvOpImageSparseFetch ||
+            opcode == SpvOpImageSparseGather ||
+            opcode == SpvOpImageSparseDrefGather ||
+            opcode == SpvOpImageSparseRead) &&
+           "Wrong opcode.  Should be an image instruction.");
+
+    int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
+    if (operand_index >= 0) {
+      auto image_operands = inst->GetSingleWordInOperand(operand_index);
+      if (image_operands & SpvImageOperandsOffsetMask) {
+        uint32_t offset_operand_index = operand_index + 1;
+        if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
+        if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
+        if (image_operands & SpvImageOperandsGradMask)
+          offset_operand_index += 2;
+        assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
+               "Offset and ConstOffset may not be used together");
+        if (offset_operand_index < inst->NumOperands()) {
+          if (constants[offset_operand_index]) {
+            image_operands = image_operands | SpvImageOperandsConstOffsetMask;
+            image_operands = image_operands & ~SpvImageOperandsOffsetMask;
+            inst->SetInOperand(operand_index, {image_operands});
+            return true;
+          }
+        }
+      }
+    }
+
+    return false;
+  };
+}
+
 }  // namespace
 
 void FoldingRules::AddFoldingRules() {
@@ -2392,6 +2489,38 @@
 
   rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
 
+  rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
+  rules_[SpvOpImageGather].push_back(UpdateImageOperands());
+  rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
+  rules_[SpvOpImageRead].push_back(UpdateImageOperands());
+  rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
+      UpdateImageOperands());
+  rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
+      UpdateImageOperands());
+  rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
+      UpdateImageOperands());
+  rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
+      UpdateImageOperands());
+  rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
+      UpdateImageOperands());
+  rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
+      UpdateImageOperands());
+  rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
+  rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
+
   FeatureManager* feature_manager = context_->get_feature_mgr();
   // Add rules for GLSLstd450
   uint32_t ext_inst_glslstd450_id =
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index a6e75d8..68a7d18 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -7058,6 +7058,96 @@
                                  1, false)
 ));
 
+std::string ImageOperandsTestBody(const std::string& image_instruction) {
+  std::string body = R"(
+               OpCapability Shader
+               OpCapability ImageGatherExtended
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpDecorate %Texture DescriptorSet 0
+               OpDecorate %Texture Binding 0
+        %int = OpTypeInt 32 1
+     %int_n1 = OpConstant %int -1
+          %5 = OpConstant %int 0
+      %float = OpTypeFloat 32
+    %float_0 = OpConstant %float 0
+%type_2d_image = OpTypeImage %float 2D 2 0 0 1 Unknown
+%type_sampled_image = OpTypeSampledImage %type_2d_image
+%type_sampler = OpTypeSampler
+%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler
+%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image
+   %_ptr_int = OpTypePointer Function %int
+      %v2int = OpTypeVector %int 2
+         %10 = OpTypeVector %float 4
+       %void = OpTypeVoid
+         %22 = OpTypeFunction %void
+    %v2float = OpTypeVector %float 2
+      %v3int = OpTypeVector %int 3
+    %Texture = OpVariable %_ptr_UniformConstant_type_2d_image UniformConstant
+   %gSampler = OpVariable %_ptr_UniformConstant_type_sampler UniformConstant
+        %101 = OpConstantComposite %v2int %int_n1 %int_n1
+         %20 = OpConstantComposite %v2float %float_0 %float_0
+       %main = OpFunction %void None %22
+         %23 = OpLabel
+        %var = OpVariable %_ptr_int Function
+         %88 = OpLoad %type_2d_image %Texture
+        %val = OpLoad %int %var
+    %sampler = OpLoad %type_sampler %gSampler
+         %26 = OpSampledImage %type_sampled_image %88 %sampler
+)" + image_instruction + R"(
+               OpReturn
+               OpFunctionEnd
+)";
+
+  return body;
+}
+
+INSTANTIATE_TEST_SUITE_P(ImageOperandsBitmaskFoldingTest, MatchingInstructionWithNoResultFoldingTest,
+::testing::Values(
+    // Test case 0: OpImageFetch without Offset
+    InstructionFoldingCase<bool>(ImageOperandsTestBody(
+        "%89 = OpImageFetch %10 %88 %101 Lod %5 \n")
+        , 89, false),
+    // Test case 1: OpImageFetch with non-const offset
+    InstructionFoldingCase<bool>(ImageOperandsTestBody(
+        "%89 = OpImageFetch %10 %88 %101 Lod|Offset %5 %val \n")
+        , 89, false),
+    // Test case 2: OpImageFetch with Lod and Offset
+    InstructionFoldingCase<bool>(ImageOperandsTestBody(
+      "         %89 = OpImageFetch %10 %88 %101 Lod|Offset %5 %101      \n"
+      "; CHECK: %89 = OpImageFetch %10 %88 %101 Lod|ConstOffset %5 %101 \n")
+      , 89, true),
+    // Test case 3: OpImageFetch with Bias and Offset
+    InstructionFoldingCase<bool>(ImageOperandsTestBody(
+      "         %89 = OpImageFetch %10 %88 %101 Bias|Offset %5 %101      \n"
+      "; CHECK: %89 = OpImageFetch %10 %88 %101 Bias|ConstOffset %5 %101 \n")
+      , 89, true),
+    // Test case 4: OpImageFetch with Grad and Offset.
+    // Grad adds 2 operands to the instruction.
+    InstructionFoldingCase<bool>(ImageOperandsTestBody(
+      "         %89 = OpImageFetch %10 %88 %101 Grad|Offset %5 %5 %101      \n"
+      "; CHECK: %89 = OpImageFetch %10 %88 %101 Grad|ConstOffset %5 %5 %101 \n")
+      , 89, true),
+    // Test case 5: OpImageFetch with Offset and MinLod.
+    // This is an example of a case where the bitmask bit-offset is larger than
+    // that of the Offset.
+    InstructionFoldingCase<bool>(ImageOperandsTestBody(
+      "         %89 = OpImageFetch %10 %88 %101 Offset|MinLod %101 %5      \n"
+      "; CHECK: %89 = OpImageFetch %10 %88 %101 ConstOffset|MinLod %101 %5 \n")
+      , 89, true),
+    // Test case 6: OpImageGather with constant Offset
+    InstructionFoldingCase<bool>(ImageOperandsTestBody(
+      "         %89 = OpImageGather %10 %26 %20 %5 Offset %101      \n"
+      "; CHECK: %89 = OpImageGather %10 %26 %20 %5 ConstOffset %101 \n")
+      , 89, true),
+    // Test case 7: OpImageWrite with constant Offset
+    InstructionFoldingCase<bool>(ImageOperandsTestBody(
+      "         OpImageWrite %88 %5 %101 Offset %101      \n"
+      "; CHECK: OpImageWrite %88 %5 %101 ConstOffset %101 \n")
+      , 0 /* No result-id */, true)
+));
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools