Merge pull request #1773 from congyue1977/master

Support Metal 2.4 Intersection Query, Implement GL_EXT_ray_query.
diff --git a/reference/opt/shaders-msl/comp/ray-query.nocompat.spv14.vk.comp b/reference/opt/shaders-msl/comp/ray-query.nocompat.spv14.vk.comp
new file mode 100644
index 0000000..b03d524
--- /dev/null
+++ b/reference/opt/shaders-msl/comp/ray-query.nocompat.spv14.vk.comp
@@ -0,0 +1,91 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+#pragma clang diagnostic ignored "-Wmissing-braces"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+using namespace metal::raytracing;
+
+using namespace metal;
+
+template<typename T, size_t Num>
+struct spvUnsafeArray
+{
+    T elements[Num ? Num : 1];
+    
+    thread T& operator [] (size_t pos) thread
+    {
+        return elements[pos];
+    }
+    constexpr const thread T& operator [] (size_t pos) const thread
+    {
+        return elements[pos];
+    }
+    
+    device T& operator [] (size_t pos) device
+    {
+        return elements[pos];
+    }
+    constexpr const device T& operator [] (size_t pos) const device
+    {
+        return elements[pos];
+    }
+    
+    constexpr const constant T& operator [] (size_t pos) const constant
+    {
+        return elements[pos];
+    }
+    
+    threadgroup T& operator [] (size_t pos) threadgroup
+    {
+        return elements[pos];
+    }
+    constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
+    {
+        return elements[pos];
+    }
+};
+
+struct Params
+{
+    uint ray_flags;
+    uint cull_mask;
+    char _m2_pad[8];
+    packed_float3 origin;
+    float tmin;
+    packed_float3 dir;
+    float tmax;
+    float thit;
+};
+
+kernel void main0(constant Params& _18 [[buffer(1)]], acceleration_structure<instancing> AS0 [[buffer(0)]], acceleration_structure<instancing> AS1 [[buffer(2)]])
+{
+    intersection_query<instancing, triangle_data> q;
+    intersection_params _intersection_params_;
+    q.reset(ray(_18.origin, _18.dir, _18.tmin, _18.tmax), AS0, _intersection_params_);
+    spvUnsafeArray<intersection_query<instancing, triangle_data>, 2> q2;
+    intersection_params _intersection_params_;
+    q2[1].reset(ray(_18.origin, _18.dir, _18.tmin, _18.tmax), AS1, _intersection_params_);
+    bool _63 = q.next();
+    q2[0].abort();
+    q.commit_bounding_box_intersection(_18.thit);
+    _14.commit_triangle_intersection();
+    float _71 = q.get_ray_min_distance();
+    float3 _74 = q.get_world_space_ray_origin();
+    float3 _75 = q.get_world_space_ray_direction();
+    uint _80 = (uint)q2[1].get_committed_intersection_type();
+    uint _83 = (uint)q2[0].get_committed_intersection_type();
+    bool _85 = q2[1].is_candidate_non_opaque_bounding_box();
+    float _87 = q2[1].get_committed_distance();
+    float _89 = q2[1].get_committed_distance();
+    int _92 = q.get_committed_user_instance_id();
+    int _94 = q2[0].get_committed_instance_id();
+    int _96 = q2[1].get_committed_geometry_id();
+    int _97 = q.get_committed_primitive_id();
+    float2 _100 = q2[0].get_committed_triangle_barycentric_coord();
+    bool _103 = q.is_committed_triangle_front_facing();
+    float3 _104 = q.get_committed_ray_direction();
+    float3 _106 = q2[0].get_committed_ray_origin();
+    float4x3 _110 = q.get_committed_object_to_world_transform();
+    float4x3 _112 = q2[1].get_committed_world_to_object_transform();
+}
+
diff --git a/reference/shaders-msl/comp/ray-query.spv14.vk.ios.msl24.comp b/reference/shaders-msl/comp/ray-query.spv14.vk.ios.msl24.comp
new file mode 100644
index 0000000..f73d491
--- /dev/null
+++ b/reference/shaders-msl/comp/ray-query.spv14.vk.ios.msl24.comp
@@ -0,0 +1,111 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+#pragma clang diagnostic ignored "-Wmissing-braces"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+#if __METAL_VERSION__ >= 230
+#include <metal_raytracing>
+using namespace metal::raytracing;
+#endif
+
+using namespace metal;
+
+template<typename T, size_t Num>
+struct spvUnsafeArray
+{
+    T elements[Num ? Num : 1];
+    
+    thread T& operator [] (size_t pos) thread
+    {
+        return elements[pos];
+    }
+    constexpr const thread T& operator [] (size_t pos) const thread
+    {
+        return elements[pos];
+    }
+    
+    device T& operator [] (size_t pos) device
+    {
+        return elements[pos];
+    }
+    constexpr const device T& operator [] (size_t pos) const device
+    {
+        return elements[pos];
+    }
+    
+    constexpr const constant T& operator [] (size_t pos) const constant
+    {
+        return elements[pos];
+    }
+    
+    threadgroup T& operator [] (size_t pos) threadgroup
+    {
+        return elements[pos];
+    }
+    constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
+    {
+        return elements[pos];
+    }
+};
+
+struct Params
+{
+    uint ray_flags;
+    uint cull_mask;
+    char _m2_pad[8];
+    packed_float3 origin;
+    float tmin;
+    packed_float3 dir;
+    float tmax;
+    float thit;
+};
+
+kernel void main0(constant Params& _18 [[buffer(1)]], acceleration_structure<instancing> AS0 [[buffer(0)]], acceleration_structure<instancing> AS1 [[buffer(2)]])
+{
+    intersection_query<instancing, triangle_data> q;
+    q.reset(ray(_18.origin, _18.dir, _18.tmin, _18.tmax), AS0, intersection_params());
+    spvUnsafeArray<intersection_query<instancing, triangle_data>, 2> q2;
+    q2[1].reset(ray(_18.origin, _18.dir, _18.tmin, _18.tmax), AS1, intersection_params());
+    bool _63 = q.next();
+    bool res = _63;
+    q2[0].abort();
+    q.commit_bounding_box_intersection(_18.thit);
+    _14.commit_triangle_intersection();
+    float _71 = q.get_ray_min_distance();
+    float fval = _71;
+    float3 _74 = q.get_world_space_ray_origin();
+    float3 fvals = _74;
+    float3 _75 = q.get_world_space_ray_direction();
+    fvals = _75;
+    uint _80 = uint(q2[1].get_committed_intersection_type());
+    uint type = _80;
+    uint _83 = uint(q2[0].get_candidate_intersection_type()) - 1;
+    type = _83;
+    bool _85 = q2[1].is_candidate_non_opaque_bounding_box();
+    res = _85;
+    float _87 = q2[1].get_committed_distance();
+    fval = _87;
+    float _89 = q2[1].get_candidate_triangle_distance();
+    fval = _89;
+    int _92 = q.get_committed_user_instance_id();
+    int ival = _92;
+    int _94 = q2[0].get_candidate_instance_id();
+    ival = _94;
+    int _96 = q2[1].get_candidate_geometry_id();
+    ival = _96;
+    int _97 = q.get_committed_primitive_id();
+    ival = _97;
+    float2 _100 = q2[0].get_candidate_triangle_barycentric_coord();
+    fvals = float3(_100.x, _100.y, fvals.z);
+    bool _103 = q.is_committed_triangle_front_facing();
+    res = _103;
+    float3 _104 = q.get_candidate_ray_direction();
+    fvals = _104;
+    float3 _106 = q2[0].get_committed_ray_origin();
+    fvals = _106;
+    float4x3 _110 = q.get_candidate_object_to_world_transform();
+    float4x3 matrices = _110;
+    float4x3 _112 = q2[1].get_committed_world_to_object_transform();
+    matrices = _112;
+}
+
diff --git a/shaders-msl/comp/ray-query.spv14.vk.ios.msl24.comp b/shaders-msl/comp/ray-query.spv14.vk.ios.msl24.comp
new file mode 100644
index 0000000..fba72ad
--- /dev/null
+++ b/shaders-msl/comp/ray-query.spv14.vk.ios.msl24.comp
@@ -0,0 +1,58 @@
+#version 460
+#extension GL_EXT_ray_query : require
+#extension GL_EXT_ray_tracing : require
+#extension GL_EXT_ray_flags_primitive_culling : require
+layout(primitive_culling);
+
+layout(set = 0, binding = 0) uniform accelerationStructureEXT AS0;
+layout(set = 0, binding = 1) uniform accelerationStructureEXT AS1;
+
+layout(set = 0, binding = 2) uniform Params
+{
+	uint ray_flags;
+	uint cull_mask;
+	vec3 origin;
+	float tmin;
+	vec3 dir;
+	float tmax;
+	float thit;
+};
+
+rayQueryEXT q2[2];
+
+void main()
+{
+	rayQueryEXT q;
+	bool res;
+	uint type;
+	float fval;
+	vec3 fvals;
+	int ival;
+	mat4x3 matrices;
+
+	rayQueryInitializeEXT(q, AS0, ray_flags, cull_mask, origin, tmin, dir, tmax);
+	rayQueryInitializeEXT(q2[1], AS1, ray_flags, cull_mask, origin, tmin, dir, tmax);
+
+	res = rayQueryProceedEXT(q);
+	rayQueryTerminateEXT(q2[0]);
+	rayQueryGenerateIntersectionEXT(q, thit);
+	rayQueryConfirmIntersectionEXT(q2[1]);
+	fval = rayQueryGetRayTMinEXT(q);
+	fvals = rayQueryGetWorldRayDirectionEXT(q);
+	fvals = rayQueryGetWorldRayOriginEXT(q);
+	type = rayQueryGetIntersectionTypeEXT(q2[1], true);
+	type = rayQueryGetIntersectionTypeEXT(q2[0], false);
+	res = rayQueryGetIntersectionCandidateAABBOpaqueEXT(q2[1]);
+	fval = rayQueryGetIntersectionTEXT(q2[1], true);
+	fval = rayQueryGetIntersectionTEXT(q2[1], false);
+	ival = rayQueryGetIntersectionInstanceCustomIndexEXT(q, true);
+	ival = rayQueryGetIntersectionInstanceIdEXT(q2[0], false);
+	ival = rayQueryGetIntersectionGeometryIndexEXT(q2[1], false);
+	ival = rayQueryGetIntersectionPrimitiveIndexEXT(q, true);
+	fvals.xy = rayQueryGetIntersectionBarycentricsEXT(q2[0], false);
+	res = rayQueryGetIntersectionFrontFaceEXT(q, true);
+	fvals = rayQueryGetIntersectionObjectRayDirectionEXT(q, false);
+	fvals = rayQueryGetIntersectionObjectRayOriginEXT(q2[0], true);
+	matrices = rayQueryGetIntersectionObjectToWorldEXT(q, false);
+	matrices = rayQueryGetIntersectionWorldToObjectEXT(q2[1], true);
+}
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index f12b1eb..fed295a 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -1512,6 +1512,14 @@
 	    (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
 	                          (need_subpass_input && !msl_options.use_framebuffer_fetch_subpasses))))
 		needs_sample_id = true;
