Fix wrong detection of trivial_mix_op.

Effectively, only the last component of the select was considered, need
to correctly early out if any case is hit.
diff --git a/reference/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp b/reference/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..94aec45
--- /dev/null
+++ b/reference/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,15 @@
+static const uint3 gl_WorkGroupSize = uint3(1u, 1u, 1u);
+
+RWByteAddressBuffer _14 : register(u0);
+
+void comp_main()
+{
+    bool3 c = bool3(asfloat(_14.Load3(16)).x < 1.0f.xxx.x, asfloat(_14.Load3(16)).y < 1.0f.xxx.y, asfloat(_14.Load3(16)).z < 1.0f.xxx.z);
+    _14.Store3(0, asuint(float3(c.x ? float3(0.0f, 0.0f, 1.0f).x : float3(1.0f, 0.0f, 0.0f).x, c.y ? float3(0.0f, 0.0f, 1.0f).y : float3(1.0f, 0.0f, 0.0f).y, c.z ? float3(0.0f, 0.0f, 1.0f).z : float3(1.0f, 0.0f, 0.0f).z)));
+}
+
+[numthreads(1, 1, 1)]
+void main()
+{
+    comp_main();
+}
diff --git a/reference/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp b/reference/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..328b42c
--- /dev/null
+++ b/reference/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,19 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct A
+{
+    float3 a;
+    float3 b;
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
+
+kernel void main0(device A& _14 [[buffer(0)]])
+{
+    bool3 c = _14.b < float3(1.0);
+    _14.a = select(float3(1.0, 0.0, 0.0), float3(0.0, 0.0, 1.0), c);
+}
+
diff --git a/reference/shaders-no-opt/comp/trivial-select-cast-vector.comp b/reference/shaders-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..92573ff
--- /dev/null
+++ b/reference/shaders-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,15 @@
+#version 450
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+
+layout(binding = 0, std430) buffer A
+{
+    vec3 a;
+    vec3 b;
+} _14;
+
+void main()
+{
+    bvec3 c = lessThan(_14.b, vec3(1.0));
+    _14.a = mix(vec3(1.0, 0.0, 0.0), vec3(0.0, 0.0, 1.0), c);
+}
+
diff --git a/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp b/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..c3e0922
--- /dev/null
+++ b/shaders-hlsl-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,14 @@
+#version 450
+layout(local_size_x = 1) in;
+
+layout(set = 0, binding = 0) buffer A
+{
+	vec3 a;
+	vec3 b;
+};
+
+void main()
+{
+	bvec3 c = lessThan(b, vec3(1.0));
+	a = mix(vec3(1, 0, 0), vec3(0, 0, 1), c);
+}
diff --git a/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp b/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..c3e0922
--- /dev/null
+++ b/shaders-msl-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,14 @@
+#version 450
+layout(local_size_x = 1) in;
+
+layout(set = 0, binding = 0) buffer A
+{
+	vec3 a;
+	vec3 b;
+};
+
+void main()
+{
+	bvec3 c = lessThan(b, vec3(1.0));
+	a = mix(vec3(1, 0, 0), vec3(0, 0, 1), c);
+}
diff --git a/shaders-no-opt/comp/trivial-select-cast-vector.comp b/shaders-no-opt/comp/trivial-select-cast-vector.comp
new file mode 100644
index 0000000..c3e0922
--- /dev/null
+++ b/shaders-no-opt/comp/trivial-select-cast-vector.comp
@@ -0,0 +1,14 @@
+#version 450
+layout(local_size_x = 1) in;
+
+layout(set = 0, binding = 0) buffer A
+{
+	vec3 a;
+	vec3 b;
+};
+
+void main()
+{
+	bvec3 c = lessThan(b, vec3(1.0));
+	a = mix(vec3(1, 0, 0), vec3(0, 0, 1), c);
+}
diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp
index af486e4..36697d8 100644
--- a/spirv_glsl.cpp
+++ b/spirv_glsl.cpp
@@ -6265,9 +6265,9 @@
 
 	// If our bool selects between 0 and 1, we can cast from bool instead, making our trivial constructor.
 	bool ret = true;
-	for (uint32_t col = 0; col < value_type.columns; col++)
+	for (uint32_t col = 0; ret && col < value_type.columns; col++)
 	{
-		for (uint32_t row = 0; row < value_type.vecsize; row++)
+		for (uint32_t row = 0; ret && row < value_type.vecsize; row++)
 		{
 			switch (type.basetype)
 			{
@@ -6299,12 +6299,10 @@
 				break;
 
 			default:
-				return false;
+				ret = false;
+				break;
 			}
 		}
-
-		if (!ret)
-			break;
 	}
 
 	if (ret)