Handle overflow in wrap-opkill (#2801)
Fixes https://crbug/994203
diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h
index 5514e2d..f12dc95 100644
--- a/source/opt/ir_builder.h
+++ b/source/opt/ir_builder.h
@@ -473,9 +473,12 @@
operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
}
- std::unique_ptr<Instruction> new_inst(
- new Instruction(GetContext(), SpvOpFunctionCall, result_type,
- GetContext()->TakeNextId(), operands));
+ uint32_t result_id = GetContext()->TakeNextId();
+ if (result_id == 0) {
+ return nullptr;
+ }
+ std::unique_ptr<Instruction> new_inst(new Instruction(
+ GetContext(), SpvOpFunctionCall, result_type, result_id, operands));
return AddInstruction(std::move(new_inst));
}
diff --git a/source/opt/wrap_opkill.cpp b/source/opt/wrap_opkill.cpp
index 3504efe..d10cdd2 100644
--- a/source/opt/wrap_opkill.cpp
+++ b/source/opt/wrap_opkill.cpp
@@ -23,12 +23,19 @@
bool modified = false;
for (auto& func : *get_module()) {
- func.ForEachInst([this, &modified](Instruction* inst) {
+ bool successful = func.WhileEachInst([this, &modified](Instruction* inst) {
if (inst->opcode() == SpvOpKill) {
modified = true;
- ReplaceWithFunctionCall(inst);
+ if (!ReplaceWithFunctionCall(inst)) {
+ return false;
+ }
}
+ return true;
});
+
+ if (!successful) {
+ return Status::Failure;
+ }
}
if (opkill_function_ != nullptr) {
@@ -39,15 +46,22 @@
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}
-void WrapOpKill::ReplaceWithFunctionCall(Instruction* inst) {
+bool WrapOpKill::ReplaceWithFunctionCall(Instruction* inst) {
assert(inst->opcode() == SpvOpKill &&
"|inst| must be an OpKill instruction.");
InstructionBuilder ir_builder(
context(), inst,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- ir_builder.AddFunctionCall(GetVoidTypeId(), GetOpKillFuncId(), {});
+ uint32_t func_id = GetOpKillFuncId();
+ if (func_id == 0) {
+ return false;
+ }
+ if (ir_builder.AddFunctionCall(GetVoidTypeId(), func_id, {}) == nullptr) {
+ return false;
+ }
ir_builder.AddUnreachable();
context()->KillInst(inst);
+ return true;
}
uint32_t WrapOpKill::GetVoidTypeId() {
@@ -77,6 +91,9 @@
}
uint32_t opkill_func_id = TakeNextId();
+ if (opkill_func_id == 0) {
+ return 0;
+ }
// Generate the function start instruction
std::unique_ptr<Instruction> func_start(new Instruction(
@@ -91,8 +108,12 @@
opkill_function_->SetFunctionEnd(std::move(func_end));
// Create the one basic block for the function.
+ uint32_t lab_id = TakeNextId();
+ if (lab_id == 0) {
+ return 0;
+ }
std::unique_ptr<Instruction> label_inst(
- new Instruction(context(), SpvOpLabel, 0, TakeNextId(), {}));
+ new Instruction(context(), SpvOpLabel, 0, lab_id, {}));
std::unique_ptr<BasicBlock> bb(new BasicBlock(std::move(label_inst)));
// Add the OpKill to the basic block
diff --git a/source/opt/wrap_opkill.h b/source/opt/wrap_opkill.h
index 6f4699d..8b03281 100644
--- a/source/opt/wrap_opkill.h
+++ b/source/opt/wrap_opkill.h
@@ -41,8 +41,9 @@
private:
// Replaces the OpKill instruction |inst| with a function call to a function
// that contains a single instruction, which is OpKill. An OpUnreachable
- // instruction will be placed after the function call.
- void ReplaceWithFunctionCall(Instruction* inst);
+ // instruction will be placed after the function call. Return true if
+ // successful.
+ bool ReplaceWithFunctionCall(Instruction* inst);
// Returns the id of the void type.
uint32_t GetVoidTypeId();
@@ -50,8 +51,9 @@
// Returns the id of the function type for a void function with no parameters.
uint32_t GetVoidFunctionTypeId();
- // Return the id of a function that has return type void, no no parameters,
- // and contains a single instruction, which is an OpKill.
+ // Return the id of a function that has return type void, has no parameters,
+ // and contains a single instruction, which is an OpKill. Returns 0 if the
+ // function could not be generated.
uint32_t GetOpKillFuncId();
// The id of the void type. If its value is 0, then the void type has not
diff --git a/test/opt/wrap_opkill_test.cpp b/test/opt/wrap_opkill_test.cpp
index a0314ad..df1b865 100644
--- a/test/opt/wrap_opkill_test.cpp
+++ b/test/opt/wrap_opkill_test.cpp
@@ -193,6 +193,75 @@
SinglePassRunAndMatch<WrapOpKill>(text, true);
}
+TEST_F(WrapOpKillTest, IdBoundOverflow1) {
+ const std::string text = R"(
+OpCapability GeometryStreams
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 Pure|Const %3
+%4194302 = OpLabel
+OpKill
+OpFunctionEnd
+ )";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ std::vector<Message> messages = {
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+ SetMessageConsumer(GetTestMessageConsumer(messages));
+ auto result = SinglePassRunToBinary<WrapOpKill>(text, true);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
+TEST_F(WrapOpKillTest, IdBoundOverflow2) {
+ const std::string text = R"(
+OpCapability GeometryStreams
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 Pure|Const %3
+%4194301 = OpLabel
+OpKill
+OpFunctionEnd
+ )";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ std::vector<Message> messages = {
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+ SetMessageConsumer(GetTestMessageConsumer(messages));
+ auto result = SinglePassRunToBinary<WrapOpKill>(text, true);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
+TEST_F(WrapOpKillTest, IdBoundOverflow3) {
+ const std::string text = R"(
+OpCapability GeometryStreams
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 Pure|Const %3
+%4194300 = OpLabel
+OpKill
+OpFunctionEnd
+ )";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ std::vector<Message> messages = {
+ {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+ SetMessageConsumer(GetTestMessageConsumer(messages));
+ auto result = SinglePassRunToBinary<WrapOpKill>(text, true);
+ EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
} // namespace
} // namespace opt
} // namespace spvtools