Merge pull request #236 from cdavis5e/push-descriptor

Support the VK_KHR_push_descriptor extension.
diff --git a/MoltenVK/MoltenVK/Commands/MVKCmdPipeline.h b/MoltenVK/MoltenVK/Commands/MVKCmdPipeline.h
index dfbd740..f9fd523 100644
--- a/MoltenVK/MoltenVK/Commands/MVKCmdPipeline.h
+++ b/MoltenVK/MoltenVK/Commands/MVKCmdPipeline.h
@@ -132,6 +132,35 @@
 
 
 #pragma mark -
+#pragma mark MVKCmdPushDescriptorSet
+
+/** Vulkan command to update a descriptor set. */
+class MVKCmdPushDescriptorSet : public MVKCommand {
+
+public:
+	void setContent(VkPipelineBindPoint pipelineBindPoint,
+					VkPipelineLayout layout,
+					uint32_t set,
+					uint32_t descriptorWriteCount,
+					const VkWriteDescriptorSet* pDescriptorWrites);
+
+	void encode(MVKCommandEncoder* cmdEncoder) override;
+
+	MVKCmdPushDescriptorSet(MVKCommandTypePool<MVKCmdPushDescriptorSet>* pool);
+
+	~MVKCmdPushDescriptorSet() override;
+
+private:
+	void clearDescriptorWrites();
+
+	VkPipelineBindPoint _pipelineBindPoint;
+	MVKPipelineLayout* _pipelineLayout;
+	std::vector<VkWriteDescriptorSet> _descriptorWrites;
+	uint32_t _set;
+};
+
+
+#pragma mark -
 #pragma mark Command creation functions
 
 /** Adds commands to the specified command buffer that insert the specified pipeline barriers. */
@@ -168,3 +197,11 @@
 						 uint32_t offset,
 						 uint32_t size,
 						 const void* pValues);
+
+/** Adds commands to the specified command buffer that update the specified descriptor set. */
+void mvkCmdPushDescriptorSet(MVKCommandBuffer* cmdBuff,
+							 VkPipelineBindPoint pipelineBindPoint,
+							 VkPipelineLayout layout,
+							 uint32_t set,
+							 uint32_t descriptorWriteCount,
+							 const VkWriteDescriptorSet* pDescriptorWrites);
diff --git a/MoltenVK/MoltenVK/Commands/MVKCmdPipeline.mm b/MoltenVK/MoltenVK/Commands/MVKCmdPipeline.mm
index 8383d65..d21af9f 100644
--- a/MoltenVK/MoltenVK/Commands/MVKCmdPipeline.mm
+++ b/MoltenVK/MoltenVK/Commands/MVKCmdPipeline.mm
@@ -180,6 +180,64 @@
 
 
 #pragma mark -
+#pragma mark MVKCmdPushDescriptorSet
+
+void MVKCmdPushDescriptorSet::setContent(VkPipelineBindPoint pipelineBindPoint,
+                                         VkPipelineLayout layout,
+                                         uint32_t set,
+                                         uint32_t descriptorWriteCount,
+                                         const VkWriteDescriptorSet* pDescriptorWrites) {
+	_pipelineBindPoint = pipelineBindPoint;
+	_pipelineLayout = (MVKPipelineLayout*)layout;
+	_set = set;
+
+	// Add the descriptor writes
+	clearDescriptorWrites();	// Clear for reuse
+	_descriptorWrites.reserve(descriptorWriteCount);
+	for (uint32_t dwIdx = 0; dwIdx < descriptorWriteCount; dwIdx++) {
+		_descriptorWrites.push_back(pDescriptorWrites[dwIdx]);
+		VkWriteDescriptorSet& descWrite = _descriptorWrites.back();
+		// Make a copy of the associated data.
+		if (descWrite.pImageInfo) {
+			auto* pNewImageInfo = new VkDescriptorImageInfo[descWrite.descriptorCount];
+			std::copy_n(descWrite.pImageInfo, descWrite.descriptorCount, pNewImageInfo);
+			descWrite.pImageInfo = pNewImageInfo;
+		}
+		if (descWrite.pBufferInfo) {
+			auto* pNewBufferInfo = new VkDescriptorBufferInfo[descWrite.descriptorCount];
+			std::copy_n(descWrite.pBufferInfo, descWrite.descriptorCount, pNewBufferInfo);
+			descWrite.pBufferInfo = pNewBufferInfo;
+		}
+		if (descWrite.pTexelBufferView) {
+			auto* pNewTexelBufferView = new VkBufferView[descWrite.descriptorCount];
+			std::copy_n(descWrite.pTexelBufferView, descWrite.descriptorCount, pNewTexelBufferView);
+			descWrite.pTexelBufferView = pNewTexelBufferView;
+		}
+	}
+}
+
+void MVKCmdPushDescriptorSet::encode(MVKCommandEncoder* cmdEncoder) {
+	_pipelineLayout->pushDescriptorSet(cmdEncoder, _descriptorWrites, _set);
+}
+
+MVKCmdPushDescriptorSet::MVKCmdPushDescriptorSet(MVKCommandTypePool<MVKCmdPushDescriptorSet>* pool)
+	: MVKCommand::MVKCommand((MVKCommandTypePool<MVKCommand>*)pool) {}
+
+MVKCmdPushDescriptorSet::~MVKCmdPushDescriptorSet() {
+	clearDescriptorWrites();
+}
+
+void MVKCmdPushDescriptorSet::clearDescriptorWrites() {
+	for (VkWriteDescriptorSet &descWrite : _descriptorWrites) {
+		if (descWrite.pImageInfo) delete[] descWrite.pImageInfo;
+		if (descWrite.pBufferInfo) delete[] descWrite.pBufferInfo;
+		if (descWrite.pTexelBufferView) delete[] descWrite.pTexelBufferView;
+	}
+	_descriptorWrites.clear();
+}
+
+
+#pragma mark -
 #pragma mark Command creation functions
 
 void mvkCmdPipelineBarrier(MVKCommandBuffer* cmdBuff,
@@ -232,3 +290,13 @@
 	cmdBuff->addCommand(cmd);
 }
 
+void mvkCmdPushDescriptorSet(MVKCommandBuffer* cmdBuff,
+							 VkPipelineBindPoint pipelineBindPoint,
+							 VkPipelineLayout layout,
+							 uint32_t set,
+							 uint32_t descriptorWriteCount,
+							 const VkWriteDescriptorSet* pDescriptorWrites) {
+	MVKCmdPushDescriptorSet* cmd = cmdBuff->_commandPool->_cmdPushDescriptorSetPool.acquireObject();
+	cmd->setContent(pipelineBindPoint, layout, set, descriptorWriteCount, pDescriptorWrites);
+	cmdBuff->addCommand(cmd);
+}
diff --git a/MoltenVK/MoltenVK/Commands/MVKCommandPool.h b/MoltenVK/MoltenVK/Commands/MVKCommandPool.h
index 2037bee..52db265 100644
--- a/MoltenVK/MoltenVK/Commands/MVKCommandPool.h
+++ b/MoltenVK/MoltenVK/Commands/MVKCommandPool.h
@@ -131,6 +131,8 @@
 
     MVKCommandTypePool<MVKCmdDispatchIndirect> _cmdDispatchIndirectPool;
 
+    MVKCommandTypePool<MVKCmdPushDescriptorSet> _cmdPushDescriptorSetPool;
+
 
 #pragma mark Command resources
 
diff --git a/MoltenVK/MoltenVK/Commands/MVKCommandPool.mm b/MoltenVK/MoltenVK/Commands/MVKCommandPool.mm
index bf21b89..e12ba78 100644
--- a/MoltenVK/MoltenVK/Commands/MVKCommandPool.mm
+++ b/MoltenVK/MoltenVK/Commands/MVKCommandPool.mm
@@ -118,7 +118,8 @@
     _cmdCopyQueryPoolResultsPool(this, true),
 	_cmdPushConstantsPool(this, true),
     _cmdDispatchPool(this, true),
