blob: 18ca44d005f0b28b7a1b0cf23b77de51aff83e4e [file] [log] [blame]
// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package bootserver
import (
"bytes"
"compress/flate"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"path/filepath"
"testing"
constants "go.fuchsia.dev/fuchsia/tools/bootserver/bootserverconstants"
"go.fuchsia.dev/fuchsia/tools/lib/osmisc"
)
func TestDownloadImagesToDir(t *testing.T) {
tmpDir := t.TempDir()
var imgs []Image
numImages := 4
for i := 0; i < numImages; i++ {
imgs = append(imgs, Image{
Name: fmt.Sprintf("image%d", i),
Reader: bytes.NewReader([]byte(fmt.Sprintf("content of image%d", i))),
Args: []string{"--arg"},
})
}
// Add another image without Args. This image should not be downloaded.
imgs = append(imgs, Image{
Name: "noArgsImage",
Reader: bytes.NewReader([]byte("content of noArgsImage")),
})
newImgs, closeFunc, err := downloadImagesToDir(context.Background(), tmpDir, imgs)
if err != nil {
t.Fatalf("failed to download image: %v", err)
}
defer closeFunc()
if len(newImgs) != numImages {
t.Errorf("unexpected number of images downloaded; expected: %d, actual: %d", numImages, len(newImgs))
}
for _, img := range newImgs {
if img.Name == "noArgsImage" {
t.Errorf("downloaded an image with no args")
}
content, err := ioutil.ReadFile(filepath.Join(tmpDir, img.Name))
if err != nil {
t.Fatalf("failed to read file: %v", err)
}
expectedData := fmt.Sprintf("content of %s", img.Name)
if string(content) != expectedData {
t.Errorf("unexpected content: expected: %s, actual: %s", expectedData, content)
}
if int(img.Size) != len(content) {
t.Errorf("incorrect size: expected: %d, actual: %d", img.Size, len(content))
}
}
}
// A mock tftp.Client that supports setting expectations on Write() calls.
type mockTftpClient struct {
// Testing class to mark any failures.
t *testing.T
// Expected files that should be written (in order), with the error to
// return from the corresponding call to Write().
expectedWrites []expectedWrite
// Expected output from Read().
expectedReads []string
}
type expectedWrite struct {
// FTP filename to expect.
filename string
// Error to return from Write().
ret error
}
func (c *mockTftpClient) Read(_ context.Context, _ string) (*bytes.Reader, error) {
expected := c.expectedReads[0]
c.expectedReads = c.expectedReads[1:]
return bytes.NewReader([]byte(expected)), nil
}
func (c *mockTftpClient) RemoteAddr() *net.UDPAddr {
c.t.Fatal("Unexpected call to mockTftpClient.RemoteAddr()")
panic("notreached")
}
func (c *mockTftpClient) Write(_ context.Context, filename string, _ io.ReaderAt, _ int64) error {
if len(c.expectedWrites) == 0 {
c.t.Fatalf("No writes expected but got %q", filename)
panic("notreached")
}
expected := c.expectedWrites[0]
if expected.filename != filename {
c.t.Fatalf("Expected %q but got %q", expected.filename, filename)
panic("notreached")
}
c.expectedWrites = c.expectedWrites[1:]
return expected.ret
}
// Creates a test Image with the given bootserver arg and small contents.
func testImage(arg string) Image {
return Image{
Reader: bytes.NewReader([]byte(fmt.Sprintf("image contents for %q arg", arg))),
Args: []string{arg},
}
}
// Creates Images based on imageArgs, then calls transferImages() and verifies
// that the expected writes were called.
func validateTransferImages(t *testing.T, imageArgs []string, expectedWrites []expectedWrite) {
// Convert the bootserver args to test Images, contents don't matter.
images := []Image{}
for _, arg := range imageArgs {
images = append(images, testImage(arg))
}
client := mockTftpClient{t, expectedWrites, []string{}}
_, err := transferImages(context.Background(), &client, images, nil, nil)
if err != nil {
t.Errorf("transferImages() failed: %v", err)
}
// Make sure all the expected writes were consumed.
if len(client.expectedWrites) > 0 {
t.Errorf("Expected writes were never made: %+v\n", client.expectedWrites)
}
}
func TestTransferImagesZirconA(t *testing.T) {
validateTransferImages(
t,
[]string{"--zircona"},
[]expectedWrite{{filename: "<<image>>zircona.img"}},
)
}
func TestTransferImagesUntypedFirmware(t *testing.T) {
validateTransferImages(
t,
[]string{"--firmware"},
[]expectedWrite{{filename: "<<image>>firmware_"}},
)
}
func TestTransferImagesUntypedFirmwareTrailingDash(t *testing.T) {
validateTransferImages(
t,
[]string{"--firmware-"},
[]expectedWrite{{filename: "<<image>>firmware_"}},
)
}
func TestTransferImagesTypedFirmware(t *testing.T) {
validateTransferImages(
t,
[]string{"--firmware-foo"},
[]expectedWrite{{filename: "<<image>>firmware_foo"}},
)
}
func TestTransferImagesOrdering(t *testing.T) {
validateTransferImages(
t,
[]string{
"--vbmetab",
"--zircona",
"--firmware-foo",
"--vbmetaa",
"--firmware",
"--zirconb",
},
[]expectedWrite{
{filename: "<<image>>firmware_"},
{filename: "<<image>>firmware_foo"},
{filename: "<<image>>zircona.img"},
{filename: "<<image>>zirconb.img"},
{filename: "<<image>>vbmetaa.img"},
{filename: "<<image>>vbmetab.img"},
},
)
}
func TestTransferImagesSkipFirmwareFailure(t *testing.T) {
// Transfer should skip a failed firmware write and continue to send
// the remaining images.
validateTransferImages(
t,
[]string{
"--firmware",
"--zircona",
"--zirconb",
},
[]expectedWrite{
{filename: "<<image>>firmware_", ret: errors.New("expected failure")},
{filename: "<<image>>zircona.img"},
{filename: "<<image>>zirconb.img"},
},
)
}
func TestValidateBoard(t *testing.T) {
for _, test := range []struct {
name string
board string
expectedRead string
wantErr bool
}{
{
name: "board is valid",
board: "x64",
expectedRead: "x64",
wantErr: false,
},
{
name: "null-terminated board is valid",
board: "x64",
expectedRead: "x64\x00",
wantErr: false,
},
{
name: "board is invalid",
board: "x64",
expectedRead: "arm64",
wantErr: true,
},
} {
t.Run(test.name, func(t *testing.T) {
client := &mockTftpClient{t: t, expectedReads: []string{test.expectedRead}}
err := ValidateBoard(context.Background(), client, test.board)
if test.wantErr != (err != nil) {
t.Errorf("failed to validate board; want err: %v, err: %v", test.wantErr, err)
}
})
}
}
func TestDownloadWithRetries(t *testing.T) {
// Temporarily override the global variable to avoid sleeping during tests.
originalSleep := downloadRetrySleep
downloadRetrySleep = 0
defer func() {
downloadRetrySleep = originalSleep
}()
tests := []struct {
name string
wantErr bool
// errFunc is a function that determines the fake error the download
// function should return, based on the index of the attempt.
errFunc func(attempt int) error
wantAttempts int
}{
{
name: "succeeds",
errFunc: func(_ int) error {
return nil
},
wantAttempts: 1,
},
{
name: "exits immediately after non-transient failure",
errFunc: func(_ int) error {
return errors.New("failure")
},
wantAttempts: 1,
wantErr: true,
},
{
name: "retries after flate error",
errFunc: func(_ int) error {
return flate.CorruptInputError(123)
},
wantAttempts: maxDownloadAttempts,
wantErr: true,
},
{
name: "passes on retry if flate error goes away",
errFunc: func(attempt int) error {
if attempt == 0 {
return flate.CorruptInputError(123)
}
return nil
},
wantAttempts: 2,
},
{
name: "passes on retry if CRC error goes away",
errFunc: func(attempt int) error {
if attempt == 0 {
return fmt.Errorf("download failed: %s", constants.BadCRCErrorMsg)
}
return nil
},
wantAttempts: 2,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
dest := filepath.Join(t.TempDir(), "foo.txt")
var attempts int
download := func() error {
defer func() {
attempts++
}()
createFile(t, dest)
return test.errFunc(attempts)
}
if err := DownloadWithRetries(context.Background(), dest, download); (err != nil) != test.wantErr {
t.Errorf("DownloadWithRetries() error = %q, wantErr %t", err, test.wantErr)
}
if test.wantAttempts != attempts {
t.Errorf("Wrong number of download attempts: wanted %d, got %d", test.wantAttempts, attempts)
}
exists, err := osmisc.FileExists(dest)
if err != nil {
t.Fatal(err)
}
if test.wantErr && exists {
t.Errorf("DownloadWithRetries() should delete the file after a failure")
} else if !test.wantErr && !exists {
t.Errorf("DownloadWithRetries() did not create the file")
}
})
}
}
func createFile(t *testing.T, path string) {
t.Helper()
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
}