Merge pull request #35265 from cpuguy83/32609_defreference_voldriver_on_error

Fixup some issues with plugin refcounting
diff --git a/integration/plugin/volume/cmd/cmd_test.go b/integration/plugin/volume/cmd/cmd_test.go
new file mode 100644
index 0000000..1d619dd
--- /dev/null
+++ b/integration/plugin/volume/cmd/cmd_test.go
@@ -0,0 +1 @@
+package cmd
diff --git a/integration/plugin/volume/cmd/create-error/main.go b/integration/plugin/volume/cmd/create-error/main.go
new file mode 100644
index 0000000..f23be51
--- /dev/null
+++ b/integration/plugin/volume/cmd/create-error/main.go
@@ -0,0 +1,23 @@
+package main
+
+import (
+	"net"
+	"net/http"
+)
+
+func main() {
+	l, err := net.Listen("unix", "/run/docker/plugins/plugin.sock")
+	if err != nil {
+		panic(err)
+	}
+
+	mux := http.NewServeMux()
+	server := http.Server{
+		Addr:    l.Addr().String(),
+		Handler: http.NewServeMux(),
+	}
+	mux.HandleFunc("/VolumeDriver.Create", func(w http.ResponseWriter, r *http.Request) {
+		http.Error(w, "error during create", http.StatusInternalServerError)
+	})
+	server.Serve(l)
+}
diff --git a/integration/plugin/volume/cmd/create-error/main_test.go b/integration/plugin/volume/cmd/create-error/main_test.go
new file mode 100644
index 0000000..06ab7d0
--- /dev/null
+++ b/integration/plugin/volume/cmd/create-error/main_test.go
@@ -0,0 +1 @@
+package main
diff --git a/integration/plugin/volume/create_test.go b/integration/plugin/volume/create_test.go
new file mode 100644
index 0000000..ce9b4dc
--- /dev/null
+++ b/integration/plugin/volume/create_test.go
@@ -0,0 +1,51 @@
+// +build linux
+
+package volume
+
+import (
+	"context"
+	"testing"
+
+	"github.com/docker/docker/api/types"
+	"github.com/docker/docker/api/types/volume"
+	"github.com/docker/docker/integration-cli/daemon"
+)
+
+// TestCreateDerefOnError ensures that if a volume create fails, that the plugin is dereferenced
+// Normally 1 volume == 1 reference to a plugin, which prevents a plugin from being removed.
+// If the volume create fails, we should make sure to dereference the plugin.
+func TestCreateDerefOnError(t *testing.T) {
+	t.Parallel()
+
+	d := daemon.New(t, "", dockerdBinary, daemon.Config{})
+	d.Start(t)
+	defer d.Stop(t)
+
+	c, err := d.NewClient()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	pName := "testderef"
+	createPlugin(t, c, pName, "create-error", asVolumeDriver)
+
+	if err := c.PluginEnable(context.Background(), pName, types.PluginEnableOptions{Timeout: 30}); err != nil {
+		t.Fatal(err)
+	}
+
+	_, err = c.VolumeCreate(context.Background(), volume.VolumesCreateBody{
+		Driver: pName,
+		Name:   "fake",
+	})
+	if err == nil {
+		t.Fatal("volume create should have failed")
+	}
+
+	if err := c.PluginDisable(context.Background(), pName, types.PluginDisableOptions{}); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := c.PluginRemove(context.Background(), pName, types.PluginRemoveOptions{}); err != nil {
+		t.Fatal(err)
+	}
+}
diff --git a/integration/plugin/volume/main_test.go b/integration/plugin/volume/main_test.go
new file mode 100644
index 0000000..8cfbf37
--- /dev/null
+++ b/integration/plugin/volume/main_test.go
@@ -0,0 +1,69 @@
+package volume
+
+import (
+	"context"
+	"os"
+	"os/exec"
+	"path/filepath"
+	"testing"
+	"time"
+
+	"github.com/docker/docker/api/types"
+	"github.com/docker/docker/integration-cli/fixtures/plugin"
+	"github.com/docker/docker/pkg/locker"
+	"github.com/pkg/errors"
+)
+
+const dockerdBinary = "dockerd"
+
+var pluginBuildLock = locker.New()
+
+func ensurePlugin(t *testing.T, name string) string {
+	pluginBuildLock.Lock(name)
+	defer pluginBuildLock.Unlock(name)
+
+	installPath := filepath.Join(os.Getenv("GOPATH"), "bin", name)
+	if _, err := os.Stat(installPath); err == nil {
+		return installPath
+	}
+
+	goBin, err := exec.LookPath("go")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	cmd := exec.Command(goBin, "build", "-o", installPath, "./"+filepath.Join("cmd", name))
+	cmd.Env = append(cmd.Env, "CGO_ENABLED=0")
+	if out, err := cmd.CombinedOutput(); err != nil {
+		t.Fatal(errors.Wrapf(err, "error building basic plugin bin: %s", string(out)))
+	}
+
+	return installPath
+}
+
+func asVolumeDriver(cfg *plugin.Config) {
+	cfg.Interface.Types = []types.PluginInterfaceType{
+		{Capability: "volumedriver", Prefix: "docker", Version: "1.0"},
+	}
+}
+
+func withSockPath(name string) func(*plugin.Config) {
+	return func(cfg *plugin.Config) {
+		cfg.Interface.Socket = name
+	}
+}
+
+func createPlugin(t *testing.T, client plugin.CreateClient, alias, bin string, opts ...plugin.CreateOpt) {
+	pluginBin := ensurePlugin(t, bin)
+
+	opts = append(opts, withSockPath("plugin.sock"))
+	opts = append(opts, plugin.WithBinary(pluginBin))
+
+	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
+	err := plugin.Create(ctx, client, alias, opts...)
+	cancel()
+
+	if err != nil {
+		t.Fatal(err)
+	}
+}
diff --git a/pkg/plugingetter/getter.go b/pkg/plugingetter/getter.go
index b04b7bc..b9ffa54 100644
--- a/pkg/plugingetter/getter.go
+++ b/pkg/plugingetter/getter.go
@@ -1,6 +1,8 @@
 package plugingetter
 
