Skip to content

Commit

Permalink
feat: Create MoveRun endpoint (#9001)
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanuelAaron authored Apr 4, 2024
1 parent 91d7e08 commit d6059e9
Show file tree
Hide file tree
Showing 11 changed files with 1,761 additions and 669 deletions.
118 changes: 118 additions & 0 deletions harness/determined/common/api/bindings.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

192 changes: 178 additions & 14 deletions master/internal/api_runs.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/pkg/errors"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/db/bunutils"
"github.com/determined-ai/determined/master/internal/experiment"
Expand All @@ -18,13 +20,21 @@ import (
"github.com/determined-ai/determined/master/internal/trials"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
"github.com/determined-ai/determined/master/pkg/set"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/projectv1"
"github.com/determined-ai/determined/proto/pkg/rbacv1"
"github.com/determined-ai/determined/proto/pkg/runv1"
"github.com/determined-ai/determined/proto/pkg/trialv1"
)

type archiveRunOKResult struct {
Archived bool
ID int32
ExpID *int32
IsMultitrial bool
}

func (a *apiServer) RunPrepareForReporting(
ctx context.Context, req *apiv1.RunPrepareForReportingRequest,
) (*apiv1.RunPrepareForReportingResponse, error) {
Expand Down Expand Up @@ -89,20 +99,7 @@ func (a *apiServer) SearchRuns(
}

if req.Filter != nil {
var efr experimentFilterRoot
err := json.Unmarshal([]byte(*req.Filter), &efr)
if err != nil {
return nil, err
}
query = query.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
_, err = efr.toSQL(q)
return q
}).WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
if !efr.ShowArchived {
return q.Where(`e.archived = false`)
}
return q
})
query, err = filterRunQuery(query, req.Filter)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -239,3 +236,170 @@ func sortRuns(sortString *string, runQuery *bun.SelectQuery) error {
}
return nil
}

func filterRunQuery(getQ *bun.SelectQuery, filter *string) (*bun.SelectQuery, error) {
var efr experimentFilterRoot
err := json.Unmarshal([]byte(*filter), &efr)
if err != nil {
return nil, err
}
getQ = getQ.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
_, err = efr.toSQL(q)
return q
}).WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
if !efr.ShowArchived {
return q.Where(`e.archived = false`)
}
return q
})
if err != nil {
return nil, err
}
return getQ, nil
}

func (a *apiServer) MoveRuns(
ctx context.Context, req *apiv1.MoveRunsRequest,
) (*apiv1.MoveRunsResponse, error) {
curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
}
// check that user can view source project
srcProject, err := a.GetProjectByID(ctx, req.SourceProjectId, *curUser)
if err != nil {
return nil, err
}
if srcProject.Archived {
return nil, errors.Errorf("project (%v) is archived and cannot have runs moved from it",
srcProject.Id)
}

// check suitable destination project
destProject, err := a.GetProjectByID(ctx, req.DestinationProjectId, *curUser)
if err != nil {
return nil, err
}
if destProject.Archived {
return nil, errors.Errorf("project (%v) is archived and cannot add new runs",
req.DestinationProjectId)
}
if err = experiment.AuthZProvider.Get().CanCreateExperiment(ctx, *curUser, destProject); err != nil {
return nil, status.Error(codes.PermissionDenied, err.Error())
}

var runChecks []archiveRunOKResult
getQ := db.Bun().NewSelect().
ModelTableExpr("runs AS r").
Model(&runChecks).
Column("r.id").
ColumnExpr("COALESCE((e.archived OR p.archived OR w.archived), FALSE) AS archived").
ColumnExpr("r.experiment_id as exp_id").
ColumnExpr("((SELECT COUNT(*) FROM runs r WHERE e.id = r.experiment_id) > 1) as is_multitrial").
Join("LEFT JOIN experiments e ON r.experiment_id=e.id").
Join("JOIN projects p ON r.project_id = p.id").
Join("JOIN workspaces w ON p.workspace_id = w.id").
Where("r.project_id = ?", req.SourceProjectId)

if req.Filter == nil {
getQ = getQ.Where("r.id IN (?)", bun.In(req.RunIds))
} else {
getQ, err = filterRunQuery(getQ, req.Filter)
if err != nil {
return nil, err
}
}

if getQ, err = experiment.AuthZProvider.Get().FilterExperimentsQuery(ctx, *curUser, nil, getQ,
[]rbacv1.PermissionType{
rbacv1.PermissionType_PERMISSION_TYPE_VIEW_EXPERIMENT_METADATA,
rbacv1.PermissionType_PERMISSION_TYPE_DELETE_EXPERIMENT,
}); err != nil {
return nil, err
}

err = getQ.Scan(ctx)
if err != nil {
return nil, err
}

var results []*apiv1.RunActionResult
visibleIDs := set.New[int32]()
var validIDs []int32
// associated experiments to move
var expMoveIds []int32
for _, check := range runChecks {
visibleIDs.Insert(check.ID)
if check.Archived {
results = append(results, &apiv1.RunActionResult{
Error: "Run is archived.",
Id: check.ID,
})
continue
}
if check.IsMultitrial && req.SkipMultitrial {
results = append(results, &apiv1.RunActionResult{
Error: fmt.Sprintf("Skipping run '%d' (part of multi-trial).", check.ID),
Id: check.ID,
})
continue
}
if check.ExpID != nil {
expMoveIds = append(expMoveIds, *check.ExpID)
}
validIDs = append(validIDs, check.ID)
}
if req.Filter == nil {
for _, originalID := range req.RunIds {
if !visibleIDs.Contains(originalID) {
results = append(results, &apiv1.RunActionResult{
Error: fmt.Sprintf("Run with id '%d' not found in project with id '%d'", originalID, req.SourceProjectId),
Id: originalID,
})
}
}
}
if len(validIDs) > 0 {
expMoveResults, err := experiment.MoveExperiments(ctx, expMoveIds, nil, req.DestinationProjectId)
if err != nil {
return nil, err
}
failedExpMoveIds := []int32{-1}
for _, res := range expMoveResults {
if res.Error != nil {
failedExpMoveIds = append(failedExpMoveIds, res.ID)
}
}
var acceptedIDs []int32
if _, err = db.Bun().NewUpdate().Table("runs").
Set("project_id = ?", req.DestinationProjectId).
Where("runs.id IN (?)", bun.In(validIDs)).
Where("runs.experiment_id NOT IN (?)", bun.In(failedExpMoveIds)).
Returning("runs.id").
Model(&acceptedIDs).
Exec(ctx); err != nil {
return nil, fmt.Errorf("updating run's project IDs: %w", err)
}

for _, acceptID := range acceptedIDs {
results = append(results, &apiv1.RunActionResult{
Error: "",
Id: acceptID,
})
}
var failedRunIDs []int32
if err = db.Bun().NewSelect().Table("runs").
Where("runs.id IN (?)", bun.In(validIDs)).
Where("runs.experiment_id IN (?)", bun.In(failedExpMoveIds)).
Scan(ctx, &failedRunIDs); err != nil {
return nil, fmt.Errorf("getting failed experiment move run IDs: %w", err)
}
for _, failedRunID := range failedRunIDs {
results = append(results, &apiv1.RunActionResult{
Error: "Failed to move associated experiment",
Id: failedRunID,
})
}
}
return &apiv1.MoveRunsResponse{Results: results}, nil
}
Loading

0 comments on commit d6059e9

Please sign in to comment.