bytes: add support for base64 encoded flags (#177)

Signed-off-by: Gorka Lerchundi Osa <glertxundi@gmail.com>
diff --git a/bytes.go b/bytes.go
index 12c58db..67d5304 100644
--- a/bytes.go
+++ b/bytes.go
@@ -1,6 +1,7 @@
 package pflag
 
 import (
+	"encoding/base64"
 	"encoding/hex"
 	"fmt"
 	"strings"
@@ -9,10 +10,12 @@
 // BytesHex adapts []byte for use as a flag. Value of flag is HEX encoded
 type bytesHexValue []byte
 
+// String implements pflag.Value.String.
 func (bytesHex bytesHexValue) String() string {
 	return fmt.Sprintf("%X", []byte(bytesHex))
 }
 
+// Set implements pflag.Value.Set.
 func (bytesHex *bytesHexValue) Set(value string) error {
 	bin, err := hex.DecodeString(strings.TrimSpace(value))
 
@@ -25,6 +28,7 @@
 	return nil
 }
 
+// Type implements pflag.Value.Type.
 func (*bytesHexValue) Type() string {
 	return "bytesHex"
 }
@@ -103,3 +107,103 @@
 func BytesHexP(name, shorthand string, value []byte, usage string) *[]byte {
 	return CommandLine.BytesHexP(name, shorthand, value, usage)
 }
+
+// BytesBase64 adapts []byte for use as a flag. Value of flag is Base64 encoded
+type bytesBase64Value []byte
+
+// String implements pflag.Value.String.
+func (bytesBase64 bytesBase64Value) String() string {
+	return base64.StdEncoding.EncodeToString([]byte(bytesBase64))
+}
+
+// Set implements pflag.Value.Set.
+func (bytesBase64 *bytesBase64Value) Set(value string) error {
+	bin, err := base64.StdEncoding.DecodeString(strings.TrimSpace(value))
+
+	if err != nil {
+		return err
+	}
+
+	*bytesBase64 = bin
+
+	return nil
+}
+
+// Type implements pflag.Value.Type.
+func (*bytesBase64Value) Type() string {
+	return "bytesBase64"
+}
+
+func newBytesBase64Value(val []byte, p *[]byte) *bytesBase64Value {
+	*p = val
+	return (*bytesBase64Value)(p)
+}
+
+func bytesBase64ValueConv(sval string) (interface{}, error) {
+
+	bin, err := base64.StdEncoding.DecodeString(sval)
+	if err == nil {
+		return bin, nil
+	}
+
+	return nil, fmt.Errorf("invalid string being converted to Bytes: %s %s", sval, err)
+}
+
+// GetBytesBase64 return the []byte value of a flag with the given name
+func (f *FlagSet) GetBytesBase64(name string) ([]byte, error) {
+	val, err := f.getFlagType(name, "bytesBase64", bytesBase64ValueConv)
+
+	if err != nil {
+		return []byte{}, err
+	}
+
+	return val.([]byte), nil
+}
+
+// BytesBase64Var defines an []byte flag with specified name, default value, and usage string.
+// The argument p points to an []byte variable in which to store the value of the flag.
+func (f *FlagSet) BytesBase64Var(p *[]byte, name string, value []byte, usage string) {
+	f.VarP(newBytesBase64Value(value, p), name, "", usage)
+}
+
+// BytesBase64VarP is like BytesBase64Var, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) BytesBase64VarP(p *[]byte, name, shorthand string, value []byte, usage string) {
+	f.VarP(newBytesBase64Value(value, p), name, shorthand, usage)
+}
+
+// BytesBase64Var defines an []byte flag with specified name, default value, and usage string.
+// The argument p points to an []byte variable in which to store the value of the flag.
+func BytesBase64Var(p *[]byte, name string, value []byte, usage string) {
+	CommandLine.VarP(newBytesBase64Value(value, p), name, "", usage)
+}
+
+// BytesBase64VarP is like BytesBase64Var, but accepts a shorthand letter that can be used after a single dash.
+func BytesBase64VarP(p *[]byte, name, shorthand string, value []byte, usage string) {
+	CommandLine.VarP(newBytesBase64Value(value, p), name, shorthand, usage)
+}
+
+// BytesBase64 defines an []byte flag with specified name, default value, and usage string.
+// The return value is the address of an []byte variable that stores the value of the flag.
+func (f *FlagSet) BytesBase64(name string, value []byte, usage string) *[]byte {
+	p := new([]byte)
+	f.BytesBase64VarP(p, name, "", value, usage)
+	return p
+}
+
+// BytesBase64P is like BytesBase64, but accepts a shorthand letter that can be used after a single dash.
+func (f *FlagSet) BytesBase64P(name, shorthand string, value []byte, usage string) *[]byte {
+	p := new([]byte)
+	f.BytesBase64VarP(p, name, shorthand, value, usage)
+	return p
+}
+
+// BytesBase64 defines an []byte flag with specified name, default value, and usage string.
+// The return value is the address of an []byte variable that stores the value of the flag.
+func BytesBase64(name string, value []byte, usage string) *[]byte {
+	return CommandLine.BytesBase64P(name, "", value, usage)
+}
+
+// BytesBase64P is like BytesBase64, but accepts a shorthand letter that can be used after a single dash.
+func BytesBase64P(name, shorthand string, value []byte, usage string) *[]byte {
+	return CommandLine.BytesBase64P(name, shorthand, value, usage)
+}
diff --git a/bytes_test.go b/bytes_test.go
index cc4a769..5251f34 100644
--- a/bytes_test.go
+++ b/bytes_test.go
@@ -1,6 +1,7 @@
 package pflag
 
 import (
+	"encoding/base64"
 	"fmt"
 	"os"
 	"testing"
@@ -61,7 +62,7 @@
 			} else if tc.success {
 				bytesHex, err := f.GetBytesHex("bytes")
 				if err != nil {
-					t.Errorf("Got error trying to fetch the IP flag: %v", err)
+					t.Errorf("Got error trying to fetch the 'bytes' flag: %v", err)
 				}
 				if fmt.Sprintf("%X", bytesHex) != tc.expected {
 					t.Errorf("expected %q, got '%X'", tc.expected, bytesHex)
@@ -70,3 +71,64 @@
 		}
 	}
 }
