Skip to content

Commit

Permalink
fix UPDATE behavior with primary keys
Browse files Browse the repository at this point in the history
  • Loading branch information
asdine committed Apr 7, 2022
1 parent 2d67d57 commit f81f3f1
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 3 deletions.
34 changes: 33 additions & 1 deletion internal/query/statement/update.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package statement

import (
"github.com/cockroachdb/errors"
"github.com/genjidb/genji/document"
"github.com/genjidb/genji/internal/expr"
"github.com/genjidb/genji/internal/stream"
Expand Down Expand Up @@ -42,18 +43,44 @@ type UpdateSetPair struct {

// Prepare implements the Preparer interface.
func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
ti, err := c.Catalog.GetTableInfo(stmt.TableName)
if err != nil {
return nil, err
}
pk := ti.GetPrimaryKey()

s := stream.New(stream.TableScan(stmt.TableName))

if stmt.WhereExpr != nil {
s = s.Pipe(stream.DocsFilter(stmt.WhereExpr))
}

var pkModified bool
if stmt.SetPairs != nil {
for _, pair := range stmt.SetPairs {
// if we modify the primary key,
// we must remove the old document and create an new one
if pk != nil && !pkModified {
for _, p := range pk.Paths {
if p.IsEqual(pair.Path) {
pkModified = true
break
}
}
}
s = s.Pipe(stream.PathsSet(pair.Path, pair.E))
}
} else if stmt.UnsetFields != nil {
for _, name := range stmt.UnsetFields {
// ensure we do not unset any path the is used in the primary key
if pk != nil {
path := document.NewPath(name)
for _, p := range pk.Paths {
if p.IsEqual(path) {
return nil, errors.New("cannot unset primary key path")
}
}
}
s = s.Pipe(stream.PathsUnset(name))
}
}
Expand All @@ -69,7 +96,12 @@ func (stmt *UpdateStmt) Prepare(c *Context) (Statement, error) {
s = s.Pipe(stream.IndexDelete(indexName))
}

s = s.Pipe(stream.TableReplace(stmt.TableName))
if pkModified {
s = s.Pipe(stream.TableDelete(stmt.TableName))
s = s.Pipe(stream.TableInsert(stmt.TableName))
} else {
s = s.Pipe(stream.TableReplace(stmt.TableName))
}

for _, indexName := range indexNames {
s = s.Pipe(stream.IndexInsert(indexName))
Expand Down
39 changes: 39 additions & 0 deletions sqltests/UPDATE/pk.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
-- test: set primary key
CREATE TABLE test (a int primary key, b int);
INSERT INTO test (a, b) VALUES (1, 10);
UPDATE test SET a = 2, b = 20 WHERE a = 1;
INSERT INTO test (a, b) VALUES (1, 10);
SELECT pk(), * FROM test;
/* result:
{"pk()": [1], a: 1, b: 10}
{"pk()": [2], a: 2, b: 20}
*/

-- test: set primary key / conflict
CREATE TABLE test (a int primary key, b int);
INSERT INTO test (a, b) VALUES (1, 10), (2, 20);
UPDATE test SET a = 2, b = 20 WHERE a = 1;
-- error: PRIMARY KEY constraint error: [a]

-- test: set composite primary key
CREATE TABLE test (a int, b int, c int, PRIMARY KEY(a, b));
INSERT INTO test (a, b, c) VALUES (1, 10, 100);
UPDATE test SET a = 2, b = 20 WHERE a = 1;
INSERT INTO test (a, b, c) VALUES (1, 10, 100);
SELECT pk(), * FROM test;
/* result:
{"pk()": [1, 10], a: 1, b: 10, c: 100}
{"pk()": [2, 20], a: 2, b: 20, c: 100}
*/

-- test: unset primary key
CREATE TABLE test (a int primary key, b int);
INSERT INTO test (a, b) VALUES (1, 10);
UPDATE test UNSET a WHERE a = 1;
-- error: cannot unset primary key path

-- test: unset composite primary key
CREATE TABLE test (a int, b int, c int, PRIMARY KEY(a, b));
INSERT INTO test (a, b, c) VALUES (1, 10, 100);
UPDATE test UNSET b WHERE a = 1;
-- error: cannot unset primary key path
3 changes: 1 addition & 2 deletions sqltests/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"
"testing"

Expand Down Expand Up @@ -89,7 +88,7 @@ func TestSQL(t *testing.T) {

if test.ErrorMatch != "" {
require.NotNilf(t, err, "%s:%d expected error, got nil", absPath, test.Line)
require.Regexpf(t, regexp.MustCompile(test.ErrorMatch), err.Error(), "Source %s:%d", absPath, test.Line)
require.Equal(t, test.ErrorMatch, err.Error(), "Source %s:%d", absPath, test.Line)
} else {
assert.Errorf(t, err, "\nSource:%s:%d expected\n%s\nto raise an error but got none", absPath, test.Line, test.Expr)
}
Expand Down

0 comments on commit f81f3f1

Please sign in to comment.