Skip to content

Commit

Permalink
Introduce pre-order / post-order visitor pattern (#813)
Browse files Browse the repository at this point in the history
* Introduce pre-order / post-order visitor pattern
  • Loading branch information
TristonianJones committed Aug 19, 2023
1 parent eaebecb commit 1a6373d
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 28 deletions.
25 changes: 25 additions & 0 deletions common/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ func Copy(a *AST) *AST {
return NewCheckedAST(NewAST(e, CopySourceInfo(a.SourceInfo())), typesCopy, refsCopy)
}

// MaxID returns the upper-bound, non-inclusive, of ids present within the AST's Expr value.
func MaxID(a *AST) int64 {
visitor := &maxIDVisitor{maxID: 1}
PostOrderVisit(a.Expr(), visitor)
return visitor.maxID + 1
}

// NewSourceInfo creates a simple SourceInfo object from an input common.Source value.
func NewSourceInfo(src common.Source) *SourceInfo {
var lineOffsets []int32
Expand Down Expand Up @@ -397,3 +404,21 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool {
}
return true
}

type maxIDVisitor struct {
maxID int64
}

// VisitExpr updates the max identifier if the incoming expression id is greater than previously observed.
func (v *maxIDVisitor) VisitExpr(e Expr) {
if v.maxID < e.ID() {
v.maxID = e.ID()
}
}

// VisitEntryExpr updates the max identifier if the incoming entry id is greater than previously observed.
func (v *maxIDVisitor) VisitEntryExpr(e EntryExpr) {
if v.maxID < e.ID() {
v.maxID = e.ID()
}
}
48 changes: 38 additions & 10 deletions common/ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"
)

