Merge pull request #831 from cdavis5e/force-recompile-hooks

MSL: Hoist fixup hooks in entry_point_args() out of the compile loop.
diff --git a/reference/opt/shaders-msl/comp/force-recompile-hooks.swizzle.comp b/reference/opt/shaders-msl/comp/force-recompile-hooks.swizzle.comp
new file mode 100644
index 0000000..267cc51
--- /dev/null
+++ b/reference/opt/shaders-msl/comp/force-recompile-hooks.swizzle.comp
@@ -0,0 +1,138 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct spvAux
+{
+    uint swizzleConst[1];
+};
+
+enum class spvSwizzle : uint
+{
+    none = 0,
+    zero,
+    one,
+    red,
+    green,
+    blue,
+    alpha
+};
+
+template<typename T> struct spvRemoveReference { typedef T type; };
+template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };
+template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };
+template<typename T> inline constexpr thread T&& spvForward(thread typename spvRemoveReference<T>::type& x)
+{
+    return static_cast<thread T&&>(x);
+}
+template<typename T> inline constexpr thread T&& spvForward(thread typename spvRemoveReference<T>::type&& x)
+{
+    return static_cast<thread T&&>(x);
+}
+
+template<typename T>
+inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)
+{
+    switch (s)
+    {
+        case spvSwizzle::none:
+            return c;
+        case spvSwizzle::zero:
+            return 0;
+        case spvSwizzle::one:
+            return 1;
+        case spvSwizzle::red:
+            return x.r;
+        case spvSwizzle::green:
+            return x.g;
+        case spvSwizzle::blue:
+            return x.b;
+        case spvSwizzle::alpha:
+            return x.a;
+    }
+}
+
+// Wrapper function that swizzles texture samples and fetches.
+template<typename T>
+inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)
+{
+    if (!s)
+        return x;
+    return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) & 0xFF)), spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));
+}
+
+template<typename T>
+inline T spvTextureSwizzle(T x, uint s)
+{
+    return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;
+}
+
+// Wrapper function that swizzles texture gathers.
+template<typename T, typename Tex, typename... Ts>
+inline vec<T, 4> spvGatherSwizzle(sampler s, const thread Tex& t, Ts... params, component c, uint sw) METAL_CONST_ARG(c)
+{
+    if (sw)
+    {
+        switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))
+        {
+            case spvSwizzle::none:
+                break;
+            case spvSwizzle::zero:
+                return vec<T, 4>(0, 0, 0, 0);
+            case spvSwizzle::one:
+                return vec<T, 4>(1, 1, 1, 1);
+            case spvSwizzle::red:
+                return t.gather(s, spvForward<Ts>(params)..., component::x);
+            case spvSwizzle::green:
+                return t.gather(s, spvForward<Ts>(params)..., component::y);
+            case spvSwizzle::blue:
+                return t.gather(s, spvForward<Ts>(params)..., component::z);
+            case spvSwizzle::alpha:
+                return t.gather(s, spvForward<Ts>(params)..., component::w);
+        }
+    }
+    switch (c)
+    {
+        case component::x:
+            return t.gather(s, spvForward<Ts>(params)..., component::x);
+        case component::y:
+            return t.gather(s, spvForward<Ts>(params)..., component::y);
+        case component::z:
+            return t.gather(s, spvForward<Ts>(params)..., component::z);
+        case component::w:
+            return t.gather(s, spvForward<Ts>(params)..., component::w);
+    }
+}
+
+// Wrapper function that swizzles depth texture gathers.
+template<typename T, typename Tex, typename... Ts>
+inline vec<T, 4> spvGatherCompareSwizzle(sampler s, const thread Tex& t, Ts... params, uint sw) 
+{
+    if (sw)
+    {
+        switch (spvSwizzle(sw & 0xFF))
+        {
+            case spvSwizzle::none:
+            case spvSwizzle::red:
+                break;
+            case spvSwizzle::zero:
+            case spvSwizzle::green:
+            case spvSwizzle::blue:
+            case spvSwizzle::alpha:
+                return vec<T, 4>(0, 0, 0, 0);
+            case spvSwizzle::one:
+                return vec<T, 4>(1, 1, 1, 1);
+        }
+    }
+    return t.gather_compare(s, spvForward<Ts>(params)...);
+}
+
+kernel void main0(constant spvAux& spvAuxBuffer [[buffer(0)]], texture2d<float> foo [[texture(0)]], texture2d<float, access::write> bar [[texture(1)]], sampler fooSmplr [[sampler(0)]])
+{
+    constant uint32_t& fooSwzl = spvAuxBuffer.swizzleConst[0];
+    bar.write(spvTextureSwizzle(foo.sample(fooSmplr, float2(1.0), level(0.0)), fooSwzl), uint2(int2(0)));
+}
+
diff --git a/reference/shaders-msl/comp/force-recompile-hooks.swizzle.comp b/reference/shaders-msl/comp/force-recompile-hooks.swizzle.comp
new file mode 100644
index 0000000..667819d
--- /dev/null
+++ b/reference/shaders-msl/comp/force-recompile-hooks.swizzle.comp
@@ -0,0 +1,139 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct spvAux
+{
+    uint swizzleConst[1];
+};
+
+enum class spvSwizzle : uint
+{
+    none = 0,
+    zero,
+    one,
+    red,
+    green,
+    blue,
+    alpha
+};
+
+template<typename T> struct spvRemoveReference { typedef T type; };
+template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };
+template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };
+template<typename T> inline constexpr thread T&& spvForward(thread typename spvRemoveReference<T>::type& x)
+{
+    return static_cast<thread T&&>(x);
+}
+template<typename T> inline constexpr thread T&& spvForward(thread typename spvRemoveReference<T>::type&& x)
+{
+    return static_cast<thread T&&>(x);
+}
+
+template<typename T>
+inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)
+{
+    switch (s)
+    {
+        case spvSwizzle::none:
+            return c;
+        case spvSwizzle::zero:
+            return 0;
+        case spvSwizzle::one:
+            return 1;
+        case spvSwizzle::red:
+            return x.r;
+        case spvSwizzle::green:
+            return x.g;
+        case spvSwizzle::blue:
+            return x.b;
+        case spvSwizzle::alpha:
+            return x.a;
+    }
+}
+
+// Wrapper function that swizzles texture samples and fetches.
+template<typename T>
+inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)
+{
+    if (!s)
+        return x;
+    return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) & 0xFF)), spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));
+}
+
+template<typename T>
+inline T spvTextureSwizzle(T x, uint s)
+{
+    return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;
+}
+
+// Wrapper function that swizzles texture gathers.
+template<typename T, typename Tex, typename... Ts>
+inline vec<T, 4> spvGatherSwizzle(sampler s, const thread Tex& t, Ts... params, component c, uint sw) METAL_CONST_ARG(c)
+{
+    if (sw)
+    {
+        switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))
+        {
+            case spvSwizzle::none:
+                break;
+            case spvSwizzle::zero:
+                return vec<T, 4>(0, 0, 0, 0);
+            case spvSwizzle::one:
+                return vec<T, 4>(1, 1, 1, 1);
+            case spvSwizzle::red:
+                return t.gather(s, spvForward<Ts>(params)..., component::x);
+            case spvSwizzle::green:
+                return t.gather(s, spvForward<Ts>(params)..., component::y);
+            case spvSwizzle::blue:
+                return t.gather(s, spvForward<Ts>(params)..., component::z);
+            case spvSwizzle::alpha:
+                return t.gather(s, spvForward<Ts>(params)..., component::w);
+        }
+    }
+    switch (c)
+    {
+        case component::x:
+            return t.gather(s, spvForward<Ts>(params)..., component::x);
+        case component::y:
+            return t.gather(s, spvForward<Ts>(params)..., component::y);
+        case component::z:
+            return t.gather(s, spvForward<Ts>(params)..., component::z);
+        case component::w:
+            return t.gather(s, spvForward<Ts>(params)..., component::w);
+    }
+}
+
+// Wrapper function that swizzles depth texture gathers.
+template<typename T, typename Tex, typename... Ts>
+inline vec<T, 4> spvGatherCompareSwizzle(sampler s, const thread Tex& t, Ts... params, uint sw) 
+{
+    if (sw)
+    {
+        switch (spvSwizzle(sw & 0xFF))
+        {
+            case spvSwizzle::none:
+            case spvSwizzle::red:
+                break;
+            case spvSwizzle::zero:
+            case spvSwizzle::green:
+            case spvSwizzle::blue:
+            case spvSwizzle::alpha:
+                return vec<T, 4>(0, 0, 0, 0);
+            case spvSwizzle::one:
+                return vec<T, 4>(1, 1, 1, 1);
+        }
+    }
+    return t.gather_compare(s, spvForward<Ts>(params)...);
+}
+
+kernel void main0(constant spvAux& spvAuxBuffer [[buffer(0)]], texture2d<float> foo [[texture(0)]], texture2d<float, access::write> bar [[texture(1)]], sampler fooSmplr [[sampler(0)]])
+{
+    constant uint32_t& fooSwzl = spvAuxBuffer.swizzleConst[0];
+    float4 a = spvTextureSwizzle(foo.sample(fooSmplr, float2(1.0), level(0.0)), fooSwzl);
+    bar.write(a, uint2(int2(0)));
+}
+
diff --git a/shaders-msl/comp/force-recompile-hooks.swizzle.comp b/shaders-msl/comp/force-recompile-hooks.swizzle.comp
new file mode 100644
index 0000000..2752d30
--- /dev/null
+++ b/shaders-msl/comp/force-recompile-hooks.swizzle.comp
@@ -0,0 +1,9 @@
+#version 450
+
+layout(binding = 0) uniform sampler2D foo;
+layout(binding = 1, rgba8) uniform image2D bar;
+
+void main() {
+	vec4 a = texture(foo, vec2(1, 1));
+	imageStore(bar, ivec2(0, 0), a);
+}
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index 87b00ea..976722a 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -447,6 +447,10 @@
 	// Mark any non-stage-in structs to be tightly packed.
 	mark_packable_structs();
 
