Replace uses of SPV_AMD_shader_trinary_minmax extension (#2835)
Part of #2814
diff --git a/source/opt/amd_ext_to_khr.cpp b/source/opt/amd_ext_to_khr.cpp
index 1cb5ba5..3a49858 100644
--- a/source/opt/amd_ext_to_khr.cpp
+++ b/source/opt/amd_ext_to_khr.cpp
@@ -24,22 +24,117 @@
namespace {
-enum ExtOpcodes {
+enum AmdShaderBallotExtOpcodes {
AmdShaderBallotSwizzleInvocationsAMD = 1,
AmdShaderBallotSwizzleInvocationsMaskedAMD = 2,
AmdShaderBallotWriteInvocationAMD = 3,
AmdShaderBallotMbcntAMD = 4
};
+enum AmdShaderTrinaryMinMaxExtOpCodes {
+ FMin3AMD = 1,
+ UMin3AMD = 2,
+ SMin3AMD = 3,
+ FMax3AMD = 4,
+ UMax3AMD = 5,
+ SMax3AMD = 6,
+ FMid3AMD = 7,
+ UMid3AMD = 8,
+ SMid3AMD = 9
+};
+
analysis::Type* GetUIntType(IRContext* ctx) {
analysis::Integer int_type(32, false);
return ctx->get_type_mgr()->GetRegisteredType(&int_type);
}
+// Returns a folding rule that replaces |op(a,b,c)| by |op(op(a,b),c)|, where
+// |op| is either min or max. |opcode| is the binary opcode in the GLSLstd450
+// extended instruction set that corresponds to the trinary instruction being
+// replaced.
+template <GLSLstd450 opcode>
+bool ReplaceTrinaryMinMax(IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ uint32_t glsl405_ext_inst_id =
+ ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+ if (glsl405_ext_inst_id == 0) {
+ ctx->AddExtInstImport("GLSL.std.450");
+ glsl405_ext_inst_id =
+ ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+ }
+
+ InstructionBuilder ir_builder(
+ ctx, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ uint32_t op1 = inst->GetSingleWordInOperand(2);
+ uint32_t op2 = inst->GetSingleWordInOperand(3);
+ uint32_t op3 = inst->GetSingleWordInOperand(4);
+
+ Instruction* temp = ir_builder.AddNaryExtendedInstruction(
+ inst->type_id(), glsl405_ext_inst_id, opcode, {op1, op2});
+
+ Instruction::OperandList new_operands;
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
+ new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
+ {static_cast<uint32_t>(opcode)}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {temp->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {op3}});
+
+ inst->SetInOperands(std::move(new_operands));
+ ctx->UpdateDefUse(inst);
+ return true;
+}
+
+// Returns a folding rule that replaces |mid(a,b,c)| by |clamp(a, min(b,c),
+// max(b,c)|. The three parameters are the opcode that correspond to the min,
+// max, and clamp operations for the type of the instruction being replaced.
+template <GLSLstd450 min_opcode, GLSLstd450 max_opcode, GLSLstd450 clamp_opcode>
+bool ReplaceTrinaryMid(IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ uint32_t glsl405_ext_inst_id =
+ ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+ if (glsl405_ext_inst_id == 0) {
+ ctx->AddExtInstImport("GLSL.std.450");
+ glsl405_ext_inst_id =
+ ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+ }
+
+ InstructionBuilder ir_builder(
+ ctx, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ uint32_t op1 = inst->GetSingleWordInOperand(2);
+ uint32_t op2 = inst->GetSingleWordInOperand(3);
+ uint32_t op3 = inst->GetSingleWordInOperand(4);
+
+ Instruction* min = ir_builder.AddNaryExtendedInstruction(
+ inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(min_opcode),
+ {op2, op3});
+ Instruction* max = ir_builder.AddNaryExtendedInstruction(
+ inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(max_opcode),
+ {op2, op3});
+
+ Instruction::OperandList new_operands;
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
+ new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
+ {static_cast<uint32_t>(clamp_opcode)}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {op1}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {min->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {max->result_id()}});
+
+ inst->SetInOperands(std::move(new_operands));
+ ctx->UpdateDefUse(inst);
+ return true;
+}
+
// Returns a folding rule that will replace the opcode with |opcode| and add
// the capabilities required. The folding rule assumes it is folding an
// OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension.
-FoldingRule ReplaceGroupNonuniformOperationOpCode(SpvOp new_opcode) {
+template <SpvOp new_opcode>
+bool ReplaceGroupNonuniformOperationOpCode(
+ IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
switch (new_opcode) {
case SpvOpGroupNonUniformIAdd:
case SpvOpGroupNonUniformFAdd:
@@ -56,27 +151,24 @@
"Should be replacing with a group non uniform arithmetic operation.");
}
- return [new_opcode](IRContext* ctx, Instruction* inst,
- const std::vector<const analysis::Constant*>&) {
- switch (inst->opcode()) {
- case SpvOpGroupIAddNonUniformAMD:
- case SpvOpGroupFAddNonUniformAMD:
- case SpvOpGroupUMinNonUniformAMD:
- case SpvOpGroupSMinNonUniformAMD:
- case SpvOpGroupFMinNonUniformAMD:
- case SpvOpGroupUMaxNonUniformAMD:
- case SpvOpGroupSMaxNonUniformAMD:
- case SpvOpGroupFMaxNonUniformAMD:
- break;
- default:
- assert(false &&
- "Should be replacing a group non uniform arithmetic operation.");
- }
+ switch (inst->opcode()) {
+ case SpvOpGroupIAddNonUniformAMD:
+ case SpvOpGroupFAddNonUniformAMD:
+ case SpvOpGroupUMinNonUniformAMD:
+ case SpvOpGroupSMinNonUniformAMD:
+ case SpvOpGroupFMinNonUniformAMD:
+ case SpvOpGroupUMaxNonUniformAMD:
+ case SpvOpGroupSMaxNonUniformAMD:
+ case SpvOpGroupFMaxNonUniformAMD:
+ break;
+ default:
+ assert(false &&
+ "Should be replacing a group non uniform arithmetic operation.");
+ }
- ctx->AddCapability(SpvCapabilityGroupNonUniformArithmetic);
- inst->SetOpcode(new_opcode);
- return true;
- };
+ ctx->AddCapability(SpvCapabilityGroupNonUniformArithmetic);
+ inst->SetOpcode(new_opcode);
+ return true;
}
// Returns a folding rule that will replace the SwizzleInvocationsAMD extended
@@ -112,84 +204,82 @@
// clang-format on
//
// Also adding the capabilities and builtins that are needed.
-FoldingRule ReplaceSwizzleInvocations() {
- return [](IRContext* ctx, Instruction* inst,
- const std::vector<const analysis::Constant*>&) {
- analysis::TypeManager* type_mgr = ctx->get_type_mgr();
- analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
+bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ analysis::TypeManager* type_mgr = ctx->get_type_mgr();
+ analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
- ctx->AddExtension("SPV_KHR_shader_ballot");
- ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
- ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
+ ctx->AddExtension("SPV_KHR_shader_ballot");
+ ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
+ ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
- InstructionBuilder ir_builder(
- ctx, inst,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ InstructionBuilder ir_builder(
+ ctx, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- uint32_t data_id = inst->GetSingleWordInOperand(2);
- uint32_t offset_id = inst->GetSingleWordInOperand(3);
+ uint32_t data_id = inst->GetSingleWordInOperand(2);
+ uint32_t offset_id = inst->GetSingleWordInOperand(3);
- // Get the subgroup invocation id.
- uint32_t var_id =
- ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
- assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
- Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
- Instruction* var_ptr_type =
- ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
- uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
+ // Get the subgroup invocation id.
+ uint32_t var_id =
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+ Instruction* var_ptr_type =
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+ uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
- Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
+ Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
- uint32_t quad_mask = ir_builder.GetUintConstantId(3);
+ uint32_t quad_mask = ir_builder.GetUintConstantId(3);
- // This gives the offset in the group of 4 of this invocation.
- Instruction* quad_idx = ir_builder.AddBinaryOp(
- uint_type_id, SpvOpBitwiseAnd, id->result_id(), quad_mask);
+ // This gives the offset in the group of 4 of this invocation.
+ Instruction* quad_idx = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseAnd,
+ id->result_id(), quad_mask);
- // Get the invocation id of the first invocation in the group of 4.
- Instruction* quad_ldr = ir_builder.AddBinaryOp(
- uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id());
+ // Get the invocation id of the first invocation in the group of 4.
+ Instruction* quad_ldr = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id());
- // Get the offset of the target invocation from the offset vector.
- Instruction* my_offset =
- ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic,
- offset_id, quad_idx->result_id());
+ // Get the offset of the target invocation from the offset vector.
+ Instruction* my_offset =
+ ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic, offset_id,
+ quad_idx->result_id());
- // Determine the index of the invocation to read from.
- Instruction* target_inv = ir_builder.AddBinaryOp(
- uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id());
+ // Determine the index of the invocation to read from.
+ Instruction* target_inv = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id());
- // Do the group operations
- uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
- uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
- const auto* ballot_value_const = const_mgr->GetConstant(
- type_mgr->GetUIntVectorType(4),
- {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
- Instruction* ballot_value =
- const_mgr->GetDefiningInstruction(ballot_value_const);
- Instruction* is_active = ir_builder.AddNaryOp(
- type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
- {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
- Instruction* shuffle = ir_builder.AddNaryOp(
- inst->type_id(), SpvOpGroupNonUniformShuffle,
- {subgroup_scope, data_id, target_inv->result_id()});
+ // Do the group operations
+ uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
+ uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
+ const auto* ballot_value_const = const_mgr->GetConstant(
+ type_mgr->GetUIntVectorType(4),
+ {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
+ Instruction* ballot_value =
+ const_mgr->GetDefiningInstruction(ballot_value_const);
+ Instruction* is_active = ir_builder.AddNaryOp(
+ type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
+ {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
+ Instruction* shuffle =
+ ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle,
+ {subgroup_scope, data_id, target_inv->result_id()});
- // Create the null constant to use in the select.
- const auto* null = const_mgr->GetConstant(
- type_mgr->GetType(inst->type_id()), std::vector<uint32_t>());
- Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
+ // Create the null constant to use in the select.
+ const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
+ std::vector<uint32_t>());
+ Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
- // Build the select.
- inst->SetOpcode(SpvOpSelect);
- Instruction::OperandList new_operands;
- new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
- new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
- new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
+ // Build the select.
+ inst->SetOpcode(SpvOpSelect);
+ Instruction::OperandList new_operands;
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
- inst->SetInOperands(std::move(new_operands));
- ctx->UpdateDefUse(inst);
- return true;
- };
+ inst->SetInOperands(std::move(new_operands));
+ ctx->UpdateDefUse(inst);
+ return true;
}
// Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD
@@ -225,89 +315,87 @@
// clang-format on
//
// Also adding the capabilities and builtins that are needed.
-FoldingRule ReplaceSwizzleInvocationsMasked() {
- return [](IRContext* ctx, Instruction* inst,
- const std::vector<const analysis::Constant*>&) {
- analysis::TypeManager* type_mgr = ctx->get_type_mgr();
- analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
- analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
+bool ReplaceSwizzleInvocationsMasked(
+ IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ analysis::TypeManager* type_mgr = ctx->get_type_mgr();
+ analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
+ analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
- // ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
- ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
- ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
+ // ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
+ ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
+ ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
- InstructionBuilder ir_builder(
- ctx, inst,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ InstructionBuilder ir_builder(
+ ctx, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- // Get the operands to inst, and the components of the mask
- uint32_t data_id = inst->GetSingleWordInOperand(2);
+ // Get the operands to inst, and the components of the mask
+ uint32_t data_id = inst->GetSingleWordInOperand(2);
- Instruction* mask_inst =
- def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
- assert(mask_inst->opcode() == SpvOpConstantComposite &&
- "The mask is suppose to be a vector constant.");
- assert(mask_inst->NumInOperands() == 3 &&
- "The mask is suppose to have 3 components.");
+ Instruction* mask_inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
+ assert(mask_inst->opcode() == SpvOpConstantComposite &&
+ "The mask is suppose to be a vector constant.");
+ assert(mask_inst->NumInOperands() == 3 &&
+ "The mask is suppose to have 3 components.");
- uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
- uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
- uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
+ uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
+ uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
+ uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
- // Get the subgroup invocation id.
- uint32_t var_id =
- ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
- ctx->AddExtension("SPV_KHR_shader_ballot");
- assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
- Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
- Instruction* var_ptr_type =
- ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
- uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
+ // Get the subgroup invocation id.
+ uint32_t var_id =
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+ ctx->AddExtension("SPV_KHR_shader_ballot");
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+ Instruction* var_ptr_type =
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+ uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
- Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
+ Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
- // Do the bitwise operations.
- uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
- Instruction* and_mask = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseOr,
- uint_x, mask_extended);
- Instruction* and_result = ir_builder.AddBinaryOp(
- uint_type_id, SpvOpBitwiseAnd, id->result_id(), and_mask->result_id());
- Instruction* or_result = ir_builder.AddBinaryOp(
- uint_type_id, SpvOpBitwiseOr, and_result->result_id(), uint_y);
- Instruction* target_inv = ir_builder.AddBinaryOp(
- uint_type_id, SpvOpBitwiseXor, or_result->result_id(), uint_z);
+ // Do the bitwise operations.
+ uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
+ Instruction* and_mask = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseOr,
+ uint_x, mask_extended);
+ Instruction* and_result = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseAnd, id->result_id(), and_mask->result_id());
+ Instruction* or_result = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseOr, and_result->result_id(), uint_y);
+ Instruction* target_inv = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseXor, or_result->result_id(), uint_z);
- // Do the group operations
- uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
- uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
- const auto* ballot_value_const = const_mgr->GetConstant(
- type_mgr->GetUIntVectorType(4),
- {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
- Instruction* ballot_value =
- const_mgr->GetDefiningInstruction(ballot_value_const);
- Instruction* is_active = ir_builder.AddNaryOp(
- type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
- {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
- Instruction* shuffle = ir_builder.AddNaryOp(
- inst->type_id(), SpvOpGroupNonUniformShuffle,
- {subgroup_scope, data_id, target_inv->result_id()});
+ // Do the group operations
+ uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
+ uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
+ const auto* ballot_value_const = const_mgr->GetConstant(
+ type_mgr->GetUIntVectorType(4),
+ {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
+ Instruction* ballot_value =
+ const_mgr->GetDefiningInstruction(ballot_value_const);
+ Instruction* is_active = ir_builder.AddNaryOp(
+ type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
+ {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
+ Instruction* shuffle =
+ ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle,
+ {subgroup_scope, data_id, target_inv->result_id()});
- // Create the null constant to use in the select.
- const auto* null = const_mgr->GetConstant(
- type_mgr->GetType(inst->type_id()), std::vector<uint32_t>());
- Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
+ // Create the null constant to use in the select.
+ const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()),
+ std::vector<uint32_t>());
+ Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
- // Build the select.
- inst->SetOpcode(SpvOpSelect);
- Instruction::OperandList new_operands;
- new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
- new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
- new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
+ // Build the select.
+ inst->SetOpcode(SpvOpSelect);
+ Instruction::OperandList new_operands;
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
- inst->SetInOperands(std::move(new_operands));
- ctx->UpdateDefUse(inst);
- return true;
- };
+ inst->SetInOperands(std::move(new_operands));
+ ctx->UpdateDefUse(inst);
+ return true;
}
// Returns a folding rule that will replace the WriteInvocationAMD extended
@@ -326,40 +414,38 @@
// %result = OpSelect %type %cmp %write_value %input_value
//
// Also adding the capabilities and builtins that are needed.
-FoldingRule ReplaceWriteInvocation() {
- return [](IRContext* ctx, Instruction* inst,
- const std::vector<const analysis::Constant*>&) {
- uint32_t var_id =
- ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
- ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
- ctx->AddExtension("SPV_KHR_shader_ballot");
- assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
- Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
- Instruction* var_ptr_type =
- ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ uint32_t var_id =
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+ ctx->AddCapability(SpvCapabilitySubgroupBallotKHR);
+ ctx->AddExtension("SPV_KHR_shader_ballot");
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+ Instruction* var_ptr_type =
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
- InstructionBuilder ir_builder(
- ctx, inst,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- Instruction* t =
- ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
- analysis::Bool bool_type;
- uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
- Instruction* cmp =
- ir_builder.AddBinaryOp(bool_type_id, SpvOpIEqual, t->result_id(),
- inst->GetSingleWordInOperand(4));
+ InstructionBuilder ir_builder(
+ ctx, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ Instruction* t =
+ ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
+ analysis::Bool bool_type;
+ uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
+ Instruction* cmp =
+ ir_builder.AddBinaryOp(bool_type_id, SpvOpIEqual, t->result_id(),
+ inst->GetSingleWordInOperand(4));
- // Build a select.
- inst->SetOpcode(SpvOpSelect);
- Instruction::OperandList new_operands;
- new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
- new_operands.push_back(inst->GetInOperand(3));
- new_operands.push_back(inst->GetInOperand(2));
+ // Build a select.
+ inst->SetOpcode(SpvOpSelect);
+ Instruction::OperandList new_operands;
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
+ new_operands.push_back(inst->GetInOperand(3));
+ new_operands.push_back(inst->GetInOperand(2));
- inst->SetInOperands(std::move(new_operands));
- ctx->UpdateDefUse(inst);
- return true;
- };
+ inst->SetInOperands(std::move(new_operands));
+ ctx->UpdateDefUse(inst);
+ return true;
}
// Returns a folding rule that will replace the MbcntAMD extended instruction in
@@ -384,51 +470,49 @@
// %result = OpBitCount %uint %and
//
// Also adding the capabilities and builtins that are needed.
-FoldingRule ReplaceMbcnt() {
- return [](IRContext* context, Instruction* inst,
- const std::vector<const analysis::Constant*>&) {
- analysis::TypeManager* type_mgr = context->get_type_mgr();
- analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+bool ReplaceMbcnt(IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
- uint32_t var_id = context->GetBuiltinInputVarId(SpvBuiltInSubgroupLtMask);
- assert(var_id != 0 && "Could not get SubgroupLtMask variable.");
- context->AddCapability(SpvCapabilityGroupNonUniformBallot);
- Instruction* var_inst = def_use_mgr->GetDef(var_id);
- Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
- Instruction* var_type =
- def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
- assert(var_type->opcode() == SpvOpTypeVector &&
- "Variable is suppose to be a vector of 4 ints");
+ uint32_t var_id = context->GetBuiltinInputVarId(SpvBuiltInSubgroupLtMask);
+ assert(var_id != 0 && "Could not get SubgroupLtMask variable.");
+ context->AddCapability(SpvCapabilityGroupNonUniformBallot);
+ Instruction* var_inst = def_use_mgr->GetDef(var_id);
+ Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
+ Instruction* var_type =
+ def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
+ assert(var_type->opcode() == SpvOpTypeVector &&
+ "Variable is suppose to be a vector of 4 ints");
- // Get the type for the shuffle.
- analysis::Vector temp_type(GetUIntType(context), 2);
- const analysis::Type* shuffle_type =
- context->get_type_mgr()->GetRegisteredType(&temp_type);
- uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
+ // Get the type for the shuffle.
+ analysis::Vector temp_type(GetUIntType(context), 2);
+ const analysis::Type* shuffle_type =
+ context->get_type_mgr()->GetRegisteredType(&temp_type);
+ uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
- uint32_t mask_id = inst->GetSingleWordInOperand(2);
- Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
+ uint32_t mask_id = inst->GetSingleWordInOperand(2);
+ Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
- // Testing with amd's shader compiler shows that a 64-bit mask is expected.
- assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
- assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
+ // Testing with amd's shader compiler shows that a 64-bit mask is expected.
+ assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
+ assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
- InstructionBuilder ir_builder(
- context, inst,
- IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
- Instruction* shuffle = ir_builder.AddVectorShuffle(
- shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
- Instruction* bitcast = ir_builder.AddUnaryOp(
- mask_inst->type_id(), SpvOpBitcast, shuffle->result_id());
- Instruction* t = ir_builder.AddBinaryOp(
- mask_inst->type_id(), SpvOpBitwiseAnd, bitcast->result_id(), mask_id);
+ InstructionBuilder ir_builder(
+ context, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+ Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
+ Instruction* shuffle = ir_builder.AddVectorShuffle(
+ shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
+ Instruction* bitcast = ir_builder.AddUnaryOp(
+ mask_inst->type_id(), SpvOpBitcast, shuffle->result_id());
+ Instruction* t = ir_builder.AddBinaryOp(mask_inst->type_id(), SpvOpBitwiseAnd,
+ bitcast->result_id(), mask_id);
- inst->SetOpcode(SpvOpBitCount);
- inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
- context->UpdateDefUse(inst);
- return true;
- };
+ inst->SetOpcode(SpvOpBitCount);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
+ context->UpdateDefUse(inst);
+ return true;
}
class AmdExtFoldingRules : public FoldingRules {
@@ -438,33 +522,59 @@
protected:
virtual void AddFoldingRules() override {
rules_[SpvOpGroupIAddNonUniformAMD].push_back(
- ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformIAdd));
+ ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformIAdd>);
rules_[SpvOpGroupFAddNonUniformAMD].push_back(
- ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFAdd));
+ ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFAdd>);
rules_[SpvOpGroupUMinNonUniformAMD].push_back(
- ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformUMin));
+ ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMin>);
rules_[SpvOpGroupSMinNonUniformAMD].push_back(
- ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformSMin));
+ ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMin>);
rules_[SpvOpGroupFMinNonUniformAMD].push_back(
- ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFMin));
+ ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMin>);
rules_[SpvOpGroupUMaxNonUniformAMD].push_back(
- ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformUMax));
+ ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMax>);
rules_[SpvOpGroupSMaxNonUniformAMD].push_back(
- ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformSMax));
+ ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMax>);
rules_[SpvOpGroupFMaxNonUniformAMD].push_back(
- ReplaceGroupNonuniformOperationOpCode(SpvOpGroupNonUniformFMax));
+ ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMax>);
uint32_t extension_id =
context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot");
- ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}].push_back(
- ReplaceSwizzleInvocations());
- ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
- .push_back(ReplaceSwizzleInvocationsMasked());
- ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
- ReplaceWriteInvocation());
- ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
- ReplaceMbcnt());
+ if (extension_id != 0) {
+ ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}]
+ .push_back(ReplaceSwizzleInvocations);
+ ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
+ .push_back(ReplaceSwizzleInvocationsMasked);
+ ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
+ ReplaceWriteInvocation);
+ ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
+ ReplaceMbcnt);
+ }
+
+ extension_id = context()->module()->GetExtInstImportId(
+ "SPV_AMD_shader_trinary_minmax");
+
+ if (extension_id != 0) {
+ ext_rules_[{extension_id, FMin3AMD}].push_back(
+ ReplaceTrinaryMinMax<GLSLstd450FMin>);
+ ext_rules_[{extension_id, UMin3AMD}].push_back(
+ ReplaceTrinaryMinMax<GLSLstd450UMin>);
+ ext_rules_[{extension_id, SMin3AMD}].push_back(
+ ReplaceTrinaryMinMax<GLSLstd450SMin>);
+ ext_rules_[{extension_id, FMax3AMD}].push_back(
+ ReplaceTrinaryMinMax<GLSLstd450FMax>);
+ ext_rules_[{extension_id, UMax3AMD}].push_back(
+ ReplaceTrinaryMinMax<GLSLstd450UMax>);
+ ext_rules_[{extension_id, SMax3AMD}].push_back(
+ ReplaceTrinaryMinMax<GLSLstd450SMax>);
+ ext_rules_[{extension_id, FMid3AMD}].push_back(
+ ReplaceTrinaryMid<GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp>);
+ ext_rules_[{extension_id, UMid3AMD}].push_back(
+ ReplaceTrinaryMid<GLSLstd450UMin, GLSLstd450UMax, GLSLstd450UClamp>);
+ ext_rules_[{extension_id, SMid3AMD}].push_back(
+ ReplaceTrinaryMid<GLSLstd450SMin, GLSLstd450SMax, GLSLstd450SClamp>);
+ }
}
};
@@ -500,9 +610,14 @@
std::vector<Instruction*> to_be_killed;
for (Instruction& inst : context()->module()->extensions()) {
if (inst.opcode() == SpvOpExtension) {
- if (!strcmp("SPV_AMD_shader_ballot",
- reinterpret_cast<const char*>(
- &(inst.GetInOperand(0).words[0])))) {
+ if (strcmp("SPV_AMD_shader_ballot",
+ reinterpret_cast<const char*>(
+ &(inst.GetInOperand(0).words[0]))) == 0) {
+ to_be_killed.push_back(&inst);
+ }
+ if (strcmp("SPV_AMD_shader_trinary_minmax",
+ reinterpret_cast<const char*>(
+ &(inst.GetInOperand(0).words[0]))) == 0) {
to_be_killed.push_back(&inst);
}
}
@@ -510,9 +625,14 @@
for (Instruction& inst : context()->ext_inst_imports()) {
if (inst.opcode() == SpvOpExtInstImport) {
- if (!strcmp("SPV_AMD_shader_ballot",
- reinterpret_cast<const char*>(
- &(inst.GetInOperand(0).words[0])))) {
+ if (strcmp("SPV_AMD_shader_ballot",
+ reinterpret_cast<const char*>(
+ &(inst.GetInOperand(0).words[0]))) == 0) {
+ to_be_killed.push_back(&inst);
+ }
+ if (strcmp("SPV_AMD_shader_trinary_minmax",
+ reinterpret_cast<const char*>(
+ &(inst.GetInOperand(0).words[0]))) == 0) {
to_be_killed.push_back(&inst);
}
}
diff --git a/source/opt/feature_manager.h b/source/opt/feature_manager.h
index 761a208..2fe3291 100644
--- a/source/opt/feature_manager.h
+++ b/source/opt/feature_manager.h
@@ -57,6 +57,9 @@
// Add the extension |ext| to the feature manager.
void AddExtension(Instruction* ext);
+ // Analyzes |module| and records imported external instruction sets.
+ void AddExtInstImportIds(Module* module);
+
private:
// Analyzes |module| and records enabled extensions.
void AddExtensions(Module* module);
@@ -64,9 +67,6 @@
// Analyzes |module| and records enabled capabilities.
void AddCapabilities(Module* module);
- // Analyzes |module| and records imported external instruction sets.
- void AddExtInstImportIds(Module* module);
-
// Auxiliary object for querying SPIR-V grammar facts.
const AssemblyGrammar& grammar_;
diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h
index b9cb26a..6720e89 100644
--- a/source/opt/ir_builder.h
+++ b/source/opt/ir_builder.h
@@ -492,6 +492,27 @@
return AddInstruction(std::move(new_inst));
}
+ Instruction* AddNaryExtendedInstruction(
+ uint32_t result_type, uint32_t set, uint32_t instruction,
+ const std::vector<uint32_t>& ext_operands) {
+ std::vector<Operand> operands;
+ operands.push_back({SPV_OPERAND_TYPE_ID, {set}});
+ operands.push_back(
+ {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {instruction}});
+ for (uint32_t id : ext_operands) {
+ operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
+ }
+
+ uint32_t result_id = GetContext()->TakeNextId();
+ if (result_id == 0) {
+ return nullptr;
+ }
+
+ std::unique_ptr<Instruction> new_inst(new Instruction(
+ GetContext(), SpvOpExtInst, result_type, result_id, operands));
+ return AddInstruction(std::move(new_inst));
+ }
+
// Inserts the new instruction before the insertion point.
Instruction* AddInstruction(std::unique_ptr<Instruction>&& insn) {
Instruction* insn_ptr = &*insert_before_.InsertBefore(std::move(insn));
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index e297fb1..3bbf180 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -199,6 +199,7 @@
inline void AddExtension(const std::string& ext_name);
inline void AddExtension(std::unique_ptr<Instruction>&& e);
// Appends an extended instruction set instruction to this module.
+ inline void AddExtInstImport(const std::string& name);
inline void AddExtInstImport(std::unique_ptr<Instruction>&& e);
// Set the memory model for this module.
inline void SetMemoryModel(std::unique_ptr<Instruction>&& m);
@@ -971,9 +972,26 @@
module()->AddExtension(std::move(e));
}
+void IRContext::AddExtInstImport(const std::string& name) {
+ const auto num_chars = name.size();
+ // Compute num words, accommodate the terminating null character.
+ const auto num_words = (num_chars + 1 + 3) / 4;
+ std::vector<uint32_t> ext_words(num_words, 0u);
+ std::memcpy(ext_words.data(), name.data(), num_chars);
+ AddExtInstImport(std::unique_ptr<Instruction>(
+ new Instruction(this, SpvOpExtInstImport, 0u, TakeNextId(),
+ {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}})));
+}
+
void IRContext::AddExtInstImport(std::unique_ptr<Instruction>&& e) {
AddCombinatorsForExtension(e.get());
+ if (AreAnalysesValid(kAnalysisDefUse)) {
+ get_def_use_mgr()->AnalyzeInstDefUse(e.get());
+ }
module()->AddExtInstImport(std::move(e));
+ if (feature_mgr_ != nullptr) {
+ feature_mgr_->AddExtInstImportIds(module());
+ }
}
void IRContext::SetMemoryModel(std::unique_ptr<Instruction>&& m) {
diff --git a/test/opt/amd_ext_to_khr.cpp b/test/opt/amd_ext_to_khr.cpp
index 7a6d4b4..cdf168a 100644
--- a/test/opt/amd_ext_to_khr.cpp
+++ b/test/opt/amd_ext_to_khr.cpp
@@ -233,6 +233,7 @@
SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
}
+
TEST_F(AmdExtToKhrTest, ReplaceWriteInvocationAMD) {
const std::string text = R"(
; CHECK: OpCapability Shader
@@ -269,6 +270,438 @@
SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
}
+TEST_F(AmdExtToKhrTest, ReplaceFMin3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeFloat 32
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] FMin [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] FMin [[temp]] [[z]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %float
+ %8 = OpUndef %float
+ %9 = OpUndef %float
+ %10 = OpExtInst %float %ext FMin3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceSMin3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 1
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] SMin [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] SMin [[temp]] [[z]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %int
+ %8 = OpUndef %int
+ %9 = OpUndef %int
+ %10 = OpExtInst %int %ext SMin3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceUMin3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 0
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] UMin [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UMin [[temp]] [[z]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %uint
+ %8 = OpUndef %uint
+ %9 = OpUndef %uint
+ %10 = OpExtInst %uint %ext UMin3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceFMax3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeFloat 32
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] FMax [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] FMax [[temp]] [[z]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %float
+ %8 = OpUndef %float
+ %9 = OpUndef %float
+ %10 = OpExtInst %float %ext FMax3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceSMax3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 1
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] SMax [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] SMax [[temp]] [[z]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %int
+ %8 = OpUndef %int
+ %9 = OpUndef %int
+ %10 = OpExtInst %int %ext SMax3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceUMax3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 0
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[temp]] [[z]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %uint
+ %8 = OpUndef %uint
+ %9 = OpUndef %uint
+ %10 = OpExtInst %uint %ext UMax3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceVecUMax3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeVector
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[temp:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[x]] [[y]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[temp]] [[z]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %vec = OpTypeVector %uint 4
+ %int = OpTypeInt 32 1
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %vec
+ %8 = OpUndef %vec
+ %9 = OpUndef %vec
+ %10 = OpExtInst %vec %ext UMax3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceFMid3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeFloat 32
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[min:%\w+]] = OpExtInst [[type]] [[ext]] FMin [[y]] [[z]]
+; CHECK-NEXT: [[max:%\w+]] = OpExtInst [[type]] [[ext]] FMax [[y]] [[z]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] FClamp [[x]] [[min]] [[max]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %float
+ %8 = OpUndef %float
+ %9 = OpUndef %float
+ %10 = OpExtInst %float %ext FMid3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceSMid3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 1
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[min:%\w+]] = OpExtInst [[type]] [[ext]] SMin [[y]] [[z]]
+; CHECK-NEXT: [[max:%\w+]] = OpExtInst [[type]] [[ext]] SMax [[y]] [[z]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] SClamp [[x]] [[min]] [[max]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %int
+ %8 = OpUndef %int
+ %9 = OpUndef %int
+ %10 = OpExtInst %int %ext SMid3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceUMid3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeInt 32 0
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[min:%\w+]] = OpExtInst [[type]] [[ext]] UMin [[y]] [[z]]
+; CHECK-NEXT: [[max:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[y]] [[z]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UClamp [[x]] [[min]] [[max]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %int = OpTypeInt 32 1
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %uint
+ %8 = OpUndef %uint
+ %9 = OpUndef %uint
+ %10 = OpExtInst %uint %ext UMid3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
+TEST_F(AmdExtToKhrTest, ReplaceVecUMid3AMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_trinary_minmax"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+; CHECK: [[ext:%\w+]] = OpExtInstImport "GLSL.std.450"
+; CHECK: [[type:%\w+]] = OpTypeVector
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[x:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[y:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[z:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[min:%\w+]] = OpExtInst [[type]] [[ext]] UMin [[y]] [[z]]
+; CHECK-NEXT: [[max:%\w+]] = OpExtInst [[type]] [[ext]] UMax [[y]] [[z]]
+; CHECK-NEXT: [[result:%\w+]] = OpExtInst [[type]] [[ext]] UClamp [[x]] [[min]] [[max]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_trinary_minmax"
+ %ext = OpExtInstImport "SPV_AMD_shader_trinary_minmax"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %vec = OpTypeVector %uint 3
+ %int = OpTypeInt 32 1
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %7 = OpUndef %vec
+ %8 = OpUndef %vec
+ %9 = OpUndef %vec
+ %10 = OpExtInst %vec %ext UMid3AMD %7 %8 %9
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
+
TEST_F(AmdExtToKhrTest, SetVersion) {
const std::string text = R"(
OpCapability Shader