Merge pull request #1745 from billhollings/location-component-vecsize

MSL: Track location component to match vecsize between shader stages.
diff --git a/reference/opt/shaders-msl/asm/comp/quantize.asm.comp b/reference/opt/shaders-msl/asm/comp/quantize.asm.comp
index 1839ec7..b7e6f91 100644
--- a/reference/opt/shaders-msl/asm/comp/quantize.asm.comp
+++ b/reference/opt/shaders-msl/asm/comp/quantize.asm.comp
@@ -1,3 +1,5 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
 #include <metal_stdlib>
 #include <simd/simd.h>
 
@@ -11,11 +13,20 @@
     float4 vec4_val;
 };
 
+template <typename F> struct SpvHalfTypeSelector;
+template <> struct SpvHalfTypeSelector<float> { public: using H = half; };
+template<uint N> struct SpvHalfTypeSelector<vec<float, N>> { using H = vec<half, N>; };
+template<typename F, typename H = typename SpvHalfTypeSelector<F>::H>
+[[clang::optnone]] F spvQuantizeToF16(F val)
+{
+    return F(H(val));
+}
+
 kernel void main0(device SSBO0& _4 [[buffer(0)]])
 {
-    _4.scalar = float(half(_4.scalar));
-    _4.vec2_val = float2(half2(_4.vec2_val));
-    _4.vec3_val = float3(half3(_4.vec3_val));
-    _4.vec4_val = float4(half4(_4.vec4_val));
+    _4.scalar = spvQuantizeToF16(_4.scalar);
+    _4.vec2_val = spvQuantizeToF16(_4.vec2_val);
+    _4.vec3_val = spvQuantizeToF16(_4.vec3_val);
+    _4.vec4_val = spvQuantizeToF16(_4.vec4_val);
 }
 
diff --git a/reference/opt/shaders-msl/vert/float-math.invariant-float-math.vert b/reference/opt/shaders-msl/vert/float-math.invariant-float-math.vert
index 05e09e2..d8f44be 100644
--- a/reference/opt/shaders-msl/vert/float-math.invariant-float-math.vert
+++ b/reference/opt/shaders-msl/vert/float-math.invariant-float-math.vert
@@ -69,13 +69,13 @@
 };
 
 template<typename T>
-T spvFMul(T l, T r)
+[[clang::optnone]] T spvFMul(T l, T r)
 {
     return fma(l, r, T(0));
 }
 
 template<typename T, int Cols, int Rows>
-vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)
+[[clang::optnone]] vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)
 {
     vec<T, Cols> res = vec<T, Cols>(0);
     for (uint i = Rows; i > 0; --i)
@@ -91,7 +91,7 @@
 }
 
 template<typename T, int Cols, int Rows>
-vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)
+[[clang::optnone]] vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)
 {
     vec<T, Rows> res = vec<T, Rows>(0);
     for (uint i = Cols; i > 0; --i)
@@ -102,7 +102,7 @@
 }
 
 template<typename T, int LCols, int LRows, int RCols, int RRows>
-matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)
+[[clang::optnone]] matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)
 {
     matrix<T, RCols, LRows> res;
     for (uint i = 0; i < RCols; i++)
diff --git a/reference/opt/shaders-msl/vert/no-contraction.vert b/reference/opt/shaders-msl/vert/no-contraction.vert
index a48731e..0b75dbc 100644
--- a/reference/opt/shaders-msl/vert/no-contraction.vert
+++ b/reference/opt/shaders-msl/vert/no-contraction.vert
@@ -18,13 +18,13 @@
 };
 
 template<typename T>
-T spvFMul(T l, T r)
+[[clang::optnone]] T spvFMul(T l, T r)
 {
     return fma(l, r, T(0));
 }
 
 template<typename T, int Cols, int Rows>
-vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)
+[[clang::optnone]] vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)
 {
     vec<T, Cols> res = vec<T, Cols>(0);
     for (uint i = Rows; i > 0; --i)
@@ -40,7 +40,7 @@
 }
 
 template<typename T, int Cols, int Rows>
-vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)
+[[clang::optnone]] vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)
 {
     vec<T, Rows> res = vec<T, Rows>(0);
     for (uint i = Cols; i > 0; --i)
@@ -51,7 +51,7 @@
 }
 
 template<typename T, int LCols, int LRows, int RCols, int RRows>
-matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)
+[[clang::optnone]] matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)
 {
     matrix<T, RCols, LRows> res;
     for (uint i = 0; i < RCols; i++)
@@ -67,13 +67,13 @@
 }
 
 template<typename T>
-T spvFAdd(T l, T r)
+[[clang::optnone]] T spvFAdd(T l, T r)
 {
     return fma(T(1), l, r);
 }
 
 template<typename T>
-T spvFSub(T l, T r)
+[[clang::optnone]] T spvFSub(T l, T r)
 {
     return fma(T(-1), r, l);
 }
