Skip to content

Commit

Permalink
feat: flat runs comparison view [ET-190] (#9477)
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilyBonar authored Jun 13, 2024
1 parent d44013c commit 32585ad
Show file tree
Hide file tree
Showing 17 changed files with 1,122 additions and 356 deletions.
73 changes: 46 additions & 27 deletions webui/react/src/components/CompareHeatMaps.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@ import MetricBadgeTag from 'components/MetricBadgeTag';
import useUI from 'components/ThemeProvider';
import { UPlotScatterProps } from 'components/UPlot/types';
import UPlotScatter from 'components/UPlot/UPlotScatter';
import { RunMetricData } from 'hooks/useMetrics';
import useResize from 'hooks/useResize';
import { TrialMetricData } from 'pages/TrialDetails/useTrialMetrics';
import {
ExperimentWithTrial,
FlatRun,
Hyperparameter,
HyperparameterType,
MetricType,
Primitive,
Range,
Scale,
TrialHyperparameters,
TrialItem,
XAxisDomain,
XOR,
} from 'types';
import { getColorScale } from 'utils/chart';
import { rgba2str, str2rgba } from 'utils/color';
Expand All @@ -34,14 +37,18 @@ import { CompareHyperparametersSettings } from './CompareHyperparameters.setting
export const COMPARE_HEAT_MAPS = 'compare-heatmaps';
export const HEAT_MAPS_TITLE = 'Heat Maps';

interface Props {
selectedExperiments: ExperimentWithTrial[];
trials: TrialItem[];
metricData: TrialMetricData;
interface BaseProps {
metricData: RunMetricData;
fullHParams: string[];
settings: CompareHyperparametersSettings;
}

type Props = XOR<
{ selectedExperiments: ExperimentWithTrial[]; trials: TrialItem[] },
{ selectedRuns: FlatRun[] }
> &
BaseProps;

type HpValue = Record<string, (number | string)[]>;

interface HpData {
Expand All @@ -51,7 +58,7 @@ interface HpData {
hpMetrics: Record<string, (number | undefined)[]>;
hpValues: HpValue;
metricRange: Range<number>;
trialIds: number[];
recordIds: number[];
}

const generateHpKey = (hParam1: string, hParam2: string): string => {
Expand All @@ -65,6 +72,7 @@ const parseHpKey = (key: string): [hParam1: string, hParam2: string] => {

const CompareHeatMaps: React.FC<Props> = ({
selectedExperiments,
selectedRuns,
trials,
metricData,
fullHParams,
Expand All @@ -91,11 +99,11 @@ const CompareHeatMaps: React.FC<Props> = ({

const smallerIsBetter = useMemo(() => {
if (selectedMetric && selectedMetric.group === MetricType.Validation) {
const selectedExperimentsWithMetric = selectedExperiments.filter((exp) => {
const selectedExperimentsWithMetric = selectedExperiments?.filter((exp) => {
return selectedMetric.name === exp?.experiment?.config?.searcher?.metric;
});

return selectedExperimentsWithMetric.some((exp) => {
return selectedExperimentsWithMetric?.some((exp) => {
return exp?.experiment?.config?.searcher?.smallerIsBetter;
});
}
Expand Down Expand Up @@ -172,7 +180,7 @@ const CompareHeatMaps: React.FC<Props> = ({
null,
chartData?.hpMetrics[key] || [],
chartData?.hpMetrics[key] || [],
chartData?.trialIds || [],
chartData?.recordIds || [],
],
],
options: {
Expand All @@ -184,7 +192,7 @@ const CompareHeatMaps: React.FC<Props> = ({
series: [{}, { fill, stroke }],
title,
},
tooltipLabels: [xLabel, yLabel, null, metricToStr(selectedMetric), null, 'trial ID'],
tooltipLabels: [xLabel, yLabel, null, metricToStr(selectedMetric), null, 'record ID'],
};
});
});
Expand Down Expand Up @@ -234,7 +242,7 @@ const CompareHeatMaps: React.FC<Props> = ({

const experimentHyperparameters = useMemo(() => {
const hpMap: Record<string, Hyperparameter> = {};
selectedExperiments.forEach((exp) => {
selectedExperiments?.forEach((exp) => {
const hps = Object.keys(exp.experiment.hyperparameters);
hps.forEach((hp) => (hpMap[hp] = exp.experiment.hyperparameters[hp]));
});
Expand All @@ -244,7 +252,7 @@ const CompareHeatMaps: React.FC<Props> = ({
useEffect(() => {
if (ui.isPageHidden || !selectedMetric) return;

const trialIds: number[] = [];
const recordIds: number[] = [];
const hpMetricMap: Record<number, Record<string, number | undefined>> = {};
const hpValueMap: Record<number, Record<string, Primitive>> = {};
const hpLabelMap: Record<string, string[]> = {};
Expand All @@ -255,11 +263,14 @@ const CompareHeatMaps: React.FC<Props> = ({
const hpValues: HpValue = {};
const metricRange: Range<number> = [Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY];

trials.forEach((trial) => {
if (!isObject(trial.hyperparameters)) return;
const recordHyperparameters: [number, TrialHyperparameters][] = selectedRuns
? selectedRuns.flatMap((run) => (run.hyperparameters ? [[run.id, run.hyperparameters]] : []))
: trials.map((trial) => [trial.id, trial.hyperparameters]);

recordHyperparameters?.forEach(([recordId, recordHp]) => {
if (!isObject(recordHp)) return;

const trialId = trial.id;
const flatHParams = flattenObject(trial.hyperparameters);
const flatHParams = flattenObject(recordHp);
const trialHParams = Object.keys(flatHParams)
.filter((hParam) => fullHParams.includes(hParam))
.sort((a, b) => a.localeCompare(b, 'en'));
Expand All @@ -269,16 +280,16 @@ const CompareHeatMaps: React.FC<Props> = ({
* dynamic min/max ranges via uPlot.Scales.
*/
const key = metricToKey(selectedMetric);
const trialMetric = data?.[trial.id]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1];
const trialMetric = data?.[recordId]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1];

trialIds.push(trialId);
hpMetricMap[trialId] = hpMetricMap[trialId] || {};
hpValueMap[trialId] = hpValueMap[trialId] || {};
recordIds.push(recordId);
hpMetricMap[recordId] = hpMetricMap[recordId] || {};
hpValueMap[recordId] = hpValueMap[recordId] || {};
trialHParams.forEach((hParam1) => {
hpValueMap[trialId][hParam1] = flatHParams[hParam1];
hpValueMap[recordId][hParam1] = flatHParams[hParam1];
trialHParams.forEach((hParam2) => {
const key = generateHpKey(hParam1, hParam2);
hpMetricMap[trialId][key] = trialMetric;
hpMetricMap[recordId][key] = trialMetric;
});
});

Expand All @@ -295,8 +306,8 @@ const CompareHeatMaps: React.FC<Props> = ({
hpLabelValueMap[hParam1] = [];
hpValues[hParam1] = [];

trialIds.forEach((trialId) => {
const hpRawValue = hpValueMap[trialId][hParam1];
recordIds.forEach((recordId) => {
const hpRawValue = hpValueMap[recordId][hParam1];
const hpValue = isBoolean(hpRawValue) ? hpRawValue.toString() : hpRawValue;

hpValues[hParam1].push(hpValue);
Expand All @@ -316,7 +327,7 @@ const CompareHeatMaps: React.FC<Props> = ({

fullHParams.forEach((hParam2) => {
const key = generateHpKey(hParam1, hParam2);
hpMetrics[key] = trialIds.map((trialId) => hpMetricMap[trialId][key]);
hpMetrics[key] = recordIds.map((recordId) => hpMetricMap[recordId][key]);
});
});

Expand All @@ -327,9 +338,17 @@ const CompareHeatMaps: React.FC<Props> = ({
hpMetrics,
hpValues,
metricRange,
trialIds,
recordIds,
});
}, [fullHParams, selectedMetric, ui.isPageHidden, trials, data, experimentHyperparameters]);
}, [
fullHParams,
selectedMetric,
ui.isPageHidden,
trials,
data,
experimentHyperparameters,
selectedRuns,
]);

if (!metricsLoaded || !chartData) {
return <Spinner center spinning />;
Expand Down
23 changes: 20 additions & 3 deletions webui/react/src/components/CompareHyperparameters.test.mock.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import React from 'react';

import { TrialMetricData } from 'pages/TrialDetails/useTrialMetrics';
import { RunMetricData } from 'hooks/useMetrics';
import { Scale } from 'types';
import { generateTestRunData } from 'utils/tests/generateTestData';

import CompareHyperparameters from './CompareHyperparameters';
export const METRIC_DATA: TrialMetricData = {
export const METRIC_DATA: RunMetricData = {
data: {
3400: {
'{"group":"training","name":"loss"}': {
Expand Down Expand Up @@ -395,10 +396,14 @@ export const TRIALS = [
},
];

export const SELECTED_RUNS = [generateTestRunData(), generateTestRunData(), generateTestRunData()];

interface Props {
empty?: boolean;
}
export const CompareHyperparametersWithMocks: React.FC<Props> = ({ empty }: Props): JSX.Element => {
export const CompareTrialHyperparametersWithMocks: React.FC<Props> = ({
empty,
}: Props): JSX.Element => {
return (
<CompareHyperparameters
metricData={METRIC_DATA}
Expand All @@ -410,3 +415,15 @@ export const CompareHyperparametersWithMocks: React.FC<Props> = ({ empty }: Prop
/>
);
};

export const CompareRunHyperparametersWithMocks: React.FC<Props> = ({
empty,
}: Props): JSX.Element => {
return (
<CompareHyperparameters
metricData={METRIC_DATA}
projectId={1}
selectedRuns={empty ? [] : SELECTED_RUNS}
/>
);
};
53 changes: 30 additions & 23 deletions webui/react/src/components/CompareHyperparameters.test.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
//

import { render, screen } from '@testing-library/react';
import UIProvider, { DefaultTheme } from 'hew/Theme';
import { BrowserRouter } from 'react-router-dom';
Expand All @@ -8,7 +6,10 @@ import { SettingsProvider } from 'hooks/useSettingsProvider';

import { COMPARE_HEAT_MAPS } from './CompareHeatMaps';
import { NO_DATA_MESSAGE } from './CompareHyperparameters';
import { CompareHyperparametersWithMocks } from './CompareHyperparameters.test.mock';
import {
CompareRunHyperparametersWithMocks,
CompareTrialHyperparametersWithMocks,
} from './CompareHyperparameters.test.mock';
import { COMPARE_PARALLEL_COORDINATES } from './CompareParallelCoordinates';
import { COMPARE_SCATTER_PLOTS } from './CompareScatterPlots';
import { ThemeProvider } from './ThemeProvider';
Expand All @@ -35,13 +36,17 @@ vi.mock('hooks/useSettings', async (importOriginal) => {
};
});

const setup = (empty?: boolean) => {
const setup = (type: 'trials' | 'runs', empty?: boolean) => {
render(
<BrowserRouter>
<UIProvider theme={DefaultTheme.Light}>
<ThemeProvider>
<SettingsProvider>
<CompareHyperparametersWithMocks empty={empty} />
{type === 'trials' ? (
<CompareTrialHyperparametersWithMocks empty={empty} />
) : (
<CompareRunHyperparametersWithMocks empty={empty} />
)}
</SettingsProvider>
</ThemeProvider>
</UIProvider>
Expand All @@ -50,23 +55,25 @@ const setup = (empty?: boolean) => {
};

describe('CompareHyperparameters component', () => {
it('renders Parallel Coordinates', () => {
setup();
expect(screen.getByTestId(COMPARE_PARALLEL_COORDINATES)).toBeInTheDocument();
});
it('renders Scatter Plots', () => {
setup();
expect(screen.getByTestId(COMPARE_SCATTER_PLOTS)).toBeInTheDocument();
});
it('renders Heat Maps', () => {
setup();
expect(screen.getByTestId(COMPARE_HEAT_MAPS)).toBeInTheDocument();
});
it('renders no data state', () => {
setup(true);
expect(screen.queryByTestId(COMPARE_PARALLEL_COORDINATES)).not.toBeInTheDocument();
expect(screen.queryByTestId(COMPARE_SCATTER_PLOTS)).not.toBeInTheDocument();
expect(screen.queryByTestId(COMPARE_HEAT_MAPS)).not.toBeInTheDocument();
expect(screen.queryByText(NO_DATA_MESSAGE)).toBeInTheDocument();
describe.each(['trials', 'runs'] as const)('%s', (type) => {
it('renders Parallel Coordinates', () => {
setup(type);
expect(screen.getByTestId(COMPARE_PARALLEL_COORDINATES)).toBeInTheDocument();
});
it('renders Scatter Plots', () => {
setup(type);
expect(screen.getByTestId(COMPARE_SCATTER_PLOTS)).toBeInTheDocument();
});
it('renders Heat Maps', () => {
setup(type);
expect(screen.getByTestId(COMPARE_HEAT_MAPS)).toBeInTheDocument();
});
it('renders no data state', () => {
setup(type, true);
expect(screen.queryByTestId(COMPARE_PARALLEL_COORDINATES)).not.toBeInTheDocument();
expect(screen.queryByTestId(COMPARE_SCATTER_PLOTS)).not.toBeInTheDocument();
expect(screen.queryByTestId(COMPARE_HEAT_MAPS)).not.toBeInTheDocument();
expect(screen.queryByText(NO_DATA_MESSAGE)).toBeInTheDocument();
});
});
});
Loading

0 comments on commit 32585ad

Please sign in to comment.