Skip to content

Commit

Permalink
fix: cease many model fetch api calls in checkpoint tab (#8749)
Browse files Browse the repository at this point in the history
* fix: cease many model fetch api call in checkpoint tab

* fix: move model fetch api into the top level component

* refactor: custom hook for models fetch

* chore: feedback
  • Loading branch information
keita-determined authored Jan 25, 2024
1 parent 96b9064 commit f771acb
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 89 deletions.
8 changes: 2 additions & 6 deletions webui/react/src/components/CheckpointModalTrigger.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { render, screen, waitFor } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import { DefaultTheme, UIProvider } from 'hew/Theme';
import { ConfirmationProvider } from 'hew/useConfirm';
import { Loaded } from 'hew/utils/loadable';
import React, { useEffect } from 'react';
import { BrowserRouter } from 'react-router-dom';

Expand All @@ -13,12 +14,6 @@ import { generateTestExperimentData } from 'utils/tests/generateTestData';
const TEST_MODAL_TITLE = 'Checkpoint Modal Test';
const REGISTER_CHECKPOINT_TEXT = 'Register Checkpoint';

vi.mock('services/api', () => ({
getModels: () => {
return Promise.resolve({ models: [] });
},
}));

const user = userEvent.setup();

const ModalTrigger: React.FC = () => {
Expand All @@ -32,6 +27,7 @@ const ModalTrigger: React.FC = () => {
<CheckpointModalTrigger
checkpoint={checkpoint}
experiment={experiment}
models={Loaded([])}
title={TEST_MODAL_TITLE}
/>
);
Expand Down
44 changes: 4 additions & 40 deletions webui/react/src/components/CheckpointModalTrigger.tsx
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import Button from 'hew/Button';
import Icon from 'hew/Icon';
import { ModalCloseReason, useModal } from 'hew/Modal';
import { Loadable, Loaded, NotLoaded } from 'hew/utils/loadable';
import { isEqual } from 'lodash';
import React, { useCallback, useEffect, useState } from 'react';
import { Loadable } from 'hew/utils/loadable';
import React, { useCallback, useState } from 'react';

import ModelCreateModal from 'components/ModelCreateModal';
import RegisterCheckpointModal from 'components/RegisterCheckpointModal';
import { getModels } from 'services/api';
import { V1GetModelsRequestSortBy } from 'services/api-ts-sdk';
import {
CheckpointWorkloadExtended,
CoreApiGenericCheckpoint,
ExperimentBase,
ModelItem,
} from 'types';
import handleError, { ErrorType } from 'utils/error';
import { validateDetApiEnum } from 'utils/service';

import CheckpointModalComponent from './CheckpointModal';

Expand All @@ -25,22 +20,22 @@ interface Props {
children?: React.ReactNode;
experiment: ExperimentBase;
title: string;
models: Loadable<ModelItem[]>;
}

const CheckpointModalTrigger: React.FC<Props> = ({
checkpoint,
experiment,
title,
children,
models,
}: Props) => {
const modelCreateModal = useModal(ModelCreateModal);
const checkpointModal = useModal(CheckpointModalComponent);

const registerModal = useModal(RegisterCheckpointModal);

const [models, setModels] = useState<Loadable<ModelItem[]>>(NotLoaded);
const [selectedModelName, setSelectedModelName] = useState<string>();
const [canceler] = useState(new AbortController());

const handleOnCloseCreateModel = useCallback(
(modelName?: string) => {
Expand All @@ -52,37 +47,6 @@ const CheckpointModalTrigger: React.FC<Props> = ({
[setSelectedModelName, registerModal],
);

const fetchModels = useCallback(async () => {
try {
const response = await getModels(
{
archived: false,
orderBy: 'ORDER_BY_DESC',
sortBy: validateDetApiEnum(
V1GetModelsRequestSortBy,
V1GetModelsRequestSortBy.LASTUPDATEDTIME,
),
},
{ signal: canceler.signal },
);
setModels((prev) => {
const loadedModels = Loaded(response.models);
if (isEqual(prev, loadedModels)) return prev;
return loadedModels;
});
} catch (e) {
handleError(e, {
publicSubject: 'Unable to fetch models.',
silent: true,
type: ErrorType.Api,
});
}
}, [canceler.signal]);

useEffect(() => {
fetchModels();
}, [fetchModels]);

const handleModalCheckpointClick = useCallback(() => {
checkpointModal.open();
}, [checkpointModal]);
Expand Down
35 changes: 35 additions & 0 deletions webui/react/src/hooks/useFetchModels.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { ErrorType } from 'hew/utils/error';
import { Loadable, NotLoaded } from 'hew/utils/loadable';

import { useAsync } from 'hooks/useAsync';
import { getModels } from 'services/api';
import { V1GetModelsRequestSortBy } from 'services/api-ts-sdk';
import { ModelItem } from 'types';
import handleError from 'utils/error';
import { validateDetApiEnum } from 'utils/service';

export const useFetchModels = (): Loadable<ModelItem[]> => {
return useAsync(async (canceler) => {
try {
const response = await getModels(
{
archived: false,
orderBy: 'ORDER_BY_DESC',
sortBy: validateDetApiEnum(
V1GetModelsRequestSortBy,
V1GetModelsRequestSortBy.LASTUPDATEDTIME,
),
},
{ signal: canceler.signal },
);
return response.models;
} catch (e) {
handleError(e, {
publicSubject: 'Unable to fetch models.',
silent: true,
type: ErrorType.Api,
});
return NotLoaded;
}
}, []);
};
54 changes: 13 additions & 41 deletions webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { FilterDropdownProps } from 'antd/es/table/interface';
import { useModal } from 'hew/Modal';
import useConfirm from 'hew/useConfirm';
import { Loadable, Loaded, NotLoaded } from 'hew/utils/loadable';
import { isEqual } from 'lodash';
import React, { Key, useCallback, useEffect, useMemo, useState } from 'react';

Expand All @@ -20,14 +19,11 @@ import {
} from 'components/Table/Table';
import TableBatch from 'components/Table/TableBatch';
import TableFilterDropdown from 'components/Table/TableFilterDropdown';
import { useFetchModels } from 'hooks/useFetchModels';
import usePolling from 'hooks/usePolling';
import { useSettings } from 'hooks/useSettings';
import { getExperimentCheckpoints, getModels } from 'services/api';
import {
Checkpointv1SortBy,
Checkpointv1State,
V1GetModelsRequestSortBy,
} from 'services/api-ts-sdk';
import { getExperimentCheckpoints } from 'services/api';
import { Checkpointv1SortBy, Checkpointv1State } from 'services/api-ts-sdk';
import { detApi } from 'services/apiConfig';
import { encodeCheckpointState } from 'services/decoder';
import { readStream } from 'services/utils';
Expand All @@ -37,7 +33,6 @@ import {
CheckpointState,
CoreApiGenericCheckpoint,
ExperimentBase,
ModelItem,
RecordKey,
} from 'types';
import { canActionCheckpoint, getActionsForCheckpointsUnion } from 'utils/checkpoint';
Expand All @@ -61,10 +56,10 @@ const ExperimentCheckpoints: React.FC<Props> = ({ experiment, pageRef }: Props)
const [total, setTotal] = useState(0);
const [isLoading, setIsLoading] = useState(true);
const [checkpoints, setCheckpoints] = useState<CoreApiGenericCheckpoint[]>();
const [models, setModels] = useState<Loadable<ModelItem[]>>(NotLoaded);
const [selectedCheckpoints, setSelectedCheckpoints] = useState<string[]>();
const [selectedModelName, setSelectedModelName] = useState<string>();
const [canceler] = useState(new AbortController());
const models = useFetchModels();

const config = useMemo(() => configForExperiment(experiment.id), [experiment.id]);
const { settings, updateSettings } = useSettings<Settings>(config);
Expand Down Expand Up @@ -115,37 +110,6 @@ const ExperimentCheckpoints: React.FC<Props> = ({ experiment, pageRef }: Props)
[handleStateFilterApply, handleStateFilterReset, settings.state],
);

const fetchModels = useCallback(async () => {
try {
const response = await getModels(
{
archived: false,
orderBy: 'ORDER_BY_DESC',
sortBy: validateDetApiEnum(
V1GetModelsRequestSortBy,
V1GetModelsRequestSortBy.LASTUPDATEDTIME,
),
},
{ signal: canceler.signal },
);
setModels((prev) => {
const loadedModels = Loaded(response.models);
if (isEqual(prev, loadedModels)) return prev;
return loadedModels;
});
} catch (e) {
handleError(e, {
publicSubject: 'Unable to fetch models.',
silent: true,
type: ErrorType.Api,
});
}
}, [canceler.signal]);

useEffect(() => {
fetchModels();
}, [fetchModels]);

const handleRegisterCheckpoint = useCallback(
(checkpoints: string[]) => {
setSelectedCheckpoints(checkpoints);
Expand Down Expand Up @@ -239,6 +203,7 @@ const ExperimentCheckpoints: React.FC<Props> = ({ experiment, pageRef }: Props)
<CheckpointModalTrigger
checkpoint={record}
experiment={experiment}
models={models}
title={`Checkpoint ${record.uuid}`}
/>
);
Expand Down Expand Up @@ -270,7 +235,14 @@ const ExperimentCheckpoints: React.FC<Props> = ({ experiment, pageRef }: Props)
});

return newColumns;
}, [dropDownOnTrigger, experiment, settings.sortDesc, settings.sortKey, stateFilterDropdown]);
}, [
dropDownOnTrigger,
experiment,
models,
settings.sortDesc,
settings.sortKey,
stateFilterDropdown,
]);

