MSL: Fix image load/store for short vectors.
Same fixes as for GLSL.
diff --git a/reference/opt/shaders-msl/asm/comp/buffer-write.asm.comp b/reference/opt/shaders-msl/asm/comp/buffer-write.asm.comp
index ddf9582..ab375a3 100644
--- a/reference/opt/shaders-msl/asm/comp/buffer-write.asm.comp
+++ b/reference/opt/shaders-msl/asm/comp/buffer-write.asm.comp
@@ -18,6 +18,6 @@
kernel void main0(constant cb& _6 [[buffer(7)]], texture2d<float, access::write> _buffer [[texture(0)]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]])
{
- _buffer.write(_6.value, spvTexelBufferCoord(((32u * gl_WorkGroupID.x) + gl_LocalInvocationIndex)));
+ _buffer.write(float4(_6.value), spvTexelBufferCoord(((32u * gl_WorkGroupID.x) + gl_LocalInvocationIndex)));
}
diff --git a/reference/opt/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp b/reference/opt/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
new file mode 100644
index 0000000..fb97d0d
--- /dev/null
+++ b/reference/opt/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
@@ -0,0 +1,10 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+kernel void main0(texture2d<float, access::read_write> TargetTexture [[texture(0)]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
+{
+ TargetTexture.write((TargetTexture.read(uint2(gl_WorkGroupID.xy)).xy + float2(1.0)).xyyy, uint2((gl_WorkGroupID.xy + uint2(1u))));
+}
+
diff --git a/reference/shaders-msl/asm/comp/buffer-write.asm.comp b/reference/shaders-msl/asm/comp/buffer-write.asm.comp
index ddf9582..ab375a3 100644
--- a/reference/shaders-msl/asm/comp/buffer-write.asm.comp
+++ b/reference/shaders-msl/asm/comp/buffer-write.asm.comp
@@ -18,6 +18,6 @@
kernel void main0(constant cb& _6 [[buffer(7)]], texture2d<float, access::write> _buffer [[texture(0)]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]])
{
- _buffer.write(_6.value, spvTexelBufferCoord(((32u * gl_WorkGroupID.x) + gl_LocalInvocationIndex)));
+ _buffer.write(float4(_6.value), spvTexelBufferCoord(((32u * gl_WorkGroupID.x) + gl_LocalInvocationIndex)));
}
diff --git a/reference/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp b/reference/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
new file mode 100644
index 0000000..c90faf9
--- /dev/null
+++ b/reference/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
@@ -0,0 +1,21 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+void _main(thread const uint3& id, thread texture2d<float, access::read_write> TargetTexture)
+{
+ float2 loaded = TargetTexture.read(uint2(id.xy)).xy;
+ float2 storeTemp = loaded + float2(1.0);
+ TargetTexture.write(storeTemp.xyyy, uint2((id.xy + uint2(1u))));
+}
+
+kernel void main0(texture2d<float, access::read_write> TargetTexture [[texture(0)]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
+{
+ uint3 id = gl_WorkGroupID;
+ uint3 param = id;
+ _main(param, TargetTexture);
+}
+
diff --git a/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp b/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
new file mode 100644
index 0000000..8f75929
--- /dev/null
+++ b/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
@@ -0,0 +1,75 @@
+; SPIR-V
+; Version: 1.0
+; Generator: Khronos Glslang Reference Front End; 7
+; Bound: 44
+; Schema: 0
+ OpCapability Shader
+ OpCapability StorageImageExtendedFormats
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "main" %id_1
+ OpExecutionMode %main LocalSize 1 1 1
+ OpSource HLSL 500
+ OpName %main "main"
+ OpName %_main_vu3_ "@main(vu3;"
+ OpName %id "id"
+ OpName %loaded "loaded"
+ OpName %TargetTexture "TargetTexture"
+ OpName %storeTemp "storeTemp"
+ OpName %id_0 "id"
+ OpName %id_1 "id"
+ OpName %param "param"
+ OpDecorate %TargetTexture DescriptorSet 0
+ OpDecorate %TargetTexture Binding 0
+ OpDecorate %id_1 BuiltIn WorkgroupId
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+ %v3uint = OpTypeVector %uint 3
+%_ptr_Function_v3uint = OpTypePointer Function %v3uint
+ %9 = OpTypeFunction %void %_ptr_Function_v3uint
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+%_ptr_Function_v2float = OpTypePointer Function %v2float
+ %17 = OpTypeImage %float 2D 0 0 0 2 Rg32f
+%_ptr_UniformConstant_17 = OpTypePointer UniformConstant %17
+%TargetTexture = OpVariable %_ptr_UniformConstant_17 UniformConstant
+ %v2uint = OpTypeVector %uint 2
+ %float_1 = OpConstant %float 1
+ %uint_1 = OpConstant %uint 1
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+ %id_1 = OpVariable %_ptr_Input_v3uint Input
+ %main = OpFunction %void None %3
+ %5 = OpLabel
+ %id_0 = OpVariable %_ptr_Function_v3uint Function
+ %param = OpVariable %_ptr_Function_v3uint Function
+ %40 = OpLoad %v3uint %id_1
+ OpStore %id_0 %40
+ %42 = OpLoad %v3uint %id_0
+ OpStore %param %42
+ %43 = OpFunctionCall %void %_main_vu3_ %param
+ OpReturn
+ OpFunctionEnd
+ %_main_vu3_ = OpFunction %void None %9
+ %id = OpFunctionParameter %_ptr_Function_v3uint
+ %12 = OpLabel
+ %loaded = OpVariable %_ptr_Function_v2float Function
+ %storeTemp = OpVariable %_ptr_Function_v2float Function
+ %20 = OpLoad %17 %TargetTexture
+ %22 = OpLoad %v3uint %id
+ %23 = OpVectorShuffle %v2uint %22 %22 0 1
+ %24 = OpImageRead %v2float %20 %23
+ OpStore %loaded %24
+ %26 = OpLoad %v2float %loaded
+ %28 = OpCompositeConstruct %v2float %float_1 %float_1
+ %29 = OpFAdd %v2float %26 %28
+ OpStore %storeTemp %29
+ %30 = OpLoad %17 %TargetTexture
+ %31 = OpLoad %v3uint %id
+ %32 = OpVectorShuffle %v2uint %31 %31 0 1
+ %34 = OpCompositeConstruct %v2uint %uint_1 %uint_1
+ %35 = OpIAdd %v2uint %32 %34
+ %36 = OpLoad %v2float %storeTemp
+ OpImageWrite %30 %35 %36
+ OpReturn
+ OpFunctionEnd
diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp
index 82ce73c..131bc80 100644
--- a/spirv_glsl.cpp
+++ b/spirv_glsl.cpp
@@ -4212,6 +4212,10 @@
if (is_legacy() && image_is_comparison(imgtype, img))
expr += ".r";
+ // Deals with reads from MSL. We might need to downconvert to fewer components.
+ if (op == OpImageRead)
+ expr = remap_swizzle(get<SPIRType>(result_type), 4, expr);
+
emit_op(result_type, id, expr, forward);
for (auto &inherit : inherited_expressions)
inherit_expression_dependencies(id, inherit);
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index 305e3e1..87b00ea 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -2711,8 +2711,13 @@
test(bias, ImageOperandsBiasMask);
test(lod, ImageOperandsLodMask);
+ auto &texel_type = expression_type(texel_id);
+ auto store_type = texel_type;
+ store_type.vecsize = 4;
+
statement(join(
- to_expression(img_id), ".write(", to_expression(texel_id), ", ",
+ to_expression(img_id), ".write(", remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)),
+ ", ",
to_function_args(img_id, img_type, true, false, false, coord_id, 0, 0, 0, 0, lod, 0, 0, 0, 0, 0, &forward),
");"));