blob: 67ff5fc1d7c40256c91755a34e9271514587ddf3 [file] [log] [blame]
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
Package flags provide convenience types and routines to accept specific types
of flag values on the command line.
*/
package flags
import (
"bytes"
"encoding/csv"
"flag"
"fmt"
"strconv"
"strings"
"time"
)
// stringFlagWithAllowedValues represents a string flag which can only take a
// predefined set of values.
type stringFlagWithAllowedValues struct {
val string
allowed []string
}
// StringWithAllowedValues returns a flag variable of type
// stringFlagWithAllowedValues configured with the provided parameters.
// 'allowed` is the set of values that this flag can be set to.
func StringWithAllowedValues(name, defaultVal, usage string, allowed []string) *string {
as := &stringFlagWithAllowedValues{defaultVal, allowed}
flag.CommandLine.Var(as, name, usage)
return &as.val
}
// String implements the flag.Value interface.
func (as *stringFlagWithAllowedValues) String() string {
return as.val
}
// Set implements the flag.Value interface.
func (as *stringFlagWithAllowedValues) Set(val string) error {
for _, a := range as.allowed {
if a == val {
as.val = val
return nil
}
}
return fmt.Errorf("want one of: %v", strings.Join(as.allowed, ", "))
}
type durationSliceValue []time.Duration
// DurationSlice returns a flag representing a slice of time.Duration objects.
func DurationSlice(name string, defaultVal []time.Duration, usage string) *[]time.Duration {
ds := make([]time.Duration, len(defaultVal))
copy(ds, defaultVal)
dsv := (*durationSliceValue)(&ds)
flag.CommandLine.Var(dsv, name, usage)
return &ds
}
// Set implements the flag.Value interface.
func (dsv *durationSliceValue) Set(s string) error {
ds := strings.Split(s, ",")
var dd []time.Duration
for _, n := range ds {
d, err := time.ParseDuration(n)
if err != nil {
return err
}
dd = append(dd, d)
}
*dsv = durationSliceValue(dd)
return nil
}
// String implements the flag.Value interface.
func (dsv *durationSliceValue) String() string {
var b bytes.Buffer
for i, d := range *dsv {
if i > 0 {
b.WriteRune(',')
}
b.WriteString(d.String())
}
return b.String()
}
type intSliceValue []int
// IntSlice returns a flag representing a slice of ints.
func IntSlice(name string, defaultVal []int, usage string) *[]int {
is := make([]int, len(defaultVal))
copy(is, defaultVal)
isv := (*intSliceValue)(&is)
flag.CommandLine.Var(isv, name, usage)
return &is
}
// Set implements the flag.Value interface.
func (isv *intSliceValue) Set(s string) error {
is := strings.Split(s, ",")
var ret []int
for _, n := range is {
i, err := strconv.Atoi(n)
if err != nil {
return err
}
ret = append(ret, i)
}
*isv = intSliceValue(ret)
return nil
}
// String implements the flag.Value interface.
func (isv *intSliceValue) String() string {
var b bytes.Buffer
for i, n := range *isv {
if i > 0 {
b.WriteRune(',')
}
b.WriteString(strconv.Itoa(n))
}
return b.String()
}
type stringSliceValue []string
// StringSlice returns a flag representing a slice of strings.
func StringSlice(name string, defaultVal []string, usage string) *[]string {
ss := make([]string, len(defaultVal))
copy(ss, defaultVal)
ssv := (*stringSliceValue)(&ss)
flag.CommandLine.Var(ssv, name, usage)
return &ss
}
// escapedCommaSplit splits a comma-separated list of strings in the same way
// CSV files work (escaping a comma requires double-quotes).
func escapedCommaSplit(str string) ([]string, error) {
r := csv.NewReader(strings.NewReader(str))
ret, err := r.Read()
if err != nil {
return nil, err
}
return ret, nil
}
// Set implements the flag.Value interface.
func (ss *stringSliceValue) Set(str string) error {
var err error
*ss, err = escapedCommaSplit(str)
if err != nil {
return err
}
return nil
}
// String implements the flag.Value interface.
func (ss *stringSliceValue) String() string {
return strings.Join(*ss, ",")
}