pemfile: Move file watcher plugin from advancedtls to gRPC (#3981)

diff --git a/credentials/tls/certprovider/pemfile/watcher.go b/credentials/tls/certprovider/pemfile/watcher.go
new file mode 100644
index 0000000..29ea8b2
--- /dev/null
+++ b/credentials/tls/certprovider/pemfile/watcher.go
@@ -0,0 +1,252 @@
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+// Package pemfile provides a file watching certificate provider plugin
+// implementation which works for files with PEM contents.
+//
+// Experimental
+//
+// Notice: All APIs in this package are experimental and may be removed in a
+// later release.
+package pemfile
+
+import (
+	"bytes"
+	"context"
+	"crypto/tls"
+	"crypto/x509"
+	"fmt"
+	"io/ioutil"
+	"time"
+
+	"google.golang.org/grpc/credentials/tls/certprovider"
+	"google.golang.org/grpc/grpclog"
+)
+
+const (
+	defaultCertRefreshDuration = 1 * time.Hour
+	defaultRootRefreshDuration = 2 * time.Hour
+)
+
+var (
+	// For overriding from unit tests.
+	newDistributor = func() distributor { return certprovider.NewDistributor() }
+
+	logger = grpclog.Component("pemfile")
+)
+
+// Options configures a certificate provider plugin that watches a specified set
+// of files that contain certificates and keys in PEM format.
+type Options struct {
+	// CertFile is the file that holds the identity certificate.
+	// Optional. If this is set, KeyFile must also be set.
+	CertFile string
+	// KeyFile is the file that holds identity private key.
+	// Optional. If this is set, CertFile must also be set.
+	KeyFile string
+	// RootFile is the file that holds trusted root certificate(s).
+	// Optional.
+	RootFile string
+	// CertRefreshDuration is the amount of time the plugin waits before
+	// checking for updates in the specified identity certificate and key file.
+	// Optional. If not set, a default value (1 hour) will be used.
+	CertRefreshDuration time.Duration
+	// RootRefreshDuration is the amount of time the plugin waits before
+	// checking for updates in the specified root file.
+	// Optional. If not set, a default value (2 hour) will be used.
+	RootRefreshDuration time.Duration
+}
+
+// NewProvider returns a new certificate provider plugin that is configured to
+// watch the PEM files specified in the passed in options.
+func NewProvider(o Options) (certprovider.Provider, error) {
+	if o.CertFile == "" && o.KeyFile == "" && o.RootFile == "" {
+		return nil, fmt.Errorf("pemfile: at least one credential file needs to be specified")
+	}
+	if keySpecified, certSpecified := o.KeyFile != "", o.CertFile != ""; keySpecified != certSpecified {
+		return nil, fmt.Errorf("pemfile: private key file and identity cert file should be both specified or not specified")
+	}
+	if o.CertRefreshDuration == 0 {
+		o.CertRefreshDuration = defaultCertRefreshDuration
+	}
+	if o.RootRefreshDuration == 0 {
+		o.RootRefreshDuration = defaultRootRefreshDuration
+	}
+
+	provider := &watcher{opts: o}
+	if o.CertFile != "" && o.KeyFile != "" {
+		provider.identityDistributor = newDistributor()
+	}
+	if o.RootFile != "" {
+		provider.rootDistributor = newDistributor()
+	}
+
+	ctx, cancel := context.WithCancel(context.Background())
+	provider.cancel = cancel
+	go provider.run(ctx)
+
+	return provider, nil
+}
+
+// watcher is a certificate provider plugin that implements the
+// certprovider.Provider interface. It watches a set of certificate and key
+// files and provides the most up-to-date key material for consumption by
+// credentials implementation.
+type watcher struct {
+	identityDistributor distributor
+	rootDistributor     distributor
+	opts                Options
+	certFileContents    []byte
+	keyFileContents     []byte
+	rootFileContents    []byte
+	cancel              context.CancelFunc
+}
+
+// distributor wraps the methods on certprovider.Distributor which are used by
+// the plugin. This is very useful in tests which need to know exactly when the
+// plugin updates its key material.
+type distributor interface {
+	KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error)
+	Set(km *certprovider.KeyMaterial, err error)
+	Stop()
+}
+
+// updateIdentityDistributor checks if the cert/key files that the plugin is
+// watching have changed, and if so, reads the new contents and updates the
+// identityDistributor with the new key material.
+//
+// Skips updates when file reading or parsing fails.
+// TODO(easwars): Retry with limit (on the number of retries or the amount of
+// time) upon failures.
+func (w *watcher) updateIdentityDistributor() {
+	if w.identityDistributor == nil {
+		return
+	}
+
+	certFileContents, err := ioutil.ReadFile(w.opts.CertFile)
+	if err != nil {
+		logger.Warningf("certFile (%s) read failed: %v", w.opts.CertFile, err)
+		return
+	}
+	keyFileContents, err := ioutil.ReadFile(w.opts.KeyFile)
+	if err != nil {
+		logger.Warningf("keyFile (%s) read failed: %v", w.opts.KeyFile, err)
+		return
+	}
+	// If the file contents have not changed, skip updating the distributor.
+	if bytes.Equal(w.certFileContents, certFileContents) && bytes.Equal(w.keyFileContents, keyFileContents) {
+		return
+	}
+
+	cert, err := tls.X509KeyPair(certFileContents, keyFileContents)
+	if err != nil {
+		logger.Warningf("tls.X509KeyPair(%q, %q) failed: %v", certFileContents, keyFileContents, err)
+		return
+	}
+	w.certFileContents = certFileContents
+	w.keyFileContents = keyFileContents
+	w.identityDistributor.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}, nil)
+}
+
+// updateRootDistributor checks if the root cert file that the plugin is
+// watching hs changed, and if so, updates the rootDistributor with the new key
+// material.
+//
+// Skips updates when root cert reading or parsing fails.
+// TODO(easwars): Retry with limit (on the number of retries or the amount of
+// time) upon failures.
+func (w *watcher) updateRootDistributor() {
+	if w.rootDistributor == nil {
+		return
+	}
+
+	rootFileContents, err := ioutil.ReadFile(w.opts.RootFile)
+	if err != nil {
+		logger.Warningf("rootFile (%s) read failed: %v", w.opts.RootFile, err)
+		return
+	}
+	trustPool := x509.NewCertPool()
+	if !trustPool.AppendCertsFromPEM(rootFileContents) {
+		logger.Warning("failed to parse root certificate")
+		return
+	}
+	// If the file contents have not changed, skip updating the distributor.
+	if bytes.Equal(w.rootFileContents, rootFileContents) {
+		return
+	}
+
+	w.rootFileContents = rootFileContents
+	w.rootDistributor.Set(&certprovider.KeyMaterial{Roots: trustPool}, nil)
+}
+
+// run is a long running goroutine which watches the configured files for
+// changes, and pushes new key material into the appropriate distributors which
+// is returned from calls to KeyMaterial().
+func (w *watcher) run(ctx context.Context) {
+	// Update both root and identity certs at the beginning. Subsequently,
+	// update only the appropriate file whose ticker has fired.
+	w.updateIdentityDistributor()
+	w.updateRootDistributor()
+
+	identityTicker := time.NewTicker(w.opts.CertRefreshDuration)
+	rootTicker := time.NewTicker(w.opts.RootRefreshDuration)
+	for {
+		select {
+		case <-ctx.Done():
+			identityTicker.Stop()
+			rootTicker.Stop()
+			if w.identityDistributor != nil {
+				w.identityDistributor.Stop()
+			}
+			if w.rootDistributor != nil {
+				w.rootDistributor.Stop()
+			}
+			return
+		case <-identityTicker.C:
+			w.updateIdentityDistributor()
+		case <-rootTicker.C:
+			w.updateRootDistributor()
+		}
+	}
+}
+
+// KeyMaterial returns the key material sourced by the watcher.
+// Callers are expected to use the returned value as read-only.
+func (w *watcher) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
+	km := &certprovider.KeyMaterial{}
+	if w.identityDistributor != nil {
+		identityKM, err := w.identityDistributor.KeyMaterial(ctx)
+		if err != nil {
+			return nil, err
+		}
+		km.Certs = identityKM.Certs
+	}
+	if w.rootDistributor != nil {
+		rootKM, err := w.rootDistributor.KeyMaterial(ctx)
+		if err != nil {
+			return nil, err
+		}
+		km.Roots = rootKM.Roots
+	}
+	return km, nil
+}
+
+// Close cleans up resources allocated by the watcher.
+func (w *watcher) Close() {
+	w.cancel()
+}
diff --git a/credentials/tls/certprovider/pemfile/watcher_test.go b/credentials/tls/certprovider/pemfile/watcher_test.go
new file mode 100644
index 0000000..092bd30
--- /dev/null
+++ b/credentials/tls/certprovider/pemfile/watcher_test.go
@@ -0,0 +1,426 @@
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package pemfile
+
+import (
+	"context"
+	"crypto/x509"
+	"io/ioutil"
+	"math/big"
+	"os"
+	"path"
+	"testing"
+	"time"
+
+	"github.com/google/go-cmp/cmp"
+
+	"google.golang.org/grpc/credentials/tls/certprovider"
+	"google.golang.org/grpc/internal/grpctest"
+	"google.golang.org/grpc/internal/testutils"
+	"google.golang.org/grpc/testdata"
+)
+
+const (
+	// These are the names of files inside temporary directories, which the
+	// plugin is asked to watch.
+	certFile = "cert.pem"
+	keyFile  = "key.pem"
+	rootFile = "ca.pem"
+
+	defaultTestRefreshDuration = 100 * time.Millisecond
+	defaultTestTimeout         = 5 * time.Second
+)
+
+type s struct {
+	grpctest.Tester
+}
+
+func Test(t *testing.T) {
+	grpctest.RunSubTests(t, s{})
+}
+
+// TestNewProvider tests the NewProvider() function with different inputs.
+func (s) TestNewProvider(t *testing.T) {
+	tests := []struct {
+		desc      string
+		options   Options
+		wantError bool
+	}{
+		{
+			desc:      "No credential files specified",
+			options:   Options{},
+			wantError: true,
+		},
+		{
+			desc: "Only identity cert is specified",
+			options: Options{
+				CertFile: testdata.Path("x509/client1_cert.pem"),
+			},
+			wantError: true,
+		},
+		{
+			desc: "Only identity key is specified",
+			options: Options{
+				KeyFile: testdata.Path("x509/client1_key.pem"),
+			},
+			wantError: true,
+		},
+		{
+			desc: "Identity cert/key pair is specified",
+			options: Options{
+				KeyFile:  testdata.Path("x509/client1_key.pem"),
+				CertFile: testdata.Path("x509/client1_cert.pem"),
+			},
+		},
+		{
+			desc: "Only root certs are specified",
+			options: Options{
+				RootFile: testdata.Path("x509/client_ca_cert.pem"),
+			},
+		},
+		{
+			desc: "Everything is specified",
+			options: Options{
+				KeyFile:  testdata.Path("x509/client1_key.pem"),
+				CertFile: testdata.Path("x509/client1_cert.pem"),
+				RootFile: testdata.Path("x509/client_ca_cert.pem"),
+			},
+			wantError: false,
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.desc, func(t *testing.T) {
+			provider, err := NewProvider(test.options)
+			if (err != nil) != test.wantError {
+				t.Fatalf("NewProvider(%v) = %v, want %v", test.options, err, test.wantError)
+			}
+			if err != nil {
+				return
+			}
+			provider.Close()
+		})
+	}
+}
+
+// wrappedDistributor wraps a distributor and pushes on a channel whenever new
+// key material is pushed to the distributor.
+type wrappedDistributor struct {
+	*certprovider.Distributor
+	distCh *testutils.Channel
+}
+
+func newWrappedDistributor(distCh *testutils.Channel) *wrappedDistributor {
+	return &wrappedDistributor{
+		distCh:      distCh,
+		Distributor: certprovider.NewDistributor(),
+	}
+}
+
+func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) {
+	wd.Distributor.Set(km, err)
+	wd.distCh.Send(nil)
+}
+
+func createTmpFile(t *testing.T, src, dst string) {
+	t.Helper()
+
+	data, err := ioutil.ReadFile(src)
+	if err != nil {
+		t.Fatalf("ioutil.ReadFile(%q) failed: %v", src, err)
+	}
+	if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil {
+		t.Fatalf("ioutil.WriteFile(%q) failed: %v", dst, err)
+	}
+	t.Logf("Wrote file at: %s", dst)
+	t.Logf("%s", string(data))
+}
+
+// createTempDirWithFiles creates a temporary directory under the system default
+// tempDir with the given dirSuffix. It also reads from certSrc, keySrc and
+// rootSrc files are creates appropriate files under the newly create tempDir.
+// Returns the name of the created tempDir.
+func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string {
+	t.Helper()
+
+	// Create a temp directory. Passing an empty string for the first argument
+	// uses the system temp directory.
+	dir, err := ioutil.TempDir("", dirSuffix)
+	if err != nil {
+		t.Fatalf("ioutil.TempDir() failed: %v", err)
+	}
+	t.Logf("Using tmpdir: %s", dir)
+
+	createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile))
+	createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile))
+	createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile))
+	return dir
+}
+
+// initializeProvider performs setup steps common to all tests (except the one
+// which uses symlinks).
+func initializeProvider(t *testing.T, testName string) (string, certprovider.Provider, *testutils.Channel, func()) {
+	t.Helper()
+
+	// Override the newDistributor to one which pushes on a channel that we
+	// can block on.
+	origDistributorFunc := newDistributor
+	distCh := testutils.NewChannel()
+	d := newWrappedDistributor(distCh)
+	newDistributor = func() distributor { return d }
+
+	// Create a new provider to watch the files in tmpdir.
+	dir := createTmpDirWithFiles(t, testName+"*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
+	opts := Options{
+		CertFile:            path.Join(dir, certFile),
+		KeyFile:             path.Join(dir, keyFile),
+		RootFile:            path.Join(dir, rootFile),
+		CertRefreshDuration: defaultTestRefreshDuration,
+		RootRefreshDuration: defaultTestRefreshDuration,
+	}
+	prov, err := NewProvider(opts)
+	if err != nil {
+		t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
+	}
+
+	// Make sure the provider picks up the files and pushes the key material on
+	// to the distributors.
+	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+	defer cancel()
+	for i := 0; i < 2; i++ {
+		// Since we have root and identity certs, we need to make sure the
+		// update is pushed on both of them.
+		if _, err := distCh.Receive(ctx); err != nil {
+			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
+		}
+	}
+
+	return dir, prov, distCh, func() {
+		newDistributor = origDistributorFunc
+		prov.Close()
+	}
+}
+
+// TestProvider_NoUpdate tests the case where a file watcher plugin is created
+// successfully, and the underlying files do not change. Verifies that the
+// plugin does not push new updates to the distributor in this case.
+func (s) TestProvider_NoUpdate(t *testing.T) {
+	_, prov, distCh, cancel := initializeProvider(t, "no_update")
+	defer cancel()
+
+	// Make sure the provider is healthy and returns key material.
+	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
+	defer cc()
+	if _, err := prov.KeyMaterial(ctx); err != nil {
+		t.Fatalf("provider.KeyMaterial() failed: %v", err)
+	}
+
+	// Files haven't change. Make sure no updates are pushed by the provider.
+	sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
+	defer sc()
+	if _, err := distCh.Receive(sCtx); err == nil {
+		t.Fatal("new key material pushed to distributor when underlying files did not change")
+	}
+}
+
+// TestProvider_UpdateSuccess tests the case where a file watcher plugin is
+// created successfully and the underlying files change. Verifies that the
+// changes are picked up by the provider.
+func (s) TestProvider_UpdateSuccess(t *testing.T) {
+	dir, prov, distCh, cancel := initializeProvider(t, "update_success")
+	defer cancel()
+
+	// Make sure the provider is healthy and returns key material.
+	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
+	defer cc()
+	km1, err := prov.KeyMaterial(ctx)
+	if err != nil {
+		t.Fatalf("provider.KeyMaterial() failed: %v", err)
+	}
+
+	// Change only the root file.
+	createTmpFile(t, testdata.Path("x509/server_ca_cert.pem"), path.Join(dir, rootFile))
+	if _, err := distCh.Receive(ctx); err != nil {
+		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
+	}
+
+	// Make sure update is picked up.
+	km2, err := prov.KeyMaterial(ctx)
+	if err != nil {
+		t.Fatalf("provider.KeyMaterial() failed: %v", err)
+	}
+	if cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
+		t.Fatal("expected provider to return new key material after update to underlying file")
+	}
+
+	// Change only cert/key files.
+	createTmpFile(t, testdata.Path("x509/client2_cert.pem"), path.Join(dir, certFile))
+	createTmpFile(t, testdata.Path("x509/client2_key.pem"), path.Join(dir, keyFile))
+	if _, err := distCh.Receive(ctx); err != nil {
+		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
+	}
+
+	// Make sure update is picked up.
+	km3, err := prov.KeyMaterial(ctx)
+	if err != nil {
+		t.Fatalf("provider.KeyMaterial() failed: %v", err)
+	}
+	if cmp.Equal(km2, km3, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
+		t.Fatal("expected provider to return new key material after update to underlying file")
+	}
+}
+
+// TestProvider_UpdateSuccessWithSymlink tests the case where a file watcher
+// plugin is created successfully to watch files through a symlink and the
+// symlink is updates to point to new files. Verifies that the changes are
+// picked up by the provider.
+func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) {
+	// Override the newDistributor to one which pushes on a channel that we
+	// can block on.
+	origDistributorFunc := newDistributor
+	distCh := testutils.NewChannel()
+	d := newWrappedDistributor(distCh)
+	newDistributor = func() distributor { return d }
+	defer func() { newDistributor = origDistributorFunc }()
+
+	// Create two tempDirs with different files.
+	dir1 := createTmpDirWithFiles(t, "update_with_symlink1_*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
+	dir2 := createTmpDirWithFiles(t, "update_with_symlink2_*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/server_ca_cert.pem")
+
+	// Create a symlink under a new tempdir, and make it point to dir1.
+	tmpdir, err := ioutil.TempDir("", "test_symlink_*")
+	if err != nil {
+		t.Fatalf("ioutil.TempDir() failed: %v", err)
+	}
+	symLinkName := path.Join(tmpdir, "test_symlink")
+	if err := os.Symlink(dir1, symLinkName); err != nil {
+		t.Fatalf("failed to create symlink to %q: %v", dir1, err)
+	}
+
+	// Create a provider which watches the files pointed to by the symlink.
+	opts := Options{
+		CertFile:            path.Join(symLinkName, certFile),
+		KeyFile:             path.Join(symLinkName, keyFile),
+		RootFile:            path.Join(symLinkName, rootFile),
+		CertRefreshDuration: defaultTestRefreshDuration,
+		RootRefreshDuration: defaultTestRefreshDuration,
+	}
+	prov, err := NewProvider(opts)
+	if err != nil {
+		t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
+	}
+	defer prov.Close()
+
+	// Make sure the provider picks up the files and pushes the key material on
+	// to the distributors.
+	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+	defer cancel()
+	for i := 0; i < 2; i++ {
+		// Since we have root and identity certs, we need to make sure the
+		// update is pushed on both of them.
+		if _, err := distCh.Receive(ctx); err != nil {
+			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
+		}
+	}
+	km1, err := prov.KeyMaterial(ctx)
+	if err != nil {
+		t.Fatalf("provider.KeyMaterial() failed: %v", err)
+	}
+
+	// Update the symlink to point to dir2.
+	symLinkTmpName := path.Join(tmpdir, "test_symlink.tmp")
+	if err := os.Symlink(dir2, symLinkTmpName); err != nil {
+		t.Fatalf("failed to create symlink to %q: %v", dir2, err)
+	}
+	if err := os.Rename(symLinkTmpName, symLinkName); err != nil {
+		t.Fatalf("failed to update symlink: %v", err)
+	}
+
+	// Make sure the provider picks up the new files and pushes the key material
+	// on to the distributors.
+	for i := 0; i < 2; i++ {
+		// Since we have root and identity certs, we need to make sure the
+		// update is pushed on both of them.
+		if _, err := distCh.Receive(ctx); err != nil {
+			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
+		}
+	}
+	km2, err := prov.KeyMaterial(ctx)
+	if err != nil {
+		t.Fatalf("provider.KeyMaterial() failed: %v", err)
+	}
+
+	if cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
+		t.Fatal("expected provider to return new key material after symlink update")
+	}
+}
+
+// TestProvider_UpdateFailure_ThenSuccess tests the case where updating cert/key
+// files fail. Verifies that the failed update does not push anything on the
+// distributor. Then the update succeeds, and the test verifies that the key
+// material is updated.
+func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) {
+	dir, prov, distCh, cancel := initializeProvider(t, "update_failure")
+	defer cancel()
+
+	// Make sure the provider is healthy and returns key material.
+	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
+	defer cc()
+	km1, err := prov.KeyMaterial(ctx)
+	if err != nil {
+		t.Fatalf("provider.KeyMaterial() failed: %v", err)
+	}
+
+	// Update only the cert file. The key file is left unchanged. This should
+	// lead to these two files being not compatible with each other. This
+	// simulates the case where the watching goroutine might catch the files in
+	// the midst of an update.
+	createTmpFile(t, testdata.Path("x509/server1_cert.pem"), path.Join(dir, certFile))
+
+	// Since the last update left the files in an incompatible state, the update
+	// should not be picked up by our provider.
+	sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
+	defer sc()
+	if _, err := distCh.Receive(sCtx); err == nil {
+		t.Fatal("new key material pushed to distributor when underlying files did not change")
+	}
+
+	// The provider should return key material corresponding to the old state.
+	km2, err := prov.KeyMaterial(ctx)
+	if err != nil {
+		t.Fatalf("provider.KeyMaterial() failed: %v", err)
+	}
+	if !cmp.Equal(km1, km2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
+		t.Fatal("expected provider to not update key material")
+	}
+
+	// Update the key file to match the cert file.
+	createTmpFile(t, testdata.Path("x509/server1_key.pem"), path.Join(dir, keyFile))
+
+	// Make sure update is picked up.
+	if _, err := distCh.Receive(ctx); err != nil {
+		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
+	}
+	km3, err := prov.KeyMaterial(ctx)
+	if err != nil {
+		t.Fatalf("provider.KeyMaterial() failed: %v", err)
+	}
+	if cmp.Equal(km2, km3, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
+		t.Fatal("expected provider to return new key material after update to underlying file")
+	}
+}
diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go
index 20a9a58..d554468 100644
--- a/security/advancedtls/advancedtls_integration_test.go
+++ b/security/advancedtls/advancedtls_integration_test.go
@@ -32,6 +32,8 @@
 
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/credentials"
+	"google.golang.org/grpc/credentials/tls/certprovider"
+	"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
 	pb "google.golang.org/grpc/examples/helloworld/helloworld"
 	"google.golang.org/grpc/security/advancedtls/internal/testutils"
 	"google.golang.org/grpc/security/advancedtls/testdata"
