Add interpolate legalization pass (#4220)

This pass converts an internal form of GLSLstd450 Interpolate ops
to the externally valid form. The external form takes the lvalue
of the interpolant. The internal form can do a load of the interpolant.
The pass replaces the load with its pointer. The internal form is
generated by glslang and possibly other frontends for HLSL shaders.
The new pass is called as part of HLSL legalization after all
propagation is complete.

Also adds internal interpolate form to pre-legalization validation
diff --git a/Android.mk b/Android.mk
index 1000f42..ef1cdff 100644
--- a/Android.mk
+++ b/Android.mk
@@ -122,6 +122,7 @@
 		source/opt/instruction.cpp \
 		source/opt/instruction_list.cpp \
 		source/opt/instrument_pass.cpp \
+		source/opt/interp_fixup_pass.cpp \
 		source/opt/ir_context.cpp \
 		source/opt/ir_loader.cpp \
 		source/opt/licm_pass.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index 9f07c94..0f6bad1 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -609,6 +609,8 @@
     "source/opt/instruction_list.h",
     "source/opt/instrument_pass.cpp",
     "source/opt/instrument_pass.h",
+    "source/opt/interp_fixup_pass.cpp",
+    "source/opt/interp_fixup_pass.h",
     "source/opt/ir_builder.h",
     "source/opt/ir_context.cpp",
     "source/opt/ir_context.h",
diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h
index 2891cbe..8b30dcb 100644
--- a/include/spirv-tools/libspirv.h
+++ b/include/spirv-tools/libspirv.h
@@ -588,6 +588,8 @@
 // 3) Pointers that are actaul parameters on function calls do not have to point
 //    to the same type pointed as the formal parameter.  The types just need to
 //    logically match.
+// 4) GLSLstd450 Interpolate* instructions can have a load of an interpolant
+//    for a first argument.
 SPIRV_TOOLS_EXPORT void spvValidatorOptionsSetBeforeHlslLegalization(
     spv_validator_options options, bool val);
 
diff --git a/include/spirv-tools/libspirv.hpp b/include/spirv-tools/libspirv.hpp
index e7e7fc7..0c31a18 100644
--- a/include/spirv-tools/libspirv.hpp
+++ b/include/spirv-tools/libspirv.hpp
@@ -136,6 +136,8 @@
   // 3) Pointers that are actaul parameters on function calls do not have to
   //    point to the same type pointed as the formal parameter.  The types just
   //    need to logically match.
+  // 4) GLSLstd450 Interpolate* instructions can have a load of an interpolant
+  //    for a first argument.
   void SetBeforeHlslLegalization(bool val) {
     spvValidatorOptionsSetBeforeHlslLegalization(options_, val);
   }
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index 1683d07..e8b5b69 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -838,6 +838,15 @@
 // capabilities.
 Optimizer::PassToken CreateAmdExtToKhrPass();
 
+// Replaces the internal version of GLSLstd450 InterpolateAt* extended
+// instructions with the externally valid version. The internal version allows
+// an OpLoad of the interpolant for the first argument. This pass removes the
+// OpLoad and replaces it with its pointer. glslang and possibly other
+// frontends will create the internal version for HLSL. This pass will be part
+// of HLSL legalization and should be called after interpolants have been
+// propagated into their final positions.
+Optimizer::PassToken CreateInterpolateFixupPass();
+
 }  // namespace spvtools
 
 #endif  // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 14a6bee..88d5658 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -62,6 +62,7 @@
   instruction.h
   instruction_list.h
   instrument_pass.h
+  interp_fixup_pass.h
   ir_builder.h
   ir_context.h
   ir_loader.h
@@ -165,6 +166,7 @@
   instruction.cpp
   instruction_list.cpp
   instrument_pass.cpp
+  interp_fixup_pass.cpp
   ir_context.cpp
   ir_loader.cpp
   licm_pass.cpp
