ac,aco: move gfx10 ngg prim count zero workaround to nir

To simplify both llvm and aco backend and remove unnecessary
workaround code where prim count is known to be not zero.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Signed-off-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22381>
diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c
index 1dcdda9..27d9ca4 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -448,6 +448,50 @@
 }
 
 static void
+alloc_vertices_and_primitives_gfx10_workaround(nir_builder *b,
+                                               nir_ssa_def *num_vtx,
+                                               nir_ssa_def *num_prim)
+{
+   /* HW workaround for a GPU hang with 100% culling on GFX10.
+    * We always have to export at least 1 primitive.
+    * Export a degenerate triangle using vertex 0 for all 3 vertices.
+    *
+    * NOTE: We rely on the caller to set the vertex count also to 0 when the primitive count is 0.
+    */
+   nir_ssa_def *is_prim_cnt_0 = nir_ieq_imm(b, num_prim, 0);
+   nir_if *if_prim_cnt_0 = nir_push_if(b, is_prim_cnt_0);
+   {
+      nir_ssa_def *one = nir_imm_int(b, 1);
+      nir_alloc_vertices_and_primitives_amd(b, one, one);
+
+      nir_ssa_def *tid = nir_load_subgroup_invocation(b);
+      nir_ssa_def *is_thread_0 = nir_ieq_imm(b, tid, 0);
+      nir_if *if_thread_0 = nir_push_if(b, is_thread_0);
+      {
+         /* The vertex indices are 0, 0, 0. */
+         nir_export_amd(b, nir_imm_zero(b, 4, 32),
+                        .base = V_008DFC_SQ_EXP_PRIM,
+                        .flags = AC_EXP_FLAG_DONE,
+                        .write_mask = 1);
+
+         /* The HW culls primitives with NaN. -1 is also NaN and can save
+          * a dword in binary code by inlining constant.
+          */
+         nir_export_amd(b, nir_imm_ivec4(b, -1, -1, -1, -1),
+                        .base = V_008DFC_SQ_EXP_POS,
+                        .flags = AC_EXP_FLAG_DONE,
+                        .write_mask = 0xf);
+      }
+      nir_pop_if(b, if_thread_0);
+   }
+   nir_push_else(b, if_prim_cnt_0);
+   {
+      nir_alloc_vertices_and_primitives_amd(b, num_vtx, num_prim);
+   }
+   nir_pop_if(b, if_prim_cnt_0);
+}
+
+static void
 ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower_ngg_nogs_state *s)
 {
    for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) {
@@ -1565,7 +1609,13 @@
       nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0)));
       {
          /* Tell the final vertex and primitive count to the HW. */
-         nir_alloc_vertices_and_primitives_amd(b, num_live_vertices_in_workgroup, num_exported_prims);
+         if (s->options->gfx_level == GFX10) {
+            alloc_vertices_and_primitives_gfx10_workaround(
+               b, num_live_vertices_in_workgroup, num_exported_prims);
+         } else {
+            nir_alloc_vertices_and_primitives_amd(
+               b, num_live_vertices_in_workgroup, num_exported_prims);
+         }
       }
       nir_pop_if(b, if_wave_0);
 
@@ -3309,7 +3359,12 @@
 
    /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
    nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
-   nir_alloc_vertices_and_primitives_amd(b, workgroup_num_vertices, max_prmcnt);
+   {
+      if (s->options->gfx_level == GFX10)
+         alloc_vertices_and_primitives_gfx10_workaround(b, workgroup_num_vertices, max_prmcnt);
+      else
+         nir_alloc_vertices_and_primitives_amd(b, workgroup_num_vertices, max_prmcnt);
+   }
    nir_pop_if(b, if_wave_0);
 
    /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp
index 38dedd5..d58bdbb 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -8053,7 +8053,6 @@
 
 Temp merged_wave_info_to_mask(isel_context* ctx, unsigned i);
 Temp lanecount_to_mask(isel_context* ctx, Temp count);
