xds: float LRU cache across interceptors (#11992)

diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
index b5568ef..8ec02f4 100644
--- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
+++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
@@ -59,15 +59,17 @@
 
   static final String TYPE_URL =
       "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig";
-
+  private final LruCache<String, CallCredentials> callCredentialsCache;
   final String filterInstanceName;
 
-  GcpAuthenticationFilter(String name) {
+  GcpAuthenticationFilter(String name, int cacheSize) {
     filterInstanceName = checkNotNull(name, "name");
+    this.callCredentialsCache = new LruCache<>(cacheSize);
   }
 
-
   static final class Provider implements Filter.Provider {
+    private final int cacheSize = 10;
+
     @Override
     public String[] typeUrls() {
       return new String[]{TYPE_URL};
@@ -80,7 +82,7 @@
 
     @Override
     public GcpAuthenticationFilter newInstance(String name) {
-      return new GcpAuthenticationFilter(name);
+      return new GcpAuthenticationFilter(name, cacheSize);
     }
 
     @Override
@@ -101,11 +103,14 @@
       // Validate cache_config
       if (gcpAuthnProto.hasCacheConfig()) {
         TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig();
-        cacheSize = cacheConfig.getCacheSize().getValue();
-        if (cacheSize == 0) {
-          return ConfigOrError.fromError(
-              "cache_config.cache_size must be greater than zero");
+        if (cacheConfig.hasCacheSize()) {
+          cacheSize = cacheConfig.getCacheSize().getValue();
+          if (cacheSize == 0) {
+            return ConfigOrError.fromError(
+                "cache_config.cache_size must be greater than zero");
+          }
         }
+
         // LruCache's size is an int and briefly exceeds its maximum size before evicting entries
         cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1);
       }
@@ -127,8 +132,9 @@
       @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) {
 
     ComputeEngineCredentials credentials = ComputeEngineCredentials.create();
-    LruCache<String, CallCredentials> callCredentialsCache =
-        new LruCache<>(((GcpAuthenticationConfig) config).getCacheSize());
+    synchronized (callCredentialsCache) {
+      callCredentialsCache.resizeCache(((GcpAuthenticationConfig) config).getCacheSize());
+    }
     return new ClientInterceptor() {
       @Override
       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
@@ -254,23 +260,37 @@
 
   private static final class LruCache<K, V> {
 
-    private final Map<K, V> cache;
+    private Map<K, V> cache;
+    private int maxSize;
 
     LruCache(int maxSize) {
-      this.cache = new LinkedHashMap<K, V>(
-          maxSize,
-          0.75f,
-          true) {
-        @Override
-        protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
-          return size() > maxSize;
-        }
-      };
+      this.maxSize = maxSize;
+      this.cache = createEvictingMap(maxSize);
     }
 
     V getOrInsert(K key, Function<K, V> create) {
       return cache.computeIfAbsent(key, create);
     }
+
+    private void resizeCache(int newSize) {
+      if (newSize >= maxSize) {
+        maxSize = newSize;
+        return;
+      }
+      Map<K, V> newCache = createEvictingMap(newSize);
+      maxSize = newSize;
+      newCache.putAll(cache);
+      cache = newCache;
+    }
+
+    private Map<K, V> createEvictingMap(int size) {
+      return new LinkedHashMap<K, V>(size, 0.75f, true) {
+        @Override
+        protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
+          return size() > LruCache.this.maxSize;
+        }
+      };
+    }
   }
 
   static class AudienceMetadataParser implements MetadataValueParser {
diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
index a5e142b..d84d8c9 100644
--- a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
+++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
@@ -28,11 +28,13 @@
 import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
 import com.google.common.collect.ImmutableList;
@@ -89,8 +91,8 @@
 
   @Test
   public void testNewFilterInstancesPerFilterName() {
-    assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"))
-        .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"));
+    assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10))
+        .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10));
   }
 
   @Test
@@ -152,7 +154,7 @@
         .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
         .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
     GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
-    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
     ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
     MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
     Channel mockChannel = Mockito.mock(Channel.class);
@@ -181,7 +183,7 @@
         .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
         .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
     GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
-    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
     ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
     MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
     Channel mockChannel = Mockito.mock(Channel.class);
@@ -190,7 +192,7 @@
     interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
     interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
 
-    verify(mockChannel, Mockito.times(2))
+    verify(mockChannel, times(2))
         .newCall(eq(methodDescriptor), callOptionsCaptor.capture());
     CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0);
     CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1);