diff --git a/reference/shaders-msl/asm/comp/quantize.asm.comp b/reference/shaders-msl/asm/comp/quantize.asm.comp
index 1839ec7..b7e6f91 100644
--- a/reference/shaders-msl/asm/comp/quantize.asm.comp
+++ b/reference/shaders-msl/asm/comp/quantize.asm.comp
@@ -1,3 +1,5 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
 #include <metal_stdlib>
 #include <simd/simd.h>
 
@@ -11,11 +13,20 @@
     float4 vec4_val;
 };
 
+template <typename F> struct SpvHalfTypeSelector;
+template <> struct SpvHalfTypeSelector<float> { public: using H = half; };
+template<uint N> struct SpvHalfTypeSelector<vec<float, N>> { using H = vec<half, N>; };
+template<typename F, typename H = typename SpvHalfTypeSelector<F>::H>
+[[clang::optnone]] F spvQuantizeToF16(F val)
+{
+    return F(H(val));
+}
+
 kernel void main0(device SSBO0& _4 [[buffer(0)]])
 {
-    _4.scalar = float(half(_4.scalar));
-    _4.vec2_val = float2(half2(_4.vec2_val));
-    _4.vec3_val = float3(half3(_4.vec3_val));
-    _4.vec4_val = float4(half4(_4.vec4_val));
+    _4.scalar = spvQuantizeToF16(_4.scalar);
+    _4.vec2_val = spvQuantizeToF16(_4.vec2_val);
+    _4.vec3_val = spvQuantizeToF16(_4.vec3_val);
+    _4.vec4_val = spvQuantizeToF16(_4.vec4_val);
 }
 
diff --git a/reference/shaders-msl/vert/float-math.invariant-float-math.vert b/reference/shaders-msl/vert/float-math.invariant-float-math.vert
index d603884..06844ca 100644
--- a/reference/shaders-msl/vert/float-math.invariant-float-math.vert
+++ b/reference/shaders-msl/vert/float-math.invariant-float-math.vert
@@ -69,13 +69,13 @@
 };
 
 template<typename T>
-T spvFMul(T l, T r)
+[[clang::optnone]] T spvFMul(T l, T r)
 {
     return fma(l, r, T(0));
 }
 
 template<typename T, int Cols, int Rows>
-vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)
+[[clang::optnone]] vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)
 {
     vec<T, Cols> res = vec<T, Cols>(0);
     for (uint i = Rows; i > 0; --i)
@@ -91,7 +91,7 @@
 }
 
 template<typename T, int Cols, int Rows>
-vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)
+[[clang::optnone]] vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)
 {
     vec<T, Rows> res = vec<T, Rows>(0);
     for (uint i = Cols; i > 0; --i)
@@ -102,7 +102,7 @@
 }
 
 template<typename T, int LCols, int LRows, int RCols, int RRows>
-matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)
+[[clang::optnone]] matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)
 {
     matrix<T, RCols, LRows> res;
     for (uint i = 0; i < RCols; i++)
diff --git a/reference/shaders-msl/vert/no-contraction.vert b/reference/shaders-msl/vert/no-contraction.vert
index 907d901..653dc26 100644
--- a/reference/shaders-msl/vert/no-contraction.vert
+++ b/reference/shaders-msl/vert/no-contraction.vert
@@ -18,13 +18,13 @@
 };
 
 template<typename T>
-T spvFMul(T l, T r)
+[[clang::optnone]] T spvFMul(T l, T r)
 {
     return fma(l, r, T(0));
 }
 
 template<typename T, int Cols, int Rows>
-vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)
+[[clang::optnone]] vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)
 {
     vec<T, Cols> res = vec<T, Cols>(0);
     for (uint i = Rows; i > 0; --i)
@@ -40,7 +40,7 @@
 }
 
 template<typename T, int Cols, int Rows>
-vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)
+[[clang::optnone]] vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)
 {
     vec<T, Rows> res = vec<T, Rows>(0);
     for (uint i = Cols; i > 0; --i)
@@ -51,7 +51,7 @@
 }
 
 template<typename T, int LCols, int LRows, int RCols, int RRows>
-matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)
+[[clang::optnone]] matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)
 {
     matrix<T, RCols, LRows> res;
     for (uint i = 0; i < RCols; i++)
@@ -67,13 +67,13 @@
 }
 
 template<typename T>
-T spvFAdd(T l, T r)
+[[clang::optnone]] T spvFAdd(T l, T r)
 {
     return fma(T(1), l, r);
 }
 
 template<typename T>
-T spvFSub(T l, T r)
+[[clang::optnone]] T spvFSub(T l, T r)
 {
     return fma(T(-1), r, l);
 }
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index e37b13e..f4ab0cc 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -4935,7 +4935,7 @@
 		// "fadd" intrinsic support
 		case SPVFuncImplFAdd:
 			statement("template<typename T>");
