Apply scalar replacement on vars with Pointer decorations (#5208)

We want to be able to apply scalar replacement on variables that have
the AliasPointer and RestrictPointer decorations.

This exposed a bug that needs to be fixed as well.

Scalar replacement sometimes uses the type manager to get the type id for the
variables it is creating. The variable type is a pointer to a pointee
type. Currently, scalar replacement uses the type manager when only if
the pointee type has to be unique in the module. This is done to try to avoid the case where two type hash to the same
value in the type manager, and it returns the wrong one.

However, this check is not the correct check. Pointer types still have to be
unique in the spir-v module. However, two unique pointer types can hash
to the same value if their pointee types are isomorphic. For example,

%s1 = OpTypeStruct %int
%s2 = OpTypeStruct %int
; %p1 and %p2 will hash to the same value even though they are still
; considered "unique".
%p1 = OpTypePointer Function %s1
%p2 = OpTypePointer Function %s2
To fix this, we now use FindPointerToType, and we modified TypeManager::IsUnique to refer to the whether or not a type will hash to a unique value and say that pointers are not unique.

Fixes #5196
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
index ae1a2a3..38c8aec 100644
--- a/source/opt/scalar_replacement_pass.cpp
+++ b/source/opt/scalar_replacement_pass.cpp
@@ -466,9 +466,9 @@
 }
 
 void ScalarReplacementPass::CreateVariable(
-    uint32_t typeId, Instruction* varInst, uint32_t index,
+    uint32_t type_id, Instruction* var_inst, uint32_t index,
     std::vector<Instruction*>* replacements) {
-  uint32_t ptrId = GetOrCreatePointerType(typeId);
+  uint32_t ptr_id = GetOrCreatePointerType(type_id);
   uint32_t id = TakeNextId();
 
   if (id == 0) {
@@ -476,51 +476,22 @@
   }
 
   std::unique_ptr<Instruction> variable(
-      new Instruction(context(), spv::Op::OpVariable, ptrId, id,
+      new Instruction(context(), spv::Op::OpVariable, ptr_id, id,
                       std::initializer_list<Operand>{
                           {SPV_OPERAND_TYPE_STORAGE_CLASS,
                            {uint32_t(spv::StorageClass::Function)}}}));
 
-  BasicBlock* block = context()->get_instr_block(varInst);
+  BasicBlock* block = context()->get_instr_block(var_inst);
   block->begin().InsertBefore(std::move(variable));
   Instruction* inst = &*block->begin();
 
   // If varInst was initialized, make sure to initialize its replacement.
-  GetOrCreateInitialValue(varInst, index, inst);
+  GetOrCreateInitialValue(var_inst, index, inst);
   get_def_use_mgr()->AnalyzeInstDefUse(inst);
   context()->set_instr_block(inst, block);
 
-  // Copy decorations from the member to the new variable.
-  Instruction* typeInst = GetStorageType(varInst);
-  for (auto dec_inst :
-       get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
-    uint32_t decoration;
-    if (dec_inst->opcode() != spv::Op::OpMemberDecorate) {
-      continue;
-    }
-
-    if (dec_inst->GetSingleWordInOperand(1) != index) {
-      continue;
-    }
-
-    decoration = dec_inst->GetSingleWordInOperand(2u);
-    switch (spv::Decoration(decoration)) {
-      case spv::Decoration::RelaxedPrecision: {
-        std::unique_ptr<Instruction> new_dec_inst(
-            new Instruction(context(), spv::Op::OpDecorate, 0, 0, {}));
-        new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
-        for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
-          new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
-        }
-        context()->AddAnnotationInst(std::move(new_dec_inst));
-      } break;
-      default:
-        break;
-    }
-  }
-
-  // Update the DebugInfo debug information.
-  inst->UpdateDebugInfoFrom(varInst);
+  CopyDecorationsToVariable(var_inst, inst, index);
+  inst->UpdateDebugInfoFrom(var_inst);
 
   replacements->push_back(inst);
 }
@@ -529,52 +500,11 @@
   auto iter = pointee_to_pointer_.find(id);
   if (iter != pointee_to_pointer_.end()) return iter->second;
 
