Skip to content

Commit

Permalink
ruleguard: add support for local functions
Browse files Browse the repository at this point in the history
This feature is useful for rule filters readability improvements.

Instead of copying a complex `Where()` expression several times,
one can now use a local function literal to define that filter
operation and use it inside `Where()` expressions.

Here is an example:

```go
func preferFprint(m dsl.Matcher) {
	isFmtPackage := func(v dsl.Var) bool {
		return v.Text == "fmt" && v.Object.Is(`PkgName`)
	}

	m.Match(`$w.Write([]byte($fmt.Sprint($*args)))`).
		Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])).
		Suggest("fmt.Fprint($w, $args)").
		Report(`fmt.Fprint($w, $args) should be preferred to the $$`)

	m.Match(`$w.Write([]byte($fmt.Sprintf($*args)))`).
		Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])).
		Suggest("fmt.Fprintf($w, $args)").
		Report(`fmt.Fprintf($w, $args) should be preferred to the $$`)

	// ...etc
}
```

Note that we used `isFmtPackage` in more than 1 rule.

Functions can accept almost arbitrary params, but there are some
restrictions on what kinds of arguments they can receive right now.

These arguments work:

* Matcher var expressions like `m["varname"]`
* Basic literals like `"foo"`, `104`, `5.2`
* Constants
  • Loading branch information
quasilyte committed Nov 7, 2021
1 parent 34f5283 commit e36eeaf
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ test-release:
@echo "everything is OK"

lint:
curl -sSfL https://github.com/raw/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GOPATH_DIR)/bin v1.30.0
curl -sSfL https://github.com/raw/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GOPATH_DIR)/bin v1.43.0
$(GOPATH_DIR)/bin/golangci-lint run ./...
go build -o go-ruleguard ./cmd/ruleguard
./go-ruleguard -debug-imports -rules rules.go ./...
Expand Down
1 change: 1 addition & 0 deletions analyzer/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ var tests = []struct {
{name: "comments"},
{name: "stdlib"},
{name: "uber"},
{name: "localfunc"},
{name: "goversion", flags: map[string]string{"go": "1.16"}},
}

Expand Down
13 changes: 7 additions & 6 deletions analyzer/testdata/src/gocritic/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,22 @@ func appendAssign(m dsl.Matcher) {
//doc:before w.Write([]byte(fmt.Sprintf("%x", 10)))
//doc:after fmt.Fprintf(w, "%x", 10)
func preferFprint(m dsl.Matcher) {
isFmtPackage := func(v dsl.Var) bool {
return v.Text == "fmt" && v.Object.Is(`PkgName`)
}

m.Match(`$w.Write([]byte($fmt.Sprint($*args)))`).
Where(m["w"].Type.Implements("io.Writer") &&
m["fmt"].Text == "fmt" && m["fmt"].Object.Is(`PkgName`)).
Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])).
Suggest("fmt.Fprint($w, $args)").
Report(`fmt.Fprint($w, $args) should be preferred to the $$`)

m.Match(`$w.Write([]byte($fmt.Sprintf($*args)))`).
Where(m["w"].Type.Implements("io.Writer") &&
m["fmt"].Text == "fmt" && m["fmt"].Object.Is(`PkgName`)).
Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])).
Suggest("fmt.Fprintf($w, $args)").
Report(`fmt.Fprintf($w, $args) should be preferred to the $$`)

m.Match(`$w.Write([]byte($fmt.Sprintln($*args)))`).
Where(m["w"].Type.Implements("io.Writer") &&
m["fmt"].Text == "fmt" && m["fmt"].Object.Is(`PkgName`)).
Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])).
Suggest("fmt.Fprintln($w, $args)").
Report(`fmt.Fprintln($w, $args) should be preferred to the $$`)
}
Expand Down
43 changes: 43 additions & 0 deletions analyzer/testdata/src/localfunc/rules.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//go:build ignore
// +build ignore

package gorules

import "github.com/quasilyte/go-ruleguard/dsl"

func testRules(m dsl.Matcher) {
bothConst := func(x, y dsl.Var) bool {
return x.Const && y.Const
}
m.Match(`test("both const", $x, $y)`).
Where(bothConst(m["x"], m["y"])).
Report(`true`)

intValue := func(x dsl.Var, val int) bool {
return x.Value.Int() == val
}
m.Match(`test("== 10", $x)`).
Where(intValue(m["x"], 10)).
Report(`true`)

isZero := func(x dsl.Var) bool { return x.Value.Int() == 0 }
m.Match(`test("== 0", $x)`).
Where(isZero(m["x"])).
Report(`true`)

// Testing closure-captured m variable.
fmtIsImported := func() bool {
return m.File().Imports(`fmt`)
}
m.Match(`test("fmt is imported")`).
Where(fmtIsImported()).
Report(`true`)

// Testing explicitly passed matcher.
ioutilIsImported := func(m2 dsl.Matcher) bool {
return m2.File().Imports(`io/ioutil`)
}
m.Match(`test("ioutil is imported")`).
Where(ioutilIsImported(m)).
Report(`true`)
}
35 changes: 35 additions & 0 deletions analyzer/testdata/src/localfunc/target.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package localfunc

