Skip to content

Commit

Permalink
fix: in-line type assertions not comparable
Browse files Browse the repository at this point in the history
* engine,parse: fix in-line action error

This change fixes two bugs:

- Actions using in-line statements with type casts for ambiguous types would not
  generate SQL properly. This would only occur when performing an expression
  (such as a comparison). This is the main purpose of this PR.
- Another very minor bug that caused a stack overflow when using the parse
  debugger tool was fixed. It was caused by having circular pointer references
  while removing error positions.

* fix failing tests
  • Loading branch information
brennanjl committed Jul 15, 2024
1 parent b3460a7 commit faf879e
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 11 deletions.
7 changes: 6 additions & 1 deletion internal/engine/execution/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,12 @@ func (g *GlobalContext) Execute(ctx context.Context, tx sql.DB, dbid, query stri
args := orderAndCleanValueMap(values, params)
args = append([]any{pg.QueryModeExec}, args...)

return tx.Execute(ctx, sqlStmt, args...)
result, err := tx.Execute(ctx, sqlStmt, args...)
if err != nil {
return nil, decorateExecuteErr(err, query)
}

return result, nil
}

type dbQueryFn func(ctx context.Context, stmt string, args ...any) (*sql.ResultSet, error)
Expand Down
18 changes: 17 additions & 1 deletion internal/engine/execution/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"reflect"
"strings"

"github.com/jackc/pgx/v5/pgconn"
"github.com/kwilteam/kwil-db/common"
sql "github.com/kwilteam/kwil-db/common/sql"
"github.com/kwilteam/kwil-db/core/types"
Expand Down Expand Up @@ -40,6 +41,7 @@ var (
ErrPrivateProcedure = errors.New("procedure is private")
ErrMutativeProcedure = errors.New("procedure is mutative")
ErrMaxStackDepth = errors.New("max call stack depth reached")
ErrCannotInferType = errors.New("cannot infer type")
)

// instruction is an instruction that can be executed.
Expand Down Expand Up @@ -313,6 +315,20 @@ type dmlStmt struct {
OrderedParameters []string
}

// decorateExecuteErr parses an execute error from postgres and tries to give a more helpful error message.
// this allows us to give a more helpful error message when users hit this,
// since the Postgres error message is not helpful, and this is a common error.
func decorateExecuteErr(err error, stmt string) error {
// this catches a common error case for in-line expressions, where the type cannot be inferred
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "42P08" || pgErr.Code == "42P18" {
return fmt.Errorf(`%w: could not dynamically determine the data type in statement "%s". try type casting using ::, e.g. $id::text`,
ErrCannotInferType, stmt)
}

return err
}

var _ instructionFunc = (&dmlStmt{}).execute

func (e *dmlStmt) execute(scope *precompiles.ProcedureContext, _ *GlobalContext, db sql.DB) error {
Expand All @@ -321,7 +337,7 @@ func (e *dmlStmt) execute(scope *precompiles.ProcedureContext, _ *GlobalContext,
// args := append([]any{pg.QueryModeExec}, params...)
results, err := db.Execute(scope.Ctx, e.SQLStatement, append([]any{pg.QueryModeExec}, params...)...)
if err != nil {
return err
return decorateExecuteErr(err, e.SQLStatement)
}

// we need to check for any pg numeric types returned, and convert them to int64
Expand Down
10 changes: 8 additions & 2 deletions internal/engine/generate/plpgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package generate

import (
"fmt"
"strconv"
"strings"

"github.com/kwilteam/kwil-db/core/types"
Expand Down Expand Up @@ -126,15 +127,20 @@ func (s *sqlGenerator) VisitExpressionVariable(p0 *parse.ExpressionVariable) any
// if it already exists, we write it as that index.
for i, v := range s.orderedParams {
if v == str {
return "$" + fmt.Sprint(i+1)
return "$" + strconv.Itoa(i+1)
}
}

// otherwise, we add it to the list.
// Postgres uses $1, $2, etc. for numbered parameters.

s.orderedParams = append(s.orderedParams, str)
return "$" + fmt.Sprint(len(s.orderedParams))

res := strings.Builder{}
res.WriteString("$")
res.WriteString(strconv.Itoa(len(s.orderedParams)))
typeCast(p0, &res)
return res.String()
}

str := strings.Builder{}
Expand Down
17 changes: 17 additions & 0 deletions internal/engine/integration/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"testing"

"github.com/kwilteam/kwil-db/internal/engine/execution"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -202,6 +203,22 @@ func Test_SQL(t *testing.T) {
{"4a67d6ea-7ac8-453c-964e-5a144f9e3004"},
},
},
{
name: "inferred type - failure",
sql: "select $id is null",
values: map[string]any{
"$id": "4a67d6ea-7ac8-453c-964e-5a144f9e3004",
},
err: execution.ErrCannotInferType,
},
{
name: "inferred type - success",
sql: "select $id::text is null",
values: map[string]any{
"$id": "4a67d6ea-7ac8-453c-964e-5a144f9e3004",
},
want: [][]any{{false}},
},
}

for _, tt := range tests {
Expand Down
29 changes: 22 additions & 7 deletions parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,16 +454,18 @@ func setupParser(inputStream string, errLisName string) (errLis *errorListener,
// It is used in both parsing tools, as well as in tests.
// WARNING: This function should NEVER be used in consensus, since it is non-deterministic.
func RecursivelyVisitPositions(v any, fn func(GetPositioner)) {

visited := make(map[uintptr]struct{})
visitRecursive(reflect.ValueOf(v), reflect.TypeOf((*GetPositioner)(nil)).Elem(), func(v reflect.Value) {
if v.CanInterface() {
a := v.Interface().(GetPositioner)
fn(a)
}
})
}, visited)
}

// visitRecursive is a recursive function that visits all types that implement the target interface.
func visitRecursive(v reflect.Value, target reflect.Type, fn func(reflect.Value)) {
func visitRecursive(v reflect.Value, target reflect.Type, fn func(reflect.Value), visited map[uintptr]struct{}) {
if v.Type().Implements(target) {
// check if the value is nil
if !v.IsNil() {
Expand All @@ -472,23 +474,36 @@ func visitRecursive(v reflect.Value, target reflect.Type, fn func(reflect.Value)
}

switch v.Kind() {
case reflect.Ptr, reflect.Interface:
case reflect.Interface:
if v.IsNil() {
return
}

visitRecursive(v.Elem(), target, fn, visited)
case reflect.Ptr:
if v.IsNil() {
return
}

visitRecursive(v.Elem(), target, fn)
// check if we have visited this pointer before
ptr := v.Pointer()
if _, ok := visited[ptr]; ok {
return
}
visited[ptr] = struct{}{}

visitRecursive(v.Elem(), target, fn, visited)
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
visitRecursive(v.Field(i), target, fn)
visitRecursive(v.Field(i), target, fn, visited)
}
case reflect.Slice, reflect.Array:
for i := 0; i < v.Len(); i++ {
visitRecursive(v.Index(i), target, fn)
visitRecursive(v.Index(i), target, fn, visited)
}
case reflect.Map:
for _, key := range v.MapKeys() {
visitRecursive(v.MapIndex(key), target, fn)
visitRecursive(v.MapIndex(key), target, fn, visited)
}
}
}

0 comments on commit faf879e

Please sign in to comment.