Skip to content

Commit

Permalink
fix: report errors from deletecheckpoints endpoint + improve feedback (
Browse files Browse the repository at this point in the history
  • Loading branch information
ashtonG committed Apr 17, 2024
1 parent 1037d83 commit 26f5e0b
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 28 deletions.
71 changes: 66 additions & 5 deletions master/internal/api_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,55 @@ func (a *apiServer) checkpointsRBACEditCheck(
return exps, groupCUUIDsByEIDs, nil
}

func makeRegisteredCheckpointErrorMessage(
ctx context.Context,
baseMessageFormat string,
checkpointMap map[uuid.UUID]checkpoints.ModelInfo,
) (*string, error) {
curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
}
var modelIDs []int
for _, v := range checkpointMap {
modelIDs = append(modelIDs, v.ID)
}
var models []struct {
ID int
}
modelQuery := internaldb.Bun().
NewSelect().
Model(&models).
Table("models").
ColumnExpr("id").
Where("id IN (?)", bun.In(modelIDs))
if modelQuery, err = modelauth.AuthZProvider.Get().
FilterReadableModelsQuery(ctx, *curUser, modelQuery); err != nil {
return nil, err
}
err = modelQuery.Scan(ctx)
if err != nil {
return nil, err
}
accessibleModels := make(map[int]bool, len(models))
for _, v := range models {
accessibleModels[v.ID] = true
}
var checkpointMsgs []string
for k, v := range checkpointMap {
if _, ok := accessibleModels[v.ID]; ok {
msg := fmt.Sprintf("%v, registered to %v (model #%d), version %d", k, v.Name, v.ID, v.Version)
checkpointMsgs = append(checkpointMsgs, msg)
} else {
msg := fmt.Sprintf("%v, registered to an unknown model", k)
checkpointMsgs = append(checkpointMsgs, msg)
}
}
retVal := fmt.Sprintf(baseMessageFormat, strings.Join(checkpointMsgs, ", "))

return &retVal, nil
}

func (a *apiServer) PatchCheckpoints(
ctx context.Context,
req *apiv1.PatchCheckpointsRequest,
Expand All @@ -222,9 +271,15 @@ func (a *apiServer) PatchCheckpoints(
return nil, err
}
if len(registeredCheckpointUUIDs) > 0 {
return nil, status.Errorf(codes.InvalidArgument,
"this subset of checkpoints provided are in the model registry and cannot be deleted: %v.",
registeredCheckpointUUIDs)
errMsg, err := makeRegisteredCheckpointErrorMessage(
ctx,
"this subset of checkpoints provided are in the model registry and cannot be patched: %v.",
registeredCheckpointUUIDs,
)
if err != nil {
return nil, err
}
return nil, status.Errorf(codes.InvalidArgument, *errMsg)
}

err = internaldb.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
Expand Down Expand Up @@ -327,9 +382,15 @@ func (a *apiServer) CheckpointsRemoveFiles(
return nil, err
}
if len(registeredCheckpointUUIDs) > 0 {
return nil, status.Errorf(codes.InvalidArgument,
errMsg, err := makeRegisteredCheckpointErrorMessage(
ctx,
"this subset of checkpoints provided are in the model registry and cannot be deleted: %v.",
registeredCheckpointUUIDs)
registeredCheckpointUUIDs,
)
if err != nil {
return nil, err
}
return nil, status.Errorf(codes.InvalidArgument, *errMsg)
}

taskSpec := *a.m.taskSpec
Expand Down
27 changes: 21 additions & 6 deletions master/internal/checkpoints/postgres_checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,40 @@ func GetModelIDsAssociatedWithCheckpoint(ctx context.Context, ckptUUID uuid.UUID
return modelIDs, nil
}

// ModelInfo is a struct containing info used for locating models.
type ModelInfo struct {
ID int
Version int
Name string
}

// GetRegisteredCheckpoints gets the checkpoints in
// the model registrys from the list of checkpoints provided.
func GetRegisteredCheckpoints(ctx context.Context, checkpoints []uuid.UUID) (map[uuid.UUID]bool, error) {
func GetRegisteredCheckpoints(ctx context.Context, checkpoints []uuid.UUID) (map[uuid.UUID]ModelInfo, error) {
var checkpointIDRows []struct {
ID uuid.UUID
ID uuid.UUID
ModelID int
ModelVersion int
ModelName string
}

if err := db.Bun().NewRaw(`
SELECT DISTINCT(mv.checkpoint_uuid) as ID FROM model_versions AS mv
WHERE mv.checkpoint_uuid IN (SELECT UNNEST(?::uuid[]));`,
SELECT DISTINCT(mv.checkpoint_uuid) as ID, mv.model_id as model_id, mv.version as model_version, m.name as model_name
FROM model_versions AS mv LEFT JOIN models as m on mv.model_id=m.id WHERE mv.checkpoint_uuid
IN (SELECT UNNEST(?::uuid[]));`,
pgdialect.Array(checkpoints)).Scan(ctx, &checkpointIDRows); err != nil {
return nil, fmt.Errorf(
"filtering checkpoint uuids by those registered in the model registry: %w", err)
}

checkpointIDs := make(map[uuid.UUID]bool, len(checkpointIDRows))
checkpointIDs := make(map[uuid.UUID]ModelInfo, len(checkpointIDRows))

for _, cRow := range checkpointIDRows {
checkpointIDs[cRow.ID] = true
checkpointIDs[cRow.ID] = ModelInfo{
ID: cRow.ModelID,
Version: cRow.ModelVersion,
Name: cRow.ModelName,
}
}

return checkpointIDs, nil
Expand Down
10 changes: 7 additions & 3 deletions master/internal/checkpoints/postgres_checkpoints_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func TestGetRegisteredCheckpoints(t *testing.T) {
Name: "checkpoint 1",
Comment: "empty",
}
_, err = db.InsertModelVersion(ctx, pmdl.Id, retCkpt1.Uuid, mv1.Name, mv1.Comment,
mv1mdl, err := db.InsertModelVersion(ctx, pmdl.Id, retCkpt1.Uuid, mv1.Name, mv1.Comment,
emptyMetadata, strings.Join(mv1.Labels, ","), mv1.Notes, user.ID,
)
require.NoError(t, err)
Expand All @@ -233,8 +233,12 @@ func TestGetRegisteredCheckpoints(t *testing.T) {
require.NoError(t, err)

checkpoints := []uuid.UUID{checkpoint1.UUID, checkpoint3.UUID}
expectedRegisteredCheckpoints := make(map[uuid.UUID]bool)
expectedRegisteredCheckpoints[checkpoint1.UUID] = true
expectedRegisteredCheckpoints := make(map[uuid.UUID]ModelInfo)
expectedRegisteredCheckpoints[checkpoint1.UUID] = ModelInfo{
ID: int(pmdl.Id),
Version: int(mv1mdl.Version),
Name: pmdl.Name,
}
dCheckpointsInRegistry, err := GetRegisteredCheckpoints(ctx, checkpoints)
require.NoError(t, err)
require.Equal(t, expectedRegisteredCheckpoints, dCheckpointsInRegistry)
Expand Down
12 changes: 8 additions & 4 deletions webui/react/src/components/CheckpointModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import useConfirm from 'hew/useConfirm';
import React, { useCallback, useMemo } from 'react';

import { paths } from 'routes/utils';
import { detApi } from 'services/apiConfig';
import { readStream } from 'services/utils';
import { deleteCheckpoints } from 'services/api';
import {
CheckpointState,
CheckpointStorageType,
Expand Down Expand Up @@ -86,12 +85,17 @@ const CheckpointModalComponent: React.FC<Props> = ({

const handleDelete = useCallback(async () => {
if (!checkpoint?.uuid) return;
await readStream(detApi.Checkpoint.deleteCheckpoints({ checkpointUuids: [checkpoint.uuid] }));
try {
await deleteCheckpoints({ checkpointUuids: [checkpoint.uuid] });
} catch (e) {
// modal error handling overwrites error message
handleError(e);
}
}, [checkpoint]);

const onClickDelete = useCallback(() => {
const content = `Are you sure you want to request checkpoint deletion for batch
${checkpoint?.totalBatches}. This action may complete or fail without further notification.`;
${checkpoint?.totalBatches}? This action may complete or fail without further notification.`;

confirm({
content,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ import { useCheckpointFlow } from 'hooks/useCheckpointFlow';
import { useFetchModels } from 'hooks/useFetchModels';
import usePolling from 'hooks/usePolling';
import { useSettings } from 'hooks/useSettings';
import { getExperimentCheckpoints } from 'services/api';
import { deleteCheckpoints, getExperimentCheckpoints } from 'services/api';
import { Checkpointv1SortBy, Checkpointv1State } from 'services/api-ts-sdk';
import { detApi } from 'services/apiConfig';
import { encodeCheckpointState } from 'services/decoder';
import { readStream } from 'services/utils';
import {
checkpointAction,
CheckpointAction,
Expand Down Expand Up @@ -128,12 +126,13 @@ const ExperimentCheckpoints: React.FC<Props> = ({ experiment, pageRef }: Props)
[registerModal],
);

const handleDelete = useCallback((checkpoints: string[]) => {
readStream(
detApi.Checkpoint.deleteCheckpoints({
checkpointUuids: checkpoints,
}),
);
const handleDelete = useCallback(async (checkpointUuids: string[]) => {
try {
await deleteCheckpoints({ checkpointUuids });
} catch (e) {
// confirm modal overwrites error message
handleError(e);
}
}, []);

const handleDeleteCheckpoint = useCallback(
Expand Down
6 changes: 6 additions & 0 deletions webui/react/src/services/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,12 @@ export const launchTensorBoard = generateDetApi<
Type.CommandResponse
>(Config.launchTensorBoard);

export const deleteCheckpoints = generateDetApi<
Api.V1DeleteCheckpointsRequest,
Api.V1DeleteCheckpointsResponse,
Api.V1DeleteCheckpointsResponse
>(Config.deleteCheckpoints);

export const openOrCreateTensorBoard = async (
params: Service.LaunchTensorBoardParams,
): Promise<Type.CommandResponse> => {
Expand Down
12 changes: 11 additions & 1 deletion webui/react/src/services/apiConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const generateApiConfig = (apiConfig?: Api.ConfigurationParameters) => {
const config = updatedApiConfigParams(apiConfig);
return {
Auth: new Api.AuthenticationApi(config),
Checkpoint: Api.CheckpointsApiFetchParamCreator(config),
Checkpoint: new Api.CheckpointsApi(config),
Cluster: new Api.ClusterApi(config),
Commands: new Api.CommandsApi(config),
Experiments: new Api.ExperimentsApi(config),
Expand Down Expand Up @@ -1952,3 +1952,13 @@ export const updateJobQueue: DetApi<
postProcess: identity,
request: (params: Api.V1UpdateJobQueueRequest) => detApi.Internal.updateJobQueue(params),
};

export const deleteCheckpoints: DetApi<
Api.V1DeleteCheckpointsRequest,
Api.V1DeleteCheckpointsResponse,
Api.V1DeleteCheckpointsResponse
> = {
name: 'deleteCheckpoints',
postProcess: identity,
request: (params, options) => detApi.Checkpoint.deleteCheckpoints(params, options),
};

0 comments on commit 26f5e0b

Please sign in to comment.