diff --git a/internal/errcheck/embedded_walker.go b/internal/errcheck/embedded_walker.go new file mode 100644 index 0000000..3b31925 --- /dev/null +++ b/internal/errcheck/embedded_walker.go @@ -0,0 +1,144 @@ +package errcheck + +import ( + "fmt" + "go/types" +) + +// walkThroughEmbeddedInterfaces returns a slice of Interfaces that +// we need to walk through in order to reach the actual definition, +// in an Interface, of the method selected by the given selection. +// +// false will be returned in the second return value if: +// - the right side of the selection is not a function +// - the actual definition of the function is not in an Interface +// +// The returned slice will contain all the interface types that need +// to be walked through to reach the actual definition. +// +// For example, say we have: +// +// type Inner interface {Method()} +// type Middle interface {Inner} +// type Outer interface {Middle} +// type T struct {Outer} +// type U struct {T} +// type V struct {U} +// +// And then the selector: +// +// V.Method +// +// We'll return [Outer, Middle, Inner] by first walking through the embedded structs +// until we reach the Outer interface, then descending through the embedded interfaces +// until we find the one that actually explicitly defines Method. +func walkThroughEmbeddedInterfaces(sel *types.Selection) ([]types.Type, bool) { + fn, ok := sel.Obj().(*types.Func) + if !ok { + return nil, false + } + + // Start off at the receiver. + currentT := sel.Recv() + + // First, we can walk through any Struct fields provided + // by the selection Index() method. We ignore the last + // index because it would give the method itself. + indexes := sel.Index() + for _, fieldIndex := range indexes[:len(indexes)-1] { + currentT = getTypeAtFieldIndex(currentT, fieldIndex) + } + + // Now currentT is either a type implementing the actual function, + // an Invalid type (if the receiver is a package), or an interface. + // + // If it's not an Interface, then we're done, as this function + // only cares about Interface-defined functions. + // + // If it is an Interface, we potentially need to continue digging until + // we find the Interface that actually explicitly defines the function. + interfaceT, ok := maybeUnname(currentT).(*types.Interface) + if !ok { + return nil, false + } + + // The first interface we pass through is this one we've found. We return the possibly + // wrapping types.Named because it is more useful to work with for callers. + result := []types.Type{currentT} + + // If this interface itself explicitly defines the given method + // then we're done digging. + for !explicitlyDefinesMethod(interfaceT, fn) { + // Otherwise, we find which of the embedded interfaces _does_ + // define the method, add it to our list, and loop. + namedInterfaceT, ok := getEmbeddedInterfaceDefiningMethod(interfaceT, fn) + if !ok { + // This should be impossible as long as we type-checked: either the + // interface or one of its embedded ones must implement the method... + panic(fmt.Sprintf("either %v or one of its embedded interfaces must implement %v", currentT, fn)) + } + result = append(result, namedInterfaceT) + interfaceT = namedInterfaceT.Underlying().(*types.Interface) + } + + return result, true +} + +func getTypeAtFieldIndex(startingAt types.Type, fieldIndex int) types.Type { + t := maybeUnname(maybeDereference(startingAt)) + s, ok := t.(*types.Struct) + if !ok { + panic(fmt.Sprintf("cannot get Field of a type that is not a struct, got a %T", t)) + } + + return s.Field(fieldIndex).Type() +} + +// getEmbeddedInterfaceDefiningMethod searches through any embedded interfaces of the +// passed interface searching for one that defines the given function. If found, the +// types.Named wrapping that interface will be returned along with true in the second value. +// +// If no such embedded interface is found, nil and false are returned. +func getEmbeddedInterfaceDefiningMethod(interfaceT *types.Interface, fn *types.Func) (*types.Named, bool) { + for i := 0; i < interfaceT.NumEmbeddeds(); i++ { + embedded := interfaceT.Embedded(i) + if definesMethod(embedded.Underlying().(*types.Interface), fn) { + return embedded, true + } + } + return nil, false +} + +func explicitlyDefinesMethod(interfaceT *types.Interface, fn *types.Func) bool { + for i := 0; i < interfaceT.NumExplicitMethods(); i++ { + if interfaceT.ExplicitMethod(i) == fn { + return true + } + } + return false +} + +func definesMethod(interfaceT *types.Interface, fn *types.Func) bool { + for i := 0; i < interfaceT.NumMethods(); i++ { + if interfaceT.Method(i) == fn { + return true + } + } + return false +} + +func maybeDereference(t types.Type) types.Type { + p, ok := t.(*types.Pointer) + if ok { + return p.Elem() + } + return t +} + +func maybeUnname(t types.Type) types.Type { + n, ok := t.(*types.Named) + if ok { + return n.Underlying() + } + return t +} diff --git a/internal/errcheck/embedded_walker_test.go b/internal/errcheck/embedded_walker_test.go new file mode 100644 index 0000000..51c13a2 --- /dev/null +++ b/internal/errcheck/embedded_walker_test.go @@ -0,0 +1,93 @@ +package errcheck + +import ( + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" +) + +const commonSrc = ` +package p + +type Inner struct {} +func (Inner) Method() + +type Outer struct {Inner} +type OuterP struct {*Inner} + +type InnerInterface interface { + Method() +} + +type OuterInterface interface {InnerInterface} +type MiddleInterfaceStruct struct {OuterInterface} +type OuterInterfaceStruct struct {MiddleInterfaceStruct} + +var c = ` + +type testCase struct { + selector string + expectedOk bool + expected []string +} + +func TestWalkThroughEmbeddedInterfaces(t *testing.T) { + cases := []testCase{ + testCase{"Inner{}.Method", false, nil}, + testCase{"(&Inner{}).Method", false, nil}, + testCase{"Outer{}.Method", false, nil}, + testCase{"InnerInterface.Method", true, []string{"test.InnerInterface"}}, + testCase{"OuterInterface.Method", true, []string{"test.OuterInterface", "test.InnerInterface"}}, + testCase{"OuterInterfaceStruct.Method", true, []string{"test.OuterInterface", "test.InnerInterface"}}, + } + + for _, c := range cases { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test", commonSrc+c.selector, 0) + if err != nil { + t.Fatal(err) + } + + conf := types.Config{} + info := types.Info{ + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + _, err = conf.Check("test", fset, []*ast.File{f}, &info) + if err != nil { + t.Fatal(err) + } + ast.Inspect(f, func(n ast.Node) bool { + s, ok := n.(*ast.SelectorExpr) + if ok { + selection, ok := info.Selections[s] + if !ok { + t.Fatalf("no Selection!") + } + ts, ok := walkThroughEmbeddedInterfaces(selection) + if ok != c.expectedOk { + t.Errorf("expected ok %v got %v", c.expectedOk, ok) + return false + } + if !ok { + return false + } + + if len(ts) != len(c.expected) { + t.Fatalf("expected %d types, got %d", len(c.expected), len(ts)) + } + + for i, e := range c.expected { + if e != ts[i].String() { + t.Errorf("mismatch at index %d: expected %s got %s", i, e, ts[i]) + } + } + } + + return true + }) + + } + +} diff --git a/internal/errcheck/errcheck.go b/internal/errcheck/errcheck.go index 694891e..c6615eb 100644 --- a/internal/errcheck/errcheck.go +++ b/internal/errcheck/errcheck.go @@ -143,6 +143,9 @@ func (c *Checker) SetExclude(l map[string]bool) { "(*strings.Builder).WriteByte", "(*strings.Builder).WriteRune", "(*strings.Builder).WriteString", + + // hash + "(hash.Hash).Write", } { c.exclude[exc] = true } @@ -236,29 +239,105 @@ type visitor struct { errors []UncheckedError } -func (v *visitor) fullName(call *ast.CallExpr) (string, bool) { +// selectorAndFunc tries to get the selector and function from call expression. +// For example, given the call expression representing "a.b()", the selector +// is "a.b" and the function is "b" itself. +// +// The final return value will be true if it is able to do extract a selector +// from the call and look up the function object it refers to. +// +// If the call does not include a selector (like if it is a plain "f()" function call) +// then the final return value will be false. +func (v *visitor) selectorAndFunc(call *ast.CallExpr) (*ast.SelectorExpr, *types.Func, bool) { sel, ok := call.Fun.(*ast.SelectorExpr) if !ok { - return "", false + return nil, nil, false } + fn, ok := v.pkg.ObjectOf(sel.Sel).(*types.Func) if !ok { // Shouldn't happen, but be paranoid - return "", false + return nil, nil, false } - // The name is fully qualified by the import path, possible type, - // function/method name and pointer receiver. - // + + return sel, fn, true + +} + +// fullName will return a package / receiver-type qualified name for a called function +// if the function is the result of a selector. Otherwise it will return +// the empty string. +// +// The name is fully qualified by the import path, possible type, +// function/method name and pointer receiver. +// +// For example, +// - for "fmt.Printf(...)" it will return "fmt.Printf" +// - for "base64.StdEncoding.Decode(...)" it will return "(*encoding/base64.Encoding).Decode" +// - for "myFunc()" it will return "" +func (v *visitor) fullName(call *ast.CallExpr) string { + _, fn, ok := v.selectorAndFunc(call) + if !ok { + return "" + } + // TODO(dh): vendored packages will have /vendor/ in their name, // thus not matching vendored standard library packages. If we // want to support vendored stdlib packages, we need to implement // FullName with our own logic. - return fn.FullName(), true + return fn.FullName() +} + +// namesForExcludeCheck will return a list of fully-qualified function names +// from a function call that can be used to check against the exclusion list. +// +// If a function call is against a local function (like "myFunc()") then no +// names are returned. If the function is package-qualified (like "fmt.Printf()") +// then just that function's fullName is returned. +// +// Otherwise, we walk through all the potentially embeddded interfaces of the receiver +// the collect a list of type-qualified function names that we will check. +func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string { + sel, fn, ok := v.selectorAndFunc(call) + if !ok { + return nil + } + + name := v.fullName(call) + if name == "" { + return nil + } + + // This will be missing for functions without a receiver (like fmt.Printf), + // so just fall back to the the function's fullName in that case. + selection, ok := v.pkg.Selections[sel] + if !ok { + return []string{name} + } + + // This will return with ok false if the function isn't defined + // on an interface, so just fall back to the fullName. + ts, ok := walkThroughEmbeddedInterfaces(selection) + if !ok { + return []string{name} + } + + result := make([]string, len(ts)) + for i, t := range ts { + // Like in fullName, vendored packages will have /vendor/ in their name, + // thus not matching vendored standard library packages. If we + // want to support vendored stdlib packages, we need to implement + // additional logic here. + result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name()) + } + return result } func (v *visitor) excludeCall(call *ast.CallExpr) bool { - if name, ok := v.fullName(call); ok { - return v.exclude[name] + for _, name := range v.namesForExcludeCheck(call) { + if v.exclude[name] { + return true + } } return false @@ -390,7 +469,7 @@ func (v *visitor) addErrorAtPosition(position token.Pos, call *ast.CallExpr) { var name string if call != nil { - name, _ = v.fullName(call) + name = v.fullName(call) } v.errors = append(v.errors, UncheckedError{pos, line, name}) diff --git a/internal/errcheck/errcheck_test.go b/internal/errcheck/errcheck_test.go index 99e12a0..84b04d6 100644 --- a/internal/errcheck/errcheck_test.go +++ b/internal/errcheck/errcheck_test.go @@ -187,6 +187,9 @@ func test(t *testing.T, f flags) { checker := NewChecker() checker.Asserts = asserts checker.Blank = blank + checker.SetExclude(map[string]bool{ + fmt.Sprintf("(%s.ErrorMakerInterface).MakeNilError", testPackage): true, + }) err := checker.CheckPackages(testPackage) uerr, ok := err.(*UncheckedErrors) if !ok { diff --git a/testdata/main.go b/testdata/main.go index 8599258..48d2c3b 100644 --- a/testdata/main.go +++ b/testdata/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "crypto/sha256" "fmt" "io/ioutil" "math/rand" @@ -65,6 +66,14 @@ func customPointerErrorTuple() (int, *MyPointerError) { return 0, &e } +// Test custom excludes +type ErrorMakerInterface interface { + MakeNilError() error +} +type ErrorMakerInterfaceWrapper interface { + ErrorMakerInterface +} + func main() { // Single error return _ = a() // BLANK @@ -137,6 +146,10 @@ func main() { b2.Write(nil) rand.Read(nil) mrand.Read(nil) + sha256.New().Write([]byte{}) ioutil.ReadFile("main.go") // UNCHECKED + + var emiw ErrorMakerInterfaceWrapper + emiw.MakeNilError() }