Merge pull request #2035 from KhronosGroup/fix-2032

HLSL: Improve support for VertexInfo aux struct.
diff --git a/main.cpp b/main.cpp
index 8b8ca7e..7f7bda4 100644
--- a/main.cpp
+++ b/main.cpp
@@ -710,7 +710,12 @@
 	bool msl = false;
 	bool hlsl = false;
 	bool hlsl_compat = false;
+
 	bool hlsl_support_nonzero_base = false;
+	bool hlsl_base_vertex_index_explicit_binding = false;
+	uint32_t hlsl_base_vertex_index_register_index = 0;
+	uint32_t hlsl_base_vertex_index_register_space = 0;
+
 	bool hlsl_force_storage_buffer_as_uav = false;
 	bool hlsl_nonwritable_uav_texture_as_srv = false;
 	bool hlsl_enable_16bit_types = false;
@@ -807,6 +812,7 @@
 	                "\t\tPointSize is ignored, and PointCoord returns (0.5, 0.5).\n"
 	                "\t[--hlsl-support-nonzero-basevertex-baseinstance]:\n\t\tSupport base vertex and base instance by emitting a special cbuffer declared as:\n"
 	                "\t\tcbuffer SPIRV_Cross_VertexInfo { int SPIRV_Cross_BaseVertex; int SPIRV_Cross_BaseInstance; };\n"
+	                "\t[--hlsl-basevertex-baseinstance-binding <register index> <register space>]:\n\t\tAssign a fixed binding to SPIRV_Cross_VertexInfo.\n"
 	                "\t[--hlsl-auto-binding (push, cbv, srv, uav, sampler, all)]\n"
 	                "\t\tDo not emit any : register(#) bindings for specific resource types, and rely on HLSL compiler to assign something.\n"
 	                "\t[--hlsl-force-storage-buffer-as-uav]:\n\t\tAlways emit SSBOs as UAVs, even when marked as read-only.\n"
@@ -1371,6 +1377,12 @@
 		hlsl_opts.flatten_matrix_vertex_input_semantics = args.hlsl_flatten_matrix_vertex_input_semantics;
 		hlsl->set_hlsl_options(hlsl_opts);
 		hlsl->set_resource_binding_flags(args.hlsl_binding_flags);
+		if (args.hlsl_base_vertex_index_explicit_binding)
+		{
+			hlsl->set_hlsl_aux_buffer_binding(HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE,
+			                                  args.hlsl_base_vertex_index_register_index,
+			                                  args.hlsl_base_vertex_index_register_space);
+		}
 	}
 
 	if (build_dummy_sampler)
@@ -1533,6 +1545,11 @@
 	cbs.add("--hlsl-enable-compat", [&args](CLIParser &) { args.hlsl_compat = true; });
 	cbs.add("--hlsl-support-nonzero-basevertex-baseinstance",
 	        [&args](CLIParser &) { args.hlsl_support_nonzero_base = true; });
+	cbs.add("--hlsl-basevertex-baseinstance-binding", [&args](CLIParser &parser) {
+		args.hlsl_base_vertex_index_explicit_binding = true;
+		args.hlsl_base_vertex_index_register_index = parser.next_uint();
+		args.hlsl_base_vertex_index_register_space = parser.next_uint();
+	});
 	cbs.add("--hlsl-auto-binding", [&args](CLIParser &parser) {
 		args.hlsl_binding_flags |= hlsl_resource_type_to_flag(parser.next_string());
 	});
