Skip to content

Commit

Permalink
chore: avoid payload limitation (#9164)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgongd authored Jun 10, 2024
1 parent dde6362 commit 439734b
Show file tree
Hide file tree
Showing 10 changed files with 949 additions and 185 deletions.
44 changes: 33 additions & 11 deletions master/internal/stream/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,27 @@ type ModelVersionMsg struct {
}

// SeqNum gets the SeqNum from a ModelVersionMsg.
func (pm *ModelVersionMsg) SeqNum() int64 {
return pm.Seq
func (mm *ModelVersionMsg) SeqNum() int64 {
return mm.Seq
}

// GetID gets the ID from a ModelVersionMsg.
func (mm *ModelVersionMsg) GetID() int {
return mm.ID
}

// UpsertMsg creates a ModelVersion stream upsert message.
func (pm *ModelVersionMsg) UpsertMsg() stream.UpsertMsg {
return stream.UpsertMsg{
func (mm *ModelVersionMsg) UpsertMsg() *stream.UpsertMsg {
return &stream.UpsertMsg{
JSONKey: ModelVersionsUpsertKey,
Msg: pm,
Msg: mm,
}
}

// DeleteMsg creates a ModelVersion stream delete message.
func (pm *ModelVersionMsg) DeleteMsg() stream.DeleteMsg {
deleted := strconv.FormatInt(int64(pm.ID), 10)
return stream.DeleteMsg{
func (mm *ModelVersionMsg) DeleteMsg() *stream.DeleteMsg {
deleted := strconv.Itoa(mm.ID)
return &stream.DeleteMsg{
Key: ModelVersionsDeleteKey,
Deleted: deleted,
}
Expand Down Expand Up @@ -158,7 +163,7 @@ func ModelVersionCollectStartupMsgs(
}
missing, appeared, err := processQuery(ctx, createQuery, spec.Since, known, "m")
if err != nil {
return nil, err
return nil, fmt.Errorf("processing known: %w", err)
}

// step 2: hydrate appeared IDs into full ModelVersionMsgs
Expand All @@ -171,14 +176,14 @@ func ModelVersionCollectStartupMsgs(
query = modelVersionPermFilterQuery(query, accessScopes)
}
err := query.Scan(ctx, &mvMsgs)
if err != nil && errors.Cause(err) != sql.ErrNoRows {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
log.Errorf("error: %v\n", err)
return nil, err
}
}

// step 3: emit deletions and updates to the client
out = append(out, stream.DeleteMsg{
out = append(out, &stream.DeleteMsg{
Key: ModelVersionsDeleteKey,
Deleted: missing,
})
Expand Down Expand Up @@ -258,3 +263,20 @@ func ModelVersionMakePermissionFilter(ctx context.Context, user model.User) (fun
}, nil
}
}

// ModelVersionMakeHydrator returns a function that gets properties of a model version by
// its id.
func ModelVersionMakeHydrator() func(*ModelVersionMsg) (*ModelVersionMsg, error) {
return func(msg *ModelVersionMsg) (*ModelVersionMsg, error) {
var saturatedMsg ModelVersionMsg
query := db.Bun().NewSelect().Model(&saturatedMsg).Where("id = ?", msg.GetID()).ExcludeColumn("workspace_id")
err := query.Scan(context.Background(), &saturatedMsg)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, err
} else if err != nil {
return nil, fmt.Errorf("error in model version hydrator: %w", err)
}
saturatedMsg.WorkspaceID = msg.WorkspaceID
return &saturatedMsg, nil
}
}
43 changes: 32 additions & 11 deletions master/internal/stream/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,27 @@ type ModelMsg struct {
}

// SeqNum gets the SeqNum from a ModelMsg.
func (pm *ModelMsg) SeqNum() int64 {
return pm.Seq
func (mm *ModelMsg) SeqNum() int64 {
return mm.Seq
}

// GetID gets the ID from a ModelMsg.
func (mm *ModelMsg) GetID() int {
return mm.ID
}

// UpsertMsg creates a model stream upsert message.
func (pm *ModelMsg) UpsertMsg() stream.UpsertMsg {
return stream.UpsertMsg{
func (mm *ModelMsg) UpsertMsg() *stream.UpsertMsg {
return &stream.UpsertMsg{
JSONKey: ModelsUpsertKey,
Msg: pm,
Msg: mm,
}
}

// DeleteMsg creates a model stream delete message.
func (pm *ModelMsg) DeleteMsg() stream.DeleteMsg {
deleted := strconv.FormatInt(int64(pm.ID), 10)
return stream.DeleteMsg{
func (mm *ModelMsg) DeleteMsg() *stream.DeleteMsg {
deleted := strconv.Itoa(mm.ID)
return &stream.DeleteMsg{
Key: ModelsDeleteKey,
Deleted: deleted,
}
Expand Down Expand Up @@ -152,7 +157,7 @@ func ModelCollectStartupMsgs(
}
missing, appeared, err := processQuery(ctx, createQuery, spec.Since, known, "m")
if err != nil {
return nil, err
return nil, fmt.Errorf("processing known: %w", err)
}

// step 2: hydrate appeared IDs into full ModelMsgs
Expand All @@ -163,14 +168,14 @@ func ModelCollectStartupMsgs(
query = permFilterQuery(query, accessScopes)
}
err := query.Scan(ctx, &modelMsgs)
if err != nil && errors.Cause(err) != sql.ErrNoRows {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
log.Errorf("error: %v\n", err)
return nil, err
}
}

// step 3: emit deletions and updates to the client
out = append(out, stream.DeleteMsg{
out = append(out, &stream.DeleteMsg{
Key: ModelsDeleteKey,
Deleted: missing,
})
Expand Down Expand Up @@ -246,3 +251,19 @@ func ModelMakePermissionFilter(ctx context.Context, user model.User) (func(*Mode
}, nil
}
}

// ModelMakeHydrator returns a function that gets properties of a model by
// its id.
func ModelMakeHydrator() func(*ModelMsg) (*ModelMsg, error) {
return func(msg *ModelMsg) (*ModelMsg, error) {
var saturatedMsg ModelMsg
query := db.Bun().NewSelect().Model(&saturatedMsg).Where("id = ?", msg.GetID())
err := query.Scan(context.Background(), &saturatedMsg)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, err
} else if err != nil {
return nil, fmt.Errorf("error in model hydrator: %w", err)
}
return &saturatedMsg, nil
}
}
37 changes: 29 additions & 8 deletions master/internal/stream/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,23 @@ func (pm *ProjectMsg) SeqNum() int64 {
return pm.Seq
}

// GetID gets the ID from a ProjectMsg.
func (pm *ProjectMsg) GetID() int {
return pm.ID
}

// UpsertMsg creates a Project stream upsert message.
func (pm *ProjectMsg) UpsertMsg() stream.UpsertMsg {
return stream.UpsertMsg{
func (pm *ProjectMsg) UpsertMsg() *stream.UpsertMsg {
return &stream.UpsertMsg{
JSONKey: ProjectsUpsertKey,
Msg: pm,
}
}

// DeleteMsg creates a Project stream delete message.
func (pm *ProjectMsg) DeleteMsg() stream.DeleteMsg {
deleted := strconv.FormatInt(int64(pm.ID), 10)
return stream.DeleteMsg{
func (pm *ProjectMsg) DeleteMsg() *stream.DeleteMsg {
deleted := strconv.Itoa(pm.ID)
return &stream.DeleteMsg{
Key: ProjectsDeleteKey,
Deleted: deleted,
}
Expand Down Expand Up @@ -151,7 +156,7 @@ func ProjectCollectStartupMsgs(
}
missing, appeared, err := processQuery(ctx, createQuery, spec.Since, known, "p")
if err != nil {
return nil, err
return nil, fmt.Errorf("processing known: %w", err)
}

// step 2: hydrate appeared IDs into full ProjectMsgs
Expand All @@ -162,14 +167,14 @@ func ProjectCollectStartupMsgs(
query = permFilterQuery(query, accessScopes)
}
err := query.Scan(ctx, &projMsgs)
if err != nil && errors.Cause(err) != sql.ErrNoRows {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
log.Errorf("error: %v\n", err)
return nil, err
}
}

// step 3: emit deletions and updates to the client
out = append(out, stream.DeleteMsg{
out = append(out, &stream.DeleteMsg{
Key: ProjectsDeleteKey,
Deleted: missing,
})
Expand Down Expand Up @@ -234,3 +239,19 @@ func ProjectMakePermissionFilter(ctx context.Context, user model.User) (func(*Pr
}, nil
}
}

// ProjectMakeHydrator returns a function that gets properties of a project by
// its id.
func ProjectMakeHydrator() func(*ProjectMsg) (*ProjectMsg, error) {
return func(msg *ProjectMsg) (*ProjectMsg, error) {
var saturatedMsg ProjectMsg
query := db.Bun().NewSelect().Model(&saturatedMsg).Where("project_msg.id = ?", msg.GetID())
err := query.Scan(context.Background(), &saturatedMsg)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, err
} else if err != nil {
return nil, fmt.Errorf("error in project hydrator: %w", err)
}
return &saturatedMsg, nil
}
}
19 changes: 15 additions & 4 deletions master/internal/stream/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"github.com/determined-ai/determined/master/pkg/syncx/errgroupx"
)

const maxEventCount = 120

// PublisherSet contains all publishers, and handles all websockets. It will connect each websocket
// with the appropriate set of publishers, based on that websocket's subscriptions.
//
Expand All @@ -38,9 +40,9 @@ func NewPublisherSet(dbAddress string) *PublisherSet {
lock := sync.Mutex{}
return &PublisherSet{
DBAddress: dbAddress,
Projects: stream.NewPublisher[*ProjectMsg](),
Models: stream.NewPublisher[*ModelMsg](),
ModelVersions: stream.NewPublisher[*ModelVersionMsg](),
Projects: stream.NewPublisher[*ProjectMsg](ProjectMakeHydrator()),
Models: stream.NewPublisher[*ModelMsg](ModelMakeHydrator()),
ModelVersions: stream.NewPublisher[*ModelVersionMsg](ModelVersionMakeHydrator()),
bootemChan: make(chan struct{}),
readyCond: *sync.NewCond(&lock),
}
Expand Down Expand Up @@ -414,6 +416,7 @@ func publishLoop[T stream.Msg](
events = append(events, event)
// Collect all available notifications before proceeding.
keepGoing := true
eventCount := 0
for keepGoing {
select {
case notification = <-listener.Notify:
Expand All @@ -423,12 +426,20 @@ func publishLoop[T stream.Msg](
return err
}
events = append(events, event)
eventCount++
keepGoing = eventCount < maxEventCount
default:
keepGoing = false
}
}

idToSaturatedMsg := map[int]*stream.UpsertMsg{}
// TODO: MD-434 improve performance by batch hydrating the messages.
for _, ev := range events {
publisher.HydrateMsg(ev.After, idToSaturatedMsg)
}
// Broadcast all the events.
publisher.Broadcast(events)
publisher.Broadcast(events, idToSaturatedMsg)
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions master/internal/stream/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package stream
import (
"context"

"github.com/labstack/echo/v4"

"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/stream"
)
Expand Down Expand Up @@ -156,7 +154,7 @@ func startup[T stream.Msg, S any](
// Scan for historical msgs matching newly-added subscriptions.
newmsgs, err := state.CollectStartupMsgs(ctx, user, known, *spec)
if err != nil {
return echo.ErrCookieNotFound
return err
}
for _, msg := range newmsgs {
*msgs = append(*msgs, prepare(msg))
Expand Down
4 changes: 2 additions & 2 deletions master/internal/stream/test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
// otherwise, returns the MarshallableMsg that the streamer sends.
func testPrepareFunc(i stream.MarshallableMsg) interface{} {
switch msg := i.(type) {
case stream.UpsertMsg:
case *stream.UpsertMsg:
switch typedMsg := msg.Msg.(type) {
case *ProjectMsg:
return fmt.Sprintf(
Expand All @@ -50,7 +50,7 @@ func testPrepareFunc(i stream.MarshallableMsg) interface{} {
typedMsg.WorkspaceID,
)
}
case stream.DeleteMsg:
case *stream.DeleteMsg:
return fmt.Sprintf("key: %s, deleted: %s", msg.Key, msg.Deleted)
case stream.SyncMsg:
return fmt.Sprintf("key: %s, sync_id: %s, complete: %t", syncKey, msg.SyncID, msg.Complete)
Expand Down
Loading

0 comments on commit 439734b

Please sign in to comment.