Skip to content

Commit

Permalink
fix: allow doesnotcontains filters on hyperparameter column (#8842)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashtonG authored Jun 13, 2024
1 parent 382995c commit ee66d15
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
43 changes: 42 additions & 1 deletion master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"os"
"sort"
"strconv"
"sync"
Expand Down Expand Up @@ -1170,6 +1171,46 @@ func TestSearchExperiments(t *testing.T) {
require.Equal(t, int32(5), resp.Experiments[2].BestTrial.Restarts)
}

func TestSearchExperimentsFilters(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
_, projectIDInt := createProjectAndWorkspace(ctx, t, api)
projectID := int32(projectIDInt)

// RequireMockExperimentParams expects to be run in the db folder.
err := os.Chdir("./db")
require.NoError(t, err)
paramNames := []string{"foo"}
db.RequireMockExperimentParams(t, api.m.db, curUser, db.MockExperimentParams{
HParamNames: &paramNames,
}, projectIDInt)
err = os.Chdir("./..")
require.NoError(t, err)

tests := map[string]struct {
expectedNumExperiments int
filter string
}{
"ExpHPNotContains": {
expectedNumExperiments: 0,
filter: `{"filterGroup":{"children":[{"columnName":"hp.foo","kind":"field",` +
`"location":"LOCATION_TYPE_HYPERPARAMETERS","operator":"notContains","type":"COLUMN_TYPE_NUMBER","value":1}],` +
`"conjunction":"and","kind":"group"},"showArchived":false}`,
},
}

for testCase, testVars := range tests {
t.Run(testCase, func(t *testing.T) {
resp, requestError := api.SearchExperiments(ctx, &apiv1.SearchExperimentsRequest{
ProjectId: &projectID,
Filter: ptrs.Ptr(testVars.filter),
})

require.NoError(t, requestError)
require.Len(t, resp.Experiments, testVars.expectedNumExperiments)
})
}
}

func TestSearchExperimentsMalformed(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
_, projectIDInt := createProjectAndWorkspace(ctx, t, api)
Expand Down Expand Up @@ -2063,7 +2104,7 @@ func TestExperimentSearchApiFilterParsing(t *testing.T) {
`{"filterGroup":{"children":[{"type":"COLUMN_TYPE_TEXT","location":"LOCATION_TYPE_HYPERPARAMETERS", "columnName":"hp.clip_grad.clip.grad","kind":"field","operator":"notContains", "value":"some_string"}],"conjunction":"and","kind":"group"},"showArchived":true}`,
`((((CASE
WHEN config->'hyperparameters'->'clip_grad'->'clip'->'grad'->>'type' = 'const' THEN config->'hyperparameters'->'clip_grad'->'clip'->'grad'->>'val' NOT LIKE '%some_string%'
WHEN config->'hyperparameters'->'clip_grad'->'clip'->'grad'->>'type' = 'categorical' THEN (config->'hyperparameters'->'clip_grad'->'clip'->'grad'->>'vals')::jsonb ? 'some_string') IS NOT TRUE
WHEN config->'hyperparameters'->'clip_grad'->'clip'->'grad'->>'type' = 'categorical' THEN (config->'hyperparameters'->'clip_grad'->'clip'->'grad'->>'vals')::jsonb ? 'some_string' IS NOT TRUE
ELSE false
END))))`,
},
Expand Down
2 changes: 1 addition & 1 deletion master/internal/experiment_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func hpToSQL(c string, filterColumnType *string, filterValue *interface{},
queryArgs = append(queryArgs, bun.Safe("?"), queryValue)
queryString = fmt.Sprintf(`(CASE
WHEN config->'hyperparameters'->%s->>'type' = 'const' THEN config->'hyperparameters'->%s->>'val' NOT LIKE %s
WHEN config->'hyperparameters'->%s->>'type' = 'categorical' THEN (config->'hyperparameters'->%s->>'vals')::jsonb %s %s) IS NOT TRUE
WHEN config->'hyperparameters'->%s->>'type' = 'categorical' THEN (config->'hyperparameters'->%s->>'vals')::jsonb %s %s IS NOT TRUE
ELSE false
END)`, hpQuery, hpQuery, "?", hpQuery, hpQuery, "?", "?")
default:
Expand Down

0 comments on commit ee66d15

Please sign in to comment.