GraphiteDawn: cache single texture bind group

Bug: b/260368758
Change-Id: Iafc8c340b37feb8cb184f9da9878f9190b3213bc
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/780956
Reviewed-by: Michael Ludwig <michaelludwig@google.com>
Commit-Queue: Quyen Le <lehoangquyen@chromium.org>
diff --git a/src/gpu/graphite/dawn/DawnCommandBuffer.cpp b/src/gpu/graphite/dawn/DawnCommandBuffer.cpp
index f6cb2ba..4d6c69e 100644
--- a/src/gpu/graphite/dawn/DawnCommandBuffer.cpp
+++ b/src/gpu/graphite/dawn/DawnCommandBuffer.cpp
@@ -284,6 +284,8 @@
         return false;
     }
 
+    this->trackResource(msaaLoadTexture);
+
     // Creating intermediate render pass (copy from resolve texture -> MSAA load texture)
     RenderPassDesc intermediateRenderPassDesc = {};
     intermediateRenderPassDesc.fColorAttachment.fLoadOp = LoadOp::kDiscard;
@@ -529,37 +531,49 @@
     SkASSERT(fActiveRenderPassEncoder);
     SkASSERT(fActiveGraphicsPipeline);
 
-    // TODO: optimize for single texture.
-    std::vector<wgpu::BindGroupEntry> entries(2 * command.fNumTexSamplers);
+    wgpu::BindGroup bindGroup;
+    if (command.fNumTexSamplers == 1) {
+        // Optimize for single texture.
+        SkASSERT(fActiveGraphicsPipeline->numTexturesAndSamplers() == 2);
 
-    for (int i = 0; i < command.fNumTexSamplers; ++i) {
         const auto* texture =
-                static_cast<const DawnTexture*>(drawPass.getTexture(command.fTextureIndices[i]));
+                static_cast<const DawnTexture*>(drawPass.getTexture(command.fTextureIndices[0]));
         const auto* sampler =
-                static_cast<const DawnSampler*>(drawPass.getSampler(command.fSamplerIndices[i]));
-        auto& wgpuTextureView = texture->sampleTextureView();
-        auto& wgpuSampler = sampler->dawnSampler();
+                static_cast<const DawnSampler*>(drawPass.getSampler(command.fSamplerIndices[0]));
 
-        // Assuming shader generator assigns binding slot to sampler then texture,
-        // then the next sampler and texture, and so on, we need to use
-        // 2 * i as base binding index of the sampler and texture.
-        // TODO: https://b.corp.google.com/issues/259457090:
-        // Better configurable way of assigning samplers and textures' bindings.
-        entries[2 * i].binding = 2 * i;
-        entries[2 * i].sampler = wgpuSampler;
+        bindGroup = fResourceProvider->findOrCreateSingleTextureSamplerBindGroup(sampler, texture);
+    } else {
+        std::vector<wgpu::BindGroupEntry> entries(2 * command.fNumTexSamplers);
 
-        entries[2 * i + 1].binding = 2 * i + 1;
-        entries[2 * i + 1].textureView = wgpuTextureView;
+        for (int i = 0; i < command.fNumTexSamplers; ++i) {
+            const auto* texture = static_cast<const DawnTexture*>(
+                    drawPass.getTexture(command.fTextureIndices[i]));
+            const auto* sampler = static_cast<const DawnSampler*>(
+                    drawPass.getSampler(command.fSamplerIndices[i]));
+            auto& wgpuTextureView = texture->sampleTextureView();
+            auto& wgpuSampler = sampler->dawnSampler();
+
+            // Assuming shader generator assigns binding slot to sampler then texture,
+            // then the next sampler and texture, and so on, we need to use
+            // 2 * i as base binding index of the sampler and texture.
+            // TODO: https://b.corp.google.com/issues/259457090:
+            // Better configurable way of assigning samplers and textures' bindings.
+            entries[2 * i].binding = 2 * i;
+            entries[2 * i].sampler = wgpuSampler;
+
+            entries[2 * i + 1].binding = 2 * i + 1;
+            entries[2 * i + 1].textureView = wgpuTextureView;
+        }
+
+        wgpu::BindGroupDescriptor desc;
+        const auto& groupLayouts = fActiveGraphicsPipeline->dawnGroupLayouts();
+        desc.layout = groupLayouts[DawnGraphicsPipeline::kTextureBindGroupIndex];
+        desc.entryCount = entries.size();
+        desc.entries = entries.data();
+
+        bindGroup = fSharedContext->device().CreateBindGroup(&desc);
     }
 
-    wgpu::BindGroupDescriptor desc;
-    const auto& groupLayouts = fActiveGraphicsPipeline->dawnGroupLayouts();
-    desc.layout = groupLayouts[DawnGraphicsPipeline::kTextureBindGroupIndex];
-    desc.entryCount = entries.size();
-    desc.entries = entries.data();
-
-    auto bindGroup = fSharedContext->device().CreateBindGroup(&desc);
-
     fActiveRenderPassEncoder.SetBindGroup(DawnGraphicsPipeline::kTextureBindGroupIndex, bindGroup);
 }
 
