Merge pull request #155 from erickt/atomic

Atomically write metadata to file system
diff --git a/src/repository.rs b/src/repository.rs
index 2878ef0..8c851f7 100644
--- a/src/repository.rs
+++ b/src/repository.rs
@@ -5,10 +5,10 @@
 use hyper::status::StatusCode;
 use hyper::{Client, Url};
 use std::collections::HashMap;
-use std::fs::{self, DirBuilder, File};
-use std::io::{Cursor, Read, Write};
+use std::fs::{DirBuilder, File};
+use std::io::{self, Cursor, Read, Write};
 use std::marker::PhantomData;
-use std::path::PathBuf;
+use std::path::{Path, PathBuf};
 use std::sync::{Arc, RwLock};
 use tempfile::NamedTempFile;
 
@@ -129,25 +129,18 @@
         M: Metadata,
     {
         Self::check::<M>(meta_path)?;
-        let components = meta_path.components::<D>(version);
 
         let mut path = self.local_path.join("metadata");
-        path.extend(&components);
+        path.extend(meta_path.components::<D>(version));
 
         if path.exists() {
-            debug!("Metadata path exists. Deleting: {:?}", path);
-            fs::remove_file(&path)?
+            debug!("Metadata path exists. Overwriting: {:?}", path);
         }
 
-        if components.len() > 1 {
-            let mut path = self.local_path.clone();
-            path.extend(&components[..(components.len() - 1)]);
-            DirBuilder::new().recursive(true).create(path)?;
-        }
-
-        let mut file = File::create(&path)?;
-        D::to_writer(&mut file, metadata)?;
-        Ok(())
+        atomically_write(&path, |write| {
+            D::to_writer(write, metadata)?;
+            Ok(())
+        })
     }
 
     /// Fetch signed metadata.
@@ -181,29 +174,17 @@
     where
         R: Read,
     {
-        let mut temp_file = NamedTempFile::new_in(self.local_path.join("temp"))?;
-        let mut buf = [0; 1024];
-        loop {
-            let bytes_read = read.read(&mut buf)?;
-            if bytes_read == 0 {
-                break;
-            }
-            temp_file.write_all(&buf[..bytes_read])?
+        let mut path = self.local_path.join("targets");
+        path.extend(target_path.components());
+
+        if path.exists() {
+            debug!("Target path exists. Overwriting: {:?}", path);
         }
 
-        let mut path = self.local_path.clone().join("targets");
-        let components = target_path.components();
-
-        if components.len() > 1 {
-            let mut path = path.clone();
-            path.extend(&components[..(components.len() - 1)]);
-            DirBuilder::new().recursive(true).create(path)?;
-        }
-
-        path.extend(components);
-        temp_file.persist(&path)?;
-
-        Ok(())
+        atomically_write(&path, |write| {
+            io::copy(&mut read, write)?;
+            Ok(())
+        })
     }
 
     fn fetch_target(
@@ -230,6 +211,28 @@
     }
 }
 
+fn atomically_write<F>(path: &Path, mut f: F) -> Result<()>
+where
+    F: FnMut(&mut Write) -> Result<()>,
+{
+    // We want to atomically write the file to make sure clients can never see a partially written file.
+    // In order to do this, we'll write to a temporary file in the same directory as our target, otherwise
+    // we risk writing the temporary file to one mountpoint, and then non-atomically copying the file to another mountpoint.
+
+    let mut temp_file = if let Some(parent) = path.parent() {
+        DirBuilder::new().recursive(true).create(parent)?;
+        NamedTempFile::new_in(parent)?
+    } else {
+        NamedTempFile::new_in(".")?
+    };
+
+    f(&mut temp_file)?;
+
+    temp_file.persist(&path)?;
+
+    Ok(())
+}
+
 /// A repository accessible over HTTP.
 pub struct HttpRepository<D>
 where