Merge pull request #1509 from billhollings/shaderconverter-updates

MoltenVKShaderConverter updates
diff --git a/ExternalRevisions/SPIRV-Cross_repo_revision b/ExternalRevisions/SPIRV-Cross_repo_revision
index 11d4f3b..bbb99b3 100644
--- a/ExternalRevisions/SPIRV-Cross_repo_revision
+++ b/ExternalRevisions/SPIRV-Cross_repo_revision
@@ -1 +1 @@
-e9cc6403341baf0edd430a4027b074d0a06b782f
+53d94a982e1d654515b44db5391de37f85489204
diff --git a/MoltenVKShaderConverter/MoltenVKShaderConverter/SPIRVReflection.h b/MoltenVKShaderConverter/MoltenVKShaderConverter/SPIRVReflection.h
index ad4ca40..63a0491 100644
--- a/MoltenVKShaderConverter/MoltenVKShaderConverter/SPIRVReflection.h
+++ b/MoltenVKShaderConverter/MoltenVKShaderConverter/SPIRVReflection.h
@@ -173,6 +173,42 @@
 #endif
 	}
 
+	auto addSat = [](uint32_t a, uint32_t b) { return a == uint32_t(-1) ? a : a + b; };
+
+	template<typename Vo>
+	static inline uint32_t getShaderOutputStructMembers(const SPIRV_CROSS_NAMESPACE::CompilerReflection& reflect, Vo& outputs,
+														const SPIRV_CROSS_NAMESPACE::SPIRType* structType, spv::StorageClass storage,
+														bool patch, uint32_t loc) {
+		bool isUsed = true;
+		auto biType = spv::BuiltInMax;
+		size_t mbrCnt = structType->member_types.size();
+		for (uint32_t mbrIdx = 0; mbrIdx < mbrCnt; mbrIdx++) {
+			// Each member may have a location decoration. If not, each member
+			// gets an incrementing location based on the base location for the struct.
+			uint32_t cmp = 0;
+			if (reflect.has_member_decoration(structType->self, mbrIdx, spv::DecorationLocation)) {
+				loc = reflect.get_member_decoration(structType->self, mbrIdx, spv::DecorationLocation);
+				cmp = reflect.get_member_decoration(structType->self, mbrIdx, spv::DecorationComponent);
+			}
+			patch = patch || reflect.has_member_decoration(structType->self, mbrIdx, spv::DecorationPatch);
+			if (reflect.has_member_decoration(structType->self, mbrIdx, spv::DecorationBuiltIn)) {
+				biType = (spv::BuiltIn)reflect.get_member_decoration(structType->self, mbrIdx, spv::DecorationBuiltIn);
+				isUsed = reflect.has_active_builtin(biType, storage);
+			}
+			const SPIRV_CROSS_NAMESPACE::SPIRType* type = &reflect.get_type(structType->member_types[mbrIdx]);
+			uint32_t elemCnt = (type->array.empty() ? 1 : type->array[0]) * type->columns;
+			for (uint32_t i = 0; i < elemCnt; i++) {
+				if (type->basetype == SPIRV_CROSS_NAMESPACE::SPIRType::Struct)
+					loc = getShaderOutputStructMembers(reflect, outputs, type, storage, patch, loc);
+				else {
+					outputs.push_back({type->basetype, type->vecsize, loc, cmp, biType, patch, isUsed});
+					loc = addSat(loc, 1);
+				}
+			}
+		}
+		return loc;
+	}
+
 	/** Given a shader in SPIR-V format, returns output reflection data. */
 	template<typename Vs, typename Vo>
 	static inline bool getShaderOutputs(const Vs& spirv, spv::ExecutionModel model, const std::string& entryName,
@@ -191,7 +227,6 @@
 
 			outputs.clear();
 
-			auto addSat = [](uint32_t a, uint32_t b) { return a == uint32_t(-1) ? a : a + b; };
 			for (auto varID : reflect.get_active_interface_variables()) {
 				spv::StorageClass storage = reflect.get_storage_class(varID);
 				if (storage != spv::StorageClassOutput) { continue; }
@@ -215,47 +250,14 @@
 				if (model == spv::ExecutionModelTessellationControl && !patch)
 					type = &reflect.get_type(type->parent_type);
 
-				if (type->basetype == SPIRV_CROSS_NAMESPACE::SPIRType::Struct) {
-					uint32_t memberLoc = loc;
-					for (uint32_t idx = 0; idx < type->member_types.size(); idx++) {
-						// Each member may have a location decoration. If not, each member
-						// gets an incrementing location based the base location for the struct.
-						uint32_t memberCmp = 0;
-						if (reflect.has_member_decoration(type->self, idx, spv::DecorationLocation)) {
-							memberLoc = reflect.get_member_decoration(type->self, idx, spv::DecorationLocation);
-							memberCmp = reflect.get_member_decoration(type->self, idx, spv::DecorationComponent);
-						}
-						patch = patch || reflect.has_member_decoration(type->self, idx, spv::DecorationPatch);
-						if (reflect.has_member_decoration(type->self, idx, spv::DecorationBuiltIn)) {
-							biType = (spv::BuiltIn)reflect.get_member_decoration(type->self, idx, spv::DecorationBuiltIn);
-							isUsed = reflect.has_active_builtin(biType, storage);
-						}
-						const SPIRV_CROSS_NAMESPACE::SPIRType& memberType = reflect.get_type(type->member_types[idx]);
-						if (memberType.columns > 1) {
-							for (uint32_t i = 0; i < memberType.columns; i++) {
-								outputs.push_back({memberType.basetype, memberType.vecsize, memberLoc, memberCmp, biType, patch, isUsed});
-								memberLoc = addSat(memberLoc, 1);
-							}
-						} else if (!memberType.array.empty()) {
-							for (uint32_t i = 0; i < memberType.array[0]; i++) {
-								outputs.push_back({memberType.basetype, memberType.vecsize, memberLoc, memberCmp, biType, patch, isUsed});
-								memberLoc = addSat(memberLoc, 1);
-							}
-						} else {
-							outputs.push_back({memberType.basetype, memberType.vecsize, memberLoc, memberCmp, biType, patch, isUsed});
-							memberLoc = addSat(memberLoc, 1);
-						}
+				uint32_t elemCnt = (type->array.empty() ? 1 : type->array[0]) * type->columns;
+				for (uint32_t i = 0; i < elemCnt; i++) {
+					if (type->basetype == SPIRV_CROSS_NAMESPACE::SPIRType::Struct)
+						loc = getShaderOutputStructMembers(reflect, outputs, type, storage, patch, loc);
+					else {
+						outputs.push_back({type->basetype, type->vecsize, loc, cmp, biType, patch, isUsed});
+						loc = addSat(loc, 1);
 					}
-				} else if (type->columns > 1) {
-					for (uint32_t i = 0; i < type->columns; i++) {
-						outputs.push_back({type->basetype, type->vecsize, addSat(loc, i), cmp, biType, patch, isUsed});
-					}
-				} else if (!type->array.empty()) {
-					for (uint32_t i = 0; i < type->array[0]; i++) {
-						outputs.push_back({type->basetype, type->vecsize, addSat(loc, i), cmp, biType, patch, isUsed});
-					}
-				} else {
-					outputs.push_back({type->basetype, type->vecsize, loc, cmp, biType, patch, isUsed});
 				}
 			}
 			// Sort outputs by ascending location.
diff --git a/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.cpp b/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.cpp
index 20327d4..2643844 100644
--- a/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.cpp
+++ b/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.cpp
@@ -31,6 +31,12 @@
 // The default list of vertex file extensions.
 static const char* _defaultVertexShaderExtns = "vs vsh vert vertex";
 
+// The default list of tessellation control file extensions.
+static const char* _defaultTescShaderExtns = "tcs tcsh tesc";
+
+// The default list of tessellation evaluation file extensions.
+static const char* _defaultTeseShaderExtns = "tes tesh tese";
+
 // The default list of fragment file extensions.
 static const char* _defaultFragShaderExtns = "fs fsh frag fragment";
 
@@ -261,6 +267,8 @@
 
 MVKGLSLConversionShaderStage MoltenVKShaderConverterTool::shaderStageFromFileExtension(string& pathExtension) {
     for (auto& fx : _glslVtxFileExtns) { if (fx == pathExtension) { return kMVKGLSLConversionShaderStageVertex; } }
+	for (auto& fx : _glslTescFileExtns) { if (fx == pathExtension) { return kMVKGLSLConversionShaderStageTessControl; } }
+	for (auto& fx : _glslTeseFileExtns) { if (fx == pathExtension) { return kMVKGLSLConversionShaderStageTessEval; } }
     for (auto& fx : _glslFragFileExtns) { if (fx == pathExtension) { return kMVKGLSLConversionShaderStageFragment; } }
     for (auto& fx : _glslCompFileExtns) { if (fx == pathExtension) { return kMVKGLSLConversionShaderStageCompute; } }
 	return kMVKGLSLConversionShaderStageAuto;
@@ -268,6 +276,8 @@
 
 bool MoltenVKShaderConverterTool::isGLSLFileExtension(string& pathExtension) {
     for (auto& fx : _glslVtxFileExtns) { if (fx == pathExtension) { return true; } }
+	for (auto& fx : _glslTescFileExtns) { if (fx == pathExtension) { return true; } }
+	for (auto& fx : _glslTeseFileExtns) { if (fx == pathExtension) { return true; } }
     for (auto& fx : _glslFragFileExtns) { if (fx == pathExtension) { return true; } }
     for (auto& fx : _glslCompFileExtns) { if (fx == pathExtension) { return true; } }
 	return false;
@@ -344,6 +354,10 @@
 	log("                       (myshdr.vsh -> myshdr.metal).");
 	log("  -vx \"fileExtns\"    - List of GLSL vertex shader file extensions.");
 	log("                       May be omitted for defaults (\"vs vsh vert vertex\").");
+	log("  -tcx \"fileExtns\"   - List of GLSL tessellation control shader file extensions.");
+	log("                       May be omitted for defaults (\"tcs tcsh tesc\").");
+	log("  -tex \"fileExtns\"   - List of GLSL tessellation evaluation shader file extensions.");
+	log("                       May be omitted for defaults (\"tes tesh tese\").");
 	log("  -fx \"fileExtns\"    - List of GLSL fragment shader file extensions.");
 	log("                       May be omitted for defaults (\"fs fsh frag fragment\").");
     log("  -cx \"fileExtns\"    - List of GLSL compute shader file extensions.");
@@ -386,6 +400,8 @@
 
 MoltenVKShaderConverterTool::MoltenVKShaderConverterTool(int argc, const char* argv[]) {
 	extractTokens(_defaultVertexShaderExtns, _glslVtxFileExtns);
+	extractTokens(_defaultTescShaderExtns, _glslTescFileExtns);
+	extractTokens(_defaultTeseShaderExtns, _glslTeseFileExtns);
 	extractTokens(_defaultFragShaderExtns, _glslFragFileExtns);
     extractTokens(_defaultCompShaderExtns, _glslCompFileExtns);
 	extractTokens(_defaultSPIRVShaderExtns, _spvFileExtns);
@@ -405,7 +421,7 @@
 	_quietMode = false;
 
 	_mslVersionMajor = 2;
-	_mslVersionMinor = 2;
+	_mslVersionMinor = 4;
 	_mslVersionPatch = 0;
 	_mslPlatform = SPIRVToMSLConversionOptions().mslOptions.platform;
 
@@ -553,6 +569,24 @@
 			continue;
 		}
 
+		if (equal(arg, "-tcx", true)) {
+			int optIdx = argIdx;
+			string shdrExtnStr;
+			argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv);
+			if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; }
+			extractTokens(shdrExtnStr, _glslTescFileExtns);
+			continue;
+		}
+
+		if (equal(arg, "-tex", true)) {
+			int optIdx = argIdx;
+			string shdrExtnStr;
+			argIdx = optionalParam(shdrExtnStr, argIdx, argc, argv);
+			if (argIdx == optIdx || shdrExtnStr.length() == 0) { return false; }
+			extractTokens(shdrExtnStr, _glslTeseFileExtns);
+			continue;
+		}
+
 		if (equal(arg, "-fx", true)) {
 			int optIdx = argIdx;
 			string shdrExtnStr;
diff --git a/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.h b/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.h
index 81132d7..58accd1 100644
--- a/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.h
+++ b/MoltenVKShaderConverter/MoltenVKShaderConverterTool/MoltenVKShaderConverterTool.h
@@ -98,6 +98,8 @@
 		std::string _hdrOutVarName;
 		std::string _origPathExtnSep;
 		std::vector<std::string> _glslVtxFileExtns;
+		std::vector<std::string> _glslTescFileExtns;
+		std::vector<std::string> _glslTeseFileExtns;
 		std::vector<std::string> _glslFragFileExtns;
         std::vector<std::string> _glslCompFileExtns;
 		std::vector<std::string> _spvFileExtns;
diff --git a/MoltenVKShaderConverter/MoltenVKShaderConverterTool/OSSupport.mm b/MoltenVKShaderConverter/MoltenVKShaderConverterTool/OSSupport.mm
index f0f8cb9..61691d1 100644
--- a/MoltenVKShaderConverter/MoltenVKShaderConverterTool/OSSupport.mm
+++ b/MoltenVKShaderConverter/MoltenVKShaderConverterTool/OSSupport.mm
@@ -100,12 +100,18 @@
 	}
 
 	@autoreleasepool {
+		NSArray* mtlDevs = [MTLCopyAllDevices() autorelease];
+		if (mtlDevs.count == 0) {
+			errMsg = "Could not retrieve MTLDevice to compile shader.";
+			return false;
+		}
+
 		MTLCompileOptions* mtlCompileOptions  = [[MTLCompileOptions new] autorelease];
 		mtlCompileOptions.languageVersion = mslVerEnum;
 		NSError* err = nil;
-		id<MTLLibrary> mtlLib = [[MTLCreateSystemDefaultDevice() newLibraryWithSource: @(mslSourceCode.c_str())
-																			  options: mtlCompileOptions
-																				error: &err] autorelease];
+		id<MTLLibrary> mtlLib = [[mtlDevs[0] newLibraryWithSource: @(mslSourceCode.c_str())
+														  options: mtlCompileOptions
+															error: &err] autorelease];
 		errMsg = err ? [NSString stringWithFormat: @"(Error code %li):\n%@", (long)err.code, err.localizedDescription].UTF8String : "";
 		return !!mtlLib;
 	}