Skip to content

Commit

Permalink
feat: add model streaming (#8973)
Browse files Browse the repository at this point in the history
  • Loading branch information
gt2345 authored Mar 19, 2024
1 parent 8bf280d commit 137bfcd
Show file tree
Hide file tree
Showing 16 changed files with 720 additions and 74 deletions.
64 changes: 55 additions & 9 deletions e2e_tests/tests/streaming/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_client_subscribe() -> None:
syncId = "sync1"
projectName = "streaming_project"
newProjectName = "streaming_project_1"
modelName = "streaming_model"
newModelName = "streaming_model_1"
pSeq = 0
mSeq = 0

resp_w = bindings.post_PostWorkspace(
sess, body=bindings.v1PostWorkspaceRequest(name=f"streaming_workspace_{random.random()}")
Expand All @@ -48,31 +52,73 @@ def test_client_subscribe() -> None:
workspaceId=w.id,
)
p = resp_p.project
resp_m = bindings.post_PostModel(
sess, body=bindings.v1PostModelRequest(name=modelName, workspaceId=w.id)
)
m = resp_m.model

stream.subscribe(sync_id=syncId, projects=streams._client.ProjectSpec(workspace_id=w.id))
stream.subscribe(
sync_id=syncId,
projects=streams._client.ProjectSpec(workspace_id=w.id),
models=streams._client.ModelSpec(workspace_id=w.id),
)
event = next(stream)
assert event == streams._client.Sync(syncId, False)
event = next(stream)
assert isinstance(event, streams.wire.ProjectMsg)
assert event.id == p.id
assert event.name == projectName
seq = event.seq
event = next(stream)
assert event == streams._client.Sync(syncId, True)
findProject, findModel, finish = False, False, False
for event in stream:
if isinstance(event, streams.wire.ProjectMsg):
assert event.id == p.id
assert event.name == projectName
pSeq = event.seq
findProject = True
if isinstance(event, streams.wire.ModelMsg):
assert event.id == m.id
assert event.name == modelName
mSeq = event.seq
findModel = True
if event == streams._client.Sync(syncId, True):
finish = True
break
assert (
findProject and findModel and finish
), f"Project found: {findProject}\n Model found: {findModel}\n Sync finished: {finish}"

bindings.patch_PatchProject(sess, body=bindings.v1PatchProject(name=newProjectName), id=p.id)
event = next(stream)
assert isinstance(event, streams.wire.ProjectMsg)
assert event.id == p.id
assert event.name == newProjectName
assert event.seq > seq
assert event.seq > pSeq

bindings.patch_PatchModel(
sess, body=bindings.v1PatchModel(name=newModelName), modelName=modelName
)
event = next(stream)
assert isinstance(event, streams.wire.ModelMsg)
assert event.id == m.id
assert event.name == newModelName
assert event.seq > mSeq

bindings.delete_DeleteProject(sess, id=p.id)
deleted = False
for event in stream:
if isinstance(event, streams.wire.ProjectMsg):
assert event.state == "DELETING"
elif isinstance(event, streams.wire.ProjectsDeleted):
assert event == streams.wire.ProjectsDeleted(str(p.id))
deleted = True
break
else:
raise ValueError(f"Unexpected message from stream. {event}")
assert deleted

bindings.delete_DeleteModel(sess, modelName=newModelName)
deleted = False
for event in stream:
if isinstance(event, streams.wire.ModelsDeleted):
assert event == streams.wire.ModelsDeleted(str(m.id))
deleted = True
break
else:
raise ValueError(f"Unexpected message from stream. {event}")
assert deleted
35 changes: 31 additions & 4 deletions harness/determined/common/streams/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ def _to_wire(self, since: int) -> Dict[str, Any]:
).to_json()


class ModelSpec:
def __init__(
self,
workspace_id: Optional[Union[int, Sequence[int]]] = None,
model_id: Optional[Union[int, Sequence[int]]] = None,
user_id: Optional[Union[int, Sequence[int]]] = None,
) -> None:
self.workspace_id = workspace_id
self.model_id = model_id
self.user_id = user_id

def _copy(self) -> "ModelSpec":
return ModelSpec(self.workspace_id, self.model_id, self.user_id)

def _to_wire(self, since: int) -> Dict[str, Any]:
return wire.ModelSubscriptionSpec(
workspace_ids=int_or_list(self.workspace_id),
model_ids=int_or_list(self.model_id),
user_ids=int_or_list(self.user_id),
).to_json()


