xds: float LRU cache across interceptors (#11992)
diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
index b5568ef..8ec02f4 100644
--- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
+++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
@@ -59,15 +59,17 @@
static final String TYPE_URL =
"type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig";
-
+ private final LruCache<String, CallCredentials> callCredentialsCache;
final String filterInstanceName;
- GcpAuthenticationFilter(String name) {
+ GcpAuthenticationFilter(String name, int cacheSize) {
filterInstanceName = checkNotNull(name, "name");
+ this.callCredentialsCache = new LruCache<>(cacheSize);
}
-
static final class Provider implements Filter.Provider {
+ private final int cacheSize = 10;
+
@Override
public String[] typeUrls() {
return new String[]{TYPE_URL};
@@ -80,7 +82,7 @@
@Override
public GcpAuthenticationFilter newInstance(String name) {
- return new GcpAuthenticationFilter(name);
+ return new GcpAuthenticationFilter(name, cacheSize);
}
@Override
@@ -101,11 +103,14 @@
// Validate cache_config
if (gcpAuthnProto.hasCacheConfig()) {
TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig();
- cacheSize = cacheConfig.getCacheSize().getValue();
- if (cacheSize == 0) {
- return ConfigOrError.fromError(
- "cache_config.cache_size must be greater than zero");
+ if (cacheConfig.hasCacheSize()) {
+ cacheSize = cacheConfig.getCacheSize().getValue();
+ if (cacheSize == 0) {
+ return ConfigOrError.fromError(
+ "cache_config.cache_size must be greater than zero");
+ }
}
+
// LruCache's size is an int and briefly exceeds its maximum size before evicting entries
cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1);
}
@@ -127,8 +132,9 @@
@Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) {
ComputeEngineCredentials credentials = ComputeEngineCredentials.create();
- LruCache<String, CallCredentials> callCredentialsCache =
- new LruCache<>(((GcpAuthenticationConfig) config).getCacheSize());
+ synchronized (callCredentialsCache) {
+ callCredentialsCache.resizeCache(((GcpAuthenticationConfig) config).getCacheSize());
+ }
return new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
@@ -254,23 +260,37 @@
private static final class LruCache<K, V> {
- private final Map<K, V> cache;
+ private Map<K, V> cache;
+ private int maxSize;
LruCache(int maxSize) {
- this.cache = new LinkedHashMap<K, V>(
- maxSize,
- 0.75f,
- true) {
- @Override
- protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
- return size() > maxSize;
- }
- };
+ this.maxSize = maxSize;
+ this.cache = createEvictingMap(maxSize);
}
V getOrInsert(K key, Function<K, V> create) {
return cache.computeIfAbsent(key, create);
}
+
+ private void resizeCache(int newSize) {
+ if (newSize >= maxSize) {
+ maxSize = newSize;
+ return;
+ }
+ Map<K, V> newCache = createEvictingMap(newSize);
+ maxSize = newSize;
+ newCache.putAll(cache);
+ cache = newCache;
+ }
+
+ private Map<K, V> createEvictingMap(int size) {
+ return new LinkedHashMap<K, V>(size, 0.75f, true) {
+ @Override
+ protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
+ return size() > LruCache.this.maxSize;
+ }
+ };
+ }
}
static class AudienceMetadataParser implements MetadataValueParser {
diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
index a5e142b..d84d8c9 100644
--- a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
+++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
@@ -28,11 +28,13 @@
import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.common.collect.ImmutableList;
@@ -89,8 +91,8 @@
@Test
public void testNewFilterInstancesPerFilterName() {
- assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"))
- .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"));
+ assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10))
+ .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10));
}
@Test
@@ -152,7 +154,7 @@
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
- GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
@@ -181,7 +183,7 @@
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
- GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
@@ -190,7 +192,7 @@
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
- verify(mockChannel, Mockito.times(2))
+ verify(mockChannel, times(2))
.newCall(eq(methodDescriptor), callOptionsCaptor.capture());
CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0);
CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1);
@@ -202,7 +204,7 @@
@Test
public void testClientInterceptor_withoutClusterSelectionKey() throws Exception {
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
- GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
@@ -233,7 +235,7 @@
Channel mockChannel = mock(Channel.class);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
- GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
@@ -244,7 +246,7 @@
@Test
public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception {
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
- GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
@@ -274,7 +276,7 @@
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
- GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
@@ -300,7 +302,7 @@
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
- GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
@@ -329,7 +331,7 @@
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
- GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
@@ -342,6 +344,115 @@
assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type");
}
+ @Test
+ public void testLruCacheAcrossInterceptors() throws IOException, ResourceInvalidException {
+ XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
+ CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
+ XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
+ .setListener(ldsUpdate)
+ .setRoute(rdsUpdate)
+ .setVirtualHost(rdsUpdate.virtualHosts.get(0))
+ .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
+ CallOptions callOptionsWithXds = CallOptions.DEFAULT
+ .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
+ .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
+ ClientInterceptor interceptor1
+ = filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
+ MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
+ Channel mockChannel = Mockito.mock(Channel.class);
+ ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class);
+
+ interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
+ verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture());
+ CallOptions capturedOptions1 = callOptionsCaptor.getAllValues().get(0);
+ assertNotNull(capturedOptions1.getCredentials());
+ ClientInterceptor interceptor2
+ = filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
+ interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
+ verify(mockChannel, times(2))
+ .newCall(eq(methodDescriptor), callOptionsCaptor.capture());
+ CallOptions capturedOptions2 = callOptionsCaptor.getAllValues().get(1);
+ assertNotNull(capturedOptions2.getCredentials());
+
+ assertSame(capturedOptions1.getCredentials(), capturedOptions2.getCredentials());
+ }
+
+ @Test
+ public void testLruCacheEvictionOnResize() throws IOException, ResourceInvalidException {
+ XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
+ CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
+ XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
+ .setListener(ldsUpdate)
+ .setRoute(rdsUpdate)
+ .setVirtualHost(rdsUpdate.virtualHosts.get(0))
+ .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
+ CallOptions callOptionsWithXds = CallOptions.DEFAULT
+ .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
+ .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
+ GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
+ MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
+
+ ClientInterceptor interceptor1 =
+ filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
+ Channel mockChannel1 = Mockito.mock(Channel.class);
+ ArgumentCaptor<CallOptions> captor = ArgumentCaptor.forClass(CallOptions.class);
+ interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel1);
+ verify(mockChannel1).newCall(eq(methodDescriptor), captor.capture());
+ CallOptions options1 = captor.getValue();
+ // This will recreate the cache with max size of 1 and copy the credential for audience1.
+ ClientInterceptor interceptor2 =
+ filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
+ Channel mockChannel2 = Mockito.mock(Channel.class);
+ interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel2);
+ verify(mockChannel2).newCall(eq(methodDescriptor), captor.capture());
+ CallOptions options2 = captor.getValue();
+
+ assertSame(options1.getCredentials(), options2.getCredentials());
+
+ clusterConfig = new XdsConfig.XdsClusterConfig(
+ CLUSTER_NAME, getCdsUpdate2(), new EndpointConfig(StatusOr.fromValue(edsUpdate)));
+ defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
+ .setListener(ldsUpdate)
+ .setRoute(rdsUpdate)
+ .setVirtualHost(rdsUpdate.virtualHosts.get(0))
+ .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
+ callOptionsWithXds = CallOptions.DEFAULT
+ .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
+ .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
+
+ // This will evict the credential for audience1 and add new credential for audience2
+ ClientInterceptor interceptor3 =
+ filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
+ Channel mockChannel3 = Mockito.mock(Channel.class);
+ interceptor3.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel3);
+ verify(mockChannel3).newCall(eq(methodDescriptor), captor.capture());
+ CallOptions options3 = captor.getValue();
+
+ assertNotSame(options1.getCredentials(), options3.getCredentials());
+
+ clusterConfig = new XdsConfig.XdsClusterConfig(
+ CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
+ defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
+ .setListener(ldsUpdate)
+ .setRoute(rdsUpdate)
+ .setVirtualHost(rdsUpdate.virtualHosts.get(0))
+ .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
+ callOptionsWithXds = CallOptions.DEFAULT
+ .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
+ .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
+
+ // This will create new credential for audience1 because it has been evicted
+ ClientInterceptor interceptor4 =
+ filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
+ Channel mockChannel4 = Mockito.mock(Channel.class);
+ interceptor4.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel4);
+ verify(mockChannel4).newCall(eq(methodDescriptor), captor.capture());
+ CallOptions options4 = captor.getValue();
+
+ assertNotSame(options1.getCredentials(), options4.getCredentials());
+ }
+
private static LdsUpdate getLdsUpdate() {
Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig(
serverName, RouterFilter.ROUTER_CONFIG);
@@ -384,6 +495,19 @@
}
}
+ private static CdsUpdate getCdsUpdate2() {
+ ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
+ parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("NEW_TEST_AUDIENCE"));
+ try {
+ CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds(
+ CLUSTER_NAME, EDS_NAME, null, null, null, null, false)
+ .lbPolicyConfig(getWrrLbConfigAsMap());
+ return cdsUpdate.parsedMetadata(parsedMetadata.build()).build();
+ } catch (IOException ex) {
+ return null;
+ }
+ }
+
private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException {
ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE");