diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 2ad21741b0e1..a7c9518dc473 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -1204,7 +1204,7 @@ func (f *Forwarder) join(ctx *authContext, w http.ResponseWriter, req *http.Requ return trace.Wrap(err) } - client := &websocketClientStreams{stream} + client := &websocketClientStreams{uuid.New(), stream} party := newParty(*ctx, stream.Mode, client) err = session.join(party, true /* emitSessionJoinEvent */) diff --git a/lib/kube/proxy/sess.go b/lib/kube/proxy/sess.go index 2a7865f5dd3c..f09770c5380d 100644 --- a/lib/kube/proxy/sess.go +++ b/lib/kube/proxy/sess.go @@ -68,10 +68,11 @@ const ( // remoteClient is either a kubectl or websocket client. type remoteClient interface { + queueID() uuid.UUID stdinStream() io.Reader stdoutStream() io.Writer stderrStream() io.Writer - resizeQueue() <-chan *remotecommand.TerminalSize + resizeQueue() <-chan terminalResizeMessage resize(size *remotecommand.TerminalSize) error forceTerminate() <-chan struct{} sendStatus(error) error @@ -79,9 +80,14 @@ type remoteClient interface { } type websocketClientStreams struct { + id uuid.UUID stream *streamproto.SessionStream } +func (p *websocketClientStreams) queueID() uuid.UUID { + return p.id +} + func (p *websocketClientStreams) stdinStream() io.Reader { return p.stream } @@ -94,8 +100,26 @@ func (p *websocketClientStreams) stderrStream() io.Writer { return p.stream } -func (p *websocketClientStreams) resizeQueue() <-chan *remotecommand.TerminalSize { - return p.stream.ResizeQueue() +func (p *websocketClientStreams) resizeQueue() <-chan terminalResizeMessage { + ch := make(chan terminalResizeMessage) + go func() { + defer close(ch) + for { + select { + case <-p.stream.Done(): + return + case size := <-p.stream.ResizeQueue(): + if size == nil { + return + } + ch <- terminalResizeMessage{ + size: size, + source: p.id, + } + } + } + }() + return ch } func (p *websocketClientStreams) resize(size *remotecommand.TerminalSize) error { @@ -115,6 +139,7 @@ func (p *websocketClientStreams) Close() error { } type kubeProxyClientStreams struct { + id uuid.UUID proxy *remoteCommandProxy sizeQueue *termQueue stdin io.Reader @@ -128,6 +153,7 @@ func newKubeProxyClientStreams(proxy *remoteCommandProxy) *kubeProxyClientStream options := proxy.options() return &kubeProxyClientStreams{ + id: uuid.New(), proxy: proxy, stdin: options.Stdin, stdout: options.Stdout, @@ -137,6 +163,10 @@ func newKubeProxyClientStreams(proxy *remoteCommandProxy) *kubeProxyClientStream } } +func (p *kubeProxyClientStreams) queueID() uuid.UUID { + return p.id +} + func (p *kubeProxyClientStreams) stdinStream() io.Reader { return p.stdin } @@ -149,8 +179,8 @@ func (p *kubeProxyClientStreams) stderrStream() io.Writer { return p.stderr } -func (p *kubeProxyClientStreams) resizeQueue() <-chan *remotecommand.TerminalSize { - ch := make(chan *remotecommand.TerminalSize) +func (p *kubeProxyClientStreams) resizeQueue() <-chan terminalResizeMessage { + ch := make(chan terminalResizeMessage) p.wg.Add(1) go func() { defer p.wg.Done() @@ -159,8 +189,9 @@ func (p *kubeProxyClientStreams) resizeQueue() <-chan *remotecommand.TerminalSiz if size == nil { return } + select { - case ch <- size: + case ch <- terminalResizeMessage{size, p.id}: // Check if the sizeQueue was already terminated. case <-p.sizeQueue.done.Done(): return @@ -193,21 +224,28 @@ func (p *kubeProxyClientStreams) Close() error { return nil } +// terminalResizeMessage is a message that contains the terminal size and the source of the resize event. +type terminalResizeMessage struct { + size *remotecommand.TerminalSize + source uuid.UUID +} + // multiResizeQueue is a merged queue of multiple terminal size queues. type multiResizeQueue struct { - queues map[string]<-chan *remotecommand.TerminalSize + queues map[string]<-chan terminalResizeMessage cases []reflect.SelectCase - callback func(*remotecommand.TerminalSize) + callback func(terminalResizeMessage) mutex sync.Mutex parentCtx context.Context reloadCtx context.Context reloadCancel context.CancelFunc + lastSize *remotecommand.TerminalSize } func newMultiResizeQueue(parentCtx context.Context) *multiResizeQueue { ctx, cancel := context.WithCancel(parentCtx) return &multiResizeQueue{ - queues: make(map[string]<-chan *remotecommand.TerminalSize), + queues: make(map[string]<-chan terminalResizeMessage), parentCtx: parentCtx, reloadCtx: ctx, reloadCancel: cancel, @@ -234,11 +272,17 @@ func (r *multiResizeQueue) rebuild() { } } +func (r *multiResizeQueue) getLastSize() *remotecommand.TerminalSize { + r.mutex.Lock() + defer r.mutex.Unlock() + return r.lastSize +} + func (r *multiResizeQueue) close() { r.reloadCancel() } -func (r *multiResizeQueue) add(id string, queue <-chan *remotecommand.TerminalSize) { +func (r *multiResizeQueue) add(id string, queue <-chan terminalResizeMessage) { r.mutex.Lock() defer r.mutex.Unlock() r.queues[id] = queue @@ -270,9 +314,12 @@ loop: } } - size := value.Interface().(*remotecommand.TerminalSize) + size := value.Interface().(terminalResizeMessage) r.callback(size) - return size + r.mutex.Lock() + r.lastSize = size.size + r.mutex.Unlock() + return size.size } } @@ -703,20 +750,24 @@ func (s *session) lockedSetupLaunch(request *remoteCommandRequest, eventPodMeta sessionStart := s.forwarder.cfg.Clock.Now().UTC() if !s.sess.noAuditEvents { - s.terminalSizeQueue.callback = func(resize *remotecommand.TerminalSize) { + s.terminalSizeQueue.callback = func(termSize terminalResizeMessage) { s.mu.Lock() defer s.mu.Unlock() for id, p := range s.parties { - err := p.Client.resize(resize) + // Skip the party that sent the resize event to avoid a resize loop. + if p.Client.queueID() == termSize.source { + continue + } + err := p.Client.resize(termSize.size) if err != nil { s.log.WithError(err).Errorf("Failed to resize client: %v", id.String()) } } params := tsession.TerminalParams{ - W: int(resize.Width), - H: int(resize.Height), + W: int(termSize.size.Width), + H: int(termSize.size.Height), } resizeEvent, err := s.recorder.PrepareSessionEvent(&apievents.Resize{ @@ -747,7 +798,7 @@ func (s *session) lockedSetupLaunch(request *remoteCommandRequest, eventPodMeta } } } else { - s.terminalSizeQueue.callback = func(resize *remotecommand.TerminalSize) {} + s.terminalSizeQueue.callback = func(resize terminalResizeMessage) {} } // If we get here, it means we are going to have a session.end event. @@ -953,6 +1004,16 @@ func (s *session) join(p *party, emitJoinEvent bool) error { s.partiesHistorical[p.ID] = p s.terminalSizeQueue.add(stringID, p.Client.resizeQueue()) + // If the session is already running, we need to resize the new party's terminal + // to match the last terminal size. + // This is done to ensure that the new party's terminal is the same size as the + // other parties' terminals and no discrepancies are present. + if lastQueueSize := s.terminalSizeQueue.getLastSize(); lastQueueSize != nil { + if err := p.Client.resize(lastQueueSize); err != nil { + s.log.WithError(err).Errorf("Failed to resize client: %v", stringID) + } + } + if p.Mode == types.SessionPeerMode { s.io.AddReader(stringID, p.Client.stdinStream()) }