spirv-fuzz: Do not allow creation of constants of block-decorated structs (#3903)
Fixes #3902.
diff --git a/source/fuzz/fuzzer_pass_add_composite_inserts.cpp b/source/fuzz/fuzzer_pass_add_composite_inserts.cpp
index 515407b..e58c754 100644
--- a/source/fuzz/fuzzer_pass_add_composite_inserts.cpp
+++ b/source/fuzz/fuzzer_pass_add_composite_inserts.cpp
@@ -166,9 +166,8 @@
// this type.
uint32_t available_object_id;
if (available_objects.empty()) {
- auto current_node_type =
- GetIRContext()->get_type_mgr()->GetType(current_node_type_id);
- if (!fuzzerutil::CanCreateConstant(*current_node_type)) {
+ if (!fuzzerutil::CanCreateConstant(GetIRContext(),
+ current_node_type_id)) {
return;
}
available_object_id =
diff --git a/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp b/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp
index 5e79ec8..1bc0c2b 100644
--- a/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp
+++ b/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp
@@ -100,9 +100,8 @@
// The id will be a zero constant if the type allows it, and an
// OpUndef otherwise. We want to avoid using OpUndef, if possible, to
// avoid undefined behaviour in the module as much as possible.
- if (fuzzerutil::CanCreateConstant(
- *GetIRContext()->get_type_mgr()->GetType(
- instruction->type_id()))) {
+ if (fuzzerutil::CanCreateConstant(GetIRContext(),
+ instruction->type_id())) {
wrapper_info.set_value_to_copy_id(
FindOrCreateZeroConstant(instruction->type_id(), true));
} else {
diff --git a/source/fuzz/fuzzer_pass_replace_branches_from_dead_blocks_with_exits.cpp b/source/fuzz/fuzzer_pass_replace_branches_from_dead_blocks_with_exits.cpp
index 200aaad..e6bebea 100644
--- a/source/fuzz/fuzzer_pass_replace_branches_from_dead_blocks_with_exits.cpp
+++ b/source/fuzz/fuzzer_pass_replace_branches_from_dead_blocks_with_exits.cpp
@@ -81,7 +81,8 @@
GetIRContext()->get_type_mgr()->GetType(function.type_id());
if (function_return_type->AsVoid()) {
opcodes.emplace_back(SpvOpReturn);
- } else if (fuzzerutil::CanCreateConstant(*function_return_type)) {
+ } else if (fuzzerutil::CanCreateConstant(GetIRContext(),
+ function.type_id())) {
// For simplicity we only allow OpReturnValue if the function return
// type is a type for which we can create a constant. This allows us a
// zero of the given type as a default return value.
diff --git a/source/fuzz/fuzzer_pass_replace_copy_objects_with_stores_loads.cpp b/source/fuzz/fuzzer_pass_replace_copy_objects_with_stores_loads.cpp
index 51cb569..e372924 100644
--- a/source/fuzz/fuzzer_pass_replace_copy_objects_with_stores_loads.cpp
+++ b/source/fuzz/fuzzer_pass_replace_copy_objects_with_stores_loads.cpp
@@ -69,9 +69,8 @@
// Find or create a constant to initialize the variable from. The type of
// |instruction| must be such that the function FindOrCreateConstant can be
// called.
- auto instruction_type =
- GetIRContext()->get_type_mgr()->GetType(instruction->type_id());
- if (!fuzzerutil::CanCreateConstant(*instruction_type)) {
+ if (!fuzzerutil::CanCreateConstant(GetIRContext(),
+ instruction->type_id())) {
return;
}
auto variable_initializer_id =
diff --git a/source/fuzz/fuzzer_pass_replace_parameter_with_global.cpp b/source/fuzz/fuzzer_pass_replace_parameter_with_global.cpp
index 8672a3b..6b3a63b 100644
--- a/source/fuzz/fuzzer_pass_replace_parameter_with_global.cpp
+++ b/source/fuzz/fuzzer_pass_replace_parameter_with_global.cpp
@@ -53,27 +53,23 @@
// function has at least one parameter.
if (std::none_of(params.begin(), params.end(),
[this](const opt::Instruction* param) {
- const auto* param_type =
- GetIRContext()->get_type_mgr()->GetType(
- param->type_id());
- assert(param_type && "Parameter has invalid type");
return TransformationReplaceParameterWithGlobal::
- IsParameterTypeSupported(*param_type);
+ IsParameterTypeSupported(GetIRContext(),
+ param->type_id());
})) {
continue;
}
// Select id of a parameter to replace.
- const opt::Instruction* replaced_param = nullptr;
- const opt::analysis::Type* param_type = nullptr;
+ const opt::Instruction* replaced_param;
+ uint32_t param_type_id;
do {
replaced_param = GetFuzzerContext()->RemoveAtRandomIndex(¶ms);
- param_type =
- GetIRContext()->get_type_mgr()->GetType(replaced_param->type_id());
- assert(param_type && "Parameter has invalid type");
+ param_type_id = replaced_param->type_id();
+ assert(param_type_id && "Parameter has invalid type");
} while (
!TransformationReplaceParameterWithGlobal::IsParameterTypeSupported(
- *param_type));
+ GetIRContext(), param_type_id));
assert(replaced_param && "Unable to find a parameter to replace");
diff --git a/source/fuzz/fuzzer_pass_replace_params_with_struct.cpp b/source/fuzz/fuzzer_pass_replace_params_with_struct.cpp
index 86d6d06..0e0610f 100644
--- a/source/fuzz/fuzzer_pass_replace_params_with_struct.cpp
+++ b/source/fuzz/fuzzer_pass_replace_params_with_struct.cpp
@@ -53,15 +53,13 @@
std::iota(parameter_index.begin(), parameter_index.end(), 0);
// Remove the indices of unsupported parameters.
- auto new_end = std::remove_if(
- parameter_index.begin(), parameter_index.end(),
- [this, ¶ms](uint32_t index) {
- const auto* type =
- GetIRContext()->get_type_mgr()->GetType(params[index]->type_id());
- assert(type && "Parameter has invalid type");
- return !TransformationReplaceParamsWithStruct::
- IsParameterTypeSupported(*type);
- });
+ auto new_end =
+ std::remove_if(parameter_index.begin(), parameter_index.end(),
+ [this, ¶ms](uint32_t index) {
+ return !TransformationReplaceParamsWithStruct::
+ IsParameterTypeSupported(GetIRContext(),
+ params[index]->type_id());
+ });
// std::remove_if changes the vector so that removed elements are placed at
// the end (i.e. [new_end, parameter_index.end()) is a range of removed
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 95ead39..122f21e 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -1119,22 +1119,32 @@
}
}
-bool CanCreateConstant(const opt::analysis::Type& type) {
- switch (type.kind()) {
- case opt::analysis::Type::kBool:
- case opt::analysis::Type::kInteger:
- case opt::analysis::Type::kFloat:
- case opt::analysis::Type::kMatrix:
- case opt::analysis::Type::kVector:
+bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id) {
+ opt::Instruction* type_instr = ir_context->get_def_use_mgr()->GetDef(type_id);
+ assert(type_instr != nullptr && "The type must exist.");
+ assert(spvOpcodeGeneratesType(type_instr->opcode()) &&
+ "A type-generating opcode was expected.");
+ switch (type_instr->opcode()) {
+ case SpvOpTypeBool:
+ case SpvOpTypeInt:
+ case SpvOpTypeFloat:
+ case SpvOpTypeMatrix:
+ case SpvOpTypeVector:
return true;
- case opt::analysis::Type::kArray:
- return CanCreateConstant(*type.AsArray()->element_type());
- case opt::analysis::Type::kStruct:
- return std::all_of(type.AsStruct()->element_types().begin(),
- type.AsStruct()->element_types().end(),
- [](const opt::analysis::Type* element_type) {
- return CanCreateConstant(*element_type);
- });
+ case SpvOpTypeArray:
+ return CanCreateConstant(ir_context,
+ type_instr->GetSingleWordInOperand(0));
+ case SpvOpTypeStruct:
+ if (HasBlockOrBufferBlockDecoration(ir_context, type_id)) {
+ return false;
+ }
+ for (uint32_t index = 0; index < type_instr->NumInOperands(); index++) {
+ if (!CanCreateConstant(ir_context,
+ type_instr->GetSingleWordInOperand(index))) {
+ return false;
+ }
+ }
+ return true;
default:
return false;
}
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 6a6efb8..569f4a6 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -399,9 +399,10 @@
uint32_t scalar_or_composite_type_id, bool is_irrelevant);
// Returns true if it is possible to create an OpConstant or an
-// OpConstantComposite instruction of |type|. That is, returns true if |type|
-// and all its constituents are either scalar or composite.
-bool CanCreateConstant(const opt::analysis::Type& type);
+// OpConstantComposite instruction of type |type_id|. That is, returns true if
+// the type associated with |type_id| and all its constituents are either scalar
+// or composite.
+bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id);
// Returns the result id of an OpConstant instruction. |scalar_type_id| must be
// a result id of a scalar type (i.e. int, float or bool). Returns 0 if no such
diff --git a/source/fuzz/transformation_mutate_pointer.cpp b/source/fuzz/transformation_mutate_pointer.cpp
index 36c5951..fefedbd 100644
--- a/source/fuzz/transformation_mutate_pointer.cpp
+++ b/source/fuzz/transformation_mutate_pointer.cpp
@@ -139,25 +139,28 @@
return false;
}
- const auto* type = ir_context->get_type_mgr()->GetType(inst.type_id());
- assert(type && "|inst| has invalid type id");
-
- const auto* pointer_type = type->AsPointer();
+ opt::Instruction* type_inst =
+ ir_context->get_def_use_mgr()->GetDef(inst.type_id());
+ assert(type_inst != nullptr && "|inst| has invalid type id");
// |inst| must be a pointer.
- if (!pointer_type) {
+ if (type_inst->opcode() != SpvOpTypePointer) {
return false;
}
// |inst| must have a supported storage class.
- if (pointer_type->storage_class() != SpvStorageClassFunction &&
- pointer_type->storage_class() != SpvStorageClassPrivate &&
- pointer_type->storage_class() != SpvStorageClassWorkgroup) {
- return false;
+ switch (static_cast<SpvStorageClass>(type_inst->GetSingleWordInOperand(0))) {
+ case SpvStorageClassFunction:
+ case SpvStorageClassPrivate:
+ case SpvStorageClassWorkgroup:
+ break;
+ default:
+ return false;
}
// |inst|'s pointee must consist of scalars and/or composites.
- return fuzzerutil::CanCreateConstant(*pointer_type->pointee_type());
+ return fuzzerutil::CanCreateConstant(ir_context,
+ type_inst->GetSingleWordInOperand(1));
}
std::unordered_set<uint32_t> TransformationMutatePointer::GetFreshIds() const {
diff --git a/source/fuzz/transformation_replace_parameter_with_global.cpp b/source/fuzz/transformation_replace_parameter_with_global.cpp
index e75a9b2..cdf7645 100644
--- a/source/fuzz/transformation_replace_parameter_with_global.cpp
+++ b/source/fuzz/transformation_replace_parameter_with_global.cpp
@@ -57,10 +57,7 @@
// |parameter_id|.
// Check that replaced parameter has valid type.
- const auto* param_type =
- ir_context->get_type_mgr()->GetType(param_inst->type_id());
- assert(param_type && "Parameter has invalid type");
- if (!IsParameterTypeSupported(*param_type)) {
+ if (!IsParameterTypeSupported(ir_context, param_inst->type_id())) {
return false;
}
@@ -198,10 +195,10 @@
}
bool TransformationReplaceParameterWithGlobal::IsParameterTypeSupported(
- const opt::analysis::Type& type) {
+ opt::IRContext* ir_context, uint32_t param_type_id) {
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3403):
// Think about other type instructions we can add here.
- return fuzzerutil::CanCreateConstant(type);
+ return fuzzerutil::CanCreateConstant(ir_context, param_type_id);
}
std::unordered_set<uint32_t>
diff --git a/source/fuzz/transformation_replace_parameter_with_global.h b/source/fuzz/transformation_replace_parameter_with_global.h
index a5bdc5b..c2d5f8f 100644
--- a/source/fuzz/transformation_replace_parameter_with_global.h
+++ b/source/fuzz/transformation_replace_parameter_with_global.h
@@ -57,7 +57,8 @@
// Returns true if the type of the parameter is supported by this
// transformation.
- static bool IsParameterTypeSupported(const opt::analysis::Type& type);
+ static bool IsParameterTypeSupported(opt::IRContext* ir_context,
+ uint32_t param_type_id);
private:
protobufs::TransformationReplaceParameterWithGlobal message_;
diff --git a/source/fuzz/transformation_replace_params_with_struct.cpp b/source/fuzz/transformation_replace_params_with_struct.cpp
index 3f8b21b..0a135e5 100644
--- a/source/fuzz/transformation_replace_params_with_struct.cpp
+++ b/source/fuzz/transformation_replace_params_with_struct.cpp
@@ -85,10 +85,8 @@
}
// Check that the parameter with result id |id| has supported type.
- const auto* type = ir_context->get_type_mgr()->GetType(
- fuzzerutil::GetTypeId(ir_context, id));
- assert(type && "Parameter has invalid type");
- if (!IsParameterTypeSupported(*type)) {
+ if (!IsParameterTypeSupported(ir_context,
+ fuzzerutil::GetTypeId(ir_context, id))) {
return false;
}
}
@@ -263,10 +261,10 @@
}
bool TransformationReplaceParamsWithStruct::IsParameterTypeSupported(
- const opt::analysis::Type& param_type) {
+ opt::IRContext* ir_context, uint32_t param_type_id) {
// TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3403):
// Consider adding support for more types of parameters.
- return fuzzerutil::CanCreateConstant(param_type);
+ return fuzzerutil::CanCreateConstant(ir_context, param_type_id);
}
uint32_t TransformationReplaceParamsWithStruct::MaybeGetRequiredStructType(
diff --git a/source/fuzz/transformation_replace_params_with_struct.h b/source/fuzz/transformation_replace_params_with_struct.h
index 4f88f8e..afa6b14 100644
--- a/source/fuzz/transformation_replace_params_with_struct.h
+++ b/source/fuzz/transformation_replace_params_with_struct.h
@@ -68,7 +68,8 @@
protobufs::Transformation ToMessage() const override;
// Returns true if parameter's type is supported by this transformation.
- static bool IsParameterTypeSupported(const opt::analysis::Type& param_type);
+ static bool IsParameterTypeSupported(opt::IRContext* ir_context,
+ uint32_t param_type_id);
private:
// Returns a result id of the OpTypeStruct instruction required by this
diff --git a/test/fuzz/transformation_flatten_conditional_branch_test.cpp b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
index d77173d..579b696 100644
--- a/test/fuzz/transformation_flatten_conditional_branch_test.cpp
+++ b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
@@ -25,8 +25,8 @@
protobufs::SideEffectWrapperInfo MakeSideEffectWrapperInfo(
const protobufs::InstructionDescriptor& instruction,
uint32_t merge_block_id, uint32_t execute_block_id,
- uint32_t actual_result_id = 0, uint32_t alternative_block_id = 0,
- uint32_t placeholder_result_id = 0, uint32_t value_to_copy_id = 0) {
+ uint32_t actual_result_id, uint32_t alternative_block_id,
+ uint32_t placeholder_result_id, uint32_t value_to_copy_id) {
protobufs::SideEffectWrapperInfo result;
*result.mutable_instruction() = instruction;
result.set_merge_block_id(merge_block_id);
@@ -38,6 +38,13 @@
return result;
}
+protobufs::SideEffectWrapperInfo MakeSideEffectWrapperInfo(
+ const protobufs::InstructionDescriptor& instruction,
+ uint32_t merge_block_id, uint32_t execute_block_id) {
+ return MakeSideEffectWrapperInfo(instruction, merge_block_id,
+ execute_block_id, 0, 0, 0, 0);
+}
+
TEST(TransformationFlattenConditionalBranchTest, Inapplicable) {
std::string shader = R"(
OpCapability Shader
@@ -434,19 +441,19 @@
#endif
// The map maps from an instruction to a list with not enough fresh ids.
- ASSERT_FALSE(
- TransformationFlattenConditionalBranch(
- 31, true,
- {{MakeSideEffectWrapperInfo(
- MakeInstructionDescriptor(6, SpvOpLoad, 0), 100, 101, 102, 103)}})
- .IsApplicable(context.get(), transformation_context));
+ ASSERT_FALSE(TransformationFlattenConditionalBranch(
+ 31, true,
+ {{MakeSideEffectWrapperInfo(
+ MakeInstructionDescriptor(6, SpvOpLoad, 0), 100, 101,
+ 102, 103, 0, 0)}})
+ .IsApplicable(context.get(), transformation_context));
// Not all fresh ids given are distinct.
ASSERT_FALSE(TransformationFlattenConditionalBranch(
31, true,
{{MakeSideEffectWrapperInfo(
MakeInstructionDescriptor(6, SpvOpLoad, 0), 100, 100,
- 102, 103, 104)}})
+ 102, 103, 104, 0)}})
.IsApplicable(context.get(), transformation_context));
// %48 heads a construct containing an OpSampledImage instruction.
@@ -454,7 +461,7 @@
48, true,
{{MakeSideEffectWrapperInfo(
MakeInstructionDescriptor(53, SpvOpLoad, 0), 100, 101,
- 102, 103, 104)}})
+ 102, 103, 104, 0)}})
.IsApplicable(context.get(), transformation_context));
// %0 is not a valid id.
@@ -1059,6 +1066,59 @@
ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
}
+TEST(TransformationFlattenConditionalBranchTest,
+ LoadFromBufferBlockDecoratedStruct) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 320
+ OpMemberDecorate %11 0 Offset 0
+ OpDecorate %11 BufferBlock
+ OpDecorate %13 DescriptorSet 0
+ OpDecorate %13 Binding 0
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeBool
+ %7 = OpConstantTrue %6
+ %10 = OpTypeInt 32 1
+ %11 = OpTypeStruct %10
+ %12 = OpTypePointer Uniform %11
+ %13 = OpVariable %12 Uniform
+ %21 = OpUndef %11
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpSelectionMerge %9 None
+ OpBranchConditional %7 %8 %9
+ %8 = OpLabel
+ %20 = OpLoad %11 %13
+ OpBranch %9
+ %9 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_3;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+
+ auto transformation = TransformationFlattenConditionalBranch(
+ 5, true,
+ {MakeSideEffectWrapperInfo(MakeInstructionDescriptor(20, SpvOpLoad, 0),
+ 100, 101, 102, 103, 104, 21)});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+ ASSERT_TRUE(IsValid(env, context.get()));
+}
+
} // namespace
} // namespace fuzz
} // namespace spvtools