-    _cmdDispatchIndirectPool(this, true)
+    _cmdDispatchIndirectPool(this, true),
+    _cmdPushDescriptorSetPool(this, true)
 {}
 
 // TODO: Destroying a command pool implicitly destroys all command buffers and commands created from it.
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKDescriptorSet.h b/MoltenVK/MoltenVK/GPUObjects/MVKDescriptorSet.h
index d03ca57..4e8a4cc 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKDescriptorSet.h
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKDescriptorSet.h
@@ -78,6 +78,16 @@
               std::vector<uint32_t>& dynamicOffsets,
               uint32_t* pDynamicOffsetIndex);
 
+    /** Encodes this binding layout and the specified descriptor binding on the specified command encoder immediately. */
+    void push(MVKCommandEncoder* cmdEncoder,
+              uint32_t& dstArrayElement,
+              uint32_t& descriptorCount,
+              VkDescriptorType descriptorType,
+              const VkDescriptorImageInfo*& pImageInfo,
+              const VkDescriptorBufferInfo*& pBufferInfo,
+              const VkBufferView*& pTexelBufferView,
+              MVKShaderResourceBinding& dslMTLRezIdxOffsets);
+
 	/** Populates the specified shader converter context, at the specified descriptor set binding. */
 	void populateShaderConverterContext(SPIRVToMSLConverterContext& context,
                                         MVKShaderResourceBinding& dslMTLRezIdxOffsets,
@@ -119,11 +129,20 @@
                            uint32_t* pDynamicOffsetIndex);
 
 
+	/** Encodes this descriptor set layout and the specified descriptor updates on the specified command encoder immediately. */
+	void pushDescriptorSet(MVKCommandEncoder* cmdEncoder,
+						   std::vector<VkWriteDescriptorSet>& descriptorWrites,
+						   MVKShaderResourceBinding& dslMTLRezIdxOffsets);
+
+
 	/** Populates the specified shader converter context, at the specified DSL index. */
 	void populateShaderConverterContext(SPIRVToMSLConverterContext& context,
                                         MVKShaderResourceBinding& dslMTLRezIdxOffsets,
                                         uint32_t dslIndex);
 
+	/** Returns true if this layout is for push descriptors only. */
+	bool isPushDescriptorLayout() const { return _isPushDescriptorLayout; }
+
 	/** Constructs an instance for the specified device. */
 	MVKDescriptorSetLayout(MVKDevice* device, const VkDescriptorSetLayoutCreateInfo* pCreateInfo);
 
@@ -135,6 +154,7 @@
 
 	std::vector<MVKDescriptorSetLayoutBinding> _bindings;
 	MVKShaderResourceBinding _mtlResourceCounts;
+	bool _isPushDescriptorLayout : 1;
 };
 
 
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKDescriptorSet.mm b/MoltenVK/MoltenVK/GPUObjects/MVKDescriptorSet.mm
index 9b514f7..d85e29c 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKDescriptorSet.mm
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKDescriptorSet.mm
@@ -173,6 +173,172 @@
     }
 }
 
+void MVKDescriptorSetLayoutBinding::push(MVKCommandEncoder* cmdEncoder,
+                                         uint32_t& dstArrayElement,
+                                         uint32_t& descriptorCount,
+                                         VkDescriptorType descriptorType,
+                                         const VkDescriptorImageInfo*& pImageInfo,
+                                         const VkDescriptorBufferInfo*& pBufferInfo,
+                                         const VkBufferView*& pTexelBufferView,
+                                         MVKShaderResourceBinding& dslMTLRezIdxOffsets) {
+    MVKMTLBufferBinding bb;
+    MVKMTLTextureBinding tb;
+    MVKMTLSamplerStateBinding sb;
+
+    if (dstArrayElement >= _info.descriptorCount) {
+        dstArrayElement -= _info.descriptorCount;
+        return;
+    }
+
+    if (descriptorType != _info.descriptorType) {
+        dstArrayElement = 0;
+        if (_info.descriptorCount > descriptorCount)
+            descriptorCount = 0;
+        else {
+            descriptorCount -= _info.descriptorCount;
+            pImageInfo += _info.descriptorCount;
+            pBufferInfo += _info.descriptorCount;
+            pTexelBufferView += _info.descriptorCount;
+        }
+        return;
+    }
+
+    // Establish the resource indices to use, by combining the offsets of the DSL and this DSL binding.
+    MVKShaderResourceBinding mtlIdxs = _mtlResourceIndexOffsets + dslMTLRezIdxOffsets;
+
+    for (uint32_t rezIdx = dstArrayElement;
+         rezIdx < _info.descriptorCount && rezIdx - dstArrayElement < descriptorCount;
+         rezIdx++) {
+        switch (_info.descriptorType) {
+
+            case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
+            case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:
+            case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
+            case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: {
+                const VkDescriptorBufferInfo& bufferInfo = pBufferInfo[rezIdx - dstArrayElement];
+                MVKBuffer* buffer = (MVKBuffer*)bufferInfo.buffer;
+                bb.mtlBuffer = buffer->getMTLBuffer();
+                bb.offset = bufferInfo.offset;
+                if (_applyToVertexStage) {
+                    bb.index = mtlIdxs.vertexStage.bufferIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindVertexBuffer(bb);
+                }
+                if (_applyToFragmentStage) {
+                    bb.index = mtlIdxs.fragmentStage.bufferIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindFragmentBuffer(bb);
+                }
+                if (_applyToComputeStage) {
+                    bb.index = mtlIdxs.computeStage.bufferIndex + rezIdx;
+                    cmdEncoder->_computeResourcesState.bindBuffer(bb);
+                }
+                break;
+            }
+
+            case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
+            case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
+            case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT: {
+                const VkDescriptorImageInfo& imageInfo = pImageInfo[rezIdx - dstArrayElement];
+                MVKImageView* imageView = (MVKImageView*)imageInfo.imageView;
+                tb.mtlTexture = imageView->getMTLTexture();
+                if (_applyToVertexStage) {
+                    tb.index = mtlIdxs.vertexStage.textureIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindVertexTexture(tb);
+                }
+                if (_applyToFragmentStage) {
+                    tb.index = mtlIdxs.fragmentStage.textureIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindFragmentTexture(tb);
+                }
+                if (_applyToComputeStage) {
+                    tb.index = mtlIdxs.computeStage.textureIndex + rezIdx;
+                    cmdEncoder->_computeResourcesState.bindTexture(tb);
+                }
+                break;
+            }
+
+            case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
+            case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER: {
+                MVKBufferView* bufferView = (MVKBufferView*)pTexelBufferView[rezIdx - dstArrayElement];
+                tb.mtlTexture = bufferView->getMTLTexture();
+                if (_applyToVertexStage) {
+                    tb.index = mtlIdxs.vertexStage.textureIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindVertexTexture(tb);
+                }
+                if (_applyToFragmentStage) {
+                    tb.index = mtlIdxs.fragmentStage.textureIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindFragmentTexture(tb);
+                }
+                if (_applyToComputeStage) {
+                    tb.index = mtlIdxs.computeStage.textureIndex + rezIdx;
+                    cmdEncoder->_computeResourcesState.bindTexture(tb);
+                }
+                break;
+            }
+
+            case VK_DESCRIPTOR_TYPE_SAMPLER: {
+                MVKSampler* sampler;
+                if (_immutableSamplers.empty())
+                    sampler = (MVKSampler*)pImageInfo[rezIdx - dstArrayElement].sampler;
+                else
+                    sampler = _immutableSamplers[rezIdx];
+                sb.mtlSamplerState = sampler->getMTLSamplerState();
+                if (_applyToVertexStage) {
+                    sb.index = mtlIdxs.vertexStage.samplerIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindVertexSamplerState(sb);
+                }
+                if (_applyToFragmentStage) {
+                    sb.index = mtlIdxs.fragmentStage.samplerIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindFragmentSamplerState(sb);
+                }
+                if (_applyToComputeStage) {
+                    sb.index = mtlIdxs.computeStage.samplerIndex + rezIdx;
+                    cmdEncoder->_computeResourcesState.bindSamplerState(sb);
+                }
+                break;
+            }
+
+            case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: {
+                const VkDescriptorImageInfo& imageInfo = pImageInfo[rezIdx - dstArrayElement];
+                MVKImageView* imageView = (MVKImageView*)imageInfo.imageView;
+                MVKSampler* sampler = _immutableSamplers.empty() ? (MVKSampler*)imageInfo.sampler : _immutableSamplers[rezIdx];
+                tb.mtlTexture = imageView->getMTLTexture();
+                sb.mtlSamplerState = sampler->getMTLSamplerState();
+                if (_applyToVertexStage) {
+                    tb.index = mtlIdxs.vertexStage.textureIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindVertexTexture(tb);
+                    sb.index = mtlIdxs.vertexStage.samplerIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindVertexSamplerState(sb);
+                }
+                if (_applyToFragmentStage) {
+                    tb.index = mtlIdxs.fragmentStage.textureIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindFragmentTexture(tb);
+                    sb.index = mtlIdxs.fragmentStage.samplerIndex + rezIdx;
+                    cmdEncoder->_graphicsResourcesState.bindFragmentSamplerState(sb);
+                }
+                if (_applyToComputeStage) {
+                    tb.index = mtlIdxs.computeStage.textureIndex + rezIdx;
+                    cmdEncoder->_computeResourcesState.bindTexture(tb);
+                    sb.index = mtlIdxs.computeStage.samplerIndex + rezIdx;
+                    cmdEncoder->_computeResourcesState.bindSamplerState(sb);
+                }
+                break;
+            }
+
+            default:
+                break;
+        }
+    }
+
+    dstArrayElement = 0;
+    if (_info.descriptorCount > descriptorCount)
+        descriptorCount = 0;
+    else {
+        descriptorCount -= _info.descriptorCount;
+        pImageInfo += _info.descriptorCount;
+        pBufferInfo += _info.descriptorCount;
+        pTexelBufferView += _info.descriptorCount;
+    }
+}
+
 void MVKDescriptorSetLayoutBinding::populateShaderConverterContext(SPIRVToMSLConverterContext& context,
                                                                    MVKShaderResourceBinding& dslMTLRezIdxOffsets,
                                                                    uint32_t dslIndex) {
@@ -300,6 +466,7 @@
                                                vector<uint32_t>& dynamicOffsets,
                                                uint32_t* pDynamicOffsetIndex) {
 
+    if (_isPushDescriptorLayout) return;
     uint32_t bindCnt = (uint32_t)_bindings.size();
     for (uint32_t bindIdx = 0; bindIdx < bindCnt; bindIdx++) {
         _bindings[bindIdx].bind(cmdEncoder, descSet->_bindings[bindIdx],
@@ -308,6 +475,28 @@
     }
 }
 
+void MVKDescriptorSetLayout::pushDescriptorSet(MVKCommandEncoder* cmdEncoder,
+                                               vector<VkWriteDescriptorSet>& descriptorWrites,
+                                               MVKShaderResourceBinding& dslMTLRezIdxOffsets) {
+
+    if (!_isPushDescriptorLayout) return;
+    for (const VkWriteDescriptorSet& descWrite : descriptorWrites) {
+        uint32_t bindIdx = descWrite.dstBinding;
+        uint32_t dstArrayElement = descWrite.dstArrayElement;
+        uint32_t descriptorCount = descWrite.descriptorCount;
+        const VkDescriptorImageInfo* pImageInfo = descWrite.pImageInfo;
+        const VkDescriptorBufferInfo* pBufferInfo = descWrite.pBufferInfo;
+        const VkBufferView* pTexelBufferView = descWrite.pTexelBufferView;
+        // Note: This will result in us walking off the end of the array
+        // in case there are too many updates... but that's ill-defined anyway.
+        for (; descriptorCount; bindIdx++) {
+            _bindings[bindIdx].push(cmdEncoder, dstArrayElement, descriptorCount,
+                                    descWrite.descriptorType, pImageInfo, pBufferInfo,
+                                    pTexelBufferView, dslMTLRezIdxOffsets);
+        }
+    }
+}
+
 void MVKDescriptorSetLayout::populateShaderConverterContext(SPIRVToMSLConverterContext& context,
                                                             MVKShaderResourceBinding& dslMTLRezIdxOffsets,
 															uint32_t dslIndex) {
@@ -319,6 +508,7 @@
 
 MVKDescriptorSetLayout::MVKDescriptorSetLayout(MVKDevice* device,
                                                const VkDescriptorSetLayoutCreateInfo* pCreateInfo) : MVKBaseDeviceObject(device) {
+    _isPushDescriptorLayout = (pCreateInfo->flags & VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR) != 0;
     // Create the descriptor bindings
     _bindings.reserve(pCreateInfo->bindingCount);
     for (uint32_t i = 0; i < pCreateInfo->bindingCount; i++) {
@@ -616,6 +806,7 @@
 
 	for (uint32_t dsIdx = 0; dsIdx < count; dsIdx++) {
 		MVKDescriptorSetLayout* mvkDSL = (MVKDescriptorSetLayout*)pSetLayouts[dsIdx];
+		if (mvkDSL->isPushDescriptorLayout()) continue;
 		MVKDescriptorSet* mvkDescSet = new MVKDescriptorSet(_device, mvkDSL);
 		_allocatedSets.push_front(mvkDescSet);
 		pDescriptorSets[dsIdx] = (VkDescriptorSet)mvkDescSet;
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKDevice.mm b/MoltenVK/MoltenVK/GPUObjects/MVKDevice.mm
index 89411c4..ee5889c 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKDevice.mm
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKDevice.mm
@@ -77,6 +77,20 @@
     if (properties) {
         properties->sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2_KHR;
         properties->properties = _properties;
+        auto* next = (VkStructureType*)properties->pNext;
+        while (next) {
+            switch (*next) {
+            case VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PUSH_DESCRIPTOR_PROPERTIES_KHR: {
+                auto* pushDescProps = (VkPhysicalDevicePushDescriptorPropertiesKHR*)next;
+                pushDescProps->maxPushDescriptors = _properties.limits.maxPerStageResources;
+                next = (VkStructureType*)pushDescProps->pNext;
+                break;
+            }
+            default:
+                next = *(VkStructureType**)(next+1);
+                break;
+            }
+        }
     }
 }
 
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKInstance.mm b/MoltenVK/MoltenVK/GPUObjects/MVKInstance.mm
index df5b9ff..61df8ed 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKInstance.mm
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKInstance.mm
@@ -249,6 +249,7 @@
 	ADD_PROC_ADDR(vkGetPhysicalDeviceQueueFamilyProperties2KHR);
 	ADD_PROC_ADDR(vkGetPhysicalDeviceMemoryProperties2KHR);
 	ADD_PROC_ADDR(vkGetPhysicalDeviceSparseImageFormatProperties2KHR);
+	ADD_PROC_ADDR(vkCmdPushDescriptorSetKHR);
 	ADD_PROC_ADDR(vkGetMoltenVKConfigurationMVK);
 	ADD_PROC_ADDR(vkSetMoltenVKConfigurationMVK);
     ADD_PROC_ADDR(vkGetPhysicalDeviceMetalFeaturesMVK);
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.h b/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.h
index 3772693..4cd6391 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.h
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.h
@@ -50,6 +50,11 @@
 	/** Populates the specified shader converter context. */
 	void populateShaderConverterContext(SPIRVToMSLConverterContext& context);
 
+	/** Updates a descriptor set in a command encoder. */
+	void pushDescriptorSet(MVKCommandEncoder* cmdEncoder,
+						   std::vector<VkWriteDescriptorSet>& descriptorWrites,
+						   uint32_t set);
+
 	/** Constructs an instance for the specified device. */
 	MVKPipelineLayout(MVKDevice* device, const VkPipelineLayoutCreateInfo* pCreateInfo);
 
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm b/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm
index 2efa06f..151b840 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm
@@ -54,6 +54,17 @@
     cmdEncoder->getPushConstants(VK_SHADER_STAGE_COMPUTE_BIT)->setMTLBufferIndex(_pushConstantsMTLResourceIndexOffsets.computeStage.bufferIndex);
 }
 
+void MVKPipelineLayout::pushDescriptorSet(MVKCommandEncoder* cmdEncoder,
+                                          vector<VkWriteDescriptorSet>& descriptorWrites,
+                                          uint32_t set) {
+
+    _descriptorSetLayouts[set].pushDescriptorSet(cmdEncoder, descriptorWrites,
+                                                 _dslMTLResourceIndexOffsets[set]);
+    cmdEncoder->getPushConstants(VK_SHADER_STAGE_VERTEX_BIT)->setMTLBufferIndex(_pushConstantsMTLResourceIndexOffsets.vertexStage.bufferIndex);
+    cmdEncoder->getPushConstants(VK_SHADER_STAGE_FRAGMENT_BIT)->setMTLBufferIndex(_pushConstantsMTLResourceIndexOffsets.fragmentStage.bufferIndex);
+    cmdEncoder->getPushConstants(VK_SHADER_STAGE_COMPUTE_BIT)->setMTLBufferIndex(_pushConstantsMTLResourceIndexOffsets.computeStage.bufferIndex);
+}
+
 void MVKPipelineLayout::populateShaderConverterContext(SPIRVToMSLConverterContext& context) {
 	context.resourceBindings.clear();
 
diff --git a/MoltenVK/MoltenVK/Loader/MVKLayers.mm b/MoltenVK/MoltenVK/Loader/MVKLayers.mm
index bc3b9e8..42f331e 100644
--- a/MoltenVK/MoltenVK/Loader/MVKLayers.mm
+++ b/MoltenVK/MoltenVK/Loader/MVKLayers.mm
@@ -107,6 +107,11 @@
     extTmplt.specVersion = VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_SPEC_VERSION;
     _extensions.push_back(extTmplt);
 
+    memset(extTmplt.extensionName, 0, sizeof(extTmplt.extensionName));
+    strcpy(extTmplt.extensionName, VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME);
+    extTmplt.specVersion = VK_KHR_PUSH_DESCRIPTOR_SPEC_VERSION;
+    _extensions.push_back(extTmplt);
+
 #if MVK_IOS
     memset(extTmplt.extensionName, 0, sizeof(extTmplt.extensionName));
 	strcpy(extTmplt.extensionName, VK_MVK_IOS_SURFACE_EXTENSION_NAME);
diff --git a/MoltenVK/MoltenVK/Vulkan/vulkan.mm b/MoltenVK/MoltenVK/Vulkan/vulkan.mm
index a9434f3..db0e6f3 100644
--- a/MoltenVK/MoltenVK/Vulkan/vulkan.mm
+++ b/MoltenVK/MoltenVK/Vulkan/vulkan.mm
@@ -1678,6 +1678,22 @@
 
 
 #pragma mark -
+#pragma mark VK_KHR_push_descriptor extension
+
+MVK_PUBLIC_SYMBOL void vkCmdPushDescriptorSetKHR(
+    VkCommandBuffer                             commandBuffer,
+    VkPipelineBindPoint                         pipelineBindPoint,
+    VkPipelineLayout                            layout,
+    uint32_t                                    set,
+    uint32_t                                    descriptorWriteCount,
+    const VkWriteDescriptorSet*                 pDescriptorWrites) {
+
+    MVKCommandBuffer* cmdBuff = MVKCommandBuffer::getMVKCommandBuffer(commandBuffer);
+    mvkCmdPushDescriptorSet(cmdBuff, pipelineBindPoint, layout, set, descriptorWriteCount, pDescriptorWrites);
+}
+
+
+#pragma mark -
 #pragma mark Loader and Layer ICD interface extension
 
 #ifdef __cplusplus