Merge pull request #1808 from billhollings/depth-img-vs-depth-cmp

Separate (partially) the tracking of depth images from depth compare ops.
diff --git a/reference/opt/shaders-msl/asm/frag/depth-image-color-format-fetch.asm.frag b/reference/opt/shaders-msl/asm/frag/depth-image-color-format-fetch.asm.frag
new file mode 100644
index 0000000..be9f133
--- /dev/null
+++ b/reference/opt/shaders-msl/asm/frag/depth-image-color-format-fetch.asm.frag
@@ -0,0 +1,33 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct _7
+{
+    float4 _m0[64];
+};
+
+struct main0_out
+{
+    float4 m_3 [[color(0)]];
+};
+
+struct main0_in
+{
+    float4 m_2 [[user(locn1)]];
+};
+
+fragment main0_out main0(main0_in in [[stage_in]], device _7& _10 [[buffer(0)]], texture2d<float> _8 [[texture(0)]])
+{
+    main0_out out = {};
+    for (int _154 = 0; _154 < 64; )
+    {
+        _10._m0[_154] = _8.read(uint2(int2(_154 - 8 * (_154 / 8), _154 / 8)), 0);
+        _154++;
+        continue;
+    }
+    out.m_3 = in.m_2;
+    return out;
+}
+
diff --git a/reference/opt/shaders-msl/asm/frag/depth-image-color-format-sampled.asm.frag b/reference/opt/shaders-msl/asm/frag/depth-image-color-format-sampled.asm.frag
new file mode 100644
index 0000000..bbe0acd
--- /dev/null
+++ b/reference/opt/shaders-msl/asm/frag/depth-image-color-format-sampled.asm.frag
@@ -0,0 +1,33 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct _7
+{
+    float4 _m0[64];
+};
+
+struct main0_out
+{
+    float4 m_3 [[color(0)]];
+};
+
+struct main0_in
+{
+    float4 m_2 [[user(locn1)]];
+};
+
+fragment main0_out main0(main0_in in [[stage_in]], device _7& _10 [[buffer(0)]], texture2d<float> _8 [[texture(0)]], sampler _9 [[sampler(0)]])
+{
+    main0_out out = {};
+    for (int _158 = 0; _158 < 64; )
+    {
+        _10._m0[_158] = _8.sample(_9, (float2(int2(_158 - 8 * (_158 / 8), _158 / 8)) * float2(0.125)), level(0.0));
+        _158++;
+        continue;
+    }
+    out.m_3 = in.m_2;
+    return out;
+}
+
diff --git a/reference/shaders-msl/asm/frag/depth-image-color-format-fetch.asm.frag b/reference/shaders-msl/asm/frag/depth-image-color-format-fetch.asm.frag
new file mode 100644
index 0000000..01c670d
--- /dev/null
+++ b/reference/shaders-msl/asm/frag/depth-image-color-format-fetch.asm.frag
@@ -0,0 +1,47 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct _7
+{
+    float4 _m0[64];
+};
+
+struct main0_out
+{
+    float4 m_3 [[color(0)]];
+};
+
+struct main0_in
+{
+    float4 m_2 [[user(locn1)]];
+};
+
+static inline __attribute__((always_inline))
+void _108(int _109, thread texture2d<float> v_8, device _7& v_10)
+{
+    int2 _113 = int2(_109 - 8 * (_109 / 8), _109 / 8);
+    v_10._m0[_109] = v_8.read(uint2(_113), 0);
+}
+
+static inline __attribute__((always_inline))
+float4 _98(float4 _119, thread texture2d<float> v_8, device _7& v_10)
+{
+    for (int _121 = 0; _121 < 64; _121++)
+    {
+        _108(_121, v_8, v_10);
+    }
+    return _119;
+}
+
+fragment main0_out main0(main0_in in [[stage_in]], device _7& v_10 [[buffer(0)]], texture2d<float> v_8 [[texture(0)]])
+{
+    main0_out out = {};
+    float4 _97 = _98(in.m_2, v_8, v_10);
+    out.m_3 = _97;
+    return out;
+}
+
diff --git a/reference/shaders-msl/asm/frag/depth-image-color-format-sampled.asm.frag b/reference/shaders-msl/asm/frag/depth-image-color-format-sampled.asm.frag
new file mode 100644
index 0000000..9e374c0
--- /dev/null
+++ b/reference/shaders-msl/asm/frag/depth-image-color-format-sampled.asm.frag
@@ -0,0 +1,46 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct _7
+{
+    float4 _m0[64];
+};
+
+struct main0_out
+{
+    float4 m_3 [[color(0)]];
+};
+
+struct main0_in
+{
+    float4 m_2 [[user(locn1)]];
+};
+
+static inline __attribute__((always_inline))
+void _108(int _109, thread texture2d<float> v_8, thread sampler v_9, device _7& v_10)
+{
+    v_10._m0[_109] = v_8.sample(v_9, (float2(int2(_109 - 8 * (_109 / 8), _109 / 8)) / float2(8.0)), level(0.0));
+}
+
+static inline __attribute__((always_inline))
+float4 _98(float4 _121, thread texture2d<float> v_8, thread sampler v_9, device _7& v_10)
+{
+    for (int _123 = 0; _123 < 64; _123++)
+    {
+        _108(_123, v_8, v_9, v_10);
+    }
+    return _121;
+}
+
+fragment main0_out main0(main0_in in [[stage_in]], device _7& v_10 [[buffer(0)]], texture2d<float> v_8 [[texture(0)]], sampler v_9 [[sampler(0)]])
+{
+    main0_out out = {};
+    float4 _97 = _98(in.m_2, v_8, v_9, v_10);
+    out.m_3 = _97;
+    return out;
+}
+
diff --git a/shaders-msl/asm/frag/depth-image-color-format-fetch.asm.frag b/shaders-msl/asm/frag/depth-image-color-format-fetch.asm.frag
new file mode 100644
index 0000000..0be26d1
--- /dev/null
+++ b/shaders-msl/asm/frag/depth-image-color-format-fetch.asm.frag
@@ -0,0 +1,170 @@
+; SPIR-V
+; Version: 1.0
+; Generator: Khronos SPIR-V Tools Assembler; 0
+; Bound: 132
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "main" %2 %3 %4
+               OpExecutionMode %1 OriginUpperLeft
+               OpDecorate %3 Location 0
+               OpDecorate %2 Location 1
+               OpDecorate %4 BuiltIn FragCoord
+               OpDecorate %5 ArrayStride 4
+               OpDecorate %6 ArrayStride 16
+               OpMemberDecorate %7 0 Offset 0
+               OpDecorate %7 BufferBlock
+               OpDecorate %8 DescriptorSet 0
+               OpDecorate %8 Binding 0
+               OpDecorate %9 DescriptorSet 0
+               OpDecorate %9 Binding 1
+               OpDecorate %10 DescriptorSet 0
+               OpDecorate %10 Binding 2
+         %11 = OpTypeVoid
+         %12 = OpTypeBool
+         %13 = OpTypeInt 32 1
+         %14 = OpTypeInt 32 0
+         %16 = OpTypeFloat 32
+         %17 = OpTypeVector %13 2
+         %18 = OpTypeVector %14 2
+         %19 = OpTypeVector %16 2
+         %20 = OpTypeVector %13 3
+         %21 = OpTypeVector %14 3
+         %22 = OpTypeVector %16 3
+         %23 = OpTypeVector %13 4
+         %24 = OpTypeVector %14 4
+         %25 = OpTypeVector %16 4
+         %26 = OpTypeVector %12 4
+         %27 = OpTypeFunction %25 %25
+         %28 = OpTypeFunction %12
+         %29 = OpTypeFunction %11
+         %30 = OpTypePointer Input %16
+         %31 = OpTypePointer Input %13
+         %32 = OpTypePointer Input %14
+         %33 = OpTypePointer Input %19
+         %34 = OpTypePointer Input %17
+         %35 = OpTypePointer Input %18
+         %38 = OpTypePointer Input %22
+         %40 = OpTypePointer Input %25
+         %41 = OpTypePointer Input %23
+         %42 = OpTypePointer Input %24
+         %43 = OpTypePointer Output %16
+         %44 = OpTypePointer Output %13
+         %45 = OpTypePointer Output %14
+         %46 = OpTypePointer Output %19
+         %47 = OpTypePointer Output %17
+         %48 = OpTypePointer Output %18
+         %49 = OpTypePointer Output %25
+         %50 = OpTypePointer Output %23
+         %51 = OpTypePointer Output %24
+         %52 = OpTypePointer Function %16
+         %53 = OpTypePointer Function %13
+         %54 = OpTypePointer Function %25
+         %55 = OpConstant %16 1
+         %56 = OpConstant %16 0
+         %57 = OpConstant %16 0.5
+         %58 = OpConstant %16 -1
+         %59 = OpConstant %16 7
+         %60 = OpConstant %16 8
+         %61 = OpConstant %13 0
+         %62 = OpConstant %13 1
+         %63 = OpConstant %13 2
+         %64 = OpConstant %13 3
+         %65 = OpConstant %13 4
+         %66 = OpConstant %14 0
+         %67 = OpConstant %14 1
+         %68 = OpConstant %14 2
+         %69 = OpConstant %14 3
+         %70 = OpConstant %14 32
+         %71 = OpConstant %14 4
+         %72 = OpConstant %14 2147483647
+         %73 = OpConstantComposite %25 %55 %55 %55 %55
+         %74 = OpConstantComposite %25 %55 %56 %56 %55
+         %75 = OpConstantComposite %25 %57 %57 %57 %57
+         %76 = OpTypeArray %16 %67
+         %77 = OpTypeArray %16 %68
+         %78 = OpTypeArray %25 %69
+         %79 = OpTypeArray %16 %71
+         %80 = OpTypeArray %25 %70
+         %81 = OpTypePointer Input %78
+         %82 = OpTypePointer Input %80
+         %83 = OpTypePointer Output %77
+         %84 = OpTypePointer Output %78
+         %85 = OpTypePointer Output %79
+          %4 = OpVariable %40 Input
+          %3 = OpVariable %49 Output
+          %2 = OpVariable %40 Input
+         %86 = OpConstant %14 64
+         %87 = OpConstant %13 64
+         %88 = OpConstant %13 8
+         %89 = OpConstantComposite %19 %60 %60
+          %5 = OpTypeArray %16 %86
+          %6 = OpTypeArray %25 %86
+         %90 = OpTypePointer Uniform %16
+         %91 = OpTypePointer Uniform %25
+          %7 = OpTypeStruct %6
+         %92 = OpTypePointer Uniform %7
+         %10 = OpVariable %92 Uniform
+         %93 = OpTypeImage %16 2D 1 0 0 1 Rgba32f
+         %94 = OpTypePointer UniformConstant %93
+          %8 = OpVariable %94 UniformConstant
+         %95 = OpTypeSampler
+         %96 = OpTypePointer UniformConstant %95
+          %9 = OpVariable %96 UniformConstant
+         %97 = OpTypeSampledImage %93
+         %98 = OpTypeFunction %11 %13
+          %1 = OpFunction %11 None %29
+         %99 = OpLabel
+        %100 = OpLoad %25 %2
+        %101 = OpFunctionCall %25 %102 %100
+               OpStore %3 %101
+               OpReturn
+               OpFunctionEnd
+        %103 = OpFunction %12 None %28
+        %104 = OpLabel
+        %105 = OpAccessChain %30 %4 %61
+        %106 = OpAccessChain %30 %4 %62
+        %107 = OpLoad %16 %105
+        %108 = OpLoad %16 %106
+        %109 = OpFOrdEqual %12 %107 %57
+        %110 = OpFOrdEqual %12 %108 %57
+        %111 = OpLogicalAnd %12 %109 %110
+               OpReturnValue %111
+               OpFunctionEnd
+        %112 = OpFunction %11 None %98
+        %113 = OpFunctionParameter %13
+        %114 = OpLabel
+        %115 = OpSRem %13 %113 %88
+        %116 = OpSDiv %13 %113 %88
+        %117 = OpCompositeConstruct %17 %115 %116
+        %118 = OpConvertSToF %19 %117
+        %119 = OpFDiv %19 %118 %89
+        %120 = OpLoad %93 %8
+        %121 = OpImageFetch %25 %120 %117
+         %36 = OpAccessChain %91 %10 %61 %113
+               OpStore %36 %121
+               OpReturn
+               OpFunctionEnd
+        %102 = OpFunction %25 None %27
+        %122 = OpFunctionParameter %25
+        %123 = OpLabel
+        %124 = OpVariable %53 Function
+               OpStore %124 %61
+               OpBranch %125
+        %125 = OpLabel
+         %15 = OpLoad %13 %124
+        %126 = OpSLessThan %12 %15 %87
+               OpLoopMerge %127 %128 None
+               OpBranchConditional %126 %129 %127
+        %129 = OpLabel
+        %130 = OpLoad %13 %124
+        %131 = OpFunctionCall %11 %112 %130
+               OpBranch %128
+        %128 = OpLabel
+         %37 = OpLoad %13 %124
+         %39 = OpIAdd %13 %37 %62
+               OpStore %124 %39
+               OpBranch %125
+        %127 = OpLabel
+               OpReturnValue %122
+               OpFunctionEnd
diff --git a/shaders-msl/asm/frag/depth-image-color-format-sampled.asm.frag b/shaders-msl/asm/frag/depth-image-color-format-sampled.asm.frag
new file mode 100644
index 0000000..97e88b5
--- /dev/null
+++ b/shaders-msl/asm/frag/depth-image-color-format-sampled.asm.frag
@@ -0,0 +1,173 @@
+; SPIR-V
+; Version: 1.0
+; Generator: Khronos SPIR-V Tools Assembler; 0
+; Bound: 134
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %1 "main" %2 %3 %4
+               OpExecutionMode %1 OriginUpperLeft
+               OpDecorate %3 Location 0
+               OpDecorate %2 Location 1
+               OpDecorate %4 BuiltIn FragCoord
+               OpDecorate %5 ArrayStride 4
+               OpDecorate %6 ArrayStride 16
+               OpMemberDecorate %7 0 Offset 0
+               OpDecorate %7 BufferBlock
+               OpDecorate %8 DescriptorSet 0
+               OpDecorate %8 Binding 0
+               OpDecorate %9 DescriptorSet 0
+               OpDecorate %9 Binding 1
+               OpDecorate %10 DescriptorSet 0
+               OpDecorate %10 Binding 2
+         %11 = OpTypeVoid
+         %12 = OpTypeBool
+         %13 = OpTypeInt 32 1
+         %14 = OpTypeInt 32 0
+         %16 = OpTypeFloat 32
+         %17 = OpTypeVector %13 2
+         %18 = OpTypeVector %14 2
+         %19 = OpTypeVector %16 2
+         %20 = OpTypeVector %13 3
+         %21 = OpTypeVector %14 3
+         %22 = OpTypeVector %16 3
+         %23 = OpTypeVector %13 4
+         %24 = OpTypeVector %14 4
+         %25 = OpTypeVector %16 4
+         %26 = OpTypeVector %12 4
+         %27 = OpTypeFunction %25 %25
+         %28 = OpTypeFunction %12
+         %29 = OpTypeFunction %11
+         %30 = OpTypePointer Input %16
+         %31 = OpTypePointer Input %13
+         %32 = OpTypePointer Input %14
+         %33 = OpTypePointer Input %19
+         %34 = OpTypePointer Input %17
+         %35 = OpTypePointer Input %18
+         %38 = OpTypePointer Input %22
+         %40 = OpTypePointer Input %25
+         %41 = OpTypePointer Input %23
+         %42 = OpTypePointer Input %24
+         %43 = OpTypePointer Output %16
+         %44 = OpTypePointer Output %13
+         %45 = OpTypePointer Output %14
+         %46 = OpTypePointer Output %19
+         %47 = OpTypePointer Output %17
+         %48 = OpTypePointer Output %18
+         %49 = OpTypePointer Output %25
+         %50 = OpTypePointer Output %23
+         %51 = OpTypePointer Output %24
+         %52 = OpTypePointer Function %16
+         %53 = OpTypePointer Function %13
+         %54 = OpTypePointer Function %25
+         %55 = OpConstant %16 1
+         %56 = OpConstant %16 0
+         %57 = OpConstant %16 0.5
+         %58 = OpConstant %16 -1
+         %59 = OpConstant %16 7
+         %60 = OpConstant %16 8
+         %61 = OpConstant %13 0
+         %62 = OpConstant %13 1
+         %63 = OpConstant %13 2
+         %64 = OpConstant %13 3
+         %65 = OpConstant %13 4
+         %66 = OpConstant %14 0
+         %67 = OpConstant %14 1
+         %68 = OpConstant %14 2
+         %69 = OpConstant %14 3
+         %70 = OpConstant %14 32
+         %71 = OpConstant %14 4
+         %72 = OpConstant %14 2147483647
+         %73 = OpConstantComposite %25 %55 %55 %55 %55
+         %74 = OpConstantComposite %25 %55 %56 %56 %55
+         %75 = OpConstantComposite %25 %57 %57 %57 %57
+         %76 = OpTypeArray %16 %67
+         %77 = OpTypeArray %16 %68
+         %78 = OpTypeArray %25 %69
+         %79 = OpTypeArray %16 %71
+         %80 = OpTypeArray %25 %70
+         %81 = OpTypePointer Input %78
+         %82 = OpTypePointer Input %80
+         %83 = OpTypePointer Output %77
+         %84 = OpTypePointer Output %78
+         %85 = OpTypePointer Output %79
+          %4 = OpVariable %40 Input
+          %3 = OpVariable %49 Output
+          %2 = OpVariable %40 Input
+         %86 = OpConstant %14 64
+         %87 = OpConstant %13 64
+         %88 = OpConstant %13 8
+         %89 = OpConstantComposite %19 %60 %60
+          %5 = OpTypeArray %16 %86
+          %6 = OpTypeArray %25 %86
+         %90 = OpTypePointer Uniform %16
+         %91 = OpTypePointer Uniform %25
+          %7 = OpTypeStruct %6
+         %92 = OpTypePointer Uniform %7
+         %10 = OpVariable %92 Uniform
+         %93 = OpTypeImage %16 2D 1 0 0 1 Rgba32f
+         %94 = OpTypePointer UniformConstant %93
+          %8 = OpVariable %94 UniformConstant
+         %95 = OpTypeSampler
+         %96 = OpTypePointer UniformConstant %95
+          %9 = OpVariable %96 UniformConstant
+         %97 = OpTypeSampledImage %93
+         %98 = OpTypeFunction %11 %13
+          %1 = OpFunction %11 None %29
+         %99 = OpLabel
+        %100 = OpLoad %25 %2
+        %101 = OpFunctionCall %25 %102 %100
+               OpStore %3 %101
+               OpReturn
+               OpFunctionEnd
+        %103 = OpFunction %12 None %28
+        %104 = OpLabel
+        %105 = OpAccessChain %30 %4 %61
+        %106 = OpAccessChain %30 %4 %62
+        %107 = OpLoad %16 %105
+        %108 = OpLoad %16 %106
+        %109 = OpFOrdEqual %12 %107 %57
+        %110 = OpFOrdEqual %12 %108 %57
+        %111 = OpLogicalAnd %12 %109 %110
+               OpReturnValue %111
+               OpFunctionEnd
+        %112 = OpFunction %11 None %98
+        %113 = OpFunctionParameter %13
+        %114 = OpLabel
+        %115 = OpSRem %13 %113 %88
+        %116 = OpSDiv %13 %113 %88
+        %117 = OpCompositeConstruct %17 %115 %116
+        %118 = OpConvertSToF %19 %117
+        %119 = OpFDiv %19 %118 %89
+        %120 = OpLoad %93 %8
+        %121 = OpLoad %95 %9
+        %122 = OpSampledImage %97 %120 %121
+        %123 = OpImageSampleExplicitLod %25 %122 %119 Lod %56
+         %36 = OpAccessChain %91 %10 %61 %113
+               OpStore %36 %123
+               OpReturn
+               OpFunctionEnd
+        %102 = OpFunction %25 None %27
+        %124 = OpFunctionParameter %25
+        %125 = OpLabel
+        %126 = OpVariable %53 Function
+               OpStore %126 %61
+               OpBranch %127
+        %127 = OpLabel
+         %15 = OpLoad %13 %126
+        %128 = OpSLessThan %12 %15 %87
+               OpLoopMerge %129 %130 None
+               OpBranchConditional %128 %131 %129
+        %131 = OpLabel
+        %132 = OpLoad %13 %126
+        %133 = OpFunctionCall %11 %112 %132
+               OpBranch %130
+        %130 = OpLabel
+         %37 = OpLoad %13 %126
+         %39 = OpIAdd %13 %37 %62
+               OpStore %126 %39
+               OpBranch %127
+        %129 = OpLabel
+               OpReturnValue %124
+               OpFunctionEnd
+
diff --git a/spirv_cross.cpp b/spirv_cross.cpp
index 07fdb94..13e4135 100644
--- a/spirv_cross.cpp
+++ b/spirv_cross.cpp
@@ -4448,16 +4448,13 @@
 		if (length < 4)
 			return false;
 
