MSL: Fix image load/store for short vectors.

Same fixes as for GLSL.
diff --git a/reference/opt/shaders-msl/asm/comp/buffer-write.asm.comp b/reference/opt/shaders-msl/asm/comp/buffer-write.asm.comp
index ddf9582..ab375a3 100644
--- a/reference/opt/shaders-msl/asm/comp/buffer-write.asm.comp
+++ b/reference/opt/shaders-msl/asm/comp/buffer-write.asm.comp
@@ -18,6 +18,6 @@
 
 kernel void main0(constant cb& _6 [[buffer(7)]], texture2d<float, access::write> _buffer [[texture(0)]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]])
 {
-    _buffer.write(_6.value, spvTexelBufferCoord(((32u * gl_WorkGroupID.x) + gl_LocalInvocationIndex)));
+    _buffer.write(float4(_6.value), spvTexelBufferCoord(((32u * gl_WorkGroupID.x) + gl_LocalInvocationIndex)));
 }
 
diff --git a/reference/opt/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp b/reference/opt/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
new file mode 100644
index 0000000..fb97d0d
--- /dev/null
+++ b/reference/opt/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
@@ -0,0 +1,10 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+kernel void main0(texture2d<float, access::read_write> TargetTexture [[texture(0)]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
+{
+    TargetTexture.write((TargetTexture.read(uint2(gl_WorkGroupID.xy)).xy + float2(1.0)).xyyy, uint2((gl_WorkGroupID.xy + uint2(1u))));
+}
+
diff --git a/reference/shaders-msl/asm/comp/buffer-write.asm.comp b/reference/shaders-msl/asm/comp/buffer-write.asm.comp
index ddf9582..ab375a3 100644
--- a/reference/shaders-msl/asm/comp/buffer-write.asm.comp
+++ b/reference/shaders-msl/asm/comp/buffer-write.asm.comp
@@ -18,6 +18,6 @@
 
 kernel void main0(constant cb& _6 [[buffer(7)]], texture2d<float, access::write> _buffer [[texture(0)]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]])
 {
-    _buffer.write(_6.value, spvTexelBufferCoord(((32u * gl_WorkGroupID.x) + gl_LocalInvocationIndex)));
+    _buffer.write(float4(_6.value), spvTexelBufferCoord(((32u * gl_WorkGroupID.x) + gl_LocalInvocationIndex)));
 }
 
diff --git a/reference/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp b/reference/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
new file mode 100644
index 0000000..c90faf9
--- /dev/null
+++ b/reference/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
@@ -0,0 +1,21 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+void _main(thread const uint3& id, thread texture2d<float, access::read_write> TargetTexture)
+{
+    float2 loaded = TargetTexture.read(uint2(id.xy)).xy;
+    float2 storeTemp = loaded + float2(1.0);
+    TargetTexture.write(storeTemp.xyyy, uint2((id.xy + uint2(1u))));
+}
+
+kernel void main0(texture2d<float, access::read_write> TargetTexture [[texture(0)]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
+{
+    uint3 id = gl_WorkGroupID;
+    uint3 param = id;
+    _main(param, TargetTexture);
+}
+
diff --git a/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp b/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
new file mode 100644
index 0000000..8f75929
--- /dev/null
+++ b/shaders-msl/asm/comp/image-load-store-short-vector.asm.comp
@@ -0,0 +1,75 @@
+; SPIR-V
+; Version: 1.0
+; Generator: Khronos Glslang Reference Front End; 7
+; Bound: 44
+; Schema: 0
+               OpCapability Shader
+               OpCapability StorageImageExtendedFormats
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main" %id_1
+               OpExecutionMode %main LocalSize 1 1 1
+               OpSource HLSL 500
+               OpName %main "main"
+               OpName %_main_vu3_ "@main(vu3;"
+               OpName %id "id"
+               OpName %loaded "loaded"
+               OpName %TargetTexture "TargetTexture"
+               OpName %storeTemp "storeTemp"
+               OpName %id_0 "id"
+               OpName %id_1 "id"
+               OpName %param "param"
+               OpDecorate %TargetTexture DescriptorSet 0
+               OpDecorate %TargetTexture Binding 0
+               OpDecorate %id_1 BuiltIn WorkgroupId
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+     %v3uint = OpTypeVector %uint 3
+%_ptr_Function_v3uint = OpTypePointer Function %v3uint
+          %9 = OpTypeFunction %void %_ptr_Function_v3uint
+      %float = OpTypeFloat 32
+    %v2float = OpTypeVector %float 2
+%_ptr_Function_v2float = OpTypePointer Function %v2float
+         %17 = OpTypeImage %float 2D 0 0 0 2 Rg32f
+%_ptr_UniformConstant_17 = OpTypePointer UniformConstant %17
+%TargetTexture = OpVariable %_ptr_UniformConstant_17 UniformConstant
+     %v2uint = OpTypeVector %uint 2
+    %float_1 = OpConstant %float 1
+     %uint_1 = OpConstant %uint 1
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+       %id_1 = OpVariable %_ptr_Input_v3uint Input
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+       %id_0 = OpVariable %_ptr_Function_v3uint Function
+      %param = OpVariable %_ptr_Function_v3uint Function
+         %40 = OpLoad %v3uint %id_1
+               OpStore %id_0 %40
+         %42 = OpLoad %v3uint %id_0
+               OpStore %param %42
+         %43 = OpFunctionCall %void %_main_vu3_ %param
+               OpReturn
+               OpFunctionEnd
+ %_main_vu3_ = OpFunction %void None %9
+         %id = OpFunctionParameter %_ptr_Function_v3uint
+         %12 = OpLabel
+     %loaded = OpVariable %_ptr_Function_v2float Function
+  %storeTemp = OpVariable %_ptr_Function_v2float Function
+         %20 = OpLoad %17 %TargetTexture
+         %22 = OpLoad %v3uint %id
+         %23 = OpVectorShuffle %v2uint %22 %22 0 1
+         %24 = OpImageRead %v2float %20 %23
+               OpStore %loaded %24
+         %26 = OpLoad %v2float %loaded
+         %28 = OpCompositeConstruct %v2float %float_1 %float_1
+         %29 = OpFAdd %v2float %26 %28
+               OpStore %storeTemp %29
+         %30 = OpLoad %17 %TargetTexture
+         %31 = OpLoad %v3uint %id
+         %32 = OpVectorShuffle %v2uint %31 %31 0 1
+         %34 = OpCompositeConstruct %v2uint %uint_1 %uint_1
+         %35 = OpIAdd %v2uint %32 %34
+         %36 = OpLoad %v2float %storeTemp
+               OpImageWrite %30 %35 %36
+               OpReturn
+               OpFunctionEnd
diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp
index 82ce73c..131bc80 100644
--- a/spirv_glsl.cpp
+++ b/spirv_glsl.cpp
@@ -4212,6 +4212,10 @@
 	if (is_legacy() && image_is_comparison(imgtype, img))
 		expr += ".r";
 
+	// Deals with reads from MSL. We might need to downconvert to fewer components.
+	if (op == OpImageRead)
+		expr = remap_swizzle(get<SPIRType>(result_type), 4, expr);
+
 	emit_op(result_type, id, expr, forward);
 	for (auto &inherit : inherited_expressions)
 		inherit_expression_dependencies(id, inherit);
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index 305e3e1..87b00ea 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -2711,8 +2711,13 @@
 		test(bias, ImageOperandsBiasMask);
 		test(lod, ImageOperandsLodMask);
 
+		auto &texel_type = expression_type(texel_id);
+		auto store_type = texel_type;
+		store_type.vecsize = 4;
+
 		statement(join(
-		    to_expression(img_id), ".write(", to_expression(texel_id), ", ",
+		    to_expression(img_id), ".write(", remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)),
+		    ", ",
 		    to_function_args(img_id, img_type, true, false, false, coord_id, 0, 0, 0, 0, lod, 0, 0, 0, 0, 0, &forward),
 		    ");"));