-void ngg_emit_sendmsg_gs_alloc_req(isel_context* ctx, Temp vtx_cnt, Temp prm_cnt);
 
 Temp
 get_interp_param(isel_context* ctx, nir_intrinsic_op intrin,
@@ -8963,7 +8962,18 @@
       assert(ctx->stage.hw == HWStage::NGG);
       Temp num_vertices = get_ssa_temp(ctx, instr->src[0].ssa);
       Temp num_primitives = get_ssa_temp(ctx, instr->src[1].ssa);
-      ngg_emit_sendmsg_gs_alloc_req(ctx, num_vertices, num_primitives);
+
+      /* Put the number of vertices and primitives into m0 for the GS_ALLOC_REQ */
+      Temp tmp =
+         bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc),
+                  num_primitives, Operand::c32(12u));
+      tmp = bld.sop2(aco_opcode::s_or_b32, bld.m0(bld.def(s1)), bld.def(s1, scc),
+                     tmp, num_vertices);
+
+      /* Request the SPI to allocate space for the primitives and vertices
+       * that will be exported by the threadgroup.
+       */
+      bld.sopp(aco_opcode::s_sendmsg, bld.m0(tmp), -1, sendmsg_gs_alloc_req);
       break;
    }
    case nir_intrinsic_gds_atomic_add_amd: {
@@ -11430,70 +11440,6 @@
    return lanecount_to_mask(ctx, count);
 }
 
-void
-ngg_emit_sendmsg_gs_alloc_req(isel_context* ctx, Temp vtx_cnt, Temp prm_cnt)
-{
-   assert(vtx_cnt.id() && prm_cnt.id());
-
-   Builder bld(ctx->program, ctx->block);
-   Temp prm_cnt_0;
-
-   if (ctx->program->gfx_level == GFX10 &&
-       (ctx->stage.has(SWStage::GS) || ctx->program->info.has_ngg_culling)) {
-      /* Navi 1x workaround: check whether the workgroup has no output.
-       * If so, change the number of exported vertices and primitives to 1.
-       */
-      prm_cnt_0 = bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), prm_cnt, Operand::zero());
-      prm_cnt = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), Operand::c32(1u), prm_cnt,
-                         bld.scc(prm_cnt_0));
-      vtx_cnt = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), Operand::c32(1u), vtx_cnt,
-                         bld.scc(prm_cnt_0));
-   }
-
-   /* Put the number of vertices and primitives into m0 for the GS_ALLOC_REQ */
-   Temp tmp =
-      bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), prm_cnt, Operand::c32(12u));
-   tmp = bld.sop2(aco_opcode::s_or_b32, bld.m0(bld.def(s1)), bld.def(s1, scc), tmp, vtx_cnt);
-
-   /* Request the SPI to allocate space for the primitives and vertices
-    * that will be exported by the threadgroup.
-    */
-   bld.sopp(aco_opcode::s_sendmsg, bld.m0(tmp), -1, sendmsg_gs_alloc_req);
-
-   if (prm_cnt_0.id()) {
-      /* Navi 1x workaround: export a triangle with NaN coordinates when NGG has no output.
-       * It can't have all-zero positions because that would render an undesired pixel with
-       * conservative rasterization.
-       */
-      Temp first_lane = bld.sop1(Builder::s_ff1_i32, bld.def(s1), Operand(exec, bld.lm));
-      Temp cond = bld.sop2(Builder::s_lshl, bld.def(bld.lm), bld.def(s1, scc),
-                           Operand::c32_or_c64(1u, ctx->program->wave_size == 64), first_lane);
-      cond = bld.sop2(Builder::s_cselect, bld.def(bld.lm), cond,
-                      Operand::zero(ctx->program->wave_size == 64 ? 8 : 4), bld.scc(prm_cnt_0));
-
-      if_context ic_prim_0;
-      begin_divergent_if_then(ctx, &ic_prim_0, cond);
-      bld.reset(ctx->block);
-      ctx->block->kind |= block_kind_export_end;
-
-      /* Use zero: means that it's a triangle whose every vertex index is 0. */
-      Temp zero = bld.copy(bld.def(v1), Operand::zero());
-      /* Use NaN for the coordinates, so that the rasterizer allways culls it.  */
-      Temp nan_coord = bld.copy(bld.def(v1), Operand::c32(-1u));
-
-      bld.exp(aco_opcode::exp, zero, Operand(v1), Operand(v1), Operand(v1), 1 /* enabled mask */,
-              V_008DFC_SQ_EXP_PRIM /* dest */, false /* compressed */, true /* done */,
-              false /* valid mask */);
-      bld.exp(aco_opcode::exp, nan_coord, nan_coord, nan_coord, nan_coord, 0xf /* enabled mask */,
-              V_008DFC_SQ_EXP_POS /* dest */, false /* compressed */, true /* done */,
-              true /* valid mask */);
-
-      begin_divergent_if_else(ctx, &ic_prim_0);
-      end_divergent_if(ctx, &ic_prim_0);
-      bld.reset(ctx->block);
-   }
-}
-
 } /* end namespace */
 
 void
