Skip to content

Commit

Permalink
Added checks for dataset registration (#200)
Browse files Browse the repository at this point in the history
This pr maps the realm in the input dataset json to the registered computes. The realm in the dataset json specifies either region OR computeId. If the dataset is private, the specified region or computeId must be correct. If it is invalid, an error would be thrown and creation of the dataset would fail. In the case of public datasets, if the region or computeId does not match, the defaultRealm is selected which also specifies a default computeId. The dataset realm is used at the time of task creation to assign computeIds and the job tasks get assigned to that compute cluster.
  • Loading branch information
dhruvsgarg committed Aug 4, 2022
1 parent 0d3b13e commit 65f073b
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 31 deletions.
5 changes: 4 additions & 1 deletion api/dataset_components.partials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ DatasetInfo:
# helps launching an ML workload in a cluster or machine associated with the launcher).
realm:
type: string
computeId:
type: string
# if it is not public, the dataset meta info is filtered when search is done by other users
isPublic:
type: boolean
Expand All @@ -61,5 +63,6 @@ DatasetInfo:
description: dataset containing handwritten digits
url: https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
dataFormat: npy
realm: "us|west|org1|cluster1"
realm: "us/west/org1/cluster1"
computeId: "cluster1"
isPublic: true
2 changes: 2 additions & 0 deletions cmd/controller/app/database/db_interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,6 @@ type TaskService interface {
// ComputeService is an interface that defines a collection of APIs related to computes
type ComputeService interface {
RegisterCompute(openapi.ComputeSpec) (openapi.ComputeStatus, error)
GetComputeIdsByRegion(string) ([]string, error)
GetComputeById(string) (openapi.ComputeSpec, error)
}
70 changes: 66 additions & 4 deletions cmd/controller/app/database/mongodb/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,26 @@ import (
"go.uber.org/zap"

"github.com/cisco-open/flame/pkg/openapi"
"github.com/cisco-open/flame/pkg/util"
)

// RegisterCompute creates a new cluster compute specification and returns ComputeStatus
func (db *MongoService) RegisterCompute(computeSpec openapi.ComputeSpec) (openapi.ComputeStatus, error) {
// First check if the compute was previously registered
filter := bson.M{"computeid": computeSpec.ComputeId}
filter := bson.M{util.DBFieldComputeId: computeSpec.ComputeId}
checkResult := db.computeCollection.FindOne(context.TODO(), filter)
if (checkResult.Err() != nil) && (checkResult.Err() != mongo.ErrNoDocuments) {
errMsg := fmt.Sprintf("Failed to register compute : %v", checkResult.Err())
zap.S().Errorf(errMsg)

return openapi.ComputeStatus{}, ErrorCheck(checkResult.Err())
}
if checkResult.Err() == mongo.ErrNoDocuments {
// If it was not registered previously, need to register
result, err := db.computeCollection.InsertOne(context.TODO(), computeSpec)
if err != nil {
zap.S().Errorf("Failed to register new compute in database: result: %v, error: %v", result, err)

errMsg := fmt.Sprintf("Failed to register new compute in database: result: %v, error: %v", result, err)
zap.S().Errorf(errMsg)
return openapi.ComputeStatus{}, ErrorCheck(err)
}

Expand Down Expand Up @@ -100,7 +107,7 @@ func (db *MongoService) UpdateComputeStatus(computeId string, computeStatus open
setElements[dateKey] = updateTime
}

filter := bson.M{"computeid": computeId}
filter := bson.M{util.DBFieldComputeId: computeId}
update := bson.M{"$set": setElements}

updatedDoc := openapi.ComputeStatus{}
Expand All @@ -111,3 +118,58 @@ func (db *MongoService) UpdateComputeStatus(computeId string, computeStatus open

return updateTime, nil
}

func (db *MongoService) GetComputeIdsByRegion(region string) ([]string, error) {
zap.S().Infof("get all computes in the region: %s", region)

filter := bson.M{util.DBFieldComputeRegion: region}
cursor, err := db.computeCollection.Find(context.TODO(), filter)
if err != nil {
errMsg := fmt.Sprintf("failed to fetch computes in the region: %s, err : %v", region, err)
zap.S().Errorf(errMsg)

return []string{}, fmt.Errorf(errMsg)
}

defer cursor.Close(context.TODO())
var computeIdList []string

for cursor.Next(context.TODO()) {
var computeSpec openapi.ComputeSpec
if err = cursor.Decode(&computeSpec); err != nil {
errMsg := fmt.Sprintf("failed to decode compute spec with error: %v", err)
zap.S().Errorf(errMsg)

return []string{}, ErrorCheck(err)
}

computeIdList = append(computeIdList, computeSpec.ComputeId)
}

if len(computeIdList) == 0 {
errMsg := fmt.Sprintf("could not find any computes for the region: %s", region)
zap.S().Errorf(errMsg)

return []string{}, fmt.Errorf(errMsg)
}
return computeIdList, nil
}

func (db *MongoService) GetComputeById(computeId string) (openapi.ComputeSpec, error) {
filter := bson.M{util.DBFieldComputeId: computeId}
checkResult := db.computeCollection.FindOne(context.TODO(), filter)
if checkResult.Err() != nil {
errMsg := fmt.Sprintf("failed to find a compute with computeId: %s", computeId)
zap.S().Errorf(errMsg)
return openapi.ComputeSpec{}, fmt.Errorf(errMsg)
}

var currentDocument openapi.ComputeSpec
err := checkResult.Decode(&currentDocument)
if err != nil {
errMsg := fmt.Sprintf("Failed to parse currentDocument: %v", err)
zap.S().Errorf(errMsg)
return openapi.ComputeSpec{}, fmt.Errorf(errMsg)
}
return currentDocument, nil
}
2 changes: 2 additions & 0 deletions cmd/controller/app/database/mongodb/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ func (db *MongoService) CreateTasks(tasks []objects.Task, dirty bool) error {
util.DBFieldTaskType: task.Type,
"config": cfgData,
"code": task.ZippedCode,
util.DBFieldComputeId: task.ComputeId,
util.DBFieldTaskDirty: dirty,
util.DBFieldTaskKey: task.Key,
util.DBFieldState: openapi.READY,
util.DBFieldTimestamp: time.Now(),
},
}
zap.S().Debugf("taskID %s is assigned compute %s", task.TaskId, task.ComputeId)

after := options.After
upsert := true
Expand Down
7 changes: 4 additions & 3 deletions cmd/controller/app/job/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
)

const (
realmSep = "/"
defaultGroup = "default"
groupByTypeTag = "tag"
taskKeyLen = 32
Expand Down Expand Up @@ -350,7 +349,7 @@ func _walkForGroupByCheck(templates map[string]*taskTemplate, prevTmpl *taskTemp
minLen := 0
tmpLen := math.MaxInt32
for _, val := range channel.GroupBy.Value {
length := len(strings.Split(val, realmSep))
length := len(strings.Split(val, util.RealmSep))
tmpLen = funcMin(tmpLen, length)
}

Expand Down Expand Up @@ -466,6 +465,7 @@ func (tmpl *taskTemplate) buildTasks(prevPeer string, templates map[string]*task

for i, dataset := range datasets {
task := tmpl.Task
task.ComputeId = dataset.ComputeId
task.Configure(openapi.SYSTEM, util.RandString(taskKeyLen), dataset.Realm, dataset.Url, i)
tasks = append(tasks, task)
}
Expand All @@ -481,7 +481,8 @@ func (tmpl *taskTemplate) buildTasks(prevPeer string, templates map[string]*task

for i := 0; i < len(channel.GroupBy.Value); i++ {
task := tmpl.Task
realm := channel.GroupBy.Value[i] + realmSep + util.ProjectName
realm := channel.GroupBy.Value[i] + util.RealmSep + util.ProjectName
task.ComputeId = util.DefaultRealm
task.Configure(openapi.SYSTEM, util.RandString(taskKeyLen), realm, emptyDatasetUrl, i)

tasks = append(tasks, task)
Expand Down
3 changes: 2 additions & 1 deletion cmd/controller/app/job/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/cisco-open/flame/cmd/controller/config"
"github.com/cisco-open/flame/pkg/openapi"
"github.com/cisco-open/flame/pkg/util"
)

var (
Expand Down Expand Up @@ -123,7 +124,7 @@ var (
)

func composeGroup(tokens ...string) string {
return strings.Join(tokens, realmSep)
return strings.Join(tokens, util.RealmSep)
}

func TestGetTaskTemplates(t *testing.T) {
Expand Down
11 changes: 6 additions & 5 deletions cmd/controller/app/objects/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ import (
)

type Task struct {
JobId string `json:"jobid"`
TaskId string `json:"taskid"`
Role string `json:"role"`
Type openapi.TaskType `json:"type"`
Key string `json:"key"`
JobId string `json:"jobid"`
TaskId string `json:"taskid"`
Role string `json:"role"`
Type openapi.TaskType `json:"type"`
Key string `json:"key"`
ComputeId string `json:"computeid"`

// the following are config and code
JobConfig JobConfig
Expand Down
2 changes: 1 addition & 1 deletion fiab/helm-chart/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ mlflow:

deployer:
adminId: "admin-1"
region: "us-east"
region: "default/us/west"
computeId: "compute-1"
apiKey: "apiKey-1"

Expand Down
69 changes: 69 additions & 0 deletions pkg/openapi/controller/api_datasets_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ import (
"errors"
"fmt"
"net/http"
"path"
"strings"

"go.uber.org/zap"

"github.com/cisco-open/flame/cmd/controller/app/database"
"github.com/cisco-open/flame/pkg/openapi"
"github.com/cisco-open/flame/pkg/util"
)

// DatasetsApiService is a service that implents the logic for the DatasetsApiServicer
Expand All @@ -52,6 +57,32 @@ func NewDatasetsApiService(dbService database.DBService) openapi.DatasetsApiServ
// CreateDataset - Create meta info for a new dataset.
func (s *DatasetsApiService) CreateDataset(ctx context.Context, user string,
datasetInfo openapi.DatasetInfo) (openapi.ImplResponse, error) {
// get datasetRealm field and identify if it is computeId _or_ region
datasetRealm := datasetInfo.Realm
datasetInfo.ComputeId = util.DefaultRealm
isRegion := strings.HasPrefix(datasetRealm, util.DefaultRealm+util.RealmSep)
// TODO enforce compute registration starting with defaultRealm by pre-pending region.

// Public dataset: The realm may not map correctly to a region or compute. For a valid datasetInfo.Realm, expansion will take place.
// If datasetInfo.Realm represents an invalid region, it is retained (for groupBy) and computeId will be defaultRealm.
// If datasetInfo.Realm represents an invalid computeId, an error is thrown. The assumption is that the user
// specified a computeId for some specific reason(s) and it may not be correct to re-assign it to another cluster.
// Private dataset: The realm must map to a correct region or registered compute. For a valid datasetInfo.Realm, expansion will take place.

if isRegion {
err := s.checkComputeForDatasetByRegion(datasetRealm, &datasetInfo)
if err != nil {
return openapi.Response(http.StatusInternalServerError, nil), err
}
} else {
err := s.checkComputeForDatasetByComputeId(datasetRealm, &datasetInfo)
if err != nil {
return openapi.Response(http.StatusInternalServerError, nil), err
}
}

zap.S().Infof("after checks for dataset creation, datasetInfo.Realm: %v and datasetInfo.computeId: %v",
datasetInfo.Realm, datasetInfo.ComputeId)
datasetId, err := s.dbService.CreateDataset(user, datasetInfo)
if err != nil {
return openapi.Response(http.StatusInternalServerError, nil), fmt.Errorf("failed to create new dataset: %v", err)
Expand Down Expand Up @@ -110,3 +141,41 @@ func (s *DatasetsApiService) UpdateDataset(ctx context.Context, user string, dat

return openapi.Response(http.StatusNotImplemented, nil), errors.New("UpdateDataset method not implemented")
}

func (s *DatasetsApiService) checkComputeForDatasetByRegion(datasetRealm string, datasetInfo *openapi.DatasetInfo) error {
eligibleComputes, err := s.dbService.GetComputeIdsByRegion(datasetRealm)
if len(eligibleComputes) == 0 {
// No compute found for the given region.
// If dataset is public, retain the realm within specified in dataset.json. Set compute to DefaultRealm.
if !datasetInfo.IsPublic {
errorMsg := fmt.Sprintf("failed to create new private dataset for region: %v, error: %v", datasetRealm, err)
zap.S().Errorf(errorMsg)
return fmt.Errorf(errorMsg)
} else {
zap.S().Infof("couldnt find compute for public dataset in region %v, assigning defaultRealm", datasetRealm)
datasetInfo.ComputeId = util.DefaultRealm
}
} else {
// if region is found for public or private dataset, perform realm expansion
datasetInfo.Realm = path.Join(datasetRealm, eligibleComputes[0])
datasetInfo.ComputeId = datasetInfo.Realm[strings.LastIndex(datasetInfo.Realm, util.RealmSep)+1:]
zap.S().Infof("found compute: %s for dataset in region %v", eligibleComputes[0], datasetRealm)
}
return nil
}

func (s *DatasetsApiService) checkComputeForDatasetByComputeId(datasetRealm string, datasetInfo *openapi.DatasetInfo) error {
computeSpec, err := s.dbService.GetComputeById(datasetRealm)
if err != nil {
// No compute found for given computeId for public or private dataset. Return an error.
errorMsg := fmt.Sprintf("computeId %s not found while trying to create new dataset. Err: %v", datasetRealm, err)
zap.S().Errorf(errorMsg)
return fmt.Errorf(errorMsg)
} else {
// if computeId is found for public or private dataset, perform realm expansion
datasetInfo.Realm = path.Join(computeSpec.Region, datasetRealm)
datasetInfo.ComputeId = datasetInfo.Realm[strings.LastIndex(datasetInfo.Realm, util.RealmSep)+1:]
zap.S().Infof("compute: %s for dataset is valid, region returned: %v", datasetRealm, computeSpec.Region)
}
return nil
}
2 changes: 2 additions & 0 deletions pkg/openapi/model_dataset_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,7 @@ type DatasetInfo struct {

Realm string `json:"realm"`

ComputeId string `json:"computeId,omitempty"`

IsPublic bool `json:"isPublic,omitempty"`
}
36 changes: 20 additions & 16 deletions pkg/util/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,23 @@ const (

// Database Fields
//TODO append Field to distinguish the fields
DBFieldMongoID = "_id"
DBFieldUserId = "userid"
DBFieldId = "id"
DBFieldDesignId = "designid"
DBFieldSchemaId = "schemaid"
DBFieldJobId = "jobid"
DBFieldTaskId = "taskid"
DBFieldState = "state"
DBFieldRole = "role"
DBFieldTaskLog = "log"
DBFieldTaskType = "type"
DBFieldIsPublic = "ispublic"
DBFieldTaskDirty = "dirty"
DBFieldTaskKey = "key"
DBFieldTimestamp = "timestamp"
DBFieldComputeId = "computeid"
DBFieldMongoID = "_id"
DBFieldUserId = "userid"
DBFieldId = "id"
DBFieldDesignId = "designid"
DBFieldSchemaId = "schemaid"
DBFieldJobId = "jobid"
DBFieldTaskId = "taskid"
DBFieldState = "state"
DBFieldRole = "role"
DBFieldTaskLog = "log"
DBFieldTaskType = "type"
DBFieldIsPublic = "ispublic"
DBFieldTaskDirty = "dirty"
DBFieldTaskKey = "key"
DBFieldTimestamp = "timestamp"
DBFieldComputeId = "computeid"
DBFieldComputeRegion = "region"

// Port numbers
ApiServerRestApiPort = 10100 // REST API port
Expand Down Expand Up @@ -82,4 +83,7 @@ const (
TaskCodeFile = "code.zip"

LogDirPath = "/var/log/" + ProjectName

DefaultRealm = "default"
RealmSep = "/"
)

0 comments on commit 65f073b

Please sign in to comment.