diff --git a/internal/query/statement/update.go b/internal/query/statement/update.go index 556a28989..b4cce7787 100644 --- a/internal/query/statement/update.go +++ b/internal/query/statement/update.go @@ -1,6 +1,7 @@ package statement import ( + "github.com/cockroachdb/errors" "github.com/genjidb/genji/document" "github.com/genjidb/genji/internal/expr" "github.com/genjidb/genji/internal/stream" @@ -42,18 +43,44 @@ type UpdateSetPair struct { // Prepare implements the Preparer interface. func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) { + ti, err := c.Catalog.GetTableInfo(stmt.TableName) + if err != nil { + return nil, err + } + pk := ti.GetPrimaryKey() + s := stream.New(stream.TableScan(stmt.TableName)) if stmt.WhereExpr != nil { s = s.Pipe(stream.DocsFilter(stmt.WhereExpr)) } + var pkModified bool if stmt.SetPairs != nil { for _, pair := range stmt.SetPairs { + // if we modify the primary key, + // we must remove the old document and create an new one + if pk != nil && !pkModified { + for _, p := range pk.Paths { + if p.IsEqual(pair.Path) { + pkModified = true + break + } + } + } s = s.Pipe(stream.PathsSet(pair.Path, pair.E)) } } else if stmt.UnsetFields != nil { for _, name := range stmt.UnsetFields { + // ensure we do not unset any path the is used in the primary key + if pk != nil { + path := document.NewPath(name) + for _, p := range pk.Paths { + if p.IsEqual(path) { + return nil, errors.New("cannot unset primary key path") + } + } + } s = s.Pipe(stream.PathsUnset(name)) } } @@ -69,7 +96,12 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) { s = s.Pipe(stream.IndexDelete(indexName)) } - s = s.Pipe(stream.TableReplace(stmt.TableName)) + if pkModified { + s = s.Pipe(stream.TableDelete(stmt.TableName)) + s = s.Pipe(stream.TableInsert(stmt.TableName)) + } else { + s = s.Pipe(stream.TableReplace(stmt.TableName)) + } for _, indexName := range indexNames { s = s.Pipe(stream.IndexInsert(indexName)) diff --git a/sqltests/UPDATE/pk.sql b/sqltests/UPDATE/pk.sql new file mode 100644 index 000000000..d29ba3649 --- /dev/null +++ b/sqltests/UPDATE/pk.sql @@ -0,0 +1,39 @@ +-- test: set primary key +CREATE TABLE test (a int primary key, b int); +INSERT INTO test (a, b) VALUES (1, 10); +UPDATE test SET a = 2, b = 20 WHERE a = 1; +INSERT INTO test (a, b) VALUES (1, 10); +SELECT pk(), * FROM test; +/* result: +{"pk()": [1], a: 1, b: 10} +{"pk()": [2], a: 2, b: 20} +*/ + +-- test: set primary key / conflict +CREATE TABLE test (a int primary key, b int); +INSERT INTO test (a, b) VALUES (1, 10), (2, 20); +UPDATE test SET a = 2, b = 20 WHERE a = 1; +-- error: PRIMARY KEY constraint error: [a] + +-- test: set composite primary key +CREATE TABLE test (a int, b int, c int, PRIMARY KEY(a, b)); +INSERT INTO test (a, b, c) VALUES (1, 10, 100); +UPDATE test SET a = 2, b = 20 WHERE a = 1; +INSERT INTO test (a, b, c) VALUES (1, 10, 100); +SELECT pk(), * FROM test; +/* result: +{"pk()": [1, 10], a: 1, b: 10, c: 100} +{"pk()": [2, 20], a: 2, b: 20, c: 100} +*/ + +-- test: unset primary key +CREATE TABLE test (a int primary key, b int); +INSERT INTO test (a, b) VALUES (1, 10); +UPDATE test UNSET a WHERE a = 1; +-- error: cannot unset primary key path + +-- test: unset composite primary key +CREATE TABLE test (a int, b int, c int, PRIMARY KEY(a, b)); +INSERT INTO test (a, b, c) VALUES (1, 10, 100); +UPDATE test UNSET b WHERE a = 1; +-- error: cannot unset primary key path diff --git a/sqltests/sql_test.go b/sqltests/sql_test.go index dbc52013d..8c4f6c850 100644 --- a/sqltests/sql_test.go +++ b/sqltests/sql_test.go @@ -6,7 +6,6 @@ import ( "io/fs" "os" "path/filepath" - "regexp" "strings" "testing" @@ -89,7 +88,7 @@ func TestSQL(t *testing.T) { if test.ErrorMatch != "" { require.NotNilf(t, err, "%s:%d expected error, got nil", absPath, test.Line) - require.Regexpf(t, regexp.MustCompile(test.ErrorMatch), err.Error(), "Source %s:%d", absPath, test.Line) + require.Equal(t, test.ErrorMatch, err.Error(), "Source %s:%d", absPath, test.Line) } else { assert.Errorf(t, err, "\nSource:%s:%d expected\n%s\nto raise an error but got none", absPath, test.Line, test.Expr) }