Skip to content

Commit

Permalink
feat: streaming update code generation for typescript (#8988)
Browse files Browse the repository at this point in the history
  • Loading branch information
gt2345 authored Mar 20, 2024
1 parent 39afa3c commit 0518785
Show file tree
Hide file tree
Showing 13 changed files with 317 additions and 96 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
proto/pkg/**/* -diff -merge linguist-generated=true
master/pkg/schemas/expconf/zgen_* -diff -merge linguist-generated=true
webui/react/src/services/api-ts-sdk/**/* -diff -merge linguist-generated=true
webui/react/src/services/stream/wire.ts -diff -merge linguist-generated=true
harness/determined/common/api/bindings.py -diff -merge linguist-generated=true
harness/determined/common/streams/wire.py -diff -merge linguist-generated=true
docs/swagger-ui/swagger-ui*js* -diff -merge
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ require (
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect
golang.org/x/sys v0.15.0
golang.org/x/term v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/text v0.14.0
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20211223182754-3ac035c7e7cb
Expand Down
9 changes: 7 additions & 2 deletions master/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ LICENSE := packaging/LICENSE
SCHEMA_INPUTS = ../schemas/gen.py $(shell find ./pkg/schemas/ -name 'zgen_*.go' -prune -o -name '*.go' -print) $(shell find ../schemas/expconf -name '*.json')
STREAM_INPUTS = $(shell find ./internal/stream/ -name '*_test.go' -prune -o -name '*.go' -print)
STREAM_PYTHON_CLIENT = ../harness/determined/common/streams/wire.py
STREAM_TS_CLIENT = ../webui/react/src/services/stream/wire.ts
MOCK_INPUTS = ./internal/sproto/task.go ./internal/db/database.go ./internal/command/authz_iface.go ../go.mod ../go.sum ./internal/rm/resource_manager_iface.go ./internal/task/allocation_service_iface.go
GORELEASER = goreleaser

Expand Down Expand Up @@ -45,12 +46,13 @@ ungen:
rm -f `find ./internal/mocks -name '*.go'` build/mock_gen.stamp

.PHONY: gen
gen: $(LICENSE) build/schema_gen.stamp $(STREAM_PYTHON_CLIENT)
gen: $(LICENSE) build/schema_gen.stamp $(STREAM_PYTHON_CLIENT) $(STREAM_TS_CLIENT)

.PHONY: force-gen
force-gen:
rm -f build/schema_gen.stamp
rm -f $(STREAM_PYTHON_CLIENT)
rm -f $(STREAM_TS_CLIENT)

build/schema_gen.stamp: $(SCHEMA_INPUTS)
go generate ./pkg/schemas/...
Expand All @@ -60,8 +62,11 @@ build/schema_gen.stamp: $(SCHEMA_INPUTS)
$(STREAM_PYTHON_CLIENT): build/stream-gen $(STREAM_INPUTS)
build/stream-gen $(STREAM_INPUTS) --python --output $@

$(STREAM_TS_CLIENT): build/stream-gen $(STREAM_INPUTS)
build/stream-gen $(STREAM_INPUTS) --ts --output $@

.PHONY: stream-gen
stream-gen: $(STREAM_PYTHON_CLIENT)
stream-gen: $(STREAM_PYTHON_CLIENT) $(STREAM_TS_CLIENT)

.PHONY: mocks
mocks: build/mock_gen.stamp
Expand Down
237 changes: 192 additions & 45 deletions master/cmd/stream-gen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,36 @@ import (
"strings"

"github.com/pkg/errors"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)

const keystr = "determined:stream-gen"

type streamType string

const (
json streamType = "JSONB"
text streamType = "string"
integer streamType = "int"
integer64 streamType = "int64"
intArr streamType = "[]int"
boolean streamType = "bool"
time streamType = "time.Time"
timePtr streamType = "*time.Time"
taskID streamType = "model.TaskID"
requestID streamType = "model.RequestID"
requestIDPtr streamType = "*model.RequestID"
workspaceState streamType = "model.WorkspaceState"
)

const (
server = "server"
client = "client"
python = "python"
typescript = "typescript"
)

// Streamable represents the struct under a determined:stream-gen comment.
type Streamable struct {
Name string
Expand All @@ -28,7 +54,7 @@ type Streamable struct {
// Field is a member of a Streamable.
type Field struct {
Name string
Type string
Type streamType
JSONTag string
}

Expand Down Expand Up @@ -188,7 +214,7 @@ func (x *StreamableFinder) Visit(node ast.Node) ast.Visitor {
// Get the string representing this type. We use the string because the ast
// representation of the type is a PITA to work with.
typestr := string(x.src[field.Type.Pos()-1 : field.Type.End()-1])
result.Fields = append(result.Fields, Field{field.Names[0].String(), typestr, v})
result.Fields = append(result.Fields, Field{field.Names[0].String(), streamType(typestr), v})
}
}

Expand Down Expand Up @@ -238,23 +264,123 @@ func (b *Builder) String() string {
return b.builder.String()
}

func genTypescript(streamables []Streamable) ([]byte, error) {
b := Builder{}
typeAnno := func(f Field) ([2]string, error) {
x := map[streamType]([2]string){
json: {"any", "{}"},
text: {"string", ""},
boolean: {"bool", "false"},
integer: {"number", "0"},
integer64: {"number", "0"},
intArr: {"Array<number>", "[]"},
time: {"string", ""},
timePtr: {"string | undefined", "undefined"},
taskID: {"string", ""},
requestID: {"number", "0"},
requestIDPtr: {"number | undefined", "undefined"},
workspaceState: {"types.WorkspaceState", "types.WorkspaceState.Unspecified"},
}
out, ok := x[f.Type]
if !ok {
return [2]string{"", ""}, fmt.Errorf("no type annotation matches %q", f.Type)
}
return out, nil
}
b.Writef("// Code generated by stream-gen. DO NOT EDIT.\n")
b.Writef("\n")
b.Writef("import { isEqual } from 'lodash';\n")
b.Writef("\n")
b.Writef("import { Streamable, StreamSpec } from '.';\n")
b.Writef("\n")
typesImported := false
for _, s := range streamables {
source := s.Args["source"]
entity := strings.ToLower(strings.TrimSuffix(s.Name, "SubscriptionSpec"))
caser := cases.Title(language.English)

switch source {
case server:
continue
case client:
for _, f := range s.Fields {
anno, err := typeAnno(f)
if err != nil {
return nil, fmt.Errorf("struct %v, field %v: %v", s.Name, f.Name, err)
}
if strings.Contains(anno[0], "types") && !typesImported {
typesImported = true
b.Writef("import * as types from 'types';\n\n")
}
}
b.Writef("export class %vSpec extends StreamSpec {\n", caser.String(entity))
b.Writef(" readonly #id: Streamable = '%vs';\n", entity)
for _, f := range s.Fields {
anno, _ := typeAnno(f)
b.Writef(" #%v: %v;\n", f.JSONTag, anno[0])
}
b.Writef("\n")
b.Writef(" constructor(\n")
for _, f := range s.Fields {
anno, _ := typeAnno(f)
b.Writef(" %v?: %v,\n", f.JSONTag, anno[0])
}
b.Writef(" ) {\n")
b.Writef(" super();\n")
for _, f := range s.Fields {
anno, _ := typeAnno(f)
b.Writef(" this.#%v = %v || %v;\n", f.JSONTag, f.JSONTag, anno[1])
}
b.Writef(" }\n")
b.Writef("\n")
b.Writef(" public equals = (sp?: StreamSpec): boolean => {\n")
b.Writef(" if (!sp) return false;\n")
b.Writef(" if (sp instanceof %vSpec) {\n", caser.String(entity))
b.Writef(" return (\n")
for i, f := range s.Fields {
if i > 0 {
b.Writef(" &&\n")
}
b.Writef(" isEqual(sp.#%v, this.#%v)\n", f.JSONTag, f.JSONTag)
}
b.Writef(" );\n")
b.Writef(" }\n")
b.Writef(" return false;\n")
b.Writef(" };\n")
b.Writef("\n")
b.Writef(" public id = (): Streamable => {\n")
b.Writef(" return this.#id;\n")
b.Writef(" };\n")
b.Writef("\n")
b.Writef(" public toWire = (): Record<string, unknown> => {\n")
b.Writef(" return {\n")
for _, f := range s.Fields {
b.Writef(" %v: this.#%v,\n", f.JSONTag, f.JSONTag)
}
b.Writef(" };\n")
b.Writef(" };\n")
b.Writef("}\n\n")
}
}
return []byte(b.String()), nil
}

func genPython(streamables []Streamable) ([]byte, error) {
b := Builder{}
typeAnno := func(f Field) (string, error) {
x := map[string]string{
"JSONB": "typing.Any",
"string": "str",
"bool": "bool",
"int": "int",
"int64": "int",
"[]int": "typing.List[int]",
"[]string": "typing.List[str]",
"time.Time": "float",
"*time.Time": "typing.Optional[float]",
"model.TaskID": "str",
"model.RequestID": "int",
"*model.RequestID": "typing.Optional[int]",
"model.State": "str",
x := map[streamType]string{
json: "typing.Any",
text: "str",
boolean: "bool",
integer: "int",
integer64: "int",
intArr: "typing.List[int]",
time: "float",
timePtr: "typing.Optional[float]",
taskID: "str",
requestID: "int",
requestIDPtr: "typing.Optional[int]",
workspaceState: "str",
}
out, ok := x[f.Type]
if !ok {
Expand Down Expand Up @@ -320,31 +446,11 @@ func genPython(streamables []Streamable) ([]byte, error) {
b.Writef(" def __eq__(self, other: object) -> bool:\n")
b.Writef(" return isinstance(other, type(self)) and self.to_json() == other.to_json()\n")

allowedArgs := map[string]bool{
"delete_msg": true,
"source": true,
}
requiredArgs := []string{"source"}

for _, s := range streamables {
// verify args
for k, v := range s.Args {
if !allowedArgs[k] {
fmt.Fprintf(os.Stderr, "unrecognized arg %q (%v=%v) @ %v\n", k, k, v, s.Position)
os.Exit(1)
}
}
for _, k := range requiredArgs {
if _, ok := s.Args[k]; !ok {
fmt.Fprintf(os.Stderr, "missing required arg %q @ %v\n", k, s.Position)
os.Exit(1)
}
}

source := s.Args["source"]

switch source {
case "server":
case server:
// Generate a subclass of a ServerMsg, all fields are always filled.
b.Writef("\n\n")
b.Writef("class %v(ServerMsg):\n", s.Name)
Expand All @@ -367,7 +473,7 @@ func genPython(streamables []Streamable) ([]byte, error) {
b.Writef("class %v(DeleteMsg):\n", deleter)
b.Writef(" pass\n")
}
case "client":
case client:
// Generate a subclass of a ClientMsg, all fields are always optional.
b.Writef("\n\n")
b.Writef("class %v(ClientMsg):\n", s.Name)
Expand Down Expand Up @@ -397,18 +503,42 @@ func printHelp(output io.Writer) {
output,
`stream-gen generates bindings for determined streaming updates.
usage: stream-gen IN.GO... --python [--output OUTPUT]
usage: stream-gen IN.GO... --python/ts [--output OUTPUT]
All structs in the input files IN.GO... which contain special 'determined:stream-gen' comments will
be included in the generated output.
Presently the only output language is --python.
Presently the only output languages are python and typescript.
Output will be written to stdout, or a location specified by --output. The OUTPUT will only be
overwritten if it would be modified.
`)
}

func verifyArgs(streamables []Streamable) {
allowedArgs := map[string]bool{
"delete_msg": true,
"source": true,
}
requiredArgs := []string{"source"}

for _, s := range streamables {
// verify args
for k, v := range s.Args {
if !allowedArgs[k] {
fmt.Fprintf(os.Stderr, "unrecognized arg %q (%v=%v) @ %v\n", k, k, v, s.Position)
os.Exit(1)
}
}
for _, k := range requiredArgs {
if _, ok := s.Args[k]; !ok {
fmt.Fprintf(os.Stderr, "missing required arg %q @ %v\n", k, s.Position)
os.Exit(1)
}
}
}
}

func main() {
// Parse commandline options manually because built-in flag library is junk.
if len(os.Args) == 1 {
Expand All @@ -426,7 +556,11 @@ func main() {
os.Exit(0)
}
if arg == "--python" {
lang = "python"
lang = python
continue
}
if arg == "--ts" {
lang = typescript
continue
}
if arg == "-o" || arg == "--output" {
Expand Down Expand Up @@ -460,11 +594,24 @@ func main() {
os.Exit(1)
}

// verify args will exit with code 1 in case of error.
verifyArgs(results)

// generate the language bindings
content, err := genPython(results)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
var content []byte
switch lang {
case python:
content, err = genPython(results)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
case typescript:
content, err = genTypescript(results)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
}

// write to output
Expand Down
18 changes: 9 additions & 9 deletions master/internal/stream/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ type ProjectMsg struct {
ID int `bun:"id,pk" json:"id"`

// mutable attributes
Name string `bun:"name" json:"name"`
Description string `bun:"description" json:"description"`
Archived bool `bun:"archived" json:"archived"`
CreatedAt time.Time `bun:"created_at" json:"created_at"`
Notes JSONB `bun:"notes" json:"notes"`
WorkspaceID int `bun:"workspace_id" json:"workspace_id"`
UserID int `bun:"user_id" json:"user_id"`
Immutable bool `bun:"immutable" json:"immutable"`
State model.State `bun:"state" json:"state"`
Name string `bun:"name" json:"name"`
Description string `bun:"description" json:"description"`
Archived bool `bun:"archived" json:"archived"`
CreatedAt time.Time `bun:"created_at" json:"created_at"`
Notes JSONB `bun:"notes" json:"notes"`
WorkspaceID int `bun:"workspace_id" json:"workspace_id"`
UserID int `bun:"user_id" json:"user_id"`
Immutable bool `bun:"immutable" json:"immutable"`
State model.WorkspaceState `bun:"state" json:"state"`

// metadata
Seq int64 `bun:"seq" json:"seq"`
Expand Down
Loading

0 comments on commit 0518785

Please sign in to comment.