Merge pull request #1766 from KhronosGroup/fix-1765

Fix some silly bugs in trivial mix op detection.
diff --git a/reference/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp b/reference/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..94aec45
--- /dev/null
+++ b/reference/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,15 @@
+static const uint3 gl_WorkGroupSize = uint3(1u, 1u, 1u);
+
+RWByteAddressBuffer _14 : register(u0);
+
+void comp_main()
+{
+    bool3 c = bool3(asfloat(_14.Load3(16)).x < 1.0f.xxx.x, asfloat(_14.Load3(16)).y < 1.0f.xxx.y, asfloat(_14.Load3(16)).z < 1.0f.xxx.z);
+    _14.Store3(0, asuint(float3(c.x ? float3(0.0f, 0.0f, 1.0f).x : float3(1.0f, 0.0f, 0.0f).x, c.y ? float3(0.0f, 0.0f, 1.0f).y : float3(1.0f, 0.0f, 0.0f).y, c.z ? float3(0.0f, 0.0f, 1.0f).z : float3(1.0f, 0.0f, 0.0f).z)));
+}
+
+[numthreads(1, 1, 1)]
+void main()
+{
+    comp_main();
+}
diff --git a/reference/shaders-hlsl-no-opt/comp/trivial-select-matrix.spv14.comp b/reference/shaders-hlsl-no-opt/comp/trivial-select-matrix.spv14.comp
new file mode 100644
index 0000000..7bd1c76
--- /dev/null
+++ b/reference/shaders-hlsl-no-opt/comp/trivial-select-matrix.spv14.comp
@@ -0,0 +1,22 @@
+static const uint3 gl_WorkGroupSize = uint3(1u, 1u, 1u);
+
+RWByteAddressBuffer _14 : register(u0);
+
+void comp_main()
+{
+    bool c = asfloat(_14.Load(48)) < 1.0f;
+    float3x3 _29 = c ? float3x3(1.0f.xxx, 1.0f.xxx, 1.0f.xxx) : float3x3(0.0f.xxx, 0.0f.xxx, 0.0f.xxx);
+    _14.Store3(0, asuint(_29[0]));
+    _14.Store3(16, asuint(_29[1]));
+    _14.Store3(32, asuint(_29[2]));
+    float3x3 _37 = c ? float3x3(float3(1.0f, 0.0f, 0.0f), float3(0.0f, 1.0f, 0.0f), float3(0.0f, 0.0f, 1.0f)) : float3x3(0.0f.xxx, 0.0f.xxx, 0.0f.xxx);
+    _14.Store3(0, asuint(_37[0]));
+    _14.Store3(16, asuint(_37[1]));
+    _14.Store3(32, asuint(_37[2]));
+}
+
+[numthreads(1, 1, 1)]
+void main()
+{
+    comp_main();
+}
diff --git a/reference/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp b/reference/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..328b42c
--- /dev/null
+++ b/reference/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,19 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct A
+{
+    float3 a;
+    float3 b;
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
+
+kernel void main0(device A& _14 [[buffer(0)]])
+{
+    bool3 c = _14.b < float3(1.0);
+    _14.a = select(float3(1.0, 0.0, 0.0), float3(0.0, 0.0, 1.0), c);
+}
+
diff --git a/reference/shaders-msl-no-opt/comp/trivial-select-matrix.spv14.comp b/reference/shaders-msl-no-opt/comp/trivial-select-matrix.spv14.comp
new file mode 100644
index 0000000..2e37a32
--- /dev/null
+++ b/reference/shaders-msl-no-opt/comp/trivial-select-matrix.spv14.comp
@@ -0,0 +1,20 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct A
+{
+    float3x3 a;
+    float b;
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
+
+kernel void main0(device A& _14 [[buffer(0)]])
+{
+    bool c = _14.b < 1.0;
+    _14.a = c ? float3x3(float3(1.0), float3(1.0), float3(1.0)) : float3x3(float3(0.0), float3(0.0), float3(0.0));
+    _14.a = c ? float3x3(float3(1.0, 0.0, 0.0), float3(0.0, 1.0, 0.0), float3(0.0, 0.0, 1.0)) : float3x3(float3(0.0), float3(0.0), float3(0.0));
+}
+
diff --git a/reference/shaders-no-opt/comp/trivial-select-cast-vector.comp b/reference/shaders-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..92573ff
--- /dev/null
+++ b/reference/shaders-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,15 @@
+#version 450
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+
+layout(binding = 0, std430) buffer A
+{
+    vec3 a;
+    vec3 b;
+} _14;
+
+void main()
+{
+    bvec3 c = lessThan(_14.b, vec3(1.0));
+    _14.a = mix(vec3(1.0, 0.0, 0.0), vec3(0.0, 0.0, 1.0), c);
+}
+
diff --git a/reference/shaders-no-opt/comp/trivial-select-matrix.spv14.comp b/reference/shaders-no-opt/comp/trivial-select-matrix.spv14.comp
new file mode 100644
index 0000000..dd227e8
--- /dev/null
+++ b/reference/shaders-no-opt/comp/trivial-select-matrix.spv14.comp
@@ -0,0 +1,16 @@
+#version 450
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+
+layout(binding = 0, std430) buffer A
+{
+    mat3 a;
+    float b;
+} _14;
+
+void main()
+{
+    bool c = _14.b < 1.0;
+    _14.a = c ? mat3(vec3(1.0), vec3(1.0), vec3(1.0)) : mat3(vec3(0.0), vec3(0.0), vec3(0.0));
+    _14.a = c ? mat3(vec3(1.0, 0.0, 0.0), vec3(0.0, 1.0, 0.0), vec3(0.0, 0.0, 1.0)) : mat3(vec3(0.0), vec3(0.0), vec3(0.0));
+}
+
diff --git a/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp b/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..c3e0922
--- /dev/null
+++ b/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,14 @@
+#version 450
+layout(local_size_x = 1) in;
+
+layout(set = 0, binding = 0) buffer A
+{
+	vec3 a;
+	vec3 b;
+};
+
+void main()
+{
+	bvec3 c = lessThan(b, vec3(1.0));
+	a = mix(vec3(1, 0, 0), vec3(0, 0, 1), c);
+}
diff --git a/shaders-hlsl-no-opt/comp/trivial-select-matrix.spv14.comp b/shaders-hlsl-no-opt/comp/trivial-select-matrix.spv14.comp
new file mode 100644
index 0000000..5ffcc3f
--- /dev/null
+++ b/shaders-hlsl-no-opt/comp/trivial-select-matrix.spv14.comp
@@ -0,0 +1,16 @@
+#version 450
+layout(local_size_x = 1) in;
+
+layout(set = 0, binding = 0) buffer A
+{
+	mat3 a;
+	float b;
+};
+
+void main()
+{
+	// Scalar to Matrix
+	bool c = b < 1.0;
+	a = c ? mat3(vec3(1), vec3(1), vec3(1)) : mat3(vec3(0), vec3(0), vec3(0));
+	a = c ? mat3(1) : mat3(0);
+}
diff --git a/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp b/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..c3e0922
--- /dev/null
+++ b/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,14 @@
+#version 450
+layout(local_size_x = 1) in;
+
+layout(set = 0, binding = 0) buffer A
+{
+	vec3 a;
+	vec3 b;
+};
+
+void main()
+{
+	bvec3 c = lessThan(b, vec3(1.0));
+	a = mix(vec3(1, 0, 0), vec3(0, 0, 1), c);
+}
diff --git a/shaders-msl-no-opt/comp/trivial-select-matrix.spv14.comp b/shaders-msl-no-opt/comp/trivial-select-matrix.spv14.comp
new file mode 100644
index 0000000..5ffcc3f
--- /dev/null
+++ b/shaders-msl-no-opt/comp/trivial-select-matrix.spv14.comp
@@ -0,0 +1,16 @@
+#version 450
+layout(local_size_x = 1) in;
+
+layout(set = 0, binding = 0) buffer A
+{
+	mat3 a;
+	float b;
+};
+
+void main()
+{
+	// Scalar to Matrix
+	bool c = b < 1.0;
+	a = c ? mat3(vec3(1), vec3(1), vec3(1)) : mat3(vec3(0), vec3(0), vec3(0));
+	a = c ? mat3(1) : mat3(0);
+}
diff --git a/shaders-no-opt/comp/trivial-select-cast-vector.comp b/shaders-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..c3e0922
--- /dev/null
+++ b/shaders-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,14 @@
+#version 450
+layout(local_size_x = 1) in;
+
+layout(set = 0, binding = 0) buffer A
+{
+	vec3 a;
+	vec3 b;
+};
+
+void main()
+{
+	bvec3 c = lessThan(b, vec3(1.0));
+	a = mix(vec3(1, 0, 0), vec3(0, 0, 1), c);
+}
diff --git a/shaders-no-opt/comp/trivial-select-matrix.spv14.comp b/shaders-no-opt/comp/trivial-select-matrix.spv14.comp
new file mode 100644
index 0000000..5ffcc3f
--- /dev/null
+++ b/shaders-no-opt/comp/trivial-select-matrix.spv14.comp
@@ -0,0 +1,16 @@
+#version 450
+layout(local_size_x = 1) in;
+
+layout(set = 0, binding = 0) buffer A
+{
+	mat3 a;
+	float b;
+};
+
+void main()
+{
+	// Scalar to Matrix
+	bool c = b < 1.0;
+	a = c ? mat3(vec3(1), vec3(1), vec3(1)) : mat3(vec3(0), vec3(0), vec3(0));
+	a = c ? mat3(1) : mat3(0);
+}
diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp
index af486e4..ccab208 100644
--- a/spirv_glsl.cpp
+++ b/spirv_glsl.cpp
@@ -6263,48 +6263,49 @@
 	if (!backend.use_constructor_splatting && value_type.vecsize != lerptype.vecsize)
 		return false;
 
+	// Only valid way in SPIR-V 1.4 to use matrices in select is a scalar select.
+	// matrix(scalar) constructor fills in diagnonals, so gets messy very quickly.
+	// Just avoid this case.
+	if (value_type.columns > 1)
+		return false;
+
 	// If our bool selects between 0 and 1, we can cast from bool instead, making our trivial constructor.
 	bool ret = true;
-	for (uint32_t col = 0; col < value_type.columns; col++)
+	for (uint32_t row = 0; ret && row < value_type.vecsize; row++)
 	{
-		for (uint32_t row = 0; row < value_type.vecsize; row++)
+		switch (type.basetype)
 		{
-			switch (type.basetype)
-			{
-			case SPIRType::Short:
-			case SPIRType::UShort:
-				ret = cleft->scalar_u16(col, row) == 0 && cright->scalar_u16(col, row) == 1;
-				break;
-
-			case SPIRType::Int:
-			case SPIRType::UInt:
-				ret = cleft->scalar(col, row) == 0 && cright->scalar(col, row) == 1;
-				break;
-
-			case SPIRType::Half:
-				ret = cleft->scalar_f16(col, row) == 0.0f && cright->scalar_f16(col, row) == 1.0f;
-				break;
-
-			case SPIRType::Float:
-				ret = cleft->scalar_f32(col, row) == 0.0f && cright->scalar_f32(col, row) == 1.0f;
-				break;
-
-			case SPIRType::Double:
-				ret = cleft->scalar_f64(col, row) == 0.0 && cright->scalar_f64(col, row) == 1.0;
-				break;
-
-			case SPIRType::Int64:
-			case SPIRType::UInt64:
-				ret = cleft->scalar_u64(col, row) == 0 && cright->scalar_u64(col, row) == 1;
-				break;
-
-			default:
-				return false;
-			}
-		}
-
-		if (!ret)
+		case SPIRType::Short:
+		case SPIRType::UShort:
+			ret = cleft->scalar_u16(0, row) == 0 && cright->scalar_u16(0, row) == 1;
 			break;
+
+		case SPIRType::Int:
+		case SPIRType::UInt:
+			ret = cleft->scalar(0, row) == 0 && cright->scalar(0, row) == 1;
+			break;
+
+		case SPIRType::Half:
+			ret = cleft->scalar_f16(0, row) == 0.0f && cright->scalar_f16(0, row) == 1.0f;
+			break;
+
+		case SPIRType::Float:
+			ret = cleft->scalar_f32(0, row) == 0.0f && cright->scalar_f32(0, row) == 1.0f;
+			break;
+
+		case SPIRType::Double:
+			ret = cleft->scalar_f64(0, row) == 0.0 && cright->scalar_f64(0, row) == 1.0;
+			break;
+
+		case SPIRType::Int64:
+		case SPIRType::UInt64:
+			ret = cleft->scalar_u64(0, row) == 0 && cright->scalar_u64(0, row) == 1;
+			break;
+
+		default:
+			ret = false;
+			break;
+		}
 	}
 
 	if (ret)
diff --git a/test_shaders.py b/test_shaders.py
index 7cb2850..eab10a6 100755
--- a/test_shaders.py
+++ b/test_shaders.py
@@ -186,7 +186,8 @@
     spirv_path = create_temporary()
     msl_path = create_temporary(os.path.basename(shader))
 
-    spirv_env = 'vulkan1.1spv1.4' if ('.spv14.' in shader) else 'vulkan1.1'
+    spirv_14 = '.spv14.' in shader
+    spirv_env = 'vulkan1.1spv1.4' if spirv_14 else 'vulkan1.1'
 
     spirv_cmd = [paths.spirv_as, '--target-env', spirv_env, '-o', spirv_path, shader]
     if '.preserve.' in shader:
@@ -195,7 +196,8 @@
     if spirv:
         subprocess.check_call(spirv_cmd)
     else:
-        subprocess.check_call([paths.glslang, '--amb' ,'--target-env', 'vulkan1.1', '-V', '-o', spirv_path, shader])
+        glslang_env = 'spirv1.4' if spirv_14 else 'vulkan1.1'
+        subprocess.check_call([paths.glslang, '--amb' ,'--target-env', glslang_env, '-V', '-o', spirv_path, shader])
 
     if opt and (not shader_is_invalid_spirv(shader)):
         if '.graphics-robust-access.' in shader:
@@ -432,7 +434,8 @@
     spirv_path = create_temporary()
     hlsl_path = create_temporary(os.path.basename(shader))
 
-    spirv_env = 'vulkan1.1spv1.4' if '.spv14.' in shader else 'vulkan1.1'
+    spirv_14 = '.spv14.' in shader
+    spirv_env = 'vulkan1.1spv1.4' if spirv_14 else 'vulkan1.1'
     spirv_cmd = [paths.spirv_as, '--target-env', spirv_env, '-o', spirv_path, shader]
     if '.preserve.' in shader:
         spirv_cmd.append('--preserve-numeric-ids')
@@ -440,7 +443,8 @@
     if spirv:
         subprocess.check_call(spirv_cmd)
     else:
-        subprocess.check_call([paths.glslang, '--amb', '--target-env', 'vulkan1.1', '-V', '-o', spirv_path, shader])
+        glslang_env = 'spirv1.4' if spirv_14 else 'vulkan1.1'
+        subprocess.check_call([paths.glslang, '--amb', '--target-env', glslang_env, '-V', '-o', spirv_path, shader])
 
     if opt and (not shader_is_invalid_spirv(hlsl_path)):
         subprocess.check_call([paths.spirv_opt, '--skip-validation', '-O', '-o', spirv_path, spirv_path])