+	// Add fixup hooks required by shader inputs and outputs. This needs to happen before
+	// the loop, so the hooks aren't added multiple times.
+	fix_up_shader_inputs_outputs();
+
 	uint32_t pass_count = 0;
 	do
 	{
@@ -4463,17 +4467,6 @@
 				resources.push_back(
 				    { &id, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype) });
 			}
-
-			if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
-			{
-				auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
-				entry_func.fixup_hooks_in.push_back([this, &var, var_id]() {
-					auto &aux_type = expression_type(aux_buffer_id);
-					statement("constant uint32_t& ", to_swizzle_expression(var_id), " = ", to_name(aux_buffer_id), ".",
-					          to_member_name(aux_type, k_aux_mbr_idx_swizzle_const), "[",
-					          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
-				});
-			}
 		}
 	});
 
@@ -4554,27 +4547,7 @@
 		// point, we get that by calling get_sample_position() on the sample ID.
 		if (var.storage == StorageClassInput && is_builtin_variable(var))
 		{
-			if (bi_type == BuiltInSamplePosition)
-			{
-				auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
-				entry_func.fixup_hooks_in.push_back([=]() {
-					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
-					          to_expression(builtin_sample_id_id), ");");
-				});
-			}
-			else if (bi_type == BuiltInHelperInvocation)
-			{
-				if (msl_options.is_ios())
-					SPIRV_CROSS_THROW("simd_is_helper_thread() is only supported on macOS.");
-				else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
-					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
-
-				auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
-				entry_func.fixup_hooks_in.push_back([=]() {
-					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
-				});
-			}
-			else
+			if (bi_type != BuiltInSamplePosition && bi_type != BuiltInHelperInvocation)
 			{
 				if (!ep_args.empty())
 					ep_args += ", ";
@@ -4598,6 +4571,64 @@
 	return ep_args;
 }
 