-		uint32_t result_type = args[0];
-		uint32_t result_id = args[1];
-		auto &type = compiler.get<SPIRType>(result_type);
-
 		// If the underlying resource has been used for comparison then duplicate loads of that resource must be too.
 		// This image must be a depth image.
+		uint32_t result_id = args[1];
 		uint32_t image = args[2];
 		uint32_t sampler = args[3];
 
-		if (type.image.depth || dref_combined_samplers.count(result_id) != 0)
+		if (dref_combined_samplers.count(result_id) != 0)
 		{
 			add_hierarchy_to_comparison_ids(image);
 
@@ -4717,9 +4714,11 @@
 	return false;
 }
 
-bool Compiler::image_is_comparison(const SPIRType &type, uint32_t id) const
+// An image is determined to be a depth image if it is marked as a depth image and is not also
+// explicitly marked with a color format, or if there are any sample/gather compare operations on it.
+bool Compiler::is_depth_image(const SPIRType &type, uint32_t id) const
 {
-	return type.image.depth || (comparison_ids.count(id) != 0);
+	return (type.image.depth && type.image.format == ImageFormatUnknown) || comparison_ids.count(id);
 }
 
 bool Compiler::type_is_opaque_value(const SPIRType &type) const
diff --git a/spirv_cross.hpp b/spirv_cross.hpp
index d896796..2674437 100644
--- a/spirv_cross.hpp
+++ b/spirv_cross.hpp
@@ -1107,7 +1107,7 @@
 	Bitset combined_decoration_for_member(const SPIRType &type, uint32_t index) const;
 	static bool is_desktop_only_format(spv::ImageFormat format);
 
