| package main |
| |
| import ( |
| "errors" |
| "fmt" |
| "go/ast" |
| "go/parser" |
| "go/token" |
| "path" |
| "reflect" |
| "strings" |
| ) |
| |
| var errBadReturn = errors.New("found return arg with no name: all args must be named") |
| |
| type errUnexpectedType struct { |
| expected string |
| actual interface{} |
| } |
| |
| func (e errUnexpectedType) Error() string { |
| return fmt.Sprintf("got wrong type expecting %s, got: %v", e.expected, reflect.TypeOf(e.actual)) |
| } |
| |
| // ParsedPkg holds information about a package that has been parsed, |
| // its name and the list of functions. |
| type ParsedPkg struct { |
| Name string |
| Functions []function |
| Imports []importSpec |
| } |
| |
| type function struct { |
| Name string |
| Args []arg |
| Returns []arg |
| Doc string |
| } |
| |
| type arg struct { |
| Name string |
| ArgType string |
| PackageSelector string |
| } |
| |
| func (a *arg) String() string { |
| return a.Name + " " + a.ArgType |
| } |
| |
| type importSpec struct { |
| Name string |
| Path string |
| } |
| |
| func (s *importSpec) String() string { |
| var ss string |
| if len(s.Name) != 0 { |
| ss += s.Name |
| } |
| ss += s.Path |
| return ss |
| } |
| |
| // Parse parses the given file for an interface definition with the given name. |
| func Parse(filePath string, objName string) (*ParsedPkg, error) { |
| fs := token.NewFileSet() |
| pkg, err := parser.ParseFile(fs, filePath, nil, parser.AllErrors) |
| if err != nil { |
| return nil, err |
| } |
| p := &ParsedPkg{} |
| p.Name = pkg.Name.Name |
| obj, exists := pkg.Scope.Objects[objName] |
| if !exists { |
| return nil, fmt.Errorf("could not find object %s in %s", objName, filePath) |
| } |
| if obj.Kind != ast.Typ { |
| return nil, fmt.Errorf("exected type, got %s", obj.Kind) |
| } |
| spec, ok := obj.Decl.(*ast.TypeSpec) |
| if !ok { |
| return nil, errUnexpectedType{"*ast.TypeSpec", obj.Decl} |
| } |
| iface, ok := spec.Type.(*ast.InterfaceType) |
| if !ok { |
| return nil, errUnexpectedType{"*ast.InterfaceType", spec.Type} |
| } |
| |
| p.Functions, err = parseInterface(iface) |
| if err != nil { |
| return nil, err |
| } |
| |
| // figure out what imports will be needed |
| imports := make(map[string]importSpec) |
| for _, f := range p.Functions { |
| args := append(f.Args, f.Returns...) |
| for _, arg := range args { |
| if len(arg.PackageSelector) == 0 { |
| continue |
| } |
| |
| for _, i := range pkg.Imports { |
| if i.Name != nil { |
| if i.Name.Name != arg.PackageSelector { |
| continue |
| } |
| imports[i.Path.Value] = importSpec{Name: arg.PackageSelector, Path: i.Path.Value} |
| break |
| } |
| |
| _, name := path.Split(i.Path.Value) |
| splitName := strings.Split(name, "-") |
| if len(splitName) > 1 { |
| name = splitName[len(splitName)-1] |
| } |
| // import paths have quotes already added in, so need to remove them for name comparison |
| name = strings.TrimPrefix(name, `"`) |
| name = strings.TrimSuffix(name, `"`) |
| if name == arg.PackageSelector { |
| imports[i.Path.Value] = importSpec{Path: i.Path.Value} |
| break |
| } |
| } |
| } |
| } |
| |
| for _, spec := range imports { |
| p.Imports = append(p.Imports, spec) |
| } |
| |
| return p, nil |
| } |
| |
| func parseInterface(iface *ast.InterfaceType) ([]function, error) { |
| var functions []function |
| for _, field := range iface.Methods.List { |
| switch f := field.Type.(type) { |
| case *ast.FuncType: |
| method, err := parseFunc(field) |
| if err != nil { |
| return nil, err |
| } |
| if method == nil { |
| continue |
| } |
| functions = append(functions, *method) |
| case *ast.Ident: |
| spec, ok := f.Obj.Decl.(*ast.TypeSpec) |
| if !ok { |
| return nil, errUnexpectedType{"*ast.TypeSpec", f.Obj.Decl} |
| } |
| iface, ok := spec.Type.(*ast.InterfaceType) |
| if !ok { |
| return nil, errUnexpectedType{"*ast.TypeSpec", spec.Type} |
| } |
| funcs, err := parseInterface(iface) |
| if err != nil { |
| fmt.Println(err) |
| continue |
| } |
| functions = append(functions, funcs...) |
| default: |
| return nil, errUnexpectedType{"*astFuncType or *ast.Ident", f} |
| } |
| } |
| return functions, nil |
| } |
| |
| func parseFunc(field *ast.Field) (*function, error) { |
| f := field.Type.(*ast.FuncType) |
| method := &function{Name: field.Names[0].Name} |
| if _, exists := skipFuncs[method.Name]; exists { |
| fmt.Println("skipping:", method.Name) |
| return nil, nil |
| } |
| if f.Params != nil { |
| args, err := parseArgs(f.Params.List) |
| if err != nil { |
| return nil, err |
| } |
| method.Args = args |
| } |
| if f.Results != nil { |
| returns, err := parseArgs(f.Results.List) |
| if err != nil { |
| return nil, fmt.Errorf("error parsing function returns for %q: %v", method.Name, err) |
| } |
| method.Returns = returns |
| } |
| return method, nil |
| } |
| |
| func parseArgs(fields []*ast.Field) ([]arg, error) { |
| var args []arg |
| for _, f := range fields { |
| if len(f.Names) == 0 { |
| return nil, errBadReturn |
| } |
| for _, name := range f.Names { |
| p, err := parseExpr(f.Type) |
| if err != nil { |
| return nil, err |
| } |
| args = append(args, arg{name.Name, p.value, p.pkg}) |
| } |
| } |
| return args, nil |
| } |
| |
| type parsedExpr struct { |
| value string |
| pkg string |
| } |
| |
| func parseExpr(e ast.Expr) (parsedExpr, error) { |
| var parsed parsedExpr |
| switch i := e.(type) { |
| case *ast.Ident: |
| parsed.value += i.Name |
| case *ast.StarExpr: |
| p, err := parseExpr(i.X) |
| if err != nil { |
| return parsed, err |
| } |
| parsed.value += "*" |
| parsed.value += p.value |
| parsed.pkg = p.pkg |
| case *ast.SelectorExpr: |
| p, err := parseExpr(i.X) |
| if err != nil { |
| return parsed, err |
| } |
| parsed.pkg = p.value |
| parsed.value += p.value + "." |
| parsed.value += i.Sel.Name |
| case *ast.MapType: |
| parsed.value += "map[" |
| p, err := parseExpr(i.Key) |
| if err != nil { |
| return parsed, err |
| } |
| parsed.value += p.value |
| parsed.value += "]" |
| p, err = parseExpr(i.Value) |
| if err != nil { |
| return parsed, err |
| } |
| parsed.value += p.value |
| parsed.pkg = p.pkg |
| case *ast.ArrayType: |
| parsed.value += "[]" |
| p, err := parseExpr(i.Elt) |
| if err != nil { |
| return parsed, err |
| } |
| parsed.value += p.value |
| parsed.pkg = p.pkg |
| default: |
| return parsed, errUnexpectedType{"*ast.Ident or *ast.StarExpr", i} |
| } |
| return parsed, nil |
| } |