@@ -202,7 +204,7 @@
   @Test
   public void testClientInterceptor_withoutClusterSelectionKey() throws Exception {
     GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
-    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
     ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
     MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
     Channel mockChannel = mock(Channel.class);
@@ -233,7 +235,7 @@
     Channel mockChannel = mock(Channel.class);
 
     GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
-    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
     ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
     MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
     interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
@@ -244,7 +246,7 @@
   @Test
   public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception {
     GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
-    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
     ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
     MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
     Channel mockChannel = mock(Channel.class);
@@ -274,7 +276,7 @@
         .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster")
         .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
     GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
-    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
     ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
     MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
     Channel mockChannel = mock(Channel.class);
@@ -300,7 +302,7 @@
         .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
         .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
     GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
-    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
     ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
     MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
     Channel mockChannel = mock(Channel.class);
@@ -329,7 +331,7 @@
         .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
         .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
     GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
-    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
     ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
     MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
     Channel mockChannel = Mockito.mock(Channel.class);
@@ -342,6 +344,115 @@
     assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type");
   }
 
+  @Test
+  public void testLruCacheAcrossInterceptors() throws IOException, ResourceInvalidException {
+    XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
+        CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
+    XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
+        .setListener(ldsUpdate)
+        .setRoute(rdsUpdate)
+        .setVirtualHost(rdsUpdate.virtualHosts.get(0))
+        .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
+    CallOptions callOptionsWithXds = CallOptions.DEFAULT
+        .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
+        .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
+    ClientInterceptor interceptor1
+        = filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
+    MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
+    Channel mockChannel = Mockito.mock(Channel.class);
+    ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class);
+
+    interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
+    verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture());
+    CallOptions capturedOptions1 = callOptionsCaptor.getAllValues().get(0);
+    assertNotNull(capturedOptions1.getCredentials());
+    ClientInterceptor interceptor2
+        = filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
+    interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
+    verify(mockChannel, times(2))
+        .newCall(eq(methodDescriptor), callOptionsCaptor.capture());
+    CallOptions capturedOptions2 = callOptionsCaptor.getAllValues().get(1);
+    assertNotNull(capturedOptions2.getCredentials());
+
+    assertSame(capturedOptions1.getCredentials(), capturedOptions2.getCredentials());
+  }
+
+  @Test
+  public void testLruCacheEvictionOnResize() throws IOException, ResourceInvalidException {
+    XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
+        CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
+    XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
+        .setListener(ldsUpdate)
+        .setRoute(rdsUpdate)
+        .setVirtualHost(rdsUpdate.virtualHosts.get(0))
+        .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
+    CallOptions callOptionsWithXds = CallOptions.DEFAULT
+        .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
+        .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
+    GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
+    MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
+
+    ClientInterceptor interceptor1 =
+        filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
+    Channel mockChannel1 = Mockito.mock(Channel.class);
+    ArgumentCaptor<CallOptions> captor = ArgumentCaptor.forClass(CallOptions.class);
+    interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel1);
+    verify(mockChannel1).newCall(eq(methodDescriptor), captor.capture());
+    CallOptions options1 = captor.getValue();
+    // This will recreate the cache with max size of 1 and copy the credential for audience1.
+    ClientInterceptor interceptor2 =
+        filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
+    Channel mockChannel2 = Mockito.mock(Channel.class);
+    interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel2);
+    verify(mockChannel2).newCall(eq(methodDescriptor), captor.capture());
+    CallOptions options2 = captor.getValue();
+
+    assertSame(options1.getCredentials(), options2.getCredentials());
+
+    clusterConfig = new XdsConfig.XdsClusterConfig(
+        CLUSTER_NAME, getCdsUpdate2(), new EndpointConfig(StatusOr.fromValue(edsUpdate)));
+    defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
+        .setListener(ldsUpdate)
+        .setRoute(rdsUpdate)
+        .setVirtualHost(rdsUpdate.virtualHosts.get(0))
+        .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
+    callOptionsWithXds = CallOptions.DEFAULT
+        .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
+        .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
+
+    // This will evict the credential for audience1 and add new credential for audience2
+    ClientInterceptor interceptor3 =
+        filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
+    Channel mockChannel3 = Mockito.mock(Channel.class);
+    interceptor3.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel3);
+    verify(mockChannel3).newCall(eq(methodDescriptor), captor.capture());
+    CallOptions options3 = captor.getValue();
+
+    assertNotSame(options1.getCredentials(), options3.getCredentials());
+
+    clusterConfig = new XdsConfig.XdsClusterConfig(
+        CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
+    defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
+        .setListener(ldsUpdate)
+        .setRoute(rdsUpdate)
+        .setVirtualHost(rdsUpdate.virtualHosts.get(0))
+        .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
+    callOptionsWithXds = CallOptions.DEFAULT
+        .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
+        .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
+
+    // This will create new credential for audience1 because it has been evicted
+    ClientInterceptor interceptor4 =
+        filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
+    Channel mockChannel4 = Mockito.mock(Channel.class);
+    interceptor4.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel4);
+    verify(mockChannel4).newCall(eq(methodDescriptor), captor.capture());
+    CallOptions options4 = captor.getValue();
+
+    assertNotSame(options1.getCredentials(), options4.getCredentials());
+  }
+
   private static LdsUpdate getLdsUpdate() {
     Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig(
         serverName, RouterFilter.ROUTER_CONFIG);
@@ -384,6 +495,19 @@
     }
   }
 
+  private static CdsUpdate getCdsUpdate2() {
+    ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
+    parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("NEW_TEST_AUDIENCE"));
+    try {
+      CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds(
+              CLUSTER_NAME, EDS_NAME, null, null, null, null, false)
+          .lbPolicyConfig(getWrrLbConfigAsMap());
+      return cdsUpdate.parsedMetadata(parsedMetadata.build()).build();
+    } catch (IOException ex) {
+      return null;
+    }
+  }
+
   private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException {
     ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
     parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE");