Added a way to set compute kernel threadgroup size when using MSL source code or MSL compiled code
diff --git a/MoltenVK/MoltenVK/API/vk_mvk_moltenvk.h b/MoltenVK/MoltenVK/API/vk_mvk_moltenvk.h
index 08bbe33..df52fd3 100644
--- a/MoltenVK/MoltenVK/API/vk_mvk_moltenvk.h
+++ b/MoltenVK/MoltenVK/API/vk_mvk_moltenvk.h
@@ -795,6 +795,18 @@
char* pVulkanVersionStringBuffer,
uint32_t vulkanVersionStringBufferLength);
+/**
+ * Sets the number of threads in a threadgroup for a compute kernel.
+ *
+ * This needs to be called if you are creating compute shader modules from MSL
+ * source code or MSL compiled code. Threadgroup size is determined automatically
+ * if you're using SPIR-V.
+ */
+VKAPI_ATTR void VKAPI_CALL vkSetThreadgroupSizeMVK(
+ VkShaderModule shaderModule,
+ uint32_t x,
+ uint32_t y,
+ uint32_t z);
#ifdef __OBJC__
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
index bfd4d59..d53a4ec 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
@@ -55,6 +55,9 @@
/** Returns the Vulkan API opaque object controlling this object. */
MVKVulkanAPIObject* getVulkanAPIObject() override { return _owner->getVulkanAPIObject(); };
+ /** Sets the number of threads in a single compute kernel workgroup, per dimension. */
+ void setWorkgroupSize(uint32_t x, uint32_t y, uint32_t z);
+
/** Constructs an instance from the specified MSL source code. */
MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const std::string& mslSourceCode,
@@ -186,7 +189,10 @@
* call to convert() function, or set directly using the setMSL() function.
*/
const SPIRVEntryPoint& getEntryPoint() { return _spvConverter.getEntryPoint(); }
-
+
+ /** Sets the number of threads in a single compute kernel workgroup, per dimension. */
+ void setWorkgroupSize(uint32_t x, uint32_t y, uint32_t z);
+
/** Returns a key as a means of identifying this shader module in a pipeline cache. */
MVKShaderModuleKey getKey() { return _key; }
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
index 999b337..426670a 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
@@ -158,6 +158,12 @@
}
}
+void MVKShaderLibrary::setWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) {
+ _entryPoint.workgroupSize.width.size = x;
+ _entryPoint.workgroupSize.height.size = y;
+ _entryPoint.workgroupSize.depth.size = z;
+}
+
MVKShaderLibrary::~MVKShaderLibrary() {
[_mtlLibrary release];
}
@@ -376,6 +382,11 @@
if (_defaultLibrary) { _defaultLibrary->destroy(); }
}
+void MVKShaderModule::setWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) {
+ _spvConverter.setWorkgroupSize(x, y, z);
+ if(_defaultLibrary) { _defaultLibrary->setWorkgroupSize(x, y, z); }
+}
+
#pragma mark -
#pragma mark MVKShaderLibraryCompiler
diff --git a/MoltenVK/MoltenVK/Vulkan/vk_mvk_moltenvk.mm b/MoltenVK/MoltenVK/Vulkan/vk_mvk_moltenvk.mm
index db81fec..778250d 100644
--- a/MoltenVK/MoltenVK/Vulkan/vk_mvk_moltenvk.mm
+++ b/MoltenVK/MoltenVK/Vulkan/vk_mvk_moltenvk.mm
@@ -23,6 +23,7 @@
#include "MVKSwapchain.h"
#include "MVKImage.h"
#include "MVKFoundation.h"
+#include "MVKShaderModule.h"
#include <string>
using namespace std;
@@ -147,3 +148,14 @@
MVKImage* mvkImg = (MVKImage*)image;
*pIOSurface = mvkImg->getIOSurface();
}
+
+MVK_PUBLIC_SYMBOL void vkSetThreadgroupSizeMVK(
+ VkShaderModule shaderModule,
+ uint32_t x,
+ uint32_t y,
+ uint32_t z) {
+
+ MVKShaderModule* mvkShaderModule = (MVKShaderModule*)shaderModule;
+ mvkShaderModule->setWorkgroupSize(x, y, z);
+}
+
diff --git a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h
index 61aa863..0d86ffe 100644
--- a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h
+++ b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h
@@ -291,6 +291,13 @@
/** Returns information about the shader entry point. */
const SPIRVEntryPoint& getEntryPoint() { return _entryPoint; }
+ /** Sets the number of threads in a single compute kernel workgroup, per dimension. */
+ void setWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) {
+ _entryPoint.workgroupSize.width.size = x;
+ _entryPoint.workgroupSize.height.size = y;
+ _entryPoint.workgroupSize.depth.size = z;
+ }
+
/**
* Returns a human-readable log of the most recent conversion activity.
* This may be empty if the conversion was successful.