+
+	if (is_intersection_query())
+	{
+		add_header_line("#if __METAL_VERSION__ >= 230");
+		add_header_line("#include <metal_raytracing>");
+		add_header_line("using namespace metal::raytracing;");
+		add_header_line("#endif");
+	}
 }
 
 // Move the Private and Workgroup global variables to the entry function.
@@ -8373,6 +8381,100 @@
 			SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
 		break; // Nothing to do in the body
 
+	case OpConvertUToAccelerationStructureKHR:
+		SPIRV_CROSS_THROW("ConvertUToAccelerationStructure is not supported in MSL.");
+		break; // Nothing to do in the body
+	case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
+		SPIRV_CROSS_THROW("BindingTableRecordOffset is not supported in MSL.");
+		break; // Nothing to do in the body
+
+	case OpRayQueryInitializeKHR:
+	{
+		flush_variable_declaration(ops[0]);
+
+		statement(to_expression(ops[0]), ".reset(", "ray(", to_expression(ops[4]), ", ", to_expression(ops[6]), ", ",
+		          to_expression(ops[5]), ", ", to_expression(ops[7]), "), ", to_expression(ops[1]),
+		          ", intersection_params());");
+		break;
+	}
+	case OpRayQueryProceedKHR:
+	{
+		flush_variable_declaration(ops[0]);
+		emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".next()"), false);
+		break;
+	}
+#define MSL_RAY_QUERY_IS_CANDIDATE get<SPIRConstant>(ops[3]).scalar_i32() == 0
+
+#define MSL_RAY_QUERY_GET_OP(op, msl_op)                                                   \
+	case OpRayQueryGet##op##KHR:                                                           \
+		flush_variable_declaration(ops[2]);                                                \
+		emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_" #msl_op "()"), false); \
+		break
+
+#define MSL_RAY_QUERY_OP_INNER2(op, msl_prefix, msl_op)                                                          \
+	case OpRayQueryGet##op##KHR:                                                                                 \
+		flush_variable_declaration(ops[2]);                                                                      \
+		if (MSL_RAY_QUERY_IS_CANDIDATE)                                                                          \
+			emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_candidate_" #msl_op "()"), false); \
+		else                                                                                                     \
+			emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_committed_" #msl_op "()"), false); \
+		break
+
+#define MSL_RAY_QUERY_GET_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .get, msl_op)
+#define MSL_RAY_QUERY_IS_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .is, msl_op)
+
+		MSL_RAY_QUERY_GET_OP(RayTMin, ray_min_distance);
+		MSL_RAY_QUERY_GET_OP(WorldRayOrigin, world_space_ray_direction);
+		MSL_RAY_QUERY_GET_OP(WorldRayDirection, world_space_ray_origin);
+		MSL_RAY_QUERY_GET_OP2(IntersectionInstanceId, instance_id);
+		MSL_RAY_QUERY_GET_OP2(IntersectionInstanceCustomIndex, user_instance_id);
+		MSL_RAY_QUERY_GET_OP2(IntersectionBarycentrics, triangle_barycentric_coord);
+		MSL_RAY_QUERY_GET_OP2(IntersectionPrimitiveIndex, primitive_id);
+		MSL_RAY_QUERY_GET_OP2(IntersectionGeometryIndex, geometry_id);
+		MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayOrigin, ray_origin);
+		MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayDirection, ray_direction);
+		MSL_RAY_QUERY_GET_OP2(IntersectionObjectToWorld, object_to_world_transform);
+		MSL_RAY_QUERY_GET_OP2(IntersectionWorldToObject, world_to_object_transform);
+		MSL_RAY_QUERY_IS_OP2(IntersectionFrontFace, triangle_front_facing);
+
+	case OpRayQueryGetIntersectionTypeKHR:
+		flush_variable_declaration(ops[2]);
+		if (MSL_RAY_QUERY_IS_CANDIDATE)
+			emit_op(ops[0], ops[1], join("uint(", to_expression(ops[2]), ".get_candidate_intersection_type()) - 1"),
+			        false);
+		else
+			emit_op(ops[0], ops[1], join("uint(", to_expression(ops[2]), ".get_committed_intersection_type())"), false);
+		break;
+	case OpRayQueryGetIntersectionTKHR:
+		flush_variable_declaration(ops[2]);
+		if (MSL_RAY_QUERY_IS_CANDIDATE)
+			emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_candidate_triangle_distance()"), false);
+		else
+			emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_committed_distance()"), false);
+		break;
+	case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
+	{
+		flush_variable_declaration(ops[0]);
+		emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".is_candidate_non_opaque_bounding_box()"), false);
+		break;
+	}
+	case OpRayQueryConfirmIntersectionKHR:
+		flush_variable_declaration(ops[0]);
+		statement(to_expression(ops[2]), ".commit_triangle_intersection();");
+		break;
+	case OpRayQueryGenerateIntersectionKHR:
+		flush_variable_declaration(ops[0]);
+		statement(to_expression(ops[0]), ".commit_bounding_box_intersection(", to_expression(ops[1]), ");");
+		break;
+	case OpRayQueryTerminateKHR:
+		flush_variable_declaration(ops[0]);
+		statement(to_expression(ops[0]), ".abort();");
+		break;
+#undef MSL_RAY_QUERY_GET_OP
+#undef MSL_RAY_QUERY_IS_CANDIDATE
+#undef MSL_RAY_QUERY_IS_OP2
+#undef MSL_RAY_QUERY_GET_OP2
+#undef MSL_RAY_QUERY_OP_INNER2
 	default:
 		CompilerGLSL::emit_instruction(instruction);
 		break;
