core: Report marshaller error for uncompressed size too large back to the client 2 (#12477)

Code mostly as in #12360 but trying to use `handleInternalError()` / `cancel()` as suggested by @ejona86

Fixes #11246
diff --git a/core/src/main/java/io/grpc/internal/CloseWithHeadersMarker.java b/core/src/main/java/io/grpc/internal/CloseWithHeadersMarker.java
new file mode 100644
index 0000000..376b9ed
--- /dev/null
+++ b/core/src/main/java/io/grpc/internal/CloseWithHeadersMarker.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright 2025 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.internal;
+
+import io.grpc.Status;
+
+/**
+ * Marker to be used for Status sent to {@link ServerStream#cancel(Status)} to signal that stream
+ * should be closed by sending headers.
+ */
+public class CloseWithHeadersMarker extends Throwable {
+  private static final long serialVersionUID = 0L;
+
+  @Override
+  public synchronized Throwable fillInStackTrace() {
+    return this;
+  }
+}
diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java
index e224384..1c1f76c 100644
--- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java
+++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java
@@ -280,6 +280,17 @@
   }
 
   /**
+   * Close the {@link ServerStream} because parsing request message failed.
+   * Similar to {@link #handleInternalError(Throwable)}.
+   */
+  private void handleParseError(StatusRuntimeException parseError) {
+    cancelled = true;
+    log.log(Level.WARNING, "Cancelling the stream because of parse error", parseError);
+    stream.cancel(parseError.getStatus().withCause(new CloseWithHeadersMarker()));
+    serverCallTracer.reportCallEnded(false); // error so always false
+  }
+
+  /**
    * All of these callbacks are assumed to called on an application thread, and the caller is
    * responsible for handling thrown exceptions.
    */
@@ -327,18 +338,23 @@
         return;
       }
 
-      InputStream message;
+      InputStream message = null;
       try {
         while ((message = producer.next()) != null) {
+          ReqT parsed;
           try {
-            listener.onMessage(call.method.parseRequest(message));
-          } catch (Throwable t) {
+            parsed = call.method.parseRequest(message);
+          } catch (StatusRuntimeException e) {
             GrpcUtil.closeQuietly(message);
-            throw t;
+            GrpcUtil.closeQuietly(producer);
+            call.handleParseError(e);
+            return;
           }
           message.close();
+          listener.onMessage(parsed);
         }
       } catch (Throwable t) {
+        GrpcUtil.closeQuietly(message);
         GrpcUtil.closeQuietly(producer);
         Throwables.throwIfUnchecked(t);
         throw new RuntimeException(t);
diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java
index 7394c83..028f1ac 100644
--- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java
@@ -48,9 +48,11 @@
 import io.grpc.SecurityLevel;
 import io.grpc.ServerCall;
 import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
 import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl;
 import io.perfmark.PerfMark;
 import java.io.ByteArrayInputStream;
+import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import org.junit.Before;
@@ -69,6 +71,8 @@
 
   @Mock private ServerStream stream;
   @Mock private ServerCall.Listener<Long> callListener;
+  @Mock private StreamListener.MessageProducer messageProducer;
+  @Mock private InputStream message;
 
   private final CallTracer serverCallTracer = CallTracer.getDefaultFactory().create();
   private ServerCallImpl<Long, Long> call;
@@ -493,6 +497,44 @@
     assertThat(e).hasMessageThat().isEqualTo("unexpected exception");
   }
 
+  @Test
+  public void streamListener_statusRuntimeException() throws IOException {
+    MethodDescriptor<Long, Long> failingParseMethod = MethodDescriptor.<Long, Long>newBuilder()
+        .setType(MethodType.UNARY)
+        .setFullMethodName("service/method")
+        .setRequestMarshaller(new LongMarshaller() {
+            @Override
+            public Long parse(InputStream stream) {
+              throw new StatusRuntimeException(Status.RESOURCE_EXHAUSTED
+                  .withDescription("Decompressed gRPC message exceeds maximum size"));
+            }
+        })
+        .setResponseMarshaller(new LongMarshaller())
+        .build();
+
+    call =  new ServerCallImpl<>(stream, failingParseMethod, requestHeaders, context,
+            DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
+            serverCallTracer, PerfMark.createTag());
+
+    ServerStreamListenerImpl<Long> streamListener =
+            new ServerCallImpl.ServerStreamListenerImpl<>(call, callListener, context);
+
+    when(messageProducer.next()).thenReturn(message, (InputStream) null);
+    streamListener.messagesAvailable(messageProducer);
+    ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
+    verify(stream).cancel(statusCaptor.capture());
+    Status status = statusCaptor.getValue();
+    assertEquals(Status.Code.RESOURCE_EXHAUSTED, status.getCode());
+    assertEquals("Decompressed gRPC message exceeds maximum size", status.getDescription());
+
+    streamListener.halfClosed();
+    verify(callListener, never()).onHalfClose();
+
+    when(messageProducer.next()).thenReturn(message, (InputStream) null);
+    streamListener.messagesAvailable(messageProducer);
+    verify(callListener, never()).onMessage(any());
+  }
+
   private static class LongMarshaller implements Marshaller<Long> {
     @Override
     public InputStream stream(Long value) {
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
index 5129528..f5cd111 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
@@ -2024,7 +2024,7 @@
     }
   }
 
