Handle overflowing id in merge return
If the ids overflow when creating an integer constant in the ir_builder, there will be a nullptr dereference. This is happening from inside merge return.
We need to propagate the error up, and make sure it is handled appropriately.
diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h
index fe5feff..4433cf0 100644
--- a/source/opt/ir_builder.h
+++ b/source/opt/ir_builder.h
@@ -359,8 +359,9 @@
return AddInstruction(std::move(select));
}
- // Adds a signed int32 constant to the binary.
- // The |value| parameter is the constant value to be added.
+ // Returns a pointer to the definition of a signed 32-bit integer constant
+ // with the given value. Returns |nullptr| if the constant does not exist and
+ // cannot be created.
Instruction* GetSintConstant(int32_t value) {
return GetIntConstant<int32_t>(value, true);
}
@@ -381,21 +382,24 @@
GetContext()->TakeNextId(), ops));
return AddInstruction(std::move(construct));
}
- // Adds an unsigned int32 constant to the binary.
- // The |value| parameter is the constant value to be added.
+
+ // Returns a pointer to the definition of an unsigned 32-bit integer constant
+ // with the given value. Returns |nullptr| if the constant does not exist and
+ // cannot be created.
Instruction* GetUintConstant(uint32_t value) {
return GetIntConstant<uint32_t>(value, false);
}
uint32_t GetUintConstantId(uint32_t value) {
Instruction* uint_inst = GetUintConstant(value);
- return uint_inst->result_id();
+ return (uint_inst != nullptr ? uint_inst->result_id() : 0);
}
// Adds either a signed or unsigned 32 bit integer constant to the binary
- // depedning on the |sign|. If |sign| is true then the value is added as a
+ // depending on the |sign|. If |sign| is true then the value is added as a
// signed constant otherwise as an unsigned constant. If |sign| is false the
- // value must not be a negative number.
+ // value must not be a negative number. Returns false if the constant does
+ // not exists and could be be created.
template <typename T>
Instruction* GetIntConstant(T value, bool sign) {
// Assert that we are not trying to store a negative number in an unsigned
@@ -411,6 +415,10 @@
uint32_t type_id =
GetContext()->get_type_mgr()->GetTypeInstruction(&int_type);
+ if (type_id == 0) {
+ return nullptr;
+ }
+
// Get the memory managed type so that it is safe to be stored by
// GetConstant.
analysis::Type* rebuilt_type =
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index f160104..a962a7c 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -111,7 +111,9 @@
}
RecordImmediateDominators(function);
- AddSingleCaseSwitchAroundFunction();
+ if (!AddSingleCaseSwitchAroundFunction()) {
+ return false;
+ }
std::list<BasicBlock*> order;
cfg()->ComputeStructuredOrder(function, &*function->begin(), &order);
@@ -770,7 +772,7 @@
list->insert(pos, new_element);
}
-void MergeReturnPass::AddSingleCaseSwitchAroundFunction() {
+bool MergeReturnPass::AddSingleCaseSwitchAroundFunction() {
CreateReturnBlock();
CreateReturn(final_return_block_);
@@ -778,7 +780,10 @@
cfg()->RegisterBlock(final_return_block_);
}
- CreateSingleCaseSwitch(final_return_block_);
+ if (!CreateSingleCaseSwitch(final_return_block_)) {
+ return false;
+ }
+ return true;
}
BasicBlock* MergeReturnPass::CreateContinueTarget(uint32_t header_label_id) {
@@ -813,7 +818,7 @@
return new_block;
}
-void MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
+bool MergeReturnPass::CreateSingleCaseSwitch(BasicBlock* merge_target) {
// Insert the switch before any code is run. We have to split the entry
// block to make sure the OpVariable instructions remain in the entry block.
BasicBlock* start_block = &*function_->begin();
@@ -830,13 +835,17 @@
context(), start_block,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- builder.AddSwitch(builder.GetUintConstantId(0u), old_block->id(), {},
- merge_target->id());
+ uint32_t const_zero_id = builder.GetUintConstantId(0u);
+ if (const_zero_id == 0) {
+ return false;
+ }
+ builder.AddSwitch(const_zero_id, old_block->id(), {}, merge_target->id());
if (context()->AreAnalysesValid(IRContext::kAnalysisCFG)) {
cfg()->RegisterBlock(old_block);
cfg()->AddEdges(start_block);
}
+ return true;
}
bool MergeReturnPass::HasNontrivialUnreachableBlocks(Function* function) {
diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h
index 06a3e7b..4096ce7 100644
--- a/source/opt/merge_return_pass.h
+++ b/source/opt/merge_return_pass.h
@@ -277,7 +277,7 @@
// current function where the switch and case value are both zero and the
// default is the merge block. Returns after the switch is executed. Sets
// |final_return_block_|.
- void AddSingleCaseSwitchAroundFunction();
+ bool AddSingleCaseSwitchAroundFunction();
// Creates a new basic block that branches to |header_label_id|. Returns the
// new basic block. The block will be the second last basic block in the
@@ -286,7 +286,7 @@
// Creates a one case switch around the executable code of the function with
// |merge_target| as the merge node.
- void CreateSingleCaseSwitch(BasicBlock* merge_target);
+ bool CreateSingleCaseSwitch(BasicBlock* merge_target);
// Returns true if |function| has an unreachable block that is not a continue
// target that simply branches back to the header, or a merge block containing
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
index fd97efa..21960d1 100644
--- a/test/opt/pass_merge_return_test.cpp
+++ b/test/opt/pass_merge_return_test.cpp
@@ -2567,6 +2567,39 @@
SinglePassRunAndMatch<MergeReturnPass>(before, true);
}
+TEST_F(MergeReturnPassTest, OverflowTest1) {
+ const std::string text =
+ R"(
+; CHECK: OpReturn
+; CHECK-NOT: OpReturn
+; CHECK: OpFunctionEnd
+ OpCapability ClipDistance
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ %void = OpTypeVoid
+ %6 = OpTypeFunction %void
+ %2 = OpFunction %void None %6
+ %4194303 = OpLabel
+ OpBranch %18
+ %18 = OpLabel
+ OpLoopMerge %19 %20 None
+ OpBranch %21
+ %21 = OpLabel
+ OpReturn
+ %20 = OpLabel
+ OpBranch %18
+ %19 = OpLabel
+ OpUnreachable
+ OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ auto result =
+ SinglePassRunToBinary<MergeReturnPass>(text, /* skip_nop = */ true);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
} // namespace
} // namespace opt
} // namespace spvtools