class Sync:
def __init__(self, sync_id: Any, complete: bool) -> None:
self.sync_id = sync_id
Expand Down Expand Up @@ -228,6 +250,7 @@ def __init__(self, ws: StreamWebSocket) -> None:
self._ws = ws
# Our stream-level in-memory cache: just enough to handle automatic reconnects.
self._projects = KeyCache()
self._models = KeyCache()
# The websocket events. We'll connect (and reconnect) lazily.
self._ws_iter: Optional[Iterable] = None
self._closed = False
Expand Down Expand Up @@ -259,6 +282,8 @@ def __init__(self, ws: StreamWebSocket) -> None:
self.handlers: Dict[str, MsgHandler] = {
"project": self._make_upsertion_handler(wire.ProjectMsg, self._projects),
"projects_deleted": self._make_deletion_handler(wire.ProjectsDeleted, self._projects),
"model": self._make_upsertion_handler(wire.ModelMsg, self._models),
"models_deleted": self._make_deletion_handler(wire.ModelsDeleted, self._models),
}

self._retries = 0
Expand Down Expand Up @@ -295,7 +320,7 @@ def _send_spec(self, spec: Dict[str, Any]) -> None:
# build our startup message
since = {
"projects": self._projects.maxseq,
# "experiments": self._experiments.maxseq,
"models": self._models.maxseq,
}
subscribe = {k: v._to_wire(since[k]) for k, v in spec.items()}
# add since info to our initial subscriptions
Expand All @@ -317,7 +342,7 @@ def _send_spec(self, spec: Dict[str, Any]) -> None:
k: v
for k, v in {
"projects": self._projects.known(),
# "experiments": self._experiments.known(),
"models": self._models.known(),
}.items()
if v
},
Expand Down Expand Up @@ -466,13 +491,15 @@ def subscribe(
sync_id: Any = None,
*,
projects: Optional[ProjectSpec] = None,
# experiments: Optional[ExperimentSpec] = None,
models: Optional[ModelSpec] = None,
) -> "Stream":
# Capture what the user asked for immediately, but we won't fill since or known values until
# we send it.
spec = {}
spec: Dict[str, Any] = {}
if projects:
spec["projects"] = projects._copy()
if models:
spec["models"] = models._copy()
self._specs.append((sync_id, spec))
# Adding a spec can trigger sending a subscription.
self._advance_subscription()
Expand Down
48 changes: 48 additions & 0 deletions harness/determined/common/streams/wire.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions master/cmd/stream-gen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ func genPython(streamables []Streamable) ([]byte, error) {
"int": "int",
"int64": "int",
"[]int": "typing.List[int]",
"[]string": "typing.List[str]",
"time.Time": "float",
"*time.Time": "typing.Optional[float]",
"model.TaskID": "str",
Expand Down
8 changes: 8 additions & 0 deletions master/internal/stream/authz_basic_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ func (a *StreamAuthZBasic) GetProjectStreamableScopes(
return model.AccessScopeSet{model.GlobalAccessScopeID: true}, nil
}

// GetModelStreamableScopes always returns an AccessScopeSet with global permissions and a nil error.
func (a *StreamAuthZBasic) GetModelStreamableScopes(
_ context.Context,
_ model.User,
) (model.AccessScopeSet, error) {
return model.AccessScopeSet{model.GlobalAccessScopeID: true}, nil
}

// GetPermissionChangeListener always returns a nil pointer and a nil error.
func (a *StreamAuthZBasic) GetPermissionChangeListener() (*pq.Listener, error) {
return nil, nil
Expand Down
3 changes: 3 additions & 0 deletions master/internal/stream/authz_iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ type StreamAuthZ interface {
// GetProjectStreamableScopes returns an AccessScopeSet where the user has permission to view projects.
GetProjectStreamableScopes(ctx context.Context, curUser model.User) (model.AccessScopeSet, error)

// GetModelStreamableScopes returns an AccessScopeSet where the user has permission to view models.
GetModelStreamableScopes(ctx context.Context, curUser model.User) (model.AccessScopeSet, error)

// GetPermissionChangeListener returns a pointer listener
// listening for permission change notifications if applicable.
GetPermissionChangeListener() (*pq.Listener, error)
Expand Down
1 change: 1 addition & 0 deletions master/internal/stream/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type StartupMsg struct {
// Each field of a KnownKeySet is a comma-separated list of int64s and ranges like "a,b-c,d".
type KnownKeySet struct {
Projects string `json:"projects"`
Models string `json:"models"`
}

// prepareWebsocketMessage converts the MarshallableMsg into a websocket.PreparedMessage.
Expand Down
Loading

0 comments on commit 137bfcd

Please sign in to comment.