Skip to content

Commit

Permalink
chore: refactor ResourceManager interface for multirm (#8847)
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinaecalderon authored Feb 28, 2024
1 parent 9f40603 commit 6857ecf
Show file tree
Hide file tree
Showing 41 changed files with 421 additions and 461 deletions.
16 changes: 8 additions & 8 deletions master/internal/api_agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
func (a *apiServer) GetAgents(
ctx context.Context, req *apiv1.GetAgentsRequest,
) (*apiv1.GetAgentsResponse, error) {
resp, err := a.m.rm.GetAgents(req)
resp, err := a.m.rm.GetAgents()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -53,7 +53,7 @@ func (a *apiServer) GetAgent(
return nil, err
}

resp, err := a.m.rm.GetAgent(req)
resp, err := a.m.rm.GetAgent(req.ResourceManager, req)
if err != nil {
return nil, err
}
Expand All @@ -78,7 +78,7 @@ func (a *apiServer) GetSlots(
return nil, err
}

resp, err := a.m.rm.GetSlots(req)
resp, err := a.m.rm.GetSlots(req.ResourceManager, req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -106,7 +106,7 @@ func (a *apiServer) GetSlot(
return nil, err
}

resp, err := a.m.rm.GetSlot(req)
resp, err := a.m.rm.GetSlot(req.ResourceManager, req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -144,7 +144,7 @@ func (a *apiServer) EnableAgent(
if err := a.canUpdateAgents(ctx); err != nil {
return nil, err
}
return a.m.rm.EnableAgent(req)
return a.m.rm.EnableAgent(req.ResourceManager, req)
}

func (a *apiServer) DisableAgent(
Expand All @@ -153,7 +153,7 @@ func (a *apiServer) DisableAgent(
if err := a.canUpdateAgents(ctx); err != nil {
return nil, err
}
return a.m.rm.DisableAgent(req)
return a.m.rm.DisableAgent(req.ResourceManager, req)
}

func (a *apiServer) EnableSlot(
Expand All @@ -163,7 +163,7 @@ func (a *apiServer) EnableSlot(
return nil, err
}

resp, err = a.m.rm.EnableSlot(req)
resp, err = a.m.rm.EnableSlot(req.ResourceManager, req)
switch {
case errors.Is(err, rmerrors.ErrNotSupported):
return resp, status.Error(codes.Unimplemented, err.Error())
Expand All @@ -181,7 +181,7 @@ func (a *apiServer) DisableSlot(
return nil, err
}

resp, err = a.m.rm.DisableSlot(req)
resp, err = a.m.rm.DisableSlot(req.ResourceManager, req)
switch {
case errors.Is(err, rmerrors.ErrNotSupported):
return resp, status.Error(codes.Unimplemented, err.Error())
Expand Down
5 changes: 3 additions & 2 deletions master/internal/api_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ func (a *apiServer) getCommandLaunchParams(ctx context.Context, req *protoComman
resources.Slots = 0
}

poolName, launchWarnings, err := a.m.ResolveResources(
managerName, poolName, launchWarnings, err := a.m.ResolveResources(
resources.ResourceManager,
resources.ResourcePool,
resources.Slots,
int(cmdSpec.Metadata.WorkspaceID),
Expand All @@ -107,7 +108,7 @@ func (a *apiServer) getCommandLaunchParams(ctx context.Context, req *protoComman
}

// Get the base TaskSpec.
taskSpec, err := a.m.fillTaskSpec(poolName, agentUserGroup, userModel)
taskSpec, err := a.m.fillTaskSpec(managerName, poolName, agentUserGroup, userModel)
if err != nil {
return nil, nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1907,6 +1907,8 @@ func TestDeleteExperimentsFiltered(t *testing.T) {
errC <- errors.New("something real bad")
return sproto.DeleteJobResponse{Err: errC}
}, nil)
mockRM.On("ResolveResourcePool", mock.Anything, mock.Anything).Return(
mock.Anything, mock.Anything, nil)

api, curUser, ctx := setupAPITest(t, nil, &mockRM)

Expand Down
13 changes: 9 additions & 4 deletions master/internal/api_generic_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,17 @@ func (a *apiServer) getGenericTaskLaunchParameters(
return nil, nil, nil, fmt.Errorf("resource slots must be >= 0")
}
isSingleNode := resources.IsSingleNode != nil && *resources.IsSingleNode
poolName, launchWarnings, err := a.m.ResolveResources(resources.ResourcePool,
managerName, poolName, launchWarnings, err := a.m.ResolveResources(
resources.ResourceManager,
resources.ResourcePool,
resources.Slots,
int(proj.WorkspaceId),
isSingleNode)
if err != nil {
return nil, nil, nil, err
}
// Get the base TaskSpec.
taskSpec, err := a.m.fillTaskSpec(poolName, agentUserGroup, userModel)
taskSpec, err := a.m.fillTaskSpec(managerName, poolName, agentUserGroup, userModel)
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -126,6 +128,7 @@ func (a *apiServer) getGenericTaskLaunchParameters(
// Copy discovered (default) resource pool name and slot count.

fillTaskConfig(resources.Slots, taskSpec, &taskConfig.Environment)
taskConfig.Resources.RawResourceManager = &resources.ResourceManager
taskConfig.Resources.RawResourcePool = &poolName
taskConfig.Resources.RawSlots = &resources.Slots

Expand Down Expand Up @@ -356,8 +359,9 @@ func (a *apiServer) CreateGenericTask(
IsUserVisible: true,
Name: fmt.Sprintf("Generic Task %s", taskID),

SlotsNeeded: *genericTaskSpec.GenericTaskConfig.Resources.Slots(),
ResourcePool: genericTaskSpec.GenericTaskConfig.Resources.ResourcePool(),
SlotsNeeded: *genericTaskSpec.GenericTaskConfig.Resources.Slots(),
ResourceManager: genericTaskSpec.GenericTaskConfig.Resources.ResourceManager(),
ResourcePool: genericTaskSpec.GenericTaskConfig.Resources.ResourcePool(),
FittingRequirements: sproto.FittingRequirements{
SingleAgent: isSingleNode,
},
Expand Down Expand Up @@ -660,6 +664,7 @@ func (a *apiServer) UnpauseGenericTask(
IsUserVisible: true,
Name: fmt.Sprintf("Generic Task %s", resumingTask.TaskID),
SlotsNeeded: *genericTaskSpec.GenericTaskConfig.Resources.Slots(),
ResourceManager: genericTaskSpec.GenericTaskConfig.Resources.ResourceManager(),
ResourcePool: genericTaskSpec.GenericTaskConfig.Resources.ResourcePool(),
FittingRequirements: sproto.FittingRequirements{
SingleAgent: isSingleNode,
Expand Down
2 changes: 1 addition & 1 deletion master/internal/api_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (a *apiServer) GetJobsV2(
func (a *apiServer) GetJobQueueStats(
_ context.Context, req *apiv1.GetJobQueueStatsRequest,
) (*apiv1.GetJobQueueStatsResponse, error) {
resp, err := a.m.rm.GetJobQueueStatsRequest(req)
resp, err := a.m.rm.GetJobQueueStatsRequest(req.ResourceManager, req)
if err != nil {
return nil, err
}
Expand Down
18 changes: 7 additions & 11 deletions master/internal/api_resourcepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/rm"
"github.com/determined-ai/determined/master/internal/sproto"
workspaceauth "github.com/determined-ai/determined/master/internal/workspace"
"github.com/determined-ai/determined/master/pkg/set"
"github.com/determined-ai/determined/proto/pkg/apiv1"
Expand Down Expand Up @@ -50,7 +49,7 @@ func (a *apiServer) GetResourcePools(
if err != nil {
return nil, err
}
resp, err := a.m.rm.GetResourcePools(req)
resp, err := a.m.rm.GetResourcePools()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -90,7 +89,7 @@ func (a *apiServer) GetResourcePools(
func (a *apiServer) BindRPToWorkspace(
ctx context.Context, req *apiv1.BindRPToWorkspaceRequest,
) (*apiv1.BindRPToWorkspaceResponse, error) {
err := a.checkIfPoolIsDefault(req.ResourcePoolName)
err := a.checkIfPoolIsDefault(req.ResourceManagerName, req.ResourcePoolName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -121,7 +120,7 @@ func (a *apiServer) BindRPToWorkspace(
func (a *apiServer) OverwriteRPWorkspaceBindings(
ctx context.Context, req *apiv1.OverwriteRPWorkspaceBindingsRequest,
) (*apiv1.OverwriteRPWorkspaceBindingsResponse, error) {
err := a.checkIfPoolIsDefault(req.ResourcePoolName)
err := a.checkIfPoolIsDefault(req.ResourceManagerName, req.ResourcePoolName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -212,16 +211,13 @@ func (a *apiServer) ListWorkspacesBoundToRP(
}, nil
}

func (a *apiServer) checkIfPoolIsDefault(poolName string) error {
defaultComputePool, err := a.m.rm.GetDefaultComputeResourcePool(
sproto.GetDefaultComputeResourcePoolRequest{})
func (a *apiServer) checkIfPoolIsDefault(managerName string, poolName string) error {
defaultComputePool, err := a.m.rm.GetDefaultComputeResourcePool(managerName)
if err != nil {
return err
}

defaultAuxPool, err := a.m.rm.GetDefaultAuxResourcePool(
sproto.GetDefaultAuxResourcePoolRequest{},
)
defaultAuxPool, err := a.m.rm.GetDefaultAuxResourcePool(managerName)
if err != nil {
return err
}
Expand Down Expand Up @@ -254,7 +250,7 @@ func (a *apiServer) canUserModifyWorkspaces(ctx context.Context, ids []int32) er
}

func (a *apiServer) resourcePoolsAsConfigs() ([]config.ResourcePoolConfig, error) {
resp, err := a.m.rm.GetResourcePools(&apiv1.GetResourcePoolsRequest{})
resp, err := a.m.rm.GetResourcePools()
if err != nil {
return nil, err
}
Expand Down
52 changes: 26 additions & 26 deletions master/internal/api_resourcepool_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ func TestPostBindingFails(t *testing.T) {

// TODO (eliu): add some tests for workspaceIDs
// test resource pools on workspaces that do not exist
mockRM.On("GetDefaultComputeResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultComputeResourcePool", mock.Anything).
Return(sproto.GetDefaultComputeResourcePoolResponse{}, nil).Once()
mockRM.On("GetDefaultAuxResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultAuxResourcePool", mock.Anything).
Return(sproto.GetDefaultAuxResourcePoolResponse{}, nil).Once()
_, err := api.BindRPToWorkspace(ctx, &apiv1.BindRPToWorkspaceRequest{
ResourcePoolName: testPoolName,
Expand All @@ -69,11 +69,11 @@ func TestPostBindingFails(t *testing.T) {
require.ErrorContains(t, err, "the following workspaces do not exist: [nonexistent_workspace]")

// test resource pool doesn't exist
mockRM.On("GetDefaultComputeResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultComputeResourcePool", mock.Anything).
Return(sproto.GetDefaultComputeResourcePoolResponse{}, nil).Once()
mockRM.On("GetDefaultAuxResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultAuxResourcePool", mock.Anything).
Return(sproto.GetDefaultAuxResourcePoolResponse{}, nil).Twice()
mockRM.On("GetResourcePools", mock.Anything, mock.Anything).
mockRM.On("GetResourcePools").
Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{},
}, nil).Once()
Expand All @@ -85,7 +85,7 @@ func TestPostBindingFails(t *testing.T) {
require.ErrorContains(t, err, "pool with name testRP doesn't exist")

// test resource pool is a default resource pool
mockRM.On("GetDefaultComputeResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultComputeResourcePool", mock.Anything).
Return(sproto.GetDefaultComputeResourcePoolResponse{PoolName: testPoolName}, nil).Twice()

_, err = api.BindRPToWorkspace(ctx, &apiv1.BindRPToWorkspaceRequest{
Expand All @@ -95,7 +95,7 @@ func TestPostBindingFails(t *testing.T) {

require.ErrorContains(t, err, "default resource pool testRP cannot be bound to any workspace")

mockRM.On("GetDefaultAuxResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultAuxResourcePool", mock.Anything).
Return(sproto.GetDefaultAuxResourcePoolResponse{PoolName: testPoolName}, nil).Once()

_, err = api.BindRPToWorkspace(ctx, &apiv1.BindRPToWorkspaceRequest{
Expand All @@ -106,11 +106,11 @@ func TestPostBindingFails(t *testing.T) {
require.ErrorContains(t, err, "default resource pool testRP cannot be bound to any workspace")

// test no resource pool specified
mockRM.On("GetDefaultComputeResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultComputeResourcePool", mock.Anything).
Return(sproto.GetDefaultComputeResourcePoolResponse{PoolName: testPoolName}, nil).Once()
mockRM.On("GetDefaultAuxResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultAuxResourcePool", mock.Anything).
Return(sproto.GetDefaultAuxResourcePoolResponse{PoolName: testPoolName}, nil).Once()
mockRM.On("GetResourcePools", mock.Anything, mock.Anything).
mockRM.On("GetResourcePools").
Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{{Name: testPoolName}},
}, nil).Once()
Expand All @@ -134,11 +134,11 @@ func TestPostBindingSucceeds(t *testing.T) {
_ = setupWorkspaces(ctx, t, api)

// bind first resource pool
mockRM.On("GetDefaultComputeResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultComputeResourcePool", mock.Anything).
Return(sproto.GetDefaultComputeResourcePoolResponse{}, nil).Twice()
mockRM.On("GetDefaultAuxResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultAuxResourcePool", mock.Anything).
Return(sproto.GetDefaultAuxResourcePoolResponse{}, nil).Twice()
mockRM.On("GetResourcePools", mock.Anything, mock.Anything).
mockRM.On("GetResourcePools").
Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{{Name: testPoolName}},
}, nil).Twice()
Expand Down Expand Up @@ -170,11 +170,11 @@ func TestListWorkspacesBoundToRPFails(t *testing.T) {
_ = setupWorkspaces(ctx, t, api)

// bind first workspace
mockRM.On("GetDefaultComputeResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultComputeResourcePool", mock.Anything).
Return(sproto.GetDefaultComputeResourcePoolResponse{}, nil).Once()
mockRM.On("GetDefaultAuxResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultAuxResourcePool", mock.Anything).
Return(sproto.GetDefaultAuxResourcePoolResponse{}, nil).Once()
mockRM.On("GetResourcePools", mock.Anything, mock.Anything).
mockRM.On("GetResourcePools").
Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{{Name: testPoolName}},
}, nil).Times(3)
Expand Down Expand Up @@ -207,11 +207,11 @@ func TestListWorkspacesBoundToRPSucceeds(t *testing.T) {
workspaceIDs := setupWorkspaces(ctx, t, api)

// test bind resource pool to workspace
mockRM.On("GetDefaultComputeResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultComputeResourcePool", mock.Anything).
Return(sproto.GetDefaultComputeResourcePoolResponse{}, nil).Once()
mockRM.On("GetDefaultAuxResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultAuxResourcePool", mock.Anything).
Return(sproto.GetDefaultAuxResourcePoolResponse{}, nil).Once()
mockRM.On("GetResourcePools", mock.Anything, mock.Anything).
mockRM.On("GetResourcePools").
Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{{Name: testPoolName}},
}, nil).Twice()
Expand All @@ -231,7 +231,7 @@ func TestListWorkspacesBoundToRPSucceeds(t *testing.T) {
require.Equal(t, workspaceIDs[0], resp.WorkspaceIds[0])

// test listing on resource pool that has no bindings
mockRM.On("GetResourcePools", mock.Anything, mock.Anything).
mockRM.On("GetResourcePools").
Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{
{Name: testPoolName}, {Name: testPool2Name},
Expand All @@ -256,11 +256,11 @@ func TestPatchBindingsSucceeds(t *testing.T) {
workspaceIDs := setupWorkspaces(ctx, t, api)

// setup first binding
mockRM.On("GetDefaultComputeResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultComputeResourcePool", mock.Anything).
Return(sproto.GetDefaultComputeResourcePoolResponse{}, nil).Times(4)
mockRM.On("GetDefaultAuxResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultAuxResourcePool", mock.Anything).
Return(sproto.GetDefaultAuxResourcePoolResponse{}, nil).Times(4)
mockRM.On("GetResourcePools", mock.Anything, mock.Anything).
mockRM.On("GetResourcePools").
Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{{Name: testPoolName}},
}, nil).Times(7)
Expand Down Expand Up @@ -328,11 +328,11 @@ func TestDeleteBindingsSucceeds(t *testing.T) {

// TODO: fix all comments
// setup first binding
mockRM.On("GetDefaultComputeResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultComputeResourcePool", mock.Anything).
Return(sproto.GetDefaultComputeResourcePoolResponse{}, nil).Times(1)
mockRM.On("GetDefaultAuxResourcePool", mock.Anything, mock.Anything).
mockRM.On("GetDefaultAuxResourcePool", mock.Anything).
Return(sproto.GetDefaultAuxResourcePoolResponse{}, nil).Times(1)
mockRM.On("GetResourcePools", mock.Anything, mock.Anything).
mockRM.On("GetResourcePools").
Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{{Name: testPoolName}},
}, nil).Times(3)
Expand Down
3 changes: 1 addition & 2 deletions master/internal/api_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
expauth "github.com/determined-ai/determined/master/internal/experiment"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/logpattern"
"github.com/determined-ai/determined/master/internal/sproto"
"github.com/determined-ai/determined/master/internal/task"
"github.com/determined-ai/determined/master/internal/webhooks"
"github.com/determined-ai/determined/master/pkg/model"
Expand Down Expand Up @@ -551,7 +550,7 @@ func (a *apiServer) GetActiveTasksCount(
func (a *apiServer) GetTasks(
ctx context.Context, req *apiv1.GetTasksRequest,
) (resp *apiv1.GetTasksResponse, err error) {
summary, err := a.m.rm.GetAllocationSummaries(sproto.GetAllocationSummaries{})
summary, err := a.m.rm.GetAllocationSummaries()
if err != nil {
return nil, err
}
Expand Down
4 changes: 1 addition & 3 deletions master/internal/api_tasks_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ func TestGetTasksAuthZ(t *testing.T) {
var allocations map[model.AllocationID]sproto.AllocationSummary

mockRM := MockRM()
mockRM.On("GetAllocationSummaries", mock.Anything).Return(func(
_ sproto.GetAllocationSummaries,
) map[model.AllocationID]sproto.AllocationSummary {
mockRM.On("GetAllocationSummaries").Return(func() map[model.AllocationID]sproto.AllocationSummary {
return allocations
}, nil)

Expand Down
Loading

0 comments on commit 6857ecf

Please sign in to comment.