+
+func setUpBytesBase64(bytesBase64 *[]byte) *FlagSet {
+	f := NewFlagSet("test", ContinueOnError)
+	f.BytesBase64Var(bytesBase64, "bytes", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64")
+	f.BytesBase64VarP(bytesBase64, "bytes2", "B", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64")
+	return f
+}
+
+func TestBytesBase64(t *testing.T) {
+	testCases := []struct {
+		input    string
+		success  bool
+		expected string
+	}{
+		/// Positive cases
+		{"", true, ""}, // Is empty string OK ?
+		{"AQ==", true, "AQ=="},
+
+		// Negative cases
+		{"AQ", false, ""}, // Padding removed
+		{"ï", false, ""},  // non-base64 characters
+	}
+
+	devnull, _ := os.Open(os.DevNull)
+	os.Stderr = devnull
+
+	for i := range testCases {
+		var bytesBase64 []byte
+		f := setUpBytesBase64(&bytesBase64)
+
+		tc := &testCases[i]
+
+		// --bytes
+		args := []string{
+			fmt.Sprintf("--bytes=%s", tc.input),
+			fmt.Sprintf("-B  %s", tc.input),
+			fmt.Sprintf("--bytes2=%s", tc.input),
+		}
+
+		for _, arg := range args {
+			err := f.Parse([]string{arg})
+
+			if err != nil && tc.success == true {
+				t.Errorf("expected success, got %q", err)
+				continue
+			} else if err == nil && tc.success == false {
+				// bytesBase64, err := f.GetBytesBase64("bytes")
+				t.Errorf("expected failure while processing %q", tc.input)
+				continue
+			} else if tc.success {
+				bytesBase64, err := f.GetBytesBase64("bytes")
+				if err != nil {
+					t.Errorf("Got error trying to fetch the 'bytes' flag: %v", err)
+				}
+				if base64.StdEncoding.EncodeToString(bytesBase64) != tc.expected {
+					t.Errorf("expected %q, got '%X'", tc.expected, bytesBase64)
+				}
+			}
+		}
+	}
+}