Replace uses of SPV_AMD_shader_trinary_minmax extension (#2835)

Part of #2814
diff --git a/source/opt/amd_ext_to_khr.cpp b/source/opt/amd_ext_to_khr.cpp
index 1cb5ba5..3a49858 100644
--- a/source/opt/amd_ext_to_khr.cpp
+++ b/source/opt/amd_ext_to_khr.cpp
@@ -24,22 +24,117 @@
 
 namespace {
 
-enum ExtOpcodes {
+enum AmdShaderBallotExtOpcodes {
   AmdShaderBallotSwizzleInvocationsAMD = 1,
   AmdShaderBallotSwizzleInvocationsMaskedAMD = 2,
   AmdShaderBallotWriteInvocationAMD = 3,
   AmdShaderBallotMbcntAMD = 4
 };
 
+enum AmdShaderTrinaryMinMaxExtOpCodes {
+  FMin3AMD = 1,
+  UMin3AMD = 2,
+  SMin3AMD = 3,
+  FMax3AMD = 4,
+  UMax3AMD = 5,
+  SMax3AMD = 6,
+  FMid3AMD = 7,
+  UMid3AMD = 8,
+  SMid3AMD = 9
+};
+
 analysis::Type* GetUIntType(IRContext* ctx) {
   analysis::Integer int_type(32, false);
   return ctx->get_type_mgr()->GetRegisteredType(&int_type);
 }
 
+// Returns a folding rule that replaces |op(a,b,c)| by |op(op(a,b),c)|, where
+// |op| is either min or max. |opcode| is the binary opcode in the GLSLstd450
+// extended instruction set that corresponds to the trinary instruction being
+// replaced.
+template <GLSLstd450 opcode>
+bool ReplaceTrinaryMinMax(IRContext* ctx, Instruction* inst,
+                          const std::vector<const analysis::Constant*>&) {
+  uint32_t glsl405_ext_inst_id =
+      ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+  if (glsl405_ext_inst_id == 0) {
+    ctx->AddExtInstImport("GLSL.std.450");
+    glsl405_ext_inst_id =
+        ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+  }
+
+  InstructionBuilder ir_builder(
+      ctx, inst,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+  uint32_t op1 = inst->GetSingleWordInOperand(2);
+  uint32_t op2 = inst->GetSingleWordInOperand(3);
+  uint32_t op3 = inst->GetSingleWordInOperand(4);
+
+  Instruction* temp = ir_builder.AddNaryExtendedInstruction(
+      inst->type_id(), glsl405_ext_inst_id, opcode, {op1, op2});
+
+  Instruction::OperandList new_operands;
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
+  new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
+                          {static_cast<uint32_t>(opcode)}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {temp->result_id()}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {op3}});
+
+  inst->SetInOperands(std::move(new_operands));
+  ctx->UpdateDefUse(inst);
+  return true;
+}
+
+// Returns a folding rule that replaces |mid(a,b,c)| by |clamp(a, min(b,c),
+// max(b,c)|. The three parameters are the opcode that correspond to the min,
+// max, and clamp operations for the type of the instruction being replaced.
+template <GLSLstd450 min_opcode, GLSLstd450 max_opcode, GLSLstd450 clamp_opcode>
+bool ReplaceTrinaryMid(IRContext* ctx, Instruction* inst,
+                       const std::vector<const analysis::Constant*>&) {
+  uint32_t glsl405_ext_inst_id =
+      ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+  if (glsl405_ext_inst_id == 0) {
+    ctx->AddExtInstImport("GLSL.std.450");
+    glsl405_ext_inst_id =
+        ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+  }
+
+  InstructionBuilder ir_builder(
+      ctx, inst,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+  uint32_t op1 = inst->GetSingleWordInOperand(2);
+  uint32_t op2 = inst->GetSingleWordInOperand(3);
+  uint32_t op3 = inst->GetSingleWordInOperand(4);
+
+  Instruction* min = ir_builder.AddNaryExtendedInstruction(
+      inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(min_opcode),
+      {op2, op3});
+  Instruction* max = ir_builder.AddNaryExtendedInstruction(
+      inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(max_opcode),
+      {op2, op3});
+
+  Instruction::OperandList new_operands;
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
+  new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
+                          {static_cast<uint32_t>(clamp_opcode)}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {op1}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {min->result_id()}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {max->result_id()}});
+
+  inst->SetInOperands(std::move(new_operands));
+  ctx->UpdateDefUse(inst);
+  return true;
+}
+
 // Returns a folding rule that will replace the opcode with |opcode| and add
 // the capabilities required.  The folding rule assumes it is folding an
 // OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension.
-FoldingRule ReplaceGroupNonuniformOperationOpCode(SpvOp new_opcode) {
+template <SpvOp new_opcode>
+bool ReplaceGroupNonuniformOperationOpCode(
+    IRContext* ctx, Instruction* inst,
+    const std::vector<const analysis::Constant*>&) {
   switch (new_opcode) {
     case SpvOpGroupNonUniformIAdd:
     case SpvOpGroupNonUniformFAdd:
@@ -56,27 +151,24 @@
           "Should be replacing with a group non uniform arithmetic operation.");
   }
 
-  return [new_opcode](IRContext* ctx, Instruction* inst,
-                      const std::vector<const analysis::Constant*>&) {
-    switch (inst->opcode()) {
-      case SpvOpGroupIAddNonUniformAMD:
-      case SpvOpGroupFAddNonUniformAMD:
-      case SpvOpGroupUMinNonUniformAMD:
-      case SpvOpGroupSMinNonUniformAMD:
-      case SpvOpGroupFMinNonUniformAMD:
-      case SpvOpGroupUMaxNonUniformAMD:
-      case SpvOpGroupSMaxNonUniformAMD:
-      case SpvOpGroupFMaxNonUniformAMD:
-        break;
-      default:
-        assert(false &&
-               "Should be replacing a group non uniform arithmetic operation.");
-    }
+  switch (inst->opcode()) {
+    case SpvOpGroupIAddNonUniformAMD:
+    case SpvOpGroupFAddNonUniformAMD:
+    case SpvOpGroupUMinNonUniformAMD:
+    case SpvOpGroupSMinNonUniformAMD:
+    case SpvOpGroupFMinNonUniformAMD:
+    case SpvOpGroupUMaxNonUniformAMD:
+    case SpvOpGroupSMaxNonUniformAMD:
+    case SpvOpGroupFMaxNonUniformAMD:
+      break;
+    default:
+      assert(false &&
+             "Should be replacing a group non uniform arithmetic operation.");
+  }
 
-    ctx->AddCapability(SpvCapabilityGroupNonUniformArithmetic);
-    inst->SetOpcode(new_opcode);
-    return true;
-  };
+  ctx->AddCapability(SpvCapabilityGroupNonUniformArithmetic);
+  inst->SetOpcode(new_opcode);
+  return true;
 }
 
 // Returns a folding rule that will replace the SwizzleInvocationsAMD extended
@@ -112,84 +204,82 @@
 // clang-format on
 //
 // Also adding the capabilities and builtins that are needed.
-FoldingRule ReplaceSwizzleInvocations() {
-  return [](IRContext* ctx, Instruction* inst,
-            const std::vector<const analysis::Constant*>&) {
-    analysis::TypeManager* type_mgr = ctx->get_type_mgr();
-    analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
+bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst,
+                               const std::vector<const analysis::Constant*>&) {
+  analysis::TypeManager* type_mgr = ctx->get_type_mgr();
+  analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
 
-    ctx->AddExtension("SPV_KHR_shader_ballot");
-    ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
-    ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
+  ctx->AddExtension("SPV_KHR_shader_ballot");
+  ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
+  ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
 
-    InstructionBuilder ir_builder(
-        ctx, inst,
-        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+  InstructionBuilder ir_builder(
+      ctx, inst,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
 
-    uint32_t data_id = inst->GetSingleWordInOperand(2);
-    uint32_t offset_id = inst->GetSingleWordInOperand(3);
+  uint32_t data_id = inst->GetSingleWordInOperand(2);
+  uint32_t offset_id = inst->GetSingleWordInOperand(3);
 
-    // Get the subgroup invocation id.
-    uint32_t var_id =
-        ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
-    assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
-    Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
-    Instruction* var_ptr_type =
-        ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
-    uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
+  // Get the subgroup invocation id.
+  uint32_t var_id =
+      ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+  assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+  Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+  Instruction* var_ptr_type =
+      ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+  uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
 
-    Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
+  Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
 
-    uint32_t quad_mask = ir_builder.GetUintConstantId(3);
+  uint32_t quad_mask = ir_builder.GetUintConstantId(3);
 
-    // This gives the offset in the group of 4 of this invocation.
-    Instruction* quad_idx = ir_builder.AddBinaryOp(
-        uint_type_id, SpvOpBitwiseAnd, id->result_id(), quad_mask);
+  // This gives the offset in the group of 4 of this invocation.
+  Instruction* quad_idx = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseAnd,
+                                                 id->result_id(), quad_mask);
 
-    // Get the invocation id of the first invocation in the group of 4.
-    Instruction* quad_ldr = ir_builder.AddBinaryOp(
-        uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id());
+  // Get the invocation id of the first invocation in the group of 4.
+  Instruction* quad_ldr = ir_builder.AddBinaryOp(
+      uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id());
 
-    // Get the offset of the target invocation from the offset vector.
-    Instruction* my_offset =
-        ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic,
-                               offset_id, quad_idx->result_id());
+  // Get the offset of the target invocation from the offset vector.
+  Instruction* my_offset =
+      ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic, offset_id,
+                             quad_idx->result_id());
 
-    // Determine the index of the invocation to read from.
-    Instruction* target_inv = ir_builder.AddBinaryOp(
-        uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id());
+  // Determine the index of the invocation to read from.
+  Instruction* target_inv = ir_builder.AddBinaryOp(
+      uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id());
 
-    // Do the group operations
-    uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
-    uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
-    const auto* ballot_value_const = const_mgr->GetConstant(
-        type_mgr->GetUIntVectorType(4),
-        {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
-    Instruction* ballot_value =
-        const_mgr->GetDefiningInstruction(ballot_value_const);
-    Instruction* is_active = ir_builder.AddNaryOp(
-        type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
-        {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
-    Instruction* shuffle = ir_builder.AddNaryOp(
-        inst->type_id(), SpvOpGroupNonUniformShuffle,
-        {subgroup_scope, data_id, target_inv->result_id()});
+  // Do the group operations
+  uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
+  uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
+  const auto* ballot_value_const = const_mgr->GetConstant(
+      type_mgr->GetUIntVectorType(4),
+      {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
+  Instruction* ballot_value =
+      const_mgr->GetDefiningInstruction(ballot_value_const);
+  Instruction* is_active = ir_builder.AddNaryOp(
+      type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
+      {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
+  Instruction* shuffle =
+      ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle,
+                           {subgroup_scope, data_id, target_inv->result_id()});
 
-    // Create the null constant to use in the select.
-    const auto* null = const_mgr->GetConstant(
-        type_mgr->GetType(inst->type_id()), std::vector<uint32_t>());
-    Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
+  // Create the null constant to use in the select.
+  const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
+                                            std::vector<uint32_t>());
+  Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
 
-    // Build the select.
-    inst->SetOpcode(SpvOpSelect);
-    Instruction::OperandList new_operands;
-    new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
-    new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
-    new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
+  // Build the select.
+  inst->SetOpcode(SpvOpSelect);
+  Instruction::OperandList new_operands;
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
 
-    inst->SetInOperands(std::move(new_operands));
-    ctx->UpdateDefUse(inst);
-    return true;
-  };
+  inst->SetInOperands(std::move(new_operands));
+  ctx->UpdateDefUse(inst);
+  return true;
 }
 
 // Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD
@@ -225,89 +315,87 @@
 // clang-format on
 //
 // Also adding the capabilities and builtins that are needed.
-FoldingRule ReplaceSwizzleInvocationsMasked() {
-  return [](IRContext* ctx, Instruction* inst,
-            const std::vector<const analysis::Constant*>&) {
-    analysis::TypeManager* type_mgr = ctx->get_type_mgr();
-    analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
-    analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
+bool ReplaceSwizzleInvocationsMasked(
+    IRContext* ctx, Instruction* inst,
+    const std::vector<const analysis::Constant*>&) {
+  analysis::TypeManager* type_mgr = ctx->get_type_mgr();
+  analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
+  analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
 
-    // ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
-    ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
-    ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
+  // ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
+  ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
+  ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
 
-    InstructionBuilder ir_builder(
-        ctx, inst,
-        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+  InstructionBuilder ir_builder(
+      ctx, inst,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
 
-    // Get the operands to inst, and the components of the mask
-    uint32_t data_id = inst->GetSingleWordInOperand(2);
+  // Get the operands to inst, and the components of the mask
+  uint32_t data_id = inst->GetSingleWordInOperand(2);
 
-    Instruction* mask_inst =
-        def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
-    assert(mask_inst->opcode() == SpvOpConstantComposite &&
-           "The mask is suppose to be a vector constant.");
-    assert(mask_inst->NumInOperands() == 3 &&
-           "The mask is suppose to have 3 components.");
+  Instruction* mask_inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
+  assert(mask_inst->opcode() == SpvOpConstantComposite &&
+         "The mask is suppose to be a vector constant.");
+  assert(mask_inst->NumInOperands() == 3 &&
+         "The mask is suppose to have 3 components.");
 
-    uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
-    uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
-    uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
+  uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
+  uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
+  uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
 
-    // Get the subgroup invocation id.
-    uint32_t var_id =
-        ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
-    ctx->AddExtension("SPV_KHR_shader_ballot");
-    assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
-    Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
-    Instruction* var_ptr_type =
-        ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
-    uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
+  // Get the subgroup invocation id.
+  uint32_t var_id =
+      ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+  ctx->AddExtension("SPV_KHR_shader_ballot");
+  assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+  Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+  Instruction* var_ptr_type =
+      ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+  uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
 
-    Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
+  Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
 
-    // Do the bitwise operations.
-    uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
-    Instruction* and_mask = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseOr,
-                                                   uint_x, mask_extended);
-    Instruction* and_result = ir_builder.AddBinaryOp(
-        uint_type_id, SpvOpBitwiseAnd, id->result_id(), and_mask->result_id());
-    Instruction* or_result = ir_builder.AddBinaryOp(
-        uint_type_id, SpvOpBitwiseOr, and_result->result_id(), uint_y);
-    Instruction* target_inv = ir_builder.AddBinaryOp(
-        uint_type_id, SpvOpBitwiseXor, or_result->result_id(), uint_z);
+  // Do the bitwise operations.
+  uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
+  Instruction* and_mask = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseOr,
+                                                 uint_x, mask_extended);
+  Instruction* and_result = ir_builder.AddBinaryOp(
+      uint_type_id, SpvOpBitwiseAnd, id->result_id(), and_mask->result_id());
+  Instruction* or_result = ir_builder.AddBinaryOp(
+      uint_type_id, SpvOpBitwiseOr, and_result->result_id(), uint_y);
+  Instruction* target_inv = ir_builder.AddBinaryOp(
+      uint_type_id, SpvOpBitwiseXor, or_result->result_id(), uint_z);
 
-    // Do the group operations
-    uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
-    uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
-    const auto* ballot_value_const = const_mgr->GetConstant(
-        type_mgr->GetUIntVectorType(4),
-        {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
-    Instruction* ballot_value =
-        const_mgr->GetDefiningInstruction(ballot_value_const);
-    Instruction* is_active = ir_builder.AddNaryOp(
-        type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
-        {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
-    Instruction* shuffle = ir_builder.AddNaryOp(
-        inst->type_id(), SpvOpGroupNonUniformShuffle,
-        {subgroup_scope, data_id, target_inv->result_id()});
+  // Do the group operations
+  uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
+  uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
+  const auto* ballot_value_const = const_mgr->GetConstant(
+      type_mgr->GetUIntVectorType(4),
+      {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
+  Instruction* ballot_value =
+      const_mgr->GetDefiningInstruction(ballot_value_const);
+  Instruction* is_active = ir_builder.AddNaryOp(
+      type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
+      {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
+  Instruction* shuffle =
+      ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle,
+                           {subgroup_scope, data_id, target_inv->result_id()});
 
-    // Create the null constant to use in the select.
-    const auto* null = const_mgr->GetConstant(
-        type_mgr->GetType(inst->type_id()), std::vector<uint32_t>());
-    Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
+  // Create the null constant to use in the select.
+  const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
+                                            std::vector<uint32_t>());
+  Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
 
-    // Build the select.
-    inst->SetOpcode(SpvOpSelect);
-    Instruction::OperandList new_operands;
-    new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
-    new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
-    new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
+  // Build the select.
+  inst->SetOpcode(SpvOpSelect);
+  Instruction::OperandList new_operands;
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
 
-    inst->SetInOperands(std::move(new_operands));
-    ctx->UpdateDefUse(inst);
-    return true;
-  };
+  inst->SetInOperands(std::move(new_operands));
+  ctx->UpdateDefUse(inst);
+  return true;
 }
 
 // Returns a folding rule that will replace the WriteInvocationAMD extended
@@ -326,40 +414,38 @@
 // %result = OpSelect %type %cmp %write_value %input_value
 //
 // Also adding the capabilities and builtins that are needed.
-FoldingRule ReplaceWriteInvocation() {
-  return [](IRContext* ctx, Instruction* inst,
-            const std::vector<const analysis::Constant*>&) {
-    uint32_t var_id =
-        ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
-    ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
-    ctx->AddExtension("SPV_KHR_shader_ballot");
-    assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
-    Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
-    Instruction* var_ptr_type =
-        ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst,
+                            const std::vector<const analysis::Constant*>&) {
+  uint32_t var_id =
+      ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+  ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
+  ctx->AddExtension("SPV_KHR_shader_ballot");
+  assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+  Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+  Instruction* var_ptr_type =
+      ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
 
-    InstructionBuilder ir_builder(
-        ctx, inst,
-        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
-    Instruction* t =
-        ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
-    analysis::Bool bool_type;
-    uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
-    Instruction* cmp =
-        ir_builder.AddBinaryOp(bool_type_id, SpvOpIEqual, t->result_id(),
-                               inst->GetSingleWordInOperand(4));
+  InstructionBuilder ir_builder(
+      ctx, inst,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+  Instruction* t =
+      ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
+  analysis::Bool bool_type;
+  uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
+  Instruction* cmp =
+      ir_builder.AddBinaryOp(bool_type_id, SpvOpIEqual, t->result_id(),
+                             inst->GetSingleWordInOperand(4));
 
-    // Build a select.
-    inst->SetOpcode(SpvOpSelect);
-    Instruction::OperandList new_operands;
-    new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
-    new_operands.push_back(inst->GetInOperand(3));
-    new_operands.push_back(inst->GetInOperand(2));
+  // Build a select.
+  inst->SetOpcode(SpvOpSelect);
+  Instruction::OperandList new_operands;
+  new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
+  new_operands.push_back(inst->GetInOperand(3));
+  new_operands.push_back(inst->GetInOperand(2));
 
-    inst->SetInOperands(std::move(new_operands));
-    ctx->UpdateDefUse(inst);
-    return true;
-  };
+  inst->SetInOperands(std::move(new_operands));
+  ctx->UpdateDefUse(inst);
+  return true;
 }
 
 // Returns a folding rule that will replace the MbcntAMD extended instruction in
@@ -384,51 +470,49 @@
 //  %result = OpBitCount %uint %and
 //
 // Also adding the capabilities and builtins that are needed.
-FoldingRule ReplaceMbcnt() {
-  return [](IRContext* context, Instruction* inst,
-            const std::vector<const analysis::Constant*>&) {
-    analysis::TypeManager* type_mgr = context->get_type_mgr();
-    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+bool ReplaceMbcnt(IRContext* context, Instruction* inst,
+                  const std::vector<const analysis::Constant*>&) {
+  analysis::TypeManager* type_mgr = context->get_type_mgr();
+  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
 
-    uint32_t var_id = context->GetBuiltinInputVarId(SpvBuiltInSubgroupLtMask);
-    assert(var_id != 0 && "Could not get SubgroupLtMask variable.");
-    context->AddCapability(SpvCapabilityGroupNonUniformBallot);
-    Instruction* var_inst = def_use_mgr->GetDef(var_id);
-    Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
-    Instruction* var_type =
-        def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
-    assert(var_type->opcode() == SpvOpTypeVector &&
-           "Variable is suppose to be a vector of 4 ints");
+  uint32_t var_id = context->GetBuiltinInputVarId(SpvBuiltInSubgroupLtMask);
+  assert(var_id != 0 && "Could not get SubgroupLtMask variable.");
+  context->AddCapability(SpvCapabilityGroupNonUniformBallot);
+  Instruction* var_inst = def_use_mgr->GetDef(var_id);
+  Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
+  Instruction* var_type =
+      def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
+  assert(var_type->opcode() == SpvOpTypeVector &&
+         "Variable is suppose to be a vector of 4 ints");
 
-    // Get the type for the shuffle.
-    analysis::Vector temp_type(GetUIntType(context), 2);
-    const analysis::Type* shuffle_type =
-        context->get_type_mgr()->GetRegisteredType(&temp_type);
-    uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
+  // Get the type for the shuffle.
+  analysis::Vector temp_type(GetUIntType(context), 2);
+  const analysis::Type* shuffle_type =
+      context->get_type_mgr()->GetRegisteredType(&temp_type);
+  uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
 
-    uint32_t mask_id = inst->GetSingleWordInOperand(2);
-    Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
+  uint32_t mask_id = inst->GetSingleWordInOperand(2);
+  Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
 
-    // Testing with amd's shader compiler shows that a 64-bit mask is expected.
-    assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
-    assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
+  // Testing with amd's shader compiler shows that a 64-bit mask is expected.
+  assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
+  assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
 
-    InstructionBuilder ir_builder(
-        context, inst,
-        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
-    Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
-    Instruction* shuffle = ir_builder.AddVectorShuffle(
-        shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
-    Instruction* bitcast = ir_builder.AddUnaryOp(
-        mask_inst->type_id(), SpvOpBitcast, shuffle->result_id());
-    Instruction* t = ir_builder.AddBinaryOp(
-        mask_inst->type_id(), SpvOpBitwiseAnd, bitcast->result_id(), mask_id);
+  InstructionBuilder ir_builder(
+      context, inst,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+  Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
+  Instruction* shuffle = ir_builder.AddVectorShuffle(
+      shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
+  Instruction* bitcast = ir_builder.AddUnaryOp(
+      mask_inst->type_id(), SpvOpBitcast, shuffle->result_id());
+  Instruction* t = ir_builder.AddBinaryOp(mask_inst->type_id(), SpvOpBitwiseAnd,
+                                          bitcast->result_id(), mask_id);
 
-    inst->SetOpcode(SpvOpBitCount);
-    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
-    context->UpdateDefUse(inst);
-    return true;
-  };
+  inst->SetOpcode(SpvOpBitCount);
+  inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
+  context->UpdateDefUse(inst);
+  return true;
 }
 
 class AmdExtFoldingRules : public FoldingRules {
@@ -438,33 +522,59 @@
  protected:
   virtual void AddFoldingRules() override {
     rules_[SpvOpGroupIAddNonUniformAMD].push_back(
-        ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformIAdd));
+        ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformIAdd>);
     rules_[SpvOpGroupFAddNonUniformAMD].push_back(
-        ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFAdd));
+        ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFAdd>);
     rules_[SpvOpGroupUMinNonUniformAMD].push_back(
-        ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformUMin));
+        ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMin>);
     rules_[SpvOpGroupSMinNonUniformAMD].push_back(
-        ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformSMin));
+        ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMin>);
     rules_[SpvOpGroupFMinNonUniformAMD].push_back(
-        ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFMin));
+        ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMin>);
     rules_[SpvOpGroupUMaxNonUniformAMD].push_back(
-        ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformUMax));
+        ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMax>);
     rules_[SpvOpGroupSMaxNonUniformAMD].push_back(
-        ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformSMax));
+        ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMax>);
     rules_[SpvOpGroupFMaxNonUniformAMD].push_back(
-        ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFMax));
+        ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMax>);
 
     uint32_t extension_id =
         context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot");
 
-    ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}].push_back(
-        ReplaceSwizzleInvocations());
-    ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
-        .push_back(ReplaceSwizzleInvocationsMasked());
-    ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
-        ReplaceWriteInvocation());
-    ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
-        ReplaceMbcnt());
+    if (extension_id != 0) {
+      ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}]
+          .push_back(ReplaceSwizzleInvocations);
+      ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
+          .push_back(ReplaceSwizzleInvocationsMasked);
+      ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
+          ReplaceWriteInvocation);
+      ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
+          ReplaceMbcnt);
+    }
+
+    extension_id = context()->module()->GetExtInstImportId(
+        "SPV_AMD_shader_trinary_minmax");
+
+    if (extension_id != 0) {
+      ext_rules_[{extension_id, FMin3AMD}].push_back(
+          ReplaceTrinaryMinMax<GLSLstd450FMin>);
+      ext_rules_[{extension_id, UMin3AMD}].push_back(
+          ReplaceTrinaryMinMax<GLSLstd450UMin>);
+      ext_rules_[{extension_id, SMin3AMD}].push_back(
+          ReplaceTrinaryMinMax<GLSLstd450SMin>);
+      ext_rules_[{extension_id, FMax3AMD}].push_back(
+          ReplaceTrinaryMinMax<GLSLstd450FMax>);
+      ext_rules_[{extension_id, UMax3AMD}].push_back(
+          ReplaceTrinaryMinMax<GLSLstd450UMax>);
+      ext_rules_[{extension_id, SMax3AMD}].push_back(
+          ReplaceTrinaryMinMax<GLSLstd450SMax>);
+      ext_rules_[{extension_id, FMid3AMD}].push_back(
+          ReplaceTrinaryMid<GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp>);
+      ext_rules_[{extension_id, UMid3AMD}].push_back(
+          ReplaceTrinaryMid<GLSLstd450UMin, GLSLstd450UMax, GLSLstd450UClamp>);
+      ext_rules_[{extension_id, SMid3AMD}].push_back(
+          ReplaceTrinaryMid<GLSLstd450SMin, GLSLstd450SMax, GLSLstd450SClamp>);
+    }
   }
 };
 