@@ -511,38 +513,38 @@
 
 // Create PEMFileProvider(s) watching the content changes of temporary
 // files.
-func createProviders(tmpFiles *tmpCredsFiles) (*PEMFileProvider, *PEMFileProvider, *PEMFileProvider, *PEMFileProvider, error) {
-	clientIdentityOptions := PEMFileProviderOptions{
-		CertFile:         tmpFiles.clientCertTmp.Name(),
-		KeyFile:          tmpFiles.clientKeyTmp.Name(),
-		IdentityInterval: credRefreshingInterval,
+func createProviders(tmpFiles *tmpCredsFiles) (certprovider.Provider, certprovider.Provider, certprovider.Provider, certprovider.Provider, error) {
+	clientIdentityOptions := pemfile.Options{
+		CertFile:            tmpFiles.clientCertTmp.Name(),
+		KeyFile:             tmpFiles.clientKeyTmp.Name(),
+		CertRefreshDuration: credRefreshingInterval,
 	}
-	clientIdentityProvider, err := NewPEMFileProvider(clientIdentityOptions)
+	clientIdentityProvider, err := pemfile.NewProvider(clientIdentityOptions)
 	if err != nil {
 		return nil, nil, nil, nil, err
 	}
-	clientRootOptions := PEMFileProviderOptions{
-		TrustFile:    tmpFiles.clientTrustTmp.Name(),
-		RootInterval: credRefreshingInterval,
+	clientRootOptions := pemfile.Options{
+		RootFile:            tmpFiles.clientTrustTmp.Name(),
+		RootRefreshDuration: credRefreshingInterval,
 	}
-	clientRootProvider, err := NewPEMFileProvider(clientRootOptions)
+	clientRootProvider, err := pemfile.NewProvider(clientRootOptions)
 	if err != nil {
 		return nil, nil, nil, nil, err
 	}
-	serverIdentityOptions := PEMFileProviderOptions{
-		CertFile:         tmpFiles.serverCertTmp.Name(),
-		KeyFile:          tmpFiles.serverKeyTmp.Name(),
-		IdentityInterval: credRefreshingInterval,
+	serverIdentityOptions := pemfile.Options{
+		CertFile:            tmpFiles.serverCertTmp.Name(),
+		KeyFile:             tmpFiles.serverKeyTmp.Name(),
+		CertRefreshDuration: credRefreshingInterval,
 	}
-	serverIdentityProvider, err := NewPEMFileProvider(serverIdentityOptions)
+	serverIdentityProvider, err := pemfile.NewProvider(serverIdentityOptions)
 	if err != nil {
 		return nil, nil, nil, nil, err
 	}
-	serverRootOptions := PEMFileProviderOptions{
-		TrustFile:    tmpFiles.serverTrustTmp.Name(),
-		RootInterval: credRefreshingInterval,
+	serverRootOptions := pemfile.Options{
+		RootFile:            tmpFiles.serverTrustTmp.Name(),
+		RootRefreshDuration: credRefreshingInterval,
 	}
-	serverRootProvider, err := NewPEMFileProvider(serverRootOptions)
+	serverRootProvider, err := pemfile.NewProvider(serverRootOptions)
 	if err != nil {
 		return nil, nil, nil, nil, err
 	}
diff --git a/security/advancedtls/pemfile_provider.go b/security/advancedtls/pemfile_provider.go
deleted file mode 100644
index 96b3587..0000000
--- a/security/advancedtls/pemfile_provider.go
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- *
- * Copyright 2020 gRPC authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-
-package advancedtls
-
-import (
-	"context"
-	"crypto/tls"
-	"crypto/x509"
-	"fmt"
-	"io/ioutil"
-	"time"
-
-	"google.golang.org/grpc/credentials/tls/certprovider"
-	"google.golang.org/grpc/grpclog"
-)
-
-const defaultIdentityInterval = 1 * time.Hour
-const defaultRootInterval = 2 * time.Hour
-
-// readKeyCertPairFunc will be overridden from unit tests.
-var readKeyCertPairFunc = tls.LoadX509KeyPair
-
-// readTrustCertFunc will be overridden from unit tests.
-var readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) {
-	trustData, err := ioutil.ReadFile(trustFile)
-	if err != nil {
-		return nil, err
-	}
-	trustPool := x509.NewCertPool()
-	if !trustPool.AppendCertsFromPEM(trustData) {
-		return nil, fmt.Errorf("AppendCertsFromPEM failed to parse certificates")
-	}
-	return trustPool, nil
-}
-
-var logger = grpclog.Component("advancedtls")
-
-// PEMFileProviderOptions contains options to configure a PEMFileProvider.
-// Note that these fields will only take effect during construction. Once the
-// PEMFileProvider starts, changing fields in PEMFileProviderOptions will have
-// no effect.
-type PEMFileProviderOptions struct {
-	// CertFile is the file path that holds identity certificate whose updates
-	// will be captured by a watching goroutine.
-	// Optional. If this is set, KeyFile must also be set.
-	CertFile string
-	// KeyFile is the file path that holds identity private key whose updates
-	// will be captured by a watching goroutine.
-	// Optional. If this is set, CertFile must also be set.
-	KeyFile string
-	// TrustFile is the file path that holds trust certificate whose updates will
-	// be captured by a watching goroutine.
-	// Optional.
-	TrustFile string
-	// IdentityInterval is the time duration between two credential update checks
-	// for identity certs.
-	// Optional. If not set, we will use the default interval(1 hour).
-	IdentityInterval time.Duration
-	// RootInterval is the time duration between two credential update checks
-	// for root certs.
-	// Optional. If not set, we will use the default interval(2 hours).
-	RootInterval time.Duration
-}
-
-// PEMFileProvider implements certprovider.Provider.
-// It provides the most up-to-date identity private key-cert pairs and/or
-// root certificates.
-type PEMFileProvider struct {
-	identityDistributor *certprovider.Distributor
-	rootDistributor     *certprovider.Distributor
-	cancel              context.CancelFunc
-}
-
-func updateIdentityDistributor(distributor *certprovider.Distributor, certFile, keyFile string) {
-	if distributor == nil {
-		return
-	}
-	// Read identity certs from PEM files.
-	identityCert, err := readKeyCertPairFunc(certFile, keyFile)
-	if err != nil {
-		// If the reading produces an error, we will skip the update for this
-		// round and log the error.
-		logger.Warningf("tls.LoadX509KeyPair reads %s and %s failed: %v", certFile, keyFile, err)
-		return
-	}
-	distributor.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{identityCert}}, nil)
-}
-
-func updateRootDistributor(distributor *certprovider.Distributor, trustFile string) {
-	if distributor == nil {
-		return
-	}
-	// Read root certs from PEM files.
-	trustPool, err := readTrustCertFunc(trustFile)
-	if err != nil {
-		// If the reading produces an error, we will skip the update for this
-		// round and log the error.
-		logger.Warningf("readTrustCertFunc reads %v failed: %v", trustFile, err)
-		return
-	}
-	distributor.Set(&certprovider.KeyMaterial{Roots: trustPool}, nil)
-}
-
-// NewPEMFileProvider returns a new PEMFileProvider constructed using the
-// provided options.
-func NewPEMFileProvider(o PEMFileProviderOptions) (*PEMFileProvider, error) {
-	if o.CertFile == "" && o.KeyFile == "" && o.TrustFile == "" {
-		return nil, fmt.Errorf("at least one credential file needs to be specified")
-	}
-	if keySpecified, certSpecified := o.KeyFile != "", o.CertFile != ""; keySpecified != certSpecified {
-		return nil, fmt.Errorf("private key file and identity cert file should be both specified or not specified")
-	}
-	if o.IdentityInterval == 0 {
-		o.IdentityInterval = defaultIdentityInterval
-	}
-	if o.RootInterval == 0 {
-		o.RootInterval = defaultRootInterval
-	}
-	provider := &PEMFileProvider{}
-	if o.CertFile != "" && o.KeyFile != "" {
-		provider.identityDistributor = certprovider.NewDistributor()
-	}
-	if o.TrustFile != "" {
-		provider.rootDistributor = certprovider.NewDistributor()
-	}
-	// A goroutine to pull file changes.
-	identityTicker := time.NewTicker(o.IdentityInterval)
-	rootTicker := time.NewTicker(o.RootInterval)
-	ctx, cancel := context.WithCancel(context.Background())
-
-	go func() {
-		for {
-			updateIdentityDistributor(provider.identityDistributor, o.CertFile, o.KeyFile)
-			updateRootDistributor(provider.rootDistributor, o.TrustFile)
-			select {
-			case <-ctx.Done():
-				identityTicker.Stop()
-				rootTicker.Stop()
-				return
-			case <-identityTicker.C:
-				break
-			case <-rootTicker.C:
-				break
-			}
-		}
-	}()
-	provider.cancel = cancel
-	return provider, nil
-}
-
-// KeyMaterial returns the key material sourced by the PEMFileProvider.
-// Callers are expected to use the returned value as read-only.
-func (p *PEMFileProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
-	km := &certprovider.KeyMaterial{}
-	if p.identityDistributor != nil {
-		identityKM, err := p.identityDistributor.KeyMaterial(ctx)
-		if err != nil {
-			return nil, err
-		}
-		km.Certs = identityKM.Certs
-	}
-	if p.rootDistributor != nil {
-		rootKM, err := p.rootDistributor.KeyMaterial(ctx)
-		if err != nil {
-			return nil, err
-		}
-		km.Roots = rootKM.Roots
-	}
-	return km, nil
-}
-
-// Close cleans up resources allocated by the PEMFileProvider.
-func (p *PEMFileProvider) Close() {
-	p.cancel()
-	if p.identityDistributor != nil {
-		p.identityDistributor.Stop()
-	}
-	if p.rootDistributor != nil {
-		p.rootDistributor.Stop()
-	}
-}
diff --git a/security/advancedtls/pemfile_provider_test.go b/security/advancedtls/pemfile_provider_test.go
deleted file mode 100644
index 48e0bd2..0000000
--- a/security/advancedtls/pemfile_provider_test.go
+++ /dev/null
@@ -1,220 +0,0 @@
-/*
- *
- * Copyright 2020 gRPC authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-
-package advancedtls
-
-import (
-	"context"
-	"crypto/tls"
-	"crypto/x509"
-	"fmt"
-	"math/big"
-	"testing"
-	"time"
-
-	"github.com/google/go-cmp/cmp"
-	"google.golang.org/grpc/credentials/tls/certprovider"
-	"google.golang.org/grpc/security/advancedtls/internal/testutils"
-	"google.golang.org/grpc/security/advancedtls/testdata"
-)
-
-func (s) TestNewPEMFileProvider(t *testing.T) {
-	tests := []struct {
-		desc      string
-		options   PEMFileProviderOptions
-		certFile  string
-		keyFile   string
-		trustFile string
-		wantError bool
-	}{
-		{
-			desc:      "Expect error if no credential files specified",
-			options:   PEMFileProviderOptions{},
-			wantError: true,
-		},
-		{
-			desc: "Expect error if only certFile is specified",
-			options: PEMFileProviderOptions{
-				CertFile: testdata.Path("client_cert_1.pem"),
-			},
-			wantError: true,
-		},
-		{
-			desc: "Should be good if only identity key cert pairs are specified",
-			options: PEMFileProviderOptions{
-				KeyFile:  testdata.Path("client_key_1.pem"),
-				CertFile: testdata.Path("client_cert_1.pem"),
-			},
-			wantError: false,
-		},
-		{
-			desc: "Should be good if only root certs are specified",
-			options: PEMFileProviderOptions{
-				TrustFile: testdata.Path("client_trust_cert_1.pem"),
-			},
-			wantError: false,
-		},
-		{
-			desc: "Should be good if both identity pairs and root certs are specified",
-			options: PEMFileProviderOptions{
-				KeyFile:   testdata.Path("client_key_1.pem"),
-				CertFile:  testdata.Path("client_cert_1.pem"),
-				TrustFile: testdata.Path("client_trust_cert_1.pem"),
-			},
-			wantError: false,
-		},
-	}
-	for _, test := range tests {
-		t.Run(test.desc, func(t *testing.T) {
-			provider, err := NewPEMFileProvider(test.options)
-			if (err != nil) != test.wantError {
-				t.Fatalf("NewPEMFileProvider(%v) = %v, want %v", test.options, err, test.wantError)
-			}
-			if err != nil {
-				return
-			}
-			provider.Close()
-		})
-	}
-
-}
-
-// This test overwrites the credential reading function used by the watching
-// goroutine. It is tested under different stages:
-// At stage 0, we force reading function to load ClientCert1 and ServerTrust1,
-// and see if the credentials are picked up by the watching go routine.
-// At stage 1, we force reading function to cause an error. The watching go
-// routine should log the error while leaving the credentials unchanged.
-// At stage 2, we force reading function to load ClientCert2 and ServerTrust2,
-// and see if the new credentials are picked up.
-func (s) TestWatchingRoutineUpdates(t *testing.T) {
-	// Load certificates.
-	cs := &testutils.CertStore{}
-	if err := cs.LoadCerts(); err != nil {
-		t.Fatalf("cs.LoadCerts() failed, err: %v", err)
-	}
-	tests := []struct {
-		desc         string
-		options      PEMFileProviderOptions
-		wantKmStage0 certprovider.KeyMaterial
-		wantKmStage1 certprovider.KeyMaterial
-		wantKmStage2 certprovider.KeyMaterial
-	}{
-		{
-			desc: "use identity certs and root certs",
-			options: PEMFileProviderOptions{
-				CertFile:  "not_empty_cert_file",
-				KeyFile:   "not_empty_key_file",
-				TrustFile: "not_empty_trust_file",
-			},
-			wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
-			wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
-			wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}, Roots: cs.ServerTrust2},
-		},
-		{
-			desc: "use identity certs only",
-			options: PEMFileProviderOptions{
-				CertFile: "not_empty_cert_file",
-				KeyFile:  "not_empty_key_file",
-			},
-			wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
-			wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
-			wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}},
-		},
-		{
-			desc: "use trust certs only",
-			options: PEMFileProviderOptions{
-				TrustFile: "not_empty_trust_file",
-			},
-			wantKmStage0: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
-			wantKmStage1: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
-			wantKmStage2: certprovider.KeyMaterial{Roots: cs.ServerTrust2},
-		},
-	}
-	for _, test := range tests {
-		testInterval := 200 * time.Millisecond
-		test.options.IdentityInterval = testInterval
-		test.options.RootInterval = testInterval
-		t.Run(test.desc, func(t *testing.T) {
-			stage := &stageInfo{}
-			oldReadKeyCertPairFunc := readKeyCertPairFunc
-			readKeyCertPairFunc = func(certFile, keyFile string) (tls.Certificate, error) {
-				switch stage.read() {
-				case 0:
-					return cs.ClientCert1, nil
-				case 1:
-					return tls.Certificate{}, fmt.Errorf("error occurred while reloading")
-				case 2:
-					return cs.ClientCert2, nil
-				default:
-					return tls.Certificate{}, fmt.Errorf("test stage not supported")
-				}
-			}
-			defer func() {
-				readKeyCertPairFunc = oldReadKeyCertPairFunc
-			}()
-			oldReadTrustCertFunc := readTrustCertFunc
-			readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) {
-				switch stage.read() {
-				case 0:
-					return cs.ServerTrust1, nil
-				case 1:
-					return nil, fmt.Errorf("error occurred while reloading")
-				case 2:
-					return cs.ServerTrust2, nil
-				default:
-					return nil, fmt.Errorf("test stage not supported")
-				}
-			}
-			defer func() {
-				readTrustCertFunc = oldReadTrustCertFunc
-			}()
-			provider, err := NewPEMFileProvider(test.options)
-			if err != nil {
-				t.Fatalf("NewPEMFileProvider failed: %v", err)
-			}
-			defer provider.Close()
-			ctx, cancel := context.WithCancel(context.Background())
-			defer cancel()
-			//// ------------------------Stage 0------------------------------------
-			// Wait for the refreshing go-routine to pick up the changes.
-			time.Sleep(1 * time.Second)
-			gotKM, err := provider.KeyMaterial(ctx)
-			if !cmp.Equal(*gotKM, test.wantKmStage0, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
-				t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage0)
-			}
-			// ------------------------Stage 1------------------------------------
-			stage.increase()
-			// Wait for the refreshing go-routine to pick up the changes.
-			time.Sleep(1 * time.Second)
-			gotKM, err = provider.KeyMaterial(ctx)
-			if !cmp.Equal(*gotKM, test.wantKmStage1, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
-				t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage1)
-			}
-			//// ------------------------Stage 2------------------------------------
-			// Wait for the refreshing go-routine to pick up the changes.
-			stage.increase()
-			time.Sleep(1 * time.Second)
-			gotKM, err = provider.KeyMaterial(ctx)
-			if !cmp.Equal(*gotKM, test.wantKmStage2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
-				t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage2)
-			}
-			stage.reset()
-		})
-	}
-}