Support SPIR-V containing multiple entry points.
Include name and stage when matching shaders contexts for pipelines.
Avoid deprecated SPIRV-Cross members.
Update to latest SPIRV-Cross version.
diff --git a/External/SPIRV-Cross_repo_revision b/External/SPIRV-Cross_repo_revision
index 5b72b9b..87ec2f2 100644
--- a/External/SPIRV-Cross_repo_revision
+++ b/External/SPIRV-Cross_repo_revision
@@ -1 +1 @@
-0f9cb0da0d5ab91b21a42ffc0062840fc76e81e3
+5161d5ed3b5a788c2469bb548fbb6001f98c03fa
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm b/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm
index 84af75e..7382652 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm
@@ -264,17 +264,20 @@
// Add shader stages
for (uint32_t i = 0; i < pCreateInfo->stageCount; i++) {
const VkPipelineShaderStageCreateInfo* pSS = &pCreateInfo->pStages[i];
+ shaderContext.options.entryPointName = pSS->pName;
MVKShaderModule* mvkShdrMod = (MVKShaderModule*)pSS->module;
// Vertex shader
if (mvkAreFlagsEnabled(pSS->stage, VK_SHADER_STAGE_VERTEX_BIT)) {
- plDesc.vertexFunction = mvkShdrMod->getMTLFunction(pSS, &shaderContext).mtlFunction;
+ shaderContext.options.entryPointStage = spv::ExecutionModelVertex;
+ plDesc.vertexFunction = mvkShdrMod->getMTLFunction(&shaderContext, pSS->pSpecializationInfo).mtlFunction;
}
// Fragment shader
if (mvkAreFlagsEnabled(pSS->stage, VK_SHADER_STAGE_FRAGMENT_BIT)) {
- plDesc.fragmentFunction = mvkShdrMod->getMTLFunction(pSS, &shaderContext).mtlFunction;
+ shaderContext.options.entryPointStage = spv::ExecutionModelFragment;
+ plDesc.fragmentFunction = mvkShdrMod->getMTLFunction(&shaderContext, pSS->pSpecializationInfo).mtlFunction;
}
}
@@ -426,13 +429,15 @@
if ( !mvkAreFlagsEnabled(pSS->stage, VK_SHADER_STAGE_COMPUTE_BIT) ) { return MVKMTLFunctionNull; }
SPIRVToMSLConverterContext shaderContext;
+ shaderContext.options.entryPointName = pCreateInfo->stage.pName;
+ shaderContext.options.entryPointStage = spv::ExecutionModelGLCompute;
shaderContext.options.mslVersion = _device->_pMetalFeatures->mslVersion;
MVKPipelineLayout* layout = (MVKPipelineLayout*)pCreateInfo->layout;
layout->populateShaderConverterContext(shaderContext);
MVKShaderModule* mvkShdrMod = (MVKShaderModule*)pSS->module;
- return mvkShdrMod->getMTLFunction(pSS, &shaderContext);
+ return mvkShdrMod->getMTLFunction(&shaderContext, pSS->pSpecializationInfo);
}
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
index 1c9969a..a697a27 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
@@ -44,8 +44,8 @@
class MVKShaderLibrary : public MVKBaseDeviceObject {
public:
- /** Returns the Metal shader function used by the specified shader state. */
- MVKMTLFunction getMTLFunction(const VkPipelineShaderStageCreateInfo* pShaderStage);
+ /** Returns the Metal shader function, possibly specialized. */
+ MVKMTLFunction getMTLFunction(const VkSpecializationInfo* pSpecializationInfo);
/** Constructs an instance from the MSL source code in the specified SPIRVToMSLConverter. */
MVKShaderLibrary(MVKDevice* device, SPIRVToMSLConverter& mslConverter);
@@ -62,7 +62,7 @@
MTLFunctionConstant* getFunctionConstant(NSArray<MTLFunctionConstant*>* mtlFCs, NSUInteger mtlFCID);
id<MTLLibrary> _mtlLibrary;
- SPIRVEntryPointsByName _entryPoints;
+ SPIRVEntryPoint _entryPoint;
};
@@ -73,9 +73,9 @@
class MVKShaderModule : public MVKBaseDeviceObject {
public:
- /** Returns the Metal shader function used by the specified shader state, or nil if it doesn't exist. */
- MVKMTLFunction getMTLFunction(const VkPipelineShaderStageCreateInfo* pShaderStage,
- SPIRVToMSLConverterContext* pContext);
+ /** Returns the Metal shader function, possibly specialized. */
+ MVKMTLFunction getMTLFunction(SPIRVToMSLConverterContext* pContext,
+ const VkSpecializationInfo* pSpecializationInfo);
MVKShaderModule(MVKDevice* device, const VkShaderModuleCreateInfo* pCreateInfo);
diff --git a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
index fe2f522..ce9d059 100644
--- a/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
+++ b/MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
@@ -39,14 +39,13 @@
return -1;
}
-MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkPipelineShaderStageCreateInfo* pShaderStage) {
+MVKMTLFunction MVKShaderLibrary::getMTLFunction(const VkSpecializationInfo* pSpecializationInfo) {
if ( !_mtlLibrary ) { return MVKMTLFunctionNull; }
// Ensure the function name is compatible with Metal (Metal does not allow main()
// as a function name), and retrieve the unspecialized Metal function with that name.
- SPIRVEntryPoint& ep = _entryPoints[pShaderStage->pName];
- NSString* mtlFuncName = @(ep.mtlFunctionName.c_str());
+ NSString* mtlFuncName = @(_entryPoint.mtlFunctionName.c_str());
uint64_t startTime = _device->getPerformanceTimestamp();
id<MTLFunction> mtlFunc = [[_mtlLibrary newFunctionWithName: mtlFuncName] autorelease];
@@ -64,16 +63,15 @@
// The Metal shader contains function constants and expects to be specialized
// Populate the Metal function constant values from the Vulkan specialization info.
MTLFunctionConstantValues* mtlFCVals = [[MTLFunctionConstantValues new] autorelease];
- const VkSpecializationInfo* pSpecInfo = pShaderStage->pSpecializationInfo;
- if (pSpecInfo) {
+ if (pSpecializationInfo) {
// Iterate through the provided Vulkan specialization entries, and populate the
// Metal function constant value that matches the Vulkan specialization constantID.
- for (uint32_t specIdx = 0; specIdx < pSpecInfo->mapEntryCount; specIdx++) {
- const VkSpecializationMapEntry* pMapEntry = &pSpecInfo->pMapEntries[specIdx];
+ for (uint32_t specIdx = 0; specIdx < pSpecializationInfo->mapEntryCount; specIdx++) {
+ const VkSpecializationMapEntry* pMapEntry = &pSpecializationInfo->pMapEntries[specIdx];
NSUInteger mtlFCIndex = pMapEntry->constantID;
MTLFunctionConstant* mtlFC = getFunctionConstant(mtlFCs, mtlFCIndex);
if (mtlFC) {
- [mtlFCVals setConstantValue: &(((char*)pSpecInfo->pData)[pMapEntry->offset])
+ [mtlFCVals setConstantValue: &(((char*)pSpecializationInfo->pData)[pMapEntry->offset])
type: mtlFC.type
atIndex: mtlFCIndex];
}
@@ -90,29 +88,28 @@
} else {
mvkNotifyErrorWithText(VK_ERROR_INITIALIZATION_FAILED, "Shader module does not contain an entry point named '%s'.", mtlFuncName.UTF8String);
}
-
- const VkSpecializationInfo* pSpecInfo = pShaderStage->pSpecializationInfo;
- if (pSpecInfo) {
+
+ if (pSpecializationInfo) {
// Get the specialization constant values for the work group size
- if (ep.workgroupSizeId.constant != 0) {
- uint32_t widthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.width);
+ if (_entryPoint.workgroupSizeId.constant != 0) {
+ uint32_t widthOffset = getOffsetForConstantId(pSpecializationInfo, _entryPoint.workgroupSizeId.width);
if (widthOffset != -1) {
- ep.workgroupSize.width = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + widthOffset);
+ _entryPoint.workgroupSize.width = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecializationInfo->pData + widthOffset);
}
-
- uint32_t heightOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.height);
+
+ uint32_t heightOffset = getOffsetForConstantId(pSpecializationInfo, _entryPoint.workgroupSizeId.height);
if (heightOffset != -1) {
- ep.workgroupSize.height = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + heightOffset);
+ _entryPoint.workgroupSize.height = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecializationInfo->pData + heightOffset);
}
-
- uint32_t depthOffset = getOffsetForConstantId(pSpecInfo, ep.workgroupSizeId.depth);
+
+ uint32_t depthOffset = getOffsetForConstantId(pSpecializationInfo, _entryPoint.workgroupSizeId.depth);
if (depthOffset != -1) {
- ep.workgroupSize.depth = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecInfo->pData + depthOffset);
+ _entryPoint.workgroupSize.depth = *reinterpret_cast<uint32_t*>((uint8_t*)pSpecializationInfo->pData + depthOffset);
}
}
}
- return { mtlFunc, MTLSizeMake(ep.workgroupSize.width, ep.workgroupSize.height, ep.workgroupSize.depth) };
+ return { mtlFunc, MTLSizeMake(_entryPoint.workgroupSize.width, _entryPoint.workgroupSize.height, _entryPoint.workgroupSize.depth) };
}
// Returns the MTLFunctionConstant with the specified ID from the specified array of function constants.
@@ -134,7 +131,7 @@
}
_device->addShaderCompilationEventPerformance(_device->_shaderCompilationPerformance.mslCompile, startTime);
- _entryPoints = mslConverter.getEntryPoints();
+ _entryPoint = mslConverter.getEntryPoint();
}
MVKShaderLibrary::MVKShaderLibrary(MVKDevice* device,
@@ -179,11 +176,11 @@
#pragma mark -
#pragma mark MVKShaderModule
-MVKMTLFunction MVKShaderModule::getMTLFunction(const VkPipelineShaderStageCreateInfo* pShaderStage,
- SPIRVToMSLConverterContext* pContext) {
+MVKMTLFunction MVKShaderModule::getMTLFunction(SPIRVToMSLConverterContext* pContext,
+ const VkSpecializationInfo* pSpecializationInfo) {
lock_guard<mutex> lock(_accessLock);
MVKShaderLibrary* mvkLib = getShaderLibrary(pContext);
- return mvkLib ? mvkLib->getMTLFunction(pShaderStage) : MVKMTLFunctionNull;
+ return mvkLib ? mvkLib->getMTLFunction(pSpecializationInfo) : MVKMTLFunctionNull;
}
MVKShaderLibrary* MVKShaderModule::getShaderLibrary(SPIRVToMSLConverterContext* pContext) {
@@ -200,7 +197,7 @@
MVKShaderLibrary* MVKShaderModule::findShaderLibrary(SPIRVToMSLConverterContext* pContext) {
for (auto& slPair : _shaderLibraries) {
if (slPair.first.matches(*pContext)) {
- (*pContext).alignUsageWith(slPair.first);
+ pContext->alignUsageWith(slPair.first);
return slPair.second;
}
}
@@ -248,8 +245,7 @@
}
case kMVKMagicNumberMSLSourceCode: { // MSL source code
uintptr_t pMSLCode = uintptr_t(pCreateInfo->pCode) + sizeof(MVKMSLSPIRVHeader);
- SPIRVEntryPointsByName entryPoints;
- _converter.setMSL((char*)pMSLCode, entryPoints);
+ _converter.setMSL((char*)pMSLCode, nullptr);
_defaultLibrary = new MVKShaderLibrary(_device, _converter);
break;
}
diff --git a/MoltenVKShaderConverter/MoltenVKGLSLToSPIRVConverter/GLSLToSPIRVConverter.cpp b/MoltenVKShaderConverter/MoltenVKGLSLToSPIRVConverter/GLSLToSPIRVConverter.cpp
index 814fe52..c80dfe2 100644
--- a/MoltenVKShaderConverter/MoltenVKGLSLToSPIRVConverter/GLSLToSPIRVConverter.cpp
+++ b/MoltenVKShaderConverter/MoltenVKGLSLToSPIRVConverter/GLSLToSPIRVConverter.cpp
@@ -43,8 +43,8 @@
MVK_PUBLIC_SYMBOL const string& GLSLToSPIRVConverter::getGLSL() { return _glsl; }
MVK_PUBLIC_SYMBOL bool GLSLToSPIRVConverter::convert(MVKShaderStage shaderStage,
- bool shouldLogGLSL,
- bool shouldLogSPIRV) {
+ bool shouldLogGLSL,
+ bool shouldLogSPIRV) {
_wasConverted = true;
_resultLog.clear();
_spirv.clear();
diff --git a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp
index 5e909c7..4211adb 100644
--- a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp
+++ b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.cpp
@@ -21,7 +21,6 @@
#include "MVKStrings.h"
#include "FileSupport.h"
#include "spirv_msl.hpp"
-#include "spirv_glsl.hpp"
#include <spirv-tools/libspirv.h>
#import <CoreFoundation/CFByteOrder.h>
@@ -34,19 +33,21 @@
// Returns whether the vector contains the value (using a matches(T&) comparison member function). */
template<class T>
-bool contains(vector<T>& vec, T& val) {
- for (T& vecVal : vec) { if (vecVal.matches(val)) { return true; } }
+bool contains(const vector<T>& vec, const T& val) {
+ for (const T& vecVal : vec) { if (vecVal.matches(val)) { return true; } }
return false;
}
-MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverterOptions::matches(SPIRVToMSLConverterOptions& other) {
+MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverterOptions::matches(const SPIRVToMSLConverterOptions& other) const {
+ if (entryPointStage != other.entryPointStage) { return false; }
if (mslVersion != other.mslVersion) { return false; }
if (!!shouldFlipVertexY != !!other.shouldFlipVertexY) { return false; }
if (!!isRenderingPoints != !!other.isRenderingPoints) { return false; }
+ if (entryPointName != other.entryPointName) { return false; }
return true;
}
-MVK_PUBLIC_SYMBOL bool MSLVertexAttribute::matches(MSLVertexAttribute& other) {
+MVK_PUBLIC_SYMBOL bool MSLVertexAttribute::matches(const MSLVertexAttribute& other) const {
if (location != other.location) { return false; }
if (mslBuffer != other.mslBuffer) { return false; }
if (mslOffset != other.mslOffset) { return false; }
@@ -55,7 +56,7 @@
return true;
}
-MVK_PUBLIC_SYMBOL bool MSLResourceBinding::matches(MSLResourceBinding& other) {
+MVK_PUBLIC_SYMBOL bool MSLResourceBinding::matches(const MSLResourceBinding& other) const {
if (stage != other.stage) { return false; }
if (descriptorSet != other.descriptorSet) { return false; }
if (binding != other.binding) { return false; }
@@ -66,7 +67,7 @@
}
// Check them all in case inactive VA's duplicate locations used by active VA's.
-MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverterContext::isVertexAttributeLocationUsed(uint32_t location) {
+MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverterContext::isVertexAttributeLocationUsed(uint32_t location) const {
for (auto& va : vertexAttributes) {
if ((va.location == location) && va.isUsedByShader) { return true; }
}
@@ -74,22 +75,22 @@
}
// Check them all in case inactive VA's duplicate buffers used by active VA's.
-MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverterContext::isVertexBufferUsed(uint32_t mslBuffer) {
+MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverterContext::isVertexBufferUsed(uint32_t mslBuffer) const {
for (auto& va : vertexAttributes) {
if ((va.mslBuffer == mslBuffer) && va.isUsedByShader) { return true; }
}
return false;
}
-MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverterContext::matches(SPIRVToMSLConverterContext& other) {
+MVK_PUBLIC_SYMBOL bool SPIRVToMSLConverterContext::matches(const SPIRVToMSLConverterContext& other) const {
if ( !options.matches(other.options) ) { return false; }
- for (auto& va : vertexAttributes) {
+ for (const auto& va : vertexAttributes) {
if (va.isUsedByShader && !contains(other.vertexAttributes, va)) { return false; }
}
- for (auto& rb : resourceBindings) {
+ for (const auto& rb : resourceBindings) {
if (rb.isUsedByShader && !contains(other.resourceBindings, rb)) { return false; }
}
@@ -97,7 +98,7 @@
}
// Aligns the usage of the destination context to that of the source context.
-MVK_PUBLIC_SYMBOL void SPIRVToMSLConverterContext::alignUsageWith(SPIRVToMSLConverterContext& srcContext) {
+MVK_PUBLIC_SYMBOL void SPIRVToMSLConverterContext::alignUsageWith(const SPIRVToMSLConverterContext& srcContext) {
for (auto& va : vertexAttributes) {
va.isUsedByShader = false;
@@ -119,7 +120,7 @@
#pragma mark SPIRVToMSLConverter
/** Populates content extracted from the SPRI-V compiler. */
-void populateFromCompiler(spirv_cross::Compiler& compiler, SPIRVEntryPointsByName& entryPoints);
+void populateFromCompiler(spirv_cross::Compiler& compiler, SPIRVEntryPoint& entryPoint, SPIRVToMSLConverterOptions& options);
MVK_PUBLIC_SYMBOL void SPIRVToMSLConverter::setSPIRV(const vector<uint32_t>& spirv) { _spirv = spirv; }
@@ -172,9 +173,13 @@
spirv_cross::CompilerMSL mslCompiler(_spirv);
+ if (context.options.hasEntryPoint()) {
+ mslCompiler.set_entry_point(context.options.entryPointName, context.options.entryPointStage);
+ }
+
// Establish the MSL options for the compiler
// This needs to be done in two steps...for CompilerMSL and its superclass.
- auto mslOpts = mslCompiler.get_options();
+ auto mslOpts = mslCompiler.get_msl_options();
#if MVK_MACOS
mslOpts.platform = spirv_cross::CompilerMSL::Options::macOS;
@@ -186,11 +191,11 @@
mslOpts.msl_version = context.options.mslVersion;
mslOpts.enable_point_size_builtin = context.options.isRenderingPoints;
mslOpts.resolve_specialized_array_lengths = true;
- mslCompiler.set_options(mslOpts);
+ mslCompiler.set_msl_options(mslOpts);
- auto scOpts = mslCompiler.CompilerGLSL::get_options();
+ auto scOpts = mslCompiler.get_common_options();
scOpts.vertex.flip_vert_y = context.options.shouldFlipVertexY;
- mslCompiler.CompilerGLSL::set_options(scOpts);
+ mslCompiler.set_common_options(scOpts);
#ifndef SPIRV_CROSS_EXCEPTIONS_TO_ASSERTIONS
try {
@@ -210,7 +215,7 @@
#endif
// Populate content extracted from the SPRI-V compiler.
- populateFromCompiler(mslCompiler, _entryPoints);
+ populateFromCompiler(mslCompiler, _entryPoint, context.options);
// To check GLSL conversion
if (shouldLogGLSL) {
@@ -317,28 +322,32 @@
#pragma mark Support functions
-void populateFromCompiler(spirv_cross::Compiler& compiler, SPIRVEntryPointsByName& entryPoints) {
+void populateFromCompiler(spirv_cross::Compiler& compiler, SPIRVEntryPoint& entryPoint, SPIRVToMSLConverterOptions& options) {
- uint32_t minDim = 1;
- entryPoints.clear();
- for (string& epOrigName : compiler.get_entry_points()) {
- auto& spvEP = compiler.get_entry_point(epOrigName);
- auto& wgSize = spvEP.workgroup_size;
+ spirv_cross::SPIREntryPoint spvEP;
+ if (options.hasEntryPoint()) {
+ spvEP = compiler.get_entry_point(options.entryPointName, options.entryPointStage);
+ } else {
+ const auto& entryPoints = compiler.get_entry_points_and_stages();
+ if ( !entryPoints.empty() ) {
+ auto& ep = entryPoints[0];
+ spvEP = compiler.get_entry_point(ep.name, ep.execution_model);
+ }
+ }
- SPIRVEntryPoint mvkEP;
- mvkEP.mtlFunctionName = spvEP.name;
- mvkEP.workgroupSize.width = max(wgSize.x, minDim);
- mvkEP.workgroupSize.height = max(wgSize.y, minDim);
- mvkEP.workgroupSize.depth = max(wgSize.z, minDim);
+ uint32_t minDim = 1;
+ auto& wgSize = spvEP.workgroup_size;
- spirv_cross::SpecializationConstant width, height, depth;
- mvkEP.workgroupSizeId.constant = compiler.get_work_group_size_specialization_constants(width, height, depth);
- mvkEP.workgroupSizeId.width = width.constant_id;
- mvkEP.workgroupSizeId.height = height.constant_id;
- mvkEP.workgroupSizeId.depth = depth.constant_id;
+ entryPoint.mtlFunctionName = spvEP.name;
+ entryPoint.workgroupSize.width = max(wgSize.x, minDim);
+ entryPoint.workgroupSize.height = max(wgSize.y, minDim);
+ entryPoint.workgroupSize.depth = max(wgSize.z, minDim);
- entryPoints[epOrigName] = mvkEP;
- }
+ spirv_cross::SpecializationConstant width, height, depth;
+ entryPoint.workgroupSizeId.constant = compiler.get_work_group_size_specialization_constants(width, height, depth);
+ entryPoint.workgroupSizeId.width = width.constant_id;
+ entryPoint.workgroupSizeId.height = height.constant_id;
+ entryPoint.workgroupSizeId.depth = depth.constant_id;
}
MVK_PUBLIC_SYMBOL void mvk::logSPIRV(vector<uint32_t>& spirv, string& spvLog) {
diff --git a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h
index aee99df..3a5118e 100644
--- a/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h
+++ b/MoltenVKShaderConverter/MoltenVKSPIRVToMSLConverter/SPIRVToMSLConverter.h
@@ -32,6 +32,9 @@
/** Options for converting SPIR-V to Metal Shading Language */
typedef struct SPIRVToMSLConverterOptions {
+ std::string entryPointName;
+ spv::ExecutionModel entryPointStage = spv::ExecutionModelMax;
+
uint32_t mslVersion = makeMSLVersion(2);
bool shouldFlipVertexY = true;
bool isRenderingPoints = false;
@@ -40,13 +43,17 @@
* Returns whether the specified options match this one.
* It does if all corresponding elements are equal.
*/
- bool matches(SPIRVToMSLConverterOptions& other);
+ bool matches(const SPIRVToMSLConverterOptions& other) const;
+
+ bool hasEntryPoint() const {
+ return !entryPointName.empty() && entryPointStage != spv::ExecutionModelMax;
+ }
void setMSLVersion(uint32_t major, uint32_t minor = 0, uint32_t point = 0) {
mslVersion = makeMSLVersion(major, minor, point);
}
- bool supportsMSLVersion(uint32_t major, uint32_t minor = 0, uint32_t point = 0) {
+ bool supportsMSLVersion(uint32_t major, uint32_t minor = 0, uint32_t point = 0) const {
return mslVersion >= makeMSLVersion(major, minor, point);
}
@@ -74,7 +81,7 @@
* Returns whether the specified vertex attribute match this one.
* It does if all corresponding elements except isUsedByShader are equal.
*/
- bool matches(MSLVertexAttribute& other);
+ bool matches(const MSLVertexAttribute& other) const;
} MSLVertexAttribute;
@@ -100,7 +107,7 @@
* Returns whether the specified resource binding match this one.
* It does if all corresponding elements except isUsedByShader are equal.
*/
- bool matches(MSLResourceBinding& other);
+ bool matches(const MSLResourceBinding& other) const;
} MSLResourceBinding;
@@ -111,10 +118,10 @@
std::vector<MSLResourceBinding> resourceBindings;
/** Returns whether the vertex attribute at the specified location is used by the shader. */
- bool isVertexAttributeLocationUsed(uint32_t location);
+ bool isVertexAttributeLocationUsed(uint32_t location) const;
/** Returns whether the vertex buffer at the specified Metal binding index is used by the shader. */
- bool isVertexBufferUsed(uint32_t mslBuffer);
+ bool isVertexBufferUsed(uint32_t mslBuffer) const;
/**
* Returns whether this context matches the other context. It does if the respective
@@ -122,10 +129,10 @@
* can be found in the other context. Vertex attributes and resource bindings that are
* in the other context but are not used by the shader that created this context, are ignored.
*/
- bool matches(SPIRVToMSLConverterContext& other);
+ bool matches(const SPIRVToMSLConverterContext& other) const;
/** Aligns the usage of this context with that of the source context. */
- void alignUsageWith(SPIRVToMSLConverterContext& srcContext);
+ void alignUsageWith(const SPIRVToMSLConverterContext& srcContext);
} SPIRVToMSLConverterContext;
@@ -135,20 +142,22 @@
* and the number of threads in each workgroup or their specialization constant id, if the shader is a compute shader.
*/
typedef struct {
- std::string mtlFunctionName;
+ std::string mtlFunctionName = "main0";
struct {
uint32_t width = 1;
uint32_t height = 1;
uint32_t depth = 1;
} workgroupSize;
struct {
- uint32_t width, height, depth;
+ uint32_t width = 1;
+ uint32_t height = 1;
+ uint32_t depth = 1;
uint32_t constant = 0;
} workgroupSizeId;
} SPIRVEntryPoint;
/** Holds a map of entry point info, indexed by the SPIRV entry point name. */
- typedef std::unordered_map<std::string, SPIRVEntryPoint> SPIRVEntryPointsByName;
+// typedef std::unordered_map<std::string, SPIRVEntryPoint> SPIRVEntryPointsByName;
/** Special constant used in a MSLResourceBinding descriptorSet element to indicate the bindings for the push constants. */
static const uint32_t kPushConstDescSet = std::numeric_limits<uint32_t>::max();
@@ -195,8 +204,11 @@
*/
const std::string& getMSL() { return _msl; }
- /** Returns a mapping of entry point info, indexed by SPIR-V entry point name. */
- const SPIRVEntryPointsByName& getEntryPoints() { return _entryPoints; }
+ /** Returns information about the shader entry point. */
+ const SPIRVEntryPoint& getEntryPoint() { return _entryPoint; }
+
+ /** Returns a mapping of entry point info, indexed by SPIR-V entry point name. */
+// const SPIRVEntryPointsByName& getEntryPoints() { return _entryPoints; }
/**
* Returns whether the most recent conversion was successful.
@@ -212,10 +224,14 @@
const std::string& getResultLog() { return _resultLog; }
/** Sets MSL source code. This can be used when MSL is supplied directly. */
- void setMSL(const std::string& msl, const SPIRVEntryPointsByName& entryPoints) {
+ void setMSL(const std::string& msl, const SPIRVEntryPoint* pEntryPoint) {
_msl = msl;
- _entryPoints = entryPoints;
+ if (pEntryPoint) { _entryPoint = *pEntryPoint; }
}
+// void setMSL(const std::string& msl, const SPIRVEntryPointsByName& entryPoints) {
+// _msl = msl;
+// _entryPoints = entryPoints;
+// }
protected:
void logMsg(const char* logMsg);
@@ -228,7 +244,8 @@
std::vector<uint32_t> _spirv;
std::string _msl;
std::string _resultLog;
- SPIRVEntryPointsByName _entryPoints;
+ SPIRVEntryPoint _entryPoint;
+// SPIRVEntryPointsByName _entryPoints;
bool _wasConverted = false;
};
diff --git a/MoltenVKShaderConverter/MoltenVKShaderConverter.xcodeproj/project.pbxproj b/MoltenVKShaderConverter/MoltenVKShaderConverter.xcodeproj/project.pbxproj
index 69aba65..00d97ff 100644
--- a/MoltenVKShaderConverter/MoltenVKShaderConverter.xcodeproj/project.pbxproj
+++ b/MoltenVKShaderConverter/MoltenVKShaderConverter.xcodeproj/project.pbxproj
@@ -2847,7 +2847,6 @@
);
GCC_SYMBOLS_PRIVATE_EXTERN = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
- GCC_WARN_ABOUT_DEPRECATED_FUNCTIONS = NO;
GCC_WARN_ABOUT_MISSING_FIELD_INITIALIZERS = YES;
GCC_WARN_ABOUT_MISSING_PROTOTYPES = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES;
@@ -2895,7 +2894,6 @@
GCC_PREPROCESSOR_DEFINITIONS = "SPIRV_CROSS_FLT_FMT=\\\"%.6g\\\"";
GCC_SYMBOLS_PRIVATE_EXTERN = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
- GCC_WARN_ABOUT_DEPRECATED_FUNCTIONS = NO;
GCC_WARN_ABOUT_MISSING_FIELD_INITIALIZERS = YES;
GCC_WARN_ABOUT_MISSING_PROTOTYPES = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES;