Skip to content

Commit

Permalink
feat: Add model version streaming (#9029)
Browse files Browse the repository at this point in the history
  • Loading branch information
gt2345 authored Apr 4, 2024
1 parent 2c6fec7 commit 9dce6f0
Show file tree
Hide file tree
Showing 20 changed files with 748 additions and 46 deletions.
41 changes: 41 additions & 0 deletions e2e_tests/tests/streaming/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from determined.common import streams
from determined.common.api import bindings
from determined.experimental import client
from tests import api_utils
from tests import config as conf
from tests import experiment as exp


@pytest.mark.e2e_cpu
Expand Down Expand Up @@ -122,3 +125,41 @@ def test_client_subscribe() -> None:
else:
raise ValueError(f"Unexpected message from stream. {event}")
assert deleted


@pytest.mark.e2e_cpu
def test_subscribe_model_version() -> None:
# Subscribe to model versions by model ID
# When model version is created, verify that can be received from the stream
sess = api_utils.admin_session()
ws = streams._client.LomondStreamWebSocket(sess)
stream = streams._client.Stream(ws)
syncId = "sync2"
modelName = "test_model_version_streaming"

detobj = client.Determined._from_session(sess)

exp_id = exp.create_experiment(
sess,
conf.fixtures_path("no_op/gc_checkpoints_decreasing.yaml"),
conf.fixtures_path("no_op"),
)
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.COMPLETED)

ckpt = detobj.get_experiment(exp_id).top_checkpoint()

resp_m = bindings.post_PostModel(sess, body=bindings.v1PostModelRequest(name=modelName))
m = resp_m.model

stream.subscribe(sync_id=syncId, model_versions=streams._client.ModelVersionSpec(model_id=m.id))

bindings.post_PostModelVersion(
sess,
body=bindings.v1PostModelVersionRequest(checkpointUuid=ckpt.uuid, modelName=modelName),
modelName=modelName,
)
for event in stream:
if isinstance(event, streams.wire.ModelVersionMsg):
assert event.model_id == m.id
assert event.checkpoint_uuid == ckpt.uuid
break
40 changes: 37 additions & 3 deletions harness/determined/common/streams/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
def _copy(self) -> "ProjectSpec":
return ProjectSpec(self.workspace_id, self.project_id)

