spirv-fuzz: Transformation to add a new function to a module (#3114) This adds a large transformation that can add a new function to a SPIR-V module. This paves the way for donation of code from one module to another.
diff --git a/source/fuzz/CMakeLists.txt b/source/fuzz/CMakeLists.txt index 1b1da9b..bc7d453 100644 --- a/source/fuzz/CMakeLists.txt +++ b/source/fuzz/CMakeLists.txt
@@ -55,6 +55,7 @@ fuzzer_util.h id_use_descriptor.h instruction_descriptor.h + instruction_message.h protobufs/spirvfuzz_protobufs.h pseudo_random_generator.h random_generator.h @@ -66,6 +67,7 @@ transformation_add_constant_scalar.h transformation_add_dead_break.h transformation_add_dead_continue.h + transformation_add_function.h transformation_add_global_undef.h transformation_add_global_variable.h transformation_add_no_contraction_decoration.h @@ -121,6 +123,7 @@ fuzzer_util.cpp id_use_descriptor.cpp instruction_descriptor.cpp + instruction_message.cpp pseudo_random_generator.cpp random_generator.cpp replayer.cpp @@ -131,6 +134,7 @@ transformation_add_constant_scalar.cpp transformation_add_dead_break.cpp transformation_add_dead_continue.cpp + transformation_add_function.cpp transformation_add_global_undef.cpp transformation_add_global_variable.cpp transformation_add_no_contraction_decoration.cpp
diff --git a/source/fuzz/instruction_message.cpp b/source/fuzz/instruction_message.cpp new file mode 100644 index 0000000..b217a21 --- /dev/null +++ b/source/fuzz/instruction_message.cpp
@@ -0,0 +1,69 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/instruction_message.h" + +#include "source/fuzz/fuzzer_util.h" + +namespace spvtools { +namespace fuzz { + +protobufs::Instruction MakeInstructionMessage( + SpvOp opcode, uint32_t result_type_id, uint32_t result_id, + const std::vector<std::pair<uint32_t, std::vector<uint32_t>>>& + input_operands) { + protobufs::Instruction result; + result.set_opcode(opcode); + result.set_result_type_id(result_type_id); + result.set_result_id(result_id); + for (auto& operand : input_operands) { + auto operand_message = result.add_input_operand(); + operand_message->set_operand_type(operand.first); + for (auto operand_word : operand.second) { + operand_message->add_operand_data(operand_word); + } + } + return result; +} + +std::unique_ptr<opt::Instruction> InstructionFromMessage( + opt::IRContext* ir_context, + const protobufs::Instruction& instruction_message) { + // First, update the module's id bound with respect to the new instruction, + // if it has a result id. + if (instruction_message.result_id()) { + fuzzerutil::UpdateModuleIdBound(ir_context, + instruction_message.result_id()); + } + // Now create a sequence of input operands from the input operand data in the + // protobuf message. + opt::Instruction::OperandList in_operands; + for (auto& operand_message : instruction_message.input_operand()) { + opt::Operand::OperandData operand_data; + for (auto& word : operand_message.operand_data()) { + operand_data.push_back(word); + } + in_operands.push_back( + {static_cast<spv_operand_type_t>(operand_message.operand_type()), + operand_data}); + } + // Create and return the instruction. + return MakeUnique<opt::Instruction>( + ir_context, static_cast<SpvOp>(instruction_message.opcode()), + instruction_message.result_type_id(), instruction_message.result_id(), + in_operands); +} + +} // namespace fuzz +} // namespace spvtools
diff --git a/source/fuzz/instruction_message.h b/source/fuzz/instruction_message.h new file mode 100644 index 0000000..ed339aa --- /dev/null +++ b/source/fuzz/instruction_message.h
@@ -0,0 +1,44 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_FUZZ_INSTRUCTION_MESSAGE_H_ +#define SOURCE_FUZZ_INSTRUCTION_MESSAGE_H_ + +#include <memory> + +#include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/opt/instruction.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace fuzz { + +// Creates an Instruction protobuf message from its component parts. +protobufs::Instruction MakeInstructionMessage( + SpvOp opcode, uint32_t result_type_id, uint32_t result_id, + const std::vector<std::pair<uint32_t, std::vector<uint32_t>>>& + input_operands); + +// Creates and returns an opt::Instruction from protobuf message +// |instruction_message|, relative to |ir_context|. In the process, the module +// id bound associated with |ir_context| is updated to be at least as large as +// the result id (if any) associated with the new instruction. +std::unique_ptr<opt::Instruction> InstructionFromMessage( + opt::IRContext* ir_context, + const protobufs::Instruction& instruction_message); + +} // namespace fuzz +} // namespace spvtools + +#endif // SOURCE_FUZZ_INSTRUCTION_MESSAGE_H_
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto index f57486c..dbd1fb8 100644 --- a/source/fuzz/protobufs/spvtoolsfuzz.proto +++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -126,6 +126,37 @@ } +message InstructionOperand { + + // Represents an operand to a SPIR-V instruction. + + // The type of the operand. + uint32 operand_type = 1; + + // The data associated with the operand. For most operands (e.g. ids, + // storage classes and literals) this will be a single word. + repeated uint32 operand_data = 2; + +} + +message Instruction { + + // Represents a SPIR-V instruction. + + // The instruction's opcode (e.g. OpLabel). + uint32 opcode = 1; + + // The id of the instruction's result type; 0 if there is no result type. + uint32 result_type_id = 2; + + // The id of the instruction's result; 0 if there is no result. + uint32 result_id = 3; + + // Zero or more input operands. + repeated InstructionOperand input_operand = 4; + +} + message FactSequence { repeated Fact fact = 1; } @@ -210,6 +241,7 @@ TransformationAddConstantComposite add_constant_composite = 30; TransformationAddGlobalVariable add_global_variable = 31; TransformationAddGlobalUndef add_global_undef = 32; + TransformationAddFunction add_function = 33; // Add additional option using the next available number. } } @@ -297,6 +329,15 @@ } +message TransformationAddFunction { + + // Adds a SPIR-V function to the module. + + // The series of instructions that comprise the function. + repeated Instruction instruction = 1; + +} + message TransformationAddGlobalUndef { // Adds an undefined value of a given type to the module at global scope.
diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp index aa886b9..1489f85 100644 --- a/source/fuzz/transformation.cpp +++ b/source/fuzz/transformation.cpp
@@ -21,6 +21,7 @@ #include "source/fuzz/transformation_add_constant_scalar.h" #include "source/fuzz/transformation_add_dead_break.h" #include "source/fuzz/transformation_add_dead_continue.h" +#include "source/fuzz/transformation_add_function.h" #include "source/fuzz/transformation_add_global_undef.h" #include "source/fuzz/transformation_add_global_variable.h" #include "source/fuzz/transformation_add_no_contraction_decoration.h" @@ -72,6 +73,8 @@ case protobufs::Transformation::TransformationCase::kAddDeadContinue: return MakeUnique<TransformationAddDeadContinue>( message.add_dead_continue()); + case protobufs::Transformation::TransformationCase::kAddFunction: + return MakeUnique<TransformationAddFunction>(message.add_function()); case protobufs::Transformation::TransformationCase::kAddGlobalUndef: return MakeUnique<TransformationAddGlobalUndef>( message.add_global_undef());
diff --git a/source/fuzz/transformation_add_function.cpp b/source/fuzz/transformation_add_function.cpp new file mode 100644 index 0000000..5e53961 --- /dev/null +++ b/source/fuzz/transformation_add_function.cpp
@@ -0,0 +1,156 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_add_function.h" + +#include "source/fuzz/fuzzer_util.h" +#include "source/fuzz/instruction_message.h" + +namespace spvtools { +namespace fuzz { + +TransformationAddFunction::TransformationAddFunction( + const spvtools::fuzz::protobufs::TransformationAddFunction& message) + : message_(message) {} + +TransformationAddFunction::TransformationAddFunction( + const std::vector<protobufs::Instruction>& instructions) { + for (auto& instruction : instructions) { + *message_.add_instruction() = instruction; + } +} + +bool TransformationAddFunction::IsApplicable( + opt::IRContext* context, + const spvtools::fuzz::FactManager& /*unused*/) const { + // Because checking all the conditions for a function to be valid is a big + // job that the SPIR-V validator can already do, a "try it and see" approach + // is taken here. + + // We first clone the current module, so that we can try adding the new + // function without risking wrecking |context|. + auto cloned_module = fuzzerutil::CloneIRContext(context); + + // We try to add a function to the cloned module, which may fail if + // |message_.instruction| is not sufficiently well-formed. + if (!TryToAddFunction(cloned_module.get())) { + return false; + } + // Having managed to add the new function to the cloned module, we ascertain + // whether the cloned module is still valid. If it is, the transformation is + // applicable. + return fuzzerutil::IsValid(cloned_module.get()); +} + +void TransformationAddFunction::Apply( + opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { + auto success = TryToAddFunction(context); + assert(success && "The function should be successfully added."); + (void)(success); // Keep release builds happy (otherwise they may complain + // that |success| is not used). + context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); +} + +protobufs::Transformation TransformationAddFunction::ToMessage() const { + protobufs::Transformation result; + *result.mutable_add_function() = message_; + return result; +} + +bool TransformationAddFunction::TryToAddFunction( + opt::IRContext* context) const { + // This function returns false if |message_.instruction| was not well-formed + // enough to actually create a function and add it to |context|. + + // A function must have at least some instructions. + if (message_.instruction().empty()) { + return false; + } + + // A function must start with OpFunction. + auto function_begin = message_.instruction(0); + if (function_begin.opcode() != SpvOpFunction) { + return false; + } + + // Make a function, headed by the OpFunction instruction. + std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>( + InstructionFromMessage(context, function_begin)); + + // Keeps track of which instruction protobuf message we are currently + // considering. + uint32_t instruction_index = 1; + const auto num_instructions = + static_cast<uint32_t>(message_.instruction().size()); + + // Iterate through all function parameter instructions, adding parameters to + // the new function. + while (instruction_index < num_instructions && + message_.instruction(instruction_index).opcode() == + SpvOpFunctionParameter) { + new_function->AddParameter(InstructionFromMessage( + context, message_.instruction(instruction_index))); + instruction_index++; + } + + // After the parameters, there needs to be a label. + if (instruction_index == num_instructions || + message_.instruction(instruction_index).opcode() != SpvOpLabel) { + return false; + } + + // Iterate through the instructions block by block until the end of the + // function is reached. + while (instruction_index < num_instructions && + message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) { + // Invariant: we should always be at a label instruction at this point. + assert(message_.instruction(instruction_index).opcode() == SpvOpLabel); + + // Make a basic block using the label instruction, with the new function + // as its parent. + std::unique_ptr<opt::BasicBlock> block = + MakeUnique<opt::BasicBlock>(InstructionFromMessage( + context, message_.instruction(instruction_index))); + block->SetParent(new_function.get()); + + // Consider successive instructions until we hit another label or the end + // of the function, adding each such instruction to the block. + instruction_index++; + while (instruction_index < num_instructions && + message_.instruction(instruction_index).opcode() != + SpvOpFunctionEnd && + message_.instruction(instruction_index).opcode() != SpvOpLabel) { + block->AddInstruction(InstructionFromMessage( + context, message_.instruction(instruction_index))); + instruction_index++; + } + // Add the block to the new function. + new_function->AddBasicBlock(std::move(block)); + } + // Having considered all the blocks, we should be at the last instruction and + // it needs to be OpFunctionEnd. + if (instruction_index != num_instructions - 1 || + message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) { + return false; + } + // Set the function's final instruction, add the function to the module and + // report success. + new_function->SetFunctionEnd( + InstructionFromMessage(context, message_.instruction(instruction_index))); + context->AddFunction(std::move(new_function)); + return true; +} + +} // namespace fuzz +} // namespace spvtools
diff --git a/source/fuzz/transformation_add_function.h b/source/fuzz/transformation_add_function.h new file mode 100644 index 0000000..fee2732 --- /dev/null +++ b/source/fuzz/transformation_add_function.h
@@ -0,0 +1,70 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_FUZZ_TRANSFORMATION_ADD_FUNCTION_H_ +#define SOURCE_FUZZ_TRANSFORMATION_ADD_FUNCTION_H_ + +#include "source/fuzz/fact_manager.h" +#include "source/fuzz/protobufs/spirvfuzz_protobufs.h" +#include "source/fuzz/transformation.h" +#include "source/opt/ir_context.h" + +namespace spvtools { +namespace fuzz { + +class TransformationAddFunction : public Transformation { + public: + explicit TransformationAddFunction( + const protobufs::TransformationAddFunction& message); + + explicit TransformationAddFunction( + const std::vector<protobufs::Instruction>& instructions); + + // - |message_.instruction| must correspond to a sufficiently well-formed + // sequence of instructions that a function can be created from them + // - Adding the created function to the module must lead to a valid module. + bool IsApplicable(opt::IRContext* context, + const FactManager& fact_manager) const override; + + // Adds the function defined by |message_.instruction| to the module + void Apply(opt::IRContext* context, FactManager* fact_manager) const override; + + protobufs::Transformation ToMessage() const override; + + private: + // Attempts to create a function from the series of instructions in + // |message_.instruction| and add it to |context|. Returns false if this is + // not possible due to the messages not respecting the basic structure of a + // function, e.g. if there is no OpFunction instruction or no blocks; in this + // case |context| is left in an indeterminate state. + // + // Otherwise returns true. Whether |context| is valid after addition of the + // function depends on the contents of |message_.instruction|. + // + // Intended usage: + // - Perform a dry run of this method on a clone of a module, and use + // the validator to check whether the resulting module is valid. Working + // on a clone means it does not matter if the function fails to be cleanly + // added, or leads to an invalid module. + // - If the dry run succeeds, run the method on the real module of interest, + // to add the function. + bool TryToAddFunction(opt::IRContext* context) const; + + protobufs::TransformationAddFunction message_; +}; + +} // namespace fuzz +} // namespace spvtools + +#endif // SOURCE_FUZZ_TRANSFORMATION_ADD_FUNCTION_H_
diff --git a/test/fuzz/transformation_add_function_test.cpp b/test/fuzz/transformation_add_function_test.cpp new file mode 100644 index 0000000..66130be --- /dev/null +++ b/test/fuzz/transformation_add_function_test.cpp
@@ -0,0 +1,447 @@ +// Copyright (c) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/fuzz/transformation_add_function.h" +#include "source/fuzz/instruction_message.h" +#include "test/fuzz/fuzz_test_util.h" + +namespace spvtools { +namespace fuzz { +namespace { + +TEST(TransformationAddFunctionTest, BasicTest) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFloat 32 + %9 = OpTypePointer Function %8 + %10 = OpTypeFunction %8 %7 %9 + %18 = OpConstant %8 0 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %37 = OpConstant %6 1 + %42 = OpTypePointer Private %8 + %43 = OpVariable %42 Private + %47 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationAddFunction transformation1(std::vector<protobufs::Instruction>( + {MakeInstructionMessage( + SpvOpFunction, 8, 13, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_ID, {10}}}), + MakeInstructionMessage(SpvOpFunctionParameter, 7, 11, {}), + MakeInstructionMessage(SpvOpFunctionParameter, 9, 12, {}), + MakeInstructionMessage(SpvOpLabel, 0, 14, {}), + MakeInstructionMessage( + SpvOpVariable, 9, 17, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}), + MakeInstructionMessage( + SpvOpVariable, 7, 19, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}), + MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {17}}, {SPV_OPERAND_TYPE_ID, {18}}}), + MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {19}}, {SPV_OPERAND_TYPE_ID, {20}}}), + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {21}}}), + MakeInstructionMessage(SpvOpLabel, 0, 21, {}), + MakeInstructionMessage( + SpvOpLoopMerge, 0, 0, + {{SPV_OPERAND_TYPE_ID, {23}}, + {SPV_OPERAND_TYPE_ID, {24}}, + {SPV_OPERAND_TYPE_LOOP_CONTROL, {SpvLoopControlMaskNone}}}), + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {25}}}), + MakeInstructionMessage(SpvOpLabel, 0, 25, {}), + MakeInstructionMessage(SpvOpLoad, 6, 26, {{SPV_OPERAND_TYPE_ID, {19}}}), + MakeInstructionMessage(SpvOpLoad, 6, 27, {{SPV_OPERAND_TYPE_ID, {11}}}), + MakeInstructionMessage( + SpvOpSLessThan, 28, 29, + {{SPV_OPERAND_TYPE_ID, {26}}, {SPV_OPERAND_TYPE_ID, {27}}}), + MakeInstructionMessage(SpvOpBranchConditional, 0, 0, + {{SPV_OPERAND_TYPE_ID, {29}}, + {SPV_OPERAND_TYPE_ID, {22}}, + {SPV_OPERAND_TYPE_ID, {23}}}), + MakeInstructionMessage(SpvOpLabel, 0, 22, {}), + MakeInstructionMessage(SpvOpLoad, 8, 30, {{SPV_OPERAND_TYPE_ID, {12}}}), + MakeInstructionMessage(SpvOpLoad, 6, 31, {{SPV_OPERAND_TYPE_ID, {19}}}), + MakeInstructionMessage(SpvOpConvertSToF, 8, 32, + {{SPV_OPERAND_TYPE_ID, {31}}}), + MakeInstructionMessage( + SpvOpFMul, 8, 33, + {{SPV_OPERAND_TYPE_ID, {30}}, {SPV_OPERAND_TYPE_ID, {32}}}), + MakeInstructionMessage(SpvOpLoad, 8, 34, {{SPV_OPERAND_TYPE_ID, {17}}}), + MakeInstructionMessage( + SpvOpFAdd, 8, 35, + {{SPV_OPERAND_TYPE_ID, {34}}, {SPV_OPERAND_TYPE_ID, {33}}}), + MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {17}}, {SPV_OPERAND_TYPE_ID, {35}}}), + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {24}}}), + MakeInstructionMessage(SpvOpLabel, 0, 24, {}), + MakeInstructionMessage(SpvOpLoad, 6, 36, {{SPV_OPERAND_TYPE_ID, {19}}}), + MakeInstructionMessage( + SpvOpIAdd, 6, 38, + {{SPV_OPERAND_TYPE_ID, {36}}, {SPV_OPERAND_TYPE_ID, {37}}}), + MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {19}}, {SPV_OPERAND_TYPE_ID, {38}}}), + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {21}}}), + MakeInstructionMessage(SpvOpLabel, 0, 23, {}), + MakeInstructionMessage(SpvOpLoad, 8, 39, {{SPV_OPERAND_TYPE_ID, {17}}}), + MakeInstructionMessage(SpvOpReturnValue, 0, 0, + {{SPV_OPERAND_TYPE_ID, {39}}}), + MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})})); + + ASSERT_TRUE(transformation1.IsApplicable(context.get(), fact_manager)); + transformation1.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation1 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFloat 32 + %9 = OpTypePointer Function %8 + %10 = OpTypeFunction %8 %7 %9 + %18 = OpConstant %8 0 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %37 = OpConstant %6 1 + %42 = OpTypePointer Private %8 + %43 = OpVariable %42 Private + %47 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %13 = OpFunction %8 None %10 + %11 = OpFunctionParameter %7 + %12 = OpFunctionParameter %9 + %14 = OpLabel + %17 = OpVariable %9 Function + %19 = OpVariable %7 Function + OpStore %17 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %26 = OpLoad %6 %19 + %27 = OpLoad %6 %11 + %29 = OpSLessThan %28 %26 %27 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + %30 = OpLoad %8 %12 + %31 = OpLoad %6 %19 + %32 = OpConvertSToF %8 %31 + %33 = OpFMul %8 %30 %32 + %34 = OpLoad %8 %17 + %35 = OpFAdd %8 %34 %33 + OpStore %17 %35 + OpBranch %24 + %24 = OpLabel + %36 = OpLoad %6 %19 + %38 = OpIAdd %6 %36 %37 + OpStore %19 %38 + OpBranch %21 + %23 = OpLabel + %39 = OpLoad %8 %17 + OpReturnValue %39 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation1, context.get())); + + TransformationAddFunction transformation2(std::vector<protobufs::Instruction>( + {MakeInstructionMessage( + SpvOpFunction, 2, 15, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_ID, {3}}}), + MakeInstructionMessage(SpvOpLabel, 0, 16, {}), + MakeInstructionMessage( + SpvOpVariable, 7, 44, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}), + MakeInstructionMessage( + SpvOpVariable, 9, 45, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}), + MakeInstructionMessage( + SpvOpVariable, 7, 48, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}), + MakeInstructionMessage( + SpvOpVariable, 9, 49, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}), + MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {44}}, {SPV_OPERAND_TYPE_ID, {20}}}), + MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {45}}, {SPV_OPERAND_TYPE_ID, {18}}}), + MakeInstructionMessage(SpvOpFunctionCall, 8, 46, + {{SPV_OPERAND_TYPE_ID, {13}}, + {SPV_OPERAND_TYPE_ID, {44}}, + {SPV_OPERAND_TYPE_ID, {45}}}), + MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {48}}, {SPV_OPERAND_TYPE_ID, {37}}}), + MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {49}}, {SPV_OPERAND_TYPE_ID, {47}}}), + MakeInstructionMessage(SpvOpFunctionCall, 8, 50, + {{SPV_OPERAND_TYPE_ID, {13}}, + {SPV_OPERAND_TYPE_ID, {48}}, + {SPV_OPERAND_TYPE_ID, {49}}}), + MakeInstructionMessage( + SpvOpFAdd, 8, 51, + {{SPV_OPERAND_TYPE_ID, {46}}, {SPV_OPERAND_TYPE_ID, {50}}}), + MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {43}}, {SPV_OPERAND_TYPE_ID, {51}}}), + MakeInstructionMessage(SpvOpReturn, 0, 0, {}), + MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})})); + + ASSERT_TRUE(transformation2.IsApplicable(context.get(), fact_manager)); + transformation2.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation2 = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFloat 32 + %9 = OpTypePointer Function %8 + %10 = OpTypeFunction %8 %7 %9 + %18 = OpConstant %8 0 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %37 = OpConstant %6 1 + %42 = OpTypePointer Private %8 + %43 = OpVariable %42 Private + %47 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %13 = OpFunction %8 None %10 + %11 = OpFunctionParameter %7 + %12 = OpFunctionParameter %9 + %14 = OpLabel + %17 = OpVariable %9 Function + %19 = OpVariable %7 Function + OpStore %17 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %26 = OpLoad %6 %19 + %27 = OpLoad %6 %11 + %29 = OpSLessThan %28 %26 %27 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + %30 = OpLoad %8 %12 + %31 = OpLoad %6 %19 + %32 = OpConvertSToF %8 %31 + %33 = OpFMul %8 %30 %32 + %34 = OpLoad %8 %17 + %35 = OpFAdd %8 %34 %33 + OpStore %17 %35 + OpBranch %24 + %24 = OpLabel + %36 = OpLoad %6 %19 + %38 = OpIAdd %6 %36 %37 + OpStore %19 %38 + OpBranch %21 + %23 = OpLabel + %39 = OpLoad %8 %17 + OpReturnValue %39 + OpFunctionEnd + %15 = OpFunction %2 None %3 + %16 = OpLabel + %44 = OpVariable %7 Function + %45 = OpVariable %9 Function + %48 = OpVariable %7 Function + %49 = OpVariable %9 Function + OpStore %44 %20 + OpStore %45 %18 + %46 = OpFunctionCall %8 %13 %44 %45 + OpStore %48 %37 + OpStore %49 %47 + %50 = OpFunctionCall %8 %13 %48 %49 + %51 = OpFAdd %8 %46 %50 + OpStore %43 %51 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation2, context.get())); +} + +TEST(TransformationAddFunctionTest, InapplicableTransformations) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFloat 32 + %9 = OpTypePointer Function %8 + %10 = OpTypeFunction %8 %7 %9 + %18 = OpConstant %8 0 + %20 = OpConstant %6 0 + %28 = OpTypeBool + %37 = OpConstant %6 1 + %42 = OpTypePointer Private %8 + %43 = OpVariable %42 Private + %47 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %13 = OpFunction %8 None %10 + %11 = OpFunctionParameter %7 + %12 = OpFunctionParameter %9 + %14 = OpLabel + %17 = OpVariable %9 Function + %19 = OpVariable %7 Function + OpStore %17 %18 + OpStore %19 %20 + OpBranch %21 + %21 = OpLabel + OpLoopMerge %23 %24 None + OpBranch %25 + %25 = OpLabel + %26 = OpLoad %6 %19 + %27 = OpLoad %6 %11 + %29 = OpSLessThan %28 %26 %27 + OpBranchConditional %29 %22 %23 + %22 = OpLabel + %30 = OpLoad %8 %12 + %31 = OpLoad %6 %19 + %32 = OpConvertSToF %8 %31 + %33 = OpFMul %8 %30 %32 + %34 = OpLoad %8 %17 + %35 = OpFAdd %8 %34 %33 + OpStore %17 %35 + OpBranch %24 + %24 = OpLabel + %36 = OpLoad %6 %19 + %38 = OpIAdd %6 %36 %37 + OpStore %19 %38 + OpBranch %21 + %23 = OpLabel + %39 = OpLoad %8 %17 + OpReturnValue %39 + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + // No instructions + ASSERT_FALSE( + TransformationAddFunction(std::vector<protobufs::Instruction>({})) + .IsApplicable(context.get(), fact_manager)); + + // No function begin + ASSERT_FALSE( + TransformationAddFunction( + std::vector<protobufs::Instruction>( + {MakeInstructionMessage(SpvOpFunctionParameter, 7, 11, {}), + MakeInstructionMessage(SpvOpFunctionParameter, 9, 12, {}), + MakeInstructionMessage(SpvOpLabel, 0, 14, {})})) + .IsApplicable(context.get(), fact_manager)); + + // No OpLabel + ASSERT_FALSE( + TransformationAddFunction( + std::vector<protobufs::Instruction>( + {MakeInstructionMessage(SpvOpFunction, 8, 13, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, + {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_ID, {10}}}), + MakeInstructionMessage(SpvOpReturnValue, 0, 0, + {{SPV_OPERAND_TYPE_ID, {39}}}), + MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})})) + .IsApplicable(context.get(), fact_manager)); + + // Abrupt end of instructions + ASSERT_FALSE(TransformationAddFunction( + std::vector<protobufs::Instruction>({MakeInstructionMessage( + SpvOpFunction, 8, 13, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, + {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_ID, {10}}})})) + .IsApplicable(context.get(), fact_manager)); + + // No function end + ASSERT_FALSE( + TransformationAddFunction( + std::vector<protobufs::Instruction>( + {MakeInstructionMessage(SpvOpFunction, 8, 13, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, + {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_ID, {10}}}), + MakeInstructionMessage(SpvOpLabel, 0, 14, {}), + MakeInstructionMessage(SpvOpReturnValue, 0, 0, + {{SPV_OPERAND_TYPE_ID, {39}}})})) + .IsApplicable(context.get(), fact_manager)); +} + +} // namespace +} // namespace fuzz +} // namespace spvtools