spirv-fuzz: Handle OpPhi during constant obfuscation (#3640)
Fixes #3639.
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 15d1057..f9564bc 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -1317,6 +1317,31 @@
return result;
}
+opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
+ uint32_t block_id,
+ SpvOp opcode) {
+ // CFG::block uses std::map::at which throws an exception when |block_id| is
+ // invalid. The error message is unhelpful, though. Thus, we test that
+ // |block_id| is valid here.
+ const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
+ (void)label_inst; // Make compilers happy in release mode.
+ assert(label_inst && label_inst->opcode() == SpvOpLabel &&
+ "|block_id| is invalid");
+
+ auto* block = ir_context->cfg()->block(block_id);
+ auto it = block->rbegin();
+ assert(it != block->rend() && "Basic block can't be empty");
+
+ if (block->GetMergeInst()) {
+ ++it;
+ assert(it != block->rend() &&
+ "|block| must have at least two instructions:"
+ "terminator and a merge instruction");
+ }
+
+ return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
+}
+
} // namespace fuzzerutil
} // namespace fuzz
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 6058022..cb01ca7 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -487,6 +487,12 @@
google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data);
+// Returns the last instruction in |block_id| before which an instruction with
+// opcode |opcode| can be inserted, or nullptr if there is no such instruction.
+opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
+ uint32_t block_id,
+ SpvOp opcode);
+
} // namespace fuzzerutil
} // namespace fuzz
diff --git a/source/fuzz/transformation_replace_constant_with_uniform.cpp b/source/fuzz/transformation_replace_constant_with_uniform.cpp
index a8f9495..8de7201 100644
--- a/source/fuzz/transformation_replace_constant_with_uniform.cpp
+++ b/source/fuzz/transformation_replace_constant_with_uniform.cpp
@@ -90,6 +90,40 @@
operands_for_load);
}
+opt::Instruction*
+TransformationReplaceConstantWithUniform::GetInsertBeforeInstruction(
+ opt::IRContext* ir_context) const {
+ auto* result =
+ FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
+ if (!result) {
+ return nullptr;
+ }
+
+ // The use might be in an OpPhi instruction.
+ if (result->opcode() == SpvOpPhi) {
+ // OpPhi instructions must be the first instructions in a block. Thus, we
+ // can't insert above the OpPhi instruction. Given the predecessor block
+ // that corresponds to the id use, get the last instruction in that block
+ // above which we can insert OpAccessChain and OpLoad.
+ return fuzzerutil::GetLastInsertBeforeInstruction(
+ ir_context,
+ result->GetSingleWordInOperand(
+ message_.id_use_descriptor().in_operand_index() + 1),
+ SpvOpLoad);
+ }
+
+ // The only operand that we could've replaced in the OpBranchConditional is
+ // the condition id. But that operand has a boolean type and uniform variables
+ // can't store booleans (see the spec on OpTypeBool). Thus, |result| can't be
+ // an OpBranchConditional.
+ assert(result->opcode() != SpvOpBranchConditional &&
+ "OpBranchConditional has no operands to replace");
+
+ assert(fuzzerutil::CanInsertOpcodeBeforeInstruction(SpvOpLoad, result) &&
+ "We should be able to insert OpLoad and OpAccessChain at this point");
+ return result;
+}
+
bool TransformationReplaceConstantWithUniform::IsApplicable(
opt::IRContext* ir_context,
const TransformationContext& transformation_context) const {
@@ -188,6 +222,12 @@
}
}
+ // Once all checks are completed, we should be able to safely insert
+ // OpAccessChain and OpLoad into the module.
+ assert(GetInsertBeforeInstruction(ir_context) &&
+ "There must exist an instruction that we can use to insert "
+ "OpAccessChain and OpLoad above");
+
return true;
}
@@ -195,7 +235,7 @@
spvtools::opt::IRContext* ir_context,
TransformationContext* /*unused*/) const {
// Get the instruction that contains the id use we wish to replace.
- auto instruction_containing_constant_use =
+ auto* instruction_containing_constant_use =
FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
assert(instruction_containing_constant_use &&
"Precondition requires that the id use can be found.");
@@ -210,12 +250,17 @@
->GetDef(message_.id_use_descriptor().id_of_interest())
->type_id();
+ // Get an instruction that will be used to insert OpAccessChain and OpLoad.
+ auto* insert_before_inst = GetInsertBeforeInstruction(ir_context);
+ assert(insert_before_inst &&
+ "There must exist an insertion point for OpAccessChain and OpLoad");
+
// Add an access chain instruction to target the uniform element.
- instruction_containing_constant_use->InsertBefore(
+ insert_before_inst->InsertBefore(
MakeAccessChainInstruction(ir_context, constant_type_id));
// Add a load from this access chain.
- instruction_containing_constant_use->InsertBefore(
+ insert_before_inst->InsertBefore(
MakeLoadInstruction(ir_context, constant_type_id));
// Adjust the instruction containing the usage of the constant so that this
diff --git a/source/fuzz/transformation_replace_constant_with_uniform.h b/source/fuzz/transformation_replace_constant_with_uniform.h
index b72407c..c507c32 100644
--- a/source/fuzz/transformation_replace_constant_with_uniform.h
+++ b/source/fuzz/transformation_replace_constant_with_uniform.h
@@ -84,6 +84,11 @@
std::unique_ptr<opt::Instruction> MakeLoadInstruction(
spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const;
+ // OpAccessChain and OpLoad will be inserted above the instruction returned
+ // by this function. Returns nullptr if no such instruction is present.
+ opt::Instruction* GetInsertBeforeInstruction(
+ opt::IRContext* ir_context) const;
+
protobufs::TransformationReplaceConstantWithUniform message_;
};
diff --git a/test/fuzz/transformation_replace_constant_with_uniform_test.cpp b/test/fuzz/transformation_replace_constant_with_uniform_test.cpp
index 8cbba46..79757b3 100644
--- a/test/fuzz/transformation_replace_constant_with_uniform_test.cpp
+++ b/test/fuzz/transformation_replace_constant_with_uniform_test.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include "source/fuzz/transformation_replace_constant_with_uniform.h"
+
#include "source/fuzz/instruction_descriptor.h"
#include "source/fuzz/uniform_buffer_element_descriptor.h"
#include "test/fuzz/fuzz_test_util.h"
@@ -1548,6 +1549,111 @@
.IsApplicable(context.get(), transformation_context));
}
+TEST(TransformationReplaceConstantWithUniformTest, ReplaceOpPhiOperand) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 320
+ OpDecorate %32 DescriptorSet 0
+ OpDecorate %32 Binding 0
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpConstant %6 2
+ %13 = OpConstant %6 4
+ %21 = OpConstant %6 1
+ %34 = OpConstant %6 0
+ %10 = OpTypeBool
+ %30 = OpTypeStruct %6
+ %31 = OpTypePointer Uniform %30
+ %32 = OpVariable %31 Uniform
+ %33 = OpTypePointer Uniform %6
+ %4 = OpFunction %2 None %3
+ %11 = OpLabel
+ OpBranch %5
+ %5 = OpLabel
+ %23 = OpPhi %6 %7 %11 %20 %15
+ %9 = OpSLessThan %10 %23 %13
+ OpLoopMerge %8 %15 None
+ OpBranchConditional %9 %15 %8
+ %15 = OpLabel
+ %20 = OpIAdd %6 %23 %21
+ OpBranch %5
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(&fact_manager,
+ validator_options);
+
+ auto int_descriptor = MakeUniformBufferElementDescriptor(0, 0, {0});
+
+ ASSERT_TRUE(
+ AddFactHelper(&transformation_context, context.get(), 2, int_descriptor));
+
+ {
+ TransformationReplaceConstantWithUniform transformation(
+ MakeIdUseDescriptor(7, MakeInstructionDescriptor(23, SpvOpPhi, 0), 0),
+ int_descriptor, 50, 51);
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ transformation.Apply(context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+ }
+
+ std::string after_transformation = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 320
+ OpDecorate %32 DescriptorSet 0
+ OpDecorate %32 Binding 0
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpConstant %6 2
+ %13 = OpConstant %6 4
+ %21 = OpConstant %6 1
+ %34 = OpConstant %6 0
+ %10 = OpTypeBool
+ %30 = OpTypeStruct %6
+ %31 = OpTypePointer Uniform %30
+ %32 = OpVariable %31 Uniform
+ %33 = OpTypePointer Uniform %6
+ %4 = OpFunction %2 None %3
+ %11 = OpLabel
+ %50 = OpAccessChain %33 %32 %34
+ %51 = OpLoad %6 %50
+ OpBranch %5
+ %5 = OpLabel
+ %23 = OpPhi %6 %51 %11 %20 %15
+ %9 = OpSLessThan %10 %23 %13
+ OpLoopMerge %8 %15 None
+ OpBranchConditional %9 %15 %8
+ %15 = OpLabel
+ %20 = OpIAdd %6 %23 %21
+ OpBranch %5
+ %8 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
} // namespace
} // namespace fuzz
} // namespace spvtools