def _to_wire(self, since: int) -> Dict[str, Any]:
def _to_wire(self) -> Dict[str, Any]:
return wire.ProjectSubscriptionSpec(
workspace_ids=int_or_list(self.workspace_id),
project_ids=int_or_list(self.project_id),
Expand All @@ -126,14 +126,36 @@ def __init__(
def _copy(self) -> "ModelSpec":
return ModelSpec(self.workspace_id, self.model_id, self.user_id)

def _to_wire(self, since: int) -> Dict[str, Any]:
def _to_wire(self) -> Dict[str, Any]:
return wire.ModelSubscriptionSpec(
workspace_ids=int_or_list(self.workspace_id),
model_ids=int_or_list(self.model_id),
user_ids=int_or_list(self.user_id),
).to_json()


class ModelVersionSpec:
def __init__(
self,
model_version_id: Optional[Union[int, Sequence[int]]] = None,
model_id: Optional[Union[int, Sequence[int]]] = None,
user_id: Optional[Union[int, Sequence[int]]] = None,
) -> None:
self.model_version_id = model_version_id
self.model_id = model_id
self.user_id = user_id

def _copy(self) -> "ModelSpec":
return ModelSpec(self.model_version_id, self.model_id, self.user_id)

def _to_wire(self) -> Dict[str, Any]:
return wire.ModelVersionSubscriptionSpec(
model_version_ids=int_or_list(self.model_version_id),
model_ids=int_or_list(self.model_id),
user_ids=int_or_list(self.user_id),
).to_json()


class Sync:
def __init__(self, sync_id: Any, complete: bool) -> None:
self.sync_id = sync_id
Expand Down Expand Up @@ -251,6 +273,7 @@ def __init__(self, ws: StreamWebSocket) -> None:
# Our stream-level in-memory cache: just enough to handle automatic reconnects.
self._projects = KeyCache()
self._models = KeyCache()
self._model_versions = KeyCache()
# The websocket events. We'll connect (and reconnect) lazily.
self._ws_iter: Optional[Iterable] = None
self._closed = False
Expand Down Expand Up @@ -284,6 +307,12 @@ def __init__(self, ws: StreamWebSocket) -> None:
"projects_deleted": self._make_deletion_handler(wire.ProjectsDeleted, self._projects),
"model": self._make_upsertion_handler(wire.ModelMsg, self._models),
"models_deleted": self._make_deletion_handler(wire.ModelsDeleted, self._models),
"modelversion": self._make_upsertion_handler(
wire.ModelVersionMsg, self._model_versions
),
"modelversions_deleted": self._make_deletion_handler(
wire.ModelVersionMsg, self._model_versions
),
}

self._retries = 0
Expand Down Expand Up @@ -321,8 +350,9 @@ def _send_spec(self, spec: Dict[str, Any]) -> None:
since = {
"projects": self._projects.maxseq,
"models": self._models.maxseq,
"modelversions": self._model_versions.maxseq,
}
subscribe = {k: v._to_wire(since[k]) for k, v in spec.items()}
subscribe = {k: v._to_wire() for k, v in spec.items()}
# add since info to our initial subscriptions
#
# Note: it is important that we calculate our since values at the moment that we send the
Expand All @@ -343,6 +373,7 @@ def _send_spec(self, spec: Dict[str, Any]) -> None:
for k, v in {
"projects": self._projects.known(),
"models": self._models.known(),
"modelversions": self._models.known(),
}.items()
if v
},
Expand Down Expand Up @@ -492,6 +523,7 @@ def subscribe(
*,
projects: Optional[ProjectSpec] = None,
models: Optional[ModelSpec] = None,
model_versions: Optional[ModelVersionSpec] = None,
) -> "Stream":
# Capture what the user asked for immediately, but we won't fill since or known values until
# we send it.
Expand All @@ -500,6 +532,8 @@ def subscribe(
spec["projects"] = projects._copy()
if models:
spec["models"] = models._copy()
if model_versions:
spec["modelversions"] = model_versions._copy()
self._specs.append((sync_id, spec))
# Adding a spec can trigger sending a subscription.
self._advance_subscription()
Expand Down
56 changes: 54 additions & 2 deletions harness/determined/common/streams/wire.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion master/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ build/mock_gen.stamp: $(MOCK_INPUTS)
check-gen: force-gen gen build/mock_gen.stamp
# Checking that committed, generated code is up-to-date by ensuring that
# git reports the files as unchanged after forcibly regenerating the files:
test -z "$(shell git status --porcelain '**/zgen*' $(STREAM_PYTHON_CLIENT))"
test -z "$(shell git status --porcelain '**/zgen*' $(STREAM_PYTHON_CLIENT) $(STREAM_TS_CLIENT))" || (git diff; false)

.PHONY: get-deps
get-deps:
Expand Down
7 changes: 7 additions & 0 deletions master/cmd/stream-gen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"io/fs"
"os"
"sort"
"strconv"
"strings"

Expand All @@ -24,6 +25,7 @@ type streamType string
const (
json streamType = "JSONB"
text streamType = "string"
textArr streamType = "[]string"
integer streamType = "int"
integer64 streamType = "int64"
intArr streamType = "[]int"
Expand Down Expand Up @@ -270,6 +272,7 @@ func genTypescript(streamables []Streamable) ([]byte, error) {
x := map[streamType]([2]string){
json: {"any", "{}"},
text: {"string", ""},
textArr: {"Array<string>", "[]"},
boolean: {"bool", "false"},
integer: {"number", "0"},
integer64: {"number", "0"},
Expand Down Expand Up @@ -371,6 +374,7 @@ func genPython(streamables []Streamable) ([]byte, error) {
x := map[streamType]string{
json: "typing.Any",
text: "str",
textArr: "typing.List[str]",
boolean: "bool",
integer: "int",
integer64: "int",
Expand Down Expand Up @@ -598,6 +602,9 @@ func main() {
verifyArgs(results)

// generate the language bindings
sort.Slice(results, func(i, j int) bool {
return results[i].Name < results[j].Name
})
var content []byte
switch lang {
case python:
Expand Down
8 changes: 8 additions & 0 deletions master/internal/stream/authz_basic_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ func (a *StreamAuthZBasic) GetModelStreamableScopes(
return model.AccessScopeSet{model.GlobalAccessScopeID: true}, nil
}

// GetModelVersionStreamableScopes always returns an AccessScopeSet with global permissions and a nil error.
func (a *StreamAuthZBasic) GetModelVersionStreamableScopes(
_ context.Context,
_ model.User,
) (model.AccessScopeSet, error) {
return model.AccessScopeSet{model.GlobalAccessScopeID: true}, nil
}

// GetPermissionChangeListener always returns a nil pointer and a nil error.
func (a *StreamAuthZBasic) GetPermissionChangeListener() (*pq.Listener, error) {
return nil, nil
Expand Down
3 changes: 3 additions & 0 deletions master/internal/stream/authz_iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ type StreamAuthZ interface {
// GetModelStreamableScopes returns an AccessScopeSet where the user has permission to view models.
GetModelStreamableScopes(ctx context.Context, curUser model.User) (model.AccessScopeSet, error)

// GetModelVersionStreamableScopes returns an AccessScopeSet where the user has permission to view models.
GetModelVersionStreamableScopes(ctx context.Context, curUser model.User) (model.AccessScopeSet, error)

// GetPermissionChangeListener returns a pointer listener
// listening for permission change notifications if applicable.
GetPermissionChangeListener() (*pq.Listener, error)
Expand Down
5 changes: 3 additions & 2 deletions master/internal/stream/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ type StartupMsg struct {
//
// Each field of a KnownKeySet is a comma-separated list of int64s and ranges like "a,b-c,d".
type KnownKeySet struct {
Projects string `json:"projects"`
Models string `json:"models"`
Projects string `json:"projects"`
Models string `json:"models"`
ModelVersions string `json:"modelversions"`
}

// prepareWebsocketMessage converts the MarshallableMsg into a websocket.PreparedMessage.
Expand Down
Loading

0 comments on commit 9dce6f0

Please sign in to comment.