Skip to content

Commit

Permalink
refactor: DET-9976 remove agentID type from agentrm (#9040)
Browse files Browse the repository at this point in the history
  • Loading branch information
jesse-amano-hpe authored Mar 26, 2024
1 parent 0710c58 commit 1202d5c
Show file tree
Hide file tree
Showing 17 changed files with 78 additions and 73 deletions.
18 changes: 9 additions & 9 deletions master/internal/rm/agentrm/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type (

mu sync.Mutex

id agentID
id aproto.ID
registeredTime time.Time
address string
agentUpdates *queue.Queue[agentUpdatedEvent]
Expand Down Expand Up @@ -115,7 +115,7 @@ type (
)

func newAgent(
id agentID,
id aproto.ID,
agentUpdates *queue.Queue[agentUpdatedEvent],
resourcePoolName string,
rpConfig *config.ResourcePoolConfig,
Expand Down Expand Up @@ -150,7 +150,7 @@ func newAgent(
// agentReconnectWait if it never reconnects.
// Ensure RP is aware of the agent.
a.syslog.Infof("adding agent: %s", a.agentState.agentID())
err := a.updateAgentStartStats(a.resourcePoolName, string(a.id), a.agentState.numSlots())
err := a.updateAgentStartStats(a.resourcePoolName, a.id, a.agentState.numSlots())
if err != nil {
a.syslog.WithError(err).Error("failed to update agent start stats")
}
Expand Down Expand Up @@ -298,7 +298,7 @@ func (a *agent) stop(cause error) {
}

a.syslog.Infof("removing agent: %s", a.id)
err := a.updateAgentEndStats(string(a.id))
err := a.updateAgentEndStats(a.id)
if err != nil {
a.syslog.WithError(err).Error("failed to update agent end stats")
}
Expand Down Expand Up @@ -680,7 +680,7 @@ func (a *agent) agentStarted(agentStarted *aproto.AgentStarted) {
a.agentState.agentStarted(agentStarted)

a.syslog.Infof("adding agent: %s", a.agentState.agentID())
err := a.updateAgentStartStats(a.resourcePoolName, string(a.id), a.agentState.numSlots())
err := a.updateAgentStartStats(a.resourcePoolName, a.id, a.agentState.numSlots())
if err != nil {
a.syslog.WithError(err).Error("failed to update agent start stats")
}
Expand Down Expand Up @@ -915,17 +915,17 @@ func (a *agent) notifyListeners() {
}

func (a *agent) updateAgentStartStats(
poolName string, agentID string, slots int,
poolName string, agentID aproto.ID, slots int,
) error {
return db.SingleDB().RecordAgentStats(&model.AgentStats{
ResourcePool: poolName,
AgentID: agentID,
AgentID: string(agentID),
Slots: slots,
})
}

func (a *agent) updateAgentEndStats(agentID string) error {
func (a *agent) updateAgentEndStats(agentID aproto.ID) error {
return db.EndAgentStats(&model.AgentStats{
AgentID: agentID,
AgentID: string(agentID),
})
}
5 changes: 1 addition & 4 deletions master/internal/rm/agentrm/agent_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@ type slotData struct {
ContainerID *cproto.ID
}

// agentID is the agent id type.
type agentID string

// agentSnapshot is a database representation of `agentState`.
type agentSnapshot struct {
bun.BaseModel `bun:"table:resourcemanagers_agent_agentstate,alias:rmas"`

ID int64 `bun:"id,pk,autoincrement"`
AgentID agentID `bun:"agent_id,notnull,unique"`
AgentID aproto.ID `bun:"agent_id,notnull,unique"`
UUID string `bun:"uuid,notnull,unique"`
ResourcePoolName string `bun:"resource_pool_name,notnull"`
Label string `bun:"label"`
Expand Down
16 changes: 8 additions & 8 deletions master/internal/rm/agentrm/agent_resource_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (*ResourceManager) DeleteJob(sproto.DeleteJob) (sproto.DeleteJobResponse, e

// DisableAgent implements rm.ResourceManager.
func (a *ResourceManager) DisableAgent(msg *apiv1.DisableAgentRequest) (*apiv1.DisableAgentResponse, error) {
agent, ok := a.agentService.get(agentID(msg.AgentId))
agent, ok := a.agentService.get(aproto.ID(msg.AgentId))
if !ok {
return nil, api.NotFoundErrs("agent", msg.AgentId, true)
}
Expand All @@ -155,7 +155,7 @@ func (a *ResourceManager) DisableSlot(req *apiv1.DisableSlotRequest) (*apiv1.Dis
deviceID := device.ID(deviceIDStr)

enabled := false
result, err := a.handlePatchSlotState(agentID(req.AgentId), patchSlotState{
result, err := a.handlePatchSlotState(aproto.ID(req.AgentId), patchSlotState{
id: deviceID,
enabled: &enabled,
drain: &req.Drain,
Expand All @@ -168,7 +168,7 @@ func (a *ResourceManager) DisableSlot(req *apiv1.DisableSlotRequest) (*apiv1.Dis

// EnableAgent implements rm.ResourceManager.
func (a *ResourceManager) EnableAgent(msg *apiv1.EnableAgentRequest) (*apiv1.EnableAgentResponse, error) {
agent, ok := a.agentService.get(agentID(msg.AgentId))
agent, ok := a.agentService.get(aproto.ID(msg.AgentId))
if !ok {
return nil, api.NotFoundErrs("agent", msg.AgentId, true)
}
Expand All @@ -184,15 +184,15 @@ func (a *ResourceManager) EnableSlot(req *apiv1.EnableSlotRequest) (*apiv1.Enabl
deviceID := device.ID(deviceIDStr)

enabled := true
result, err := a.handlePatchSlotState(agentID(req.AgentId), patchSlotState{id: deviceID, enabled: &enabled})
result, err := a.handlePatchSlotState(aproto.ID(req.AgentId), patchSlotState{id: deviceID, enabled: &enabled})
if err != nil {
return nil, err
}
return &apiv1.EnableSlotResponse{Slot: result.ToProto()}, nil
}

func (a *ResourceManager) handlePatchSlotState(
agentID agentID, msg patchSlotState,
agentID aproto.ID, msg patchSlotState,
) (*model.SlotSummary, error) {
agent, ok := a.agentService.get(agentID)
if !ok {
Expand Down Expand Up @@ -225,7 +225,7 @@ func (*ResourceManager) ExternalPreemptionPending(sproto.PendingPreemption) erro

// GetAgent implements rm.ResourceManager.
func (a *ResourceManager) GetAgent(msg *apiv1.GetAgentRequest) (*apiv1.GetAgentResponse, error) {
agent, ok := a.agentService.get(agentID(msg.AgentId))
agent, ok := a.agentService.get(aproto.ID(msg.AgentId))
if !ok {
return nil, api.NotFoundErrs("agent", msg.AgentId, true)
}
Expand Down Expand Up @@ -343,7 +343,7 @@ func (a *ResourceManager) GetSlot(req *apiv1.GetSlotRequest) (*apiv1.GetSlotResp
}
deviceID := device.ID(deviceIDStr)

result, err := a.handlePatchSlotState(agentID(req.AgentId), patchSlotState{id: deviceID})
result, err := a.handlePatchSlotState(aproto.ID(req.AgentId), patchSlotState{id: deviceID})
if err != nil {
return nil, err
}
Expand All @@ -352,7 +352,7 @@ func (a *ResourceManager) GetSlot(req *apiv1.GetSlotRequest) (*apiv1.GetSlotResp

// GetSlots implements rm.ResourceManager.
func (a *ResourceManager) GetSlots(msg *apiv1.GetSlotsRequest) (*apiv1.GetSlotsResponse, error) {
agent, ok := a.agentService.get(agentID(msg.AgentId))
agent, ok := a.agentService.get(aproto.ID(msg.AgentId))
if !ok {
return nil, api.NotFoundErrs("agent", msg.AgentId, true)
}
Expand Down
12 changes: 6 additions & 6 deletions master/internal/rm/agentrm/agent_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type slot struct {
type agentState struct {
syslog *log.Entry

id agentID // TODO(DET-9976): Why agentID and aproto.ID? Let's just have one or the other.
id aproto.ID
handler *agent
Devices map[device.Device]*cproto.ID
resourcePoolName string
Expand All @@ -59,7 +59,7 @@ type agentState struct {

// newAgentState returns a new agent empty agent state backed by the handler.
// TODO(DET-9977): It is error-prone that we can new up an agentState is invalid / would cause panics.
func newAgentState(id agentID, maxZeroSlotContainers int) *agentState {
func newAgentState(id aproto.ID, maxZeroSlotContainers int) *agentState {
return &agentState{
syslog: log.WithField("component", "agent-state-state").WithField("id", id),
id: id,
Expand All @@ -77,7 +77,7 @@ func (a *agentState) string() string {
return string(a.id)
}

func (a *agentState) agentID() agentID {
func (a *agentState) agentID() aproto.ID {
return a.id
}

Expand Down Expand Up @@ -553,13 +553,13 @@ func (a *agentState) clearUnlessRecovered(

// retrieveAgentStates reconstructs AgentStates from the database for all resource pools that
// have agent_container_reattachment enabled.
func retrieveAgentStates() (map[agentID]agentState, error) {
func retrieveAgentStates() (map[aproto.ID]agentState, error) {
var snapshots []agentSnapshot
if err := db.Bun().NewSelect().Model(&snapshots).Scan(context.TODO()); err != nil {
return nil, fmt.Errorf("selecting agent snapshost: %w", err)
}

result := make(map[agentID]agentState, len(snapshots))
result := make(map[aproto.ID]agentState, len(snapshots))
for _, s := range snapshots {
state, err := newAgentStateFromSnapshot(s)
if err != nil {
Expand Down Expand Up @@ -644,7 +644,7 @@ func (a *agentState) restoreContainersField() error {
return nil
}

func clearAgentStates(agentIds []agentID) error {
func clearAgentStates(agentIds []aproto.ID) error {
if _, err := db.Bun().NewDelete().Model((*agentSnapshot)(nil)).
Where("agent_id in (?)", bun.In(agentIds)).
Exec(context.TODO()); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions master/internal/rm/agentrm/agent_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestAgentStatePersistence(t *testing.T) {
require.NoError(t, err)

// Fake an agent, test adding it to the db.
state := newAgentState(agentID(uuid.NewString()), 64)
state := newAgentState(aproto.ID(uuid.NewString()), 64)
state.handler = &agent{}
state.resourcePoolName = "compute"
devices := []device.Device{
Expand Down Expand Up @@ -187,7 +187,7 @@ func TestAgentStatePersistence(t *testing.T) {

func TestClearAgentStates(t *testing.T) {
ctx := context.Background()
agentIDs := []agentID{agentID(uuid.NewString()), agentID(uuid.NewString())}
agentIDs := []aproto.ID{aproto.ID(uuid.NewString()), aproto.ID(uuid.NewString())}
for _, agentID := range agentIDs {
_, err := db.Bun().NewInsert().Model(&agentSnapshot{
AgentID: agentID,
Expand Down Expand Up @@ -326,7 +326,7 @@ func Test_agentState_checkAgentStartedDevicesMatch(t *testing.T) {

func TestSlotStates(t *testing.T) {
rpName := "test"
state := newAgentState(agentID(uuid.NewString()), 64)
state := newAgentState(aproto.ID(uuid.NewString()), 64)
state.handler = &agent{}
state.resourcePoolName = rpName
devices := []device.Device{
Expand Down
22 changes: 11 additions & 11 deletions master/internal/rm/agentrm/agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type agents struct {
syslog *logrus.Entry
mu sync.Mutex

agents *tasklist.Registry[agentID, *agent]
agents *tasklist.Registry[aproto.ID, *agent]
agentUpdates *queue.Queue[agentUpdatedEvent]
poolConfigs []config.ResourcePoolConfig
opts *aproto.MasterSetAgentOptions
Expand All @@ -51,7 +51,7 @@ func newAgentService(
agentUpdates := queue.New[agentUpdatedEvent]()
a := &agents{
syslog: logrus.WithField("component", "agents"),
agents: tasklist.NewRegistry[agentID, *agent](),
agents: tasklist.NewRegistry[aproto.ID, *agent](),
agentUpdates: agentUpdates,
poolConfigs: poolConfigs,
opts: opts,
Expand All @@ -66,7 +66,7 @@ func newAgentService(
}

a.syslog.Debugf("agent states to restore: %d", len(agentStates))
badAgentIds := []agentID{}
badAgentIds := []aproto.ID{}

for agentID, state := range agentStates {
state := state
Expand Down Expand Up @@ -100,9 +100,9 @@ func newAgentService(
}

// list implements agentService.
func (a *agents) list(resourcePoolName string) map[agentID]*agentState {
func (a *agents) list(resourcePoolName string) map[aproto.ID]*agentState {
agents := a.agents.Snapshot()
result := make(map[agentID]*agentState, len(agents))
result := make(map[aproto.ID]*agentState, len(agents))
for id, a := range agents {
state, err := a.State()
if err != nil {
Expand All @@ -117,7 +117,7 @@ func (a *agents) list(resourcePoolName string) map[agentID]*agentState {
return result
}

func (a *agents) get(id agentID) (*agent, bool) {
func (a *agents) get(id aproto.ID) (*agent, bool) {
return a.agents.Load(id)
}

Expand All @@ -143,7 +143,7 @@ func (a *agents) HandleWebsocketConnection(msg webSocketRequest) error {
}
}

id := msg.echoCtx.QueryParam("id")
id := aproto.ID(msg.echoCtx.QueryParam("id"))
reconnect, err := msg.isReconnect()
if err != nil {
return errors.Wrapf(err, "parsing reconnect query param")
Expand All @@ -153,7 +153,7 @@ func (a *agents) HandleWebsocketConnection(msg webSocketRequest) error {
// accept it. Whether it is a network failure or a crash/restart, we will just try
// to reattach whatever containers still exist.
// That logic is located in agent.receive(ws.WebSocketRequest).
existingRef, ok := a.agents.Load(agentID(id))
existingRef, ok := a.agents.Load(id)
if ok {
a.syslog.WithField("reconnect", reconnect).Infof("restoring agent id: %s", id)
return existingRef.HandleWebsocketConnection(msg)
Expand All @@ -168,12 +168,12 @@ func (a *agents) HandleWebsocketConnection(msg webSocketRequest) error {

// Finally, this must not be a recovery flow, so just create the agent actor.
resourcePool := msg.echoCtx.QueryParam("resource_pool")
ref, err := a.createAgent(agentID(id), resourcePool, a.opts, nil, func() { _ = a.agents.Delete(agentID(id)) })
ref, err := a.createAgent(id, resourcePool, a.opts, nil, func() { _ = a.agents.Delete(id) })
if err != nil {
return err
}

err = a.agents.Add(agentID(id), ref)
err = a.agents.Add(id, ref)
if err != nil {
return fmt.Errorf("adding agent because of incoming websocket: %w", err)
}
Expand All @@ -189,7 +189,7 @@ func (a *agents) getAgents() *apiv1.GetAgentsResponse {
}

func (a *agents) createAgent(
id agentID,
id aproto.ID,
resourcePool string,
opts *aproto.MasterSetAgentOptions,
restoredAgentState *agentState,
Expand Down
9 changes: 5 additions & 4 deletions master/internal/rm/agentrm/fair_share.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/determined-ai/determined/master/internal/rm/tasklist"
"github.com/determined-ai/determined/master/internal/sproto"
"github.com/determined-ai/determined/master/pkg/aproto"
"github.com/determined-ai/determined/master/pkg/check"
"github.com/determined-ai/determined/master/pkg/mathx"
"github.com/determined-ai/determined/master/pkg/model"
Expand Down Expand Up @@ -81,7 +82,7 @@ func (f *fairShare) JobQInfo(rp *resourcePool) map[model.JobID]*sproto.RMJobInfo
func fairshareSchedule(
taskList *tasklist.TaskList,
groups map[model.JobID]*tasklist.Group,
agents map[agentID]*agentState,
agents map[aproto.ID]*agentState,
fittingMethod SoftConstraint,
allowHeterogeneousAgentFits bool,
) ([]*sproto.AllocateRequest, []model.AllocationID) {
Expand Down Expand Up @@ -133,7 +134,7 @@ func fairshareSchedule(
return allToAllocate, allToRelease
}

func totalCapacity(agents map[agentID]*agentState) int {
func totalCapacity(agents map[aproto.ID]*agentState) int {
result := 0

for _, agent := range agents {
Expand All @@ -147,7 +148,7 @@ func calculateGroupStates(
taskList *tasklist.TaskList,
groups map[model.JobID]*tasklist.Group,
capacity int,
agents map[agentID]*agentState,
agents map[aproto.ID]*agentState,
fittingMethod SoftConstraint,
allowHeterogeneousAgentFits bool,
) []*groupState {
Expand Down Expand Up @@ -358,7 +359,7 @@ func calculateSmallestAllocatableTask(state *groupState) (smallest *sproto.Alloc
}

func assignTasks(
agents map[agentID]*agentState, states []*groupState, fittingMethod SoftConstraint,
agents map[aproto.ID]*agentState, states []*groupState, fittingMethod SoftConstraint,
allowHetergenousAgentFits bool,
) ([]*sproto.AllocateRequest, []model.AllocationID) {
toAllocate := make([]*sproto.AllocateRequest, 0)
Expand Down
Loading

0 comments on commit 1202d5c

Please sign in to comment.