Merge pull request #1787 from KhronosGroup/fix-1786
MSL: Workaround compiler crashes when using threadgroup bool.
diff --git a/reference/opt/shaders-msl/comp/threadgroup-boolean-workaround.comp b/reference/opt/shaders-msl/comp/threadgroup-boolean-workaround.comp
new file mode 100644
index 0000000..8b80929
--- /dev/null
+++ b/reference/opt/shaders-msl/comp/threadgroup-boolean-workaround.comp
@@ -0,0 +1,20 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct SSBO
+{
+ float4 values[1];
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
+
+kernel void main0(device SSBO& _23 [[buffer(0)]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
+{
+ threadgroup short4 foo[4];
+ foo[gl_LocalInvocationIndex] = short4((isunordered(_23.values[gl_GlobalInvocationID.x], float4(10.0)) || _23.values[gl_GlobalInvocationID.x] != float4(10.0)));
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ _23.values[gl_GlobalInvocationID.x] = select(float4(40.0), float4(30.0), bool4(foo[gl_LocalInvocationIndex ^ 3u]));
+}
+
diff --git a/reference/shaders-msl/comp/threadgroup-boolean-workaround.comp b/reference/shaders-msl/comp/threadgroup-boolean-workaround.comp
new file mode 100644
index 0000000..d01b135
--- /dev/null
+++ b/reference/shaders-msl/comp/threadgroup-boolean-workaround.comp
@@ -0,0 +1,28 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct SSBO
+{
+ float4 values[1];
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u);
+
+static inline __attribute__((always_inline))
+void in_function(threadgroup short4 (&foo)[4], thread uint& gl_LocalInvocationIndex, device SSBO& v_23, thread uint3& gl_GlobalInvocationID)
+{
+ foo[gl_LocalInvocationIndex] = short4((isunordered(v_23.values[gl_GlobalInvocationID.x], float4(10.0)) || v_23.values[gl_GlobalInvocationID.x] != float4(10.0)));
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ v_23.values[gl_GlobalInvocationID.x] = select(float4(40.0), float4(30.0), bool4(foo[gl_LocalInvocationIndex ^ 3u]));
+}
+
+kernel void main0(device SSBO& v_23 [[buffer(0)]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
+{
+ threadgroup short4 foo[4];
+ in_function(foo, gl_LocalInvocationIndex, v_23, gl_GlobalInvocationID);
+}
+
diff --git a/shaders-msl/comp/threadgroup-boolean-workaround.comp b/shaders-msl/comp/threadgroup-boolean-workaround.comp
new file mode 100644
index 0000000..8dce77a
--- /dev/null
+++ b/shaders-msl/comp/threadgroup-boolean-workaround.comp
@@ -0,0 +1,21 @@
+#version 450
+layout(local_size_x = 4) in;
+
+shared bvec4 foo[4];
+
+layout(binding = 0) buffer SSBO
+{
+ vec4 values[];
+};
+
+void in_function()
+{
+ foo[gl_LocalInvocationIndex] = notEqual(values[gl_GlobalInvocationID.x], vec4(10.0));
+ barrier();
+ values[gl_GlobalInvocationID.x] = mix(vec4(40.0), vec4(30.0), foo[gl_LocalInvocationIndex ^ 3]);
+}
+
+void main()
+{
+ in_function();
+}
diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp
index aa891c1..0a00947 100644
--- a/spirv_glsl.cpp
+++ b/spirv_glsl.cpp
@@ -9891,7 +9891,7 @@
convert_non_uniform_expression(lhs, lhs_expression);
// We might need to cast in order to store to a builtin.
- cast_to_builtin_store(lhs_expression, rhs, expression_type(rhs_expression));
+ cast_to_variable_store(lhs_expression, rhs, expression_type(rhs_expression));
// Tries to optimize assignments like "<lhs> = <lhs> op expr".
// While this is purely cosmetic, this is important for legacy ESSL where loop
@@ -10056,7 +10056,7 @@
expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
// We might need to cast in order to load from a builtin.
- cast_from_builtin_load(ptr, expr, type);
+ cast_from_variable_load(ptr, expr, type);
// We might be trying to load a gl_Position[N], where we should be
// doing float4[](gl_in[i].gl_Position, ...) instead.
@@ -15385,7 +15385,7 @@
}
}
-void CompilerGLSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
+void CompilerGLSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
{
// We will handle array cases elsewhere.
if (!expr_type.array.empty())
@@ -15444,7 +15444,7 @@
expr = bitcast_expression(expr_type, expected_type, expr);
}
-void CompilerGLSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
+void CompilerGLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
{
auto *var = maybe_get_backing_variable(target_id);
if (var)
diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp
index d9594d3..52839d6 100644
--- a/spirv_glsl.hpp
+++ b/spirv_glsl.hpp
@@ -903,8 +903,8 @@
// Builtins in GLSL are always specific signedness, but the SPIR-V can declare them
// as either unsigned or signed.
// Sometimes we will need to automatically perform casts on load and store to make this work.
- virtual void cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type);
- virtual void cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type);
+ virtual void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type);
+ virtual void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type);
void unroll_array_from_complex_load(uint32_t target_id, uint32_t source_id, std::string &expr);
bool unroll_array_to_complex_store(uint32_t target_id, uint32_t source_id);
void convert_non_uniform_expression(std::string &expr, uint32_t ptr_id);
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index 43cb6cd..ceed1c6 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -13351,8 +13351,23 @@
// Scalars
case SPIRType::Boolean:
- type_name = "bool";
+ {
+ auto *var = maybe_get_backing_variable(id);
+ if (var && var->basevariable)
+ var = &get<SPIRVariable>(var->basevariable);
+
+ // Need to special-case threadgroup booleans. They are supposed to be logical
+ // storage, but MSL compilers will sometimes crash if you use threadgroup bool.
+ // Workaround this by using 16-bit types instead and fixup on load-store to this data.
+ // FIXME: We have no sane way of working around this problem if a struct member is boolean
+ // and that struct is used as a threadgroup variable, but ... sigh.
+ if ((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup)
+ type_name = "short";
+ else
+ type_name = "bool";
break;
+ }
+
case SPIRType::Char:
case SPIRType::SByte:
type_name = "char";
@@ -15413,12 +15428,16 @@
constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
}
-void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
+void CompilerMSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
{
auto *var = maybe_get_backing_variable(source_id);
if (var)
source_id = var->self;
+ // Type fixups for workgroup variables if they are booleans.
+ if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
+ expr = join(type_to_glsl(expr_type), "(", expr, ")");
+
// Only interested in standalone builtin variables.
if (!has_decoration(source_id, DecorationBuiltIn))
return;
@@ -15505,12 +15524,20 @@
}
}
-void CompilerMSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
+void CompilerMSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
{
auto *var = maybe_get_backing_variable(target_id);
if (var)
target_id = var->self;
+ // Type fixups for workgroup variables if they are booleans.
+ if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
+ {
+ auto short_type = expr_type;
+ short_type.basetype = SPIRType::Short;
+ expr = join(type_to_glsl(short_type), "(", expr, ")");
+ }
+
// Only interested in standalone builtin variables.
if (!has_decoration(target_id, DecorationBuiltIn))
return;
diff --git a/spirv_msl.hpp b/spirv_msl.hpp
index ce43cd9..d1d2ef3 100644
--- a/spirv_msl.hpp
+++ b/spirv_msl.hpp
@@ -960,8 +960,8 @@
bool does_shader_write_sample_mask = false;
- void cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
- void cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override;
+ void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override;
+ void cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type) override;
void emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression) override;
void analyze_sampled_image_usage();