func TestASTCopy(t *testing.T) {
Expand All @@ -40,12 +42,38 @@ func TestASTCopy(t *testing.T) {

for _, tst := range tests {
checked := mustTypeCheck(t, tst)
copied := ast.Copy(checked)
if !reflect.DeepEqual(copied.Expr(), checked.Expr()) {
t.Errorf("Copy() got expr %v, wanted %v", copied.Expr(), checked.Expr())
copyChecked := ast.Copy(checked)
if !reflect.DeepEqual(copyChecked.Expr(), checked.Expr()) {
t.Errorf("Copy() got expr %v, wanted %v", copyChecked.Expr(), checked.Expr())
}
if !reflect.DeepEqual(copied.SourceInfo(), checked.SourceInfo()) {
t.Errorf("Copy() got source info %v, wanted %v", copied.SourceInfo(), checked.SourceInfo())
if !reflect.DeepEqual(copyChecked.SourceInfo(), checked.SourceInfo()) {
t.Errorf("Copy() got source info %v, wanted %v", copyChecked.SourceInfo(), checked.SourceInfo())
}
copyParsed := ast.Copy(ast.NewAST(checked.Expr(), checked.SourceInfo()))
if !reflect.DeepEqual(copyParsed.Expr(), checked.Expr()) {
t.Errorf("Copy() got expr %v, wanted %v", copyParsed.Expr(), checked.Expr())
}
if !reflect.DeepEqual(copyParsed.SourceInfo(), checked.SourceInfo()) {
t.Errorf("Copy() got source info %v, wanted %v", copyParsed.SourceInfo(), checked.SourceInfo())
}
checkedPB, err := ast.ToProto(checked)
if err != nil {
t.Errorf("ast.ToProto() failed: %v", err)
}
copyCheckedPB, err := ast.ToProto(copyChecked)
if err != nil {
t.Errorf("ast.ToProto() failed: %v", err)
}
if !proto.Equal(checkedPB, copyCheckedPB) {
t.Errorf("Copy() produced different proto results, got %v, wanted %v",
prototext.Format(checkedPB), prototext.Format(copyCheckedPB))
}
checkedRoundtrip, err := ast.ToAST(checkedPB)
if err != nil {
t.Errorf("ast.ToAST() failed: %v", err)
}
if !reflect.DeepEqual(checked, checkedRoundtrip) {
t.Errorf("Roundtrip got %v, wanted %v", checkedRoundtrip, checked)
}
}
}
Expand All @@ -67,10 +95,10 @@ func TestASTNilSafety(t *testing.T) {
ast.NewAST(ex, info),
ast.NewCheckedAST(ast.NewAST(ex, info), map[int64]*types.Type{}, map[int64]*ast.ReferenceInfo{}),
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
testAST := tc
for _, tst := range tests {
a := tst
asts := []*ast.AST{a, ast.Copy(a)}
for _, testAST := range asts {
if testAST.Expr().ID() != 0 {
t.Errorf("Expr().ID() got %v, wanted 0", testAST.Expr().ID())
}
Expand All @@ -86,7 +114,7 @@ func TestASTNilSafety(t *testing.T) {
if len(testAST.GetOverloadIDs(testAST.Expr().ID())) != 0 {
t.Errorf("GetOverloadIDs() got %v, wanted empty set", testAST.GetOverloadIDs(testAST.Expr().ID()))
}
})
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion common/ast/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) {
return sourceInfo, nil
}

// ProtoToSourceInfo deserializes
// ProtoToSourceInfo deserializes the protobuf into a native SourceInfo value.
func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) {
sourceInfo := &SourceInfo{
syntax: info.GetSyntaxVersion(),
Expand Down
7 changes: 7 additions & 0 deletions common/ast/conversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ func TestConvertExpr(t *testing.T) {
if !reflect.DeepEqual(gotExpr, tc.wantExpr) {
t.Errorf("got %v, wanted %v", gotExpr, tc.wantExpr)
}
gotExprRoundtrip, err := ast.ProtoToExpr(gotPBExpr)
if err != nil {
t.Fatalf("ast.ProtoToExpr() failed: %v", err)
}
if !reflect.DeepEqual(parsed.Expr(), gotExprRoundtrip) {
t.Errorf("ast.ProtoToExpr() got %v, wanted %v", gotExprRoundtrip, parsed.Expr())
}
info := parsed.SourceInfo()
for id, wantCall := range tc.macroCalls {
call, found := info.GetMacroCall(id)
Expand Down
18 changes: 18 additions & 0 deletions common/ast/factory.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ast

import "github.com/google/cel-go/common/types/ref"
Expand Down Expand Up @@ -183,6 +197,10 @@ func (fac *baseExprFactory) NewUnspecifiedExpr(id int64) Expr {
}

func (fac *baseExprFactory) CopyExpr(e Expr) Expr {
// unwrap navigable expressions to avoid unnecessary allocations during copying.
if nav, ok := e.(*navigableExprImpl); ok {
e = nav.Expr
}
switch e.Kind() {
case CallKind:
c := e.AsCall()
Expand Down
163 changes: 146 additions & 17 deletions common/ast/navigable.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ast

import (
Expand Down Expand Up @@ -87,34 +101,149 @@ func AllMatcher() ExprMatcher {
}
}

// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values of the
// descendants which match.
// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values
// matching the input criteria in post-order (bottom up).
func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr {
return matchListInternal([]NavigableExpr{expr}, matcher, true)
matches := []NavigableExpr{}
navVisitor := &baseVisitor{
visitExpr: func(e Expr) {
nav := e.(NavigableExpr)
if matcher(nav) {
matches = append(matches, nav)
}
},
}
visit(expr, navVisitor, postOrder, 0, 0)
return matches
}

// MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a
// subset of NavigableExpr values which match.
func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr {
visit := make([]NavigableExpr, len(exprs))
copy(visit, exprs)
return matchListInternal(visit, matcher, false)
matches := []NavigableExpr{}
navVisitor := &baseVisitor{
visitExpr: func(e Expr) {
nav := e.(NavigableExpr)
if matcher(nav) {
matches = append(matches, nav)
}
},
}
for _, expr := range exprs {
visit(expr, navVisitor, postOrder, 0, 1)
}
return matches
}

// Visitor defines an object for visiting Expr and EntryExpr nodes within an expression graph.
type Visitor interface {
// VisitExpr visits the input expression.
VisitExpr(Expr)

// VisitEntryExpr visits the input entry expression, i.e. a struct field or map entry.
VisitEntryExpr(EntryExpr)
}

type baseVisitor struct {
visitExpr func(Expr)
visitEntryExpr func(EntryExpr)
}

// VisitExpr visits the Expr if the internal expr visitor has been configured.
func (v *baseVisitor) VisitExpr(e Expr) {
if v.visitExpr != nil {
v.visitExpr(e)
}
}

// VisitEntryExpr visits the entry if the internal expr entry visitor has been configured.
func (v *baseVisitor) VisitEntryExpr(e EntryExpr) {
if v.visitEntryExpr != nil {
v.visitEntryExpr(e)
}
}

// NewExprVisitor creates a visitor which only visits expression nodes.
func NewExprVisitor(v func(Expr)) Visitor {
return &baseVisitor{
visitExpr: v,
visitEntryExpr: nil,
}
}

func matchListInternal(visit []NavigableExpr, matcher ExprMatcher, visitDescendants bool) []NavigableExpr {
var matched []NavigableExpr
for len(visit) != 0 {
e := visit[0]
if matcher(e) {
matched = append(matched, e)
// PostOrderVisit walks the expression graph and calls the visitor in post-order (bottom-up).
func PostOrderVisit(expr Expr, visitor Visitor) {
visit(expr, visitor, postOrder, 0, 0)
}

// PreOrderVisit walks the expression graph and calls the visitor in pre-order (top-down).
func PreOrderVisit(expr Expr, visitor Visitor) {
visit(expr, visitor, preOrder, 0, 0)
}

type visitOrder int

const (
preOrder = iota + 1
postOrder
)

// TODO: consider exposing a way to configure a limit for the max visit depth.
// It's possible that we could want to configure this on the NewExprVisitor()
// and through MatchDescendents() / MaxID().
func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) {
if maxDepth > 0 && depth == maxDepth {
return
}
if order == preOrder {
visitor.VisitExpr(expr)
}
switch expr.Kind() {
case CallKind:
c := expr.AsCall()
if c.IsMemberFunction() {
visit(c.Target(), visitor, order, depth+1, maxDepth)
}
if visitDescendants {
visit = append(visit[1:], e.Children()...)
} else {
visit = visit[1:]
for _, arg := range c.Args() {
visit(arg, visitor, order, depth+1, maxDepth)
}
case ComprehensionKind:
c := expr.AsComprehension()
visit(c.IterRange(), visitor, order, depth+1, maxDepth)
visit(c.AccuInit(), visitor, order, depth+1, maxDepth)
visit(c.LoopCondition(), visitor, order, depth+1, maxDepth)
visit(c.LoopStep(), visitor, order, depth+1, maxDepth)
visit(c.Result(), visitor, order, depth+1, maxDepth)
case ListKind:
l := expr.AsList()
for _, elem := range l.Elements() {
visit(elem, visitor, order, depth+1, maxDepth)
}
case MapKind:
m := expr.AsMap()
for _, e := range m.Entries() {
if order == preOrder {
visitor.VisitEntryExpr(e)
}
entry := e.AsMapEntry()
visit(entry.Key(), visitor, order, depth+1, maxDepth)
visit(entry.Value(), visitor, order, depth+1, maxDepth)
if order == postOrder {
visitor.VisitEntryExpr(e)
}
}
case SelectKind:
visit(expr.AsSelect().Operand(), visitor, order, depth+1, maxDepth)
case StructKind:
s := expr.AsStruct()
for _, f := range s.Fields() {
visitor.VisitEntryExpr(f)
visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth)
}
}
if order == postOrder {
visitor.VisitExpr(expr)
}
return matched
}

func matchIsConstantValue(e NavigableExpr) bool {
Expand Down
Loading

0 comments on commit 1a6373d

Please sign in to comment.