Skip to content

Commit

Permalink
chore: bunify db/postgres_tasks.go (#8764)
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinaecalderon committed Feb 1, 2024
1 parent 07494cf commit 36a2e29
Show file tree
Hide file tree
Showing 27 changed files with 339 additions and 480 deletions.
2 changes: 1 addition & 1 deletion master/internal/api_checkpoint_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func createVersionTwoCheckpoint(
ResourcePool: "somethingelse",
StartTime: ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)),
}
require.NoError(t, api.m.db.AddAllocation(aIn))
require.NoError(t, db.AddAllocation(ctx, aIn))

checkpoint := &model.CheckpointV2{
ID: 0,
Expand Down
2 changes: 1 addition & 1 deletion master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2735,7 +2735,7 @@ func (a *apiServer) createTrialTx(
nil,
0)

if err := a.m.db.AddTask(&model.Task{
if err := db.AddTask(ctx, &model.Task{
TaskID: taskID,
TaskType: model.TaskTypeTrial,
StartTime: time.Now(),
Expand Down
6 changes: 3 additions & 3 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func TestGetTaskContextDirectoryExperiment(t *testing.T) {
func TestGetTaskContextDirectoryTask(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)
task := &model.Task{TaskType: model.TaskTypeNotebook, TaskID: model.NewTaskID()}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(ctx, task))

expectedContextDirectory := []byte("expectedContextDirectory")
_, err := db.Bun().NewInsert().Model(&model.TaskContextDirectory{
Expand Down Expand Up @@ -567,7 +567,7 @@ func TestGetExperiments(t *testing.T) {
require.NoError(t, api.m.db.AddExperiment(exp0, activeConfig0))
for i := 0; i < 3; i++ {
task := &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(ctx, task))
require.NoError(t, db.AddTrial(ctx, &model.Trial{
State: model.PausedState,
ExperimentID: exp0.ID,
Expand Down Expand Up @@ -819,7 +819,7 @@ func TestSearchExperiments(t *testing.T) {
// Trial without validations doesn't cause issues.
noValidationsExp := createTestExpWithProjectID(t, api, curUser, projectIDInt)
task := &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(ctx, task))
require.NoError(t, db.AddTrial(ctx, &model.Trial{
State: model.PausedState,
ExperimentID: noValidationsExp.ID,
Expand Down
18 changes: 9 additions & 9 deletions master/internal/api_tasks_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ func mockNotebookWithWorkspaceID(
TaskID: model.NewTaskID(),
TaskType: model.TaskTypeNotebook,
}
require.NoError(t, api.m.db.AddTask(nb))
require.NoError(t, db.AddTask(ctx, nb))

allocationID := model.AllocationID(string(nb.TaskID) + ".1")
require.NoError(t, api.m.db.AddAllocation(&model.Allocation{
require.NoError(t, db.AddAllocation(ctx, &model.Allocation{
TaskID: nb.TaskID,
AllocationID: allocationID,
}))
Expand Down Expand Up @@ -324,7 +324,7 @@ func TestAddAllocationAcceleratorData(t *testing.T) {
TaskType: model.TaskTypeTrial,
StartTime: time.Now().UTC().Truncate(time.Millisecond),
}
require.NoError(t, api.m.db.AddTask(task), "failed to add task")
require.NoError(t, db.AddTask(ctx, task), "failed to add task")

aID := tID + "-1"
a := &model.Allocation{
Expand All @@ -333,7 +333,7 @@ func TestAddAllocationAcceleratorData(t *testing.T) {
Slots: 1,
ResourcePool: "default",
}
require.NoError(t, api.m.db.AddAllocation(a), "failed to add allocation")
require.NoError(t, db.AddAllocation(ctx, a), "failed to add allocation")
accData := &model.AcceleratorData{
ContainerID: uuid.NewString(),
AllocationID: model.AllocationID(aID),
Expand Down Expand Up @@ -362,7 +362,7 @@ func TestGetAllocationAcceleratorDataWithNoData(t *testing.T) {
TaskType: model.TaskTypeTrial,
StartTime: time.Now().UTC().Truncate(time.Millisecond),
}
require.NoError(t, api.m.db.AddTask(task), "failed to add task")
require.NoError(t, db.AddTask(ctx, task), "failed to add task")

aID := tID + "-1"
a := &model.Allocation{
Expand All @@ -371,7 +371,7 @@ func TestGetAllocationAcceleratorDataWithNoData(t *testing.T) {
Slots: 1,
ResourcePool: "default",
}
require.NoError(t, api.m.db.AddAllocation(a), "failed to add allocation")
require.NoError(t, db.AddAllocation(ctx, a), "failed to add allocation")

resp, err := api.GetTaskAcceleratorData(ctx,
&apiv1.GetTaskAcceleratorDataRequest{TaskId: tID.String()})
Expand All @@ -390,7 +390,7 @@ func TestGetAllocationAcceleratorData(t *testing.T) {
TaskType: model.TaskTypeTrial,
StartTime: time.Now().UTC().Truncate(time.Millisecond),
}
require.NoError(t, api.m.db.AddTask(task), "failed to add task")
require.NoError(t, db.AddTask(ctx, task), "failed to add task")

aID1 := tID + "-1"
a1 := &model.Allocation{
Expand All @@ -399,7 +399,7 @@ func TestGetAllocationAcceleratorData(t *testing.T) {
Slots: 1,
ResourcePool: "default",
}
require.NoError(t, api.m.db.AddAllocation(a1), "failed to add allocation")
require.NoError(t, db.AddAllocation(ctx, a1), "failed to add allocation")
accData := &model.AcceleratorData{
ContainerID: uuid.NewString(),
AllocationID: model.AllocationID(aID1),
Expand All @@ -418,7 +418,7 @@ func TestGetAllocationAcceleratorData(t *testing.T) {
Slots: 1,
ResourcePool: "default",
}
require.NoError(t, api.m.db.AddAllocation(a2), "failed to add allocation")
require.NoError(t, db.AddAllocation(ctx, a2), "failed to add allocation")

resp, err := api.GetTaskAcceleratorData(ctx,
&apiv1.GetTaskAcceleratorDataRequest{TaskId: tID.String()})
Expand Down
18 changes: 9 additions & 9 deletions master/internal/api_trials_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func createTestTrial(
StartTime: time.Now(),
TaskID: trialTaskID(exp.ID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(context.TODO(), task))

trial := &model.Trial{
StartTime: time.Now(),
Expand Down Expand Up @@ -834,15 +834,15 @@ func TestTrialProtoTaskIDs(t *testing.T) {
StartTime: task0.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task1))
require.NoError(t, db.AddTask(ctx, task1))

task2 := &model.Task{
TaskType: model.TaskTypeTrial,
LogVersion: model.TaskLogVersion1,
StartTime: task1.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task2))
require.NoError(t, db.AddTask(ctx, task2))

_, err = db.Bun().NewInsert().Model(&[]model.RunTaskID{
{RunID: trial.ID, TaskID: task1.TaskID},
Expand Down Expand Up @@ -917,7 +917,7 @@ func TestExperimentIDFromTrialTaskID(t *testing.T) {
StartTime: time.Now(),
TaskID: model.TaskID(uuid.New().String()),
}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(context.TODO(), task))
_, err = experimentIDFromTrialTaskID(notTrialTask.TaskID)
require.ErrorIs(t, err, errIsNotTrialTaskID)

Expand All @@ -935,7 +935,7 @@ func TestTrialLogsBackported(t *testing.T) {
StartTime: time.Now(),
TaskID: model.TaskID(fmt.Sprintf("backported.%d", exp.ID)),
}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(ctx, task))

trial := &model.Trial{
StartTime: time.Now(),
Expand Down Expand Up @@ -972,15 +972,15 @@ func TestTrialLogs(t *testing.T) {
StartTime: task0.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task1))
require.NoError(t, db.AddTask(ctx, task1))

task2 := &model.Task{
TaskType: model.TaskTypeTrial,
LogVersion: model.TaskLogVersion1,
StartTime: task1.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task2))
require.NoError(t, db.AddTask(ctx, task2))

_, err := db.Bun().NewInsert().Model(&[]model.RunTaskID{
{RunID: trial.ID, TaskID: task1.TaskID},
Expand Down Expand Up @@ -1068,15 +1068,15 @@ func TestTrialLogFields(t *testing.T) {
StartTime: task0.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task1))
require.NoError(t, db.AddTask(ctx, task1))

task2 := &model.Task{
TaskType: model.TaskTypeTrial,
LogVersion: model.TaskLogVersion1,
StartTime: task1.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task2))
require.NoError(t, db.AddTask(ctx, task2))

_, err := db.Bun().NewInsert().Model(&[]model.RunTaskID{
{RunID: trial.ID, TaskID: task1.TaskID},
Expand Down
8 changes: 4 additions & 4 deletions master/internal/checkpoint_gc.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func runCheckpointGCForCheckpoints(

func runCheckpointGCTask(
rm rm.ResourceManager,
db *db.PgDB,
pgDB *db.PgDB,
taskID model.TaskID,
jobID model.JobID,
jobSubmissionTime time.Time,
Expand Down Expand Up @@ -140,7 +140,7 @@ func runCheckpointGCTask(
})
syslog := logrus.WithField("component", "checkpointgc").WithFields(logCtx.Fields())

if err := db.AddTask(&model.Task{
if err := db.AddTask(context.TODO(), &model.Task{
TaskID: taskID,
TaskType: model.TaskTypeCheckpointGC,
StartTime: time.Now().UTC(),
Expand All @@ -155,7 +155,7 @@ func runCheckpointGCTask(

resultChan := make(chan error, 1)
onExit := func(ae *task.AllocationExited) {
if err := db.CompleteTask(taskID, time.Now().UTC()); err != nil {
if err := db.CompleteTask(context.TODO(), taskID, time.Now().UTC()); err != nil {
syslog.WithError(err).Error("marking GC task complete")
}
if err := tasklist.GroupPriorityChangeRegistry.Delete(gcJobID); err != nil {
Expand All @@ -177,7 +177,7 @@ func runCheckpointGCTask(
SingleAgent: true,
},
ResourcePool: rp,
}, db, rm, gcSpec, onExit)
}, pgDB, rm, gcSpec, onExit)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions master/internal/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (c *Command) OnExit(ae *task.AllocationExited) {

c.exitStatus = ae

if err := c.db.CompleteTask(c.taskID, time.Now().UTC()); err != nil {
if err := db.CompleteTask(context.TODO(), c.taskID, time.Now().UTC()); err != nil {
c.syslog.WithError(err).Error("marking task complete")
}
if err := user.DeleteSessionByToken(context.TODO(), c.GenericCommandSpec.Base.UserSessionToken); err != nil {
Expand All @@ -251,7 +251,7 @@ func (c *Command) garbageCollect() {
}

if c.exitStatus == nil {
if err := c.db.CompleteTask(c.taskID, time.Now().UTC()); err != nil {
if err := db.CompleteTask(context.Background(), c.taskID, time.Now().UTC()); err != nil {
c.syslog.WithError(err).Error("marking task complete")
}
}
Expand Down
8 changes: 4 additions & 4 deletions master/internal/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -791,9 +791,9 @@ func (m *Master) restoreNonTerminalExperiments() error {
return nil
}

func (m *Master) closeOpenAllocations() error {
func (m *Master) closeOpenAllocations(ctx context.Context) error {
allocationIds := task.DefaultService.GetAllAllocationIDs()
if err := m.db.CloseOpenAllocations(allocationIds); err != nil {
if err := db.CloseOpenAllocations(ctx, allocationIds); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -1081,11 +1081,11 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error {
return err
}

if err = m.closeOpenAllocations(); err != nil {
if err = m.closeOpenAllocations(ctx); err != nil {
return err
}

if err = m.db.EndAllTaskStats(); err != nil {
if err = db.EndAllTaskStats(ctx); err != nil {
return err
}

Expand Down
13 changes: 0 additions & 13 deletions master/internal/db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,8 @@ type DB interface {
id int,
experimentBest, trialBest, trialLatest int,
) ([]uuid.UUID, error)
AddTask(t *model.Task) error
UpdateTrial(id int, newState model.State) error
UpdateTrialRunnerState(id int, state string) error
UpdateTrialRunnerMetadata(id int, md *trialv1.TrialRunnerMetadata) error
AddAllocation(a *model.Allocation) error
CompleteAllocation(a *model.Allocation) error
CompleteAllocationTelemetry(aID model.AllocationID) ([]byte, error)
TrialRunIDAndRestarts(trialID int) (int, int, error)
UpdateTrialRunID(id, runID int) error
UpdateTrialRestarts(id, restarts int) error
Expand Down Expand Up @@ -89,11 +84,6 @@ type DB interface {
trials []*apiv1.TrialsSnapshotResponse_Trial, endTime time.Time, err error)
TopTrialsByTrainingLength(experimentID int, maxTrials int, metric string,
smallerIsBetter bool) (trials []int32, err error)
StartAllocationSession(allocationID model.AllocationID, owner *model.User) (string, error)
DeleteAllocationSession(allocationID model.AllocationID) error
UpdateAllocationState(allocation model.Allocation) error
UpdateAllocationStartTime(allocation model.Allocation) error
UpdateAllocationProxyAddress(allocation model.Allocation) error
ExperimentSnapshot(experimentID int) ([]byte, int, error)
SaveSnapshot(
experimentID int, version int, experimentSnapshot []byte,
Expand All @@ -114,9 +104,6 @@ type DB interface {
RecordInstanceStats(a *model.InstanceStats) error
EndInstanceStats(a *model.InstanceStats) error
EndAllInstanceStats() error
EndAllTaskStats() error
RecordTaskEndStats(stats *model.TaskStats) error
RecordTaskStats(stats *model.TaskStats) error
UpdateJobPosition(jobID model.JobID, position decimal.Decimal) error
}

Expand Down
38 changes: 0 additions & 38 deletions master/internal/db/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,25 +211,6 @@ func (db *PgDB) Close() error {
return db.sql.Close()
}

// namedGet is a convenience method for a named query for a single value.
func (db *PgDB) namedGet(dest interface{}, query string, arg interface{}) error {
nstmt, err := db.sql.PrepareNamed(query)
if err != nil {
return errors.Wrapf(err, "error preparing query %s", query)
}

defer nstmt.Close()

if sErr := nstmt.QueryRowx(arg).Scan(dest); sErr != nil {
err = errors.Wrapf(sErr, "error scanning query %s", query)
}
if cErr := nstmt.Close(); cErr != nil && err != nil {
err = errors.Wrap(cErr, "error closing named DB statement")
}

return err
}

// namedExecOne is a convenience method for a NamedExec that should affect only one row.
func (db *PgDB) namedExecOne(query string, arg interface{}) error {
res, err := db.sql.NamedExec(query, arg)
Expand All @@ -249,25 +230,6 @@ func (db *PgDB) namedExecOne(query string, arg interface{}) error {
return nil
}

// namedExecOne is a convenience method for a NamedExec that should affect only one row.
func namedExecOne(tx *sqlx.Tx, query string, arg interface{}) error {
res, err := tx.NamedExec(query, arg)
if err != nil {
return errors.Wrapf(err, "error in query %v \narg %v", query, arg)
}
num, err := res.RowsAffected()
if err != nil {
return errors.Wrapf(
err,
"error checking rows affected for query %v\n arg %v",
query, arg)
}
if num != 1 {
return errors.Errorf("error: %v rows affected on query %v \narg %v", num, query, arg)
}
return nil
}

func queryBinds(fields []string) []string {
binds := make([]string, 0, len(fields))
for _, field := range fields {
Expand Down
Loading

0 comments on commit 36a2e29

Please sign in to comment.