Merge pull request #272 from pkg/271-alt
fix race condition in fileget/fileput
diff --git a/request.go b/request.go
index d668c29..8188c77 100644
--- a/request.go
+++ b/request.go
@@ -116,34 +116,12 @@
}
// manage file read/write state
-func (r *Request) setWriterState(wa io.WriterAt) {
- r.state.Lock()
- defer r.state.Unlock()
- r.state.writerAt = wa
-}
-func (r *Request) setReaderState(ra io.ReaderAt) {
- r.state.Lock()
- defer r.state.Unlock()
- r.state.readerAt = ra
-}
func (r *Request) setListerState(la ListerAt) {
r.state.Lock()
defer r.state.Unlock()
r.state.listerAt = la
}
-func (r *Request) getWriter() io.WriterAt {
- r.state.RLock()
- defer r.state.RUnlock()
- return r.state.writerAt
-}
-
-func (r *Request) getReader() io.ReaderAt {
- r.state.RLock()
- defer r.state.RUnlock()
- return r.state.readerAt
-}
-
func (r *Request) getLister() ListerAt {
r.state.RLock()
defer r.state.RUnlock()
@@ -157,11 +135,15 @@
r.cancelCtx()
}
}()
- rd := r.getReader()
+ r.state.RLock()
+ rd := r.state.readerAt
+ r.state.RUnlock()
if c, ok := rd.(io.Closer); ok {
return c.Close()
}
- wt := r.getWriter()
+ r.state.RLock()
+ wt := r.state.writerAt
+ r.state.RUnlock()
if c, ok := wt.(io.Closer); ok {
return c.Close()
}
@@ -204,13 +186,20 @@
// wrap FileReader handler
func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket {
var err error
- reader := r.getReader()
+ r.state.RLock()
+ reader := r.state.readerAt
+ r.state.RUnlock()
if reader == nil {
- reader, err = h.Fileread(r)
- if err != nil {
- return statusFromError(pkt, err)
+ r.state.Lock()
+ if r.state.readerAt == nil {
+ r.state.readerAt, err = h.Fileread(r)
+ if err != nil {
+ r.state.Unlock()
+ return statusFromError(pkt, err)
+ }
}
- r.setReaderState(reader)
+ reader = r.state.readerAt
+ r.state.Unlock()
}
_, offset, length := packetData(pkt)
@@ -230,13 +219,20 @@
// wrap FileWriter handler
func fileput(h FileWriter, r *Request, pkt requestPacket) responsePacket {
var err error
- writer := r.getWriter()
+ r.state.RLock()
+ writer := r.state.writerAt
+ r.state.RUnlock()
if writer == nil {
- writer, err = h.Filewrite(r)
- if err != nil {
- return statusFromError(pkt, err)
+ r.state.Lock()
+ if r.state.writerAt == nil {
+ r.state.writerAt, err = h.Filewrite(r)
+ if err != nil {
+ r.state.Unlock()
+ return statusFromError(pkt, err)
+ }
}
- r.setWriterState(writer)
+ writer = r.state.writerAt
+ r.state.Unlock()
}
data, offset, _ := packetData(pkt)