From 0a31f76300dcfbeb50591afc461623fd5d0634b0 Mon Sep 17 00:00:00 2001 From: Denis Gukov Date: Sun, 28 Jan 2024 16:24:37 +0500 Subject: [PATCH 1/2] refactor(vault): getObjects -> getProjectObjects for sql db --- db/sql/SqlDb.go | 16 +++++++++++----- db/sql/access_key.go | 41 ++++++++++++++++++++++++++++++++++++++--- db/sql/environment.go | 2 +- db/sql/inventory.go | 2 +- db/sql/runner.go | 2 +- db/sql/view.go | 2 +- 6 files changed, 53 insertions(+), 12 deletions(-) diff --git a/db/sql/SqlDb.go b/db/sql/SqlDb.go index 6a391f30f..a2a079f52 100644 --- a/db/sql/SqlDb.go +++ b/db/sql/SqlDb.go @@ -209,14 +209,16 @@ func (d *SqlDb) getObject(projectID int, props db.ObjectProps, objectID int, obj return } -func (d *SqlDb) getObjects(projectID int, props db.ObjectProps, params db.RetrieveQueryParams, objects interface{}) (err error) { +func (d *SqlDb) getObjects(projectID int, props db.ObjectProps, params db.RetrieveQueryParams, objects interface{}, ignoreProjectId bool) (err error) { q := squirrel.Select("*"). From(props.TableName + " pe") - if props.IsGlobal { - q = q.Where("pe.project_id is null") - } else { - q = q.Where("pe.project_id=?", projectID) + if !ignoreProjectId { + if props.IsGlobal { + q = q.Where("pe.project_id is null") + } else { + q = q.Where("pe.project_id=?", projectID) + } } orderDirection := "ASC" @@ -244,6 +246,10 @@ func (d *SqlDb) getObjects(projectID int, props db.ObjectProps, params db.Retrie return } +func (d *SqlDb) getProjectObjects(projectID int, props db.ObjectProps, params db.RetrieveQueryParams, objects interface{}) (err error) { + return d.getObjects(projectID, props, params, objects, false) +} + func (d *SqlDb) deleteObject(projectID int, props db.ObjectProps, objectID int) error { if props.IsGlobal { return validateMutationResult( diff --git a/db/sql/access_key.go b/db/sql/access_key.go index 00d361823..8f0c4a13d 100644 --- a/db/sql/access_key.go +++ b/db/sql/access_key.go @@ -17,7 +17,7 @@ func (d *SqlDb) GetAccessKeyRefs(projectID int, keyID int) (db.ObjectReferrers, func (d *SqlDb) GetAccessKeys(projectID int, params db.RetrieveQueryParams) ([]db.AccessKey, error) { var keys []db.AccessKey - err := d.getObjects(projectID, db.AccessKeyProps, params, &keys) + err := d.getProjectObjects(projectID, db.AccessKeyProps, params, &keys) return keys, err } @@ -84,7 +84,42 @@ func (d *SqlDb) DeleteAccessKey(projectID int, accessKeyID int) error { return d.deleteObject(projectID, db.AccessKeyProps, accessKeyID) } -func (d *SqlDb) RekeyAccessKeys(oldKey string) error { +const RekeyBatchSize = 100 - return nil +func (d *SqlDb) RekeyAccessKeys(oldKey string) (err error) { + + var globalProps = db.AccessKeyProps + globalProps.IsGlobal = true + + for i := 0; ; i++ { + + var keys []db.AccessKey + err = d.getObjects(-1, globalProps, db.RetrieveQueryParams{Count: RekeyBatchSize, Offset: i * RekeyBatchSize}, &keys, true) + + if err != nil { + return + } + + if len(keys) == 0 { + break + } + + for _, key := range keys { + + err = key.DeserializeSecret2(oldKey) + + if err != nil { + return err + } + + key.OverrideSecret = true + err = d.UpdateAccessKey(key) + + if err != nil { + return err + } + } + } + + return } diff --git a/db/sql/environment.go b/db/sql/environment.go index 31b745a2a..6ad9f2fbf 100644 --- a/db/sql/environment.go +++ b/db/sql/environment.go @@ -16,7 +16,7 @@ func (d *SqlDb) GetEnvironmentRefs(projectID int, environmentID int) (db.ObjectR func (d *SqlDb) GetEnvironments(projectID int, params db.RetrieveQueryParams) ([]db.Environment, error) { var environment []db.Environment - err := d.getObjects(projectID, db.EnvironmentProps, params, &environment) + err := d.getProjectObjects(projectID, db.EnvironmentProps, params, &environment) return environment, err } diff --git a/db/sql/inventory.go b/db/sql/inventory.go index 7ea3fcd65..656e2d301 100644 --- a/db/sql/inventory.go +++ b/db/sql/inventory.go @@ -14,7 +14,7 @@ func (d *SqlDb) GetInventory(projectID int, inventoryID int) (inventory db.Inven func (d *SqlDb) GetInventories(projectID int, params db.RetrieveQueryParams) ([]db.Inventory, error) { var inventories []db.Inventory - err := d.getObjects(projectID, db.InventoryProps, params, &inventories) + err := d.getProjectObjects(projectID, db.InventoryProps, params, &inventories) return inventories, err } diff --git a/db/sql/runner.go b/db/sql/runner.go index 535ce27a3..1eeb53054 100644 --- a/db/sql/runner.go +++ b/db/sql/runner.go @@ -24,7 +24,7 @@ func (d *SqlDb) GetGlobalRunner(runnerID int) (runner db.Runner, err error) { } func (d *SqlDb) GetGlobalRunners() (runners []db.Runner, err error) { - err = d.getObjects(0, db.GlobalRunnerProps, db.RetrieveQueryParams{}, &runners) + err = d.getProjectObjects(0, db.GlobalRunnerProps, db.RetrieveQueryParams{}, &runners) return } diff --git a/db/sql/view.go b/db/sql/view.go index 691805a72..b677e802e 100644 --- a/db/sql/view.go +++ b/db/sql/view.go @@ -8,7 +8,7 @@ func (d *SqlDb) GetView(projectID int, viewID int) (view db.View, err error) { } func (d *SqlDb) GetViews(projectID int) (views []db.View, err error) { - err = d.getObjects(projectID, db.ViewProps, db.RetrieveQueryParams{}, &views) + err = d.getProjectObjects(projectID, db.ViewProps, db.RetrieveQueryParams{}, &views) return } From 84fdfa4623ae14e4467a94ea606d332c1f932afd Mon Sep 17 00:00:00 2001 From: Denis Gukov Date: Sun, 28 Jan 2024 17:18:07 +0500 Subject: [PATCH 2/2] fix(vault): offset in sql query --- db/sql/SqlDb.go | 8 ++++++++ db/sql/access_key.go | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/db/sql/SqlDb.go b/db/sql/SqlDb.go index a2a079f52..313d7a62d 100644 --- a/db/sql/SqlDb.go +++ b/db/sql/SqlDb.go @@ -235,6 +235,14 @@ func (d *SqlDb) getObjects(projectID int, props db.ObjectProps, params db.Retrie q = q.OrderBy("pe." + orderColumn + " " + orderDirection) } + if params.Count > 0 { + q = q.Limit(uint64(params.Count)) + } + + if params.Offset > 0 { + q = q.Offset(uint64(params.Offset)) + } + query, args, err := q.ToSql() if err != nil { diff --git a/db/sql/access_key.go b/db/sql/access_key.go index 8f0c4a13d..cbc504ecf 100644 --- a/db/sql/access_key.go +++ b/db/sql/access_key.go @@ -2,6 +2,7 @@ package sql import ( "database/sql" + "errors" "github.com/ansible-semaphore/semaphore/db" ) @@ -115,7 +116,7 @@ func (d *SqlDb) RekeyAccessKeys(oldKey string) (err error) { key.OverrideSecret = true err = d.UpdateAccessKey(key) - if err != nil { + if err != nil && !errors.Is(err, db.ErrNotFound) { return err } }