-  private static void assertCodeEquals(Status.Code expected, Status actual) {
+  protected static void assertCodeEquals(Status.Code expected, Status actual) {
     assertWithMessage("Unexpected status: %s", actual).that(actual.getCode()).isEqualTo(expected);
   }
 
diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java
index b969238..33cd624 100644
--- a/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java
+++ b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java
@@ -17,6 +17,7 @@
 package io.grpc.testing.integration;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 import com.google.protobuf.ByteString;
@@ -37,6 +38,8 @@
 import io.grpc.ServerCall.Listener;
 import io.grpc.ServerCallHandler;
 import io.grpc.ServerInterceptor;
+import io.grpc.Status.Code;
+import io.grpc.StatusRuntimeException;
 import io.grpc.internal.GrpcUtil;
 import io.grpc.netty.InternalNettyChannelBuilder;
 import io.grpc.netty.InternalNettyServerBuilder;
@@ -53,7 +56,9 @@
 import java.io.OutputStream;
 import org.junit.Before;
 import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TestName;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
@@ -84,10 +89,16 @@
     compressors.register(Codec.Identity.NONE);
   }
 
+  @Rule
+  public final TestName currentTest = new TestName();
+
   @Override
   protected ServerBuilder<?> getServerBuilder() {
     NettyServerBuilder builder = NettyServerBuilder.forPort(0, InsecureServerCredentials.create())
-        .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
+        .maxInboundMessageSize(
+            DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME.equals(currentTest.getMethodName())
+                ? 1000
+                : AbstractInteropTest.MAX_MESSAGE_SIZE)
         .compressorRegistry(compressors)
         .decompressorRegistry(decompressors)
         .intercept(new ServerInterceptor() {
@@ -126,6 +137,22 @@
     assertTrue(FZIPPER.anyWritten);
   }
 
+  private static final String DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME =
+      "decompressedMessageTooLong";
+
+  @Test
+  public void decompressedMessageTooLong() {
+    assertEquals(DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME, currentTest.getMethodName());
+    final SimpleRequest bigRequest = SimpleRequest.newBuilder()
+        .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10_000])))
+        .build();
+    StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
+        () -> blockingStub.withCompression("gzip").unaryCall(bigRequest));
+    assertCodeEquals(Code.RESOURCE_EXHAUSTED, e.getStatus());
+    assertEquals("Decompressed gRPC message exceeds maximum size 1000",
+        e.getStatus().getDescription());
+  }
+
   @Override
   protected NettyChannelBuilder createChannelBuilder() {
     NettyChannelBuilder builder = NettyChannelBuilder.forAddress(getListenAddress())
diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java
index 836f39d..681e649 100644
--- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java
+++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java
@@ -23,6 +23,7 @@
 import io.grpc.Metadata;
 import io.grpc.Status;
 import io.grpc.internal.AbstractServerStream;
+import io.grpc.internal.CloseWithHeadersMarker;
 import io.grpc.internal.StatsTraceContext;
 import io.grpc.internal.TransportTracer;
 import io.grpc.internal.WritableBuffer;
@@ -130,7 +131,11 @@
     @Override
     public void cancel(Status status) {
       try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) {
-        writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true);
+        CancelServerStreamCommand cmd =
+            status.getCause() instanceof CloseWithHeadersMarker
+                ? CancelServerStreamCommand.withReason(transportState(), status)
+                : CancelServerStreamCommand.withReset(transportState(), status);
+        writeQueue.enqueue(cmd, true);
       }
     }
   }