-			statement("T spvFAdd(T l, T r)");
+			statement("[[clang::optnone]] T spvFAdd(T l, T r)");
 			begin_scope();
 			statement("return fma(T(1), l, r);");
 			end_scope();
@@ -4945,7 +4945,7 @@
 		// "fsub" intrinsic support
 		case SPVFuncImplFSub:
 			statement("template<typename T>");
-			statement("T spvFSub(T l, T r)");
+			statement("[[clang::optnone]] T spvFSub(T l, T r)");
 			begin_scope();
 			statement("return fma(T(-1), r, l);");
 			end_scope();
@@ -4955,14 +4955,14 @@
 		// "fmul' intrinsic support
 		case SPVFuncImplFMul:
 			statement("template<typename T>");
-			statement("T spvFMul(T l, T r)");
+			statement("[[clang::optnone]] T spvFMul(T l, T r)");
 			begin_scope();
 			statement("return fma(l, r, T(0));");
 			end_scope();
 			statement("");
 
 			statement("template<typename T, int Cols, int Rows>");
-			statement("vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
+			statement("[[clang::optnone]] vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
 			begin_scope();
 			statement("vec<T, Cols> res = vec<T, Cols>(0);");
 			statement("for (uint i = Rows; i > 0; --i)");
@@ -4979,7 +4979,7 @@
 			statement("");
 
 			statement("template<typename T, int Cols, int Rows>");
-			statement("vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
+			statement("[[clang::optnone]] vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
 			begin_scope();
 			statement("vec<T, Rows> res = vec<T, Rows>(0);");
 			statement("for (uint i = Cols; i > 0; --i)");
@@ -4991,8 +4991,7 @@
 			statement("");
 
 			statement("template<typename T, int LCols, int LRows, int RCols, int RRows>");
-			statement(
-			    "matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
+			statement("[[clang::optnone]] matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
 			begin_scope();
 			statement("matrix<T, RCols, LRows> res;");
 			statement("for (uint i = 0; i < RCols; i++)");
@@ -5009,6 +5008,20 @@
 			statement("");
 			break;
 
+		case SPVFuncImplQuantizeToF16:
+			// Ensure fast-math is disabled to match Vulkan results.
+			// SpvHalfTypeSelector is used to match the half* template type to the float* template type.
+			statement("template <typename F> struct SpvHalfTypeSelector;");
+			statement("template <> struct SpvHalfTypeSelector<float> { public: using H = half; };");
+			statement("template<uint N> struct SpvHalfTypeSelector<vec<float, N>> { using H = vec<half, N>; };");
+			statement("template<typename F, typename H = typename SpvHalfTypeSelector<F>::H>");
+			statement("[[clang::optnone]] F spvQuantizeToF16(F val)");
+			begin_scope();
+			statement("return F(H(val));");
+			end_scope();
+			statement("");
+			break;
+
 		// Emulate texturecube_array with texture2d_array for iOS where this type is not available
 		case SPVFuncImplCubemapTo2DArrayFace:
 			statement(force_inline);
@@ -8072,28 +8085,7 @@
 		uint32_t result_type = ops[0];
 		uint32_t id = ops[1];
 		uint32_t arg = ops[2];
-
-		string exp;
-		auto &type = get<SPIRType>(result_type);
-
-		switch (type.vecsize)
-		{
-		case 1:
-			exp = join("float(half(", to_expression(arg), "))");
-			break;
-		case 2:
-			exp = join("float2(half2(", to_expression(arg), "))");
-			break;
-		case 3:
-			exp = join("float3(half3(", to_expression(arg), "))");
-			break;
-		case 4:
-			exp = join("float4(half4(", to_expression(arg), "))");
-			break;
-		default:
-			SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16.");
-		}
-
+		string exp = join("spvQuantizeToF16(", to_expression(arg), ")");
 		emit_op(result_type, id, exp, should_forward(arg));
 		break;
 	}
@@ -15062,6 +15054,9 @@
 		}
 		break;
 
+	case OpQuantizeToF16:
+		return SPVFuncImplQuantizeToF16;
+
 	case OpTypeArray:
 	{
 		// Allow Metal to use the array<T> template to make arrays a value type
diff --git a/spirv_msl.hpp b/spirv_msl.hpp
index a2b1b55..50d0686 100644
--- a/spirv_msl.hpp
+++ b/spirv_msl.hpp
@@ -657,6 +657,7 @@
 		SPVFuncImplFMul,
 		SPVFuncImplFAdd,
 		SPVFuncImplFSub,
+		SPVFuncImplQuantizeToF16,
 		SPVFuncImplCubemapTo2DArrayFace,
 		SPVFuncImplUnsafeArray, // Allow Metal to use the array<T> template to make arrays a value type
 		SPVFuncImplInverse4x4,