Skip to content

Commit

Permalink
stream: add websocket transport layer support
Browse files Browse the repository at this point in the history
  • Loading branch information
criyle committed Feb 6, 2024
1 parent a52f136 commit aa41950
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 20 deletions.
23 changes: 23 additions & 0 deletions README.cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,29 @@ reboot
- `-force-gc-target` 默认 `20m`, 堆内存使用超过该值是强制垃圾收集和归还内存
- `-force-gc-interval` 默认 `5s`, 为后台线程检查的频繁程度

### WebSocket 流接口

WebSocket 流接口是用于运行一个程序,同时和它的输入输出进行交互。所有的消息都应该使用 WebSocket 的 binary 格式来发送来避免兼容性问题。

```text
+--------+--------+---...
| 类型 | 载荷 ...
+--------|--------+---...
请求:
请求类型 =
1 - 运行请求 (载荷 = JSON 编码的请求体)
2 - 设置终端窗口大小 (载荷 = JSON 编码的请求体)
3 - 输入 (载荷 = 1 字节 (4 位的 命令下标 + 4 位的 文件描述符) + 输入内容)
4 - 取消 (没有载荷)
响应:
响应类型 =
1 - 运行结果 (载荷 = JSON 编码的运行结果)
2 - 输出 (载荷 = 1 字节 (4 位的 命令下标 + 4 位的 文件描述符) + 输入内容)
```

任何的不完整,或者不合法的消息会被认为是错误,并终止运行。

### 压力测试

使用 `wrk``t.lua`: `wrk -s t.lua -c 1 -t 1 -d 30s --latency http://localhost:5050/run`.
Expand Down
4 changes: 2 additions & 2 deletions cmd/go-judge-shell/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ type grpcWrapper struct {
sc pb.Executor_ExecStreamClient
}

func newGrpc(args []string, srvAddr *string) Stream {
func newGrpc(args []string, srvAddr string) Stream {
token := os.Getenv("TOKEN")
opts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
if token != "" {
opts = append(opts, grpc.WithPerRPCCredentials(newTokenAuth(token)))
}
conn, err := grpc.Dial(*srvAddr, opts...)
conn, err := grpc.Dial(srvAddr, opts...)
if err != nil {
log.Fatalln("client", err)
}
Expand Down
17 changes: 14 additions & 3 deletions cmd/go-judge-shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ import (
)

var (
srvAddr = flag.String("srv-addr", "localhost:5051", "GRPC server addr")
transport = flag.String("transport", "websocket", "defines transport layer (websocket / grpc)")
wsURL = flag.String("ws-url", "ws://localhost:5050/stream", "HTTP server url")
grpcAddr = flag.String("grpc-addr", "localhost:5051", "GRPC server addr")
)

const (
Expand All @@ -33,6 +35,7 @@ var env = []string{
"TERM=" + os.Getenv("TERM"),
}

// Stream defines the transport layer for stream execution
type Stream interface {
Send(*stream.Request) error
Recv() (*stream.Response, error)
Expand All @@ -44,8 +47,16 @@ func main() {
if len(args) == 0 {
args = []string{"/bin/bash"}
}
w := newGrpc(args, srvAddr)
r, err := run(w, args)
var s Stream
switch *transport {
case "websocket":
s = newWebsocket(args, *wsURL)
case "grpc":
s = newGrpc(args, *grpcAddr)
default:
log.Fatalln("invalid transport: ", *transport)
}
r, err := run(s, args)
log.Printf("finished: %+v %v", r, err)
}

Expand Down
104 changes: 104 additions & 0 deletions cmd/go-judge-shell/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package main

import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"

"github.com/criyle/go-judge/cmd/go-judge/model"
"github.com/criyle/go-judge/cmd/go-judge/stream"
"github.com/gorilla/websocket"
)

var _ Stream = &websocketStream{}

type websocketStream struct {
conn *websocket.Conn
}

func newWebsocket(args []string, wsURL string) Stream {
header := make(http.Header)
token := os.Getenv("TOKEN")
if token != "" {
header.Add("Authorization", "Bearer "+token)
}
conn, _, err := websocket.DefaultDialer.Dial(wsURL, header)
if err != nil {
log.Fatalln("ws connect: ", err)
}
log.Println("start", args)
return &websocketStream{conn: conn}
}

// Recv implements Stream.
func (s *websocketStream) Recv() (*stream.Response, error) {
_, r, err := s.conn.ReadMessage()
if err != nil {
return nil, err
}
if len(r) == 0 {
return nil, io.ErrUnexpectedEOF
}
resp := new(stream.Response)
switch r[0] {
case 1:
resp.Response = new(model.Response)
if err := json.Unmarshal(r[1:], resp.Response); err != nil {
return nil, err
}
case 2:
if len(r) < 2 {
return nil, io.ErrUnexpectedEOF
}
resp.Output = new(stream.OutputResponse)
resp.Output.Index = int(r[1]>>4) & 0xf
resp.Output.Fd = int(r[1]) & 0xf
resp.Output.Content = r[2:]
default:
return nil, fmt.Errorf("invalid type code: %d", r[0])
}
return resp, nil
}

// Send implements Stream.
func (s *websocketStream) Send(req *stream.Request) error {
w, err := s.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return err
}
defer w.Close()

switch {
case req.Request != nil:
if _, err := w.Write([]byte{1}); err != nil {
return err
}
if err := json.NewEncoder(w).Encode(req.Request); err != nil {
return err
}
case req.Resize != nil:
if _, err := w.Write([]byte{2}); err != nil {
return err
}
if err := json.NewEncoder(w).Encode(req.Resize); err != nil {
return err
}
case req.Input != nil:
if _, err := w.Write([]byte{3, byte(req.Input.Index<<4 | req.Input.Fd)}); err != nil {
return err
}
if _, err := w.Write(req.Input.Content); err != nil {
return err
}
case req.Cancel != nil:
if _, err := w.Write([]byte{4}); err != nil {
return err
}
default:
return fmt.Errorf("invalid request")
}
return nil
}
2 changes: 2 additions & 0 deletions cmd/go-judge/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ func generateHandleVersion(conf *config.Config, builderParam map[string]any) fun
"pipeProxy": true,
"symlink": true,
"addressSpaceLimit": true,
"stream": true,
})
}
}
Expand All @@ -521,6 +522,7 @@ func generateHandleConfig(conf *config.Config, builderParam map[string]any) func
"pipeProxy": true,
"symlink": true,
"addressSpaceLimit": true,
"stream": true,
"fileStorePath": conf.Dir,
"runnerConfig": builderParam,
})
Expand Down
3 changes: 3 additions & 0 deletions cmd/go-judge/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ import (
"github.com/criyle/go-judge/worker"
)