@@ -11295,6 +11397,12 @@
 	        (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input));
 }
 
+bool CompilerMSL::is_intersection_query() const
+{
+	auto &caps = get_declared_capabilities();
+	return std::find(caps.begin(), caps.end(), CapabilityRayQueryKHR) != caps.end();
+}
+
 void CompilerMSL::entry_point_args_builtin(string &ep_args)
 {
 	// Builtin variables
@@ -11773,6 +11881,10 @@
 			}
 			break;
 		}
+		case SPIRType::AccelerationStructure:
+			ep_args += ", " + type_to_glsl(type, var_id) + " " + r.name;
+			ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]";
+			break;
 		default:
 			if (!ep_args.empty())
 				ep_args += ", ";
@@ -13283,6 +13395,17 @@
 	case SPIRType::Double:
 		type_name = "double"; // Currently unsupported
 		break;
+	case SPIRType::AccelerationStructure:
+		if (msl_options.supports_msl_version(2, 4))
+			type_name = "acceleration_structure<instancing>";
+		else if (msl_options.supports_msl_version(2, 3))
+			type_name = "instance_acceleration_structure";
+		else
+			SPIRV_CROSS_THROW("Acceleration Structure Type is supported in MSL 2.3 and above.");
+		break;
+	case SPIRType::RayQuery:
+		type_name = "intersection_query<instancing, triangle_data>";
+		break;
 
 	default:
 		return "unknown_type";