const stateString = settings.state?.join('.');
const fetchExperimentCheckpoints = useCallback(async () => {
Expand Down
12 changes: 11 additions & 1 deletion webui/react/src/pages/ExperimentDetails/ExperimentTrials.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { defaultRowClassName, getFullPaginationConfig, Renderer } from 'componen
import TableBatch from 'components/Table/TableBatch';
import TableFilterDropdown from 'components/Table/TableFilterDropdown';
import { terminalRunStates } from 'constants/states';
import { useFetchModels } from 'hooks/useFetchModels';
import usePermissions from 'hooks/usePermissions';
import usePolling from 'hooks/usePolling';
import { useSettings } from 'hooks/useSettings';
Expand Down Expand Up @@ -67,6 +68,7 @@ const ExperimentTrials: React.FC<Props> = ({ experiment, pageRef }: Props) => {
const trialsComparisonModal = useModal(TrialsComparisonModalComponent);
const config = useMemo(() => configForExperiment(experiment.id), [experiment.id]);
const { settings, updateSettings } = useSettings<Settings>(config);
const models = useFetchModels();

const workspace = { id: experiment.workspaceId };
const { canCreateExperiment, canViewExperimentArtifacts } = usePermissions();
Expand Down Expand Up @@ -187,6 +189,7 @@ const ExperimentTrials: React.FC<Props> = ({ experiment, pageRef }: Props) => {
<CheckpointModalTrigger
checkpoint={checkpoint}
experiment={experiment}
models={models}
title={`Best Checkpoint for Trial ${checkpoint.trialId}`}
/>
);
Expand Down Expand Up @@ -240,7 +243,14 @@ const ExperimentTrials: React.FC<Props> = ({ experiment, pageRef }: Props) => {
});

return newColumns;
}, [experiment, settings, stateFilterDropdown, dropDownOnTrigger]);
}, [
experiment,
models,
dropDownOnTrigger,
settings.sortKey,
settings.sortDesc,
stateFilterDropdown,
]);