-  analysis::Type* pointeeTy;
-  std::unique_ptr<analysis::Pointer> pointerTy;
-  std::tie(pointeeTy, pointerTy) =
-      context()->get_type_mgr()->GetTypeAndPointerType(
-          id, spv::StorageClass::Function);
-  uint32_t ptrId = 0;
-  if (pointeeTy->IsUniqueType()) {
-    // Non-ambiguous type, just ask the type manager for an id.
-    ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
-    pointee_to_pointer_[id] = ptrId;
-    return ptrId;
-  }
-
-  // Ambiguous type. We must perform a linear search to try and find the right
-  // type.
-  for (auto global : context()->types_values()) {
-    if (global.opcode() == spv::Op::OpTypePointer &&
-        spv::StorageClass(global.GetSingleWordInOperand(0u)) ==
-            spv::StorageClass::Function &&
-        global.GetSingleWordInOperand(1u) == id) {
-      if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
-        // Only reuse a decoration-less pointer of the correct type.
-        ptrId = global.result_id();
-        break;
-      }
-    }
-  }
-
-  if (ptrId != 0) {
-    pointee_to_pointer_[id] = ptrId;
-    return ptrId;
-  }
-
-  ptrId = TakeNextId();
-  context()->AddType(MakeUnique<Instruction>(
-      context(), spv::Op::OpTypePointer, 0, ptrId,
-      std::initializer_list<Operand>{{SPV_OPERAND_TYPE_STORAGE_CLASS,
-                                      {uint32_t(spv::StorageClass::Function)}},
-                                     {SPV_OPERAND_TYPE_ID, {id}}}));
-  Instruction* ptr = &*--context()->types_values_end();
-  get_def_use_mgr()->AnalyzeInstDefUse(ptr);
-  pointee_to_pointer_[id] = ptrId;
-  // Register with the type manager if necessary.
-  context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
-
-  return ptrId;
+  analysis::TypeManager* type_mgr = context()->get_type_mgr();
+  uint32_t ptr_type_id =
+      type_mgr->FindPointerToType(id, spv::StorageClass::Function);
+  pointee_to_pointer_[id] = ptr_type_id;
+  return ptr_type_id;
 }
 
 void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
@@ -761,6 +691,8 @@
       case spv::Decoration::AlignmentId:
       case spv::Decoration::MaxByteOffset:
       case spv::Decoration::RelaxedPrecision:
+      case spv::Decoration::AliasedPointer:
+      case spv::Decoration::RestrictPointer:
         break;
       default:
         return false;
@@ -781,6 +713,8 @@
       case spv::Decoration::Alignment:
       case spv::Decoration::AlignmentId:
       case spv::Decoration::MaxByteOffset:
+      case spv::Decoration::AliasedPointer:
+      case spv::Decoration::RestrictPointer:
         break;
       default:
         return false;
@@ -1011,5 +945,69 @@
   return 0;
 }
 
