Skip to content

Commit

Permalink
feat: enable token auth for Jupyter notebooks [MD-404] (#9452)
Browse files Browse the repository at this point in the history
Adds notebook sessions table to persist and fetch Jupyter tokens
  • Loading branch information
azhou-determined authored Jun 20, 2024
1 parent ea929fc commit 553521e
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 35 deletions.
8 changes: 5 additions & 3 deletions e2e_tests/tests/cluster/test_master_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,11 +555,13 @@ def test_master_restart_notebook_k8s(

def _test_master_restart_notebook(managed_cluster: abstract_cluster.Cluster, downtime: int) -> None:
sess = api_utils.user_session()
with cmd.interactive_command(sess, ["notebook", "start", "--detach"]) as notebook:
task_id = notebook.task_id
with cmd.interactive_command(sess, ["notebook", "start", "--detach"]) as notebook_cmd:
task_id = notebook_cmd.task_id
assert task_id is not None
utils.wait_for_task_state(sess, "notebook", task_id, "RUNNING")
notebook_path = f"proxy/{task_id}/"
notebook = bindings.get_GetNotebook(session=sess, notebookId=task_id).notebook
assert notebook and notebook.serviceAddress
notebook_path = notebook.serviceAddress.lstrip("/")
_check_notebook_url(sess, notebook_path)

if downtime >= 0:
Expand Down
7 changes: 6 additions & 1 deletion harness/determined/cli/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def start_notebook(args: argparse.Namespace) -> None:

cli.wait_ntsc_ready(sess, api.NTSC_Kind.notebook, nb.id)

assert nb.serviceAddress is not None, "missing tensorboard serviceAddress"
assert nb.serviceAddress is not None, "missing Jupyter serviceAddress"
nb_path = ntsc.make_interactive_task_url(
task_id=nb.id,
service_address=nb.serviceAddress,
Expand All @@ -62,6 +62,11 @@ def start_notebook(args: argparse.Namespace) -> None:
if not args.no_browser:
webbrowser.open(url)
print(termcolor.colored(f"Jupyter Notebook is running at: {url}", "green"))
print(
termcolor.colored(
f"Connect to remote Jupyter server: " f"{args.master}{nb.serviceAddress}", "blue"
)
)


def open_notebook(args: argparse.Namespace) -> None:
Expand Down
48 changes: 42 additions & 6 deletions master/internal/api_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,35 @@ func redirectToLogin(c echo.Context) error {
// processProxyAuthentication is a middleware processing function that attempts
// to authenticate incoming HTTP requests coming through proxies.
func processProxyAuthentication(c echo.Context) (done bool, err error) {
user, _, err := user.GetService().UserAndSessionFromRequest(c.Request())
taskID := model.TaskID(strings.SplitN(c.Param("service"), ":", 2)[0])

// Notebooks require special auth token passed as a URL parameter.
token := extractNotebookTokenFromRequest(c.Request())
var usr *model.User
var notebookSession *model.NotebookSession

if token != "" {
// Notebooks go through special token param auth.
usr, notebookSession, err = user.GetService().UserAndNotebookSessionFromToken(token)
if err != nil {
return true, err
}
if notebookSession.TaskID != taskID {
return true, fmt.Errorf("invalid notebook session token for task (%v)", taskID)
}
} else {
usr, _, err = user.GetService().UserAndSessionFromRequest(c.Request())
}

if errors.Is(err, db.ErrNotFound) {
return true, redirectToLogin(c)
} else if err != nil {
return true, err
}
if !user.Active {
if !usr.Active {
return true, redirectToLogin(c)
}

taskID := model.TaskID(strings.SplitN(c.Param("service"), ":", 2)[0])
var ctx context.Context

if c.Request() == nil || c.Request().Context() == nil {
Expand All @@ -141,7 +159,7 @@ func processProxyAuthentication(c echo.Context) (done bool, err error) {
return true, fmt.Errorf("error looking up task experiment: %w", err)
}

err = expauth.AuthZProvider.Get().CanGetExperiment(ctx, *user, e)
err = expauth.AuthZProvider.Get().CanGetExperiment(ctx, *usr, e)
return err != nil, authz.SubIfUnauthorized(err, serviceNotFoundErr)
}

Expand All @@ -152,14 +170,32 @@ func processProxyAuthentication(c echo.Context) (done bool, err error) {
// Continue NTSC task checks.
if spec.TaskType == model.TaskTypeTensorboard {
err = command.AuthZProvider.Get().CanGetTensorboard(
ctx, *user, spec.WorkspaceID, spec.ExperimentIDs, spec.TrialIDs)
ctx, *usr, spec.WorkspaceID, spec.ExperimentIDs, spec.TrialIDs)
} else {
err = command.AuthZProvider.Get().CanGetNSC(
ctx, *user, spec.WorkspaceID)
ctx, *usr, spec.WorkspaceID)
}
return err != nil, authz.SubIfUnauthorized(err, serviceNotFoundErr)
}

// extractNotebookTokenFromRequest looks for auth token for Jupyter notebooks
// in two places:
// 1. A token query parameter in the request URL.
// 2. An HTTP Authorization header with a "token" type.
func extractNotebookTokenFromRequest(r *http.Request) string {
token := r.URL.Query().Get("token")
authRaw := r.Header.Get("Authorization")
if token != "" {
return token
} else if authRaw != "" {
if strings.HasPrefix(authRaw, "token ") {
return strings.TrimPrefix(authRaw, "token ")
}
}
// If we found no token, then abort the request with an HTTP 401.
return ""
}

// processAuthWithRedirect is an auth middleware that redirects browser requests
// to login page for a set of given paths in case of authentication errors.
func processAuthWithRedirect(redirectPaths []string) echo.MiddlewareFunc {
Expand Down
7 changes: 3 additions & 4 deletions master/internal/api_notebook.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,9 @@ func (a *apiServer) LaunchNotebook(
}

// Launch a Notebook.
genericCmd, err := command.DefaultCmdService.LaunchGenericCommand(
model.TaskTypeNotebook,
model.JobTypeNotebook,
launchReq)
genericCmd, err := command.DefaultCmdService.LaunchNotebookCommand(
launchReq,
session)
if err != nil {
return nil, err
}
Expand Down
30 changes: 18 additions & 12 deletions master/internal/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,15 @@ type Command struct {

tasks.GenericCommandSpec

registeredTime time.Time
taskID model.TaskID
taskType model.TaskType
jobType model.JobType
jobID model.JobID
allocationID model.AllocationID
lastState task.AllocationState
exitStatus *task.AllocationExited
restored bool

registeredTime time.Time
taskID model.TaskID
taskType model.TaskType
jobType model.JobType
jobID model.JobID
allocationID model.AllocationID
lastState task.AllocationState
exitStatus *task.AllocationExited
restored bool
contextDirectory []byte // Don't rely on this being set outsides of PreStart non restore case.

logCtx logger.Context
Expand Down Expand Up @@ -237,6 +236,12 @@ func (c *Command) OnExit(ae *task.AllocationExited) {
c.syslog.WithError(err).Errorf(
"failure to delete user session for task: %v", c.taskID)
}
if c.TaskType == model.TaskTypeNotebook {
if err := internaldb.DeleteNotebookSessionByTask(context.TODO(), c.taskID); err != nil {
c.syslog.WithError(err).Errorf(
"failure to delete notebook session for task: %v", c.taskID)
}
}

go func() {
time.Sleep(terminatedDuration)
Expand Down Expand Up @@ -323,14 +328,15 @@ func (c *Command) ToV1Command() *commandv1.Command {
func (c *Command) ToV1Notebook() *notebookv1.Notebook {
c.mu.Lock()
defer c.mu.Unlock()

allo := c.refreshAllocationState()
notebookToken := c.Base.ExtraEnvVars[model.NotebookSessionEnvVar]
notebookAddress := fmt.Sprintf("%s?token=%s", c.serviceAddress(), notebookToken)
return &notebookv1.Notebook{
Id: c.stringID(),
State: enrichState(allo.State),
Description: c.Config.Description,
Container: allo.SingleContainer().ToProto(),
ServiceAddress: c.serviceAddress(),
ServiceAddress: notebookAddress,
StartTime: protoutils.ToTimestamp(c.registeredTime),
Username: c.Base.Owner.Username,
UserId: int32(c.Base.Owner.ID),
Expand Down
54 changes: 53 additions & 1 deletion master/internal/command/command_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ func (cs *CommandService) LaunchGenericCommand(
) (*Command, error) {
cs.mu.Lock()
defer cs.mu.Unlock()

taskID := model.NewTaskID()
jobID := model.NewJobID()
req.Spec.CommandID = string(taskID)
Expand Down Expand Up @@ -131,6 +130,59 @@ func (cs *CommandService) LaunchGenericCommand(
return cmd, nil
}

// LaunchNotebookCommand creates notebook commands and persists them to the database.
func (cs *CommandService) LaunchNotebookCommand(
req *CreateGeneric,
session *model.UserSession,
) (*Command, error) {
cs.mu.Lock()
defer cs.mu.Unlock()

taskID := model.NewTaskID()
jobID := model.NewJobID()
req.Spec.CommandID = string(taskID)
req.Spec.TaskType = model.TaskTypeNotebook

logCtx := logger.Context{
"job-id": jobID,
"task-id": taskID,
"task-type": model.TaskTypeNotebook,
}

token, err := db.GenerateNotebookSessionToken(session.ID, taskID)
if err != nil {
return nil, err
}
req.Spec.Base.ExtraEnvVars[model.NotebookSessionEnvVar] = token
cmd := &Command{
db: cs.db,
rm: cs.rm,

GenericCommandSpec: *req.Spec,

taskID: taskID,
taskType: model.TaskTypeNotebook,
jobType: model.JobTypeNotebook,
jobID: jobID,
contextDirectory: req.ContextDirectory,
logCtx: logCtx,
syslog: logrus.WithFields(logrus.Fields{"component": "command"}).WithFields(logCtx.Fields()),
}

if err := cmd.Start(context.TODO()); err != nil {
return nil, err
}

if err := db.StartNotebookSession(context.TODO(), session.ID, taskID, &token); err != nil {
return nil, err
}

// Add it to the registry.
cs.commands[cmd.taskID] = cmd

return cmd, nil
}

func (cs *CommandService) unregisterCommand(id model.TaskID) {
cs.mu.Lock()
defer cs.mu.Unlock()
Expand Down
63 changes: 63 additions & 0 deletions master/internal/db/postgres_notebook_sessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package db

import (
"context"
"fmt"

"github.com/o1egl/paseto"

"github.com/determined-ai/determined/master/pkg/model"
)

// StartNotebookSession persists a new notebook session row into the database.
func StartNotebookSession(
ctx context.Context,
userSessionID model.SessionID,
taskID model.TaskID,
token *string,
) error {
notebookSession := &model.NotebookSession{
UserSessionID: userSessionID,
TaskID: taskID,
Token: token,
}

if _, err := Bun().NewInsert().Model(notebookSession).
Returning("id").Exec(ctx, &notebookSession.ID); err != nil {
return fmt.Errorf("failed to create notebook session for task (%s): %w", taskID, err)
}

return nil
}

// GenerateNotebookSessionToken generates a token for a notebook session.
func GenerateNotebookSessionToken(
userSessionID model.SessionID,
taskID model.TaskID,
) (string, error) {
notebookSession := &model.NotebookSession{
UserSessionID: userSessionID,
TaskID: taskID,
}

v2 := paseto.NewV2()
token, err := v2.Sign(GetTokenKeys().PrivateKey, notebookSession, nil)
if err != nil {
return "", fmt.Errorf("failed to generate task authentication token: %w", err)
}
return token, nil
}

// DeleteNotebookSessionByTask deletes the notebook session associated with the task.
func DeleteNotebookSessionByTask(
ctx context.Context,
taskID model.TaskID,
) error {
if _, err := Bun().NewDelete().
Table("notebook_sessions").
Where("task_id = ?", taskID).
Exec(ctx); err != nil {
return fmt.Errorf("failed to delete notebook session for task (%s): %w", taskID, err)
}
return nil
}
37 changes: 37 additions & 0 deletions master/internal/user/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"sync"
"time"

"github.com/o1egl/paseto"

log "github.com/sirupsen/logrus"

"github.com/labstack/echo/v4"
Expand Down Expand Up @@ -158,6 +160,41 @@ func (s *Service) UserAndSessionFromRequest(
return ByToken(context.TODO(), token, s.extConfig)
}

// UserAndNotebookSessionFromToken gets the user and notebook session for a given token.
func (s *Service) UserAndNotebookSessionFromToken(
token string,
) (*model.User, *model.NotebookSession, error) {
var notebookSession model.NotebookSession
ctx := context.TODO()
v2 := paseto.NewV2()
if err := v2.Verify(token, db.GetTokenKeys().PublicKey, &notebookSession, nil); err != nil {
return nil, nil, db.ErrNotFound
}
var session model.UserSession

if err := db.Bun().NewSelect().
Model(&session).
Where("id = ?", notebookSession.UserSessionID).
Scan(ctx); err != nil {
return nil, nil, err
}

if session.Expiry.Before(time.Now()) {
return nil, nil, db.ErrNotFound
}

var user model.User
err := db.Bun().NewSelect().
Table("users").
ColumnExpr("users.*").
Join("JOIN user_sessions ON user_sessions.user_id = users.id").
Where("user_sessions.id = ?", session.ID).Scan(ctx, &user)
if err != nil {
return nil, nil, err
}
return &user, &notebookSession, nil
}

// getAuthLevel returns what level of authentication a request needs.
func (s *Service) getAuthLevel(c echo.Context) int {
switch {
Expand Down
15 changes: 15 additions & 0 deletions master/pkg/model/notebook_session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package model

import "github.com/uptrace/bun"

// NotebookSession corresponds to a row in the "notebook_sessions" DB table.
type NotebookSession struct {
bun.BaseModel `bun:"table:notebook_sessions"`
ID SessionID `db:"id" bun:"id,pk,autoincrement" json:"id"`
TaskID TaskID `db:"task_id" bun:"task_id" json:"task_id"`
UserSessionID SessionID `db:"user_session_id" bun:"user_session_id" json:"user_session_id"`
Token *string `db:"token" bun:"token" json:"token"`
}

// NotebookSessionEnvVar is the environment variable name for notebook task tokens.
const NotebookSessionEnvVar = "DET_NOTEBOOK_TOKEN"
Loading

0 comments on commit 553521e

Please sign in to comment.