diff --git a/src/amd/llvm/ac_llvm_build.c b/src/amd/llvm/ac_llvm_build.c
index c1b86bc..dde69a0 100644
--- a/src/amd/llvm/ac_llvm_build.c
+++ b/src/amd/llvm/ac_llvm_build.c
@@ -3970,65 +3970,6 @@
    args->enabled_channels = mask;
 }
 
-/* Send GS Alloc Req message from the first wave of the group to SPI.
- * Message payload is:
- * - bits 0..10: vertices in group
- * - bits 12..22: primitives in group
- */
-void ac_build_sendmsg_gs_alloc_req(struct ac_llvm_context *ctx, LLVMValueRef wave_id,
-                                   LLVMValueRef vtx_cnt, LLVMValueRef prim_cnt)
-{
-   LLVMBuilderRef builder = ctx->builder;
-
-   if (wave_id)
-      ac_build_ifcc(ctx, LLVMBuildICmp(builder, LLVMIntEQ, wave_id, ctx->i32_0, ""), 5020);
-
-   /* HW workaround for a GPU hang with 100% culling on GFX10.
-    * We always have to export at least 1 primitive.
-    * Export a degenerate triangle using vertex 0 for all 3 vertices.
-    * 
-    * NOTE: We rely on the caller to set the vertex count also to 0 when the primitive count is 0.
-    */
-   if (ctx->gfx_level == GFX10) {
-      ac_build_ifcc(ctx, LLVMBuildICmp(builder, LLVMIntEQ, ac_get_thread_id(ctx), ctx->i32_0, ""), 5021);
-      LLVMValueRef prim_cnt_is_0 = LLVMBuildICmp(builder, LLVMIntEQ, prim_cnt, ctx->i32_0, "");
-      ac_build_ifcc(ctx, prim_cnt_is_0, 5022);
-      {
-         LLVMValueRef x = LLVMBuildShl(builder, ctx->i32_1, LLVMConstInt(ctx->i32, 12, false), "");
-         x = LLVMBuildOr(builder, x, ctx->i32_1, "");
-         ac_build_sendmsg(ctx, AC_SENDMSG_GS_ALLOC_REQ, x);
-
-         /* The vertex indices are 0, 0, 0. */
-         struct ac_ngg_prim prim = {0};
-         prim.passthrough = ctx->i32_0;
-
-         /* The HW culls primitives with NaN. */
-         struct ac_export_args pos = {0};
-         pos.out[0] = pos.out[1] = pos.out[2] = pos.out[3] = LLVMConstReal(ctx->f32, NAN);
-         pos.target = V_008DFC_SQ_EXP_POS;
-         pos.enabled_channels = 0xf;
-         pos.done = true;
-
-         ac_build_export_prim(ctx, &prim);
-         ac_build_export(ctx, &pos);
-      }
-      ac_build_else(ctx, 5022);
-   }
-
-   LLVMValueRef x = LLVMBuildShl(builder, prim_cnt, LLVMConstInt(ctx->i32, 12, false), "");
-   x = LLVMBuildOr(builder, x, vtx_cnt, "");
-   ac_build_sendmsg(ctx, AC_SENDMSG_GS_ALLOC_REQ, x);
-
-   if (ctx->gfx_level == GFX10) {
-      ac_build_endif(ctx, 5022);
-      ac_build_endif(ctx, 5021);
-   }
-
-   if (wave_id)
-      ac_build_endif(ctx, 5020);
-}
-
-
 LLVMValueRef ac_pack_edgeflags_for_export(struct ac_llvm_context *ctx,
                                           const struct ac_shader_args *args)
 {
@@ -4044,53 +3985,6 @@
    return LLVMBuildAnd(ctx->builder, tmp, LLVMConstInt(ctx->i32, 0x20080200, 0), "");
 }
 