@@ -500,9 +610,14 @@
   std::vector<Instruction*> to_be_killed;
   for (Instruction& inst : context()->module()->extensions()) {
     if (inst.opcode() == SpvOpExtension) {
-      if (!strcmp("SPV_AMD_shader_ballot",
-                  reinterpret_cast<const char*>(
-                      &(inst.GetInOperand(0).words[0])))) {
+      if (strcmp("SPV_AMD_shader_ballot",
+                 reinterpret_cast<const char*>(
+                     &(inst.GetInOperand(0).words[0]))) == 0) {
+        to_be_killed.push_back(&inst);
+      }
+      if (strcmp("SPV_AMD_shader_trinary_minmax",
+                 reinterpret_cast<const char*>(
+                     &(inst.GetInOperand(0).words[0]))) == 0) {
         to_be_killed.push_back(&inst);
       }
     }
@@ -510,9 +625,14 @@
 
   for (Instruction& inst : context()->ext_inst_imports()) {
     if (inst.opcode() == SpvOpExtInstImport) {
-      if (!strcmp("SPV_AMD_shader_ballot",
-                  reinterpret_cast<const char*>(
-                      &(inst.GetInOperand(0).words[0])))) {
+      if (strcmp("SPV_AMD_shader_ballot",
+                 reinterpret_cast<const char*>(
+                     &(inst.GetInOperand(0).words[0]))) == 0) {
+        to_be_killed.push_back(&inst);
+      }
+      if (strcmp("SPV_AMD_shader_trinary_minmax",
+                 reinterpret_cast<const char*>(
+                     &(inst.GetInOperand(0).words[0]))) == 0) {
         to_be_killed.push_back(&inst);
       }
     }
diff --git a/source/opt/feature_manager.h b/source/opt/feature_manager.h
index 761a208..2fe3291 100644
--- a/source/opt/feature_manager.h
+++ b/source/opt/feature_manager.h
@@ -57,6 +57,9 @@
   // Add the extension |ext| to the feature manager.
   void AddExtension(Instruction* ext);
 
+  // Analyzes |module| and records imported external instruction sets.
+  void AddExtInstImportIds(Module* module);
+
  private:
   // Analyzes |module| and records enabled extensions.
   void AddExtensions(Module* module);
@@ -64,9 +67,6 @@
   // Analyzes |module| and records enabled capabilities.
   void AddCapabilities(Module* module);
 
-  // Analyzes |module| and records imported external instruction sets.
-  void AddExtInstImportIds(Module* module);
-
   // Auxiliary object for querying SPIR-V grammar facts.
   const AssemblyGrammar& grammar_;
 
diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h
index b9cb26a..6720e89 100644
--- a/source/opt/ir_builder.h
+++ b/source/opt/ir_builder.h
@@ -492,6 +492,27 @@
     return AddInstruction(std::move(new_inst));
   }
 
+  Instruction* AddNaryExtendedInstruction(
+      uint32_t result_type, uint32_t set, uint32_t instruction,
+      const std::vector<uint32_t>& ext_operands) {
+    std::vector<Operand> operands;
+    operands.push_back({SPV_OPERAND_TYPE_ID, {set}});
+    operands.push_back(
+        {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {instruction}});
+    for (uint32_t id : ext_operands) {
+      operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
+    }
+
+    uint32_t result_id = GetContext()->TakeNextId();
+    if (result_id == 0) {
+      return nullptr;
+    }
+
+    std::unique_ptr<Instruction> new_inst(new Instruction(
+        GetContext(), SpvOpExtInst, result_type, result_id, operands));
+    return AddInstruction(std::move(new_inst));
+  }
+
   // Inserts the new instruction before the insertion point.
   Instruction* AddInstruction(std::unique_ptr<Instruction>&& insn) {
     Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn));
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index e297fb1..3bbf180 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -199,6 +199,7 @@
   inline void AddExtension(const std::string& ext_name);
   inline void AddExtension(std::unique_ptr<Instruction>&& e);
   // Appends an extended instruction set instruction to this module.
+  inline void AddExtInstImport(const std::string& name);
   inline void AddExtInstImport(std::unique_ptr<Instruction>&& e);
   // Set the memory model for this module.
   inline void SetMemoryModel(std::unique_ptr<Instruction>&& m);
@@ -971,9 +972,26 @@
   module()->AddExtension(std::move(e));
 }
 
+void IRContext::AddExtInstImport(const std::string& name) {
+  const auto num_chars = name.size();
+  // Compute num words, accommodate the terminating null character.
+  const auto num_words = (num_chars + 1 + 3) / 4;
+  std::vector<uint32_t> ext_words(num_words, 0u);
+  std::memcpy(ext_words.data(), name.data(), num_chars);
+  AddExtInstImport(std::unique_ptr<Instruction>(
+      new Instruction(this, SpvOpExtInstImport, 0u, TakeNextId(),
+                      {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}})));
+}
+
 void IRContext::AddExtInstImport(std::unique_ptr<Instruction>&& e) {
   AddCombinatorsForExtension(e.get());
+  if (AreAnalysesValid(kAnalysisDefUse)) {
+    get_def_use_mgr()->AnalyzeInstDefUse(e.get());
+  }
   module()->AddExtInstImport(std::move(e));
+  if (feature_mgr_ != nullptr) {
+    feature_mgr_->AddExtInstImportIds(module());
+  }
 }
 
 void IRContext::SetMemoryModel(std::unique_ptr<Instruction>&& m) {
diff --git a/test/opt/amd_ext_to_khr.cpp b/test/opt/amd_ext_to_khr.cpp
index 7a6d4b4..cdf168a 100644
--- a/test/opt/amd_ext_to_khr.cpp
+++ b/test/opt/amd_ext_to_khr.cpp
@@ -233,6 +233,7 @@
 
   SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
 }