import (
"fmt"
"io/ioutil"
)

func test(args ...interface{}) {}

func _() {
fmt.Println("ok")
_ = ioutil.Discard

var i int

test("both const", 1, 2) // want `true`
test("both const", 1, 2+2) // want `true`
test("both const", i, 2)
test("both const", 1, i)
test("both const", i, i)

test("== 10", 10) // want `true`
test("== 10", 9+1) // want `true`
test("== 10", 11)
test("== 10", i)

test("== 0", 0) // want `true`
test("== 0", 1-1) // want `true`
test("== 0", 11)
test("== 0", i)

test("fmt is imported") // want `true`

test("ioutil is imported") // want `true`
}
7 changes: 7 additions & 0 deletions analyzer/testdata/src/localfunc/target2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package localfunc

func _() {
test("fmt is imported")

test("ioutil is imported")
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/quasilyte/go-ruleguard
go 1.15

require (
github.com/go-toolsmith/astcopy v1.0.0
github.com/go-toolsmith/astequal v1.0.1
github.com/google/go-cmp v0.5.2
github.com/quasilyte/go-ruleguard/dsl v0.3.10
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
github.com/go-toolsmith/astcopy v1.0.0 h1:OMgl1b1MEpjFQ1m5ztEO06rz5CUd3oBv9RF7+DyvdG8=
github.com/go-toolsmith/astcopy v1.0.0/go.mod h1:vrgyG+5Bxrnz4MZWPF+pI4R8h3qKRjjyvV/DSez4WVQ=
github.com/go-toolsmith/astequal v1.0.0/go.mod h1:H+xSiq0+LtiDC11+h1G32h7Of5O3CYFJ99GVbS5lDKY=
github.com/go-toolsmith/astequal v1.0.1 h1:JbSszi42Jiqu36Gnf363HWS9MTEAz67vTQLponh3Moc=
github.com/go-toolsmith/astequal v1.0.1/go.mod h1:4oGA3EZXTVItV/ipGiOx7NWkY5veFfcsOJVS2YxltLw=
github.com/go-toolsmith/strparse v1.0.0 h1:Vcw78DnpCAKlM20kSbAyO4mPfJn/lyYA4BJUDxe2Jb4=
Expand Down
14 changes: 14 additions & 0 deletions ruleguard/debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,20 @@ func TestDebug(t *testing.T) {
` $x []string: []string{"x"}`,
},
},

`isConst := func(v dsl.Var) bool { return v.Const }; m.Match("_ = $x").Where(isConst(m["x"]) && !m["x"].Type.Is("string"))`: {
`_ = 10`: nil,

`_ = "str"`: {
`input.go:4: [rules.go:5] rejected by !m["x"].Type.Is("string")`,
` $x string: "str"`,
},

`_ = f()`: {
`input.go:4: [rules.go:5] rejected by isConst(m["x"])`,
` $x interface{}: f()`,
},
},
}