diff --git a/src/gpu/graphite/dawn/DawnGraphicsPipeline.cpp b/src/gpu/graphite/dawn/DawnGraphicsPipeline.cpp
index 9bf00b4..d2f219f 100644
--- a/src/gpu/graphite/dawn/DawnGraphicsPipeline.cpp
+++ b/src/gpu/graphite/dawn/DawnGraphicsPipeline.cpp
@@ -386,27 +386,33 @@
 
         bool hasFragmentSamplers = hasFragmentSkSL && numTexturesAndSamplers > 0;
         if (hasFragmentSamplers) {
-            std::vector<wgpu::BindGroupLayoutEntry> entries(numTexturesAndSamplers);
-            for (int i = 0; i < numTexturesAndSamplers;) {
-                entries[i].binding = static_cast<uint32_t>(i);
-                entries[i].visibility = wgpu::ShaderStage::Fragment;
-                entries[i].sampler.type = wgpu::SamplerBindingType::Filtering;
-                ++i;
-                entries[i].binding = i;
-                entries[i].visibility = wgpu::ShaderStage::Fragment;
-                entries[i].texture.sampleType = wgpu::TextureSampleType::Float;
-                entries[i].texture.viewDimension = wgpu::TextureViewDimension::e2D;
-                entries[i].texture.multisampled = false;
-                ++i;
-            }
+            if (numTexturesAndSamplers == 2) {
+                // Common case: single texture + sampler.
+                groupLayouts[1] =
+                        resourceProvider->getOrCreateSingleTextureSamplerBindGroupLayout();
+            } else {
+                std::vector<wgpu::BindGroupLayoutEntry> entries(numTexturesAndSamplers);
+                for (int i = 0; i < numTexturesAndSamplers;) {
+                    entries[i].binding = static_cast<uint32_t>(i);
+                    entries[i].visibility = wgpu::ShaderStage::Fragment;
+                    entries[i].sampler.type = wgpu::SamplerBindingType::Filtering;
+                    ++i;
+                    entries[i].binding = i;
+                    entries[i].visibility = wgpu::ShaderStage::Fragment;
+                    entries[i].texture.sampleType = wgpu::TextureSampleType::Float;
+                    entries[i].texture.viewDimension = wgpu::TextureViewDimension::e2D;
+                    entries[i].texture.multisampled = false;
+                    ++i;
+                }
 
-            wgpu::BindGroupLayoutDescriptor groupLayoutDesc;
+                wgpu::BindGroupLayoutDescriptor groupLayoutDesc;
 #if defined(SK_DEBUG)
-            groupLayoutDesc.label = step->name();
+                groupLayoutDesc.label = step->name();
 #endif
-            groupLayoutDesc.entryCount = entries.size();
-            groupLayoutDesc.entries = entries.data();
-            groupLayouts[1] = device.CreateBindGroupLayout(&groupLayoutDesc);
+                groupLayoutDesc.entryCount = entries.size();
+                groupLayoutDesc.entries = entries.data();
+                groupLayouts[1] = device.CreateBindGroupLayout(&groupLayoutDesc);
+            }
             if (!groupLayouts[1]) {
                 return {};
             }
@@ -554,7 +560,8 @@
                                      step->primitiveType(),
                                      depthStencilSettings.fStencilReferenceValue,
                                      /*hasStepUniforms=*/!step->uniforms().empty(),
-                                     /*hasPaintUniforms=*/fsSkSLInfo.fNumPaintUniforms > 0));
+                                     /*hasPaintUniforms=*/fsSkSLInfo.fNumPaintUniforms > 0,
+                                     numTexturesAndSamplers));
 }
 
 DawnGraphicsPipeline::DawnGraphicsPipeline(const skgpu::graphite::SharedContext* sharedContext,
@@ -564,14 +571,16 @@
                                            PrimitiveType primitiveType,
                                            uint32_t refValue,
                                            bool hasStepUniforms,
-                                           bool hasPaintUniforms)
+                                           bool hasPaintUniforms,
+                                           int numFragmentTexturesAndSamplers)
         : GraphicsPipeline(sharedContext, pipelineInfo)
         , fAsyncPipelineCreation(std::move(asyncCreationInfo))
         , fGroupLayouts(std::move(groupLayouts))
         , fPrimitiveType(primitiveType)
         , fStencilReferenceValue(refValue)
         , fHasStepUniforms(hasStepUniforms)
