Replace SwizzleInvocationsAMD extended instruction. (#2823)
Part of #2814
diff --git a/source/opt/amd_ext_to_khr.cpp b/source/opt/amd_ext_to_khr.cpp
index 0d84ef3..1cb5ba5 100644
--- a/source/opt/amd_ext_to_khr.cpp
+++ b/source/opt/amd_ext_to_khr.cpp
@@ -79,16 +79,121 @@
};
}
-FoldingRule NotImplementedYet() {
- return [](IRContext*, Instruction*,
+// Returns a folding rule that will replace the SwizzleInvocationsAMD extended
+// instruction in the SPV_AMD_shader_ballot extension.
+//
+// The instruction
+//
+// %offset = OpConstantComposite %v3uint %x %y %z %w
+// %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset
+//
+// is replaced with
+//
+// potentially new constants and types
+//
+// clang-format off
+// %uint_max = OpConstant %uint 0xFFFFFFFF
+// %v4uint = OpTypeVector %uint 4
+// %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
+// %null = OpConstantNull %type
+// clang-format on
+//
+// and the following code in the function body
+//
+// clang-format off
+// %id = OpLoad %uint %SubgroupLocalInvocationId
+// %quad_idx = OpBitwiseAnd %uint %id %uint_3
+// %quad_ldr = OpBitwiseXor %uint %id %quad_idx
+// %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx
+// %target_inv = OpIAdd %uint %quad_ldr %my_offset
+// %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
+// %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
+// %result = OpSelect %type %is_active %shuffle %null
+// clang-format on
+//
+// Also adding the capabilities and builtins that are needed.
+FoldingRule ReplaceSwizzleInvocations() {
+ return [](IRContext* ctx, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
- assert(false && "Replacement not implemented yet.");
- return false;
+ analysis::TypeManager* type_mgr = ctx->get_type_mgr();
+ analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
+
+ ctx->AddExtension("SPV_KHR_shader_ballot");
+ ctx->AddCapability(SpvCapabilityGroupNonUniformBallot);
+ ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle);
+
+ InstructionBuilder ir_builder(
+ ctx, inst,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ uint32_t data_id = inst->GetSingleWordInOperand(2);
+ uint32_t offset_id = inst->GetSingleWordInOperand(3);
+
+ // Get the subgroup invocation id.
+ uint32_t var_id =
+ ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId);
+ assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable.");
+ Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
+ Instruction* var_ptr_type =
+ ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
+ uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
+
+ Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
+
+ uint32_t quad_mask = ir_builder.GetUintConstantId(3);
+
+ // This gives the offset in the group of 4 of this invocation.
+ Instruction* quad_idx = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseAnd, id->result_id(), quad_mask);
+
+ // Get the invocation id of the first invocation in the group of 4.
+ Instruction* quad_ldr = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id());
+
+ // Get the offset of the target invocation from the offset vector.
+ Instruction* my_offset =
+ ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic,
+ offset_id, quad_idx->result_id());
+
+ // Determine the index of the invocation to read from.
+ Instruction* target_inv = ir_builder.AddBinaryOp(
+ uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id());
+
+ // Do the group operations
+ uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
+ uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup);
+ const auto* ballot_value_const = const_mgr->GetConstant(
+ type_mgr->GetUIntVectorType(4),
+ {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
+ Instruction* ballot_value =
+ const_mgr->GetDefiningInstruction(ballot_value_const);
+ Instruction* is_active = ir_builder.AddNaryOp(
+ type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract,
+ {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
+ Instruction* shuffle = ir_builder.AddNaryOp(
+ inst->type_id(), SpvOpGroupNonUniformShuffle,
+ {subgroup_scope, data_id, target_inv->result_id()});
+
+ // Create the null constant to use in the select.
+ const auto* null = const_mgr->GetConstant(
+ type_mgr->GetType(inst->type_id()), std::vector<uint32_t>());
+ Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
+
+ // Build the select.
+ inst->SetOpcode(SpvOpSelect);
+ Instruction::OperandList new_operands;
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
+
+ inst->SetInOperands(std::move(new_operands));
+ ctx->UpdateDefUse(inst);
+ return true;
};
}
-// Returns a folding rule that will replace the WriteInvocationAMD extended
-// instruction in the SPV_AMD_shader_ballot extension.
+// Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD
+// extended instruction in the SPV_AMD_shader_ballot extension.
//
// The instruction
//
@@ -120,7 +225,7 @@
// clang-format on
//
// Also adding the capabilities and builtins that are needed.
-FoldingRule FoldSwizzleInvocationsMasked() {
+FoldingRule ReplaceSwizzleInvocationsMasked() {
return [](IRContext* ctx, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
analysis::TypeManager* type_mgr = ctx->get_type_mgr();
@@ -353,9 +458,9 @@
context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot");
ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}].push_back(
- NotImplementedYet());
+ ReplaceSwizzleInvocations());
ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
- .push_back(FoldSwizzleInvocationsMasked());
+ .push_back(ReplaceSwizzleInvocationsMasked());
ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
ReplaceWriteInvocation());
ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
diff --git a/test/opt/amd_ext_to_khr.cpp b/test/opt/amd_ext_to_khr.cpp
index b260761..7a6d4b4 100644
--- a/test/opt/amd_ext_to_khr.cpp
+++ b/test/opt/amd_ext_to_khr.cpp
@@ -135,6 +135,54 @@
SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
}
+TEST_F(AmdExtToKhrTest, ReplaceSwizzleInvocationsAMD) {
+ const std::string text = R"(
+; CHECK: OpCapability Shader
+; CHECK-NOT: OpExtension "SPV_AMD_shader_ballot"
+; CHECK-NOT: OpExtInstImport "SPV_AMD_shader_ballot"
+; CHECK: OpDecorate [[var:%\w+]] BuiltIn SubgroupLocalInvocationId
+; CHECK: [[subgroup:%\w+]] = OpConstant %uint 3
+; CHECK: [[offset:%\w+]] = OpConstantComposite %v4uint
+; CHECK: [[var]] = OpVariable %_ptr_Input_uint Input
+; CHECK: [[uint_max:%\w+]] = OpConstant %uint 4294967295
+; CHECK: [[ballot_value:%\w+]] = OpConstantComposite %v4uint [[uint_max]] [[uint_max]] [[uint_max]] [[uint_max]]
+; CHECK: [[null:%\w+]] = OpConstantNull [[type:%\w+]]
+; CHECK: OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[data:%\w+]] = OpUndef [[type]]
+; CHECK-NEXT: [[id:%\w+]] = OpLoad %uint [[var]]
+; CHECK-NEXT: [[quad_idx:%\w+]] = OpBitwiseAnd %uint [[id]] %uint_3
+; CHECK-NEXT: [[quad_ldr:%\w+]] = OpBitwiseXor %uint [[id]] [[quad_idx]]
+; CHECK-NEXT: [[my_offset:%\w+]] = OpVectorExtractDynamic %uint [[offset]] [[quad_idx]]
+; CHECK-NEXT: [[target_inv:%\w+]] = OpIAdd %uint [[quad_ldr]] [[my_offset]]
+; CHECK-NEXT: [[is_active:%\w+]] = OpGroupNonUniformBallotBitExtract %bool [[subgroup]] [[ballot_value]] [[target_inv]]
+; CHECK-NEXT: [[shuffle:%\w+]] = OpGroupNonUniformShuffle [[type]] [[subgroup]] [[data]] [[target_inv]]
+; CHECK-NEXT: [[result:%\w+]] = OpSelect [[type]] [[is_active]] [[shuffle]] [[null]]
+ OpCapability Shader
+ OpExtension "SPV_AMD_shader_ballot"
+ %ext = OpExtInstImport "SPV_AMD_shader_ballot"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %1 "func"
+ OpExecutionMode %1 OriginUpperLeft
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %uint_x = OpConstant %uint 1
+ %uint_y = OpConstant %uint 2
+ %uint_z = OpConstant %uint 3
+ %uint_w = OpConstant %uint 0
+ %v4uint = OpTypeVector %uint 4
+ %offset = OpConstantComposite %v4uint %uint_x %uint_y %uint_z %uint_x
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ %data = OpUndef %uint
+ %9 = OpExtInst %uint %ext SwizzleInvocationsAMD %data %offset
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<AmdExtensionToKhrPass>(text, true);
+}
TEST_F(AmdExtToKhrTest, ReplaceSwizzleInvocationsMaskedAMD) {
const std::string text = R"(
; CHECK: OpCapability Shader