const handleTableChange = useCallback(
(
Expand Down
6 changes: 5 additions & 1 deletion webui/react/src/pages/TrialDetails/TrialDetailsWorkloads.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import ResponsiveFilters from 'components/ResponsiveFilters';
import Section from 'components/Section';
import ResponsiveTable from 'components/Table/ResponsiveTable';
import { defaultRowClassName, getFullPaginationConfig } from 'components/Table/Table';
import { useFetchModels } from 'hooks/useFetchModels';
import usePolling from 'hooks/usePolling';
import { getTrialWorkloads } from 'services/api';
import {
Expand Down Expand Up @@ -52,6 +53,8 @@ const TrialDetailsWorkloads: React.FC<Props> = ({
trial,
updateSettings,
}: Props) => {
const models = useFetchModels();

const hasFiltersApplied = useMemo(() => {
const metricsApplied = !_.isEqual(metrics, defaultMetrics);
const checkpointValidationFilterApplied = settings.filter !== TrialWorkloadFilter.All;
Expand All @@ -70,6 +73,7 @@ const TrialDetailsWorkloads: React.FC<Props> = ({
<CheckpointModalTrigger
checkpoint={checkpoint}
experiment={experiment}
models={models}
title={`Checkpoint for Batch ${checkpoint.totalBatches}`}
/>
);
Expand Down Expand Up @@ -121,7 +125,7 @@ const TrialDetailsWorkloads: React.FC<Props> = ({
}
return column;
});
}, [metrics, settings, trial, experiment]);
}, [experiment, metrics, trial, models, settings.sortDesc, settings.sortKey]);

const [workloads, setWorkloads] = useState<Loadable<WorkloadGroup[]>>(NotLoaded);
const [workloadCount, setWorkloadCount] = useState<number>(0);
Expand Down

0 comments on commit f771acb

Please sign in to comment.