diff --git a/source/opt/interp_fixup_pass.cpp b/source/opt/interp_fixup_pass.cpp
new file mode 100644
index 0000000..ad29e6a
--- /dev/null
+++ b/source/opt/interp_fixup_pass.cpp
@@ -0,0 +1,131 @@
+// Copyright (c) 2021 The Khronos Group Inc.
+// Copyright (c) 2021 Valve Corporation
+// Copyright (c) 2021 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/interp_fixup_pass.h"
+
+#include <set>
+#include <string>
+
+#include "ir_builder.h"
+#include "source/opt/ir_context.h"
+#include "type_manager.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+
+// Input Operand Indices
+static const int kSpvVariableStorageClassInIdx = 0;
+
+// Avoid unused variable warning/error on Linux
+#ifndef NDEBUG
+#define USE_ASSERT(x) assert(x)
+#else
+#define USE_ASSERT(x) ((void)(x))
+#endif
+
+// Folding rule function which attempts to replace |op(OpLoad(a),...)|
+// by |op(a,...)|, where |op| is one of the GLSLstd450 InterpolateAt*
+// instructions. Returns true if replaced, false otherwise.
+bool ReplaceInternalInterpolate(IRContext* ctx, Instruction* inst,
+                                const std::vector<const analysis::Constant*>&) {
+  uint32_t glsl450_ext_inst_id =
+      ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+  assert(glsl450_ext_inst_id != 0);
+
+  uint32_t ext_opcode = inst->GetSingleWordInOperand(1);
+
+  uint32_t op1_id = inst->GetSingleWordInOperand(2);
+
+  Instruction* load_inst = ctx->get_def_use_mgr()->GetDef(op1_id);
+  if (load_inst->opcode() != SpvOpLoad) return false;
+
+  Instruction* base_inst = load_inst->GetBaseAddress();
+  USE_ASSERT(base_inst->opcode() == SpvOpVariable &&
+             base_inst->GetSingleWordInOperand(kSpvVariableStorageClassInIdx) ==
+                 SpvStorageClassInput &&
+             "unexpected interpolant in InterpolateAt*");
+
+  uint32_t ptr_id = load_inst->GetSingleWordInOperand(0);
+  uint32_t op2_id = (ext_opcode != GLSLstd450InterpolateAtCentroid)
+                        ? inst->GetSingleWordInOperand(3)
+                        : 0;
+
+  Instruction::OperandList new_operands;
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl450_ext_inst_id}});
+  new_operands.push_back(
+      {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {ext_opcode}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {ptr_id}});
+  if (op2_id != 0) new_operands.push_back({SPV_OPERAND_TYPE_ID, {op2_id}});
+
+  inst->SetInOperands(std::move(new_operands));
+  ctx->UpdateDefUse(inst);
+  return true;
+}
+
+class InterpFoldingRules : public FoldingRules {
+ public:
+  explicit InterpFoldingRules(IRContext* ctx) : FoldingRules(ctx) {}
+
+ protected:
+  virtual void AddFoldingRules() override {
+    uint32_t extension_id =
+        context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+
+    if (extension_id != 0) {
+      ext_rules_[{extension_id, GLSLstd450InterpolateAtCentroid}].push_back(
+          ReplaceInternalInterpolate);
+      ext_rules_[{extension_id, GLSLstd450InterpolateAtSample}].push_back(
+          ReplaceInternalInterpolate);
+      ext_rules_[{extension_id, GLSLstd450InterpolateAtOffset}].push_back(
+          ReplaceInternalInterpolate);
+    }
+  }
+};
+
+class InterpConstFoldingRules : public ConstantFoldingRules {
+ public:
+  InterpConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {}
+
+ protected:
+  virtual void AddFoldingRules() override {}
+};
+
+}  // namespace
+
+Pass::Status InterpFixupPass::Process() {
+  bool changed = false;
+
+  // Traverse the body of the functions to replace instructions that require
+  // the extensions.
+  InstructionFolder folder(
+      context(),
+      std::unique_ptr<InterpFoldingRules>(new InterpFoldingRules(context())),
+      MakeUnique<InterpConstFoldingRules>(context()));
+  for (Function& func : *get_module()) {
+    func.ForEachInst([&changed, &folder](Instruction* inst) {
+      if (folder.FoldInstruction(inst)) {
+        changed = true;
+      }
+    });
+  }
+
+  return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/interp_fixup_pass.h b/source/opt/interp_fixup_pass.h
new file mode 100644
index 0000000..e112b65
--- /dev/null
+++ b/source/opt/interp_fixup_pass.h
@@ -0,0 +1,54 @@
+// Copyright (c) 2021 The Khronos Group Inc.
+// Copyright (c) 2021 Valve Corporation
+// Copyright (c) 2021 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_INTERP_FIXUP_H
+#define SOURCE_OPT_INTERP_FIXUP_H
+
+#include "source/opt/ir_context.h"
+#include "source/opt/module.h"
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// Replaces overloaded internal form for GLSLstd450Interpolate* instructions
+// with external form. Specifically, removes OpLoad from the first argument
+// and replaces it with the pointer for the OpLoad. glslang generates the
+// internal form. This pass is called as part of glslang HLSL legalization.
+class InterpFixupPass : public Pass {
+ public:
+  const char* name() const override { return "interp-fixup"; }
+  Status Process() override;
+
+  IRContext::Analysis GetPreservedAnalyses() override {
+    return IRContext::kAnalysisInstrToBlockMapping |
+           IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
+           IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis |
+           IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap |
+           IRContext::kAnalysisScalarEvolution |
+           IRContext::kAnalysisRegisterPressure |
+           IRContext::kAnalysisValueNumberTable |
+           IRContext::kAnalysisStructuredCFG |
+           IRContext::kAnalysisBuiltinVarId |
+           IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisTypes |
+           IRContext::kAnalysisDefUse | IRContext::kAnalysisConstants;
+  }
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // SOURCE_OPT_INTERP_FIXUP_H
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index 909442c..a5d10c3 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -155,7 +155,8 @@
           .RegisterPass(CreateVectorDCEPass())
           .RegisterPass(CreateDeadInsertElimPass())
           .RegisterPass(CreateReduceLoadSizePass())
-          .RegisterPass(CreateAggressiveDCEPass());
+          .RegisterPass(CreateAggressiveDCEPass())
+          .RegisterPass(CreateInterpolateFixupPass());
 }
 
 Optimizer& Optimizer::RegisterPerformancePasses() {
@@ -494,6 +495,8 @@
     RegisterPass(CreateWrapOpKillPass());
   } else if (pass_name == "amd-ext-to-khr") {
     RegisterPass(CreateAmdExtToKhrPass());
+  } else if (pass_name == "interpolate-fixup") {
+    RegisterPass(CreateInterpolateFixupPass());
   } else {
     Errorf(consumer(), nullptr, {},
            "Unknown flag '--%s'. Use --help for a list of valid flags",
@@ -925,4 +928,9 @@
       MakeUnique<opt::AmdExtensionToKhrPass>());
 }
 
+Optimizer::PassToken CreateInterpolateFixupPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::InterpFixupPass>());
+}
+
 }  // namespace spvtools
diff --git a/source/opt/passes.h b/source/opt/passes.h
index 1bc94c7..bfb34af 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -46,6 +46,7 @@
 #include "source/opt/inst_bindless_check_pass.h"
 #include "source/opt/inst_buff_addr_check_pass.h"
 #include "source/opt/inst_debug_printf_pass.h"
+#include "source/opt/interp_fixup_pass.h"
 #include "source/opt/licm_pass.h"
 #include "source/opt/local_access_chain_convert_pass.h"
 #include "source/opt/local_redundancy_elimination.h"
diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp
index dc8c024..a7167fc 100644
--- a/source/val/validate_extensions.cpp
+++ b/source/val/validate_extensions.cpp
@@ -692,8 +692,8 @@
     if (extension ==
         ExtensionToString(kSPV_KHR_workgroup_memory_explicit_layout)) {
       return _.diag(SPV_ERROR_WRONG_VERSION, inst)
-          << "SPV_KHR_workgroup_memory_explicit_layout extension "
-             "requires SPIR-V version 1.4 or later.";
+             << "SPV_KHR_workgroup_memory_explicit_layout extension "
+                "requires SPIR-V version 1.4 or later.";
     }
   }
 
@@ -1372,7 +1372,16 @@
                  << "or vector type";
         }
 
-        const uint32_t interpolant_type = _.GetOperandTypeId(inst, 4);
+        // If HLSL legalization and first operand is an OpLoad, use load
+        // pointer as the interpolant lvalue. Else use interpolate first
+        // operand.
+        uint32_t interp_id = inst->GetOperandAs<uint32_t>(4);
+        auto* interp_inst = _.FindDef(interp_id);
+        uint32_t interpolant_type = (_.options()->before_hlsl_legalization &&
+                                     interp_inst->opcode() == SpvOpLoad)
+                                        ? _.GetOperandTypeId(interp_inst, 2)
+                                        : _.GetOperandTypeId(inst, 4);
+
         uint32_t interpolant_storage_class = 0;
         uint32_t interpolant_data_type = 0;
         if (!_.GetPointerTypeInfo(interpolant_type, &interpolant_data_type,
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index 79cb3fc..f65d2ff 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -57,6 +57,7 @@
        inst_debug_printf_test.cpp
        instruction_list_test.cpp
        instruction_test.cpp
+       interp_fixup_test.cpp
        ir_builder.cpp
        ir_context_test.cpp
        ir_loader_test.cpp
diff --git a/test/opt/interp_fixup_test.cpp b/test/opt/interp_fixup_test.cpp
new file mode 100644
index 0000000..a43a29c
--- /dev/null
+++ b/test/opt/interp_fixup_test.cpp
@@ -0,0 +1,172 @@
+// Copyright (c) 2021 The Khronos Group Inc.
+// Copyright (c) 2021 Valve Corporation
+// Copyright (c) 2021 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using InterpFixupTest = PassTest<::testing::Test>;
+
+using ::testing::HasSubstr;
+
+TEST_F(InterpFixupTest, FixInterpAtSample) {
+  const std::string text = R"(
+               OpCapability Shader
+               OpCapability InterpolationFunction
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %MainPs "MainPs" %i_vPositionOs %_entryPointOutput
+               OpExecutionMode %MainPs OriginUpperLeft
+               OpSource HLSL 500
+               OpName %MainPs "MainPs"
+               OpName %i_vPositionOs "i.vPositionOs"
+               OpName %_entryPointOutput "@entryPointOutput"
+               OpDecorate %i_vPositionOs Location 0
+               OpDecorate %_entryPointOutput Location 0
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+    %float_0 = OpConstant %float 0
+         %10 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+     %uint_4 = OpConstant %uint 4
+       %bool = OpTypeBool
+        %int = OpTypeInt 32 1
+      %int_1 = OpConstant %int 1
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%i_vPositionOs = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%_entryPointOutput = OpVariable %_ptr_Output_v4float Output
+     %MainPs = OpFunction %void None %6
+         %19 = OpLabel
+         %20 = OpLoad %v4float %i_vPositionOs
+               OpBranch %21
+         %21 = OpLabel
+         %22 = OpPhi %v4float %10 %19 %23 %24
+         %25 = OpPhi %uint %uint_0 %19 %26 %24
+         %27 = OpULessThan %bool %25 %uint_4
+               OpLoopMerge %28 %24 None
+               OpBranchConditional %27 %24 %28
+         %24 = OpLabel
+         %29 = OpExtInst %v4float %1 InterpolateAtSample %20 %25
+;CHECK:  %29 = OpExtInst %v4float %1 InterpolateAtSample %i_vPositionOs %25
+         %30 = OpCompositeExtract %float %29 0
+         %31 = OpCompositeExtract %float %22 0
+         %32 = OpFAdd %float %31 %30
+         %23 = OpCompositeInsert %v4float %32 %22 0
+         %26 = OpIAdd %uint %25 %int_1
+               OpBranch %21
+         %28 = OpLabel
+               OpStore %_entryPointOutput %22
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<InterpFixupPass>(text, false);
+}
+
+TEST_F(InterpFixupTest, FixInterpAtCentroid) {
+  const std::string text = R"(
+               OpCapability Shader
+               OpCapability InterpolationFunction
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %MainPs "MainPs" %i_vPositionOs %_entryPointOutput
+               OpExecutionMode %MainPs OriginUpperLeft
+               OpSource HLSL 500
+               OpName %MainPs "MainPs"
+               OpName %i_vPositionOs "i.vPositionOs"
+               OpName %_entryPointOutput "@entryPointOutput"
+               OpDecorate %i_vPositionOs Location 0
+               OpDecorate %_entryPointOutput Location 0
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+    %float_0 = OpConstant %float 0
+         %10 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%i_vPositionOs = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%_entryPointOutput = OpVariable %_ptr_Output_v4float Output
+     %MainPs = OpFunction %void None %6
+         %13 = OpLabel
+         %14 = OpLoad %v4float %i_vPositionOs
+         %15 = OpExtInst %v4float %1 InterpolateAtCentroid %14
+;CHECK:  %15 = OpExtInst %v4float %1 InterpolateAtCentroid %i_vPositionOs
+         %16 = OpCompositeExtract %float %15 0
+         %17 = OpCompositeInsert %v4float %16 %10 0
+               OpStore %_entryPointOutput %17
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<InterpFixupPass>(text, false);
+}
+
+TEST_F(InterpFixupTest, FixInterpAtOffset) {
+  const std::string text = R"(
+               OpCapability Shader
+               OpCapability InterpolationFunction
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %MainPs "MainPs" %i_vPositionOs %_entryPointOutput
+               OpExecutionMode %MainPs OriginUpperLeft
+               OpSource HLSL 500
+               OpName %MainPs "MainPs"
+               OpName %i_vPositionOs "i.vPositionOs"
+               OpName %_entryPointOutput "@entryPointOutput"
+               OpDecorate %i_vPositionOs Location 0
+               OpDecorate %_entryPointOutput Location 0
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+    %float_0 = OpConstant %float 0
+         %10 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+    %v2float = OpTypeVector %float 2
+%float_0_0625 = OpConstant %float 0.0625
+         %13 = OpConstantComposite %v2float %float_0_0625 %float_0_0625
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%i_vPositionOs = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%_entryPointOutput = OpVariable %_ptr_Output_v4float Output
+     %MainPs = OpFunction %void None %6
+         %16 = OpLabel
+         %17 = OpLoad %v4float %i_vPositionOs
+         %18 = OpExtInst %v4float %1 InterpolateAtOffset %17 %13
+;CHECK:  %18 = OpExtInst %v4float %1 InterpolateAtOffset %i_vPositionOs %13
+         %19 = OpCompositeExtract %float %18 0
+         %20 = OpCompositeInsert %v4float %19 %10 0
+               OpStore %_entryPointOutput %20
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<InterpFixupPass>(text, false);
+}
+
+}  // namespace
+}  // namespace opt
+}  // namespace spvtools
diff --git a/test/val/val_ext_inst_test.cpp b/test/val/val_ext_inst_test.cpp
index 683a76f..b73ec34 100644
--- a/test/val/val_ext_inst_test.cpp
+++ b/test/val/val_ext_inst_test.cpp
@@ -5204,6 +5204,49 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidInternalSuccess) {
+  const std::string body = R"(
+%ld1  = OpLoad %f32 %f32_input
+%val1 = OpExtInst %f32 %extinst InterpolateAtCentroid %ld1
+%ld2  = OpLoad %f32vec2 %f32vec2_input
+%val2 = OpExtInst %f32vec2 %extinst InterpolateAtCentroid %ld2
+)";
+
+  CompileSuccessfully(
+      GenerateShaderCode(body, "OpCapability InterpolationFunction\n"));
+  getValidatorOptions()->before_hlsl_legalization = true;
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidInternalInvalidDataF32) {
+  const std::string body = R"(
+%ld1  = OpLoad %f32 %f32_input
+%val1 = OpExtInst %f32 %extinst InterpolateAtCentroid %ld1
+)";
+
+  CompileSuccessfully(
+      GenerateShaderCode(body, "OpCapability InterpolationFunction\n"));
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("GLSL.std.450 InterpolateAtCentroid: "
+                        "expected Interpolant to be a pointer"));
+}
+
+TEST_F(ValidateExtInst,
+       GlslStd450InterpolateAtCentroidInternalInvalidDataF32Vec2) {
+  const std::string body = R"(
+%ld2  = OpLoad %f32vec2 %f32vec2_input
+%val2 = OpExtInst %f32vec2 %extinst InterpolateAtCentroid %ld2
+)";
+
+  CompileSuccessfully(
+      GenerateShaderCode(body, "OpCapability InterpolationFunction\n"));
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("GLSL.std.450 InterpolateAtCentroid: "
+                        "expected Interpolant to be a pointer"));
+}
+
 TEST_F(ValidateExtInst, GlslStd450InterpolateAtCentroidNoCapability) {
   const std::string body = R"(
 %val1 = OpExtInst %f32 %extinst InterpolateAtCentroid %f32_input
@@ -5308,6 +5351,49 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleInternalSuccess) {
+  const std::string body = R"(
+%ld1  = OpLoad %f32 %f32_input
+%val1 = OpExtInst %f32 %extinst InterpolateAtSample %ld1 %u32_1
+%ld2  = OpLoad %f32vec2 %f32vec2_input
+%val2 = OpExtInst %f32vec2 %extinst InterpolateAtSample %ld2 %u32_1
+)";
+
+  CompileSuccessfully(
+      GenerateShaderCode(body, "OpCapability InterpolationFunction\n"));
+  getValidatorOptions()->before_hlsl_legalization = true;
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleInternalInvalidDataF32) {
+  const std::string body = R"(
+%ld1  = OpLoad %f32 %f32_input
+%val1 = OpExtInst %f32 %extinst InterpolateAtSample %ld1 %u32_1
+)";
+
+  CompileSuccessfully(
+      GenerateShaderCode(body, "OpCapability InterpolationFunction\n"));
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("GLSL.std.450 InterpolateAtSample: "
+                        "expected Interpolant to be a pointer"));
+}
+
+TEST_F(ValidateExtInst,
+       GlslStd450InterpolateAtSampleInternalInvalidDataF32Vec2) {
+  const std::string body = R"(
+%ld2  = OpLoad %f32vec2 %f32vec2_input
+%val2 = OpExtInst %f32vec2 %extinst InterpolateAtSample %ld2 %u32_1
+)";
+
+  CompileSuccessfully(
+      GenerateShaderCode(body, "OpCapability InterpolationFunction\n"));
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("GLSL.std.450 InterpolateAtSample: "
+                        "expected Interpolant to be a pointer"));
+}
+
 TEST_F(ValidateExtInst, GlslStd450InterpolateAtSampleNoCapability) {
   const std::string body = R"(
 %val1 = OpExtInst %f32 %extinst InterpolateAtSample %f32_input %u32_1
@@ -5438,6 +5524,48 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetInternalSuccess) {
+  const std::string body = R"(
+%ld1  = OpLoad %f32 %f32_input
+%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %ld1 %f32vec2_01
+%ld2  = OpLoad %f32vec2 %f32vec2_input
+%val2 = OpExtInst %f32vec2 %extinst InterpolateAtOffset %ld2 %f32vec2_01
+)";
+
+  CompileSuccessfully(
+      GenerateShaderCode(body, "OpCapability InterpolationFunction\n"));
+  getValidatorOptions()->before_hlsl_legalization = true;
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetInternalInvalidDataF32) {
+  const std::string body = R"(
+%ld1  = OpLoad %f32 %f32_input
+%val1 = OpExtInst %f32 %extinst InterpolateAtOffset %ld1 %f32vec2_01
+)";
+
+  CompileSuccessfully(
+      GenerateShaderCode(body, "OpCapability InterpolationFunction\n"));
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("GLSL.std.450 InterpolateAtOffset: "
+                        "expected Interpolant to be a pointer"));
+}
+
+TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetInternalInvalidDataF32Vec2) {
+  const std::string body = R"(
+%ld2  = OpLoad %f32vec2 %f32vec2_input
+%val2 = OpExtInst %f32vec2 %extinst InterpolateAtOffset %ld2 %f32vec2_01
+)";
+
+  CompileSuccessfully(
+      GenerateShaderCode(body, "OpCapability InterpolationFunction\n"));
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("GLSL.std.450 InterpolateAtOffset: "
+                        "expected Interpolant to be a pointer"));
+}
+
 TEST_F(ValidateExtInst, GlslStd450InterpolateAtOffsetNoCapability) {
   const std::string body = R"(
 %val1 = OpExtInst %f32 %extinst InterpolateAtOffset %f32_input %f32vec2_01