+void ScalarReplacementPass::CopyDecorationsToVariable(Instruction* from,
+                                                      Instruction* to,
+                                                      uint32_t member_index) {
+  CopyPointerDecorationsToVariable(from, to);
+  CopyNecessaryMemberDecorationsToVariable(from, to, member_index);
+}
+
+void ScalarReplacementPass::CopyPointerDecorationsToVariable(Instruction* from,
+                                                             Instruction* to) {
+  // The RestrictPointer and AliasedPointer decorations are copied to all
+  // members even if the new variable does not contain a pointer. It does
+  // not hurt to do so.
+  for (auto dec_inst :
+       get_decoration_mgr()->GetDecorationsFor(from->result_id(), false)) {
+    uint32_t decoration;
+    decoration = dec_inst->GetSingleWordInOperand(1u);
+    switch (spv::Decoration(decoration)) {
+      case spv::Decoration::AliasedPointer:
+      case spv::Decoration::RestrictPointer: {
+        std::unique_ptr<Instruction> new_dec_inst(dec_inst->Clone(context()));
+        new_dec_inst->SetInOperand(0, {to->result_id()});
+        context()->AddAnnotationInst(std::move(new_dec_inst));
+      } break;
+      default:
+        break;
+    }
+  }
+}
+
+void ScalarReplacementPass::CopyNecessaryMemberDecorationsToVariable(
+    Instruction* from, Instruction* to, uint32_t member_index) {
+  Instruction* type_inst = GetStorageType(from);
+  for (auto dec_inst :
+       get_decoration_mgr()->GetDecorationsFor(type_inst->result_id(), false)) {
+    uint32_t decoration;
+    if (dec_inst->opcode() == spv::Op::OpMemberDecorate) {
+      if (dec_inst->GetSingleWordInOperand(1) != member_index) {
+        continue;
+      }
+
+      decoration = dec_inst->GetSingleWordInOperand(2u);
+      switch (spv::Decoration(decoration)) {
+        case spv::Decoration::ArrayStride:
+        case spv::Decoration::Alignment:
+        case spv::Decoration::AlignmentId:
+        case spv::Decoration::MaxByteOffset:
+        case spv::Decoration::MaxByteOffsetId:
+        case spv::Decoration::RelaxedPrecision: {
+          std::unique_ptr<Instruction> new_dec_inst(
+              new Instruction(context(), spv::Op::OpDecorate, 0, 0, {}));
+          new_dec_inst->AddOperand(
+              Operand(SPV_OPERAND_TYPE_ID, {to->result_id()}));
+          for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
+            new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
+          }
+          context()->AddAnnotationInst(std::move(new_dec_inst));
+        } break;
+        default:
+          break;
+      }
+    }
+  }
+}
+
 }  // namespace opt
 }  // namespace spvtools
diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h
index 0bcd2a4..c73ecfd 100644
--- a/source/opt/scalar_replacement_pass.h
+++ b/source/opt/scalar_replacement_pass.h
@@ -262,9 +262,26 @@
   // that we will be willing to split.
   bool IsLargerThanSizeLimit(uint64_t length) const;
 
+  // Copies all relevant decorations from `from` to `to`. This includes
+  // decorations applied to the variable, and to the members of the type.
+  // It is assumed that `to` is a variable that is intended to replace the
+  // `member_index`th member of `from`.
+  void CopyDecorationsToVariable(Instruction* from, Instruction* to,
+                                 uint32_t member_index);
+
+  // Copies pointer related decoration from `from` to `to` if they exist.
+  void CopyPointerDecorationsToVariable(Instruction* from, Instruction* to);
+
+  // Copies decorations that are needed from the `member_index` of `from` to
+  // `to, if there was one.
+  void CopyNecessaryMemberDecorationsToVariable(Instruction* from,
+                                                Instruction* to,
+                                                uint32_t member_index);
+
   // Limit on the number of members in an object that will be replaced.
   // 0 means there is no limit.
   uint32_t max_num_elements_;
+
   // This has to be big enough to fit "scalar-replacement=" followed by a
   // uint32_t number written in decimal (so 10 digits), and then a
   // terminating nul.
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index 6e4c054..1b1aead 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -178,7 +178,7 @@
   if (iter == id_to_type_.end()) return;
 
   auto& type = iter->second;