-        , fHasPaintUniforms(hasPaintUniforms) {}
+        , fHasPaintUniforms(hasPaintUniforms)
+        , fNumFragmentTexturesAndSamplers(numFragmentTexturesAndSamplers) {}
 
 void DawnGraphicsPipeline::freeGpuData() {
     fAsyncPipelineCreation = nullptr;
diff --git a/src/gpu/graphite/dawn/DawnGraphicsPipeline.h b/src/gpu/graphite/dawn/DawnGraphicsPipeline.h
index b7fb93c..1bdc1d4 100644
--- a/src/gpu/graphite/dawn/DawnGraphicsPipeline.h
+++ b/src/gpu/graphite/dawn/DawnGraphicsPipeline.h
@@ -65,6 +65,7 @@
     PrimitiveType primitiveType() const { return fPrimitiveType; }
     bool hasStepUniforms() const { return fHasStepUniforms; }
     bool hasPaintUniforms() const { return fHasPaintUniforms; }
+    int numTexturesAndSamplers() const { return fNumFragmentTexturesAndSamplers; }
     const wgpu::RenderPipeline& dawnRenderPipeline() const;
 
     using BindGroupLayouts = std::array<wgpu::BindGroupLayout, kBindGroupCount>;
@@ -80,7 +81,8 @@
                          PrimitiveType primitiveType,
                          uint32_t refValue,
                          bool hasStepUniforms,
-                         bool hasPaintUniforms);
+                         bool hasPaintUniforms,
+                         int numFragmentTexturesAndSamplers);
 
     void freeGpuData() override;
 
@@ -90,6 +92,7 @@
     const uint32_t fStencilReferenceValue;
     const bool fHasStepUniforms;
     const bool fHasPaintUniforms;
+    const int fNumFragmentTexturesAndSamplers;
 };
 
 } // namespace skgpu::graphite
