Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(agent): Format agent code #4784

Merged
merged 15 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 129 additions & 34 deletions scheduler/pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,16 @@ func (c *Client) WaitReadySubServices(isStartup bool) error {

func (c *Client) UnloadAllModels() error {
logger := c.logger.WithField("func", "UnloadAllModels")

models, err := c.stateManager.v2Client.GetModels()
if err != nil {
return err
}

for _, model := range models {
if model.State == interfaces.ServerModelState_READY || model.State == interfaces.ServerModelState_LOADING {
logger.Infof("Unloading existing model %s", model)

v2Err := c.stateManager.v2Client.UnloadModel(model.Name)
if v2Err != nil {
if !v2Err.IsNotFound() {
Expand All @@ -369,25 +372,31 @@ func (c *Client) UnloadAllModels() error {
}
}
}

err := c.ModelRepository.RemoveModelVersion(model.Name)
if err != nil {
c.logger.WithError(err).Errorf("Model %s could not be removed from repository", model)
}
}

return nil
}

func (c *Client) getConnection(host string, plainTxtPort int, tlsPort int) (*grpc.ClientConn, error) {
logger := c.logger.WithField("func", "getConnection")

var err error
protocol := seldontls.GetSecurityProtocolFromEnv(seldontls.EnvSecurityPrefixControlPlane)
if protocol == seldontls.SecurityProtocolSSL {
c.certificateStore, err = seldontls.NewCertificateStore(seldontls.Prefix(seldontls.EnvSecurityPrefixControlPlaneClient),
seldontls.ValidationPrefix(seldontls.EnvSecurityPrefixControlPlaneServer))
c.certificateStore, err = seldontls.NewCertificateStore(
seldontls.Prefix(seldontls.EnvSecurityPrefixControlPlaneClient),
seldontls.ValidationPrefix(seldontls.EnvSecurityPrefixControlPlaneServer),
)
if err != nil {
return nil, err
}
}

var transCreds credentials.TransportCredentials
var port int
if c.certificateStore == nil {
Expand All @@ -399,41 +408,51 @@ func (c *Client) getConnection(host string, plainTxtPort int, tlsPort int) (*grp
transCreds = c.certificateStore.CreateClientTransportCredentials()
port = tlsPort
}

logger.Infof("Connecting (non-blocking) to scheduler at %s:%d", host, port)

opts := []grpc.DialOption{
grpc.WithTransportCredentials(transCreds),
grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor()),
}
logger.Infof("Connecting (non-blocking) to scheduler at %s:%d", host, port)
conn, err := grpc.Dial(fmt.Sprintf("%s:%d", host, port), opts...)
if err != nil {
return nil, err
}

return conn, nil
}

func (c *Client) StartService() error {
logger := c.logger.WithField("func", "StartService")
logger.Info("Call subscribe to scheduler")

grpcClient := agent.NewAgentServiceClient(c.conn)

stream, err := grpcClient.Subscribe(context.Background(), &agent.AgentSubscribeRequest{
ServerName: c.serverName,
ReplicaIdx: c.replicaIdx,
ReplicaConfig: c.replicaConfig,
LoadedModels: c.stateManager.modelVersions.getVersionsForAllModels(),
Shared: true,
AvailableMemoryBytes: c.stateManager.GetAvailableMemoryBytesWithOverCommit(),
}, grpc_retry.WithMax(1)) //TODO make configurable
stream, err := grpcClient.Subscribe(
context.Background(),
&agent.AgentSubscribeRequest{
ServerName: c.serverName,
ReplicaIdx: c.replicaIdx,
ReplicaConfig: c.replicaConfig,
LoadedModels: c.stateManager.modelVersions.getVersionsForAllModels(),
Shared: true,
AvailableMemoryBytes: c.stateManager.GetAvailableMemoryBytesWithOverCommit(),
},
grpc_retry.WithMax(1),
) //TODO make configurable
if err != nil {
return err
}

logger.Info("Subscribed to scheduler")

// start model scaling events consumer
clientStream, err := grpcClient.ModelScalingTrigger(context.Background())
if err != nil {
return err
}

c.modelScalingClientStream = clientStream
defer func() {
_, _ = clientStream.CloseAndRecv()
Expand All @@ -445,19 +464,27 @@ func (c *Client) StartService() error {
logger.Info("Stopping")
break
}

operation, err := stream.Recv()
if err != nil {
logger.WithError(err).Error("event recv failed")
break
}

c.logger.Infof("Received operation")

switch operation.Operation {
case agent.ModelOperationMessage_LOAD_MODEL:
c.logger.Infof("calling load model")

go func() {
err := c.LoadModel(operation)
if err != nil {
c.logger.WithError(err).Errorf("Failed to handle load model %s:%d", operation.GetModelVersion().GetModel().GetMeta().GetName(), operation.GetModelVersion().GetVersion())
c.logger.WithError(err).Errorf(
"Failed to handle load model %s:%d",
operation.GetModelVersion().GetModel().GetMeta().GetName(),
operation.GetModelVersion().GetVersion(),
)
}
}()

Expand All @@ -466,7 +493,11 @@ func (c *Client) StartService() error {
go func() {
err := c.UnloadModel(operation)
if err != nil {
c.logger.WithError(err).Errorf("Failed to handle unload model %s:%d", operation.GetModelVersion().GetModel().GetMeta().GetName(), operation.GetModelVersion().GetVersion())
c.logger.WithError(err).Errorf(
"Failed to handle unload model %s:%d",
operation.GetModelVersion().GetModel().GetMeta().GetName(),
operation.GetModelVersion().GetVersion(),
)
}
}()
}
Expand All @@ -481,12 +512,16 @@ func (c *Client) StartService() error {
}

func (c *Client) getArtifactConfig(request *agent.ModelOperationMessage) ([]byte, error) {
if request.GetModelVersion().GetModel().GetModelSpec().StorageConfig == nil {
model := request.GetModelVersion().GetModel()

if model.GetModelSpec().StorageConfig == nil {
return nil, nil
}

logger := c.logger.WithField("func", "getArtifactConfig")
logger.Infof("Getting Rclone configuration")
switch x := request.GetModelVersion().GetModel().GetModelSpec().StorageConfig.Config.(type) {

switch x := model.GetModelSpec().GetStorageConfig().GetConfig().(type) {
case *pbs.StorageConfig_StorageRcloneConfig:
return []byte(x.StorageRcloneConfig), nil
case *pbs.StorageConfig_StorageSecretName:
Expand All @@ -495,27 +530,41 @@ func (c *Client) getArtifactConfig(request *agent.ModelOperationMessage) ([]byte
if err != nil {
return nil, err
}
if request.GetModelVersion().GetModel().GetMeta().GetKubernetesMeta() != nil {
c.KubernetesOptions.secretsHandler = k8s.NewSecretsHandler(secretClientSet, request.GetModelVersion().GetModel().GetMeta().GetKubernetesMeta().GetNamespace())

if model.GetMeta().GetKubernetesMeta() != nil {
c.KubernetesOptions.secretsHandler = k8s.NewSecretsHandler(
secretClientSet,
model.GetMeta().GetKubernetesMeta().GetNamespace(),
)
} else {
return nil, fmt.Errorf("Can't load model %s:%dwith k8s secret %s when namespace not set", request.GetModelVersion().GetModel().GetMeta().GetName(), request.GetModelVersion().GetVersion(), x.StorageSecretName)
return nil, fmt.Errorf(
"Can't load model %s:%dwith k8s secret %s when namespace not set",
model.GetMeta().GetName(),
request.GetModelVersion().GetVersion(),
x.StorageSecretName,
)
}

}

config, err := c.secretsHandler.GetSecretConfig(x.StorageSecretName)
if err != nil {
return nil, err
}

return config, nil
}

return nil, nil
}

func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
logger := c.logger.WithField("func", "LoadModel")
if request == nil || request.ModelVersion == nil {
return fmt.Errorf("Empty request received for load model")
}

logger := c.logger.WithField("func", "LoadModel")

modelName := request.GetModelVersion().GetModel().GetMeta().GetName()
modelVersion := request.GetModelVersion().GetVersion()
modelWithVersion := util.GetVersionedModelName(modelName, modelVersion)
Expand All @@ -532,17 +581,26 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, err)
return err
}

// Copy model artifact
chosenVersionPath, err := c.ModelRepository.DownloadModelVersion(
modelWithVersion, pinnedModelVersion, request.GetModelVersion().GetModel().GetModelSpec(), config)
modelWithVersion,
pinnedModelVersion,
request.GetModelVersion().GetModel().GetModelSpec(),
config,
)
if err != nil {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, err)
return err
}
logger.Infof("Chose path %s for model %s:%d", *chosenVersionPath, modelName, modelVersion)

// TODO: consider whether we need the actual protos being sent to `LoadModelVersion`?
modifiedModelVersionRequest := getModifiedModelVersion(modelWithVersion, pinnedModelVersion, request.GetModelVersion())
modifiedModelVersionRequest := getModifiedModelVersion(
modelWithVersion,
pinnedModelVersion,
request.GetModelVersion(),
)
loaderFn := func() error {
return c.stateManager.LoadModelVersion(modifiedModelVersionRequest)
}
Expand All @@ -552,8 +610,8 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
}

// if scheduler ask for autoscaling, add pointers in model scaling stats
// we have done it via the scaling service as not to expose here all the model scaling stats that we have and then call Add on
// each one of them
// we have done it via the scaling service as not to expose here all the model scaling stats
// that we have and then call Add on each one of them
if request.AutoscalingEnabled {
logger.Debugf("Enabling autoscaling checks for model %s", modelWithVersion)
if err := c.modelScalingService.(*modelscaling.StatsAnalyserService).AddModel(modelWithVersion); err != nil {
Expand All @@ -566,10 +624,12 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
}

func (c *Client) UnloadModel(request *agent.ModelOperationMessage) error {
logger := c.logger.WithField("func", "UnloadModel")
if request == nil || request.GetModelVersion() == nil {
return fmt.Errorf("Empty request received for unload model")
}

logger := c.logger.WithField("func", "UnloadModel")

modelName := request.GetModelVersion().GetModel().GetMeta().GetName()
modelVersion := request.GetModelVersion().GetVersion()
modelWithVersion := util.GetVersionedModelName(modelName, modelVersion)
Expand All @@ -596,7 +656,10 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage) error {
// each one of them
// note that we do not check if the model is already enabled for autoscaling, should we?
if err := c.modelScalingService.(*modelscaling.StatsAnalyserService).DeleteModel(modelWithVersion); err != nil {
logger.WithError(err).Warnf("Cannot delete model %s from scaling service, likely that it was not enabled in the first place", modelWithVersion)
logger.WithError(err).Warnf(
"Cannot delete model %s from scaling service, likely that it was not enabled in the first place",
modelWithVersion,
)
}

err := c.ModelRepository.RemoveModelVersion(modelWithVersion)
Expand All @@ -609,7 +672,12 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage) error {
return c.sendAgentEvent(modelName, modelVersion, agent.ModelEventMessage_UNLOADED)
}

func (c *Client) sendModelEventError(modelName string, modelVersion uint32, event agent.ModelEventMessage_Event, err error) {
func (c *Client) sendModelEventError(
modelName string,
modelVersion uint32,
event agent.ModelEventMessage_Event,
err error,
) {
grpcClient := agent.NewAgentServiceClient(c.conn)
_, err = grpcClient.AgentEvent(context.Background(), &agent.ModelEventMessage{
ServerName: c.serverName,
Expand All @@ -625,11 +693,20 @@ func (c *Client) sendModelEventError(modelName string, modelVersion uint32, even
}
}

func (c *Client) sendAgentEvent(modelName string, modelVersion uint32, event agent.ModelEventMessage_Event) error {
func (c *Client) sendAgentEvent(
modelName string,
modelVersion uint32,
event agent.ModelEventMessage_Event,
) error {
// if the server is draining and the model load has succeeded, we need to "cancel"
if c.isDraining.Load() {
if event == agent.ModelEventMessage_LOADED {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, fmt.Errorf("server replica is draining"))
c.sendModelEventError(
modelName,
modelVersion,
agent.ModelEventMessage_LOAD_FAILED,
fmt.Errorf("server replica is draining"),
)
return nil
}
}
Expand All @@ -649,10 +726,12 @@ func (c *Client) sendAgentEvent(modelName string, modelVersion uint32, event age
func (c *Client) drainOnRequest(drainer *drainservice.DrainerService) error {
drainer.WaitOnTrigger()
c.isDraining.Store(true)

err := c.sendAgentDrainEvent()
if err != nil {
c.logger.WithError(err).Warn("Could not drain agent / server")
}

drainer.SetSchedulerDone()
return err
}
Expand All @@ -670,11 +749,17 @@ func (c *Client) sendAgentDrainEvent() error {
}

func (c *Client) sendModelScalingTriggerEvent(
modelName string, modelVersion uint32, scalingType modelscaling.ModelScalingEventType, amount uint32, data map[string]uint32) error {
modelName string,
modelVersion uint32,
scalingType modelscaling.ModelScalingEventType,
amount uint32,
data map[string]uint32,
) error {
triggerType := agent.ModelScalingTriggerMessage_SCALE_UP
if scalingType == modelscaling.ScaleDownEvent {
triggerType = agent.ModelScalingTriggerMessage_SCALE_DOWN
}

err := c.modelScalingClientStream.Send(&agent.ModelScalingTriggerMessage{
ServerName: c.serverName,
ReplicaIdx: c.replicaIdx,
Expand All @@ -695,19 +780,29 @@ func (c *Client) modelScalingEventsConsumer() {
if err != nil {
c.logger.WithError(err).Warnf(
"Trigger model scaling event %d for model %s failed",
e.EventType, e.StatsData.ModelName)
e.EventType,
e.StatsData.ModelName,
)
continue
} else {
c.logger.Debugf("Trigger model scaling event %d for model %s:%d with value %d",
e.EventType, modelName, modelVersion, e.StatsData.Value)
}

c.logger.Debugf(
"Trigger model scaling event %d for model %s:%d with value %d",
e.EventType,
modelName,
modelVersion,
e.StatsData.Value,
)

err = c.sendModelScalingTriggerEvent(
modelName, modelVersion, e.EventType, e.StatsData.Value, nil,
)
if err != nil {
c.logger.WithError(err).Warnf(
"Sending model scaling event %d for model %s failed",
e.EventType, e.StatsData.ModelName)
e.EventType,
e.StatsData.ModelName,
)
continue
}
}
Expand Down
Loading