Skip to content

Commit

Permalink
chore: add trigger to abort checkpoint deletion (#8878)
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 authored Feb 27, 2024
1 parent 2689b0b commit 9817a4d
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 55 deletions.
101 changes: 46 additions & 55 deletions master/internal/checkpoints/postgres_checkpoints_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ func TestUpdateCheckpointSize(t *testing.T) {
user := db.RequireMockUser(t, db.SingleDB())

var resources []map[string]int64
for i := 0; i < 8; i++ {
resources = append(resources, map[string]int64{"TEST": int64(i) + 1})
for i := 1; i <= 8; i++ {
resources = append(resources, map[string]int64{"TEST": int64(i)})
}

// Create two experiments with two trials each with two checkpoints.
Expand Down Expand Up @@ -366,7 +366,7 @@ func TestUpdateCheckpointSize(t *testing.T) {
verifySizes(e)
}

func TestDeleteCheckpoints(t *testing.T) {
func TestUpdateCheckpointStateToDeleted(t *testing.T) {
ctx := context.Background()

user := db.RequireMockUser(t, db.SingleDB())
Expand All @@ -385,73 +385,64 @@ func TestDeleteCheckpoints(t *testing.T) {
err = db.AddCheckpointMetadata(ctx, &checkpoint2, tr.ID)
require.NoError(t, err)

ckpt3 := uuid.New()
checkpoint3 := db.MockModelCheckpoint(ckpt3, allocation)
err = db.AddCheckpointMetadata(ctx, &checkpoint3, tr.ID)
require.NoError(t, err)
require.NoError(t, MarkCheckpointsDeleted(ctx, []uuid.UUID{checkpoint1.UUID}))

// Insert a model.
now := time.Now()
mdl := db.Model{
Name: uuid.NewString(),
Description: "some important model",
CreationTime: now,
LastUpdatedTime: now,
Labels: []string{"some other label"},
UserID: user.ID,
WorkspaceID: 1,
}
mdlNotes := "some notes3"
pmdl, err := db.InsertModel(ctx, mdl.Name, mdl.Description, emptyMetadata,
strings.Join(mdl.Labels, ","), mdlNotes, user.ID, mdl.WorkspaceID,
)
var numDStateCheckpoints int

err = db.Bun().NewSelect().
TableExpr("checkpoints_view AS c").
ColumnExpr("count(c.uuid) AS numC").
Where("c.uuid::text = ? AND c.state = 'DELETED'", checkpoint1.UUID).
Scan(ctx, &numDStateCheckpoints)
require.NoError(t, err)

// Register checkpoint_1 and checkpoint_2 in ModelRegistry
retCkpt1, err := db.GetCheckpoint(ctx, checkpoint1.UUID.String())
require.NoError(t, err)
require.Equal(t, 1, numDStateCheckpoints, "didn't mark checkpoint as deleted")

retCkpt2, err := db.GetCheckpoint(ctx, checkpoint2.UUID.String())
err = db.Bun().NewSelect().
TableExpr("checkpoints_view AS c").
ColumnExpr("count(c.uuid) AS numC").
Where("c.uuid::text = ? AND c.state = 'DELETED'", checkpoint2.UUID).
Scan(ctx, &numDStateCheckpoints)
require.NoError(t, err)

mv := modelv1.ModelVersion{
Model: pmdl,
Checkpoint: retCkpt1,
Name: "checkpoint 1",
Comment: "empty",
}
_, err = db.InsertModelVersion(ctx, pmdl.Id, retCkpt1.Uuid, mv.Name, mv.Comment,
emptyMetadata, strings.Join(mv.Labels, ","), mv.Notes, user.ID,
)
require.Equal(t, 0, numDStateCheckpoints)
}

func TestDeleteCheckpoints(t *testing.T) {
// Verify that checkpoints only get deleted when their state is 'DELETED', indicating that all
// corresponding checkpoint files were thoroughly removed from storage.
ctx := context.Background()

user := db.RequireMockUser(t, db.SingleDB())
exp := db.RequireMockExperiment(t, db.SingleDB(), user)
tr, task := db.RequireMockTrial(t, db.SingleDB(), exp)
allocation := db.RequireMockAllocation(t, db.SingleDB(), task.TaskID)

// Create checkpoints
ckpt1 := uuid.New()
checkpoint1 := db.MockModelCheckpoint(ckpt1, allocation)
err := db.AddCheckpointMetadata(ctx, &checkpoint1, tr.ID)
require.NoError(t, err)

mv = modelv1.ModelVersion{
Model: pmdl,
Checkpoint: retCkpt2,
Name: "checkpoint 2",
Comment: "empty",
}
_, err = db.InsertModelVersion(ctx, pmdl.Id, retCkpt2.Uuid, mv.Name, mv.Comment,
emptyMetadata, strings.Join(mv.Labels, ","), mv.Notes, user.ID,
)
_, err = db.Bun().NewDelete().Model(&model.CheckpointV2{}).Where("uuid = ?", ckpt1).Exec(ctx)
require.NoError(t, err)

validDeleteCheckpoint := checkpoint3.UUID
numValidDCheckpoints := 1
// Verify that checkpoint wasn't deleted since its state is not 'DELETED'.
ct, err := db.Bun().NewSelect().Model(&model.CheckpointV2{}).Where("uuid = ?", ckpt1).Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, ct)

require.NoError(t, MarkCheckpointsDeleted(ctx, []uuid.UUID{validDeleteCheckpoint}))
_, err = db.Bun().NewUpdate().Model(&model.CheckpointV2{}).Set("state = ?", "DELETED").
Where("uuid = ?", ckpt1).Exec(ctx)
require.NoError(t, err)

var numDStateCheckpoints int
err = db.Bun().NewSelect().
TableExpr("checkpoints_view AS c").
ColumnExpr("count(c.uuid) AS numC").
Where("c.uuid::text = ? AND c.state = 'DELETED'", validDeleteCheckpoint).
Scan(ctx, &numDStateCheckpoints)
_, err = db.Bun().NewDelete().Model(&model.CheckpointV2{}).Where("uuid = ?", ckpt1).Exec(ctx)
require.NoError(t, err)

require.Equal(t, numValidDCheckpoints, numDStateCheckpoints,
"didn't correctly delete the valid checkpoints")
// Verify that checkpoint was deleted once its state was marked 'DELETED'.
ct, err = db.Bun().NewSelect().Model(&model.CheckpointV2{}).Where("uuid = ?", ckpt1).Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, ct)
}

func BenchmarkUpdateCheckpointSize(b *testing.B) {
Expand Down
5 changes: 5 additions & 0 deletions master/internal/db/postgres_experiments_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,11 @@ func TestDeleteExperiments(t *testing.T) {
for k := 0; k < numChkpts; k++ { // Create checkpoints
ckpt := uuid.New()
checkpoint := MockModelCheckpoint(ckpt, allocation)

// Set checkpoint state to 'DELETED' (indicating that they cease to exist in
// storage) so that their deletion isn't blocked by the on_checkpoint_deletion
// trigger.
checkpoint.State = model.DeletedState
err := AddCheckpointMetadata(ctx, &checkpoint, tr.ID)
require.NoError(t, err)
checkpointIDs = append(checkpointIDs, checkpoint.ID)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
DROP TRIGGER IF EXISTS on_checkpoint_deletion ON checkpoints_v2;

DROP FUNCTION IF EXISTS abort_checkpoint_delete();
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
CREATE OR REPLACE FUNCTION abort_checkpoint_delete() RETURNS TRIGGER AS $$
BEGIN
IF OLD.state <> 'DELETED' THEN
RETURN NULL;
END IF;
RETURN OLD;
END
$$ LANGUAGE plpgsql;

CREATE TRIGGER on_checkpoint_deletion
BEFORE DELETE ON checkpoints_v2
FOR EACH ROW EXECUTE PROCEDURE abort_checkpoint_delete();

0 comments on commit 9817a4d

Please sign in to comment.