diff --git a/src/gpu/graphite/dawn/DawnResourceProvider.cpp b/src/gpu/graphite/dawn/DawnResourceProvider.cpp
index d410654..1cc1d75 100644
--- a/src/gpu/graphite/dawn/DawnResourceProvider.cpp
+++ b/src/gpu/graphite/dawn/DawnResourceProvider.cpp
@@ -24,7 +24,8 @@
 namespace {
 
 constexpr int kBufferBindingSizeAlignment = 16;
-constexpr int kMaxNumberOfCachedBindGroups = 32;
+constexpr int kMaxNumberOfCachedBufferBindGroups = 32;
+constexpr int kMaxNumberOfCachedTextureBindGroups = 256;
 
 wgpu::ShaderModule create_shader_module(const wgpu::Device& device, const char* source) {
     wgpu::ShaderModuleWGSLDescriptor wgslDesc;
@@ -124,6 +125,25 @@
 
     return uniqueKey;
 }
+
+UniqueKey make_texture_bind_group_key(const DawnSampler* sampler, const DawnTexture* texture) {
+    static const UniqueKey::Domain kTextureBindGroupDomain = UniqueKey::GenerateDomain();
+
+    UniqueKey uniqueKey;
+    {
+        UniqueKey::Builder builder(&uniqueKey,
+                                   kTextureBindGroupDomain,
+                                   2,
+                                   "GraphicsPipelineSingleTextureSamplerBindGroup");
+
+        builder[0] = sampler->uniqueID().asUInt();
+        builder[1] = texture->uniqueID().asUInt();
+
+        builder.finish();
+    }
+
+    return uniqueKey;
+}
 }  // namespace
 
 DawnResourceProvider::DawnResourceProvider(SharedContext* sharedContext,
@@ -131,7 +151,8 @@
                                            uint32_t recorderID,
                                            size_t resourceBudget)
         : ResourceProvider(sharedContext, singleOwner, recorderID, resourceBudget)
-        , fUniformBufferBindGroupCache(kMaxNumberOfCachedBindGroups) {}
+        , fUniformBufferBindGroupCache(kMaxNumberOfCachedBufferBindGroups)
+        , fSingleTextureSamplerBindGroups(kMaxNumberOfCachedTextureBindGroups) {}
 
 DawnResourceProvider::~DawnResourceProvider() = default;
 
@@ -337,6 +358,37 @@
     return fUniformBuffersBindGroupLayout;
 }
 
+const wgpu::BindGroupLayout&
+DawnResourceProvider::getOrCreateSingleTextureSamplerBindGroupLayout() {
+    if (fSingleTextureSamplerBindGroupLayout) {
+        return fSingleTextureSamplerBindGroupLayout;
+    }
+
+    std::array<wgpu::BindGroupLayoutEntry, 2> entries;
+
+    entries[0].binding = 0;
+    entries[0].visibility = wgpu::ShaderStage::Fragment;
+    entries[0].sampler.type = wgpu::SamplerBindingType::Filtering;
+
+    entries[1].binding = 1;
+    entries[1].visibility = wgpu::ShaderStage::Fragment;
+    entries[1].texture.sampleType = wgpu::TextureSampleType::Float;
+    entries[1].texture.viewDimension = wgpu::TextureViewDimension::e2D;
+    entries[1].texture.multisampled = false;
+
+    wgpu::BindGroupLayoutDescriptor groupLayoutDesc;
+#if defined(SK_DEBUG)
+    groupLayoutDesc.label = "Single texture + sampler bind group layout";
+#endif
+
+    groupLayoutDesc.entryCount = entries.size();
+    groupLayoutDesc.entries = entries.data();
+    fSingleTextureSamplerBindGroupLayout =
+            this->dawnSharedContext()->device().CreateBindGroupLayout(&groupLayoutDesc);
+
+    return fSingleTextureSamplerBindGroupLayout;
+}
+
 const wgpu::Buffer& DawnResourceProvider::getOrCreateNullBuffer() {
     if (!fNullBuffer) {
         wgpu::BufferDescriptor desc;
@@ -399,4 +451,31 @@
     return *fUniformBufferBindGroupCache.insert(key, bindGroup);
 }
 
+const wgpu::BindGroup& DawnResourceProvider::findOrCreateSingleTextureSamplerBindGroup(
+        const DawnSampler* sampler, const DawnTexture* texture) {
+    auto key = make_texture_bind_group_key(sampler, texture);
+    auto* existingBindGroup = fSingleTextureSamplerBindGroups.find(key);
+    if (existingBindGroup) {
+        // cache hit.
+        return *existingBindGroup;
+    }
+
+    std::array<wgpu::BindGroupEntry, 2> entries;
+
+    entries[0].binding = 0;
+    entries[0].sampler = sampler->dawnSampler();
+    entries[1].binding = 1;
+    entries[1].textureView = texture->sampleTextureView();
+
+    wgpu::BindGroupDescriptor desc;
+    desc.layout = getOrCreateSingleTextureSamplerBindGroupLayout();
+    desc.entryCount = entries.size();
+    desc.entries = entries.data();
+
+    const auto& device = this->dawnSharedContext()->device();
+    auto bindGroup = device.CreateBindGroup(&desc);
+
+    return *fSingleTextureSamplerBindGroups.insert(key, bindGroup);
+}
+
 } // namespace skgpu::graphite