diff --git a/spirv_msl.hpp b/spirv_msl.hpp
index 50d0686..ce43cd9 100644
--- a/spirv_msl.hpp
+++ b/spirv_msl.hpp
@@ -864,6 +864,7 @@
 	std::string to_swizzle_expression(uint32_t id);
 	std::string to_buffer_size_expression(uint32_t id);
 	bool is_sample_rate() const;
+	bool is_intersection_query() const;
 	bool is_direct_input_builtin(spv::BuiltIn builtin);
 	std::string builtin_qualifier(spv::BuiltIn builtin);
 	std::string builtin_type_decl(spv::BuiltIn builtin, uint32_t id = 0);
diff --git a/test_shaders.py b/test_shaders.py
index eab10a6..6963cdd 100755
--- a/test_shaders.py
+++ b/test_shaders.py
@@ -132,6 +132,8 @@
             return '-std=ios-metal2.2'
         elif '.msl23.' in shader:
             return '-std=ios-metal2.3'
+        elif '.msl24.' in shader:
+            return '-std=ios-metal2.4'
         elif '.msl11.' in shader:
             return '-std=ios-metal1.1'
         elif '.msl10.' in shader:
@@ -147,6 +149,8 @@
             return '-std=macos-metal2.2'
         elif '.msl23.' in shader:
             return '-std=macos-metal2.3'
+        elif '.msl24.' in shader:
+            return '-std=macos-metal2.4'
         elif '.msl11.' in shader:
             return '-std=macos-metal1.1'
         else:
@@ -161,6 +165,8 @@
         return '20200'
     elif '.msl23.' in shader:
         return '20300'
+    elif '.msl24.' in shader:
+        return '20400'
     elif '.msl11.' in shader:
         return '10100'
     else:
@@ -768,7 +774,8 @@
 
     shader_is_msl22 = 'msl22' in joined_path
     shader_is_msl23 = 'msl23' in joined_path
-    skip_validation = (shader_is_msl22 and (not args.msl22)) or (shader_is_msl23 and (not args.msl23))
+    shader_is_msl24 = 'msl24' in joined_path
+    skip_validation = (shader_is_msl22 and (not args.msl22)) or (shader_is_msl23 and (not args.msl23)) or (shader_is_msl24 and (not args.msl24))
     if '.invalid.' in joined_path:
         skip_validation = True
 
@@ -917,10 +924,12 @@
 
     args.msl22 = False
     args.msl23 = False
+    args.msl24 = False
     if args.msl:
         print_msl_compiler_version()
         args.msl22 = msl_compiler_supports_version('2.2')
         args.msl23 = msl_compiler_supports_version('2.3')
+        args.msl24 = msl_compiler_supports_version('2.4')
 
     backend = 'glsl'
     if (args.msl or args.metal):