-  if (!type->IsUniqueType(true)) {
+  if (!type->IsUniqueType()) {
     auto tIter = type_to_id_.find(type);
     if (tIter != type_to_id_.end() && tIter->second == id) {
       // |type| currently maps to |id|.
@@ -437,7 +437,7 @@
                                         spv::StorageClass storage_class) {
   Type* pointeeTy = GetType(type_id);
   Pointer pointerTy(pointeeTy, storage_class);
-  if (pointeeTy->IsUniqueType(true)) {
+  if (pointeeTy->IsUniqueType()) {
     // Non-ambiguous type. Get the pointer type through the type manager.
     return GetTypeInstruction(&pointerTy);
   }
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index 2f18362..49eec9b 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -84,10 +84,9 @@
   return CompareTwoVectors(decorations_, that->decorations_);
 }
 
-bool Type::IsUniqueType(bool allowVariablePointers) const {
+bool Type::IsUniqueType() const {
   switch (kind_) {
     case kPointer:
-      return !allowVariablePointers;
     case kStruct:
     case kArray:
     case kRuntimeArray:
diff --git a/source/opt/types.h b/source/opt/types.h
index 1f32937..26c058c 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -148,12 +148,16 @@
   // Returns a clone of |this| minus any decorations.
   std::unique_ptr<Type> RemoveDecorations() const;
 
-  // Returns true if this type must be unique.
+  // Returns true if this cannot hash to the same value as another type in the
+  // module. For example, structs are not unique types because the module could
+  // have two types
   //
-  // If variable pointers are allowed, then pointers are not required to be
-  // unique.
-  // TODO(alanbaker): Update this if variable pointers become a core feature.
-  bool IsUniqueType(bool allowVariablePointers = false) const;
+  //  %1 = OpTypeStruct %int
+  //  %2 = OpTypeStruct %int
+  //
+  // The only way to distinguish these types is the result id. The type manager
+  // will hash them to the same value.
+  bool IsUniqueType() const;
 
   bool operator==(const Type& other) const;
 
diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp
index b63a6b6..0ba285b 100644
--- a/test/opt/scalar_replacement_test.cpp
+++ b/test/opt/scalar_replacement_test.cpp
@@ -2308,6 +2308,54 @@
   SinglePassRunAndMatch<ScalarReplacementPass>(text, true);
 }
 
+TEST_F(ScalarReplacementTest, RestrictPointer) {
+  // This test makes sure that a variable with the restrict pointer decoration
+  // is replaced, and that the pointer is applied to the new variable.
+  const std::string text = R"(
+; CHECK: OpDecorate [[new_var:%\w+]] RestrictPointer
+; CHECK: [[struct_type:%\w+]] = OpTypeStruct %int
+; CHECK: [[ptr_type:%\w+]] = OpTypePointer PhysicalStorageBuffer [[struct_type]]
+; CHECK: [[dup_struct_type:%\w+]] = OpTypeStruct %int
+; CHECK: {{%\w+}} = OpTypePointer PhysicalStorageBuffer [[dup_struct_type]]
+; CHECK: [[var_type:%\w+]] = OpTypePointer Function [[ptr_type]]
+; CHECK: [[new_var]] = OpVariable [[var_type]] Function
+               OpCapability Shader
+               OpCapability PhysicalStorageBufferAddresses
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel PhysicalStorageBuffer64 GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpMemberDecorate %3 0 Offset 0
+               OpDecorate %3 Block
+               OpMemberDecorate %4 0 Offset 0
+               OpDecorate %4 Block
+               OpDecorate %5 RestrictPointer
+          %6 = OpTypeVoid
+          %7 = OpTypeFunction %6
+          %8 = OpTypeInt 32 1
+          %9 = OpConstant %8 0
+          %3 = OpTypeStruct %8
+         %10 = OpTypePointer PhysicalStorageBuffer %3
+         %11 = OpTypeStruct %10
+          %4 = OpTypeStruct %8
+         %12 = OpTypePointer PhysicalStorageBuffer %4
+         %13 = OpTypePointer Function %11
+         %14 = OpTypePointer Function %10
+         %15 = OpTypePointer Function %12
+         %16 = OpUndef %11
+          %2 = OpFunction %6 None %7
+         %17 = OpLabel
+          %5 = OpVariable %13 Function
+               OpStore %5 %16
+         %18 = OpAccessChain %14 %5 %9
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  SetTargetEnv(SPV_ENV_UNIVERSAL_1_6);
+  SinglePassRunAndMatch<ScalarReplacementPass>(text, true);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools
diff --git a/test/opt/types_test.cpp b/test/opt/types_test.cpp
index 4352b7c..4ceeb14 100644
--- a/test/opt/types_test.cpp
+++ b/test/opt/types_test.cpp
@@ -391,18 +391,13 @@
       case Type::kArray:
       case Type::kRuntimeArray:
       case Type::kStruct:
+      case Type::kPointer:
         expectation = false;
         break;
       default:
         break;
     }
-    EXPECT_EQ(t->IsUniqueType(false), expectation)
-        << "expected '" << t->str() << "' to be a "
-        << (expectation ? "" : "non-") << "unique type";
-
-    // Allowing variables pointers.
-    if (t->AsPointer()) expectation = false;
-    EXPECT_EQ(t->IsUniqueType(true), expectation)
+    EXPECT_EQ(t->IsUniqueType(), expectation)
         << "expected '" << t->str() << "' to be a "
         << (expectation ? "" : "non-") << "unique type";
   }