-import "github.com/docker/docker/pkg/plugins"
+import (
+	"github.com/docker/docker/pkg/plugins"
+)
 
 const (
 	// Lookup doesn't update RefCount
diff --git a/plugin/store.go b/plugin/store.go
index 8349f34..b339814 100644
--- a/plugin/store.go
+++ b/plugin/store.go
@@ -115,10 +115,15 @@
 	if ps != nil {
 		p, err := ps.GetV2Plugin(name)
 		if err == nil {
-			p.AddRefCount(mode)
 			if p.IsEnabled() {
-				return p.FilterByCap(capability)
+				fp, err := p.FilterByCap(capability)
+				if err != nil {
+					return nil, err
+				}
+				p.AddRefCount(mode)
+				return fp, nil
 			}
+
 			// Plugin was found but it is disabled, so we should not fall back to legacy plugins
 			// but we should error out right away
 			return nil, errDisabled(name)
diff --git a/plugin/store_test.go b/plugin/store_test.go
index d3876da..5c61cc6 100644
--- a/plugin/store_test.go
+++ b/plugin/store_test.go
@@ -4,6 +4,7 @@
 	"testing"
 
 	"github.com/docker/docker/api/types"
+	"github.com/docker/docker/pkg/plugingetter"
 	"github.com/docker/docker/plugin/v2"
 )
 
@@ -31,3 +32,33 @@
 		t.Fatalf("expected no error, got %v", err)
 	}
 }
+
+func TestStoreGetPluginNotMatchCapRefs(t *testing.T) {
+	s := NewStore()
+	p := v2.Plugin{PluginObj: types.Plugin{Name: "test:latest"}}
+
+	iType := types.PluginInterfaceType{Capability: "whatever", Prefix: "docker", Version: "1.0"}
+	i := types.PluginConfigInterface{Socket: "plugins.sock", Types: []types.PluginInterfaceType{iType}}
+	p.PluginObj.Config.Interface = i
+
+	if err := s.Add(&p); err != nil {
+		t.Fatal(err)
+	}
+
+	if _, err := s.Get("test", "volumedriver", plugingetter.Acquire); err == nil {
+		t.Fatal("exepcted error when getting plugin that doesn't match the passed in capability")
+	}
+
+	if refs := p.GetRefCount(); refs != 0 {
+		t.Fatalf("reference count should be 0, got: %d", refs)
+	}
+
+	p.PluginObj.Enabled = true
+	if _, err := s.Get("test", "volumedriver", plugingetter.Acquire); err == nil {
+		t.Fatal("exepcted error when getting plugin that doesn't match the passed in capability")
+	}
+
+	if refs := p.GetRefCount(); refs != 0 {
+		t.Fatalf("reference count should be 0, got: %d", refs)
+	}
+}
diff --git a/volume/drivers/extpoint.go b/volume/drivers/extpoint.go
index ee42f2f..c360b37 100644
--- a/volume/drivers/extpoint.go
+++ b/volume/drivers/extpoint.go
@@ -11,6 +11,7 @@
 	getter "github.com/docker/docker/pkg/plugingetter"
 	"github.com/docker/docker/volume"
 	"github.com/pkg/errors"
+	"github.com/sirupsen/logrus"
 )
 
 // currently created by hand. generation tool would generate this like:
@@ -130,6 +131,12 @@
 
 		d := NewVolumeDriver(p.Name(), p.BasePath(), p.Client())
 		if err := validateDriver(d); err != nil {
+			if mode > 0 {
+				// Undo any reference count changes from the initial `Get`
+				if _, err := drivers.plugingetter.Get(name, extName, mode*-1); err != nil {
+					logrus.WithError(err).WithField("action", "validate-driver").WithField("plugin", name).Error("error releasing reference to plugin")
+				}
+			}
 			return nil, err
 		}
 
@@ -169,9 +176,9 @@
 	return lookup(name, getter.Acquire)
 }
 
-// RemoveDriver returns a volume driver by its name and decrements RefCount..
+// ReleaseDriver returns a volume driver by its name and decrements RefCount..
 // If the driver is empty, it looks for the local driver.
-func RemoveDriver(name string) (volume.Driver, error) {
+func ReleaseDriver(name string) (volume.Driver, error) {
 	if name == "" {
 		name = volume.DefaultDriverName
 	}
diff --git a/volume/store/store.go b/volume/store/store.go
index e47ec0e..5402c6b 100644
--- a/volume/store/store.go
+++ b/volume/store/store.go
@@ -145,7 +145,7 @@
 	s.globalLock.Lock()
 	v, exists := s.names[name]
 	if exists {
-		if _, err := volumedrivers.RemoveDriver(v.DriverName()); err != nil {
+		if _, err := volumedrivers.ReleaseDriver(v.DriverName()); err != nil {
 			logrus.Errorf("Error dereferencing volume driver: %v", err)
 		}
 	}
@@ -409,6 +409,9 @@
 	}
 	v, err = vd.Create(name, opts)
 	if err != nil {
+		if _, err := volumedrivers.ReleaseDriver(driverName); err != nil {
+			logrus.WithError(err).WithField("driver", driverName).Error("Error releasing reference to volume driver")
+		}
 		return nil, err
 	}
 	s.globalLock.Lock()