+
 TEST_F(AmdExtToKhrTest, ReplaceWriteInvocationAMD) {
   const std::string text = R"(
 ; CHECK: OpCapability Shader
@@ -269,6 +270,438 @@
   SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
 }
 
+TEST_F(AmdExtToKhrTest, ReplaceFMin3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeFloat 32
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] FMin [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] FMin [[temp]] [[z]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %float
+          %8 = OpUndef %float
+          %9 = OpUndef %float
+         %10 = OpExtInst %float %ext FMin3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceSMin3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 1
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] SMin [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] SMin [[temp]] [[z]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+       %int = OpTypeInt 32 1
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %int
+          %8 = OpUndef %int
+          %9 = OpUndef %int
+         %10 = OpExtInst %int %ext SMin3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceUMin3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 0
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] UMin [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UMin [[temp]] [[z]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+       %int = OpTypeInt 32 1
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %uint
+          %8 = OpUndef %uint
+          %9 = OpUndef %uint
+         %10 = OpExtInst %uint %ext UMin3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceFMax3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeFloat 32
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] FMax [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] FMax [[temp]] [[z]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %float
+          %8 = OpUndef %float
+          %9 = OpUndef %float
+         %10 = OpExtInst %float %ext FMax3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceSMax3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 1
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] SMax [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] SMax [[temp]] [[z]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+       %int = OpTypeInt 32 1
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %int
+          %8 = OpUndef %int
+          %9 = OpUndef %int
+         %10 = OpExtInst %int %ext SMax3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceUMax3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 0
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[temp]] [[z]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+       %int = OpTypeInt 32 1
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %uint
+          %8 = OpUndef %uint
+          %9 = OpUndef %uint
+         %10 = OpExtInst %uint %ext UMax3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceVecUMax3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeVector
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[temp]] [[z]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+        %vec = OpTypeVector %uint 4
+       %int = OpTypeInt 32 1
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %vec
+          %8 = OpUndef %vec
+          %9 = OpUndef %vec
+         %10 = OpExtInst %vec %ext UMax3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceFMid3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeFloat 32
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[min:%\w+]] = OpExtInst [[type]] [[ext]] FMin [[y]] [[z]]
+; CHECK-NEXT: [[max:%\w+]] = OpExtInst [[type]] [[ext]] FMax [[y]] [[z]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] FClamp [[x]] [[min]] [[max]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %float
+          %8 = OpUndef %float
+          %9 = OpUndef %float
+         %10 = OpExtInst %float %ext FMid3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceSMid3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 1
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[min:%\w+]] = OpExtInst [[type]] [[ext]] SMin [[y]] [[z]]
+; CHECK-NEXT: [[max:%\w+]] = OpExtInst [[type]] [[ext]] SMax [[y]] [[z]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] SClamp [[x]] [[min]] [[max]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+       %int = OpTypeInt 32 1
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %int
+          %8 = OpUndef %int
+          %9 = OpUndef %int
+         %10 = OpExtInst %int %ext SMid3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceUMid3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 0
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[min:%\w+]] = OpExtInst [[type]] [[ext]] UMin [[y]] [[z]]
+; CHECK-NEXT: [[max:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[y]] [[z]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UClamp [[x]] [[min]] [[max]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+       %int = OpTypeInt 32 1
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %uint
+          %8 = OpUndef %uint
+          %9 = OpUndef %uint
+         %10 = OpExtInst %uint %ext UMid3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceVecUMid3AMD) {
+  const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeVector
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[min:%\w+]] = OpExtInst [[type]] [[ext]] UMin [[y]] [[z]]
+; CHECK-NEXT: [[max:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[y]] [[z]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UClamp [[x]] [[min]] [[max]]
+               OpCapability Shader
+               OpExtension "SPV_AMD_shader_trinary_minmax"
+        %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "func"
+               OpExecutionMode %1 OriginUpperLeft
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+       %vec = OpTypeVector %uint 3
+       %int = OpTypeInt 32 1
+      %float = OpTypeFloat 32
+     %uint_3 = OpConstant %uint 3
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+          %7 = OpUndef %vec
+          %8 = OpUndef %vec
+          %9 = OpUndef %vec
+         %10 = OpExtInst %vec %ext UMid3AMD %7 %8 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
 TEST_F(AmdExtToKhrTest, SetVersion) {
   const std::string text = R"(
                OpCapability Shader