// FileError defines the location, file name and the detailed message for a failed file operation
type FileError = envexec.FileError

// FileErrorType defines the location that file operation fails
type FileErrorType = envexec.FileErrorType

// CmdFile defines file from multiple source including local / memory / cached or pipe collector
Expand Down
28 changes: 13 additions & 15 deletions cmd/go-judge/stream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ type Response struct {

// ResizeRequest defines resize operation to the virtual terminal
type ResizeRequest struct {
Index int
Fd int
Rows int
Cols int
X int
Y int
Index int `json:"index,omitempty"`
Fd int `json:"fd,omitempty"`
Rows int `json:"rows,omitempty"`
Cols int `json:"cols,omitempty"`
X int `json:"x,omitempty"`
Y int `json:"y,omitempty"`
}

// InputRequest defines input operation from the remote
Expand Down Expand Up @@ -97,15 +97,13 @@ func Start(baseCtx context.Context, s Stream, w worker.Worker, srcPrefix []strin
defer cancel()

// stream in
if len(streamIn) > 0 {
wg.Go(func() error {
if err := streamInput(ctx, s, streamIn, execCancel); err != nil {
cancel()
return err
}
return nil
})
}
wg.Go(func() error {
if err := streamInput(ctx, s, streamIn, execCancel); err != nil {
cancel()
return err
}
return nil
})

// stream out
outCh := make(chan *OutputResponse, len(streamOut))
Expand Down
120 changes: 120 additions & 0 deletions cmd/go-judge/ws_executor/stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package wsexecutor

import (
"context"
"encoding/json"
"fmt"
"io"
"time"

"github.com/criyle/go-judge/cmd/go-judge/model"
"github.com/criyle/go-judge/cmd/go-judge/stream"
"github.com/gorilla/websocket"
)

var _ stream.Stream = &streamWrapper{}

type streamWrapper struct {
ctx context.Context
conn *websocket.Conn
sendCh chan stream.Response
}

func (w *streamWrapper) sendLoop() {
conn := w.conn
defer conn.Close()

ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case r := <-w.sendCh:
conn.SetWriteDeadline(time.Now().Add(writeWait))
switch {
case r.Response != nil:
w, err := conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return
}
if _, err := w.Write([]byte{1}); err != nil {
return
}
if err := json.NewEncoder(w).Encode(r.Response); err != nil {
return
}
if err := w.Close(); err != nil {
return
}
conn.Close()
return
case r.Output != nil:
w, err := conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return
}
if _, err := w.Write([]byte{2, byte(r.Output.Index<<4 | r.Output.Fd)}); err != nil {
return
}
if _, err := w.Write(r.Output.Content); err != nil {
return
}
if err := w.Close(); err != nil {
return
}
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}

func (w *streamWrapper) Send(resp stream.Response) error {
select {
case <-w.ctx.Done():
return w.ctx.Err()
case w.sendCh <- resp:
return nil
}
}

func (w *streamWrapper) Recv() (*stream.Request, error) {
conn := w.conn
_, buf, err := conn.ReadMessage()
if err != nil {
return nil, err
}
if len(buf) == 0 {
return nil, io.ErrUnexpectedEOF
}
var req stream.Request
switch buf[0] {
case 1:
req.Request = new(model.Request)
if err := json.Unmarshal(buf[1:], req.Request); err != nil {
return nil, err
}
case 2:
req.Resize = new(stream.ResizeRequest)
if err := json.Unmarshal(buf[1:], req.Resize); err != nil {
return nil, err
}
case 3:
if len(buf) < 2 {
return nil, io.ErrUnexpectedEOF
}
req.Input = new(stream.InputRequest)
req.Input.Index = int(buf[1]>>4) & 0xf
req.Input.Fd = int(buf[1]) & 0xf
req.Input.Content = buf[2:]
case 4:
req.Cancel = new(struct{})
default:
return nil, fmt.Errorf("invalid type code: %d", buf[0])
}
return &req, nil
}
Loading

0 comments on commit aa41950

Please sign in to comment.