diff --git a/src/gpu/graphite/dawn/DawnResourceProvider.h b/src/gpu/graphite/dawn/DawnResourceProvider.h
index b2be5ce..3a86a5e 100644
--- a/src/gpu/graphite/dawn/DawnResourceProvider.h
+++ b/src/gpu/graphite/dawn/DawnResourceProvider.h
@@ -15,6 +15,7 @@
 namespace skgpu::graphite {
 
 class DawnGraphicsPipeline;
+class DawnSampler;
 class DawnSharedContext;
 class DawnTexture;
 class DawnBuffer;
@@ -37,6 +38,7 @@
     sk_sp<DawnBuffer> findOrCreateDawnBuffer(size_t size, BufferType type, AccessPattern);
 
     const wgpu::BindGroupLayout& getOrCreateUniformBuffersBindGroupLayout();
+    const wgpu::BindGroupLayout& getOrCreateSingleTextureSamplerBindGroupLayout();
 
     // Find the cached bind group or create a new one based on the bound buffers and their
     // binding sizes (boundBuffersAndSizes) for these uniforms (in order):
@@ -46,6 +48,10 @@
     const wgpu::BindGroup& findOrCreateUniformBuffersBindGroup(
             const std::array<std::pair<const DawnBuffer*, uint32_t>, 3>& boundBuffersAndSizes);
 
+    // Find or create a bind group containing the given sampler & texture.
+    const wgpu::BindGroup& findOrCreateSingleTextureSamplerBindGroup(const DawnSampler* sampler,
+                                                                     const DawnTexture* texture);
+
 private:
     sk_sp<GraphicsPipeline> createGraphicsPipeline(const RuntimeEffectDictionary*,
                                                    const GraphicsPipelineDesc&,
@@ -69,6 +75,7 @@
     skia_private::THashMap<uint64_t, wgpu::RenderPipeline> fBlitWithDrawPipelines;
 
     wgpu::BindGroupLayout fUniformBuffersBindGroupLayout;
+    wgpu::BindGroupLayout fSingleTextureSamplerBindGroupLayout;
 
     wgpu::Buffer fNullBuffer;
 
@@ -76,8 +83,10 @@
         uint32_t operator()(const skgpu::UniqueKey& key) const { return key.hash(); }
     };
 
-    using BufferBindGroupCache = SkLRUCache<UniqueKey, wgpu::BindGroup, UniqueKeyHash>;
-    BufferBindGroupCache fUniformBufferBindGroupCache;
+    using BindGroupCache = SkLRUCache<UniqueKey, wgpu::BindGroup, UniqueKeyHash>;
+
+    BindGroupCache fUniformBufferBindGroupCache;
+    BindGroupCache fSingleTextureSamplerBindGroups;
 };
 
 } // namespace skgpu::graphite
diff --git a/src/gpu/graphite/dawn/DawnTexture.cpp b/src/gpu/graphite/dawn/DawnTexture.cpp
index ed87408..12f0fc3 100644
--- a/src/gpu/graphite/dawn/DawnTexture.cpp
+++ b/src/gpu/graphite/dawn/DawnTexture.cpp
@@ -184,6 +184,12 @@
 }
 
 void DawnTexture::freeGpuData() {
+    if (this->ownership() != Ownership::kWrapped && fTexture) {
+        // Destroy the texture even if it is still referenced by other BindGroup or views.
+        // Graphite should already guarantee that all command buffers using this texture (indirectly
+        // via BindGroup or views) are already completed.
+        fTexture.Destroy();
+    }
     fTexture = nullptr;
     fSampleTextureView = nullptr;
     fRenderTextureView = nullptr;