Merge pull request #1787 from KhronosGroup/fix-1786

MSL: Workaround compiler crashes when using threadgroup bool.
diff --git a/reference/opt/shaders-msl/comp/threadgroup-boolean-workaround.comp b/reference/opt/shaders-msl/comp/threadgroup-boolean-workaround.comp
new file mode 100644
index 0000000..8b80929
--- /dev/null
+++ b/reference/opt/shaders-msl/comp/threadgroup-boolean-workaround.comp
@@ -0,0 +1,20 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct SSBO
+{
+    float4 values[1];
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
+
+kernel void main0(device SSBO& _23 [[buffer(0)]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
+{
+    threadgroup short4 foo[4];
+    foo[gl_LocalInvocationIndex] = short4((isunordered(_23.values[gl_GlobalInvocationID.x], float4(10.0)) || _23.values[gl_GlobalInvocationID.x] != float4(10.0)));
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    _23.values[gl_GlobalInvocationID.x] = select(float4(40.0), float4(30.0), bool4(foo[gl_LocalInvocationIndex ^ 3u]));
+}
+
diff --git a/reference/shaders-msl/comp/threadgroup-boolean-workaround.comp b/reference/shaders-msl/comp/threadgroup-boolean-workaround.comp
new file mode 100644
index 0000000..d01b135
--- /dev/null
+++ b/reference/shaders-msl/comp/threadgroup-boolean-workaround.comp
@@ -0,0 +1,28 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct SSBO
+{
+    float4 values[1];
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
+
+static inline __attribute__((always_inline))
+void in_function(threadgroup short4 (&foo)[4], thread uint& gl_LocalInvocationIndex, device SSBO& v_23, thread uint3& gl_GlobalInvocationID)
+{
+    foo[gl_LocalInvocationIndex] = short4((isunordered(v_23.values[gl_GlobalInvocationID.x], float4(10.0)) || v_23.values[gl_GlobalInvocationID.x] != float4(10.0)));
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    v_23.values[gl_GlobalInvocationID.x] = select(float4(40.0), float4(30.0), bool4(foo[gl_LocalInvocationIndex ^ 3u]));
+}
+
+kernel void main0(device SSBO& v_23 [[buffer(0)]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
+{
+    threadgroup short4 foo[4];
+    in_function(foo, gl_LocalInvocationIndex, v_23, gl_GlobalInvocationID);
+}
+
diff --git a/shaders-msl/comp/threadgroup-boolean-workaround.comp b/shaders-msl/comp/threadgroup-boolean-workaround.comp
new file mode 100644
index 0000000..8dce77a
--- /dev/null
+++ b/shaders-msl/comp/threadgroup-boolean-workaround.comp
@@ -0,0 +1,21 @@
+#version 450
+layout(local_size_x = 4) in;
+
+shared bvec4 foo[4];
+
+layout(binding = 0) buffer SSBO
+{
+	vec4 values[];
+};
+
+void in_function()
+{
+	foo[gl_LocalInvocationIndex] = notEqual(values[gl_GlobalInvocationID.x], vec4(10.0));
+	barrier();
+	values[gl_GlobalInvocationID.x] = mix(vec4(40.0), vec4(30.0), foo[gl_LocalInvocationIndex ^ 3]);
+}
+
+void main()
+{
+	in_function();
+}
diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp
index aa891c1..0a00947 100644
--- a/spirv_glsl.cpp
+++ b/spirv_glsl.cpp
@@ -9891,7 +9891,7 @@
 				convert_non_uniform_expression(lhs, lhs_expression);
 
 			// We might need to cast in order to store to a builtin.
-			cast_to_builtin_store(lhs_expression, rhs, expression_type(rhs_expression));
+			cast_to_variable_store(lhs_expression, rhs, expression_type(rhs_expression));
 
 			// Tries to optimize assignments like "<lhs> = <lhs> op expr".
 			// While this is purely cosmetic, this is important for legacy ESSL where loop
@@ -10056,7 +10056,7 @@
 			expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
 
 		// We might need to cast in order to load from a builtin.
-		cast_from_builtin_load(ptr, expr, type);
+		cast_from_variable_load(ptr, expr, type);
 
 		// We might be trying to load a gl_Position[N], where we should be
 		// doing float4[](gl_in[i].gl_Position, ...) instead.
@@ -15385,7 +15385,7 @@
 	}
 }
 
-void CompilerGLSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
+void CompilerGLSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
 {
 	// We will handle array cases elsewhere.
 	if (!expr_type.array.empty())
@@ -15444,7 +15444,7 @@
 		expr = bitcast_expression(expr_type, expected_type, expr);
 }
 
-void CompilerGLSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
+void CompilerGLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
 {
 	auto *var = maybe_get_backing_variable(target_id);
 	if (var)
diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp
index d9594d3..52839d6 100644
--- a/spirv_glsl.hpp
+++ b/spirv_glsl.hpp
@@ -903,8 +903,8 @@
 	// Builtins in GLSL are always specific signedness, but the SPIR-V can declare them
 	// as either unsigned or signed.
 	// Sometimes we will need to automatically perform casts on load and store to make this work.
-	virtual void cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type);
-	virtual void cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type);
+	virtual void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type);
+	virtual void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type);
 	void unroll_array_from_complex_load(uint32_t target_id, uint32_t source_id, std::string &expr);
 	bool unroll_array_to_complex_store(uint32_t target_id, uint32_t source_id);
 	void convert_non_uniform_expression(std::string &expr, uint32_t ptr_id);
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index 43cb6cd..ceed1c6 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -13351,8 +13351,23 @@
 
 	// Scalars
 	case SPIRType::Boolean:
-		type_name = "bool";
+	{
+		auto *var = maybe_get_backing_variable(id);
+		if (var && var->basevariable)
+			var = &get<SPIRVariable>(var->basevariable);
+
+		// Need to special-case threadgroup booleans. They are supposed to be logical
+		// storage, but MSL compilers will sometimes crash if you use threadgroup bool.
+		// Workaround this by using 16-bit types instead and fixup on load-store to this data.
+		// FIXME: We have no sane way of working around this problem if a struct member is boolean
+		// and that struct is used as a threadgroup variable, but ... sigh.
+		if ((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup)
+			type_name = "short";
+		else
+			type_name = "bool";
 		break;
+	}
+
 	case SPIRType::Char:
 	case SPIRType::SByte:
 		type_name = "char";
@@ -15413,12 +15428,16 @@
 	constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
 }
 
-void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
+void CompilerMSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
 {
 	auto *var = maybe_get_backing_variable(source_id);
 	if (var)
 		source_id = var->self;
 
+	// Type fixups for workgroup variables if they are booleans.
+	if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
+		expr = join(type_to_glsl(expr_type), "(", expr, ")");
+
 	// Only interested in standalone builtin variables.
 	if (!has_decoration(source_id, DecorationBuiltIn))
 		return;
@@ -15505,12 +15524,20 @@
 	}
 }
 
-void CompilerMSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
+void CompilerMSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
 {
 	auto *var = maybe_get_backing_variable(target_id);
 	if (var)
 		target_id = var->self;
 
+	// Type fixups for workgroup variables if they are booleans.
+	if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
+	{
+		auto short_type = expr_type;
+		short_type.basetype = SPIRType::Short;
+		expr = join(type_to_glsl(short_type), "(", expr, ")");
+	}
+
 	// Only interested in standalone builtin variables.
 	if (!has_decoration(target_id, DecorationBuiltIn))
 		return;
diff --git a/spirv_msl.hpp b/spirv_msl.hpp
index ce43cd9..d1d2ef3 100644
--- a/spirv_msl.hpp
+++ b/spirv_msl.hpp
@@ -960,8 +960,8 @@
 
 	bool does_shader_write_sample_mask = false;
 
-	void cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
-	void cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override;
+	void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
+	void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override;
 	void emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression) override;
 
 	void analyze_sampled_image_usage();