diff --git a/reference/shaders-hlsl-no-opt/vert/base-instance.vert b/reference/shaders-hlsl-no-opt/vert/base-instance.vert
new file mode 100644
index 0000000..de31f2c
--- /dev/null
+++ b/reference/shaders-hlsl-no-opt/vert/base-instance.vert
@@ -0,0 +1,30 @@
+static float4 gl_Position;
+static int gl_BaseInstanceARB;
+cbuffer SPIRV_Cross_VertexInfo
+{
+    int SPIRV_Cross_BaseVertex;
+    int SPIRV_Cross_BaseInstance;
+};
+
+struct SPIRV_Cross_Input
+{
+};
+
+struct SPIRV_Cross_Output
+{
+    float4 gl_Position : SV_Position;
+};
+
+void vert_main()
+{
+    gl_Position = float(gl_BaseInstanceARB).xxxx;
+}
+
+SPIRV_Cross_Output main(SPIRV_Cross_Input stage_input)
+{
+    gl_BaseInstanceARB = SPIRV_Cross_BaseInstance;
+    vert_main();
+    SPIRV_Cross_Output stage_output;
+    stage_output.gl_Position = gl_Position;
+    return stage_output;
+}
diff --git a/reference/shaders-hlsl-no-opt/vert/base-vertex.vert b/reference/shaders-hlsl-no-opt/vert/base-vertex.vert
new file mode 100644
index 0000000..6b9b62b
--- /dev/null
+++ b/reference/shaders-hlsl-no-opt/vert/base-vertex.vert
@@ -0,0 +1,30 @@
+static float4 gl_Position;
+static int gl_BaseVertexARB;
+cbuffer SPIRV_Cross_VertexInfo
+{
+    int SPIRV_Cross_BaseVertex;
+    int SPIRV_Cross_BaseInstance;
+};
+
+struct SPIRV_Cross_Input
+{
+};
+
+struct SPIRV_Cross_Output
+{
+    float4 gl_Position : SV_Position;
+};
+
+void vert_main()
+{
+    gl_Position = float(gl_BaseVertexARB).xxxx;
+}
+
+SPIRV_Cross_Output main(SPIRV_Cross_Input stage_input)
+{
+    gl_BaseVertexARB = SPIRV_Cross_BaseVertex;
+    vert_main();
+    SPIRV_Cross_Output stage_output;
+    stage_output.gl_Position = gl_Position;
+    return stage_output;
+}
diff --git a/shaders-hlsl-no-opt/vert/base-instance.vert b/shaders-hlsl-no-opt/vert/base-instance.vert
new file mode 100644
index 0000000..20b686c
--- /dev/null
+++ b/shaders-hlsl-no-opt/vert/base-instance.vert
@@ -0,0 +1,7 @@
+#version 450
+#extension GL_ARB_shader_draw_parameters : require
+
+void main()
+{
+	gl_Position = vec4(gl_BaseInstanceARB);
+}
diff --git a/shaders-hlsl-no-opt/vert/base-vertex.vert b/shaders-hlsl-no-opt/vert/base-vertex.vert
new file mode 100644
index 0000000..ef486c8
--- /dev/null
+++ b/shaders-hlsl-no-opt/vert/base-vertex.vert
@@ -0,0 +1,7 @@
+#version 450
+#extension GL_ARB_shader_draw_parameters : require
+
+void main()
+{
+	gl_Position = vec4(gl_BaseVertexARB);
+}
diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp
index 919727d..1291f7e 100644
--- a/spirv_hlsl.cpp
+++ b/spirv_hlsl.cpp
@@ -748,6 +748,8 @@
 		case BuiltInSubgroupLeMask:
 		case BuiltInSubgroupGtMask:
 		case BuiltInSubgroupGeMask:
+		case BuiltInBaseVertex:
+		case BuiltInBaseInstance:
 			// Handled specially.
 			break;
 
@@ -1032,8 +1034,6 @@
 	Bitset builtins = active_input_builtins;
 	builtins.merge_or(active_output_builtins);
 
-	bool need_base_vertex_info = false;
-
 	std::unordered_map<uint32_t, ID> builtin_to_initializer;
 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
 		if (!is_builtin_variable(var) || var.storage != StorageClassOutput || !var.initializer)
@@ -1087,7 +1087,13 @@
 		case BuiltInInstanceIndex:
 			type = "int";
 			if (hlsl_options.support_nonzero_base_vertex_base_instance)
-				need_base_vertex_info = true;
+				base_vertex_info.used = true;
+			break;
+
+		case BuiltInBaseVertex:
+		case BuiltInBaseInstance:
+			type = "int";
+			base_vertex_info.used = true;
 			break;
 
 		case BuiltInInstanceId:
@@ -1187,9 +1193,17 @@
 		}
 	});
 
-	if (need_base_vertex_info)
+	if (base_vertex_info.used)
 	{
-		statement("cbuffer SPIRV_Cross_VertexInfo");
+		string binding_info;
+		if (base_vertex_info.explicit_binding)
+		{
+			binding_info = join(" : register(b", base_vertex_info.register_index);
+			if (base_vertex_info.register_space)
+				binding_info += join(", space", base_vertex_info.register_space);
+			binding_info += ")";
+		}
+		statement("cbuffer SPIRV_Cross_VertexInfo", binding_info);
 		begin_scope();
 		statement("int SPIRV_Cross_BaseVertex;");
 		statement("int SPIRV_Cross_BaseInstance;");
@@ -1198,6 +1212,30 @@
 	}
 }
 
+void CompilerHLSL::set_hlsl_aux_buffer_binding(HLSLAuxBinding binding, uint32_t register_index, uint32_t register_space)
+{
+	if (binding == HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE)
+	{
+		base_vertex_info.explicit_binding = true;
+		base_vertex_info.register_space = register_space;
+		base_vertex_info.register_index = register_index;
+	}
+}
+
+void CompilerHLSL::unset_hlsl_aux_buffer_binding(HLSLAuxBinding binding)
+{
+	if (binding == HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE)
+		base_vertex_info.explicit_binding = false;
+}
+
+bool CompilerHLSL::is_hlsl_aux_buffer_binding_used(HLSLAuxBinding binding) const
+{
+	if (binding == HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE)
+		return base_vertex_info.used;
+	else
+		return false;
+}
+
 void CompilerHLSL::emit_composite_constants()
 {
 	// HLSL cannot declare structs or arrays inline, so we must move them out to
@@ -2612,6 +2650,14 @@
 				statement(builtin, " = int(stage_input.", builtin, ");");
 			break;
 
+		case BuiltInBaseVertex:
+			statement(builtin, " = SPIRV_Cross_BaseVertex;");
+			break;
+
+		case BuiltInBaseInstance:
+			statement(builtin, " = SPIRV_Cross_BaseInstance;");
+			break;
+
 		case BuiltInInstanceId:
 			// D3D semantics are uint, but shader wants int.
 			statement(builtin, " = int(stage_input.", builtin, ");");
diff --git a/spirv_hlsl.hpp b/spirv_hlsl.hpp
index f01bcf9..41ce73b 100644
--- a/spirv_hlsl.hpp
+++ b/spirv_hlsl.hpp
@@ -98,6 +98,11 @@
 	} cbv, uav, srv, sampler;
 };
 
+enum HLSLAuxBinding
+{
+	HLSL_AUX_BINDING_BASE_VERTEX_INSTANCE = 0
+};
+
 class CompilerHLSL : public CompilerGLSL
 {
 public:
@@ -211,6 +216,11 @@
 	// Controls which storage buffer bindings will be forced to be declared as UAVs.
 	void set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding);
 
+	// By default, these magic buffers are not assigned a specific binding.
+	void set_hlsl_aux_buffer_binding(HLSLAuxBinding binding, uint32_t register_index, uint32_t register_space);
+	void unset_hlsl_aux_buffer_binding(HLSLAuxBinding binding);
+	bool is_hlsl_aux_buffer_binding_used(HLSLAuxBinding binding) const;
+
 private:
 	std::string type_to_glsl(const SPIRType &type, uint32_t id = 0) override;
 	std::string image_type_hlsl(const SPIRType &type, uint32_t id);
@@ -373,6 +383,14 @@
 
 	std::unordered_set<SetBindingPair, InternalHasher> force_uav_buffer_bindings;
 
+	struct
+	{
+		uint32_t register_index = 0;
+		uint32_t register_space = 0;
+		bool explicit_binding = false;
+		bool used = false;
+	} base_vertex_info;
+
 	// Returns true for BuiltInSampleMask because gl_SampleMask[] is an array in SPIR-V, but SV_Coverage is a scalar in HLSL.
 	bool builtin_translates_to_nonarray(spv::BuiltIn builtin) const override;