Optimizer: Handle array type with OpSpecConstantOp length (#2652)
When it's an OpConstant or OpSpecConstant, then the literal
values are compared. If the OpSpecConstant also has a SpecId
decoration, then that's also compared.
Otherwise, it's an OpSpecConstantOp and we only compare the
ID of the OpSpecConstantOp instruction itself.
Fixes #2649
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index 113dc80..1c27b16 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -510,16 +510,8 @@
}
case Type::kArray: {
const Array* array_ty = type.AsArray();
- const Type* ele_ty = array_ty->element_type();
- if (array_ty->length_spec_id() != 0u)
- rebuilt_ty =
- MakeUnique<Array>(RebuildType(*ele_ty), array_ty->LengthId(),
- array_ty->length_spec_id());
- else
- rebuilt_ty =
- MakeUnique<Array>(RebuildType(*ele_ty), array_ty->LengthId(),
- array_ty->length_constant_type(),
- array_ty->length_constant_words());
+ rebuilt_ty =
+ MakeUnique<Array>(array_ty->element_type(), array_ty->length_info());
break;
}
case Type::kRuntimeArray: {
@@ -654,28 +646,45 @@
const Instruction* length_constant_inst = id_to_constant_inst_[length_id];
assert(length_constant_inst);
- // If it is a specialised constants, retrieve its SpecId.
+ // How will we distinguish one length value from another?
+ // Determine extra words required to distinguish this array length
+ // from another.
+ std::vector<uint32_t> extra_words{Array::LengthInfo::kDefiningId};
+ // If it is a specialised constant, retrieve its SpecId.
+ // Only OpSpecConstant has a SpecId.
uint32_t spec_id = 0u;
- Type* length_type = nullptr;
- Operand::OperandData length_words;
- if (spvOpcodeIsSpecConstant(length_constant_inst->opcode())) {
+ bool has_spec_id = false;
+ if (length_constant_inst->opcode() == SpvOpSpecConstant) {
context()->get_decoration_mgr()->ForEachDecoration(
length_id, SpvDecorationSpecId,
- [&spec_id](const Instruction& decoration) {
+ [&spec_id, &has_spec_id](const Instruction& decoration) {
assert(decoration.opcode() == SpvOpDecorate);
spec_id = decoration.GetSingleWordOperand(2u);
+ has_spec_id = true;
});
- } else {
- length_type = GetType(length_constant_inst->type_id());
- length_words = length_constant_inst->GetOperand(2u).words;
}
+ const auto opcode = length_constant_inst->opcode();
+ if (has_spec_id) {
+ extra_words.push_back(spec_id);
+ }
+ if ((opcode == SpvOpConstant) || (opcode == SpvOpSpecConstant)) {
+ // Always include the literal constant words. In the spec constant
+ // case, the constant might not be overridden, so it's still
+ // significant.
+ extra_words.insert(extra_words.end(),
+ length_constant_inst->GetOperand(2).words.begin(),
+ length_constant_inst->GetOperand(2).words.end());
+ extra_words[0] = has_spec_id ? Array::LengthInfo::kConstantWithSpecId
+ : Array::LengthInfo::kConstant;
+ } else {
+ assert(extra_words[0] == Array::LengthInfo::kDefiningId);
+ extra_words.push_back(length_id);
+ }
+ assert(extra_words.size() >= 2);
+ Array::LengthInfo length_info{length_id, extra_words};
- if (spec_id != 0u)
- type = new Array(GetType(inst.GetSingleWordInOperand(0)), length_id,
- spec_id);
- else
- type = new Array(GetType(inst.GetSingleWordInOperand(0)), length_id,
- length_type, length_words);
+ type = new Array(GetType(inst.GetSingleWordInOperand(0)), length_info);
+
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index af747cf..e345b2d 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -12,14 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "source/opt/types.h"
+
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <sstream>
+#include <string>
#include <unordered_set>
-#include "source/opt/types.h"
#include "source/util/make_unique.h"
+#include "spirv/unified1/spirv.h"
namespace spvtools {
namespace opt {
@@ -383,64 +386,42 @@
image_type_->GetHashWords(words, seen);
}
-Array::Array(Type* type, uint32_t length_id, uint32_t spec_id)
- : Type(kArray),
- element_type_(type),
- length_id_(length_id),
- length_spec_id_(spec_id),
- length_constant_type_(nullptr),
- length_constant_words_() {
+Array::Array(const Type* type, const Array::LengthInfo& length_info_arg)
+ : Type(kArray), element_type_(type), length_info_(length_info_arg) {
+ assert(type != nullptr);
assert(!type->AsVoid());
- assert(spec_id != 0u);
-}
-
-Array::Array(Type* type, uint32_t length_id, const Type* constant_type,
- Operand::OperandData constant_words)
- : Type(kArray),
- element_type_(type),
- length_id_(length_id),
- length_spec_id_(0u),
- length_constant_type_(constant_type),
- length_constant_words_(constant_words) {
- assert(!type->AsVoid());
- assert(constant_type && constant_type->AsInteger());
+ // We always have a word to say which case we're in, followed
+ // by at least one more word.
+ assert(length_info_arg.words.size() >= 2);
}
bool Array::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Array* at = that->AsArray();
if (!at) return false;
- bool is_same = element_type_->IsSameImpl(at->element_type_, seen) &&
- HasSameDecorations(that);
- // If it is a specialized constant
- if (length_spec_id_ != 0u) {
- // ensure they have the same SpecId
- is_same = is_same && length_spec_id_ == at->length_spec_id_;
- } else {
- // else, ensure they have the same length literal number.
- is_same =
- is_same &&
- length_constant_type_->IsSameImpl(at->length_constant_type_, seen) &&
- length_constant_words_ == at->length_constant_words_;
- }
+ bool is_same = element_type_->IsSameImpl(at->element_type_, seen);
+ is_same = is_same && HasSameDecorations(that);
+ is_same = is_same && (length_info_.words == at->length_info_.words);
return is_same;
}
std::string Array::str() const {
std::ostringstream oss;
- oss << "[" << element_type_->str() << ", id(" << length_id_ << ")]";
+ oss << "[" << element_type_->str() << ", id(" << LengthId() << "), words(";
+ const char* spacer = "";
+ for (auto w : length_info_.words) {
+ oss << spacer << w;
+ spacer = ",";
+ }
+ oss << ")]";
return oss.str();
}
void Array::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
element_type_->GetHashWords(words, seen);
- if (length_spec_id_ != 0u) {
- words->push_back(length_spec_id_);
- } else {
- length_constant_type_->GetHashWords(words, seen);
- words->insert(words->end(), length_constant_words_.begin(),
- length_constant_words_.end());
- }
+ // This should mirror the logic in IsSameImpl
+ words->insert(words->end(), length_info_.words.begin(),
+ length_info_.words.end());
}
void Array::ReplaceElementType(const Type* type) { element_type_ = type; }
@@ -575,7 +556,12 @@
return HasSameDecorations(that);
}
-std::string Pointer::str() const { return pointee_type_->str() + "*"; }
+std::string Pointer::str() const {
+ std::ostringstream os;
+ os << pointee_type_->str() << " " << static_cast<uint32_t>(storage_class_)
+ << "*";
+ return os.str();
+}
void Pointer::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
diff --git a/source/opt/types.h b/source/opt/types.h
index 381cab6..e9dcc70 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -357,19 +357,36 @@
class Array : public Type {
public:
- Array(Type* element_type, uint32_t length_id, uint32_t spec_id);
- Array(Type* element_type, uint32_t length_id, const Type* constant_type,
- Operand::OperandData constant_words);
+ // Data about the length operand, that helps us distinguish between one
+ // array length and another.
+ struct LengthInfo {
+ // The result id of the instruction defining the length.
+ const uint32_t id;
+ enum Case : uint32_t {
+ kConstant = 0,
+ kConstantWithSpecId = 1,
+ kDefiningId = 2
+ };
+ // Extra words used to distinshish one array length and another.
+ // - if OpConstant, then it's 0, then the words in the literal constant
+ // value.
+ // - if OpSpecConstant, then it's 1, then the SpecID decoration if there
+ // is one, followed by the words in the literal constant value.
+ // The spec might not be overridden, in which case we'll end up using
+ // the literal value.
+ // - Otherwise, it's an OpSpecConsant, and this 2, then the ID (again).
+ const std::vector<uint32_t> words;
+ };
+
+ // Constructs an array type with given element and length. If the length
+ // is an OpSpecConstant, then |spec_id| should be its SpecId decoration.
+ Array(const Type* element_type, const LengthInfo& length_info_arg);
Array(const Array&) = default;
std::string str() const override;
const Type* element_type() const { return element_type_; }
- uint32_t LengthId() const { return length_id_; }
- uint32_t length_spec_id() const { return length_spec_id_; }
- const Type* length_constant_type() const { return length_constant_type_; }
- Operand::OperandData length_constant_words() const {
- return length_constant_words_;
- }
+ uint32_t LengthId() const { return length_info_.id; }
+ const LengthInfo& length_info() const { return length_info_; }
Array* AsArray() override { return this; }
const Array* AsArray() const override { return this; }
@@ -383,10 +400,7 @@
bool IsSameImpl(const Type* that, IsSameCache*) const override;
const Type* element_type_;
- uint32_t length_id_;
- uint32_t length_spec_id_;
- const Type* length_constant_type_;
- Operand::OperandData length_constant_words_;
+ const LengthInfo length_info_;
};
class RuntimeArray : public Type {
diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp
index c5ff9e6..267d98c 100644
--- a/test/opt/type_manager_test.cpp
+++ b/test/opt/type_manager_test.cpp
@@ -117,10 +117,10 @@
types.emplace_back(new SampledImage(image2));
// Array
- types.emplace_back(new Array(f32, 100, 1u));
- types.emplace_back(new Array(f32, 42, 2u));
+ types.emplace_back(new Array(f32, Array::LengthInfo{100, {0, 100u}}));
+ types.emplace_back(new Array(f32, Array::LengthInfo{42, {0, 42u}}));
auto* a42f32 = types.back().get();
- types.emplace_back(new Array(u64, 24, s32, {42}));
+ types.emplace_back(new Array(u64, Array::LengthInfo{24, {0, 24u}}));
// RuntimeArray
types.emplace_back(new RuntimeArray(v3f32));
@@ -171,7 +171,8 @@
TEST(TypeManager, TypeStrings) {
const std::string text = R"(
- OpTypeForwardPointer !20 !2 ; id for %p is 20, Uniform is 2
+ OpDecorate %spec_const_with_id SpecId 99
+ OpTypeForwardPointer %p Uniform
%void = OpTypeVoid
%bool = OpTypeBool
%u32 = OpTypeInt 32 0
@@ -201,48 +202,68 @@
%ps = OpTypePipeStorage
%nb = OpTypeNamedBarrier
%rtacc = OpTypeAccelerationStructureNV
+ ; Set up other kinds of OpTypeArray
+ %s64 = OpTypeInt 64 1
+ ; ID 32
+ %spec_const_without_id = OpSpecConstant %s32 44
+ %spec_const_with_id = OpSpecConstant %s32 42 ;; This is ID 1
+ %long_constant = OpConstant %s64 5000000000
+ %spec_const_op = OpSpecConstantOp %s32 IAdd %id4 %id4
+ ; ID 35
+ %arr_spec_const_without_id = OpTypeArray %s32 %spec_const_without_id
+ %arr_spec_const_with_id = OpTypeArray %s32 %spec_const_with_id
+ %arr_long_constant = OpTypeArray %s32 %long_constant
+ %arr_spec_const_op = OpTypeArray %s32 %spec_const_op
)";
std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
- {1, "void"},
- {2, "bool"},
- {3, "uint32"},
- // Id 4 is used by the constant.
- {5, "sint32"},
- {6, "float64"},
- {7, "<uint32, 3>"},
- {8, "<<uint32, 3>, 3>"},
- {9, "image(sint32, 3, 0, 1, 1, 0, 3, 2)"},
- {10, "image(sint32, 3, 0, 1, 1, 0, 3, 0)"},
- {11, "sampler"},
- {12, "sampled_image(image(sint32, 3, 0, 1, 1, 0, 3, 2))"},
- {13, "sampled_image(image(sint32, 3, 0, 1, 1, 0, 3, 0))"},
- {14, "[uint32, id(4)]"},
- {15, "[float64]"},
- {16, "{uint32}"},
- {17, "{float64, sint32, <uint32, 3>}"},
- {18, "opaque('')"},
- {19, "opaque('opaque')"},
- {20, "{uint32}*"},
- {21, "(uint32, uint32) -> void"},
- {22, "event"},
- {23, "device_event"},
- {24, "reserve_id"},
- {25, "queue"},
- {26, "pipe(0)"},
- {27, "pipe_storage"},
- {28, "named_barrier"},
- {29, "accelerationStructureNV"},
+ {3, "void"},
+ {4, "bool"},
+ {5, "uint32"},
+ // Id 6 is used by the constant.
+ {7, "sint32"},
+ {8, "float64"},
+ {9, "<uint32, 3>"},
+ {10, "<<uint32, 3>, 3>"},
+ {11, "image(sint32, 3, 0, 1, 1, 0, 3, 2)"},
+ {12, "image(sint32, 3, 0, 1, 1, 0, 3, 0)"},
+ {13, "sampler"},
+ {14, "sampled_image(image(sint32, 3, 0, 1, 1, 0, 3, 2))"},
+ {15, "sampled_image(image(sint32, 3, 0, 1, 1, 0, 3, 0))"},
+ {16, "[uint32, id(6), words(0,4)]"},
+ {17, "[float64]"},
+ {18, "{uint32}"},
+ {19, "{float64, sint32, <uint32, 3>}"},
+ {20, "opaque('')"},
+ {21, "opaque('opaque')"},
+ {2, "{uint32} 2*"}, // Include storage class number
+ {22, "(uint32, uint32) -> void"},
+ {23, "event"},
+ {24, "device_event"},
+ {25, "reserve_id"},
+ {26, "queue"},
+ {27, "pipe(0)"},
+ {28, "pipe_storage"},
+ {29, "named_barrier"},
+ {30, "accelerationStructureNV"},
+ {31, "sint64"},
+ {35, "[sint32, id(32), words(0,44)]"},
+ {36, "[sint32, id(1), words(1,99,42)]"},
+ {37, "[sint32, id(33), words(0,705032704,1)]"},
+ {38, "[sint32, id(34), words(2,34)]"},
};
std::unique_ptr<IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
+ ASSERT_NE(nullptr, context.get()); // It assembled
TypeManager manager(nullptr, context.get());
EXPECT_EQ(type_id_strs.size(), manager.NumTypes());
for (const auto& p : type_id_strs) {
- EXPECT_EQ(p.second, manager.GetType(p.first)->str());
+ ASSERT_NE(nullptr, manager.GetType(p.first));
+ EXPECT_EQ(p.second, manager.GetType(p.first)->str())
+ << " id is " << p.first;
EXPECT_EQ(p.first, manager.GetId(manager.GetType(p.first)));
}
}
diff --git a/test/opt/types_test.cpp b/test/opt/types_test.cpp
index 2c0d8db..fd98806 100644
--- a/test/opt/types_test.cpp
+++ b/test/opt/types_test.cpp
@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "source/opt/types.h"
+
#include <memory>
#include <utility>
#include <vector>
#include "gtest/gtest.h"
-#include "source/opt/types.h"
#include "source/util/make_unique.h"
namespace spvtools {
@@ -46,8 +47,8 @@
std::unique_ptr<Type> image_t_;
};
-#define TestMultipleInstancesOfTheSameType(ty, ...) \
- TEST_F(SameTypeTest, MultiSame##ty) { \
+#define TestMultipleInstancesOfTheSameTypeQualified(ty, name, ...) \
+ TEST_F(SameTypeTest, MultiSame##ty##name) { \
std::vector<std::unique_ptr<Type>> types; \
for (int i = 0; i < 10; ++i) types.emplace_back(new ty(__VA_ARGS__)); \
for (size_t i = 0; i < types.size(); ++i) { \
@@ -61,6 +62,9 @@
} \
} \
}
+#define TestMultipleInstancesOfTheSameType(ty, ...) \
+ TestMultipleInstancesOfTheSameTypeQualified(ty, Simple, __VA_ARGS__)
+
TestMultipleInstancesOfTheSameType(Void);
TestMultipleInstancesOfTheSameType(Bool);
TestMultipleInstancesOfTheSameType(Integer, 32, true);
@@ -72,7 +76,23 @@
SpvAccessQualifierWriteOnly);
TestMultipleInstancesOfTheSameType(Sampler);
TestMultipleInstancesOfTheSameType(SampledImage, image_t_.get());
-TestMultipleInstancesOfTheSameType(Array, u32_t_.get(), 10, 3);
+// There are three classes of arrays, based on the kinds of length information
+// they have.
+// 1. Array length is a constant or spec constant without spec ID, with literals
+// for the constant value.
+TestMultipleInstancesOfTheSameTypeQualified(Array, LenConstant, u32_t_.get(),
+ Array::LengthInfo{42,
+ {
+ 0,
+ 9999,
+ }});
+// 2. Array length is a spec constant with a given spec id.
+TestMultipleInstancesOfTheSameTypeQualified(Array, LenSpecId, u32_t_.get(),
+ Array::LengthInfo{42, {1, 99}});
+// 3. Array length is an OpSpecConstantOp expression
+TestMultipleInstancesOfTheSameTypeQualified(Array, LenDefiningId, u32_t_.get(),
+ Array::LengthInfo{42, {2, 42}});
+
TestMultipleInstancesOfTheSameType(RuntimeArray, u32_t_.get());
TestMultipleInstancesOfTheSameType(Struct, std::vector<const Type*>{
u32_t_.get(), f64_t_.get()});
@@ -90,6 +110,7 @@
TestMultipleInstancesOfTheSameType(NamedBarrier);
TestMultipleInstancesOfTheSameType(AccelerationStructureNV);
#undef TestMultipleInstanceOfTheSameType
+#undef TestMultipleInstanceOfTheSameTypeQual
std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
// Types in this test case are only equal to themselves, nothing else.
@@ -151,10 +172,31 @@
types.emplace_back(new SampledImage(image2));
// Array
- types.emplace_back(new Array(f32, 100, 1u));
- types.emplace_back(new Array(f32, 42, 2u));
+ // Length is constant with integer bit representation of 42.
+ types.emplace_back(new Array(f32, Array::LengthInfo{99u, {0, 42u}}));
auto* a42f32 = types.back().get();
- types.emplace_back(new Array(u64, 24, s32, {42}));
+ // Differs from previous in length value only.
+ types.emplace_back(new Array(f32, Array::LengthInfo{99u, {0, 44u}}));
+ // Length is 64-bit constant integer value 42.
+ types.emplace_back(new Array(u64, Array::LengthInfo{100u, {0, 42u, 0u}}));
+ // Differs from previous in length value only.
+ types.emplace_back(new Array(u64, Array::LengthInfo{100u, {0, 44u, 0u}}));
+
+ // Length is spec constant with spec id 18 and default value 44.
+ types.emplace_back(new Array(f32, Array::LengthInfo{99u,
+ {
+ 1,
+ 18u,
+ 44u,
+ }}));
+ // Differs from previous in spec id only.
+ types.emplace_back(new Array(f32, Array::LengthInfo{99u, {1, 19u, 44u}}));
+ // Differs from previous in literal value only.
+ types.emplace_back(new Array(f32, Array::LengthInfo{99u, {1, 19u, 48u}}));
+ // Length is spec constant op with id 42.
+ types.emplace_back(new Array(f32, Array::LengthInfo{42u, {2, 42}}));
+ // Differs from previous in result id only.
+ types.emplace_back(new Array(f32, Array::LengthInfo{43u, {2, 43}}));
// RuntimeArray
types.emplace_back(new RuntimeArray(v3f32));
@@ -215,8 +257,8 @@
<< types[j]->str() << "'";
} else {
EXPECT_FALSE(types[i]->IsSame(types[j].get()))
- << "expected '" << types[i]->str() << "' is different to '"
- << types[j]->str() << "'";
+ << "entry (" << i << "," << j << ") expected '" << types[i]->str()
+ << "' is different to '" << types[j]->str() << "'";
}
}
}