Merge pull request #2474 from billhollings/VK_KHR_maintenance4
Fix failure when shader specifies both LocalSizeId and workgroup size builtin.
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
index eb8af44..8090e9c 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
@@ -170,8 +170,8 @@
auto& wgSize = _shaderConversionResultInfo.entryPoint.workgroupSize;
return MVKMTLFunction(mtlFunc, _shaderConversionResultInfo, MTLSizeMake(getWorkgroupDimensionSize(wgSize.width, pSpecializationInfo),
- getWorkgroupDimensionSize(wgSize.height, pSpecializationInfo),
- getWorkgroupDimensionSize(wgSize.depth, pSpecializationInfo)));
+ getWorkgroupDimensionSize(wgSize.height, pSpecializationInfo),
+ getWorkgroupDimensionSize(wgSize.depth, pSpecializationInfo)));
}
}
}
diff --git a/MoltenVKShaderConverter/MoltenVKShaderConverter/SPIRVToMSLConverter.cpp b/MoltenVKShaderConverter/MoltenVKShaderConverter/SPIRVToMSLConverter.cpp
index ed4fe4f..92f7b55 100644
--- a/MoltenVKShaderConverter/MoltenVKShaderConverter/SPIRVToMSLConverter.cpp
+++ b/MoltenVKShaderConverter/MoltenVKShaderConverter/SPIRVToMSLConverter.cpp
@@ -502,6 +502,21 @@
log += "\n\n";
}
+// Extracts the workgroup dimension from either the LocalSizeId, LocalSize, or WorkgroupSize Builtin.
+// Although LocalSizeId is the modern mechanism, the Builtin takes precedence if it is presidence.
+static void getWorkgroupSize(Compiler* pCompiler, SPIREntryPoint& spvEP, uint32_t& x, uint32_t& y, uint32_t& z) {
+ auto& wgSz = spvEP.workgroup_size;
+ if (spvEP.flags.get(ExecutionModeLocalSizeId) && !wgSz.constant) {
+ x = wgSz.id_x ? pCompiler->get_constant(wgSz.id_x).scalar() : 0;
+ y = wgSz.id_y ? pCompiler->get_constant(wgSz.id_y).scalar() : 0;
+ z = wgSz.id_z ? pCompiler->get_constant(wgSz.id_z).scalar() : 0;
+ } else {
+ x = wgSz.x;
+ y = wgSz.y;
+ z = wgSz.z;
+ }
+}
+
void SPIRVToMSLConverter::populateWorkgroupDimension(SPIRVWorkgroupSizeDimension& wgDim,
uint32_t size,
SpecializationConstant& spvSpecConst) {
@@ -531,13 +546,16 @@
entryPoint.mtlFunctionName = spvEP.name;
entryPoint.supportsFastMath = !spvEP.flags.get(ExecutionModeSignedZeroInfNanPreserve);
+ uint32_t x, y, z;
+ getWorkgroupSize(pCompiler, spvEP, x, y, z);
+
SpecializationConstant widthSC, heightSC, depthSC;
pCompiler->get_work_group_size_specialization_constants(widthSC, heightSC, depthSC);
auto& wgSize = entryPoint.workgroupSize;
- populateWorkgroupDimension(wgSize.width, spvEP.workgroup_size.x, widthSC);
- populateWorkgroupDimension(wgSize.height, spvEP.workgroup_size.y, heightSC);
- populateWorkgroupDimension(wgSize.depth, spvEP.workgroup_size.z, depthSC);
+ populateWorkgroupDimension(wgSize.width, x, widthSC);
+ populateWorkgroupDimension(wgSize.height, y, heightSC);
+ populateWorkgroupDimension(wgSize.depth, z, depthSC);
}
bool SPIRVToMSLConverter::usesPhysicalStorageBufferAddressesCapability(Compiler* pCompiler) {