-	bool image_is_comparison(const SPIRType &type, uint32_t id) const;
+	bool is_depth_image(const SPIRType &type, uint32_t id) const;
 
 	void set_extended_decoration(uint32_t id, ExtendedDecorations decoration, uint32_t value = 0);
 	uint32_t get_extended_decoration(uint32_t id, ExtendedDecorations decoration) const;
diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp
index 9578502..a663318 100644
--- a/spirv_glsl.cpp
+++ b/spirv_glsl.cpp
@@ -6215,7 +6215,7 @@
 	// GLES has very limited support for shadow samplers.
 	// Basically shadow2D and shadow2DProj work through EXT_shadow_samplers,
 	// everything else can just throw
-	bool is_comparison = image_is_comparison(imgtype, tex);
+	bool is_comparison = is_depth_image(imgtype, tex);
 	if (is_comparison && is_legacy_es())
 	{
 		if (op == "texture" || op == "textureProj")
@@ -6842,7 +6842,7 @@
 	expr += ")";
 
 	// texture(samplerXShadow) returns float. shadowX() returns vec4. Swizzle here.
-	if (is_legacy() && image_is_comparison(imgtype, img))
+	if (is_legacy() && is_depth_image(imgtype, img))
 		expr += ".r";
 
 	// Sampling from a texture which was deduced to be a depth image, might actually return 1 component here.
@@ -6853,16 +6853,16 @@
 		const auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
 		VariableID image_id = combined ? combined->image : img;
 
-		if (combined && image_is_comparison(imgtype, combined->image))
+		if (combined && is_depth_image(imgtype, combined->image))
 			image_is_depth = true;
-		else if (image_is_comparison(imgtype, img))
+		else if (is_depth_image(imgtype, img))
 			image_is_depth = true;
 
 		// We must also check the backing variable for the image.
 		// We might have loaded an OpImage, and used that handle for two different purposes.
 		// Once with comparison, once without.
 		auto *image_variable = maybe_get_backing_variable(image_id);
-		if (image_variable && image_is_comparison(get<SPIRType>(image_variable->basetype), image_variable->self))
+		if (image_variable && is_depth_image(get<SPIRType>(image_variable->basetype), image_variable->self))
 			image_is_depth = true;
 
 		if (image_is_depth)
@@ -6930,7 +6930,7 @@
 	// This happens for HLSL SampleCmpLevelZero on Texture2DArray and TextureCube.
 	bool workaround_lod_array_shadow_as_grad = false;
 	if (((imgtype.image.arrayed && imgtype.image.dim == Dim2D) || imgtype.image.dim == DimCube) &&
-	    image_is_comparison(imgtype, tex) && args.lod)
+	    is_depth_image(imgtype, tex) && args.lod)
 	{
 		if (!expression_is_constant_null(args.lod))
 		{
@@ -7074,7 +7074,7 @@
 	// This happens for HLSL SampleCmpLevelZero on Texture2DArray and TextureCube.
 	bool workaround_lod_array_shadow_as_grad =
 	    ((imgtype.image.arrayed && imgtype.image.dim == Dim2D) || imgtype.image.dim == DimCube) &&
-	    image_is_comparison(imgtype, img) && args.lod != 0;
+	    is_depth_image(imgtype, img) && args.lod != 0;
 
 	if (args.dref)
 	{
@@ -13392,7 +13392,7 @@
 
 	// "Shadow" state in GLSL only exists for samplers and combined image samplers.
 	if (((type.basetype == SPIRType::SampledImage) || (type.basetype == SPIRType::Sampler)) &&
-	    image_is_comparison(type, id))
+	    is_depth_image(type, id))
 	{
 		res += "Shadow";
 	}
@@ -15831,7 +15831,7 @@
 
 bool CompilerGLSL::variable_is_depth_or_compare(VariableID id) const
 {
-	return image_is_comparison(get<SPIRType>(get<SPIRVariable>(id).basetype), id);
+	return is_depth_image(get<SPIRType>(get<SPIRVariable>(id).basetype), id);
 }
 
 const char *CompilerGLSL::ShaderSubgroupSupportHelper::get_extension_name(Candidate c)
diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp
index 1f04c40..bdcb6dd 100644
--- a/spirv_hlsl.cpp
+++ b/spirv_hlsl.cpp
@@ -2380,7 +2380,7 @@
 		    arg_type.image.dim != DimBuffer)
 		{
 			// Manufacture automatic sampler arg for SampledImage texture
-			arglist.push_back(join(image_is_comparison(arg_type, arg.id) ? "SamplerComparisonState " : "SamplerState ",
+			arglist.push_back(join(is_depth_image(arg_type, arg.id) ? "SamplerComparisonState " : "SamplerState ",
 			                       to_sampler_expression(arg.id), type_to_array_glsl(arg_type)));
 		}
 
@@ -2910,7 +2910,7 @@
 		{
 			texop += img_expr;
 
-			if (image_is_comparison(imgtype, img))
+			if (is_depth_image(imgtype, img))
 			{
 				if (gather)
 				{
@@ -3386,7 +3386,7 @@
 		if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
 		{
 			// For combined image samplers, also emit a combined image sampler.
-			if (image_is_comparison(type, var.self))
+			if (is_depth_image(type, var.self))
 				statement("SamplerComparisonState ", to_sampler_expression(var.self), type_to_array_glsl(type),
 				          to_resource_binding_sampler(var), ";");
 			else
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index 3f53ebd..b08bb86 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -9393,8 +9393,6 @@
 string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
 {
 	VariableID img = args.base.img;
-	auto &imgtype = *args.base.imgtype;
-
 	const MSLConstexprSampler *constexpr_sampler = nullptr;
 	bool is_dynamic_img_sampler = false;
 	if (auto *var = maybe_get_backing_variable(img))
@@ -9408,8 +9406,9 @@
 	if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
 	    (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
 	{
-		add_spv_func_and_recompile(imgtype.image.depth ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
-		return imgtype.image.depth ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
+		bool is_compare = comparison_ids.count(img);
+		add_spv_func_and_recompile(is_compare ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
+		return is_compare ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
 	}
 
 	auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
@@ -10021,7 +10020,7 @@
 				image_var = var->self;
 			}
 
-			if (image_var == 0 || !image_is_comparison(expression_type(image_var), image_var))
+			if (image_var == 0 || !is_depth_image(expression_type(image_var), image_var))
 				farg_str += ", " + to_component_argument(args.component);
 		}
 	}
@@ -13631,7 +13630,7 @@
 
 	// Bypass pointers because we need the real image struct
 	auto &img_type = get<SPIRType>(type.self).image;
-	if (image_is_comparison(type, id))
+	if (is_depth_image(type, id))
 	{
 		switch (img_type.dim)
 		{