Add adjust branch weights transformation (#3336)
In this PR, the classes that represent the adjust branch weights
transformation and fuzzer pass were implemented. This transformation
adjusts the branch weights of a OpBranchConditional instruction.
diff --git a/source/fuzz/CMakeLists.txt b/source/fuzz/CMakeLists.txt
index ce6d3a6..0d4e4ea 100644
--- a/source/fuzz/CMakeLists.txt
+++ b/source/fuzz/CMakeLists.txt
@@ -50,6 +50,7 @@
fuzzer_pass_add_no_contraction_decorations.h
fuzzer_pass_add_stores.h
fuzzer_pass_add_useful_constructs.h
+ fuzzer_pass_adjust_branch_weights.h
fuzzer_pass_adjust_function_controls.h
fuzzer_pass_adjust_loop_controls.h
fuzzer_pass_adjust_memory_operands_masks.h
@@ -98,6 +99,7 @@
transformation_add_type_pointer.h
transformation_add_type_struct.h
transformation_add_type_vector.h
+ transformation_adjust_branch_weights.h
transformation_composite_construct.h
transformation_composite_extract.h
transformation_compute_data_synonym_fact_closure.h
@@ -145,6 +147,7 @@
fuzzer_pass_add_no_contraction_decorations.cpp
fuzzer_pass_add_stores.cpp
fuzzer_pass_add_useful_constructs.cpp
+ fuzzer_pass_adjust_branch_weights.cpp
fuzzer_pass_adjust_function_controls.cpp
fuzzer_pass_adjust_loop_controls.cpp
fuzzer_pass_adjust_memory_operands_masks.cpp
@@ -192,6 +195,7 @@
transformation_add_type_pointer.cpp
transformation_add_type_struct.cpp
transformation_add_type_vector.cpp
+ transformation_adjust_branch_weights.cpp
transformation_composite_construct.cpp
transformation_composite_extract.cpp
transformation_compute_data_synonym_fact_closure.cpp
diff --git a/source/fuzz/fuzzer.cpp b/source/fuzz/fuzzer.cpp
index 6524c21..d073254 100644
--- a/source/fuzz/fuzzer.cpp
+++ b/source/fuzz/fuzzer.cpp
@@ -34,6 +34,7 @@
#include "source/fuzz/fuzzer_pass_add_no_contraction_decorations.h"
#include "source/fuzz/fuzzer_pass_add_stores.h"
#include "source/fuzz/fuzzer_pass_add_useful_constructs.h"
+#include "source/fuzz/fuzzer_pass_adjust_branch_weights.h"
#include "source/fuzz/fuzzer_pass_adjust_function_controls.h"
#include "source/fuzz/fuzzer_pass_adjust_loop_controls.h"
#include "source/fuzz/fuzzer_pass_adjust_selection_controls.h"
@@ -281,6 +282,9 @@
// Now apply some passes that it does not make sense to apply repeatedly,
// as they do not unlock other passes.
std::vector<std::unique_ptr<FuzzerPass>> final_passes;
+ MaybeAddPass<FuzzerPassAdjustBranchWeights>(
+ &final_passes, ir_context.get(), &transformation_context, &fuzzer_context,
+ transformation_sequence_out);
MaybeAddPass<FuzzerPassAdjustFunctionControls>(
&final_passes, ir_context.get(), &transformation_context, &fuzzer_context,
transformation_sequence_out);
diff --git a/source/fuzz/fuzzer_context.cpp b/source/fuzz/fuzzer_context.cpp
index 94032ef..1779709 100644
--- a/source/fuzz/fuzzer_context.cpp
+++ b/source/fuzz/fuzzer_context.cpp
@@ -40,6 +40,7 @@
5, 70};
const std::pair<uint32_t, uint32_t> kChanceOfAddingStore = {5, 50};
const std::pair<uint32_t, uint32_t> kChanceOfAddingVectorType = {20, 70};
+const std::pair<uint32_t, uint32_t> kChanceOfAdjustingBranchWeights = {20, 90};
const std::pair<uint32_t, uint32_t> kChanceOfAdjustingFunctionControl = {20,
70};
const std::pair<uint32_t, uint32_t> kChanceOfAdjustingLoopControl = {20, 90};
@@ -124,6 +125,8 @@
chance_of_adding_store_ = ChooseBetweenMinAndMax(kChanceOfAddingStore);
chance_of_adding_vector_type_ =
ChooseBetweenMinAndMax(kChanceOfAddingVectorType);
+ chance_of_adjusting_branch_weights_ =
+ ChooseBetweenMinAndMax(kChanceOfAdjustingBranchWeights);
chance_of_adjusting_function_control_ =
ChooseBetweenMinAndMax(kChanceOfAdjustingFunctionControl);
chance_of_adjusting_loop_control_ =
diff --git a/source/fuzz/fuzzer_context.h b/source/fuzz/fuzzer_context.h
index 5899235..dd19d9a 100644
--- a/source/fuzz/fuzzer_context.h
+++ b/source/fuzz/fuzzer_context.h
@@ -136,6 +136,9 @@
uint32_t GetChanceOfAddingVectorType() {
return chance_of_adding_vector_type_;
}
+ uint32_t GetChanceOfAdjustingBranchWeights() {
+ return chance_of_adjusting_branch_weights_;
+ }
uint32_t GetChanceOfAdjustingFunctionControl() {
return chance_of_adjusting_function_control_;
}
@@ -201,6 +204,18 @@
uint32_t GetRandomLoopLimit() {
return random_generator_->RandomUint32(max_loop_limit_);
}
+ std::pair<uint32_t, uint32_t> GetRandomBranchWeights() {
+ std::pair<uint32_t, uint32_t> branch_weights = {0, 0};
+
+ while (branch_weights.first == 0 && branch_weights.second == 0) {
+ // Using INT32_MAX to do not overflow UINT32_MAX when the branch weights
+ // are added together.
+ branch_weights.first = random_generator_->RandomUint32(INT32_MAX);
+ branch_weights.second = random_generator_->RandomUint32(INT32_MAX);
+ }
+
+ return branch_weights;
+ }
uint32_t GetRandomSizeForNewArray() {
// Ensure that the array size is non-zero.
return random_generator_->RandomUint32(max_new_array_size_limit_ - 1) + 1;
@@ -231,6 +246,7 @@
uint32_t chance_of_adding_no_contraction_decoration_;
uint32_t chance_of_adding_store_;
uint32_t chance_of_adding_vector_type_;
+ uint32_t chance_of_adjusting_branch_weights_;
uint32_t chance_of_adjusting_function_control_;
uint32_t chance_of_adjusting_loop_control_;
uint32_t chance_of_adjusting_memory_operands_mask_;
diff --git a/source/fuzz/fuzzer_pass_adjust_branch_weights.cpp b/source/fuzz/fuzzer_pass_adjust_branch_weights.cpp
new file mode 100644
index 0000000..1d6d434
--- /dev/null
+++ b/source/fuzz/fuzzer_pass_adjust_branch_weights.cpp
@@ -0,0 +1,48 @@
+// Copyright (c) 2020 André Perez Maselco
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/fuzzer_pass_adjust_branch_weights.h"
+
+#include "source/fuzz/fuzzer_util.h"
+#include "source/fuzz/instruction_descriptor.h"
+#include "source/fuzz/transformation_adjust_branch_weights.h"
+
+namespace spvtools {
+namespace fuzz {
+
+FuzzerPassAdjustBranchWeights::FuzzerPassAdjustBranchWeights(
+ opt::IRContext* ir_context, TransformationContext* transformation_context,
+ FuzzerContext* fuzzer_context,
+ protobufs::TransformationSequence* transformations)
+ : FuzzerPass(ir_context, transformation_context, fuzzer_context,
+ transformations) {}
+
+FuzzerPassAdjustBranchWeights::~FuzzerPassAdjustBranchWeights() = default;
+
+void FuzzerPassAdjustBranchWeights::Apply() {
+ // For all OpBranchConditional instructions,
+ // randomly applies the transformation.
+ GetIRContext()->module()->ForEachInst([this](opt::Instruction* instruction) {
+ if (instruction->opcode() == SpvOpBranchConditional &&
+ GetFuzzerContext()->ChoosePercentage(
+ GetFuzzerContext()->GetChanceOfAdjustingBranchWeights())) {
+ ApplyTransformation(TransformationAdjustBranchWeights(
+ MakeInstructionDescriptor(GetIRContext(), instruction),
+ GetFuzzerContext()->GetRandomBranchWeights()));
+ }
+ });
+}
+
+} // namespace fuzz
+} // namespace spvtools
diff --git a/source/fuzz/fuzzer_pass_adjust_branch_weights.h b/source/fuzz/fuzzer_pass_adjust_branch_weights.h
new file mode 100644
index 0000000..5b2b33f
--- /dev/null
+++ b/source/fuzz/fuzzer_pass_adjust_branch_weights.h
@@ -0,0 +1,41 @@
+// Copyright (c) 2020 André Perez Maselco
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_FUZZ_FUZZER_PASS_ADJUST_BRANCH_WEIGHTS_H_
+#define SOURCE_FUZZ_FUZZER_PASS_ADJUST_BRANCH_WEIGHTS_H_
+
+#include "source/fuzz/fuzzer_pass.h"
+
+namespace spvtools {
+namespace fuzz {
+
+// This fuzzer pass searches for branch conditional instructions
+// and randomly chooses which of these instructions will have their weights
+// adjusted.
+class FuzzerPassAdjustBranchWeights : public FuzzerPass {
+ public:
+ FuzzerPassAdjustBranchWeights(
+ opt::IRContext* ir_context, TransformationContext* transformation_context,
+ FuzzerContext* fuzzer_context,
+ protobufs::TransformationSequence* transformations);
+
+ ~FuzzerPassAdjustBranchWeights();
+
+ void Apply() override;
+};
+
+} // namespace fuzz
+} // namespace spvtools
+
+#endif // SOURCE_FUZZ_FUZZER_PASS_ADJUST_BRANCH_WEIGHTS_H_
diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto
index 68460a9..775b2ad 100644
--- a/source/fuzz/protobufs/spvtoolsfuzz.proto
+++ b/source/fuzz/protobufs/spvtoolsfuzz.proto
@@ -374,6 +374,7 @@
TransformationToggleAccessChainInstruction toggle_access_chain_instruction = 43;
TransformationAddConstantNull add_constant_null = 44;
TransformationComputeDataSynonymFactClosure compute_data_synonym_fact_closure = 45;
+ TransformationAdjustBranchWeights adjust_branch_weights = 46;
// Add additional option using the next available number.
}
}
@@ -742,6 +743,19 @@
}
+message TransformationAdjustBranchWeights {
+
+ // A transformation that adjusts the branch weights
+ // of a branch conditional instruction.
+
+ // A descriptor for a branch conditional instruction.
+ InstructionDescriptor instruction_descriptor = 1;
+
+ // Branch weights of a branch conditional instruction.
+ UInt32Pair branch_weights = 2;
+
+}
+
message TransformationCompositeConstruct {
// A transformation that introduces an OpCompositeConstruct instruction to
diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp
index c8391e1..8b84169 100644
--- a/source/fuzz/transformation.cpp
+++ b/source/fuzz/transformation.cpp
@@ -39,6 +39,7 @@
#include "source/fuzz/transformation_add_type_pointer.h"
#include "source/fuzz/transformation_add_type_struct.h"
#include "source/fuzz/transformation_add_type_vector.h"
+#include "source/fuzz/transformation_adjust_branch_weights.h"
#include "source/fuzz/transformation_composite_construct.h"
#include "source/fuzz/transformation_composite_extract.h"
#include "source/fuzz/transformation_compute_data_synonym_fact_closure.h"
@@ -129,6 +130,9 @@
return MakeUnique<TransformationAddTypeStruct>(message.add_type_struct());
case protobufs::Transformation::TransformationCase::kAddTypeVector:
return MakeUnique<TransformationAddTypeVector>(message.add_type_vector());
+ case protobufs::Transformation::TransformationCase::kAdjustBranchWeights:
+ return MakeUnique<TransformationAdjustBranchWeights>(
+ message.adjust_branch_weights());
case protobufs::Transformation::TransformationCase::kCompositeConstruct:
return MakeUnique<TransformationCompositeConstruct>(
message.composite_construct());
diff --git a/source/fuzz/transformation_adjust_branch_weights.cpp b/source/fuzz/transformation_adjust_branch_weights.cpp
new file mode 100644
index 0000000..ed68134
--- /dev/null
+++ b/source/fuzz/transformation_adjust_branch_weights.cpp
@@ -0,0 +1,97 @@
+// Copyright (c) 2020 André Perez Maselco
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_adjust_branch_weights.h"
+
+#include "source/fuzz/fuzzer_util.h"
+#include "source/fuzz/instruction_descriptor.h"
+
+namespace spvtools {
+namespace fuzz {
+
+namespace {
+
+const uint32_t kBranchWeightForTrueLabelIndex = 3;
+const uint32_t kBranchWeightForFalseLabelIndex = 4;
+
+} // namespace
+
+TransformationAdjustBranchWeights::TransformationAdjustBranchWeights(
+ const spvtools::fuzz::protobufs::TransformationAdjustBranchWeights& message)
+ : message_(message) {}
+
+TransformationAdjustBranchWeights::TransformationAdjustBranchWeights(
+ const protobufs::InstructionDescriptor& instruction_descriptor,
+ const std::pair<uint32_t, uint32_t>& branch_weights) {
+ *message_.mutable_instruction_descriptor() = instruction_descriptor;
+ message_.mutable_branch_weights()->set_first(branch_weights.first);
+ message_.mutable_branch_weights()->set_second(branch_weights.second);
+}
+
+bool TransformationAdjustBranchWeights::IsApplicable(
+ opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
+ auto instruction =
+ FindInstruction(message_.instruction_descriptor(), ir_context);
+ if (instruction == nullptr) {
+ return false;
+ }
+
+ SpvOp opcode = static_cast<SpvOp>(
+ message_.instruction_descriptor().target_instruction_opcode());
+
+ assert(instruction->opcode() == opcode &&
+ "The located instruction must have the same opcode as in the "
+ "descriptor.");
+
+ // Must be an OpBranchConditional instruction.
+ if (opcode != SpvOpBranchConditional) {
+ return false;
+ }
+
+ assert((message_.branch_weights().first() != 0 ||
+ message_.branch_weights().second() != 0) &&
+ "At least one weight must be non-zero.");
+
+ assert(message_.branch_weights().first() <=
+ UINT32_MAX - message_.branch_weights().second() &&
+ "The sum of the two weights must not be greater than UINT32_MAX.");
+
+ return true;
+}
+
+void TransformationAdjustBranchWeights::Apply(
+ opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
+ auto instruction =
+ FindInstruction(message_.instruction_descriptor(), ir_context);
+ if (instruction->HasBranchWeights()) {
+ instruction->SetOperand(kBranchWeightForTrueLabelIndex,
+ {message_.branch_weights().first()});
+ instruction->SetOperand(kBranchWeightForFalseLabelIndex,
+ {message_.branch_weights().second()});
+ } else {
+ instruction->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER,
+ {message_.branch_weights().first()}});
+ instruction->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER,
+ {message_.branch_weights().second()}});
+ }
+}
+
+protobufs::Transformation TransformationAdjustBranchWeights::ToMessage() const {
+ protobufs::Transformation result;
+ *result.mutable_adjust_branch_weights() = message_;
+ return result;
+}
+
+} // namespace fuzz
+} // namespace spvtools
diff --git a/source/fuzz/transformation_adjust_branch_weights.h b/source/fuzz/transformation_adjust_branch_weights.h
new file mode 100644
index 0000000..638b0a9
--- /dev/null
+++ b/source/fuzz/transformation_adjust_branch_weights.h
@@ -0,0 +1,57 @@
+// Copyright (c) 2020 André Perez Maselco
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_FUZZ_TRANSFORMATION_ADJUST_BRANCH_WEIGHTS_H_
+#define SOURCE_FUZZ_TRANSFORMATION_ADJUST_BRANCH_WEIGHTS_H_
+
+#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
+#include "source/fuzz/transformation.h"
+#include "source/fuzz/transformation_context.h"
+#include "source/opt/ir_context.h"
+
+namespace spvtools {
+namespace fuzz {
+
+class TransformationAdjustBranchWeights : public Transformation {
+ public:
+ explicit TransformationAdjustBranchWeights(
+ const protobufs::TransformationAdjustBranchWeights& message);
+
+ TransformationAdjustBranchWeights(
+ const protobufs::InstructionDescriptor& instruction_descriptor,
+ const std::pair<uint32_t, uint32_t>& branch_weights);
+
+ // - |message_.instruction_descriptor| must identify an existing
+ // branch conditional instruction
+ // - At least one of |branch_weights| must be non-zero and
+ // the two weights must not overflow a 32-bit unsigned integer when added
+ // together
+ bool IsApplicable(
+ opt::IRContext* ir_context,
+ const TransformationContext& transformation_context) const override;
+
+ // Adjust the branch weights of a branch conditional instruction.
+ void Apply(opt::IRContext* ir_context,
+ TransformationContext* transformation_context) const override;
+
+ protobufs::Transformation ToMessage() const override;
+
+ private:
+ protobufs::TransformationAdjustBranchWeights message_;
+};
+
+} // namespace fuzz
+} // namespace spvtools
+
+#endif // SOURCE_FUZZ_TRANSFORMATION_ADJUST_BRANCH_WEIGHTS_H_
diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp
index 7052d3e..19f6fff 100644
--- a/source/opt/instruction.cpp
+++ b/source/opt/instruction.cpp
@@ -38,6 +38,10 @@
const uint32_t kDebugScopeNumWords = 7;
const uint32_t kDebugScopeNumWordsWithoutInlinedAt = 6;
const uint32_t kDebugNoScopeNumWords = 5;
+
+// Number of operands of an OpBranchConditional instruction
+// with weights.
+const uint32_t kOpBranchConditionalWithWeightsNumOperands = 5;
} // namespace
Instruction::Instruction(IRContext* c)
@@ -166,6 +170,15 @@
return size;
}
+bool Instruction::HasBranchWeights() const {
+ if (opcode_ == SpvOpBranchConditional &&
+ NumOperands() == kOpBranchConditionalWithWeightsNumOperands) {
+ return true;
+ }
+
+ return false;
+}
+
void Instruction::ToBinaryWithoutAttachedDebugInsts(
std::vector<uint32_t>* binary) const {
const uint32_t num_words = 1 + NumOperandWords();
diff --git a/source/opt/instruction.h b/source/opt/instruction.h
index aa29c5e..a758df4 100644
--- a/source/opt/instruction.h
+++ b/source/opt/instruction.h
@@ -366,6 +366,10 @@
inline bool WhileEachInOperand(
const std::function<bool(const uint32_t*)>& f) const;
+ // Returns true if it's an OpBranchConditional instruction
+ // with branch weights.
+ bool HasBranchWeights() const;
+
// Returns true if any operands can be labels
inline bool HasLabels() const;
diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt
index 8c5d512..afa168d 100644
--- a/test/fuzz/CMakeLists.txt
+++ b/test/fuzz/CMakeLists.txt
@@ -47,6 +47,7 @@
transformation_add_type_pointer_test.cpp
transformation_add_type_struct_test.cpp
transformation_add_type_vector_test.cpp
+ transformation_adjust_branch_weights_test.cpp
transformation_composite_construct_test.cpp
transformation_composite_extract_test.cpp
transformation_compute_data_synonym_fact_closure_test.cpp
diff --git a/test/fuzz/transformation_adjust_branch_weights_test.cpp b/test/fuzz/transformation_adjust_branch_weights_test.cpp
new file mode 100644
index 0000000..7f8ba31
--- /dev/null
+++ b/test/fuzz/transformation_adjust_branch_weights_test.cpp
@@ -0,0 +1,349 @@
+// Copyright (c) 2020 André Perez Maselco
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/fuzz/transformation_adjust_branch_weights.h"
+#include "source/fuzz/instruction_descriptor.h"
+#include "test/fuzz/fuzz_test_util.h"
+
+namespace spvtools {
+namespace fuzz {
+namespace {
+
+TEST(TransformationAdjustBranchWeightsTest, IsApplicableTest) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main" %51 %27
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ OpName %25 "buf"
+ OpMemberName %25 0 "value"
+ OpName %27 ""
+ OpName %51 "color"
+ OpMemberDecorate %25 0 Offset 0
+ OpDecorate %25 Block
+ OpDecorate %27 DescriptorSet 0
+ OpDecorate %27 Binding 0
+ OpDecorate %51 Location 0
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeFloat 32
+ %7 = OpTypeVector %6 4
+ %150 = OpTypeVector %6 2
+ %10 = OpConstant %6 0.300000012
+ %11 = OpConstant %6 0.400000006
+ %12 = OpConstant %6 0.5
+ %13 = OpConstant %6 1
+ %14 = OpConstantComposite %7 %10 %11 %12 %13
+ %15 = OpTypeInt 32 1
+ %18 = OpConstant %15 0
+ %25 = OpTypeStruct %6
+ %26 = OpTypePointer Uniform %25
+ %27 = OpVariable %26 Uniform
+ %28 = OpTypePointer Uniform %6
+ %32 = OpTypeBool
+ %103 = OpConstantTrue %32
+ %34 = OpConstant %6 0.100000001
+ %48 = OpConstant %15 1
+ %50 = OpTypePointer Output %7
+ %51 = OpVariable %50 Output
+ %100 = OpTypePointer Function %6
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %101 = OpVariable %100 Function
+ %102 = OpVariable %100 Function
+ OpBranch %19
+ %19 = OpLabel
+ %60 = OpPhi %7 %14 %5 %58 %20
+ %59 = OpPhi %15 %18 %5 %49 %20
+ %29 = OpAccessChain %28 %27 %18
+ %30 = OpLoad %6 %29
+ %31 = OpConvertFToS %15 %30
+ %33 = OpSLessThan %32 %59 %31
+ OpLoopMerge %21 %20 None
+ OpBranchConditional %33 %20 %21 1 2
+ %20 = OpLabel
+ %39 = OpCompositeExtract %6 %60 0
+ %40 = OpFAdd %6 %39 %34
+ %55 = OpCompositeInsert %7 %40 %60 0
+ %44 = OpCompositeExtract %6 %60 1
+ %45 = OpFSub %6 %44 %34
+ %58 = OpCompositeInsert %7 %45 %55 1
+ %49 = OpIAdd %15 %59 %48
+ OpBranch %19
+ %21 = OpLabel
+ OpStore %51 %60
+ OpSelectionMerge %105 None
+ OpBranchConditional %103 %104 %105
+ %104 = OpLabel
+ OpBranch %105
+ %105 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(&fact_manager,
+ validator_options);
+
+ // Tests OpBranchConditional instruction with weigths.
+ auto instruction_descriptor =
+ MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
+ auto transformation =
+ TransformationAdjustBranchWeights(instruction_descriptor, {0, 1});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+
+ // Tests the two branch weights equal to 0.
+ instruction_descriptor =
+ MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
+ transformation =
+ TransformationAdjustBranchWeights(instruction_descriptor, {0, 0});
+#ifndef NDEBUG
+ ASSERT_DEATH(
+ transformation.IsApplicable(context.get(), transformation_context),
+ "At least one weight must be non-zero");
+#endif
+
+ // Tests 32-bit unsigned integer overflow.
+ instruction_descriptor =
+ MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
+ transformation = TransformationAdjustBranchWeights(instruction_descriptor,
+ {UINT32_MAX, 0});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+
+ instruction_descriptor =
+ MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
+ transformation = TransformationAdjustBranchWeights(instruction_descriptor,
+ {1, UINT32_MAX});
+#ifndef NDEBUG
+ ASSERT_DEATH(
+ transformation.IsApplicable(context.get(), transformation_context),
+ "The sum of the two weights must not be greater than UINT32_MAX");
+#endif
+
+ // Tests OpBranchConditional instruction with no weights.
+ instruction_descriptor =
+ MakeInstructionDescriptor(21, SpvOpBranchConditional, 0);
+ transformation =
+ TransformationAdjustBranchWeights(instruction_descriptor, {0, 1});
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+
+ // Tests non-OpBranchConditional instructions.
+ instruction_descriptor = MakeInstructionDescriptor(2, SpvOpTypeVoid, 0);
+ transformation =
+ TransformationAdjustBranchWeights(instruction_descriptor, {5, 6});
+ ASSERT_FALSE(
+ transformation.IsApplicable(context.get(), transformation_context));
+
+ instruction_descriptor = MakeInstructionDescriptor(20, SpvOpLabel, 0);
+ transformation =
+ TransformationAdjustBranchWeights(instruction_descriptor, {1, 2});
+ ASSERT_FALSE(
+ transformation.IsApplicable(context.get(), transformation_context));
+
+ instruction_descriptor = MakeInstructionDescriptor(49, SpvOpIAdd, 0);
+ transformation =
+ TransformationAdjustBranchWeights(instruction_descriptor, {1, 2});
+ ASSERT_FALSE(
+ transformation.IsApplicable(context.get(), transformation_context));
+}
+
+TEST(TransformationAdjustBranchWeightsTest, ApplyTest) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main" %51 %27
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ OpName %25 "buf"
+ OpMemberName %25 0 "value"
+ OpName %27 ""
+ OpName %51 "color"
+ OpMemberDecorate %25 0 Offset 0
+ OpDecorate %25 Block
+ OpDecorate %27 DescriptorSet 0
+ OpDecorate %27 Binding 0
+ OpDecorate %51 Location 0
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeFloat 32
+ %7 = OpTypeVector %6 4
+ %150 = OpTypeVector %6 2
+ %10 = OpConstant %6 0.300000012
+ %11 = OpConstant %6 0.400000006
+ %12 = OpConstant %6 0.5
+ %13 = OpConstant %6 1
+ %14 = OpConstantComposite %7 %10 %11 %12 %13
+ %15 = OpTypeInt 32 1
+ %18 = OpConstant %15 0
+ %25 = OpTypeStruct %6
+ %26 = OpTypePointer Uniform %25
+ %27 = OpVariable %26 Uniform
+ %28 = OpTypePointer Uniform %6
+ %32 = OpTypeBool
+ %103 = OpConstantTrue %32
+ %34 = OpConstant %6 0.100000001
+ %48 = OpConstant %15 1
+ %50 = OpTypePointer Output %7
+ %51 = OpVariable %50 Output
+ %100 = OpTypePointer Function %6
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %101 = OpVariable %100 Function
+ %102 = OpVariable %100 Function
+ OpBranch %19
+ %19 = OpLabel
+ %60 = OpPhi %7 %14 %5 %58 %20
+ %59 = OpPhi %15 %18 %5 %49 %20
+ %29 = OpAccessChain %28 %27 %18
+ %30 = OpLoad %6 %29
+ %31 = OpConvertFToS %15 %30
+ %33 = OpSLessThan %32 %59 %31
+ OpLoopMerge %21 %20 None
+ OpBranchConditional %33 %20 %21 1 2
+ %20 = OpLabel
+ %39 = OpCompositeExtract %6 %60 0
+ %40 = OpFAdd %6 %39 %34
+ %55 = OpCompositeInsert %7 %40 %60 0
+ %44 = OpCompositeExtract %6 %60 1
+ %45 = OpFSub %6 %44 %34
+ %58 = OpCompositeInsert %7 %45 %55 1
+ %49 = OpIAdd %15 %59 %48
+ OpBranch %19
+ %21 = OpLabel
+ OpStore %51 %60
+ OpSelectionMerge %105 None
+ OpBranchConditional %103 %104 %105
+ %104 = OpLabel
+ OpBranch %105
+ %105 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ ASSERT_TRUE(IsValid(env, context.get()));
+
+ FactManager fact_manager;
+ spvtools::ValidatorOptions validator_options;
+ TransformationContext transformation_context(&fact_manager,
+ validator_options);
+
+ auto instruction_descriptor =
+ MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
+ auto transformation =
+ TransformationAdjustBranchWeights(instruction_descriptor, {5, 6});
+ transformation.Apply(context.get(), &transformation_context);
+
+ instruction_descriptor =
+ MakeInstructionDescriptor(21, SpvOpBranchConditional, 0);
+ transformation =
+ TransformationAdjustBranchWeights(instruction_descriptor, {7, 8});
+ transformation.Apply(context.get(), &transformation_context);
+
+ std::string variant_shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main" %51 %27
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 310
+ OpName %4 "main"
+ OpName %25 "buf"
+ OpMemberName %25 0 "value"
+ OpName %27 ""
+ OpName %51 "color"
+ OpMemberDecorate %25 0 Offset 0
+ OpDecorate %25 Block
+ OpDecorate %27 DescriptorSet 0
+ OpDecorate %27 Binding 0
+ OpDecorate %51 Location 0
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeFloat 32
+ %7 = OpTypeVector %6 4
+ %150 = OpTypeVector %6 2
+ %10 = OpConstant %6 0.300000012
+ %11 = OpConstant %6 0.400000006
+ %12 = OpConstant %6 0.5
+ %13 = OpConstant %6 1
+ %14 = OpConstantComposite %7 %10 %11 %12 %13
+ %15 = OpTypeInt 32 1
+ %18 = OpConstant %15 0
+ %25 = OpTypeStruct %6
+ %26 = OpTypePointer Uniform %25
+ %27 = OpVariable %26 Uniform
+ %28 = OpTypePointer Uniform %6
+ %32 = OpTypeBool
+ %103 = OpConstantTrue %32
+ %34 = OpConstant %6 0.100000001
+ %48 = OpConstant %15 1
+ %50 = OpTypePointer Output %7
+ %51 = OpVariable %50 Output
+ %100 = OpTypePointer Function %6
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %101 = OpVariable %100 Function
+ %102 = OpVariable %100 Function
+ OpBranch %19
+ %19 = OpLabel
+ %60 = OpPhi %7 %14 %5 %58 %20
+ %59 = OpPhi %15 %18 %5 %49 %20
+ %29 = OpAccessChain %28 %27 %18
+ %30 = OpLoad %6 %29
+ %31 = OpConvertFToS %15 %30
+ %33 = OpSLessThan %32 %59 %31
+ OpLoopMerge %21 %20 None
+ OpBranchConditional %33 %20 %21 5 6
+ %20 = OpLabel
+ %39 = OpCompositeExtract %6 %60 0
+ %40 = OpFAdd %6 %39 %34
+ %55 = OpCompositeInsert %7 %40 %60 0
+ %44 = OpCompositeExtract %6 %60 1
+ %45 = OpFSub %6 %44 %34
+ %58 = OpCompositeInsert %7 %45 %55 1
+ %49 = OpIAdd %15 %59 %48
+ OpBranch %19
+ %21 = OpLabel
+ OpStore %51 %60
+ OpSelectionMerge %105 None
+ OpBranchConditional %103 %104 %105 7 8
+ %104 = OpLabel
+ OpBranch %105
+ %105 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
+}
+
+} // namespace
+} // namespace fuzz
+} // namespace spvtools