core: Report marshaller error for uncompressed size too large back to the client 2 (#12477)
Code mostly as in #12360 but trying to use `handleInternalError()` / `cancel()` as suggested by @ejona86
Fixes #11246
diff --git a/core/src/main/java/io/grpc/internal/CloseWithHeadersMarker.java b/core/src/main/java/io/grpc/internal/CloseWithHeadersMarker.java
new file mode 100644
index 0000000..376b9ed
--- /dev/null
+++ b/core/src/main/java/io/grpc/internal/CloseWithHeadersMarker.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright 2025 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.internal;
+
+import io.grpc.Status;
+
+/**
+ * Marker to be used for Status sent to {@link ServerStream#cancel(Status)} to signal that stream
+ * should be closed by sending headers.
+ */
+public class CloseWithHeadersMarker extends Throwable {
+ private static final long serialVersionUID = 0L;
+
+ @Override
+ public synchronized Throwable fillInStackTrace() {
+ return this;
+ }
+}
diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java
index e224384..1c1f76c 100644
--- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java
+++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java
@@ -280,6 +280,17 @@
}
/**
+ * Close the {@link ServerStream} because parsing request message failed.
+ * Similar to {@link #handleInternalError(Throwable)}.
+ */
+ private void handleParseError(StatusRuntimeException parseError) {
+ cancelled = true;
+ log.log(Level.WARNING, "Cancelling the stream because of parse error", parseError);
+ stream.cancel(parseError.getStatus().withCause(new CloseWithHeadersMarker()));
+ serverCallTracer.reportCallEnded(false); // error so always false
+ }
+
+ /**
* All of these callbacks are assumed to called on an application thread, and the caller is
* responsible for handling thrown exceptions.
*/
@@ -327,18 +338,23 @@
return;
}
- InputStream message;
+ InputStream message = null;
try {
while ((message = producer.next()) != null) {
+ ReqT parsed;
try {
- listener.onMessage(call.method.parseRequest(message));
- } catch (Throwable t) {
+ parsed = call.method.parseRequest(message);
+ } catch (StatusRuntimeException e) {
GrpcUtil.closeQuietly(message);
- throw t;
+ GrpcUtil.closeQuietly(producer);
+ call.handleParseError(e);
+ return;
}
message.close();
+ listener.onMessage(parsed);
}
} catch (Throwable t) {
+ GrpcUtil.closeQuietly(message);
GrpcUtil.closeQuietly(producer);
Throwables.throwIfUnchecked(t);
throw new RuntimeException(t);
diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java
index 7394c83..028f1ac 100644
--- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java
@@ -48,9 +48,11 @@
import io.grpc.SecurityLevel;
import io.grpc.ServerCall;
import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl;
import io.perfmark.PerfMark;
import java.io.ByteArrayInputStream;
+import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.junit.Before;
@@ -69,6 +71,8 @@
@Mock private ServerStream stream;
@Mock private ServerCall.Listener<Long> callListener;
+ @Mock private StreamListener.MessageProducer messageProducer;
+ @Mock private InputStream message;
private final CallTracer serverCallTracer = CallTracer.getDefaultFactory().create();
private ServerCallImpl<Long, Long> call;
@@ -493,6 +497,44 @@
assertThat(e).hasMessageThat().isEqualTo("unexpected exception");
}
+ @Test
+ public void streamListener_statusRuntimeException() throws IOException {
+ MethodDescriptor<Long, Long> failingParseMethod = MethodDescriptor.<Long, Long>newBuilder()
+ .setType(MethodType.UNARY)
+ .setFullMethodName("service/method")
+ .setRequestMarshaller(new LongMarshaller() {
+ @Override
+ public Long parse(InputStream stream) {
+ throw new StatusRuntimeException(Status.RESOURCE_EXHAUSTED
+ .withDescription("Decompressed gRPC message exceeds maximum size"));
+ }
+ })
+ .setResponseMarshaller(new LongMarshaller())
+ .build();
+
+ call = new ServerCallImpl<>(stream, failingParseMethod, requestHeaders, context,
+ DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
+ serverCallTracer, PerfMark.createTag());
+
+ ServerStreamListenerImpl<Long> streamListener =
+ new ServerCallImpl.ServerStreamListenerImpl<>(call, callListener, context);
+
+ when(messageProducer.next()).thenReturn(message, (InputStream) null);
+ streamListener.messagesAvailable(messageProducer);
+ ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
+ verify(stream).cancel(statusCaptor.capture());
+ Status status = statusCaptor.getValue();
+ assertEquals(Status.Code.RESOURCE_EXHAUSTED, status.getCode());
+ assertEquals("Decompressed gRPC message exceeds maximum size", status.getDescription());
+
+ streamListener.halfClosed();
+ verify(callListener, never()).onHalfClose();
+
+ when(messageProducer.next()).thenReturn(message, (InputStream) null);
+ streamListener.messagesAvailable(messageProducer);
+ verify(callListener, never()).onMessage(any());
+ }
+
private static class LongMarshaller implements Marshaller<Long> {
@Override
public InputStream stream(Long value) {
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
index 5129528..f5cd111 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
@@ -2024,7 +2024,7 @@
}
}
- private static void assertCodeEquals(Status.Code expected, Status actual) {
+ protected static void assertCodeEquals(Status.Code expected, Status actual) {
assertWithMessage("Unexpected status: %s", actual).that(actual.getCode()).isEqualTo(expected);
}
diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java
index b969238..33cd624 100644
--- a/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java
+++ b/interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java
@@ -17,6 +17,7 @@
package io.grpc.testing.integration;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import com.google.protobuf.ByteString;
@@ -37,6 +38,8 @@
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
+import io.grpc.Status.Code;
+import io.grpc.StatusRuntimeException;
import io.grpc.internal.GrpcUtil;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyServerBuilder;
@@ -53,7 +56,9 @@
import java.io.OutputStream;
import org.junit.Before;
import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -84,10 +89,16 @@
compressors.register(Codec.Identity.NONE);
}
+ @Rule
+ public final TestName currentTest = new TestName();
+
@Override
protected ServerBuilder<?> getServerBuilder() {
NettyServerBuilder builder = NettyServerBuilder.forPort(0, InsecureServerCredentials.create())
- .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
+ .maxInboundMessageSize(
+ DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME.equals(currentTest.getMethodName())
+ ? 1000
+ : AbstractInteropTest.MAX_MESSAGE_SIZE)
.compressorRegistry(compressors)
.decompressorRegistry(decompressors)
.intercept(new ServerInterceptor() {
@@ -126,6 +137,22 @@
assertTrue(FZIPPER.anyWritten);
}
+ private static final String DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME =
+ "decompressedMessageTooLong";
+
+ @Test
+ public void decompressedMessageTooLong() {
+ assertEquals(DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME, currentTest.getMethodName());
+ final SimpleRequest bigRequest = SimpleRequest.newBuilder()
+ .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10_000])))
+ .build();
+ StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
+ () -> blockingStub.withCompression("gzip").unaryCall(bigRequest));
+ assertCodeEquals(Code.RESOURCE_EXHAUSTED, e.getStatus());
+ assertEquals("Decompressed gRPC message exceeds maximum size 1000",
+ e.getStatus().getDescription());
+ }
+
@Override
protected NettyChannelBuilder createChannelBuilder() {
NettyChannelBuilder builder = NettyChannelBuilder.forAddress(getListenAddress())
diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java
index 836f39d..681e649 100644
--- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java
+++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java
@@ -23,6 +23,7 @@
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.AbstractServerStream;
+import io.grpc.internal.CloseWithHeadersMarker;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportTracer;
import io.grpc.internal.WritableBuffer;
@@ -130,7 +131,11 @@
@Override
public void cancel(Status status) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) {
- writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true);
+ CancelServerStreamCommand cmd =
+ status.getCause() instanceof CloseWithHeadersMarker
+ ? CancelServerStreamCommand.withReason(transportState(), status)
+ : CancelServerStreamCommand.withReset(transportState(), status);
+ writeQueue.enqueue(cmd, true);
}
}
}