Instrument: Be sure Float16 capability on when generating float16 null (#2831)
diff --git a/source/opt/inst_bindless_check_pass.cpp b/source/opt/inst_bindless_check_pass.cpp
index b283354..4587343 100644
--- a/source/opt/inst_bindless_check_pass.cpp
+++ b/source/opt/inst_bindless_check_pass.cpp
@@ -296,7 +296,7 @@
// reference.
if (new_ref_id != 0) {
Instruction* phi_inst = builder.AddPhi(
- ref_type_id, {new_ref_id, valid_blk_id, builder.GetNullId(ref_type_id),
+ ref_type_id, {new_ref_id, valid_blk_id, GetNullId(ref_type_id),
last_invalid_blk_id});
context()->ReplaceAllUsesWith(ref->ref_inst->result_id(),
phi_inst->result_id());
diff --git a/source/opt/inst_buff_addr_check_pass.cpp b/source/opt/inst_buff_addr_check_pass.cpp
index 03221ef..ef29ce5 100644
--- a/source/opt/inst_buff_addr_check_pass.cpp
+++ b/source/opt/inst_buff_addr_check_pass.cpp
@@ -108,8 +108,8 @@
// reference.
if (new_ref_id != 0) {
Instruction* phi_inst = builder.AddPhi(
- ref_type_id, {new_ref_id, valid_blk_id, builder.GetNullId(ref_type_id),
- invalid_blk_id});
+ ref_type_id,
+ {new_ref_id, valid_blk_id, GetNullId(ref_type_id), invalid_blk_id});
context()->ReplaceAllUsesWith(ref_inst->result_id(), phi_inst->result_id());
}
new_blocks->push_back(std::move(new_blk_ptr));
diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h
index d3875c4..b9cb26a 100644
--- a/source/opt/ir_builder.h
+++ b/source/opt/ir_builder.h
@@ -358,16 +358,6 @@
return uint_inst->result_id();
}
- uint32_t GetNullId(uint32_t type_id) {
- analysis::TypeManager* type_mgr = GetContext()->get_type_mgr();
- analysis::ConstantManager* const_mgr = GetContext()->get_constant_mgr();
- const analysis::Type* type = type_mgr->GetType(type_id);
- const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
- Instruction* null_inst =
- const_mgr->GetDefiningInstruction(null_const, type_id);
- return null_inst->result_id();
- }
-
// Adds either a signed or unsigned 32 bit integer constant to the binary
// depedning on the |sign|. If |sign| is true then the value is added as a
// signed constant otherwise as an unsigned constant. If |sign| is false the
diff --git a/source/opt/pass.cpp b/source/opt/pass.cpp
index 72d7cea..09b78af 100644
--- a/source/opt/pass.cpp
+++ b/source/opt/pass.cpp
@@ -73,6 +73,17 @@
return ty_inst->GetSingleWordInOperand(0) == width;
}
+uint32_t Pass::GetNullId(uint32_t type_id) {
+ if (IsFloat(type_id, 16)) context()->AddCapability(SpvCapabilityFloat16);
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
+ const analysis::Type* type = type_mgr->GetType(type_id);
+ const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
+ Instruction* null_inst =
+ const_mgr->GetDefiningInstruction(null_const, type_id);
+ return null_inst->result_id();
+}
+
uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id,
Instruction* insertion_position) {
analysis::TypeManager* type_mgr = context()->get_type_mgr();
diff --git a/source/opt/pass.h b/source/opt/pass.h
index 356e94d..a8c9c4b 100644
--- a/source/opt/pass.h
+++ b/source/opt/pass.h
@@ -116,6 +116,9 @@
// float and |width|
bool IsFloat(uint32_t ty_id, uint32_t width);
+ // Return the id of OpConstantNull of type |type_id|. Create if necessary.
+ uint32_t GetNullId(uint32_t type_id);
+
protected:
// Constructs a new pass.
//