| package stdcopy |
| |
| import ( |
| "bytes" |
| "errors" |
| "io" |
| "io/ioutil" |
| "strings" |
| "testing" |
| ) |
| |
| func TestNewStdWriter(t *testing.T) { |
| writer := NewStdWriter(ioutil.Discard, Stdout) |
| if writer == nil { |
| t.Fatalf("NewStdWriter with an invalid StdType should not return nil.") |
| } |
| } |
| |
| func TestWriteWithUnitializedStdWriter(t *testing.T) { |
| writer := StdWriter{ |
| Writer: nil, |
| prefix: Stdout, |
| sizeBuf: make([]byte, 4), |
| } |
| n, err := writer.Write([]byte("Something here")) |
| if n != 0 || err == nil { |
| t.Fatalf("Should fail when given an uncomplete or uninitialized StdWriter") |
| } |
| } |
| |
| func TestWriteWithNilBytes(t *testing.T) { |
| writer := NewStdWriter(ioutil.Discard, Stdout) |
| n, err := writer.Write(nil) |
| if err != nil { |
| t.Fatalf("Shouldn't have fail when given no data") |
| } |
| if n > 0 { |
| t.Fatalf("Write should have written 0 byte, but has written %d", n) |
| } |
| } |
| |
| func TestWrite(t *testing.T) { |
| writer := NewStdWriter(ioutil.Discard, Stdout) |
| data := []byte("Test StdWrite.Write") |
| n, err := writer.Write(data) |
| if err != nil { |
| t.Fatalf("Error while writing with StdWrite") |
| } |
| if n != len(data) { |
| t.Fatalf("Write should have written %d byte but wrote %d.", len(data), n) |
| } |
| } |
| |
| type errWriter struct { |
| n int |
| err error |
| } |
| |
| func (f *errWriter) Write(buf []byte) (int, error) { |
| return f.n, f.err |
| } |
| |
| func TestWriteWithWriterError(t *testing.T) { |
| expectedError := errors.New("expected") |
| expectedReturnedBytes := 10 |
| writer := NewStdWriter(&errWriter{ |
| n: stdWriterPrefixLen + expectedReturnedBytes, |
| err: expectedError}, Stdout) |
| data := []byte("This won't get written, sigh") |
| n, err := writer.Write(data) |
| if err != expectedError { |
| t.Fatalf("Didn't get expected error.") |
| } |
| if n != expectedReturnedBytes { |
| t.Fatalf("Didn't get expected writen bytes %d, got %d.", |
| expectedReturnedBytes, n) |
| } |
| } |
| |
| func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) { |
| writer := NewStdWriter(&errWriter{n: -1}, Stdout) |
| data := []byte("This won't get written, sigh") |
| actual, _ := writer.Write(data) |
| if actual != 0 { |
| t.Fatalf("Expected returned written bytes equal to 0, got %d", actual) |
| } |
| } |
| |
| func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (buffer *bytes.Buffer, err error) { |
| buffer = new(bytes.Buffer) |
| dstOut := NewStdWriter(buffer, Stdout) |
| _, err = dstOut.Write(stdOutBytes) |
| if err != nil { |
| return |
| } |
| dstErr := NewStdWriter(buffer, Stderr) |
| _, err = dstErr.Write(stdErrBytes) |
| return |
| } |
| |
| func TestStdCopyWriteAndRead(t *testing.T) { |
| stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
| stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
| buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| written, err := StdCopy(ioutil.Discard, ioutil.Discard, buffer) |
| if err != nil { |
| t.Fatal(err) |
| } |
| expectedTotalWritten := len(stdOutBytes) + len(stdErrBytes) |
| if written != int64(expectedTotalWritten) { |
| t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written) |
| } |
| } |
| |
| type customReader struct { |
| n int |
| err error |
| totalCalls int |
| correctCalls int |
| src *bytes.Buffer |
| } |
| |
| func (f *customReader) Read(buf []byte) (int, error) { |
| f.totalCalls++ |
| if f.totalCalls <= f.correctCalls { |
| return f.src.Read(buf) |
| } |
| return f.n, f.err |
| } |
| |
| func TestStdCopyReturnsErrorReadingHeader(t *testing.T) { |
| expectedError := errors.New("error") |
| reader := &customReader{ |
| err: expectedError} |
| written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader) |
| if written != 0 { |
| t.Fatalf("Expected 0 bytes read, got %d", written) |
| } |
| if err != expectedError { |
| t.Fatalf("Didn't get expected error") |
| } |
| } |
| |
| func TestStdCopyReturnsErrorReadingFrame(t *testing.T) { |
| expectedError := errors.New("error") |
| stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
| stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
| buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| reader := &customReader{ |
| correctCalls: 1, |
| n: stdWriterPrefixLen + 1, |
| err: expectedError, |
| src: buffer} |
| written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader) |
| if written != 0 { |
| t.Fatalf("Expected 0 bytes read, got %d", written) |
| } |
| if err != expectedError { |
| t.Fatalf("Didn't get expected error") |
| } |
| } |
| |
| func TestStdCopyDetectsCorruptedFrame(t *testing.T) { |
| stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
| stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
| buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| reader := &customReader{ |
| correctCalls: 1, |
| n: stdWriterPrefixLen + 1, |
| err: io.EOF, |
| src: buffer} |
| written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader) |
| if written != startingBufLen { |
| t.Fatalf("Expected 0 bytes read, got %d", written) |
| } |
| if err != nil { |
| t.Fatal("Didn't get nil error") |
| } |
| } |
| |
| func TestStdCopyWithInvalidInputHeader(t *testing.T) { |
| dstOut := NewStdWriter(ioutil.Discard, Stdout) |
| dstErr := NewStdWriter(ioutil.Discard, Stderr) |
| src := strings.NewReader("Invalid input") |
| _, err := StdCopy(dstOut, dstErr, src) |
| if err == nil { |
| t.Fatal("StdCopy with invalid input header should fail.") |
| } |
| } |
| |
| func TestStdCopyWithCorruptedPrefix(t *testing.T) { |
| data := []byte{0x01, 0x02, 0x03} |
| src := bytes.NewReader(data) |
| written, err := StdCopy(nil, nil, src) |
| if err != nil { |
| t.Fatalf("StdCopy should not return an error with corrupted prefix.") |
| } |
| if written != 0 { |
| t.Fatalf("StdCopy should have written 0, but has written %d", written) |
| } |
| } |
| |
| func TestStdCopyReturnsWriteErrors(t *testing.T) { |
| stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
| stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
| buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| expectedError := errors.New("expected") |
| |
| dstOut := &errWriter{err: expectedError} |
| |
| written, err := StdCopy(dstOut, ioutil.Discard, buffer) |
| if written != 0 { |
| t.Fatalf("StdCopy should have written 0, but has written %d", written) |
| } |
| if err != expectedError { |
| t.Fatalf("Didn't get expected error, got %v", err) |
| } |
| } |
| |
| func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) { |
| stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
| stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
| buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
| if err != nil { |
| t.Fatal(err) |
| } |
| dstOut := &errWriter{n: startingBufLen - 10} |
| |
| written, err := StdCopy(dstOut, ioutil.Discard, buffer) |
| if written != 0 { |
| t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written) |
| } |
| if err != io.ErrShortWrite { |
| t.Fatalf("Didn't get expected io.ErrShortWrite error") |
| } |
| } |
| |
| func BenchmarkWrite(b *testing.B) { |
| w := NewStdWriter(ioutil.Discard, Stdout) |
| data := []byte("Test line for testing stdwriter performance\n") |
| data = bytes.Repeat(data, 100) |
| b.SetBytes(int64(len(data))) |
| b.ResetTimer() |
| for i := 0; i < b.N; i++ { |
| if _, err := w.Write(data); err != nil { |
| b.Fatal(err) |
| } |
| } |
| } |