blob: 9354b7367119c079e35969203a82f8685088bb04 [file] [log] [blame]
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/*
Package flags provide convenience types and routines to accept specific types
of flag values on the command line.
*/
package flags
import (
"bytes"
"flag"
"fmt"
"strconv"
"strings"
"time"
)
// stringFlagWithAllowedValues represents a string flag which can only take a
// predefined set of values.
type stringFlagWithAllowedValues struct {
val string
allowed []string
}
// StringWithAllowedValues returns a flag variable of type
// stringFlagWithAllowedValues configured with the provided parameters.
// 'allowed` is the set of values that this flag can be set to.
func StringWithAllowedValues(name, defaultVal, usage string, allowed []string) *string {
as := &stringFlagWithAllowedValues{defaultVal, allowed}
flag.CommandLine.Var(as, name, usage)
return &as.val
}
// String implements the flag.Value interface.
func (as *stringFlagWithAllowedValues) String() string {
return as.val
}
// Set implements the flag.Value interface.
func (as *stringFlagWithAllowedValues) Set(val string) error {
for _, a := range as.allowed {
if a == val {
as.val = val
return nil
}
}
return fmt.Errorf("want one of: %v", strings.Join(as.allowed, ", "))
}
type durationSliceValue []time.Duration
// DurationSlice returns a flag representing a slice of time.Duration objects.
func DurationSlice(name string, defaultVal []time.Duration, usage string) *[]time.Duration {
ds := make([]time.Duration, len(defaultVal))
copy(ds, defaultVal)
dsv := (*durationSliceValue)(&ds)
flag.CommandLine.Var(dsv, name, usage)
return &ds
}
// Set implements the flag.Value interface.
func (dsv *durationSliceValue) Set(s string) error {
ds := strings.Split(s, ",")
var dd []time.Duration
for _, n := range ds {
d, err := time.ParseDuration(n)
if err != nil {
return err
}
dd = append(dd, d)
}
*dsv = durationSliceValue(dd)
return nil
}
// String implements the flag.Value interface.
func (dsv *durationSliceValue) String() string {
var b bytes.Buffer
for i, d := range *dsv {
if i > 0 {
b.WriteRune(',')
}
b.WriteString(d.String())
}
return b.String()
}
type intSliceValue []int
// IntSlice returns a flag representing a slice of ints.
func IntSlice(name string, defaultVal []int, usage string) *[]int {
is := make([]int, len(defaultVal))
copy(is, defaultVal)
isv := (*intSliceValue)(&is)
flag.CommandLine.Var(isv, name, usage)
return &is
}
// Set implements the flag.Value interface.
func (isv *intSliceValue) Set(s string) error {
is := strings.Split(s, ",")
var ret []int
for _, n := range is {
i, err := strconv.Atoi(n)
if err != nil {
return err
}
ret = append(ret, i)
}
*isv = intSliceValue(ret)
return nil
}
// String implements the flag.Value interface.
func (isv *intSliceValue) String() string {
var b bytes.Buffer
for i, n := range *isv {
if i > 0 {
b.WriteRune(',')
}
b.WriteString(strconv.Itoa(n))
}
return b.String()
}