diff --git a/internal/errcheck/embedded_walker.go b/internal/errcheck/embedded_walker.go index 9aac6c2..3b31925 100644 --- a/internal/errcheck/embedded_walker.go +++ b/internal/errcheck/embedded_walker.go @@ -94,6 +94,11 @@ func getTypeAtFieldIndex(startingAt types.Type, fieldIndex int) types.Type { return s.Field(fieldIndex).Type() } +// getEmbeddedInterfaceDefiningMethod searches through any embedded interfaces of the +// passed interface searching for one that defines the given function. If found, the +// types.Named wrapping that interface will be returned along with true in the second value. +// +// If no such embedded interface is found, nil and false are returned. func getEmbeddedInterfaceDefiningMethod(interfaceT *types.Interface, fn *types.Func) (*types.Named, bool) { for i := 0; i < interfaceT.NumEmbeddeds(); i++ { embedded := interfaceT.Embedded(i) diff --git a/internal/errcheck/errcheck.go b/internal/errcheck/errcheck.go index d66c2ba..c6615eb 100644 --- a/internal/errcheck/errcheck.go +++ b/internal/errcheck/errcheck.go @@ -239,6 +239,15 @@ type visitor struct { errors []UncheckedError } +// selectorAndFunc tries to get the selector and function from call expression. +// For example, given the call expression representing "a.b()", the selector +// is "a.b" and the function is "b" itself. +// +// The final return value will be true if it is able to do extract a selector +// from the call and look up the function object it refers to. +// +// If the call does not include a selector (like if it is a plain "f()" function call) +// then the final return value will be false. func (v *visitor) selectorAndFunc(call *ast.CallExpr) (*ast.SelectorExpr, *types.Func, bool) { sel, ok := call.Fun.(*ast.SelectorExpr) if !ok { @@ -255,29 +264,47 @@ func (v *visitor) selectorAndFunc(call *ast.CallExpr) (*ast.SelectorExpr, *types } -func (v *visitor) fullName(call *ast.CallExpr) (string, bool) { +// fullName will return a package / receiver-type qualified name for a called function +// if the function is the result of a selector. Otherwise it will return +// the empty string. +// +// The name is fully qualified by the import path, possible type, +// function/method name and pointer receiver. +// +// For example, +// - for "fmt.Printf(...)" it will return "fmt.Printf" +// - for "base64.StdEncoding.Decode(...)" it will return "(*encoding/base64.Encoding).Decode" +// - for "myFunc()" it will return "" +func (v *visitor) fullName(call *ast.CallExpr) string { _, fn, ok := v.selectorAndFunc(call) if !ok { - return "", false + return "" } - // The name is fully qualified by the import path, possible type, - // function/method name and pointer receiver. - // + // TODO(dh): vendored packages will have /vendor/ in their name, // thus not matching vendored standard library packages. If we // want to support vendored stdlib packages, we need to implement // FullName with our own logic. - return fn.FullName(), true + return fn.FullName() } +// namesForExcludeCheck will return a list of fully-qualified function names +// from a function call that can be used to check against the exclusion list. +// +// If a function call is against a local function (like "myFunc()") then no +// names are returned. If the function is package-qualified (like "fmt.Printf()") +// then just that function's fullName is returned. +// +// Otherwise, we walk through all the potentially embeddded interfaces of the receiver +// the collect a list of type-qualified function names that we will check. func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string { sel, fn, ok := v.selectorAndFunc(call) if !ok { return nil } - name, ok := v.fullName(call) - if !ok { + name := v.fullName(call) + if name == "" { return nil } @@ -297,6 +324,10 @@ func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string { result := make([]string, len(ts)) for i, t := range ts { + // Like in fullName, vendored packages will have /vendor/ in their name, + // thus not matching vendored standard library packages. If we + // want to support vendored stdlib packages, we need to implement + // additional logic here. result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name()) } return result @@ -438,7 +469,7 @@ func (v *visitor) addErrorAtPosition(position token.Pos, call *ast.CallExpr) { var name string if call != nil { - name, _ = v.fullName(call) + name = v.fullName(call) } v.errors = append(v.errors, UncheckedError{pos, line, name})