+void CompilerMSL::fix_up_shader_inputs_outputs()
+{
+	// Look for sampled images. Add hooks to set up the swizzle constants.
+	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
+		auto &type = get_variable_data_type(var);
+
+		uint32_t var_id = var.self;
+
+		if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
+		     var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
+		    !is_hidden_variable(var))
+		{
+			if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
+			{
+				auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
+				entry_func.fixup_hooks_in.push_back([this, &var, var_id]() {
+					auto &aux_type = expression_type(aux_buffer_id);
+					statement("constant uint32_t& ", to_swizzle_expression(var_id), " = ", to_name(aux_buffer_id), ".",
+					          to_member_name(aux_type, k_aux_mbr_idx_swizzle_const), "[",
+					          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
+				});
+			}
+		}
+	});
+
+	// Builtin variables
+	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
+		uint32_t var_id = var.self;
+		BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
+
+		if (var.storage == StorageClassInput && is_builtin_variable(var))
+		{
+			auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
+			switch (bi_type)
+			{
+			case BuiltInSamplePosition:
+				entry_func.fixup_hooks_in.push_back([=]() {
+					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
+					          to_expression(builtin_sample_id_id), ");");
+				});
+				break;
+			case BuiltInHelperInvocation:
+				if (msl_options.is_ios())
+					SPIRV_CROSS_THROW("simd_is_helper_thread() is only supported on macOS.");
+				else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
+					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
+
+				entry_func.fixup_hooks_in.push_back([=]() {
+					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
+				});
+				break;
+			default:
+				break;
+			}
+		}
+	});
+}
+
 // Returns the Metal index of the resource of the specified type as used by the specified variable.
 uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype)
 {
diff --git a/spirv_msl.hpp b/spirv_msl.hpp
index 8fcb224..3ea5b4d 100644
--- a/spirv_msl.hpp
+++ b/spirv_msl.hpp
@@ -392,6 +392,7 @@
 	void emit_interface_block(uint32_t ib_var_id);
 	bool maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs);
 	void add_convert_row_major_matrix_function(uint32_t cols, uint32_t rows);
+	void fix_up_shader_inputs_outputs();
 
 	std::string func_type_decl(SPIRType &type);
 	std::string entry_point_args(bool append_comma);