Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add map typed flags #15

Merged
merged 11 commits into from
Feb 4, 2018
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ And you can use your favorite flag or cli library!
- [x] net.IP
- [x] time.Duration
- [x] regexp.Regexp
- [ ] map[string]string
- [ ] map[string]int
- [x] map for all previous types (e.g. `map[int64]bool`, `map[string]float64`)

## Custom types:
- [x] HexBytes
Expand Down
255 changes: 246 additions & 9 deletions cmd/genvalues/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import (
"os"
"os/exec"
"strings"
"reflect"
"text/template"
"unicode"
"unicode/utf8"
"math/rand"
)

const (
Expand All @@ -25,6 +27,13 @@ import (
"{{.}}"{{end}}
)

{{$mapKeyTypes := .MapKeysTypes}}

// MapAllowedKinds stores list of kinds allowed for map keys.
var MapAllowedKinds = []reflect.Kind{ \nn
{{range $mapKeyTypes}}
reflect.{{. | Title}},{{end}}
}

func parseGenerated(value interface{}) Value {
switch value.(type) {
Expand Down Expand Up @@ -52,6 +61,18 @@ func parseGeneratedPtrs(value interface{}) Value {
}
}

func parseGeneratedMap(value interface{}) Value {
switch value.(type) {
{{range .Values}}{{ if not .NoMap }}\nn
{{ $value := . }}{{range $mapKeyTypes}}\nn
case *map[{{.}}]{{$value.Type}}:
return new{{MapValueName $value . | Title}}(value.(*map[{{.}}]{{$value.Type}}))
{{end}}{{end}}{{end}}\nn
default:
return nil
}
}

{{range .Values}}
{{if not .NoValueParser}}
// -- {{.Type}} Value
Expand Down Expand Up @@ -174,6 +195,99 @@ func (v *{{.|SliceValueName}}) IsCumulative() bool {

{{end}}

{{ if not .NoMap }}
{{ $value := . }}
{{range $mapKeyTypes}}
// -- {{ MapValueName $value . }}
type {{ MapValueName $value . }} struct {
value *map[{{.}}]{{$value.Type}}
}

var _ RepeatableFlag = (*{{MapValueName $value .}})(nil)
var _ Value = (*{{MapValueName $value .}})(nil)
var _ Getter = (*{{MapValueName $value .}})(nil)


func new{{MapValueName $value . | Title}}(m *map[{{.}}]{{$value.Type}}) *{{MapValueName $value .}} {
return &{{MapValueName $value .}}{
value: m,
}
}

func (v *{{MapValueName $value .}}) Set(s string) error {
ss := strings.Split(s, ":")
if len(ss) < 2 {
return errors.New("invalid map flag syntax, use -map=key1:val1")
}

{{ $kindVal := KindValue . }}

s = ss[0]

{{if $kindVal.Parser }}\nn
parsedKey, err := {{$kindVal.Parser}}
if err != nil {
return err
}

{{if $kindVal.Convert}}\nn
key := ({{$kindVal.Type}})(parsedKey)
{{else}}\nn
key := parsedKey
{{end}}\nn

{{ else }}\nn
key := s
{{end}}\nn


s = ss[1]

{{if $value.Parser }}\nn
parsedVal, err := {{$value.Parser}}
if err != nil {
return err
}

{{if $value.Convert}}\nn
val := ({{$value.Type}})(parsedVal)
{{else}}\nn
val := parsedVal
{{end}}\nn

{{ else }}\nn
val := s
{{end}}\nn

(*v.value)[key] = val

return nil
}

func (v *{{MapValueName $value .}}) Get() interface{} {
if v != nil && v.value != nil {
{{/* flag package create zero Value and compares it to actual Value */}}\nn
return *v.value
}
return nil
}

func (v *{{MapValueName $value .}}) String() string {
if v != nil && v.value != nil && len(*v.value) > 0 {
{{/* flag package create zero Value and compares it to actual Value */}}\nn
return fmt.Sprintf("%v", *v.value)
}
return ""
}

func (v *{{MapValueName $value .}}) Type() string { return "map[{{.}}]{{$value.Type}}" }

func (v *{{MapValueName $value .}}) IsCumulative() bool {
return true
}
{{end}}
{{end}}

{{end}}


Expand All @@ -191,6 +305,9 @@ import (
{{end}}\nn
)

{{$mapKeyTypes := .MapKeysTypes}}


{{range .Values}}

func Test{{.|Name}}Value_Zero(t *testing.T) {
Expand Down Expand Up @@ -245,6 +362,21 @@ func Test{{.|Name}}SliceValue_Zero(t *testing.T) {
}{{end}}


{{ if not .NoMap }}
{{ $value := . }}
{{range $mapKeyTypes}}
func Test{{MapValueName $value . | Title}}_Zero(t *testing.T) {
var nilValue {{MapValueName $value .}}
assert.Equal(t, "", nilValue.String())
assert.Nil(t, nilValue.Get())
nilObj := (*{{MapValueName $value . }})(nil)
assert.Equal(t, "", nilObj.String())
assert.Nil(t, nilObj.Get())
}
{{end}}
{{end}}


{{ if .SliceTests }}{{ $value := . }}
func Test{{.|Name}}SliceValue(t *testing.T) {
{{range .SliceTests}}{{ $test := . }}\nn
Expand All @@ -269,13 +401,61 @@ func Test{{.|Name}}SliceValue(t *testing.T) {
{{end}}
}{{end}}

{{ if .MapTests }}
{{ $value := . }}
{{range $mapKeyTypes}}{{ $keyType := . }}
func Test{{MapValueName $value $keyType | Title}}(t *testing.T) {
{{range $value.MapTests}}{{ $test := . }}\nn
t.Run("{{.}}", func(t *testing.T) {
var err error
a := make(map[{{$keyType}}]{{$value.Type}})
v := new{{MapValueName $value $keyType | Title}}(&a)
assert.Equal(t, parseGeneratedMap(&a), v)
assert.True(t, v.IsCumulative())
{{range .In}}\nn
err = v.Set("{{$keyType | KindTest}}{{.}}")
assert.EqualError(t, err, "invalid map flag syntax, use -map=key1:val1")
{{if ne $keyType "string"}}\nn
err = v.Set(":{{.}}")
assert.NotNil(t, err)
{{end}}\nn
err = v.Set("{{$keyType | KindTest}}:{{.}}")
{{if $test.Err}}\nn
assert.EqualError(t, err, "{{$test.Err}}")
{{ else }}\nn
assert.Nil(t, err)
{{end}}\nn
{{end}}\nn
assert.Equal(t, a, v.Get())
assert.Equal(t, "map[{{$keyType}}]{{$value.Type}}", v.Type())
{{if $test.Err}}\nn
assert.Empty(t, v.String())
{{else}}\nn
assert.NotEmpty(t, v.String())
{{end}}\nn
})
{{end}}\nn
}
{{end}}
{{end}}

{{end}}

func TestParseGeneratedMap_NilDefault(t *testing.T) {
a := new(bool)
v := parseGeneratedMap(a)
assert.Nil(t, v)
}

`
)

// MapAllowedKinds stores list of kinds allowed for map keys.
var mapAllowedKinds = []reflect.Kind{
reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
}

type test struct {
In string
Out string
Expand All @@ -296,6 +476,15 @@ func (t *sliceTest) String() string {
return fmt.Sprintf("in: %v", t.In)
}

type mapTest struct {
In []string
Err string
}

func (t *mapTest) String() string {
return fmt.Sprintf("in: %v", t.In)
}

type value struct {
Name string `json:"name"`
Kind string `json:"kind"`
Expand All @@ -308,8 +497,10 @@ type value struct {
Help string `json:"help"`
Import []string `json:"import"`
Tests []test `json:"tests"`
SliceTests []sliceTest `json:"slice_tests"`
NoSlice bool `json:"no_slice"`
SliceTests []sliceTest `json:"slice_tests"`
NoMap bool `json:"no_map"`
MapTests []mapTest `json:"map_tests"`
}

func fatalIfError(err error) {
Expand Down Expand Up @@ -345,6 +536,7 @@ func main() {

baseT := template.New("genvalues").Funcs(template.FuncMap{
"Lower": strings.ToLower,
"Title": strings.Title,
"Format": func(v *value) string {
if v.Format != "" {
return v.Format
Expand All @@ -362,6 +554,27 @@ func main() {
name := valueName(v)
return camelToLower(name) + "SliceValue"
},
"MapValueName": func(v *value, kind string) string {
name := valueName(v)

return kind + name + "MapValue"
},
"KindValue": func(kind string) value {
for _, value := range values {
if value.Type == kind {
return value
}
}

return value{}
},
"KindTest": func(kind string) interface{} {
if kind == "string" {
return randStr(5)
}

return rand.Intn(8)
},
"Name": valueName,
"Plural": func(v *value) string {
if v.Plural != "" {
Expand Down Expand Up @@ -394,11 +607,13 @@ func main() {
defer w.Close()

err = t.Execute(w, struct {
Values []value
Imports []string
Values []value
Imports []string
MapKeysTypes []string
}{
Values: values,
Imports: imports,
Values: values,
Imports: imports,
MapKeysTypes: stringifyKinds(mapAllowedKinds),
})
fatalIfError(err)

Expand All @@ -414,11 +629,13 @@ func main() {
defer w.Close()

err = t.Execute(w, struct {
Values []value
Imports []string
Values []value
Imports []string
MapKeysTypes []string
}{
Values: values,
Imports: imports,
Values: values,
Imports: imports,
MapKeysTypes: stringifyKinds(mapAllowedKinds),
})
fatalIfError(err)

Expand All @@ -427,6 +644,16 @@ func main() {

}

func stringifyKinds(kinds []reflect.Kind) []string {
var l []string

for _, kind := range kinds {
l = append(l, kind.String())
}

return l
}

func gofmt(path string) {
cmd := exec.Command("goimports", "-w", path)
b, err := cmd.CombinedOutput()
Expand Down Expand Up @@ -527,3 +754,13 @@ func split(src string) (entries []string) {
}
return
}

const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"

func randStr(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}
3 changes: 3 additions & 0 deletions examples/flag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type httpConfig struct {
SSL bool
Timeout time.Duration
Addr *net.TCPAddr
Methods map[string]int64 `desc:"HTTP Methods"`
}

type config struct {
Expand Down Expand Up @@ -59,6 +60,8 @@ func main() {
"-http-addr", "google.com:8000",
"-regexp", "ddfd",
"-count", "-count",
"-http-methods", "post:15",
"-http-methods", "get:25",
})
if err != nil {
fmt.Printf("err: %v", err)
Expand Down
Loading