Fix a race in temp file creation. (#565)
The added test fails before the fix and passes after.
Fixes #500.
diff --git a/internal/driver/tempfile.go b/internal/driver/tempfile.go
index 28679f1..b6c8776 100644
--- a/internal/driver/tempfile.go
+++ b/internal/driver/tempfile.go
@@ -24,9 +24,11 @@
// newTempFile returns a new output file in dir with the provided prefix and suffix.
func newTempFile(dir, prefix, suffix string) (*os.File, error) {
for index := 1; index < 10000; index++ {
- path := filepath.Join(dir, fmt.Sprintf("%s%03d%s", prefix, index, suffix))
- if _, err := os.Stat(path); err != nil {
- return os.Create(path)
+ switch f, err := os.OpenFile(filepath.Join(dir, fmt.Sprintf("%s%03d%s", prefix, index, suffix)), os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666); {
+ case err == nil:
+ return f, nil
+ case !os.IsExist(err):
+ return nil, err
}
}
// Give up
@@ -44,11 +46,15 @@
}
// cleanupTempFiles removes any temporary files selected for deferred cleaning.
-func cleanupTempFiles() {
+func cleanupTempFiles() error {
tempFilesMu.Lock()
+ defer tempFilesMu.Unlock()
+ var lastErr error
for _, f := range tempFiles {
- os.Remove(f)
+ if err := os.Remove(f); err != nil {
+ lastErr = err
+ }
}
tempFiles = nil
- tempFilesMu.Unlock()
+ return lastErr
}
diff --git a/internal/driver/tempfile_test.go b/internal/driver/tempfile_test.go
new file mode 100644
index 0000000..7004353
--- /dev/null
+++ b/internal/driver/tempfile_test.go
@@ -0,0 +1,55 @@
+package driver
+
+import (
+ "os"
+ "sync"
+ "testing"
+)
+
+func TestNewTempFile(t *testing.T) {
+ const n = 100
+ // Line up ready to execute goroutines with a read-write lock.
+ var mu sync.RWMutex
+ mu.Lock()
+ var wg sync.WaitGroup
+ errc := make(chan error, n)
+ for i := 0; i < n; i++ {
+ wg.Add(1)
+ go func() {
+ mu.RLock()
+ defer mu.RUnlock()
+ defer wg.Done()
+ f, err := newTempFile(os.TempDir(), "profile", ".tmp")
+ errc <- err
+ deferDeleteTempFile(f.Name())
+ f.Close()
+ }()
+ }
+ // Start the file creation race.
+ mu.Unlock()
+ // Wait for the goroutines to finish.
+ wg.Wait()
+
+ for i := 0; i < n; i++ {
+ if err := <-errc; err != nil {
+ t.Fatalf("newTempFile(): got %v, want no error", err)
+ }
+ }
+ if len(tempFiles) != n {
+ t.Errorf("len(tempFiles): got %d, want %d", len(tempFiles), n)
+ }
+ names := map[string]bool{}
+ for _, name := range tempFiles {
+ if names[name] {
+ t.Errorf("got temp file %s created multiple times", name)
+ break
+ }
+ names[name] = true
+ }
+ if err := cleanupTempFiles(); err != nil {
+ t.Errorf("cleanupTempFiles(): got error %v, want no error", err)
+ }
+ if len(tempFiles) != 0 {
+ t.Errorf("len(tempFiles) after the cleanup: got %d, want 0", len(tempFiles))
+ }
+}