-LLVMValueRef ac_pack_prim_export(struct ac_llvm_context *ctx, const struct ac_ngg_prim *prim)
-{
-   /* The prim export format is:
-    *  - bits 0..8: index 0
-    *  - bit 9: edge flag 0
-    *  - bits 10..18: index 1
-    *  - bit 19: edge flag 1
-    *  - bits 20..28: index 2
-    *  - bit 29: edge flag 2
-    *  - bit 31: null primitive (skip)
-    */
-   LLVMBuilderRef builder = ctx->builder;
-   LLVMValueRef tmp = LLVMBuildZExt(builder, prim->isnull, ctx->i32, "");
-   LLVMValueRef result = LLVMBuildShl(builder, tmp, LLVMConstInt(ctx->i32, 31, false), "");
-   result = LLVMBuildOr(ctx->builder, result, prim->edgeflags, "");
-
-   for (unsigned i = 0; i < prim->num_vertices; ++i) {
-      tmp = LLVMBuildShl(builder, prim->index[i], LLVMConstInt(ctx->i32, 10 * i, false), "");
-      result = LLVMBuildOr(builder, result, tmp, "");
-   }
-   return result;
-}
-
-void ac_build_export_prim(struct ac_llvm_context *ctx, const struct ac_ngg_prim *prim)
-{
-   struct ac_export_args args;
-
-   if (prim->passthrough) {
-      args.out[0] = prim->passthrough;
-   } else {
-      args.out[0] = ac_pack_prim_export(ctx, prim);
-   }
-
-   args.out[0] = LLVMBuildBitCast(ctx->builder, args.out[0], ctx->f32, "");
-   args.out[1] = LLVMGetUndef(ctx->f32);
-   args.out[2] = LLVMGetUndef(ctx->f32);
-   args.out[3] = LLVMGetUndef(ctx->f32);
-
-   args.target = V_008DFC_SQ_EXP_PRIM;
-   args.enabled_channels = 1;
-   args.done = true;
-   args.valid_mask = false;
-   args.compr = false;
-
-   ac_build_export(ctx, &args);
-}
-
 static LLVMTypeRef arg_llvm_type(enum ac_arg_type type, unsigned size, struct ac_llvm_context *ctx)
 {
    LLVMTypeRef base;
diff --git a/src/amd/llvm/ac_llvm_build.h b/src/amd/llvm/ac_llvm_build.h
index 299fe8f..373ded5 100644
--- a/src/amd/llvm/ac_llvm_build.h
+++ b/src/amd/llvm/ac_llvm_build.h
@@ -552,9 +552,6 @@
                      LLVMValueRef samplemask, LLVMValueRef mrt0_alpha, bool is_last,
                      struct ac_export_args *args);
 
-void ac_build_sendmsg_gs_alloc_req(struct ac_llvm_context *ctx, LLVMValueRef wave_id,
-                                   LLVMValueRef vtx_cnt, LLVMValueRef prim_cnt);
-
 struct ac_ngg_prim {
    unsigned num_vertices;
    LLVMValueRef isnull;
@@ -565,8 +562,6 @@
 
 LLVMValueRef ac_pack_edgeflags_for_export(struct ac_llvm_context *ctx,
                                           const struct ac_shader_args *args);
-LLVMValueRef ac_pack_prim_export(struct ac_llvm_context *ctx, const struct ac_ngg_prim *prim);
-void ac_build_export_prim(struct ac_llvm_context *ctx, const struct ac_ngg_prim *prim);
 
 LLVMTypeRef ac_arg_type_to_pointee_type(struct ac_llvm_context *ctx, enum ac_arg_type type);
 
diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c
index e7db715..e464a55 100644
--- a/src/amd/llvm/ac_nir_to_llvm.c
+++ b/src/amd/llvm/ac_nir_to_llvm.c
@@ -4099,14 +4099,22 @@
    case nir_intrinsic_load_workgroup_num_input_primitives_amd:
       result = ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->gs_tg_info), 22, 9);
       break;
-   case nir_intrinsic_alloc_vertices_and_primitives_amd:
-      /* The caller should only call this conditionally for wave 0, so pass NULL to disable
-       * the wave 0 check inside this function.
+   case nir_intrinsic_alloc_vertices_and_primitives_amd: {
+      /* The caller should only call this conditionally for wave 0.
+       *
+       * Send GS Alloc Req message from the first wave of the group to SPI.
+       * Message payload is:
+       * - bits 0..10: vertices in group
+       * - bits 12..22: primitives in group
        */
-      ac_build_sendmsg_gs_alloc_req(&ctx->ac, NULL,
-                                    get_src(ctx, instr->src[0]),
-                                    get_src(ctx, instr->src[1]));
+      LLVMValueRef vtx_cnt = get_src(ctx, instr->src[0]);
+      LLVMValueRef prim_cnt = get_src(ctx, instr->src[1]);
+      LLVMValueRef msg = LLVMBuildShl(ctx->ac.builder, prim_cnt,
+                                      LLVMConstInt(ctx->ac.i32, 12, false), "");
+      msg = LLVMBuildOr(ctx->ac.builder, msg, vtx_cnt, "");
+      ac_build_sendmsg(&ctx->ac, AC_SENDMSG_GS_ALLOC_REQ, msg);
       break;
+   }
    case nir_intrinsic_overwrite_vs_arguments_amd:
       ctx->abi->vertex_id_replaced = get_src(ctx, instr->src[0]);
       ctx->abi->instance_id_replaced = get_src(ctx, instr->src[1]);