Add descriptor array scalar replacement (#2742)
Creates a pass that will replace a descriptor array with individual variables. See #2740 for details.
Fixes #2740.
diff --git a/Android.mk b/Android.mk
index 8a507da..a6278af 100644
--- a/Android.mk
+++ b/Android.mk
@@ -95,6 +95,7 @@
source/opt/decompose_initialized_variables_pass.cpp \
source/opt/decoration_manager.cpp \
source/opt/def_use_manager.cpp \
+ source/opt/desc_sroa.cpp \
source/opt/dominator_analysis.cpp \
source/opt/dominator_tree.cpp \
source/opt/eliminate_dead_constant_pass.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index 90d80e5..84b21e1 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -491,6 +491,8 @@
"source/opt/decoration_manager.h",
"source/opt/def_use_manager.cpp",
"source/opt/def_use_manager.h",
+ "source/opt/desc_sroa.cpp",
+ "source/opt/desc_sroa.h",
"source/opt/dominator_analysis.cpp",
"source/opt/dominator_analysis.h",
"source/opt/dominator_tree.cpp",
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index a52dcd0..d442b97 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -784,6 +784,17 @@
// wide.
Optimizer::PassToken CreateGraphicsRobustAccessPass();
+// Create descriptor scalar replacement pass.
+// This pass replaces every array variable |desc| that has a DescriptorSet and
+// Binding decorations with a new variable for each element of the array.
+// Suppose |desc| was bound at binding |b|. Then the variable corresponding to
+// |desc[i]| will have binding |b+i|. The descriptor set will be the same. It
+// is assumed that no other variable already has a binding that will used by one
+// of the new variables. If not, the pass will generate invalid Spir-V. All
+// accesses to |desc| must be OpAccessChain instructions with a literal index
+// for the first index.
+Optimizer::PassToken CreateDescriptorScalarReplacementPass();
+
} // namespace spvtools
#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 278f794..2ebad51 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -33,6 +33,7 @@
decompose_initialized_variables_pass.h
decoration_manager.h
def_use_manager.h
+ desc_sroa.h
dominator_analysis.h
dominator_tree.h
eliminate_dead_constant_pass.h
@@ -134,6 +135,7 @@
decompose_initialized_variables_pass.cpp
decoration_manager.cpp
def_use_manager.cpp
+ desc_sroa.cpp
dominator_analysis.cpp
dominator_tree.cpp
eliminate_dead_constant_pass.cpp
diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp
new file mode 100644
index 0000000..36256ff
--- /dev/null
+++ b/source/opt/desc_sroa.cpp
@@ -0,0 +1,255 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/desc_sroa.h"
+
+#include <source/util/string_utils.h>
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status DescriptorScalarReplacement::Process() {
+ bool modified = false;
+
+ std::vector<Instruction*> vars_to_kill;
+
+ for (Instruction& var : context()->types_values()) {
+ if (IsCandidate(&var)) {
+ modified = true;
+ if (!ReplaceCandidate(&var)) {
+ return Status::Failure;
+ }
+ vars_to_kill.push_back(&var);
+ }
+ }
+
+ for (Instruction* var : vars_to_kill) {
+ context()->KillInst(var);
+ }
+
+ return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
+}
+
+bool DescriptorScalarReplacement::IsCandidate(Instruction* var) {
+ if (var->opcode() != SpvOpVariable) {
+ return false;
+ }
+
+ uint32_t ptr_type_id = var->type_id();
+ Instruction* ptr_type_inst =
+ context()->get_def_use_mgr()->GetDef(ptr_type_id);
+ if (ptr_type_inst->opcode() != SpvOpTypePointer) {
+ return false;
+ }
+
+ uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
+ Instruction* var_type_inst =
+ context()->get_def_use_mgr()->GetDef(var_type_id);
+ if (var_type_inst->opcode() != SpvOpTypeArray) {
+ return false;
+ }
+
+ bool has_desc_set_decoration = false;
+ context()->get_decoration_mgr()->ForEachDecoration(
+ var->result_id(), SpvDecorationDescriptorSet,
+ [&has_desc_set_decoration](const Instruction&) {
+ has_desc_set_decoration = true;
+ });
+ if (!has_desc_set_decoration) {
+ return false;
+ }
+
+ bool has_binding_decoration = false;
+ context()->get_decoration_mgr()->ForEachDecoration(
+ var->result_id(), SpvDecorationBinding,
+ [&has_binding_decoration](const Instruction&) {
+ has_binding_decoration = true;
+ });
+ if (!has_binding_decoration) {
+ return false;
+ }
+
+ return true;
+}
+
+bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
+ std::vector<Instruction*> work_list;
+ bool failed = !get_def_use_mgr()->WhileEachUser(
+ var->result_id(), [this, &work_list](Instruction* use) {
+ if (use->opcode() == SpvOpName) {
+ return true;
+ }
+
+ if (use->IsDecoration()) {
+ return true;
+ }
+
+ switch (use->opcode()) {
+ case SpvOpAccessChain:
+ case SpvOpInBoundsAccessChain:
+ work_list.push_back(use);
+ return true;
+ default:
+ context()->EmitErrorMessage(
+ "Variable cannot be replaced: invalid instruction", use);
+ return false;
+ }
+ return true;
+ });
+
+ if (failed) {
+ return false;
+ }
+
+ for (Instruction* use : work_list) {
+ if (!ReplaceAccessChain(var, use)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
+ Instruction* use) {
+ if (use->NumInOperands() <= 1) {
+ context()->EmitErrorMessage(
+ "Variable cannot be replaced: invalid instruction", use);
+ return false;
+ }
+
+ uint32_t idx_id = use->GetSingleWordInOperand(1);
+ const analysis::Constant* idx_const =
+ context()->get_constant_mgr()->FindDeclaredConstant(idx_id);
+ if (idx_const == nullptr) {
+ context()->EmitErrorMessage("Variable cannot be replaced: invalid index",
+ use);
+ return false;
+ }
+
+ uint32_t idx = idx_const->GetU32();
+ uint32_t replacement_var = GetReplacementVariable(var, idx);
+
+ if (use->NumInOperands() == 2) {
+ // We are not indexing into the replacement variable. We can replaces the
+ // access chain with the replacement varibale itself.
+ context()->ReplaceAllUsesWith(use->result_id(), replacement_var);
+ context()->KillInst(use);
+ return true;
+ }
+
+ // We need to build a new access chain with the replacement variable as the
+ // base address.
+ Instruction::OperandList new_operands;
+
+ // Same result id and result type.
+ new_operands.emplace_back(use->GetOperand(0));
+ new_operands.emplace_back(use->GetOperand(1));
+
+ // Use the replacement variable as the base address.
+ new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var}});
+
+ // Drop the first index because it is consumed by the replacment, and copy the
+ // rest.
+ for (uint32_t i = 4; i < use->NumOperands(); i++) {
+ new_operands.emplace_back(use->GetOperand(i));
+ }
+
+ use->ReplaceOperands(new_operands);
+ context()->UpdateDefUse(use);
+ return true;
+}
+
+uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var,
+ uint32_t idx) {
+ auto replacement_vars = replacement_variables_.find(var);
+ if (replacement_vars == replacement_variables_.end()) {
+ uint32_t ptr_type_id = var->type_id();
+ Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
+ assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
+ "Variable should be a pointer to an array.");
+ uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1);
+ Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id);
+ assert(arr_type_inst->opcode() == SpvOpTypeArray &&
+ "Variable should be a pointer to an array.");
+
+ uint32_t array_len_id = arr_type_inst->GetSingleWordInOperand(1);
+ const analysis::Constant* array_len_const =
+ context()->get_constant_mgr()->FindDeclaredConstant(array_len_id);
+ assert(array_len_const != nullptr && "Array length must be a constant.");
+ uint32_t array_len = array_len_const->GetU32();
+
+ replacement_vars = replacement_variables_
+ .insert({var, std::vector<uint32_t>(array_len, 0)})
+ .first;
+ }
+
+ if (replacement_vars->second[idx] == 0) {
+ replacement_vars->second[idx] = CreateReplacementVariable(var, idx);
+ }
+
+ return replacement_vars->second[idx];
+}
+
+uint32_t DescriptorScalarReplacement::CreateReplacementVariable(
+ Instruction* var, uint32_t idx) {
+ // The storage class for the new variable is the same as the original.
+ SpvStorageClass storage_class =
+ static_cast<SpvStorageClass>(var->GetSingleWordInOperand(0));
+
+ // The type for the new variable will be a pointer to type of the elements of
+ // the array.
+ uint32_t ptr_type_id = var->type_id();
+ Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
+ assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
+ "Variable should be a pointer to an array.");
+ uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1);
+ Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id);
+ assert(arr_type_inst->opcode() == SpvOpTypeArray &&
+ "Variable should be a pointer to an array.");
+ uint32_t element_type_id = arr_type_inst->GetSingleWordInOperand(0);
+
+ uint32_t ptr_element_type_id = context()->get_type_mgr()->FindPointerToType(
+ element_type_id, storage_class);
+
+ // Create the variable.
+ uint32_t id = TakeNextId();
+ std::unique_ptr<Instruction> variable(
+ new Instruction(context(), SpvOpVariable, ptr_element_type_id, id,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_STORAGE_CLASS,
+ {static_cast<uint32_t>(storage_class)}}}));
+ context()->AddGlobalValue(std::move(variable));
+
+ // Copy all of the decorations to the new variable. The only difference is
+ // the Binding decoration needs to be adjusted.
+ for (auto old_decoration :
+ get_decoration_mgr()->GetDecorationsFor(var->result_id(), true)) {
+ assert(old_decoration->opcode() == SpvOpDecorate);
+ std::unique_ptr<Instruction> new_decoration(
+ old_decoration->Clone(context()));
+ new_decoration->SetInOperand(0, {id});
+
+ uint32_t decoration = new_decoration->GetSingleWordInOperand(1u);
+ if (decoration == SpvDecorationBinding) {
+ uint32_t new_binding = new_decoration->GetSingleWordInOperand(2) + idx;
+ new_decoration->SetInOperand(2, {new_binding});
+ }
+ context()->AddAnnotationInst(std::move(new_decoration));
+ }
+
+ return id;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h
new file mode 100644
index 0000000..a95c6b5
--- /dev/null
+++ b/source/opt/desc_sroa.h
@@ -0,0 +1,84 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_DESC_SROA_H_
+#define SOURCE_OPT_DESC_SROA_H_
+
+#include <cstdio>
+#include <memory>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "source/opt/function.h"
+#include "source/opt/pass.h"
+#include "source/opt/type_manager.h"
+
+namespace spvtools {
+namespace opt {
+
+// Documented in optimizer.hpp
+class DescriptorScalarReplacement : public Pass {
+ public:
+ DescriptorScalarReplacement() {}
+
+ const char* name() const override { return "descriptor-scalar-replacement"; }
+
+ Status Process() override;
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisDefUse |
+ IRContext::kAnalysisInstrToBlockMapping |
+ IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG |
+ IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
+ }
+
+ private:
+ // Returns true if |var| is an OpVariable instruction that represents a
+ // descriptor array. These are the variables that we want to replace.
+ bool IsCandidate(Instruction* var);
+
+ // Replaces all references to |var| by new variables, one for each element of
+ // the array |var|. The binding for the new variables corresponding to
+ // element i will be the binding of |var| plus i. Returns true if successful.
+ bool ReplaceCandidate(Instruction* var);
+
+ // Replaces the base address |var| in the OpAccessChain or
+ // OpInBoundsAccessChain instruction |use| by the variable that the access
+ // chain accesses. The first index in |use| must be an |OpConstant|. Returns
+ // |true| if successful.
+ bool ReplaceAccessChain(Instruction* var, Instruction* use);
+
+ // Returns the id of the variable that will be used to replace the |idx|th
+ // element of |var|. The variable is created if it has not already been
+ // created.
+ uint32_t GetReplacementVariable(Instruction* var, uint32_t idx);
+
+ // Returns the id of a new variable that can be used to replace the |idx|th
+ // element of |var|.
+ uint32_t CreateReplacementVariable(Instruction* var, uint32_t idx);
+
+ // A map from an OpVariable instruction to the set of variables that will be
+ // used to replace it. The entry |replacement_variables_[var][i]| is the id of
+ // a variable that will be used in the place of the the ith element of the
+ // array |var|. If the entry is |0|, then the variable has not been
+ // created yet.
+ std::map<Instruction*, std::vector<uint32_t>> replacement_variables_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_DESC_SROA_H_
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index b600f12..823c2b7 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -788,6 +788,42 @@
return modified;
}
+void IRContext::EmitErrorMessage(std::string message, Instruction* inst) {
+ if (!consumer()) {
+ return;
+ }
+
+ Instruction* line_inst = inst;
+ while (line_inst != nullptr) { // Stop at the beginning of the basic block.
+ if (!line_inst->dbg_line_insts().empty()) {
+ line_inst = &line_inst->dbg_line_insts().back();
+ if (line_inst->opcode() == SpvOpNoLine) {
+ line_inst = nullptr;
+ }
+ break;
+ }
+ line_inst = line_inst->PreviousNode();
+ }
+
+ uint32_t line_number = 0;
+ uint32_t col_number = 0;
+ char* source = nullptr;
+ if (line_inst != nullptr) {
+ Instruction* file_name =
+ get_def_use_mgr()->GetDef(line_inst->GetSingleWordInOperand(0));
+ source = reinterpret_cast<char*>(&file_name->GetInOperand(0).words[0]);
+
+ // Get the line number and column number.
+ line_number = line_inst->GetSingleWordInOperand(1);
+ col_number = line_inst->GetSingleWordInOperand(2);
+ }
+
+ message +=
+ "\n " + inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ consumer()(SPV_MSG_ERROR, source, {line_number, col_number, 0},
+ message.c_str());
+}
+
// Gets the dominator analysis for function |f|.
DominatorAnalysis* IRContext::GetDominatorAnalysis(const Function* f) {
if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) {
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 308f633..05df9c0 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -556,6 +556,10 @@
bool ProcessCallTreeFromRoots(ProcessFunction& pfn,
std::queue<uint32_t>* roots);
+ // Emmits a error message to the message consumer indicating the error
+ // described by |message| occurred in |inst|.
+ void EmitErrorMessage(std::string message, Instruction* inst);
+
private:
// Builds the def-use manager from scratch, even if it was already valid.
void BuildDefUseManager() {
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index 2dd1708..4cc5e97 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -315,6 +315,8 @@
RegisterPass(CreateCombineAccessChainsPass());
} else if (pass_name == "convert-local-access-chains") {
RegisterPass(CreateLocalAccessChainConvertPass());
+ } else if (pass_name == "descriptor-scalar-replacement") {
+ RegisterPass(CreateDescriptorScalarReplacementPass());
} else if (pass_name == "eliminate-dead-code-aggressive") {
RegisterPass(CreateAggressiveDCEPass());
} else if (pass_name == "propagate-line-info") {
@@ -886,4 +888,9 @@
MakeUnique<opt::GraphicsRobustAccessPass>());
}
+Optimizer::PassToken CreateDescriptorScalarReplacementPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::DescriptorScalarReplacement>());
+}
+
} // namespace spvtools
diff --git a/source/opt/passes.h b/source/opt/passes.h
index 5eddc22..86588f7 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -29,6 +29,7 @@
#include "source/opt/dead_insert_elim_pass.h"
#include "source/opt/dead_variable_elimination.h"
#include "source/opt/decompose_initialized_variables_pass.h"
+#include "source/opt/desc_sroa.h"
#include "source/opt/eliminate_dead_constant_pass.h"
#include "source/opt/eliminate_dead_functions_pass.h"
#include "source/opt/eliminate_dead_members_pass.h"
diff --git a/source/util/string_utils.h b/source/util/string_utils.h
index f1cd179..4282aa9 100644
--- a/source/util/string_utils.h
+++ b/source/util/string_utils.h
@@ -15,8 +15,10 @@
#ifndef SOURCE_UTIL_STRING_UTILS_H_
#define SOURCE_UTIL_STRING_UTILS_H_
+#include <assert.h>
#include <sstream>
#include <string>
+#include <vector>
#include "source/util/string_utils.h"
@@ -42,6 +44,48 @@
// string will be empty.
std::pair<std::string, std::string> SplitFlagArgs(const std::string& flag);
+// Encodes a string as a sequence of words, using the SPIR-V encoding.
+inline std::vector<uint32_t> MakeVector(std::string input) {
+ std::vector<uint32_t> result;
+ uint32_t word = 0;
+ size_t num_bytes = input.size();
+ // SPIR-V strings are null-terminated. The byte_index == num_bytes
+ // case is used to push the terminating null byte.
+ for (size_t byte_index = 0; byte_index <= num_bytes; byte_index++) {
+ const auto new_byte =
+ (byte_index < num_bytes ? uint8_t(input[byte_index]) : uint8_t(0));
+ word |= (new_byte << (8 * (byte_index % sizeof(uint32_t))));
+ if (3 == (byte_index % sizeof(uint32_t))) {
+ result.push_back(word);
+ word = 0;
+ }
+ }
+ // Emit a trailing partial word.
+ if ((num_bytes + 1) % sizeof(uint32_t)) {
+ result.push_back(word);
+ }
+ return result;
+}
+
+// Decode a string from a sequence of words, using the SPIR-V encoding.
+template <class VectorType>
+inline std::string MakeString(const VectorType& words) {
+ std::string result;
+
+ for (uint32_t word : words) {
+ for (int byte_index = 0; byte_index < 4; byte_index++) {
+ uint32_t extracted_word = (word >> (8 * byte_index)) & 0xFF;
+ char c = static_cast<char>(extracted_word);
+ if (c == 0) {
+ return result;
+ }
+ result += c;
+ }
+ }
+ assert(false && "Did not find terminating null for the string.");
+ return result;
+} // namespace utils
+
} // namespace utils
} // namespace spvtools
diff --git a/test/assembly_context_test.cpp b/test/assembly_context_test.cpp
index ee0bb24..c8aa06b 100644
--- a/test/assembly_context_test.cpp
+++ b/test/assembly_context_test.cpp
@@ -17,6 +17,7 @@
#include "gmock/gmock.h"
#include "source/instruction.h"
+#include "source/util/string_utils.h"
#include "test/unit_spirv.h"
namespace spvtools {
@@ -40,9 +41,8 @@
ASSERT_EQ(SPV_SUCCESS,
context.binaryEncodeString(GetParam().str.c_str(), &inst));
// We already trust MakeVector
- EXPECT_THAT(inst.words,
- Eq(Concatenate({GetParam().initial_contents,
- spvtest::MakeVector(GetParam().str)})));
+ EXPECT_THAT(inst.words, Eq(Concatenate({GetParam().initial_contents,
+ utils::MakeVector(GetParam().str)})));
}
// clang-format off
diff --git a/test/binary_parse_test.cpp b/test/binary_parse_test.cpp
index b966102..54664fc 100644
--- a/test/binary_parse_test.cpp
+++ b/test/binary_parse_test.cpp
@@ -21,6 +21,7 @@
#include "gmock/gmock.h"
#include "source/latest_version_opencl_std_header.h"
#include "source/table.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -39,7 +40,7 @@
using ::spvtest::Concatenate;
using ::spvtest::MakeInstruction;
-using ::spvtest::MakeVector;
+using utils::MakeVector;
using ::spvtest::ScopedContext;
using ::testing::_;
using ::testing::AnyOf;
diff --git a/test/comment_test.cpp b/test/comment_test.cpp
index f46b72a..49f8df6 100644
--- a/test/comment_test.cpp
+++ b/test/comment_test.cpp
@@ -15,6 +15,7 @@
#include <string>
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -23,7 +24,7 @@
using spvtest::Concatenate;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using testing::Eq;
diff --git a/test/ext_inst.debuginfo_test.cpp b/test/ext_inst.debuginfo_test.cpp
index ec012e0..9090c24 100644
--- a/test/ext_inst.debuginfo_test.cpp
+++ b/test/ext_inst.debuginfo_test.cpp
@@ -17,6 +17,7 @@
#include "DebugInfo.h"
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -31,7 +32,7 @@
using spvtest::Concatenate;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using testing::Eq;
struct InstructionCase {
diff --git a/test/ext_inst.opencl_test.cpp b/test/ext_inst.opencl_test.cpp
index 7dd903e..7547d92 100644
--- a/test/ext_inst.opencl_test.cpp
+++ b/test/ext_inst.opencl_test.cpp
@@ -17,6 +17,7 @@
#include "gmock/gmock.h"
#include "source/latest_version_opencl_std_header.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -25,7 +26,7 @@
using spvtest::Concatenate;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using testing::Eq;
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index 6131c9b..366a61f 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -34,6 +34,7 @@
decompose_initialized_variables_test.cpp
decoration_manager_test.cpp
def_use_test.cpp
+ desc_sroa_test.cpp
eliminate_dead_const_test.cpp
eliminate_dead_functions_test.cpp
eliminate_dead_member_test.cpp
diff --git a/test/opt/decoration_manager_test.cpp b/test/opt/decoration_manager_test.cpp
index 3eb3ef5..fcfbff0 100644
--- a/test/opt/decoration_manager_test.cpp
+++ b/test/opt/decoration_manager_test.cpp
@@ -22,6 +22,7 @@
#include "source/opt/decoration_manager.h"
#include "source/opt/ir_context.h"
#include "source/spirv_constant.h"
+#include "source/util/string_utils.h"
#include "test/unit_spirv.h"
namespace spvtools {
@@ -29,7 +30,7 @@
namespace analysis {
namespace {
-using spvtest::MakeVector;
+using utils::MakeVector;
class DecorationManagerTest : public ::testing::Test {
public:
diff --git a/test/opt/desc_sroa_test.cpp b/test/opt/desc_sroa_test.cpp
new file mode 100644
index 0000000..04ea0f7
--- /dev/null
+++ b/test/opt/desc_sroa_test.cpp
@@ -0,0 +1,209 @@
+// Copyright (c) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "test/opt/assembly_builder.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using DescriptorScalarReplacementTest = PassTest<::testing::Test>;
+
+TEST_F(DescriptorScalarReplacementTest, ExpandTexture) {
+ const std::string text = R"(
+; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var1]] Binding 0
+; CHECK: OpDecorate [[var2:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var2]] Binding 1
+; CHECK: OpDecorate [[var3:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var3]] Binding 2
+; CHECK: OpDecorate [[var4:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var4]] Binding 3
+; CHECK: OpDecorate [[var5:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var5]] Binding 4
+; CHECK: [[image_type:%\w+]] = OpTypeImage
+; CHECK: [[ptr_type:%\w+]] = OpTypePointer UniformConstant [[image_type]]
+; CHECK: [[var1]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var2]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var3]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var4]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var5]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: OpLoad [[image_type]] [[var1]]
+; CHECK: OpLoad [[image_type]] [[var2]]
+; CHECK: OpLoad [[image_type]] [[var3]]
+; CHECK: OpLoad [[image_type]] [[var4]]
+; CHECK: OpLoad [[image_type]] [[var5]]
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource HLSL 600
+ OpDecorate %MyTextures DescriptorSet 0
+ OpDecorate %MyTextures Binding 0
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %int_1 = OpConstant %int 1
+ %int_2 = OpConstant %int 2
+ %int_3 = OpConstant %int 3
+ %int_4 = OpConstant %int 4
+ %uint = OpTypeInt 32 0
+ %uint_5 = OpConstant %uint 5
+ %float = OpTypeFloat 32
+%type_2d_image = OpTypeImage %float 2D 2 0 0 1 Unknown
+%_arr_type_2d_image_uint_5 = OpTypeArray %type_2d_image %uint_5
+%_ptr_UniformConstant__arr_type_2d_image_uint_5 = OpTypePointer UniformConstant %_arr_type_2d_image_uint_5
+ %v2float = OpTypeVector %float 2
+ %void = OpTypeVoid
+ %26 = OpTypeFunction %void
+%_ptr_UniformConstant_type_2d_image = OpTypePointer UniformConstant %type_2d_image
+ %MyTextures = OpVariable %_ptr_UniformConstant__arr_type_2d_image_uint_5 UniformConstant
+ %main = OpFunction %void None %26
+ %28 = OpLabel
+ %29 = OpUndef %v2float
+ %30 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_0
+ %31 = OpLoad %type_2d_image %30
+ %35 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_1
+ %36 = OpLoad %type_2d_image %35
+ %40 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_2
+ %41 = OpLoad %type_2d_image %40
+ %45 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_3
+ %46 = OpLoad %type_2d_image %45
+ %50 = OpAccessChain %_ptr_UniformConstant_type_2d_image %MyTextures %int_4
+ %51 = OpLoad %type_2d_image %50
+ OpReturn
+ OpFunctionEnd
+
+ )";
+
+ SinglePassRunAndMatch<DescriptorScalarReplacement>(text, true);
+}
+
+TEST_F(DescriptorScalarReplacementTest, ExpandSampler) {
+ const std::string text = R"(
+; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var1]] Binding 1
+; CHECK: OpDecorate [[var2:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var2]] Binding 2
+; CHECK: OpDecorate [[var3:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var3]] Binding 3
+; CHECK: [[sampler_type:%\w+]] = OpTypeSampler
+; CHECK: [[ptr_type:%\w+]] = OpTypePointer UniformConstant [[sampler_type]]
+; CHECK: [[var1]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var2]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: [[var3]] = OpVariable [[ptr_type]] UniformConstant
+; CHECK: OpLoad [[sampler_type]] [[var1]]
+; CHECK: OpLoad [[sampler_type]] [[var2]]
+; CHECK: OpLoad [[sampler_type]] [[var3]]
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource HLSL 600
+ OpDecorate %MySampler DescriptorSet 0
+ OpDecorate %MySampler Binding 1
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %int_1 = OpConstant %int 1
+ %int_2 = OpConstant %int 2
+ %uint = OpTypeInt 32 0
+ %uint_3 = OpConstant %uint 3
+%type_sampler = OpTypeSampler
+%_arr_type_sampler_uint_3 = OpTypeArray %type_sampler %uint_3
+%_ptr_UniformConstant__arr_type_sampler_uint_3 = OpTypePointer UniformConstant %_arr_type_sampler_uint_3
+ %void = OpTypeVoid
+ %26 = OpTypeFunction %void
+%_ptr_UniformConstant_type_sampler = OpTypePointer UniformConstant %type_sampler
+ %MySampler = OpVariable %_ptr_UniformConstant__arr_type_sampler_uint_3 UniformConstant
+ %main = OpFunction %void None %26
+ %28 = OpLabel
+ %31 = OpAccessChain %_ptr_UniformConstant_type_sampler %MySampler %int_0
+ %32 = OpLoad %type_sampler %31
+ %35 = OpAccessChain %_ptr_UniformConstant_type_sampler %MySampler %int_1
+ %36 = OpLoad %type_sampler %35
+ %40 = OpAccessChain %_ptr_UniformConstant_type_sampler %MySampler %int_2
+ %41 = OpLoad %type_sampler %40
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<DescriptorScalarReplacement>(text, true);
+}
+
+TEST_F(DescriptorScalarReplacementTest, ExpandSSBO) {
+ // Tests the expansion of an SSBO. Also check that an access chain with more
+ // than 1 index is correctly handled.
+ const std::string text = R"(
+; CHECK: OpDecorate [[var1:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var1]] Binding 0
+; CHECK: OpDecorate [[var2:%\w+]] DescriptorSet 0
+; CHECK: OpDecorate [[var2]] Binding 1
+; CHECK: OpTypeStruct
+; CHECK: [[struct_type:%\w+]] = OpTypeStruct
+; CHECK: [[ptr_type:%\w+]] = OpTypePointer Uniform [[struct_type]]
+; CHECK: [[var1]] = OpVariable [[ptr_type]] Uniform
+; CHECK: [[var2]] = OpVariable [[ptr_type]] Uniform
+; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[var1]] %uint_0 %uint_0 %uint_0
+; CHECK: OpLoad %v4float [[ac1]]
+; CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_v4float [[var2]] %uint_0 %uint_0 %uint_0
+; CHECK: OpLoad %v4float [[ac2]]
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource HLSL 600
+ OpDecorate %buffers DescriptorSet 0
+ OpDecorate %buffers Binding 0
+ OpMemberDecorate %S 0 Offset 0
+ OpDecorate %_runtimearr_S ArrayStride 16
+ OpMemberDecorate %type_StructuredBuffer_S 0 Offset 0
+ OpMemberDecorate %type_StructuredBuffer_S 0 NonWritable
+ OpDecorate %type_StructuredBuffer_S BufferBlock
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %uint_2 = OpConstant %uint 2
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+ %S = OpTypeStruct %v4float
+%_runtimearr_S = OpTypeRuntimeArray %S
+%type_StructuredBuffer_S = OpTypeStruct %_runtimearr_S
+%_arr_type_StructuredBuffer_S_uint_2 = OpTypeArray %type_StructuredBuffer_S %uint_2
+%_ptr_Uniform__arr_type_StructuredBuffer_S_uint_2 = OpTypePointer Uniform %_arr_type_StructuredBuffer_S_uint_2
+%_ptr_Uniform_type_StructuredBuffer_S = OpTypePointer Uniform %type_StructuredBuffer_S
+ %void = OpTypeVoid
+ %19 = OpTypeFunction %void
+%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float
+ %buffers = OpVariable %_ptr_Uniform__arr_type_StructuredBuffer_S_uint_2 Uniform
+ %main = OpFunction %void None %19
+ %21 = OpLabel
+ %22 = OpAccessChain %_ptr_Uniform_v4float %buffers %uint_0 %uint_0 %uint_0 %uint_0
+ %23 = OpLoad %v4float %22
+ %24 = OpAccessChain %_ptr_Uniform_type_StructuredBuffer_S %buffers %uint_1
+ %25 = OpAccessChain %_ptr_Uniform_v4float %24 %uint_0 %uint_0 %uint_0
+ %26 = OpLoad %v4float %25
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<DescriptorScalarReplacement>(text, true);
+}
+
+} // namespace
+} // namespace opt
+} // namespace spvtools
diff --git a/test/text_to_binary.annotation_test.cpp b/test/text_to_binary.annotation_test.cpp
index 69a4861..61bdf64 100644
--- a/test/text_to_binary.annotation_test.cpp
+++ b/test/text_to_binary.annotation_test.cpp
@@ -21,6 +21,7 @@
#include <vector>
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -29,7 +30,7 @@
using spvtest::EnumCase;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using ::testing::Combine;
using ::testing::Eq;
diff --git a/test/text_to_binary.debug_test.cpp b/test/text_to_binary.debug_test.cpp
index f9a4645..39ba5c5 100644
--- a/test/text_to_binary.debug_test.cpp
+++ b/test/text_to_binary.debug_test.cpp
@@ -19,6 +19,7 @@
#include <vector>
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -26,7 +27,7 @@
namespace {
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using ::testing::Eq;
diff --git a/test/text_to_binary.extension_test.cpp b/test/text_to_binary.extension_test.cpp
index 84552b5..9408e9a 100644
--- a/test/text_to_binary.extension_test.cpp
+++ b/test/text_to_binary.extension_test.cpp
@@ -22,6 +22,7 @@
#include "gmock/gmock.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/latest_version_opencl_std_header.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -30,7 +31,7 @@
using spvtest::Concatenate;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using spvtest::TextToBinaryTest;
using ::testing::Combine;
using ::testing::Eq;
diff --git a/test/text_to_binary.mode_setting_test.cpp b/test/text_to_binary.mode_setting_test.cpp
index d1b69dd..8ddf421 100644
--- a/test/text_to_binary.mode_setting_test.cpp
+++ b/test/text_to_binary.mode_setting_test.cpp
@@ -20,6 +20,7 @@
#include <vector>
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -28,7 +29,7 @@
using spvtest::EnumCase;
using spvtest::MakeInstruction;
-using spvtest::MakeVector;
+using utils::MakeVector;
using ::testing::Combine;
using ::testing::Eq;
using ::testing::TestWithParam;
diff --git a/test/unit_spirv.cpp b/test/unit_spirv.cpp
index 84ed87a..0854439 100644
--- a/test/unit_spirv.cpp
+++ b/test/unit_spirv.cpp
@@ -15,12 +15,13 @@
#include "test/unit_spirv.h"
#include "gmock/gmock.h"
+#include "source/util/string_utils.h"
#include "test/test_fixture.h"
namespace spvtools {
namespace {
-using spvtest::MakeVector;
+using utils::MakeVector;
using ::testing::Eq;
using Words = std::vector<uint32_t>;
diff --git a/test/unit_spirv.h b/test/unit_spirv.h
index 2244288..3264662 100644
--- a/test/unit_spirv.h
+++ b/test/unit_spirv.h
@@ -133,29 +133,6 @@
return result;
}
-// Encodes a string as a sequence of words, using the SPIR-V encoding.
-inline std::vector<uint32_t> MakeVector(std::string input) {
- std::vector<uint32_t> result;
- uint32_t word = 0;
- size_t num_bytes = input.size();
- // SPIR-V strings are null-terminated. The byte_index == num_bytes
- // case is used to push the terminating null byte.
- for (size_t byte_index = 0; byte_index <= num_bytes; byte_index++) {
- const auto new_byte =
- (byte_index < num_bytes ? uint8_t(input[byte_index]) : uint8_t(0));
- word |= (new_byte << (8 * (byte_index % sizeof(uint32_t))));
- if (3 == (byte_index % sizeof(uint32_t))) {
- result.push_back(word);
- word = 0;
- }
- }
- // Emit a trailing partial word.
- if ((num_bytes + 1) % sizeof(uint32_t)) {
- result.push_back(word);
- }
- return result;
-}
-
// A type for easily creating spv_text_t values, with an implicit conversion to
// spv_text.
struct AutoText {
diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp
index 8dcf933..c18b64c 100644
--- a/tools/opt/opt.cpp
+++ b/tools/opt/opt.cpp
@@ -147,6 +147,15 @@
around known issues with some Vulkan drivers for initialize
variables.)");
printf(R"(
+ --descriptor-scalar-replacement
+ Replaces every array variable |desc| that has a DescriptorSet
+ and Binding decorations with a new variable for each element of
+ the array. Suppose |desc| was bound at binding |b|. Then the
+ variable corresponding to |desc[i]| will have binding |b+i|.
+ The descriptor set will be the same. All accesses to |desc|
+ must be in OpAccessChain instructions with a literal index for
+ the first index.)");
+ printf(R"(
--eliminate-dead-branches
Convert conditional branches with constant condition to the
indicated unconditional brranch. Delete all resulting dead