loadRulesFromExpr := func(e *Engine, s string) {
Expand Down
160 changes: 159 additions & 1 deletion ruleguard/irconv/irconv.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strconv"
"strings"

"github.com/go-toolsmith/astcopy"
"github.com/quasilyte/go-ruleguard/ruleguard/goutil"
"github.com/quasilyte/go-ruleguard/ruleguard/ir"
"golang.org/x/tools/go/ast/astutil"
Expand Down Expand Up @@ -52,13 +53,20 @@ type convError struct {
err error
}

type localMacroFunc struct {
name string
params []string
template ast.Expr
}

type converter struct {
types *types.Info
pkg *types.Package
fset *token.FileSet
src []byte

group *ir.RuleGroup
group *ir.RuleGroup
groupFuncs []localMacroFunc

dslPkgname string // The local name of the "ruleguard/dsl" package (usually its just "dsl")
}
Expand Down Expand Up @@ -171,6 +179,7 @@ func (conv *converter) convertRuleGroup(decl *ast.FuncDecl) *ir.RuleGroup {
Line: conv.fset.Position(decl.Name.Pos()).Line,
}
conv.group = result
conv.groupFuncs = conv.groupFuncs[:0]

result.Name = decl.Name.String()
result.MatcherName = decl.Type.Params.List[0].Names[0].String()
Expand All @@ -181,6 +190,11 @@ func (conv *converter) convertRuleGroup(decl *ast.FuncDecl) *ir.RuleGroup {

seenRules := false
for _, stmt := range decl.Body.List {
if assign, ok := stmt.(*ast.AssignStmt); ok && assign.Tok == token.DEFINE {
conv.localDefine(assign)
continue
}

if _, ok := stmt.(*ast.DeclStmt); ok {
continue
}
Expand Down Expand Up @@ -208,6 +222,146 @@ func (conv *converter) convertRuleGroup(decl *ast.FuncDecl) *ir.RuleGroup {
return result
}

func (conv *converter) findLocalMacro(call *ast.CallExpr) *localMacroFunc {
fn, ok := call.Fun.(*ast.Ident)
if !ok {
return nil
}
for i := range conv.groupFuncs {
if conv.groupFuncs[i].name == fn.Name {
return &conv.groupFuncs[i]
}
}
return nil
}

func (conv *converter) expandMacro(macro *localMacroFunc, call *ast.CallExpr) ir.FilterExpr {
// Check that call args are OK.
// Since "function calls" are implemented as a macro expansion here,
// we don't allow arguments that have a non-trivial evaluation.
isSafe := func(arg ast.Expr) bool {
switch arg := astutil.Unparen(arg).(type) {
case *ast.BasicLit, *ast.Ident:
return true

case *ast.IndexExpr:
mapIdent, ok := astutil.Unparen(arg.X).(*ast.Ident)
if !ok {
return false
}
if mapIdent.Name != conv.group.MatcherName {
return false
}
key, ok := astutil.Unparen(arg.Index).(*ast.BasicLit)
if !ok || key.Kind != token.STRING {
return false
}
return true

default:
return false
}
}
args := map[string]ast.Expr{}
for i, arg := range call.Args {
paramName := macro.params[i]
if !isSafe(arg) {
panic(conv.errorf(arg, "unsupported/too complex %s argument", paramName))
}
args[paramName] = astutil.Unparen(arg)
}

body := astcopy.Expr(macro.template)
expanded := astutil.Apply(body, nil, func(cur *astutil.Cursor) bool {
if ident, ok := cur.Node().(*ast.Ident); ok {
arg, ok := args[ident.Name]
if ok {
cur.Replace(arg)
return true
}
}
// astcopy above will copy the AST tree, but it won't update
// the associated types.Info map of const values.
// We'll try to solve that issue at least partially here.
if lit, ok := cur.Node().(*ast.BasicLit); ok {
switch lit.Kind {
case token.STRING:
val, err := strconv.Unquote(lit.Value)
if err == nil {
conv.types.Types[lit] = types.TypeAndValue{
Type: types.Typ[types.UntypedString],
Value: constant.MakeString(val),
}
}
case token.INT:
val, err := strconv.ParseInt(lit.Value, 0, 64)
if err == nil {
conv.types.Types[lit] = types.TypeAndValue{
Type: types.Typ[types.UntypedInt],
Value: constant.MakeInt64(val),
}
}
case token.FLOAT:
val, err := strconv.ParseFloat(lit.Value, 64)
if err == nil {
conv.types.Types[lit] = types.TypeAndValue{
Type: types.Typ[types.UntypedFloat],
Value: constant.MakeFloat64(val),
}
}
}
}
return true
})

return conv.convertFilterExpr(expanded.(ast.Expr))
}

func (conv *converter) localDefine(assign *ast.AssignStmt) {
if len(assign.Lhs) != 1 || len(assign.Rhs) != 1 {
panic(conv.errorf(assign, "multi-value := is not supported"))
}
lhs, ok := assign.Lhs[0].(*ast.Ident)
if !ok {
panic(conv.errorf(assign.Lhs[0], "only simple ident lhs is supported"))
}
rhs := assign.Rhs[0]
fn, ok := rhs.(*ast.FuncLit)
if !ok {
panic(conv.errorf(rhs, "only func literals are supported on the rhs"))
}
typ := conv.types.TypeOf(fn).(*types.Signature)
isBoolResult := typ.Results() != nil &&
typ.Results().Len() == 1 &&
typ.Results().At(0).Type() == types.Typ[types.Bool]
if !isBoolResult {
var loc ast.Node = fn.Type
if fn.Type.Results != nil {
loc = fn.Type.Results
}
panic(conv.errorf(loc, "only funcs returning bool are supported"))
}
if len(fn.Body.List) != 1 {
panic(conv.errorf(fn.Body, "only simple 1 return statement funcs are supported"))
}
stmt, ok := fn.Body.List[0].(*ast.ReturnStmt)
if !ok {
panic(conv.errorf(fn.Body.List[0], "expected a return statement, found %T", fn.Body.List[0]))
}
var params []string
for _, field := range fn.Type.Params.List {
for _, id := range field.Names {
params = append(params, id.Name)
}
}
macro := localMacroFunc{
name: lhs.Name,
params: params,
template: stmt.Results[0],
}
conv.groupFuncs = append(conv.groupFuncs, macro)
}

func (conv *converter) doMatcherImport(call *ast.CallExpr) {
pkgPath := conv.parseStringArg(call.Args[0])
pkgName := path.Base(pkgPath)
Expand Down Expand Up @@ -518,6 +672,10 @@ func (conv *converter) convertFilterExprImpl(e ast.Expr) ir.FilterExpr {
return ir.FilterExpr{Op: ir.FilterVarFilterOp, Value: op.varName, Args: args}
}

if macro := conv.findLocalMacro(e); macro != nil {
return conv.expandMacro(macro, e)
}

args := convertExprList(e.Args)
switch op.path {
case "Value.Int":
Expand Down
Loading

0 comments on commit e36eeaf

Please sign in to comment.