Skip to content

Commit

Permalink
fix: properly merge resource configs (#9233)
Browse files Browse the repository at this point in the history
  • Loading branch information
eecsliu authored Apr 24, 2024
1 parent 3b39d3c commit c18ac83
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 37 deletions.
21 changes: 16 additions & 5 deletions master/internal/rm/agentrm/agent_resource_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,18 +510,29 @@ func (a *ResourceManager) SetGroupWeight(msg sproto.SetGroupWeight) error {
// TaskContainerDefaults implements rm.ResourceManager.
func (a *ResourceManager) TaskContainerDefaults(
resourcePoolName rm.ResourcePoolName,
fallbackConfig model.TaskContainerDefaultsConfig,
defaultConfig model.TaskContainerDefaultsConfig,
) (model.TaskContainerDefaultsConfig, error) {
result := fallbackConfig
result := defaultConfig

// Iterate through configured pools looking for a TaskContainerDefaults setting.
var poolConfigOverrides *model.TaskContainerDefaultsConfig
for _, pool := range a.poolsConfig {
if resourcePoolName.String() == pool.PoolName {
if pool.TaskContainerDefaults == nil {
break
if pool.TaskContainerDefaults != nil {
poolConfigOverrides = pool.TaskContainerDefaults
}
result = *pool.TaskContainerDefaults
break
}
}

if poolConfigOverrides != nil {
tmp, err := result.Merge(*poolConfigOverrides)
if err != nil {
return model.TaskContainerDefaultsConfig{}, err
}
result = tmp
}

return result, nil
}

Expand Down
53 changes: 25 additions & 28 deletions master/internal/rm/kubernetesrm/kubernetes_resource_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,13 +488,31 @@ func (k ResourceManager) IsReattachableOnlyAfterStarted() bool {

// TaskContainerDefaults returns TaskContainerDefaults for the specified pool.
func (k ResourceManager) TaskContainerDefaults(
pool rm.ResourcePoolName,
fallbackConfig model.TaskContainerDefaultsConfig,
) (result model.TaskContainerDefaultsConfig, err error) {
return k.getTaskContainerDefaults(taskContainerDefaults{
fallbackDefault: fallbackConfig,
resourcePool: pool.String(),
}), nil
resourcePoolName rm.ResourcePoolName,
defaultConfig model.TaskContainerDefaultsConfig,
) (model.TaskContainerDefaultsConfig, error) {
result := defaultConfig

// Iterate through configured pools looking for a TaskContainerDefaults setting.
var poolConfigOverrides *model.TaskContainerDefaultsConfig
for _, pool := range k.poolsConfig {
if resourcePoolName.String() == pool.PoolName {
if pool.TaskContainerDefaults != nil {
poolConfigOverrides = pool.TaskContainerDefaults
}
break
}
}

if poolConfigOverrides != nil {
tmp, err := result.Merge(*poolConfigOverrides)
if err != nil {
return model.TaskContainerDefaultsConfig{}, err
}
result = tmp
}

return result, nil
}

func (k *ResourceManager) podStatusUpdateCallback(msg sproto.UpdatePodStatus) {
Expand Down Expand Up @@ -613,27 +631,6 @@ func (k *ResourceManager) getResourcePoolConfig(poolName string) (
return config.ResourcePoolConfig{}, errors.Errorf("cannot find resource pool %s", poolName)
}

type taskContainerDefaults struct {
fallbackDefault model.TaskContainerDefaultsConfig
resourcePool string
}

func (k *ResourceManager) getTaskContainerDefaults(
msg taskContainerDefaults,
) model.TaskContainerDefaultsConfig {
result := msg.fallbackDefault
// Iterate through configured pools looking for a TaskContainerDefaults setting.
for _, pool := range k.poolsConfig {
if msg.resourcePool == pool.PoolName {
if pool.TaskContainerDefaults == nil {
break
}
result = *pool.TaskContainerDefaults
}
}
return result
}

// EnableAgent allows scheduling on a node that has been disabled.
func (k *ResourceManager) EnableAgent(
req *apiv1.EnableAgentRequest,
Expand Down
19 changes: 15 additions & 4 deletions master/pkg/model/task_container_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,13 @@ func (c TaskContainerDefaultsConfig) Merge(
}

if otherEnvVars := other.EnvironmentVariables; otherEnvVars != nil {
otherEnvs := other.EnvironmentVariables
res.EnvironmentVariables.CPU = mergeEnvVars(res.EnvironmentVariables.CPU, otherEnvs.CPU)
res.EnvironmentVariables.CUDA = mergeEnvVars(res.EnvironmentVariables.CUDA, otherEnvs.CUDA)
res.EnvironmentVariables.ROCM = mergeEnvVars(res.EnvironmentVariables.ROCM, otherEnvs.ROCM)
if res.EnvironmentVariables == nil {
res.EnvironmentVariables = other.EnvironmentVariables
} else {
res.EnvironmentVariables.CPU = mergeEnvVars(res.EnvironmentVariables.CPU, otherEnvVars.CPU)
res.EnvironmentVariables.CUDA = mergeEnvVars(res.EnvironmentVariables.CUDA, otherEnvVars.CUDA)
res.EnvironmentVariables.ROCM = mergeEnvVars(res.EnvironmentVariables.ROCM, otherEnvVars.ROCM)
}
}

if other.AddCapabilities != nil {
Expand Down Expand Up @@ -301,6 +304,14 @@ func (c TaskContainerDefaultsConfig) Merge(
res.Pbs.SetSbatchArgs(tmp)
}

if other.LogPolicies != nil {
if res.LogPolicies == nil {
res.LogPolicies = other.LogPolicies
} else {
res.LogPolicies = res.LogPolicies.Merge(other.LogPolicies)
}
}

return res, nil
}

Expand Down

0 comments on commit c18ac83

Please sign in to comment.