From af5c046ec344397dc6c8da4a16d4bceb2e75828e Mon Sep 17 00:00:00 2001 From: Louis Sobel Date: Sat, 7 Apr 2018 11:32:45 -0700 Subject: [PATCH 1/7] check reciever type aginst exclude --- internal/errcheck/errcheck.go | 51 ++++++++++++++++++++++++++++-- internal/errcheck/errcheck_test.go | 5 +++ testdata/main.go | 20 ++++++++++++ 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/internal/errcheck/errcheck.go b/internal/errcheck/errcheck.go index 694891e..b8491a7 100644 --- a/internal/errcheck/errcheck.go +++ b/internal/errcheck/errcheck.go @@ -236,16 +236,27 @@ type visitor struct { errors []UncheckedError } -func (v *visitor) fullName(call *ast.CallExpr) (string, bool) { +func (v *visitor) getReceiverTypeAndFunc(call *ast.CallExpr) (types.Type, *types.Func, bool) { sel, ok := call.Fun.(*ast.SelectorExpr) if !ok { - return "", false + return nil, nil, false } + t := v.pkg.TypeOf(sel.X) fn, ok := v.pkg.ObjectOf(sel.Sel).(*types.Func) if !ok { // Shouldn't happen, but be paranoid + return nil, nil, false + } + + return t, fn, true +} + +func (v *visitor) fullName(call *ast.CallExpr) (string, bool) { + _, fn, ok := v.getReceiverTypeAndFunc(call) + if !ok { return "", false } + // The name is fully qualified by the import path, possible type, // function/method name and pointer receiver. // @@ -256,9 +267,43 @@ func (v *visitor) fullName(call *ast.CallExpr) (string, bool) { return fn.FullName(), true } +// fullNameWithReceiversType returns the full name of a function call, but taking +// into account embedded structs, unlike the plain fullName function above. +// +// For example, a struct defined like: +// +// type WriterWrapper { +// io.Writer +// } +// +// fullName(WriterWrapper{}.Write) would return "(io.Writer).Write", +// but fullNameWithReceiversType would return "(WriterWrapper).Write". +func (v *visitor) fullNameWithReceiversType(call *ast.CallExpr) (string, bool) { + t, fn, ok := v.getReceiverTypeAndFunc(call) + if !ok { + return "", false + } + + if t == types.Typ[types.Invalid] { + return "", false + } + + // The name is fully qualified by the import path, possible type, + // function/method name and pointer receiver. + return fmt.Sprintf("(%s).%s", t.String(), fn.Name()), true +} + func (v *visitor) excludeCall(call *ast.CallExpr) bool { if name, ok := v.fullName(call); ok { - return v.exclude[name] + if v.exclude[name] { + return true + } + } + + if name, ok := v.fullNameWithReceiversType(call); ok { + if v.exclude[name] { + return true + } } return false diff --git a/internal/errcheck/errcheck_test.go b/internal/errcheck/errcheck_test.go index 99e12a0..0923b7e 100644 --- a/internal/errcheck/errcheck_test.go +++ b/internal/errcheck/errcheck_test.go @@ -187,6 +187,11 @@ func test(t *testing.T, f flags) { checker := NewChecker() checker.Asserts = asserts checker.Blank = blank + checker.SetExclude(map[string]bool{ + fmt.Sprintf("(%s.ErrorMakerWrapper).MakeNilError", testPackage): true, + fmt.Sprintf("(*%s.ErrorMakerWrapper).MakeNilError", testPackage): true, + fmt.Sprintf("(%s.ErrorMaker).MakeAnotherNilError", testPackage): true, + }) err := checker.CheckPackages(testPackage) uerr, ok := err.(*UncheckedErrors) if !ok { diff --git a/testdata/main.go b/testdata/main.go index 8599258..7c15b54 100644 --- a/testdata/main.go +++ b/testdata/main.go @@ -65,6 +65,14 @@ func customPointerErrorTuple() (int, *MyPointerError) { return 0, &e } +// Test custom excludes +type ErrorMaker struct{} + +func (ErrorMaker) MakeNilError() error { return nil } +func (ErrorMaker) MakeAnotherNilError() error { return nil } + +type ErrorMakerWrapper struct{ ErrorMaker } + func main() { // Single error return _ = a() // BLANK @@ -139,4 +147,16 @@ func main() { mrand.Read(nil) ioutil.ReadFile("main.go") // UNCHECKED + + // We exclude (ErrorMakerWrapper).MakeNilError, but not + // (ErrorMaker).MakeNilError itself. + // + // We do exclude MakeAnotherNilError itself. + em := ErrorMaker{} + em.MakeNilError() // UNCHECKED + em.MakeAnotherNilError() + wem := ErrorMakerWrapper{ErrorMaker{}} + wem.MakeNilError() + (&wem).MakeNilError() + wem.MakeAnotherNilError() } From 10d5a4d6b000ac6f28a07dba7ffef970905856a1 Mon Sep 17 00:00:00 2001 From: Louis Sobel Date: Sat, 7 Apr 2018 15:43:00 -0700 Subject: [PATCH 2/7] ignore embedded --- internal/errcheck/embedded_walker.go | 114 ++++++++++++++++++++++ internal/errcheck/embedded_walker_test.go | 93 ++++++++++++++++++ internal/errcheck/errcheck.go | 61 ++++++------ internal/errcheck/errcheck_test.go | 4 +- testdata/main.go | 25 ++--- 5 files changed, 247 insertions(+), 50 deletions(-) create mode 100644 internal/errcheck/embedded_walker.go create mode 100644 internal/errcheck/embedded_walker_test.go diff --git a/internal/errcheck/embedded_walker.go b/internal/errcheck/embedded_walker.go new file mode 100644 index 0000000..90f6227 --- /dev/null +++ b/internal/errcheck/embedded_walker.go @@ -0,0 +1,114 @@ +package errcheck + +import ( + // "fmt" + "go/types" +) + +// walkThroughEmbeddedInterfaces returns a slice of types that we need to walk through +// in order to reach the actual interface definition of the function on the other end of this selection (x.f) +// +// False will be returned it: +// - the left side of the selection is not a function +// - the right side of the selection is an invalid type +// - we don't end at an interface-defined function +// +func walkThroughEmbeddedInterfaces(sel *types.Selection) ([]types.Type, bool) { + fn, ok := sel.Obj().(*types.Func) + if !ok { + return nil, false + } + + currentT := sel.Recv() + if currentT == types.Typ[types.Invalid] { + return nil, false + } + + // The first type is the immediate receiver itself + result := []types.Type{currentT} + + // First, we can walk through any Struct fields provided + // by the selection Index() method. + indexes := sel.Index() + for _, fieldIndex := range indexes[:len(indexes)-1] { + currentT = maybeUnname(maybeDereference(currentT)) + + // Because we have an entry in Index for this type, + // we know it has to be a Struct. + s, ok := currentT.(*types.Struct) + if !ok { + panic("expected Struct!") + } + + nextT := s.Field(fieldIndex).Type() + result = append(result, nextT) + currentT = nextT + } + + // Now currentT is either a Struct implementing the + // actual function or an interface. If it's an interface, + // we need to continue digging until we find the interface + // that actually explicitly defines the function! + // + // If it's a Struct, we return false; we're only interested in interface-defined + // functions here. + _, ok = maybeUnname(currentT).(*types.Interface) + if !ok { + return nil, false + } + + for { + interfaceT := maybeUnname(currentT).(*types.Interface) + if explicitlyDefinesMethod(interfaceT, fn) { + // then we're done + break + } + + // otherwise, search through the embedded interfaces to find + // the one that defines this method. + for i := 0; i < interfaceT.NumEmbeddeds(); i++ { + nextNamedInterface := interfaceT.Embedded(i) + if definesMethod(maybeUnname(nextNamedInterface).(*types.Interface), fn) { + result = append(result, nextNamedInterface) + currentT = nextNamedInterface + break + } + } + } + + return result, true +} + +func explicitlyDefinesMethod(interfaceT *types.Interface, fn *types.Func) bool { + for i := 0; i < interfaceT.NumExplicitMethods(); i++ { + if interfaceT.ExplicitMethod(i).Id() == fn.Id() { + return true + } + } + return false +} + +func definesMethod(interfaceT *types.Interface, fn *types.Func) bool { + for i := 0; i < interfaceT.NumMethods(); i++ { + if interfaceT.Method(i).Id() == fn.Id() { + return true + } + } + return false +} + +func maybeDereference(t types.Type) types.Type { + p, ok := t.(*types.Pointer) + if ok { + return p.Elem() + } + return t +} + +func maybeUnname(t types.Type) types.Type { + n, ok := t.(*types.Named) + if ok { + return n.Underlying() + } + return t +} diff --git a/internal/errcheck/embedded_walker_test.go b/internal/errcheck/embedded_walker_test.go new file mode 100644 index 0000000..59a763f --- /dev/null +++ b/internal/errcheck/embedded_walker_test.go @@ -0,0 +1,93 @@ +package errcheck + +import ( + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" +) + +const commonSrc = ` +package p + +type Inner struct {} +func (Inner) Method() + +type Outer struct {Inner} +type OuterP struct {*Inner} + +type InnerInterface interface { + Method() +} + +type OuterInterface interface {InnerInterface} +type MiddleInterfaceStruct struct {OuterInterface} +type OuterInterfaceStruct struct {MiddleInterfaceStruct} + +var c = ` + +type testCase struct { + selector string + expectedOk bool + expected []string +} + +func TestWalkThroughEmbeddedInterfaces(t *testing.T) { + cases := []testCase{ + testCase{"Inner{}.Method", false, nil}, + testCase{"(&Inner{}).Method", false, nil}, + testCase{"Outer{}.Method", false, nil}, + testCase{"InnerInterface.Method", true, []string{"test.InnerInterface"}}, + testCase{"OuterInterface.Method", true, []string{"test.OuterInterface", "test.InnerInterface"}}, + testCase{"OuterInterfaceStruct.Method", true, []string{"test.OuterInterfaceStruct", "test.MiddleInterfaceStruct", "test.OuterInterface", "test.InnerInterface"}}, + } + + for _, c := range cases { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test", commonSrc+c.selector, 0) + if err != nil { + t.Fatal(err) + } + + conf := types.Config{} + info := types.Info{ + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + _, err = conf.Check("test", fset, []*ast.File{f}, &info) + if err != nil { + t.Fatal(err) + } + ast.Inspect(f, func(n ast.Node) bool { + s, ok := n.(*ast.SelectorExpr) + if ok { + selection, ok := info.Selections[s] + if !ok { + t.Fatalf("no Selection!") + } + ts, ok := walkThroughEmbeddedInterfaces(selection) + if ok != c.expectedOk { + t.Errorf("expected ok %v got %v", c.expectedOk, ok) + return false + } + if !ok { + return false + } + + if len(ts) != len(c.expected) { + t.Fatalf("expected %d types, got %d", len(c.expected), len(ts)) + } + + for i, e := range c.expected { + if e != ts[i].String() { + t.Errorf("mismatch at index %d: expected %s got %s", i, e, ts[i]) + } + } + } + + return true + }) + + } + +} diff --git a/internal/errcheck/errcheck.go b/internal/errcheck/errcheck.go index b8491a7..85d70cf 100644 --- a/internal/errcheck/errcheck.go +++ b/internal/errcheck/errcheck.go @@ -236,23 +236,24 @@ type visitor struct { errors []UncheckedError } -func (v *visitor) getReceiverTypeAndFunc(call *ast.CallExpr) (types.Type, *types.Func, bool) { +func (v *visitor) selectorAndFunc(call *ast.CallExpr) (*ast.SelectorExpr, *types.Func, bool) { sel, ok := call.Fun.(*ast.SelectorExpr) if !ok { return nil, nil, false } - t := v.pkg.TypeOf(sel.X) + fn, ok := v.pkg.ObjectOf(sel.Sel).(*types.Func) if !ok { // Shouldn't happen, but be paranoid return nil, nil, false } - return t, fn, true + return sel, fn, true + } func (v *visitor) fullName(call *ast.CallExpr) (string, bool) { - _, fn, ok := v.getReceiverTypeAndFunc(call) + _, fn, ok := v.selectorAndFunc(call) if !ok { return "", false } @@ -267,40 +268,40 @@ func (v *visitor) fullName(call *ast.CallExpr) (string, bool) { return fn.FullName(), true } -// fullNameWithReceiversType returns the full name of a function call, but taking -// into account embedded structs, unlike the plain fullName function above. -// -// For example, a struct defined like: -// -// type WriterWrapper { -// io.Writer -// } -// -// fullName(WriterWrapper{}.Write) would return "(io.Writer).Write", -// but fullNameWithReceiversType would return "(WriterWrapper).Write". -func (v *visitor) fullNameWithReceiversType(call *ast.CallExpr) (string, bool) { - t, fn, ok := v.getReceiverTypeAndFunc(call) +func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string { + sel, fn, ok := v.selectorAndFunc(call) if !ok { - return "", false + return nil } - if t == types.Typ[types.Invalid] { - return "", false + name, ok := v.fullName(call) + if !ok { + return nil } - // The name is fully qualified by the import path, possible type, - // function/method name and pointer receiver. - return fmt.Sprintf("(%s).%s", t.String(), fn.Name()), true -} + // This will have ok false for functions without a receiver type, + // so just return the functions full name. + selection, ok := v.pkg.Selections[sel] + if !ok { + return []string{name} + } -func (v *visitor) excludeCall(call *ast.CallExpr) bool { - if name, ok := v.fullName(call); ok { - if v.exclude[name] { - return true - } + ts, ok := walkThroughEmbeddedInterfaces(selection) + if !ok { + return []string{name} + } + + result := make([]string, len(ts)) + for i, t := range ts { + result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name()) } + fmt.Printf("%v\n", ts) + return result +} - if name, ok := v.fullNameWithReceiversType(call); ok { +func (v *visitor) excludeCall(call *ast.CallExpr) bool { + for _, name := range v.namesForExcludeCheck(call) { + fmt.Printf("%v\n", name) if v.exclude[name] { return true } diff --git a/internal/errcheck/errcheck_test.go b/internal/errcheck/errcheck_test.go index 0923b7e..84b04d6 100644 --- a/internal/errcheck/errcheck_test.go +++ b/internal/errcheck/errcheck_test.go @@ -188,9 +188,7 @@ func test(t *testing.T, f flags) { checker.Asserts = asserts checker.Blank = blank checker.SetExclude(map[string]bool{ - fmt.Sprintf("(%s.ErrorMakerWrapper).MakeNilError", testPackage): true, - fmt.Sprintf("(*%s.ErrorMakerWrapper).MakeNilError", testPackage): true, - fmt.Sprintf("(%s.ErrorMaker).MakeAnotherNilError", testPackage): true, + fmt.Sprintf("(%s.ErrorMakerInterface).MakeNilError", testPackage): true, }) err := checker.CheckPackages(testPackage) uerr, ok := err.(*UncheckedErrors) diff --git a/testdata/main.go b/testdata/main.go index 7c15b54..e2b62fe 100644 --- a/testdata/main.go +++ b/testdata/main.go @@ -66,12 +66,12 @@ func customPointerErrorTuple() (int, *MyPointerError) { } // Test custom excludes -type ErrorMaker struct{} - -func (ErrorMaker) MakeNilError() error { return nil } -func (ErrorMaker) MakeAnotherNilError() error { return nil } - -type ErrorMakerWrapper struct{ ErrorMaker } +type ErrorMakerInterface interface { + MakeNilError() error +} +type ErrorMakerInterfaceWrapper interface { + ErrorMakerInterface +} func main() { // Single error return @@ -148,15 +148,6 @@ func main() { ioutil.ReadFile("main.go") // UNCHECKED - // We exclude (ErrorMakerWrapper).MakeNilError, but not - // (ErrorMaker).MakeNilError itself. - // - // We do exclude MakeAnotherNilError itself. - em := ErrorMaker{} - em.MakeNilError() // UNCHECKED - em.MakeAnotherNilError() - wem := ErrorMakerWrapper{ErrorMaker{}} - wem.MakeNilError() - (&wem).MakeNilError() - wem.MakeAnotherNilError() + var emiw ErrorMakerInterfaceWrapper + emiw.MakeNilError() } From 2ae7c6b2e370a46dd4966b74deab88bca643b5b6 Mon Sep 17 00:00:00 2001 From: Louis Sobel Date: Sat, 7 Apr 2018 15:46:42 -0700 Subject: [PATCH 3/7] exclude hash.Hash.Write --- internal/errcheck/errcheck.go | 9 +++------ testdata/main.go | 2 ++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/internal/errcheck/errcheck.go b/internal/errcheck/errcheck.go index 85d70cf..ffa6bfb 100644 --- a/internal/errcheck/errcheck.go +++ b/internal/errcheck/errcheck.go @@ -143,13 +143,12 @@ func (c *Checker) SetExclude(l map[string]bool) { "(*strings.Builder).WriteByte", "(*strings.Builder).WriteRune", "(*strings.Builder).WriteString", + + // hash + "(hash.Hash).Write", } { c.exclude[exc] = true } - - for k := range l { - c.exclude[k] = true - } } func (c *Checker) logf(msg string, args ...interface{}) { @@ -295,13 +294,11 @@ func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string { for i, t := range ts { result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name()) } - fmt.Printf("%v\n", ts) return result } func (v *visitor) excludeCall(call *ast.CallExpr) bool { for _, name := range v.namesForExcludeCheck(call) { - fmt.Printf("%v\n", name) if v.exclude[name] { return true } diff --git a/testdata/main.go b/testdata/main.go index e2b62fe..48d2c3b 100644 --- a/testdata/main.go +++ b/testdata/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "crypto/sha256" "fmt" "io/ioutil" "math/rand" @@ -145,6 +146,7 @@ func main() { b2.Write(nil) rand.Read(nil) mrand.Read(nil) + sha256.New().Write([]byte{}) ioutil.ReadFile("main.go") // UNCHECKED From 05dbe35037c0818c6b0f59fbfc17285b9b8fe2ef Mon Sep 17 00:00:00 2001 From: Louis Sobel Date: Sat, 7 Apr 2018 15:49:51 -0700 Subject: [PATCH 4/7] touch up test a bit --- internal/errcheck/embedded_walker.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/internal/errcheck/embedded_walker.go b/internal/errcheck/embedded_walker.go index 90f6227..f7db421 100644 --- a/internal/errcheck/embedded_walker.go +++ b/internal/errcheck/embedded_walker.go @@ -6,7 +6,8 @@ import ( ) // walkThroughEmbeddedInterfaces returns a slice of types that we need to walk through -// in order to reach the actual interface definition of the function on the other end of this selection (x.f) +// in order to reach the actual definition, in an interface, of the function on +// the other end of this selection (x.f) // // False will be returned it: // - the left side of the selection is not a function @@ -48,10 +49,10 @@ func walkThroughEmbeddedInterfaces(sel *types.Selection) ([]types.Type, bool) { // Now currentT is either a Struct implementing the // actual function or an interface. If it's an interface, // we need to continue digging until we find the interface - // that actually explicitly defines the function! + // that actually explicitly defines the function. // - // If it's a Struct, we return false; we're only interested in interface-defined - // functions here. + // If it's a Struct, we return false, as we're only interested + // in interface-defined functions in this function. _, ok = maybeUnname(currentT).(*types.Interface) if !ok { return nil, false From 12b83e19c1ce0f37729eee2f746ff7ddd5f6bd2e Mon Sep 17 00:00:00 2001 From: Louis Sobel Date: Sun, 8 Apr 2018 11:07:54 -0700 Subject: [PATCH 5/7] improvements --- internal/errcheck/embedded_walker.go | 126 +++++++++++++--------- internal/errcheck/embedded_walker_test.go | 2 +- internal/errcheck/errcheck.go | 7 +- 3 files changed, 80 insertions(+), 55 deletions(-) diff --git a/internal/errcheck/embedded_walker.go b/internal/errcheck/embedded_walker.go index f7db421..9aac6c2 100644 --- a/internal/errcheck/embedded_walker.go +++ b/internal/errcheck/embedded_walker.go @@ -1,88 +1,112 @@ package errcheck import ( - // "fmt" + "fmt" "go/types" ) -// walkThroughEmbeddedInterfaces returns a slice of types that we need to walk through -// in order to reach the actual definition, in an interface, of the function on -// the other end of this selection (x.f) +// walkThroughEmbeddedInterfaces returns a slice of Interfaces that +// we need to walk through in order to reach the actual definition, +// in an Interface, of the method selected by the given selection. // -// False will be returned it: -// - the left side of the selection is not a function -// - the right side of the selection is an invalid type -// - we don't end at an interface-defined function +// false will be returned in the second return value if: +// - the right side of the selection is not a function +// - the actual definition of the function is not in an Interface // +// The returned slice will contain all the interface types that need +// to be walked through to reach the actual definition. +// +// For example, say we have: +// +// type Inner interface {Method()} +// type Middle interface {Inner} +// type Outer interface {Middle} +// type T struct {Outer} +// type U struct {T} +// type V struct {U} +// +// And then the selector: +// +// V.Method +// +// We'll return [Outer, Middle, Inner] by first walking through the embedded structs +// until we reach the Outer interface, then descending through the embedded interfaces +// until we find the one that actually explicitly defines Method. func walkThroughEmbeddedInterfaces(sel *types.Selection) ([]types.Type, bool) { fn, ok := sel.Obj().(*types.Func) if !ok { return nil, false } + // Start off at the receiver. currentT := sel.Recv() - if currentT == types.Typ[types.Invalid] { - return nil, false - } - - // The first type is the immediate receiver itself - result := []types.Type{currentT} // First, we can walk through any Struct fields provided - // by the selection Index() method. + // by the selection Index() method. We ignore the last + // index because it would give the method itself. indexes := sel.Index() for _, fieldIndex := range indexes[:len(indexes)-1] { - currentT = maybeUnname(maybeDereference(currentT)) - - // Because we have an entry in Index for this type, - // we know it has to be a Struct. - s, ok := currentT.(*types.Struct) - if !ok { - panic("expected Struct!") - } - - nextT := s.Field(fieldIndex).Type() - result = append(result, nextT) - currentT = nextT + currentT = getTypeAtFieldIndex(currentT, fieldIndex) } - // Now currentT is either a Struct implementing the - // actual function or an interface. If it's an interface, - // we need to continue digging until we find the interface - // that actually explicitly defines the function. + // Now currentT is either a type implementing the actual function, + // an Invalid type (if the receiver is a package), or an interface. // - // If it's a Struct, we return false, as we're only interested - // in interface-defined functions in this function. - _, ok = maybeUnname(currentT).(*types.Interface) + // If it's not an Interface, then we're done, as this function + // only cares about Interface-defined functions. + // + // If it is an Interface, we potentially need to continue digging until + // we find the Interface that actually explicitly defines the function. + interfaceT, ok := maybeUnname(currentT).(*types.Interface) if !ok { return nil, false } - for { - interfaceT := maybeUnname(currentT).(*types.Interface) - if explicitlyDefinesMethod(interfaceT, fn) { - // then we're done - break - } + // The first interface we pass through is this one we've found. We return the possibly + // wrapping types.Named because it is more useful to work with for callers. + result := []types.Type{currentT} - // otherwise, search through the embedded interfaces to find - // the one that defines this method. - for i := 0; i < interfaceT.NumEmbeddeds(); i++ { - nextNamedInterface := interfaceT.Embedded(i) - if definesMethod(maybeUnname(nextNamedInterface).(*types.Interface), fn) { - result = append(result, nextNamedInterface) - currentT = nextNamedInterface - break - } + // If this interface itself explicitly defines the given method + // then we're done digging. + for !explicitlyDefinesMethod(interfaceT, fn) { + // Otherwise, we find which of the embedded interfaces _does_ + // define the method, add it to our list, and loop. + namedInterfaceT, ok := getEmbeddedInterfaceDefiningMethod(interfaceT, fn) + if !ok { + // This should be impossible as long as we type-checked: either the + // interface or one of its embedded ones must implement the method... + panic(fmt.Sprintf("either %v or one of its embedded interfaces must implement %v", currentT, fn)) } + result = append(result, namedInterfaceT) + interfaceT = namedInterfaceT.Underlying().(*types.Interface) } return result, true } +func getTypeAtFieldIndex(startingAt types.Type, fieldIndex int) types.Type { + t := maybeUnname(maybeDereference(startingAt)) + s, ok := t.(*types.Struct) + if !ok { + panic(fmt.Sprintf("cannot get Field of a type that is not a struct, got a %T", t)) + } + + return s.Field(fieldIndex).Type() +} + +func getEmbeddedInterfaceDefiningMethod(interfaceT *types.Interface, fn *types.Func) (*types.Named, bool) { + for i := 0; i < interfaceT.NumEmbeddeds(); i++ { + embedded := interfaceT.Embedded(i) + if definesMethod(embedded.Underlying().(*types.Interface), fn) { + return embedded, true + } + } + return nil, false +} + func explicitlyDefinesMethod(interfaceT *types.Interface, fn *types.Func) bool { for i := 0; i < interfaceT.NumExplicitMethods(); i++ { - if interfaceT.ExplicitMethod(i).Id() == fn.Id() { + if interfaceT.ExplicitMethod(i) == fn { return true } } @@ -91,7 +115,7 @@ func explicitlyDefinesMethod(interfaceT *types.Interface, fn *types.Func) bool { func definesMethod(interfaceT *types.Interface, fn *types.Func) bool { for i := 0; i < interfaceT.NumMethods(); i++ { - if interfaceT.Method(i).Id() == fn.Id() { + if interfaceT.Method(i) == fn { return true } } diff --git a/internal/errcheck/embedded_walker_test.go b/internal/errcheck/embedded_walker_test.go index 59a763f..51c13a2 100644 --- a/internal/errcheck/embedded_walker_test.go +++ b/internal/errcheck/embedded_walker_test.go @@ -40,7 +40,7 @@ func TestWalkThroughEmbeddedInterfaces(t *testing.T) { testCase{"Outer{}.Method", false, nil}, testCase{"InnerInterface.Method", true, []string{"test.InnerInterface"}}, testCase{"OuterInterface.Method", true, []string{"test.OuterInterface", "test.InnerInterface"}}, - testCase{"OuterInterfaceStruct.Method", true, []string{"test.OuterInterfaceStruct", "test.MiddleInterfaceStruct", "test.OuterInterface", "test.InnerInterface"}}, + testCase{"OuterInterfaceStruct.Method", true, []string{"test.OuterInterface", "test.InnerInterface"}}, } for _, c := range cases { diff --git a/internal/errcheck/errcheck.go b/internal/errcheck/errcheck.go index ffa6bfb..c808d3d 100644 --- a/internal/errcheck/errcheck.go +++ b/internal/errcheck/errcheck.go @@ -256,7 +256,6 @@ func (v *visitor) fullName(call *ast.CallExpr) (string, bool) { if !ok { return "", false } - // The name is fully qualified by the import path, possible type, // function/method name and pointer receiver. // @@ -278,13 +277,15 @@ func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string { return nil } - // This will have ok false for functions without a receiver type, - // so just return the functions full name. + // This will be missing for functions without a receiver (like fmt.Printf), + // so just fall back to the the function's fullName in that case. selection, ok := v.pkg.Selections[sel] if !ok { return []string{name} } + // This will return with ok false if the function isn't defined + // on an interface, so just fall back to the fullName. ts, ok := walkThroughEmbeddedInterfaces(selection) if !ok { return []string{name} From 070cb41b0746e0f9f889c80076ee81eb5d7a7a9c Mon Sep 17 00:00:00 2001 From: Louis Sobel Date: Sat, 19 May 2018 10:52:32 -0400 Subject: [PATCH 6/7] fix merge --- internal/errcheck/errcheck.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/errcheck/errcheck.go b/internal/errcheck/errcheck.go index c808d3d..d66c2ba 100644 --- a/internal/errcheck/errcheck.go +++ b/internal/errcheck/errcheck.go @@ -149,6 +149,10 @@ func (c *Checker) SetExclude(l map[string]bool) { } { c.exclude[exc] = true } + + for k := range l { + c.exclude[k] = true + } } func (c *Checker) logf(msg string, args ...interface{}) { From cd18fe797be8880ca6dfb65755352bed112f0265 Mon Sep 17 00:00:00 2001 From: Louis Sobel Date: Sat, 19 May 2018 11:36:21 -0400 Subject: [PATCH 7/7] documentation work --- internal/errcheck/embedded_walker.go | 5 +++ internal/errcheck/errcheck.go | 49 +++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/internal/errcheck/embedded_walker.go b/internal/errcheck/embedded_walker.go index 9aac6c2..3b31925 100644 --- a/internal/errcheck/embedded_walker.go +++ b/internal/errcheck/embedded_walker.go @@ -94,6 +94,11 @@ func getTypeAtFieldIndex(startingAt types.Type, fieldIndex int) types.Type { return s.Field(fieldIndex).Type() } +// getEmbeddedInterfaceDefiningMethod searches through any embedded interfaces of the +// passed interface searching for one that defines the given function. If found, the +// types.Named wrapping that interface will be returned along with true in the second value. +// +// If no such embedded interface is found, nil and false are returned. func getEmbeddedInterfaceDefiningMethod(interfaceT *types.Interface, fn *types.Func) (*types.Named, bool) { for i := 0; i < interfaceT.NumEmbeddeds(); i++ { embedded := interfaceT.Embedded(i) diff --git a/internal/errcheck/errcheck.go b/internal/errcheck/errcheck.go index d66c2ba..c6615eb 100644 --- a/internal/errcheck/errcheck.go +++ b/internal/errcheck/errcheck.go @@ -239,6 +239,15 @@ type visitor struct { errors []UncheckedError } +// selectorAndFunc tries to get the selector and function from call expression. +// For example, given the call expression representing "a.b()", the selector +// is "a.b" and the function is "b" itself. +// +// The final return value will be true if it is able to do extract a selector +// from the call and look up the function object it refers to. +// +// If the call does not include a selector (like if it is a plain "f()" function call) +// then the final return value will be false. func (v *visitor) selectorAndFunc(call *ast.CallExpr) (*ast.SelectorExpr, *types.Func, bool) { sel, ok := call.Fun.(*ast.SelectorExpr) if !ok { @@ -255,29 +264,47 @@ func (v *visitor) selectorAndFunc(call *ast.CallExpr) (*ast.SelectorExpr, *types } -func (v *visitor) fullName(call *ast.CallExpr) (string, bool) { +// fullName will return a package / receiver-type qualified name for a called function +// if the function is the result of a selector. Otherwise it will return +// the empty string. +// +// The name is fully qualified by the import path, possible type, +// function/method name and pointer receiver. +// +// For example, +// - for "fmt.Printf(...)" it will return "fmt.Printf" +// - for "base64.StdEncoding.Decode(...)" it will return "(*encoding/base64.Encoding).Decode" +// - for "myFunc()" it will return "" +func (v *visitor) fullName(call *ast.CallExpr) string { _, fn, ok := v.selectorAndFunc(call) if !ok { - return "", false + return "" } - // The name is fully qualified by the import path, possible type, - // function/method name and pointer receiver. - // + // TODO(dh): vendored packages will have /vendor/ in their name, // thus not matching vendored standard library packages. If we // want to support vendored stdlib packages, we need to implement // FullName with our own logic. - return fn.FullName(), true + return fn.FullName() } +// namesForExcludeCheck will return a list of fully-qualified function names +// from a function call that can be used to check against the exclusion list. +// +// If a function call is against a local function (like "myFunc()") then no +// names are returned. If the function is package-qualified (like "fmt.Printf()") +// then just that function's fullName is returned. +// +// Otherwise, we walk through all the potentially embeddded interfaces of the receiver +// the collect a list of type-qualified function names that we will check. func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string { sel, fn, ok := v.selectorAndFunc(call) if !ok { return nil } - name, ok := v.fullName(call) - if !ok { + name := v.fullName(call) + if name == "" { return nil } @@ -297,6 +324,10 @@ func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string { result := make([]string, len(ts)) for i, t := range ts { + // Like in fullName, vendored packages will have /vendor/ in their name, + // thus not matching vendored standard library packages. If we + // want to support vendored stdlib packages, we need to implement + // additional logic here. result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name()) } return result @@ -438,7 +469,7 @@ func (v *visitor) addErrorAtPosition(position token.Pos, call *ast.CallExpr) { var name string if call != nil { - name, _ = v.fullName(call) + name = v.fullName(call) } v.errors = append(v.errors, UncheckedError{pos, line, name})