From d9e237e3f268d786af028ba8e7cc24f70a6d6a36 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 23 Sep 2021 04:09:29 +0800 Subject: [PATCH 01/13] Fix problem when database id is not increment as expected (#17124) --- models/action.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/action.go b/models/action.go index f957e2b0cd9a..f17f75779bd6 100644 --- a/models/action.go +++ b/models/action.go @@ -317,7 +317,7 @@ func GetFeeds(opts GetFeedsOptions) ([]*Action, error) { actions := make([]*Action, 0, setting.UI.FeedPagingNum) - if err := db.DefaultContext().Engine().Limit(setting.UI.FeedPagingNum).Desc("id").Where(cond).Find(&actions); err != nil { + if err := db.DefaultContext().Engine().Limit(setting.UI.FeedPagingNum).Desc("created_unix").Where(cond).Find(&actions); err != nil { return nil, fmt.Errorf("Find: %v", err) } From d9c69596fff1a1482cbc15ac220f9d5e1829a5ea Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 23 Sep 2021 18:50:06 +0800 Subject: [PATCH 02/13] Fix commit status index problem (#17061) * Fix commit status index problem * remove unused functions * Add fixture and test for migration * Fix lint * Fix fixture * Fix lint * Fix test * Fix bug * Fix bug --- models/commit_status.go | 107 ++++++++++++++++++++---- models/db/index.go | 5 +- models/fixtures/commit_status_index.yml | 5 ++ models/migrations/migrations.go | 2 + models/migrations/v195.go | 47 +++++++++++ models/migrations/v195_test.go | 62 ++++++++++++++ models/pull.go | 2 +- 7 files changed, 211 insertions(+), 19 deletions(-) create mode 100644 models/fixtures/commit_status_index.yml create mode 100644 models/migrations/v195.go create mode 100644 models/migrations/v195_test.go diff --git a/models/commit_status.go b/models/commit_status.go index f3639e819ea7..7ec233e80d02 100644 --- a/models/commit_status.go +++ b/models/commit_status.go @@ -40,6 +40,82 @@ type CommitStatus struct { func init() { db.RegisterModel(new(CommitStatus)) + db.RegisterModel(new(CommitStatusIndex)) +} + +// upsertCommitStatusIndex the function will not return until it acquires the lock or receives an error. +func upsertCommitStatusIndex(e db.Engine, repoID int64, sha string) (err error) { + // An atomic UPSERT operation (INSERT/UPDATE) is the only operation + // that ensures that the key is actually locked. + switch { + case setting.Database.UseSQLite3 || setting.Database.UsePostgreSQL: + _, err = e.Exec("INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+ + "VALUES (?,?,1) ON CONFLICT (repo_id,sha) DO UPDATE SET max_index = `commit_status_index`.max_index+1", + repoID, sha) + case setting.Database.UseMySQL: + _, err = e.Exec("INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+ + "VALUES (?,?,1) ON DUPLICATE KEY UPDATE max_index = max_index+1", + repoID, sha) + case setting.Database.UseMSSQL: + // https://weblogs.sqlteam.com/dang/2009/01/31/upsert-race-condition-with-merge/ + _, err = e.Exec("MERGE `commit_status_index` WITH (HOLDLOCK) as target "+ + "USING (SELECT ? AS repo_id, ? AS sha) AS src "+ + "ON src.repo_id = target.repo_id AND src.sha = target.sha "+ + "WHEN MATCHED THEN UPDATE SET target.max_index = target.max_index+1 "+ + "WHEN NOT MATCHED THEN INSERT (repo_id, sha, max_index) "+ + "VALUES (src.repo_id, src.sha, 1);", + repoID, sha) + default: + return fmt.Errorf("database type not supported") + } + return +} + +// GetNextCommitStatusIndex retried 3 times to generate a resource index +func GetNextCommitStatusIndex(repoID int64, sha string) (int64, error) { + for i := 0; i < db.MaxDupIndexAttempts; i++ { + idx, err := getNextCommitStatusIndex(repoID, sha) + if err == db.ErrResouceOutdated { + continue + } + if err != nil { + return 0, err + } + return idx, nil + } + return 0, db.ErrGetResourceIndexFailed +} + +// getNextCommitStatusIndex return the next index +func getNextCommitStatusIndex(repoID int64, sha string) (int64, error) { + ctx, commiter, err := db.TxContext() + if err != nil { + return 0, err + } + defer commiter.Close() + + var preIdx int64 + _, err = ctx.Engine().SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ?", repoID, sha).Get(&preIdx) + if err != nil { + return 0, err + } + + if err := upsertCommitStatusIndex(ctx.Engine(), repoID, sha); err != nil { + return 0, err + } + + var curIdx int64 + has, err := ctx.Engine().SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ? AND max_index=?", repoID, sha, preIdx+1).Get(&curIdx) + if err != nil { + return 0, err + } + if !has { + return 0, db.ErrResouceOutdated + } + if err := commiter.Commit(); err != nil { + return 0, err + } + return curIdx, nil } func (status *CommitStatus) loadAttributes(e db.Engine) (err error) { @@ -142,6 +218,14 @@ func sortCommitStatusesSession(sess *xorm.Session, sortType string) { } } +// CommitStatusIndex represents a table for commit status index +type CommitStatusIndex struct { + ID int64 + RepoID int64 `xorm:"unique(repo_sha)"` + SHA string `xorm:"unique(repo_sha)"` + MaxIndex int64 `xorm:"index"` +} + // GetLatestCommitStatus returns all statuses with a unique context for a given commit. func GetLatestCommitStatus(repoID int64, sha string, listOptions ListOptions) ([]*CommitStatus, error) { return getLatestCommitStatus(db.DefaultContext().Engine(), repoID, sha, listOptions) @@ -206,6 +290,12 @@ func NewCommitStatus(opts NewCommitStatusOptions) error { return fmt.Errorf("NewCommitStatus[%s, %s]: no user specified", repoPath, opts.SHA) } + // Get the next Status Index + idx, err := GetNextCommitStatusIndex(opts.Repo.ID, opts.SHA) + if err != nil { + return fmt.Errorf("generate commit status index failed: %v", err) + } + ctx, committer, err := db.TxContext() if err != nil { return fmt.Errorf("NewCommitStatus[repo_id: %d, user_id: %d, sha: %s]: %v", opts.Repo.ID, opts.Creator.ID, opts.SHA, err) @@ -218,22 +308,7 @@ func NewCommitStatus(opts NewCommitStatusOptions) error { opts.CommitStatus.SHA = opts.SHA opts.CommitStatus.CreatorID = opts.Creator.ID opts.CommitStatus.RepoID = opts.Repo.ID - - // Get the next Status Index - var nextIndex int64 - lastCommitStatus := &CommitStatus{ - SHA: opts.SHA, - RepoID: opts.Repo.ID, - } - has, err := ctx.Engine().Desc("index").Limit(1).Get(lastCommitStatus) - if err != nil { - return fmt.Errorf("NewCommitStatus[%s, %s]: %v", repoPath, opts.SHA, err) - } - if has { - log.Debug("NewCommitStatus[%s, %s]: found", repoPath, opts.SHA) - nextIndex = lastCommitStatus.Index - } - opts.CommitStatus.Index = nextIndex + 1 + opts.CommitStatus.Index = idx log.Debug("NewCommitStatus[%s, %s]: %d", repoPath, opts.SHA, opts.CommitStatus.Index) opts.CommitStatus.ContextHash = hashCommitStatusContext(opts.CommitStatus.Context) diff --git a/models/db/index.go b/models/db/index.go index 873289db54e1..0086a8f54806 100644 --- a/models/db/index.go +++ b/models/db/index.go @@ -54,12 +54,13 @@ var ( ) const ( - maxDupIndexAttempts = 3 + // MaxDupIndexAttempts max retry times to create index + MaxDupIndexAttempts = 3 ) // GetNextResourceIndex retried 3 times to generate a resource index func GetNextResourceIndex(tableName string, groupID int64) (int64, error) { - for i := 0; i < maxDupIndexAttempts; i++ { + for i := 0; i < MaxDupIndexAttempts; i++ { idx, err := getNextResourceIndex(tableName, groupID) if err == ErrResouceOutdated { continue diff --git a/models/fixtures/commit_status_index.yml b/models/fixtures/commit_status_index.yml new file mode 100644 index 000000000000..3f252e87ef09 --- /dev/null +++ b/models/fixtures/commit_status_index.yml @@ -0,0 +1,5 @@ +- + id: 1 + repo_id: 1 + sha: "1234123412341234123412341234123412341234" + max_index: 5 \ No newline at end of file diff --git a/models/migrations/migrations.go b/models/migrations/migrations.go index fb6958f2da61..3f90e5e74a3f 100644 --- a/models/migrations/migrations.go +++ b/models/migrations/migrations.go @@ -342,6 +342,8 @@ var migrations = []Migration{ NewMigration("Add repo id column for attachment table", addRepoIDForAttachment), // v194 -> v195 NewMigration("Add Branch Protection Unprotected Files Column", addBranchProtectionUnprotectedFilesColumn), + // v196 -> v197 + NewMigration("Add table commit_status_index", addTableCommitStatusIndex), } // GetCurrentDBVersion returns the current db version diff --git a/models/migrations/v195.go b/models/migrations/v195.go new file mode 100644 index 000000000000..06694eb57df9 --- /dev/null +++ b/models/migrations/v195.go @@ -0,0 +1,47 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package migrations + +import ( + "fmt" + + "xorm.io/xorm" +) + +func addTableCommitStatusIndex(x *xorm.Engine) error { + // CommitStatusIndex represents a table for commit status index + type CommitStatusIndex struct { + ID int64 + RepoID int64 `xorm:"unique(repo_sha)"` + SHA string `xorm:"unique(repo_sha)"` + MaxIndex int64 `xorm:"index"` + } + + if err := x.Sync2(new(CommitStatusIndex)); err != nil { + return fmt.Errorf("Sync2: %v", err) + } + + sess := x.NewSession() + defer sess.Close() + + if err := sess.Begin(); err != nil { + return err + } + + // Remove data we're goint to rebuild + if _, err := sess.Table("commit_status_index").Where("1=1").Delete(&CommitStatusIndex{}); err != nil { + return err + } + + // Create current data for all repositories with issues and PRs + if _, err := sess.Exec("INSERT INTO commit_status_index (repo_id, sha, max_index) " + + "SELECT max_data.repo_id, max_data.sha, max_data.max_index " + + "FROM ( SELECT commit_status.repo_id AS repo_id, commit_status.sha AS sha, max(commit_status.`index`) AS max_index " + + "FROM commit_status GROUP BY commit_status.repo_id, commit_status.sha) AS max_data"); err != nil { + return err + } + + return sess.Commit() +} diff --git a/models/migrations/v195_test.go b/models/migrations/v195_test.go new file mode 100644 index 000000000000..baf9cb61c2a0 --- /dev/null +++ b/models/migrations/v195_test.go @@ -0,0 +1,62 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package migrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_addTableCommitStatusIndex(t *testing.T) { + // Create the models used in the migration + type CommitStatus struct { + ID int64 `xorm:"pk autoincr"` + Index int64 `xorm:"INDEX UNIQUE(repo_sha_index)"` + RepoID int64 `xorm:"INDEX UNIQUE(repo_sha_index)"` + SHA string `xorm:"VARCHAR(64) NOT NULL INDEX UNIQUE(repo_sha_index)"` + } + + // Prepare and load the testing database + x, deferable := prepareTestEnv(t, 0, new(CommitStatus)) + if x == nil || t.Failed() { + defer deferable() + return + } + defer deferable() + + // Run the migration + if err := addTableCommitStatusIndex(x); err != nil { + assert.NoError(t, err) + return + } + + type CommitStatusIndex struct { + ID int64 + RepoID int64 `xorm:"unique(repo_sha)"` + SHA string `xorm:"unique(repo_sha)"` + MaxIndex int64 `xorm:"index"` + } + + var start = 0 + const batchSize = 1000 + for { + var indexes = make([]CommitStatusIndex, 0, batchSize) + err := x.Table("commit_status_index").Limit(batchSize, start).Find(&indexes) + assert.NoError(t, err) + + for _, idx := range indexes { + var maxIndex int + has, err := x.SQL("SELECT max(`index`) FROM commit_status WHERE repo_id = ? AND sha = ?", idx.RepoID, idx.SHA).Get(&maxIndex) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, maxIndex, idx.MaxIndex) + } + if len(indexes) < batchSize { + break + } + start += len(indexes) + } +} diff --git a/models/pull.go b/models/pull.go index 92516735761d..5cb7b57286f6 100644 --- a/models/pull.go +++ b/models/pull.go @@ -450,7 +450,7 @@ func (pr *PullRequest) SetMerged() (bool, error) { func NewPullRequest(repo *Repository, issue *Issue, labelIDs []int64, uuids []string, pr *PullRequest) (err error) { idx, err := db.GetNextResourceIndex("issue_index", repo.ID) if err != nil { - return fmt.Errorf("generate issue index failed: %v", err) + return fmt.Errorf("generate pull request index failed: %v", err) } issue.Index = idx From b22be7f594401d7bd81196750456ce52185bd391 Mon Sep 17 00:00:00 2001 From: delvh Date: Thu, 23 Sep 2021 14:42:42 +0200 Subject: [PATCH 03/13] Fix typo skipping a migration (#17130) --- models/migrations/migrations.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/migrations/migrations.go b/models/migrations/migrations.go index 3f90e5e74a3f..753ca063d95b 100644 --- a/models/migrations/migrations.go +++ b/models/migrations/migrations.go @@ -342,7 +342,7 @@ var migrations = []Migration{ NewMigration("Add repo id column for attachment table", addRepoIDForAttachment), // v194 -> v195 NewMigration("Add Branch Protection Unprotected Files Column", addBranchProtectionUnprotectedFilesColumn), - // v196 -> v197 + // v195 -> v196 NewMigration("Add table commit_status_index", addTableCommitStatusIndex), } From 9302eba971611601c3ebf6024e22a11c63f4e151 Mon Sep 17 00:00:00 2001 From: zeripath Date: Thu, 23 Sep 2021 16:45:36 +0100 Subject: [PATCH 04/13] DBContext is just a Context (#17100) * DBContext is just a Context This PR removes some of the specialness from the DBContext and makes it context This allows us to simplify the GetEngine code to wrap around any context in future and means that we can change our loadRepo(e Engine) functions to simply take contexts. Signed-off-by: Andrew Thornton * fix unit tests Signed-off-by: Andrew Thornton * another place that needs to set the initial context Signed-off-by: Andrew Thornton * avoid race Signed-off-by: Andrew Thornton * change attachment error Signed-off-by: Andrew Thornton --- models/access.go | 6 +- models/access_test.go | 12 +- models/action.go | 6 +- models/action_list.go | 6 +- models/admin.go | 22 ++-- models/attachment.go | 45 +++---- models/attachment_test.go | 2 +- models/avatar.go | 9 +- models/branches.go | 32 ++--- models/commit_status.go | 12 +- models/consistency.go | 72 +++++----- models/consistency_test.go | 6 +- models/db/context.go | 85 ++++++++++-- models/db/engine.go | 5 + models/db/unit_tests.go | 6 + models/engine_test.go | 2 +- models/external_login_user.go | 18 +-- models/fixture_generation.go | 4 +- models/gpg_key.go | 10 +- models/gpg_key_add.go | 4 +- models/gpg_key_import.go | 2 +- models/gpg_key_verify.go | 4 +- models/issue.go | 147 ++++++++++----------- models/issue_assignees.go | 10 +- models/issue_comment.go | 58 ++++----- models/issue_comment_list.go | 8 +- models/issue_dependency.go | 8 +- models/issue_label.go | 53 ++++---- models/issue_list.go | 14 +- models/issue_lock.go | 2 +- models/issue_milestone.go | 40 +++--- models/issue_milestone_test.go | 12 +- models/issue_reaction.go | 12 +- models/issue_reaction_test.go | 2 +- models/issue_stopwatch.go | 12 +- models/issue_test.go | 14 +- models/issue_tracked_time.go | 16 +-- models/issue_user.go | 11 +- models/issue_user_test.go | 4 +- models/issue_watch.go | 16 +-- models/issue_xref.go | 4 +- models/issue_xref_test.go | 4 +- models/lfs.go | 18 +-- models/lfs_lock.go | 12 +- models/list_options.go | 2 +- models/login_source.go | 28 ++-- models/migrate.go | 12 +- models/notification.go | 36 ++--- models/oauth2.go | 4 +- models/oauth2_application.go | 36 ++--- models/org.go | 62 ++++----- models/org_team.go | 56 ++++---- models/org_test.go | 2 +- models/project.go | 14 +- models/project_board.go | 16 +-- models/project_issue.go | 16 +-- models/protected_tag.go | 10 +- models/pull.go | 40 +++--- models/pull_list.go | 12 +- models/release.go | 43 +++--- models/repo.go | 166 ++++++++++++------------ models/repo_activity.go | 6 +- models/repo_archiver.go | 17 +-- models/repo_avatar.go | 10 +- models/repo_collaboration.go | 18 +-- models/repo_collaboration_test.go | 2 +- models/repo_generate.go | 21 +-- models/repo_indexer.go | 6 +- models/repo_language_stats.go | 8 +- models/repo_list.go | 8 +- models/repo_mirror.go | 10 +- models/repo_permission.go | 16 +-- models/repo_permission_test.go | 8 +- models/repo_pushmirror.go | 14 +- models/repo_redirect.go | 2 +- models/repo_redirect_test.go | 6 +- models/repo_test.go | 2 +- models/repo_transfer.go | 8 +- models/repo_watch.go | 20 +-- models/review.go | 60 ++++----- models/session.go | 14 +- models/ssh_key.go | 24 ++-- models/ssh_key_authorized_keys.go | 6 +- models/ssh_key_authorized_principals.go | 6 +- models/ssh_key_deploy.go | 18 +-- models/ssh_key_principals.go | 4 +- models/star.go | 12 +- models/statistic.go | 32 ++--- models/task.go | 18 +-- models/token.go | 16 +-- models/topic.go | 14 +- models/twofactor.go | 8 +- models/u2f.go | 10 +- models/update.go | 13 +- models/upload.go | 8 +- models/user.go | 86 ++++++------ models/user_avatar.go | 6 +- models/user_follow.go | 6 +- models/user_heatmap.go | 2 +- models/user_mail.go | 28 ++-- models/user_openid.go | 12 +- models/user_redirect.go | 2 +- models/user_redirect_test.go | 6 +- models/user_test.go | 8 +- models/userlist.go | 4 +- models/webhook.go | 40 +++--- modules/doctor/mergebase.go | 2 +- modules/doctor/misc.go | 2 +- modules/repository/adopt.go | 3 +- modules/repository/check.go | 8 +- modules/repository/create.go | 3 +- modules/repository/fork.go | 7 +- modules/repository/generate.go | 8 +- modules/repository/hooks.go | 2 +- modules/repository/init.go | 8 +- modules/repository/repo.go | 2 +- modules/repository/update.go | 5 +- routers/web/org/org_labels.go | 2 +- routers/web/repo/issue_label.go | 2 +- routers/web/repo/repo.go | 8 +- services/attachment/attachment.go | 3 +- services/comments/comments.go | 2 +- services/issue/issue.go | 2 +- services/mirror/mirror_pull.go | 2 +- services/pull/pull.go | 2 +- services/pull/review.go | 4 +- services/release/release.go | 4 +- services/repository/generate.go | 4 +- services/repository/push.go | 2 +- 129 files changed, 1112 insertions(+), 1022 deletions(-) diff --git a/models/access.go b/models/access.go index e230629040b6..88fbe8189fa8 100644 --- a/models/access.go +++ b/models/access.go @@ -116,7 +116,7 @@ func (repoAccess) TableName() string { // GetRepositoryAccesses finds all repositories with their access mode where a user has access but does not own. func (user *User) GetRepositoryAccesses() (map[*Repository]AccessMode, error) { - rows, err := db.DefaultContext().Engine(). + rows, err := db.GetEngine(db.DefaultContext). Join("INNER", "repository", "repository.id = access.repo_id"). Where("access.user_id = ?", user.ID). And("repository.owner_id <> ?", user.ID). @@ -151,7 +151,7 @@ func (user *User) GetRepositoryAccesses() (map[*Repository]AccessMode, error) { // GetAccessibleRepositories finds repositories which the user has access but does not own. // If limit is smaller than 1 means returns all found results. func (user *User) GetAccessibleRepositories(limit int) (repos []*Repository, _ error) { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Where("owner_id !=? ", user.ID). Desc("updated_unix") if limit > 0 { @@ -342,5 +342,5 @@ func (repo *Repository) recalculateAccesses(e db.Engine) error { // RecalculateAccesses recalculates all accesses for repository. func (repo *Repository) RecalculateAccesses() error { - return repo.recalculateAccesses(db.DefaultContext().Engine()) + return repo.recalculateAccesses(db.GetEngine(db.DefaultContext)) } diff --git a/models/access_test.go b/models/access_test.go index 875b2a0c1ab5..2f641bb9b573 100644 --- a/models/access_test.go +++ b/models/access_test.go @@ -127,12 +127,12 @@ func TestRepository_RecalculateAccesses(t *testing.T) { repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository) assert.NoError(t, repo1.GetOwner()) - _, err := db.DefaultContext().Engine().Delete(&Collaboration{UserID: 2, RepoID: 3}) + _, err := db.GetEngine(db.DefaultContext).Delete(&Collaboration{UserID: 2, RepoID: 3}) assert.NoError(t, err) assert.NoError(t, repo1.RecalculateAccesses()) access := &Access{UserID: 2, RepoID: 3} - has, err := db.DefaultContext().Engine().Get(access) + has, err := db.GetEngine(db.DefaultContext).Get(access) assert.NoError(t, err) assert.True(t, has) assert.Equal(t, AccessModeOwner, access.Mode) @@ -144,11 +144,11 @@ func TestRepository_RecalculateAccesses2(t *testing.T) { repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository) assert.NoError(t, repo1.GetOwner()) - _, err := db.DefaultContext().Engine().Delete(&Collaboration{UserID: 4, RepoID: 4}) + _, err := db.GetEngine(db.DefaultContext).Delete(&Collaboration{UserID: 4, RepoID: 4}) assert.NoError(t, err) assert.NoError(t, repo1.RecalculateAccesses()) - has, err := db.DefaultContext().Engine().Get(&Access{UserID: 4, RepoID: 4}) + has, err := db.GetEngine(db.DefaultContext).Get(&Access{UserID: 4, RepoID: 4}) assert.NoError(t, err) assert.False(t, has) } @@ -158,7 +158,7 @@ func TestRepository_RecalculateAccesses3(t *testing.T) { team5 := db.AssertExistsAndLoadBean(t, &Team{ID: 5}).(*Team) user29 := db.AssertExistsAndLoadBean(t, &User{ID: 29}).(*User) - has, err := db.DefaultContext().Engine().Get(&Access{UserID: 29, RepoID: 23}) + has, err := db.GetEngine(db.DefaultContext).Get(&Access{UserID: 29, RepoID: 23}) assert.NoError(t, err) assert.False(t, has) @@ -166,7 +166,7 @@ func TestRepository_RecalculateAccesses3(t *testing.T) { // even though repo 23 is public assert.NoError(t, AddTeamMember(team5, user29.ID)) - has, err = db.DefaultContext().Engine().Get(&Access{UserID: 29, RepoID: 23}) + has, err = db.GetEngine(db.DefaultContext).Get(&Access{UserID: 29, RepoID: 23}) assert.NoError(t, err) assert.True(t, has) } diff --git a/models/action.go b/models/action.go index f17f75779bd6..7c970e1fdb6f 100644 --- a/models/action.go +++ b/models/action.go @@ -208,7 +208,7 @@ func GetRepositoryFromMatch(ownerName, repoName string) (*Repository, error) { // GetCommentLink returns link to action comment. func (a *Action) GetCommentLink() string { - return a.getCommentLink(db.DefaultContext().Engine()) + return a.getCommentLink(db.GetEngine(db.DefaultContext)) } func (a *Action) getCommentLink(e db.Engine) string { @@ -317,7 +317,7 @@ func GetFeeds(opts GetFeedsOptions) ([]*Action, error) { actions := make([]*Action, 0, setting.UI.FeedPagingNum) - if err := db.DefaultContext().Engine().Limit(setting.UI.FeedPagingNum).Desc("created_unix").Where(cond).Find(&actions); err != nil { + if err := db.GetEngine(db.DefaultContext).Limit(setting.UI.FeedPagingNum).Desc("created_unix").Where(cond).Find(&actions); err != nil { return nil, fmt.Errorf("Find: %v", err) } @@ -408,6 +408,6 @@ func DeleteOldActions(olderThan time.Duration) (err error) { return nil } - _, err = db.DefaultContext().Engine().Where("created_unix < ?", time.Now().Add(-olderThan).Unix()).Delete(&Action{}) + _, err = db.GetEngine(db.DefaultContext).Where("created_unix < ?", time.Now().Add(-olderThan).Unix()).Delete(&Action{}) return } diff --git a/models/action_list.go b/models/action_list.go index a0ceef6bee52..69e6aa73121f 100644 --- a/models/action_list.go +++ b/models/action_list.go @@ -45,7 +45,7 @@ func (actions ActionList) loadUsers(e db.Engine) ([]*User, error) { // LoadUsers loads actions' all users func (actions ActionList) LoadUsers() ([]*User, error) { - return actions.loadUsers(db.DefaultContext().Engine()) + return actions.loadUsers(db.GetEngine(db.DefaultContext)) } func (actions ActionList) getRepoIDs() []int64 { @@ -80,7 +80,7 @@ func (actions ActionList) loadRepositories(e db.Engine) ([]*Repository, error) { // LoadRepositories loads actions' all repositories func (actions ActionList) LoadRepositories() ([]*Repository, error) { - return actions.loadRepositories(db.DefaultContext().Engine()) + return actions.loadRepositories(db.GetEngine(db.DefaultContext)) } // loadAttributes loads all attributes @@ -98,5 +98,5 @@ func (actions ActionList) loadAttributes(e db.Engine) (err error) { // LoadAttributes loads attributes of the actions func (actions ActionList) LoadAttributes() error { - return actions.loadAttributes(db.DefaultContext().Engine()) + return actions.loadAttributes(db.GetEngine(db.DefaultContext)) } diff --git a/models/admin.go b/models/admin.go index 084942ff5e4b..27a2032e2cf1 100644 --- a/models/admin.go +++ b/models/admin.go @@ -44,7 +44,7 @@ func (n *Notice) TrStr() string { // CreateNotice creates new system notice. func CreateNotice(tp NoticeType, desc string, args ...interface{}) error { - return createNotice(db.DefaultContext().Engine(), tp, desc, args...) + return createNotice(db.GetEngine(db.DefaultContext), tp, desc, args...) } func createNotice(e db.Engine, tp NoticeType, desc string, args ...interface{}) error { @@ -61,19 +61,19 @@ func createNotice(e db.Engine, tp NoticeType, desc string, args ...interface{}) // CreateRepositoryNotice creates new system notice with type NoticeRepository. func CreateRepositoryNotice(desc string, args ...interface{}) error { - return createNotice(db.DefaultContext().Engine(), NoticeRepository, desc, args...) + return createNotice(db.GetEngine(db.DefaultContext), NoticeRepository, desc, args...) } // RemoveAllWithNotice removes all directories in given path and // creates a system notice when error occurs. func RemoveAllWithNotice(title, path string) { - removeAllWithNotice(db.DefaultContext().Engine(), title, path) + removeAllWithNotice(db.GetEngine(db.DefaultContext), title, path) } // RemoveStorageWithNotice removes a file from the storage and // creates a system notice when error occurs. func RemoveStorageWithNotice(bucket storage.ObjectStorage, title, path string) { - removeStorageWithNotice(db.DefaultContext().Engine(), bucket, title, path) + removeStorageWithNotice(db.GetEngine(db.DefaultContext), bucket, title, path) } func removeStorageWithNotice(e db.Engine, bucket storage.ObjectStorage, title, path string) { @@ -98,14 +98,14 @@ func removeAllWithNotice(e db.Engine, title, path string) { // CountNotices returns number of notices. func CountNotices() int64 { - count, _ := db.DefaultContext().Engine().Count(new(Notice)) + count, _ := db.GetEngine(db.DefaultContext).Count(new(Notice)) return count } // Notices returns notices in given page. func Notices(page, pageSize int) ([]*Notice, error) { notices := make([]*Notice, 0, pageSize) - return notices, db.DefaultContext().Engine(). + return notices, db.GetEngine(db.DefaultContext). Limit(pageSize, (page-1)*pageSize). Desc("id"). Find(¬ices) @@ -113,18 +113,18 @@ func Notices(page, pageSize int) ([]*Notice, error) { // DeleteNotice deletes a system notice by given ID. func DeleteNotice(id int64) error { - _, err := db.DefaultContext().Engine().ID(id).Delete(new(Notice)) + _, err := db.GetEngine(db.DefaultContext).ID(id).Delete(new(Notice)) return err } // DeleteNotices deletes all notices with ID from start to end (inclusive). func DeleteNotices(start, end int64) error { if start == 0 && end == 0 { - _, err := db.DefaultContext().Engine().Exec("DELETE FROM notice") + _, err := db.GetEngine(db.DefaultContext).Exec("DELETE FROM notice") return err } - sess := db.DefaultContext().Engine().Where("id >= ?", start) + sess := db.GetEngine(db.DefaultContext).Where("id >= ?", start) if end > 0 { sess.And("id <= ?", end) } @@ -137,7 +137,7 @@ func DeleteNoticesByIDs(ids []int64) error { if len(ids) == 0 { return nil } - _, err := db.DefaultContext().Engine(). + _, err := db.GetEngine(db.DefaultContext). In("id", ids). Delete(new(Notice)) return err @@ -146,7 +146,7 @@ func DeleteNoticesByIDs(ids []int64) error { // GetAdminUser returns the first administrator func GetAdminUser() (*User, error) { var admin User - has, err := db.DefaultContext().Engine().Where("is_admin=?", true).Get(&admin) + has, err := db.GetEngine(db.DefaultContext).Where("is_admin=?", true).Get(&admin) if err != nil { return nil, err } else if !has { diff --git a/models/attachment.go b/models/attachment.go index 36f318db6db9..f06b389dc697 100644 --- a/models/attachment.go +++ b/models/attachment.go @@ -5,6 +5,7 @@ package models import ( + "context" "fmt" "path" @@ -38,7 +39,7 @@ func init() { // IncreaseDownloadCount is update download count + 1 func (a *Attachment) IncreaseDownloadCount() error { // Update download count. - if _, err := db.DefaultContext().Engine().Exec("UPDATE `attachment` SET download_count=download_count+1 WHERE id=?", a.ID); err != nil { + if _, err := db.GetEngine(db.DefaultContext).Exec("UPDATE `attachment` SET download_count=download_count+1 WHERE id=?", a.ID); err != nil { return fmt.Errorf("increase attachment count: %v", err) } @@ -86,7 +87,7 @@ func (a *Attachment) LinkedRepository() (*Repository, UnitType, error) { // GetAttachmentByID returns attachment by given id func GetAttachmentByID(id int64) (*Attachment, error) { - return getAttachmentByID(db.DefaultContext().Engine(), id) + return getAttachmentByID(db.GetEngine(db.DefaultContext), id) } func getAttachmentByID(e db.Engine, id int64) (*Attachment, error) { @@ -111,8 +112,8 @@ func getAttachmentByUUID(e db.Engine, uuid string) (*Attachment, error) { } // GetAttachmentsByUUIDs returns attachment by given UUID list. -func GetAttachmentsByUUIDs(ctx *db.Context, uuids []string) ([]*Attachment, error) { - return getAttachmentsByUUIDs(ctx.Engine(), uuids) +func GetAttachmentsByUUIDs(ctx context.Context, uuids []string) ([]*Attachment, error) { + return getAttachmentsByUUIDs(db.GetEngine(ctx), uuids) } func getAttachmentsByUUIDs(e db.Engine, uuids []string) ([]*Attachment, error) { @@ -127,17 +128,17 @@ func getAttachmentsByUUIDs(e db.Engine, uuids []string) ([]*Attachment, error) { // GetAttachmentByUUID returns attachment by given UUID. func GetAttachmentByUUID(uuid string) (*Attachment, error) { - return getAttachmentByUUID(db.DefaultContext().Engine(), uuid) + return getAttachmentByUUID(db.GetEngine(db.DefaultContext), uuid) } // ExistAttachmentsByUUID returns true if attachment is exist by given UUID func ExistAttachmentsByUUID(uuid string) (bool, error) { - return db.DefaultContext().Engine().Where("`uuid`=?", uuid).Exist(new(Attachment)) + return db.GetEngine(db.DefaultContext).Where("`uuid`=?", uuid).Exist(new(Attachment)) } // GetAttachmentByReleaseIDFileName returns attachment by given releaseId and fileName. func GetAttachmentByReleaseIDFileName(releaseID int64, fileName string) (*Attachment, error) { - return getAttachmentByReleaseIDFileName(db.DefaultContext().Engine(), releaseID, fileName) + return getAttachmentByReleaseIDFileName(db.GetEngine(db.DefaultContext), releaseID, fileName) } func getAttachmentsByIssueID(e db.Engine, issueID int64) ([]*Attachment, error) { @@ -147,12 +148,12 @@ func getAttachmentsByIssueID(e db.Engine, issueID int64) ([]*Attachment, error) // GetAttachmentsByIssueID returns all attachments of an issue. func GetAttachmentsByIssueID(issueID int64) ([]*Attachment, error) { - return getAttachmentsByIssueID(db.DefaultContext().Engine(), issueID) + return getAttachmentsByIssueID(db.GetEngine(db.DefaultContext), issueID) } // GetAttachmentsByCommentID returns all attachments if comment by given ID. func GetAttachmentsByCommentID(commentID int64) ([]*Attachment, error) { - return getAttachmentsByCommentID(db.DefaultContext().Engine(), commentID) + return getAttachmentsByCommentID(db.GetEngine(db.DefaultContext), commentID) } func getAttachmentsByCommentID(e db.Engine, commentID int64) ([]*Attachment, error) { @@ -174,12 +175,12 @@ func getAttachmentByReleaseIDFileName(e db.Engine, releaseID int64, fileName str // DeleteAttachment deletes the given attachment and optionally the associated file. func DeleteAttachment(a *Attachment, remove bool) error { - _, err := DeleteAttachments(db.DefaultContext(), []*Attachment{a}, remove) + _, err := DeleteAttachments(db.DefaultContext, []*Attachment{a}, remove) return err } // DeleteAttachments deletes the given attachments and optionally the associated files. -func DeleteAttachments(ctx *db.Context, attachments []*Attachment, remove bool) (int, error) { +func DeleteAttachments(ctx context.Context, attachments []*Attachment, remove bool) (int, error) { if len(attachments) == 0 { return 0, nil } @@ -189,7 +190,7 @@ func DeleteAttachments(ctx *db.Context, attachments []*Attachment, remove bool) ids = append(ids, a.ID) } - cnt, err := ctx.Engine().In("id", ids).NoAutoCondition().Delete(attachments[0]) + cnt, err := db.GetEngine(ctx).In("id", ids).NoAutoCondition().Delete(attachments[0]) if err != nil { return 0, err } @@ -211,7 +212,7 @@ func DeleteAttachmentsByIssue(issueID int64, remove bool) (int, error) { return 0, err } - return DeleteAttachments(db.DefaultContext(), attachments, remove) + return DeleteAttachments(db.DefaultContext, attachments, remove) } // DeleteAttachmentsByComment deletes all attachments associated with the given comment. @@ -221,20 +222,20 @@ func DeleteAttachmentsByComment(commentID int64, remove bool) (int, error) { return 0, err } - return DeleteAttachments(db.DefaultContext(), attachments, remove) + return DeleteAttachments(db.DefaultContext, attachments, remove) } // UpdateAttachment updates the given attachment in database func UpdateAttachment(atta *Attachment) error { - return updateAttachment(db.DefaultContext().Engine(), atta) + return updateAttachment(db.GetEngine(db.DefaultContext), atta) } // UpdateAttachmentByUUID Updates attachment via uuid -func UpdateAttachmentByUUID(ctx *db.Context, attach *Attachment, cols ...string) error { +func UpdateAttachmentByUUID(ctx context.Context, attach *Attachment, cols ...string) error { if attach.UUID == "" { - return fmt.Errorf("Attachement uuid should not blank") + return fmt.Errorf("attachment uuid should be not blank") } - _, err := ctx.Engine().Where("uuid=?", attach.UUID).Cols(cols...).Update(attach) + _, err := db.GetEngine(ctx).Where("uuid=?", attach.UUID).Cols(cols...).Update(attach) return err } @@ -252,7 +253,7 @@ func updateAttachment(e db.Engine, atta *Attachment) error { // DeleteAttachmentsByRelease deletes all attachments associated with the given release. func DeleteAttachmentsByRelease(releaseID int64) error { - _, err := db.DefaultContext().Engine().Where("release_id = ?", releaseID).Delete(&Attachment{}) + _, err := db.GetEngine(db.DefaultContext).Where("release_id = ?", releaseID).Delete(&Attachment{}) return err } @@ -262,7 +263,7 @@ func IterateAttachment(f func(attach *Attachment) error) error { const batchSize = 100 for { attachments := make([]*Attachment, 0, batchSize) - if err := db.DefaultContext().Engine().Limit(batchSize, start).Find(&attachments); err != nil { + if err := db.GetEngine(db.DefaultContext).Limit(batchSize, start).Find(&attachments); err != nil { return err } if len(attachments) == 0 { @@ -280,13 +281,13 @@ func IterateAttachment(f func(attach *Attachment) error) error { // CountOrphanedAttachments returns the number of bad attachments func CountOrphanedAttachments() (int64, error) { - return db.DefaultContext().Engine().Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))"). + return db.GetEngine(db.DefaultContext).Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))"). Count(new(Attachment)) } // DeleteOrphanedAttachments delete all bad attachments func DeleteOrphanedAttachments() error { - _, err := db.DefaultContext().Engine().Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))"). + _, err := db.GetEngine(db.DefaultContext).Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))"). Delete(new(Attachment)) return err } diff --git a/models/attachment_test.go b/models/attachment_test.go index 3e8e78a0a319..725d5a40c0b2 100644 --- a/models/attachment_test.go +++ b/models/attachment_test.go @@ -92,7 +92,7 @@ func TestUpdateAttachment(t *testing.T) { func TestGetAttachmentsByUUIDs(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) - attachList, err := GetAttachmentsByUUIDs(db.DefaultContext(), []string{"a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a17", "not-existing-uuid"}) + attachList, err := GetAttachmentsByUUIDs(db.DefaultContext, []string{"a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a17", "not-existing-uuid"}) assert.NoError(t, err) assert.Len(t, attachList, 2) assert.Equal(t, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", attachList[0].UUID) diff --git a/models/avatar.go b/models/avatar.go index e81b876667ab..71b14ad915b5 100644 --- a/models/avatar.go +++ b/models/avatar.go @@ -5,6 +5,7 @@ package models import ( + "context" "crypto/md5" "fmt" "net/url" @@ -64,7 +65,7 @@ func GetEmailForHash(md5Sum string) (string, error) { Hash: strings.ToLower(strings.TrimSpace(md5Sum)), } - _, err := db.DefaultContext().Engine().Get(&emailHash) + _, err := db.GetEngine(db.DefaultContext).Get(&emailHash) return emailHash.Email, err }) } @@ -95,13 +96,13 @@ func HashedAvatarLink(email string, size int) string { Hash: sum, } // OK we're going to open a session just because I think that that might hide away any problems with postgres reporting errors - if err := db.WithTx(func(ctx *db.Context) error { - has, err := ctx.Engine().Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash)) + if err := db.WithTx(func(ctx context.Context) error { + has, err := db.GetEngine(ctx).Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash)) if has || err != nil { // Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time return nil } - _, _ = ctx.Engine().Insert(emailHash) + _, _ = db.GetEngine(ctx).Insert(emailHash) return nil }); err != nil { // Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time diff --git a/models/branches.go b/models/branches.go index 617dfe488d89..8eaa4b6fd74b 100644 --- a/models/branches.go +++ b/models/branches.go @@ -122,7 +122,7 @@ func (protectBranch *ProtectedBranch) IsUserMergeWhitelisted(userID int64, permi // IsUserOfficialReviewer check if user is official reviewer for the branch (counts towards required approvals) func (protectBranch *ProtectedBranch) IsUserOfficialReviewer(user *User) (bool, error) { - return protectBranch.isUserOfficialReviewer(db.DefaultContext().Engine(), user) + return protectBranch.isUserOfficialReviewer(db.GetEngine(db.DefaultContext), user) } func (protectBranch *ProtectedBranch) isUserOfficialReviewer(e db.Engine, user *User) (bool, error) { @@ -162,7 +162,7 @@ func (protectBranch *ProtectedBranch) HasEnoughApprovals(pr *PullRequest) bool { // GetGrantedApprovalsCount returns the number of granted approvals for pr. A granted approval must be authored by a user in an approval whitelist. func (protectBranch *ProtectedBranch) GetGrantedApprovalsCount(pr *PullRequest) int64 { - sess := db.DefaultContext().Engine().Where("issue_id = ?", pr.IssueID). + sess := db.GetEngine(db.DefaultContext).Where("issue_id = ?", pr.IssueID). And("type = ?", ReviewTypeApprove). And("official = ?", true). And("dismissed = ?", false) @@ -183,7 +183,7 @@ func (protectBranch *ProtectedBranch) MergeBlockedByRejectedReview(pr *PullReque if !protectBranch.BlockOnRejectedReviews { return false } - rejectExist, err := db.DefaultContext().Engine().Where("issue_id = ?", pr.IssueID). + rejectExist, err := db.GetEngine(db.DefaultContext).Where("issue_id = ?", pr.IssueID). And("type = ?", ReviewTypeReject). And("official = ?", true). And("dismissed = ?", false). @@ -202,7 +202,7 @@ func (protectBranch *ProtectedBranch) MergeBlockedByOfficialReviewRequests(pr *P if !protectBranch.BlockOnOfficialReviewRequests { return false } - has, err := db.DefaultContext().Engine().Where("issue_id = ?", pr.IssueID). + has, err := db.GetEngine(db.DefaultContext).Where("issue_id = ?", pr.IssueID). And("type = ?", ReviewTypeRequest). And("official = ?", true). Exist(new(Review)) @@ -300,7 +300,7 @@ func (protectBranch *ProtectedBranch) IsUnprotectedFile(patterns []glob.Glob, pa // GetProtectedBranchBy getting protected branch by ID/Name func GetProtectedBranchBy(repoID int64, branchName string) (*ProtectedBranch, error) { - return getProtectedBranchBy(db.DefaultContext().Engine(), repoID, branchName) + return getProtectedBranchBy(db.GetEngine(db.DefaultContext), repoID, branchName) } func getProtectedBranchBy(e db.Engine, repoID int64, branchName string) (*ProtectedBranch, error) { @@ -375,13 +375,13 @@ func UpdateProtectBranch(repo *Repository, protectBranch *ProtectedBranch, opts // Make sure protectBranch.ID is not 0 for whitelists if protectBranch.ID == 0 { - if _, err = db.DefaultContext().Engine().Insert(protectBranch); err != nil { + if _, err = db.GetEngine(db.DefaultContext).Insert(protectBranch); err != nil { return fmt.Errorf("Insert: %v", err) } return nil } - if _, err = db.DefaultContext().Engine().ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil { + if _, err = db.GetEngine(db.DefaultContext).ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil { return fmt.Errorf("Update: %v", err) } @@ -391,7 +391,7 @@ func UpdateProtectBranch(repo *Repository, protectBranch *ProtectedBranch, opts // GetProtectedBranches get all protected branches func (repo *Repository) GetProtectedBranches() ([]*ProtectedBranch, error) { protectedBranches := make([]*ProtectedBranch, 0) - return protectedBranches, db.DefaultContext().Engine().Find(&protectedBranches, &ProtectedBranch{RepoID: repo.ID}) + return protectedBranches, db.GetEngine(db.DefaultContext).Find(&protectedBranches, &ProtectedBranch{RepoID: repo.ID}) } // GetBranchProtection get the branch protection of a branch @@ -406,7 +406,7 @@ func (repo *Repository) IsProtectedBranch(branchName string) (bool, error) { BranchName: branchName, } - has, err := db.DefaultContext().Engine().Exist(protectedBranch) + has, err := db.GetEngine(db.DefaultContext).Exist(protectedBranch) if err != nil { return true, err } @@ -493,7 +493,7 @@ func (repo *Repository) DeleteProtectedBranch(id int64) (err error) { ID: id, } - if affected, err := db.DefaultContext().Engine().Delete(protectedBranch); err != nil { + if affected, err := db.GetEngine(db.DefaultContext).Delete(protectedBranch); err != nil { return err } else if affected != 1 { return fmt.Errorf("delete protected branch ID(%v) failed", id) @@ -522,20 +522,20 @@ func (repo *Repository) AddDeletedBranch(branchName, commit string, deletedByID DeletedByID: deletedByID, } - _, err := db.DefaultContext().Engine().InsertOne(deletedBranch) + _, err := db.GetEngine(db.DefaultContext).InsertOne(deletedBranch) return err } // GetDeletedBranches returns all the deleted branches func (repo *Repository) GetDeletedBranches() ([]*DeletedBranch, error) { deletedBranches := make([]*DeletedBranch, 0) - return deletedBranches, db.DefaultContext().Engine().Where("repo_id = ?", repo.ID).Desc("deleted_unix").Find(&deletedBranches) + return deletedBranches, db.GetEngine(db.DefaultContext).Where("repo_id = ?", repo.ID).Desc("deleted_unix").Find(&deletedBranches) } // GetDeletedBranchByID get a deleted branch by its ID func (repo *Repository) GetDeletedBranchByID(id int64) (*DeletedBranch, error) { deletedBranch := &DeletedBranch{} - has, err := db.DefaultContext().Engine().ID(id).Get(deletedBranch) + has, err := db.GetEngine(db.DefaultContext).ID(id).Get(deletedBranch) if err != nil { return nil, err } @@ -552,7 +552,7 @@ func (repo *Repository) RemoveDeletedBranch(id int64) (err error) { ID: id, } - if affected, err := db.DefaultContext().Engine().Delete(deletedBranch); err != nil { + if affected, err := db.GetEngine(db.DefaultContext).Delete(deletedBranch); err != nil { return err } else if affected != 1 { return fmt.Errorf("remove deleted branch ID(%v) failed", id) @@ -573,7 +573,7 @@ func (deletedBranch *DeletedBranch) LoadUser() { // RemoveDeletedBranch removes all deleted branches func RemoveDeletedBranch(repoID int64, branch string) error { - _, err := db.DefaultContext().Engine().Where("repo_id=? AND name=?", repoID, branch).Delete(new(DeletedBranch)) + _, err := db.GetEngine(db.DefaultContext).Where("repo_id=? AND name=?", repoID, branch).Delete(new(DeletedBranch)) return err } @@ -583,7 +583,7 @@ func RemoveOldDeletedBranches(ctx context.Context, olderThan time.Duration) { log.Trace("Doing: DeletedBranchesCleanup") deleteBefore := time.Now().Add(-olderThan) - _, err := db.DefaultContext().Engine().Where("deleted_unix < ?", deleteBefore.Unix()).Delete(new(DeletedBranch)) + _, err := db.GetEngine(db.DefaultContext).Where("deleted_unix < ?", deleteBefore.Unix()).Delete(new(DeletedBranch)) if err != nil { log.Error("DeletedBranchesCleanup: %v", err) } diff --git a/models/commit_status.go b/models/commit_status.go index 7ec233e80d02..ada94667cccc 100644 --- a/models/commit_status.go +++ b/models/commit_status.go @@ -136,7 +136,7 @@ func (status *CommitStatus) loadAttributes(e db.Engine) (err error) { // APIURL returns the absolute APIURL to this commit-status. func (status *CommitStatus) APIURL() string { - _ = status.loadAttributes(db.DefaultContext().Engine()) + _ = status.loadAttributes(db.GetEngine(db.DefaultContext)) return fmt.Sprintf("%sapi/v1/repos/%s/statuses/%s", setting.AppURL, status.Repo.FullName(), status.SHA) } @@ -193,7 +193,7 @@ func GetCommitStatuses(repo *Repository, sha string, opts *CommitStatusOptions) } func listCommitStatusesStatement(repo *Repository, sha string, opts *CommitStatusOptions) *xorm.Session { - sess := db.DefaultContext().Engine().Where("repo_id = ?", repo.ID).And("sha = ?", sha) + sess := db.GetEngine(db.DefaultContext).Where("repo_id = ?", repo.ID).And("sha = ?", sha) switch opts.State { case "pending", "success", "error", "failure", "warning": sess.And("state = ?", opts.State) @@ -228,7 +228,7 @@ type CommitStatusIndex struct { // GetLatestCommitStatus returns all statuses with a unique context for a given commit. func GetLatestCommitStatus(repoID int64, sha string, listOptions ListOptions) ([]*CommitStatus, error) { - return getLatestCommitStatus(db.DefaultContext().Engine(), repoID, sha, listOptions) + return getLatestCommitStatus(db.GetEngine(db.DefaultContext), repoID, sha, listOptions) } func getLatestCommitStatus(e db.Engine, repoID int64, sha string, listOptions ListOptions) ([]*CommitStatus, error) { @@ -255,7 +255,7 @@ func getLatestCommitStatus(e db.Engine, repoID int64, sha string, listOptions Li func FindRepoRecentCommitStatusContexts(repoID int64, before time.Duration) ([]string, error) { start := timeutil.TimeStampNow().AddDuration(-before) ids := make([]int64, 0, 10) - if err := db.DefaultContext().Engine().Table("commit_status"). + if err := db.GetEngine(db.DefaultContext).Table("commit_status"). Where("repo_id = ?", repoID). And("updated_unix >= ?", start). Select("max( id ) as id"). @@ -268,7 +268,7 @@ func FindRepoRecentCommitStatusContexts(repoID int64, before time.Duration) ([]s if len(ids) == 0 { return contexts, nil } - return contexts, db.DefaultContext().Engine().Select("context").Table("commit_status").In("id", ids).Find(&contexts) + return contexts, db.GetEngine(db.DefaultContext).Select("context").Table("commit_status").In("id", ids).Find(&contexts) } // NewCommitStatusOptions holds options for creating a CommitStatus @@ -314,7 +314,7 @@ func NewCommitStatus(opts NewCommitStatusOptions) error { opts.CommitStatus.ContextHash = hashCommitStatusContext(opts.CommitStatus.Context) // Insert new CommitStatus - if _, err = ctx.Engine().Insert(opts.CommitStatus); err != nil { + if _, err = db.GetEngine(ctx).Insert(opts.CommitStatus); err != nil { return fmt.Errorf("Insert CommitStatus[%s, %s]: %v", repoPath, opts.SHA, err) } diff --git a/models/consistency.go b/models/consistency.go index cc02e32de441..8af884365ebc 100644 --- a/models/consistency.go +++ b/models/consistency.go @@ -41,7 +41,7 @@ func CheckConsistencyFor(t *testing.T, beansToCheck ...interface{}) { ptrToSliceValue := reflect.New(sliceType) ptrToSliceValue.Elem().Set(sliceValue) - assert.NoError(t, db.DefaultContext().Engine().Table(bean).Find(ptrToSliceValue.Interface())) + assert.NoError(t, db.GetEngine(db.DefaultContext).Table(bean).Find(ptrToSliceValue.Interface())) sliceValue = ptrToSliceValue.Elem() for i := 0; i < sliceValue.Len(); i++ { @@ -66,7 +66,7 @@ func getCount(t *testing.T, e db.Engine, bean interface{}) int64 { // assertCount test the count of database entries matching bean func assertCount(t *testing.T, bean interface{}, expected int) { - assert.EqualValues(t, expected, getCount(t, db.DefaultContext().Engine(), bean), + assert.EqualValues(t, expected, getCount(t, db.GetEngine(db.DefaultContext), bean), "Failed consistency test, the counted bean (of type %T) was %+v", bean, bean) } @@ -92,33 +92,33 @@ func (repo *Repository) checkForConsistency(t *testing.T) { db.AssertExistsAndLoadBean(t, &Repository{ID: repo.ForkID}) } - actual := getCount(t, db.DefaultContext().Engine().Where("Mode<>?", RepoWatchModeDont), &Watch{RepoID: repo.ID}) + actual := getCount(t, db.GetEngine(db.DefaultContext).Where("Mode<>?", RepoWatchModeDont), &Watch{RepoID: repo.ID}) assert.EqualValues(t, repo.NumWatches, actual, "Unexpected number of watches for repo %+v", repo) - actual = getCount(t, db.DefaultContext().Engine().Where("is_pull=?", false), &Issue{RepoID: repo.ID}) + actual = getCount(t, db.GetEngine(db.DefaultContext).Where("is_pull=?", false), &Issue{RepoID: repo.ID}) assert.EqualValues(t, repo.NumIssues, actual, "Unexpected number of issues for repo %+v", repo) - actual = getCount(t, db.DefaultContext().Engine().Where("is_pull=? AND is_closed=?", false, true), &Issue{RepoID: repo.ID}) + actual = getCount(t, db.GetEngine(db.DefaultContext).Where("is_pull=? AND is_closed=?", false, true), &Issue{RepoID: repo.ID}) assert.EqualValues(t, repo.NumClosedIssues, actual, "Unexpected number of closed issues for repo %+v", repo) - actual = getCount(t, db.DefaultContext().Engine().Where("is_pull=?", true), &Issue{RepoID: repo.ID}) + actual = getCount(t, db.GetEngine(db.DefaultContext).Where("is_pull=?", true), &Issue{RepoID: repo.ID}) assert.EqualValues(t, repo.NumPulls, actual, "Unexpected number of pulls for repo %+v", repo) - actual = getCount(t, db.DefaultContext().Engine().Where("is_pull=? AND is_closed=?", true, true), &Issue{RepoID: repo.ID}) + actual = getCount(t, db.GetEngine(db.DefaultContext).Where("is_pull=? AND is_closed=?", true, true), &Issue{RepoID: repo.ID}) assert.EqualValues(t, repo.NumClosedPulls, actual, "Unexpected number of closed pulls for repo %+v", repo) - actual = getCount(t, db.DefaultContext().Engine().Where("is_closed=?", true), &Milestone{RepoID: repo.ID}) + actual = getCount(t, db.GetEngine(db.DefaultContext).Where("is_closed=?", true), &Milestone{RepoID: repo.ID}) assert.EqualValues(t, repo.NumClosedMilestones, actual, "Unexpected number of closed milestones for repo %+v", repo) } func (issue *Issue) checkForConsistency(t *testing.T) { - actual := getCount(t, db.DefaultContext().Engine().Where("type=?", CommentTypeComment), &Comment{IssueID: issue.ID}) + actual := getCount(t, db.GetEngine(db.DefaultContext).Where("type=?", CommentTypeComment), &Comment{IssueID: issue.ID}) assert.EqualValues(t, issue.NumComments, actual, "Unexpected number of comments for issue %+v", issue) if issue.IsPull { @@ -136,7 +136,7 @@ func (pr *PullRequest) checkForConsistency(t *testing.T) { func (milestone *Milestone) checkForConsistency(t *testing.T) { assertCount(t, &Issue{MilestoneID: milestone.ID}, milestone.NumIssues) - actual := getCount(t, db.DefaultContext().Engine().Where("is_closed=?", true), &Issue{MilestoneID: milestone.ID}) + actual := getCount(t, db.GetEngine(db.DefaultContext).Where("is_closed=?", true), &Issue{MilestoneID: milestone.ID}) assert.EqualValues(t, milestone.NumClosedIssues, actual, "Unexpected number of closed issues for milestone %+v", milestone) @@ -149,7 +149,7 @@ func (milestone *Milestone) checkForConsistency(t *testing.T) { func (label *Label) checkForConsistency(t *testing.T) { issueLabels := make([]*IssueLabel, 0, 10) - assert.NoError(t, db.DefaultContext().Engine().Find(&issueLabels, &IssueLabel{LabelID: label.ID})) + assert.NoError(t, db.GetEngine(db.DefaultContext).Find(&issueLabels, &IssueLabel{LabelID: label.ID})) assert.EqualValues(t, label.NumIssues, len(issueLabels), "Unexpected number of issue for label %+v", label) @@ -160,7 +160,7 @@ func (label *Label) checkForConsistency(t *testing.T) { expected := int64(0) if len(issueIDs) > 0 { - expected = getCount(t, db.DefaultContext().Engine().In("id", issueIDs).Where("is_closed=?", true), &Issue{}) + expected = getCount(t, db.GetEngine(db.DefaultContext).In("id", issueIDs).Where("is_closed=?", true), &Issue{}) } assert.EqualValues(t, expected, label.NumClosedIssues, "Unexpected number of closed issues for label %+v", label) @@ -178,12 +178,12 @@ func (action *Action) checkForConsistency(t *testing.T) { // CountOrphanedLabels return count of labels witch are broken and not accessible via ui anymore func CountOrphanedLabels() (int64, error) { - noref, err := db.DefaultContext().Engine().Table("label").Where("repo_id=? AND org_id=?", 0, 0).Count("label.id") + noref, err := db.GetEngine(db.DefaultContext).Table("label").Where("repo_id=? AND org_id=?", 0, 0).Count("label.id") if err != nil { return 0, err } - norepo, err := db.DefaultContext().Engine().Table("label"). + norepo, err := db.GetEngine(db.DefaultContext).Table("label"). Where(builder.And( builder.Gt{"repo_id": 0}, builder.NotIn("repo_id", builder.Select("id").From("repository")), @@ -193,7 +193,7 @@ func CountOrphanedLabels() (int64, error) { return 0, err } - noorg, err := db.DefaultContext().Engine().Table("label"). + noorg, err := db.GetEngine(db.DefaultContext).Table("label"). Where(builder.And( builder.Gt{"org_id": 0}, builder.NotIn("org_id", builder.Select("id").From("user")), @@ -209,12 +209,12 @@ func CountOrphanedLabels() (int64, error) { // DeleteOrphanedLabels delete labels witch are broken and not accessible via ui anymore func DeleteOrphanedLabels() error { // delete labels with no reference - if _, err := db.DefaultContext().Engine().Table("label").Where("repo_id=? AND org_id=?", 0, 0).Delete(new(Label)); err != nil { + if _, err := db.GetEngine(db.DefaultContext).Table("label").Where("repo_id=? AND org_id=?", 0, 0).Delete(new(Label)); err != nil { return err } // delete labels with none existing repos - if _, err := db.DefaultContext().Engine(). + if _, err := db.GetEngine(db.DefaultContext). Where(builder.And( builder.Gt{"repo_id": 0}, builder.NotIn("repo_id", builder.Select("id").From("repository")), @@ -224,7 +224,7 @@ func DeleteOrphanedLabels() error { } // delete labels with none existing orgs - if _, err := db.DefaultContext().Engine(). + if _, err := db.GetEngine(db.DefaultContext). Where(builder.And( builder.Gt{"org_id": 0}, builder.NotIn("org_id", builder.Select("id").From("user")), @@ -238,14 +238,14 @@ func DeleteOrphanedLabels() error { // CountOrphanedIssueLabels return count of IssueLabels witch have no label behind anymore func CountOrphanedIssueLabels() (int64, error) { - return db.DefaultContext().Engine().Table("issue_label"). + return db.GetEngine(db.DefaultContext).Table("issue_label"). NotIn("label_id", builder.Select("id").From("label")). Count() } // DeleteOrphanedIssueLabels delete IssueLabels witch have no label behind anymore func DeleteOrphanedIssueLabels() error { - _, err := db.DefaultContext().Engine(). + _, err := db.GetEngine(db.DefaultContext). NotIn("label_id", builder.Select("id").From("label")). Delete(IssueLabel{}) @@ -254,7 +254,7 @@ func DeleteOrphanedIssueLabels() error { // CountOrphanedIssues count issues without a repo func CountOrphanedIssues() (int64, error) { - return db.DefaultContext().Engine().Table("issue"). + return db.GetEngine(db.DefaultContext).Table("issue"). Join("LEFT", "repository", "issue.repo_id=repository.id"). Where(builder.IsNull{"repository.id"}). Count("id") @@ -270,7 +270,7 @@ func DeleteOrphanedIssues() error { var ids []int64 - if err := ctx.Engine().Table("issue").Distinct("issue.repo_id"). + if err := db.GetEngine(ctx).Table("issue").Distinct("issue.repo_id"). Join("LEFT", "repository", "issue.repo_id=repository.id"). Where(builder.IsNull{"repository.id"}).GroupBy("issue.repo_id"). Find(&ids); err != nil { @@ -279,7 +279,7 @@ func DeleteOrphanedIssues() error { var attachmentPaths []string for i := range ids { - paths, err := deleteIssuesByRepoID(ctx.Engine(), ids[i]) + paths, err := deleteIssuesByRepoID(db.GetEngine(ctx), ids[i]) if err != nil { return err } @@ -293,14 +293,14 @@ func DeleteOrphanedIssues() error { // Remove issue attachment files. for i := range attachmentPaths { - removeAllWithNotice(db.DefaultContext().Engine(), "Delete issue attachment", attachmentPaths[i]) + removeAllWithNotice(db.GetEngine(db.DefaultContext), "Delete issue attachment", attachmentPaths[i]) } return nil } // CountOrphanedObjects count subjects with have no existing refobject anymore func CountOrphanedObjects(subject, refobject, joinCond string) (int64, error) { - return db.DefaultContext().Engine().Table("`"+subject+"`"). + return db.GetEngine(db.DefaultContext).Table("`"+subject+"`"). Join("LEFT", refobject, joinCond). Where(builder.IsNull{"`" + refobject + "`.id"}). Count("id") @@ -316,45 +316,45 @@ func DeleteOrphanedObjects(subject, refobject, joinCond string) error { if err != nil { return err } - _, err = db.DefaultContext().Engine().Exec(append([]interface{}{sql}, args...)...) + _, err = db.GetEngine(db.DefaultContext).Exec(append([]interface{}{sql}, args...)...) return err } // CountNullArchivedRepository counts the number of repositories with is_archived is null func CountNullArchivedRepository() (int64, error) { - return db.DefaultContext().Engine().Where(builder.IsNull{"is_archived"}).Count(new(Repository)) + return db.GetEngine(db.DefaultContext).Where(builder.IsNull{"is_archived"}).Count(new(Repository)) } // FixNullArchivedRepository sets is_archived to false where it is null func FixNullArchivedRepository() (int64, error) { - return db.DefaultContext().Engine().Where(builder.IsNull{"is_archived"}).Cols("is_archived").NoAutoTime().Update(&Repository{ + return db.GetEngine(db.DefaultContext).Where(builder.IsNull{"is_archived"}).Cols("is_archived").NoAutoTime().Update(&Repository{ IsArchived: false, }) } // CountWrongUserType count OrgUser who have wrong type func CountWrongUserType() (int64, error) { - return db.DefaultContext().Engine().Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Count(new(User)) + return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Count(new(User)) } // FixWrongUserType fix OrgUser who have wrong type func FixWrongUserType() (int64, error) { - return db.DefaultContext().Engine().Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Cols("type").NoAutoTime().Update(&User{Type: 1}) + return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Cols("type").NoAutoTime().Update(&User{Type: 1}) } // CountCommentTypeLabelWithEmptyLabel count label comments with empty label func CountCommentTypeLabelWithEmptyLabel() (int64, error) { - return db.DefaultContext().Engine().Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Count(new(Comment)) + return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Count(new(Comment)) } // FixCommentTypeLabelWithEmptyLabel count label comments with empty label func FixCommentTypeLabelWithEmptyLabel() (int64, error) { - return db.DefaultContext().Engine().Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Delete(new(Comment)) + return db.GetEngine(db.DefaultContext).Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Delete(new(Comment)) } // CountCommentTypeLabelWithOutsideLabels count label comments with outside label func CountCommentTypeLabelWithOutsideLabels() (int64, error) { - return db.DefaultContext().Engine().Where("comment.type = ? AND ((label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id))", CommentTypeLabel). + return db.GetEngine(db.DefaultContext).Where("comment.type = ? AND ((label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id))", CommentTypeLabel). Table("comment"). Join("inner", "label", "label.id = comment.label_id"). Join("inner", "issue", "issue.id = comment.issue_id "). @@ -364,7 +364,7 @@ func CountCommentTypeLabelWithOutsideLabels() (int64, error) { // FixCommentTypeLabelWithOutsideLabels count label comments with outside label func FixCommentTypeLabelWithOutsideLabels() (int64, error) { - res, err := db.DefaultContext().Engine().Exec(`DELETE FROM comment WHERE comment.id IN ( + res, err := db.GetEngine(db.DefaultContext).Exec(`DELETE FROM comment WHERE comment.id IN ( SELECT il_too.id FROM ( SELECT com.id FROM comment AS com @@ -383,7 +383,7 @@ func FixCommentTypeLabelWithOutsideLabels() (int64, error) { // CountIssueLabelWithOutsideLabels count label comments with outside label func CountIssueLabelWithOutsideLabels() (int64, error) { - return db.DefaultContext().Engine().Where(builder.Expr("(label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id)")). + return db.GetEngine(db.DefaultContext).Where(builder.Expr("(label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id)")). Table("issue_label"). Join("inner", "label", "issue_label.label_id = label.id "). Join("inner", "issue", "issue.id = issue_label.issue_id "). @@ -393,7 +393,7 @@ func CountIssueLabelWithOutsideLabels() (int64, error) { // FixIssueLabelWithOutsideLabels fix label comments with outside label func FixIssueLabelWithOutsideLabels() (int64, error) { - res, err := db.DefaultContext().Engine().Exec(`DELETE FROM issue_label WHERE issue_label.id IN ( + res, err := db.GetEngine(db.DefaultContext).Exec(`DELETE FROM issue_label WHERE issue_label.id IN ( SELECT il_too.id FROM ( SELECT il_too_too.id FROM issue_label AS il_too_too diff --git a/models/consistency_test.go b/models/consistency_test.go index 25acb467b61e..8332b5d76191 100644 --- a/models/consistency_test.go +++ b/models/consistency_test.go @@ -14,10 +14,10 @@ import ( func TestDeleteOrphanedObjects(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) - countBefore, err := db.DefaultContext().Engine().Count(&PullRequest{}) + countBefore, err := db.GetEngine(db.DefaultContext).Count(&PullRequest{}) assert.NoError(t, err) - _, err = db.DefaultContext().Engine().Insert(&PullRequest{IssueID: 1000}, &PullRequest{IssueID: 1001}, &PullRequest{IssueID: 1003}) + _, err = db.GetEngine(db.DefaultContext).Insert(&PullRequest{IssueID: 1000}, &PullRequest{IssueID: 1001}, &PullRequest{IssueID: 1003}) assert.NoError(t, err) orphaned, err := CountOrphanedObjects("pull_request", "issue", "pull_request.issue_id=issue.id") @@ -27,7 +27,7 @@ func TestDeleteOrphanedObjects(t *testing.T) { err = DeleteOrphanedObjects("pull_request", "issue", "pull_request.issue_id=issue.id") assert.NoError(t, err) - countAfter, err := db.DefaultContext().Engine().Count(&PullRequest{}) + countAfter, err := db.GetEngine(db.DefaultContext).Count(&PullRequest{}) assert.NoError(t, err) assert.EqualValues(t, countBefore, countAfter) } diff --git a/models/db/context.go b/models/db/context.go index 9b1a3010b6fc..0037bb198dde 100644 --- a/models/db/context.go +++ b/models/db/context.go @@ -5,14 +5,29 @@ package db import ( + "context" + "code.gitea.io/gitea/modules/setting" "xorm.io/builder" "xorm.io/xorm" ) +// DefaultContext is the default context to run xorm queries in +// will be overwritten by Init with HammerContext +var DefaultContext context.Context + +// contextKey is a value for use with context.WithValue. +type contextKey struct { + name string +} + +// EnginedContextKey is a context key. It is used with context.Value() to get the current Engined for the context +var EnginedContextKey = &contextKey{"engined"} + // Context represents a db context type Context struct { + context.Context e Engine } @@ -30,9 +45,48 @@ func (ctx *Context) NewSession() *xorm.Session { return nil } -// DefaultContext represents a Context with default Engine -func DefaultContext() *Context { - return &Context{x} +// Value shadows Value for context.Context but allows us to get ourselves and an Engined object +func (ctx *Context) Value(key interface{}) interface{} { + if key == EnginedContextKey { + return ctx + } + return ctx.Context.Value(key) +} + +// Engined structs provide an Engine +type Engined interface { + Engine() Engine + NewSession() *xorm.Session +} + +// GetEngine will get a db Engine from this context or return an Engine restricted to this context +func GetEngine(ctx context.Context) Engine { + if engined, ok := ctx.(Engined); ok { + return engined.Engine() + } + enginedInterface := ctx.Value(EnginedContextKey) + if enginedInterface != nil { + return enginedInterface.(Engined).Engine() + } + return x.Context(ctx) +} + +// NewSession will get a db Session from this context or return a session restricted to this context +func NewSession(ctx context.Context) *xorm.Session { + if engined, ok := ctx.(Engined); ok { + return engined.NewSession() + } + + enginedInterface := ctx.Value(EnginedContextKey) + if enginedInterface != nil { + sess := enginedInterface.(Engined).NewSession() + if sess != nil { + return sess.Context(ctx) + } + return nil + } + + return x.NewSession().Context(ctx) } // Committer represents an interface to Commit or Close the Context @@ -49,23 +103,32 @@ func TxContext() (*Context, Committer, error) { return nil, nil, err } - return &Context{sess}, sess, nil + return &Context{ + Context: DefaultContext, + e: sess, + }, sess, nil } // WithContext represents executing database operations func WithContext(f func(ctx *Context) error) error { - return f(&Context{x}) + return f(&Context{ + Context: DefaultContext, + e: x, + }) } // WithTx represents executing database operations on a transaction -func WithTx(f func(ctx *Context) error) error { +func WithTx(f func(ctx context.Context) error) error { sess := x.NewSession() defer sess.Close() if err := sess.Begin(); err != nil { return err } - if err := f(&Context{sess}); err != nil { + if err := f(&Context{ + Context: DefaultContext, + e: sess, + }); err != nil { return err } @@ -73,14 +136,14 @@ func WithTx(f func(ctx *Context) error) error { } // Iterate iterates the databases and doing something -func Iterate(ctx *Context, tableBean interface{}, cond builder.Cond, fun func(idx int, bean interface{}) error) error { - return ctx.e.Where(cond). +func Iterate(ctx context.Context, tableBean interface{}, cond builder.Cond, fun func(idx int, bean interface{}) error) error { + return GetEngine(ctx).Where(cond). BufferSize(setting.Database.IterateBufferSize). Iterate(tableBean, fun) } // Insert inserts records into database -func Insert(ctx *Context, beans ...interface{}) error { - _, err := ctx.e.Insert(beans...) +func Insert(ctx context.Context, beans ...interface{}) error { + _, err := GetEngine(ctx).Insert(beans...) return err } diff --git a/models/db/engine.go b/models/db/engine.go index e71b6fdc59a8..256eb2f3fc72 100755 --- a/models/db/engine.go +++ b/models/db/engine.go @@ -169,6 +169,11 @@ func NewEngine(ctx context.Context, migrateFunc func(*xorm.Engine) error) (err e return err } + DefaultContext = &Context{ + Context: ctx, + e: x, + } + x.SetDefaultContext(ctx) if err = x.Ping(); err != nil { diff --git a/models/db/unit_tests.go b/models/db/unit_tests.go index 0540c9ba8ac0..781f0ecca20f 100644 --- a/models/db/unit_tests.go +++ b/models/db/unit_tests.go @@ -5,6 +5,7 @@ package db import ( + "context" "fmt" "math" "net/url" @@ -117,6 +118,11 @@ func CreateTestEngine(fixturesDir string) error { x.ShowSQL(true) } + DefaultContext = &Context{ + Context: context.Background(), + e: x, + } + return InitFixtures(fixturesDir) } diff --git a/models/engine_test.go b/models/engine_test.go index 66e3ef099f01..d97fc3cc197a 100644 --- a/models/engine_test.go +++ b/models/engine_test.go @@ -25,7 +25,7 @@ func TestDumpDatabase(t *testing.T) { ID int64 `xorm:"pk autoincr"` Version int64 } - assert.NoError(t, db.DefaultContext().Engine().Sync2(new(Version))) + assert.NoError(t, db.GetEngine(db.DefaultContext).Sync2(new(Version))) for _, dbName := range setting.SupportedDatabases { dbType := setting.GetDBTypeByName(dbName) diff --git a/models/external_login_user.go b/models/external_login_user.go index 5bfdad69935f..c6a4b71b53e2 100644 --- a/models/external_login_user.go +++ b/models/external_login_user.go @@ -41,13 +41,13 @@ func init() { // GetExternalLogin checks if a externalID in loginSourceID scope already exists func GetExternalLogin(externalLoginUser *ExternalLoginUser) (bool, error) { - return db.DefaultContext().Engine().Get(externalLoginUser) + return db.GetEngine(db.DefaultContext).Get(externalLoginUser) } // ListAccountLinks returns a map with the ExternalLoginUser and its LoginSource func ListAccountLinks(user *User) ([]*ExternalLoginUser, error) { externalAccounts := make([]*ExternalLoginUser, 0, 5) - err := db.DefaultContext().Engine().Where("user_id=?", user.ID). + err := db.GetEngine(db.DefaultContext).Where("user_id=?", user.ID). Desc("login_source_id"). Find(&externalAccounts) if err != nil { @@ -59,7 +59,7 @@ func ListAccountLinks(user *User) ([]*ExternalLoginUser, error) { // LinkExternalToUser link the external user to the user func LinkExternalToUser(user *User, externalLoginUser *ExternalLoginUser) error { - has, err := db.DefaultContext().Engine().Where("external_id=? AND login_source_id=?", externalLoginUser.ExternalID, externalLoginUser.LoginSourceID). + has, err := db.GetEngine(db.DefaultContext).Where("external_id=? AND login_source_id=?", externalLoginUser.ExternalID, externalLoginUser.LoginSourceID). NoAutoCondition(). Exist(externalLoginUser) if err != nil { @@ -68,13 +68,13 @@ func LinkExternalToUser(user *User, externalLoginUser *ExternalLoginUser) error return ErrExternalLoginUserAlreadyExist{externalLoginUser.ExternalID, user.ID, externalLoginUser.LoginSourceID} } - _, err = db.DefaultContext().Engine().Insert(externalLoginUser) + _, err = db.GetEngine(db.DefaultContext).Insert(externalLoginUser) return err } // RemoveAccountLink will remove all external login sources for the given user func RemoveAccountLink(user *User, loginSourceID int64) (int64, error) { - deleted, err := db.DefaultContext().Engine().Delete(&ExternalLoginUser{UserID: user.ID, LoginSourceID: loginSourceID}) + deleted, err := db.GetEngine(db.DefaultContext).Delete(&ExternalLoginUser{UserID: user.ID, LoginSourceID: loginSourceID}) if err != nil { return deleted, err } @@ -93,7 +93,7 @@ func removeAllAccountLinks(e db.Engine, user *User) error { // GetUserIDByExternalUserID get user id according to provider and userID func GetUserIDByExternalUserID(provider, userID string) (int64, error) { var id int64 - _, err := db.DefaultContext().Engine().Table("external_login_user"). + _, err := db.GetEngine(db.DefaultContext).Table("external_login_user"). Select("user_id"). Where("provider=?", provider). And("external_id=?", userID). @@ -130,7 +130,7 @@ func UpdateExternalUser(user *User, gothUser goth.User) error { ExpiresAt: gothUser.ExpiresAt, } - has, err := db.DefaultContext().Engine().Where("external_id=? AND login_source_id=?", gothUser.UserID, loginSource.ID). + has, err := db.GetEngine(db.DefaultContext).Where("external_id=? AND login_source_id=?", gothUser.UserID, loginSource.ID). NoAutoCondition(). Exist(externalLoginUser) if err != nil { @@ -139,7 +139,7 @@ func UpdateExternalUser(user *User, gothUser goth.User) error { return ErrExternalLoginUserNotExist{user.ID, loginSource.ID} } - _, err = db.DefaultContext().Engine().Where("external_id=? AND login_source_id=?", gothUser.UserID, loginSource.ID).AllCols().Update(externalLoginUser) + _, err = db.GetEngine(db.DefaultContext).Where("external_id=? AND login_source_id=?", gothUser.UserID, loginSource.ID).AllCols().Update(externalLoginUser) return err } @@ -161,7 +161,7 @@ func (opts FindExternalUserOptions) toConds() builder.Cond { // FindExternalUsersByProvider represents external users via provider func FindExternalUsersByProvider(opts FindExternalUserOptions) ([]ExternalLoginUser, error) { var users []ExternalLoginUser - err := db.DefaultContext().Engine().Where(opts.toConds()). + err := db.GetEngine(db.DefaultContext).Where(opts.toConds()). Limit(opts.Limit, opts.Start). OrderBy("login_source_id ASC, external_id ASC"). Find(&users) diff --git a/models/fixture_generation.go b/models/fixture_generation.go index 49cf5ad6a605..c87909d01ba5 100644 --- a/models/fixture_generation.go +++ b/models/fixture_generation.go @@ -15,7 +15,7 @@ import ( // for the access table, as recalculated using repo.RecalculateAccesses() func GetYamlFixturesAccess() (string, error) { repos := make([]*Repository, 0, 50) - if err := db.DefaultContext().Engine().Find(&repos); err != nil { + if err := db.GetEngine(db.DefaultContext).Find(&repos); err != nil { return "", err } @@ -29,7 +29,7 @@ func GetYamlFixturesAccess() (string, error) { var b strings.Builder accesses := make([]*Access, 0, 200) - if err := db.DefaultContext().Engine().OrderBy("user_id, repo_id").Find(&accesses); err != nil { + if err := db.GetEngine(db.DefaultContext).OrderBy("user_id, repo_id").Find(&accesses); err != nil { return "", err } diff --git a/models/gpg_key.go b/models/gpg_key.go index fabb3d5c913d..d8dd79c28538 100644 --- a/models/gpg_key.go +++ b/models/gpg_key.go @@ -63,7 +63,7 @@ func (key *GPGKey) AfterLoad(session *xorm.Session) { // ListGPGKeys returns a list of public keys belongs to given user. func ListGPGKeys(uid int64, listOptions ListOptions) ([]*GPGKey, error) { - return listGPGKeys(db.DefaultContext().Engine(), uid, listOptions) + return listGPGKeys(db.GetEngine(db.DefaultContext), uid, listOptions) } func listGPGKeys(e db.Engine, uid int64, listOptions ListOptions) ([]*GPGKey, error) { @@ -78,13 +78,13 @@ func listGPGKeys(e db.Engine, uid int64, listOptions ListOptions) ([]*GPGKey, er // CountUserGPGKeys return number of gpg keys a user own func CountUserGPGKeys(userID int64) (int64, error) { - return db.DefaultContext().Engine().Where("owner_id=? AND primary_key_id=''", userID).Count(&GPGKey{}) + return db.GetEngine(db.DefaultContext).Where("owner_id=? AND primary_key_id=''", userID).Count(&GPGKey{}) } // GetGPGKeyByID returns public key by given ID. func GetGPGKeyByID(keyID int64) (*GPGKey, error) { key := new(GPGKey) - has, err := db.DefaultContext().Engine().ID(keyID).Get(key) + has, err := db.GetEngine(db.DefaultContext).ID(keyID).Get(key) if err != nil { return nil, err } else if !has { @@ -96,7 +96,7 @@ func GetGPGKeyByID(keyID int64) (*GPGKey, error) { // GetGPGKeysByKeyID returns public key by given ID. func GetGPGKeysByKeyID(keyID string) ([]*GPGKey, error) { keys := make([]*GPGKey, 0, 1) - return keys, db.DefaultContext().Engine().Where("key_id=?", keyID).Find(&keys) + return keys, db.GetEngine(db.DefaultContext).Where("key_id=?", keyID).Find(&keys) } // GPGKeyToEntity retrieve the imported key and the traducted entity @@ -233,7 +233,7 @@ func DeleteGPGKey(doer *User, id int64) (err error) { } defer committer.Close() - if _, err = deleteGPGKey(ctx.Engine(), key.KeyID); err != nil { + if _, err = deleteGPGKey(db.GetEngine(ctx), key.KeyID); err != nil { return err } diff --git a/models/gpg_key_add.go b/models/gpg_key_add.go index 91a30120a619..635872c920f5 100644 --- a/models/gpg_key_add.go +++ b/models/gpg_key_add.go @@ -103,7 +103,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro for _, ekey := range ekeys { // Key ID cannot be duplicated. - has, err := ctx.Engine().Where("key_id=?", ekey.PrimaryKey.KeyIdString()). + has, err := db.GetEngine(ctx).Where("key_id=?", ekey.PrimaryKey.KeyIdString()). Get(new(GPGKey)) if err != nil { return nil, err @@ -118,7 +118,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro return nil, err } - if err = addGPGKey(ctx.Engine(), key, content); err != nil { + if err = addGPGKey(db.GetEngine(ctx), key, content); err != nil { return nil, err } keys = append(keys, key) diff --git a/models/gpg_key_import.go b/models/gpg_key_import.go index 73d46b4cb423..1eed9296272b 100644 --- a/models/gpg_key_import.go +++ b/models/gpg_key_import.go @@ -34,7 +34,7 @@ func init() { // GetGPGImportByKeyID returns the import public armored key by given KeyID. func GetGPGImportByKeyID(keyID string) (*GPGKeyImport, error) { key := new(GPGKeyImport) - has, err := db.DefaultContext().Engine().ID(keyID).Get(key) + has, err := db.GetEngine(db.DefaultContext).ID(keyID).Get(key) if err != nil { return nil, err } else if !has { diff --git a/models/gpg_key_verify.go b/models/gpg_key_verify.go index 050bee0bca71..1c6b79ec5f6a 100644 --- a/models/gpg_key_verify.go +++ b/models/gpg_key_verify.go @@ -38,7 +38,7 @@ func VerifyGPGKey(ownerID int64, keyID, token, signature string) (string, error) key := new(GPGKey) - has, err := ctx.Engine().Where("owner_id = ? AND key_id = ?", ownerID, keyID).Get(key) + has, err := db.GetEngine(ctx).Where("owner_id = ? AND key_id = ?", ownerID, keyID).Get(key) if err != nil { return "", err } else if !has { @@ -92,7 +92,7 @@ func VerifyGPGKey(ownerID int64, keyID, token, signature string) (string, error) } key.Verified = true - if _, err := ctx.Engine().ID(key.ID).SetExpr("verified", true).Update(new(GPGKey)); err != nil { + if _, err := db.GetEngine(ctx).ID(key.ID).SetExpr("verified", true).Update(new(GPGKey)); err != nil { return "", err } diff --git a/models/issue.go b/models/issue.go index a1f5373583a4..cafd996ac51d 100644 --- a/models/issue.go +++ b/models/issue.go @@ -6,6 +6,7 @@ package models import ( + "context" "fmt" "regexp" "sort" @@ -113,7 +114,7 @@ func (issue *Issue) IsOverdue() bool { // LoadRepo loads issue's repository func (issue *Issue) LoadRepo() error { - return issue.loadRepo(db.DefaultContext().Engine()) + return issue.loadRepo(db.GetEngine(db.DefaultContext)) } func (issue *Issue) loadRepo(e db.Engine) (err error) { @@ -128,7 +129,7 @@ func (issue *Issue) loadRepo(e db.Engine) (err error) { // IsTimetrackerEnabled returns true if the repo enables timetracking func (issue *Issue) IsTimetrackerEnabled() bool { - return issue.isTimetrackerEnabled(db.DefaultContext().Engine()) + return issue.isTimetrackerEnabled(db.GetEngine(db.DefaultContext)) } func (issue *Issue) isTimetrackerEnabled(e db.Engine) bool { @@ -145,7 +146,7 @@ func (issue *Issue) GetPullRequest() (pr *PullRequest, err error) { return nil, fmt.Errorf("Issue is not a pull request") } - pr, err = getPullRequestByIssueID(db.DefaultContext().Engine(), issue.ID) + pr, err = getPullRequestByIssueID(db.GetEngine(db.DefaultContext), issue.ID) if err != nil { return nil, err } @@ -155,7 +156,7 @@ func (issue *Issue) GetPullRequest() (pr *PullRequest, err error) { // LoadLabels loads labels func (issue *Issue) LoadLabels() error { - return issue.loadLabels(db.DefaultContext().Engine()) + return issue.loadLabels(db.GetEngine(db.DefaultContext)) } func (issue *Issue) loadLabels(e db.Engine) (err error) { @@ -170,7 +171,7 @@ func (issue *Issue) loadLabels(e db.Engine) (err error) { // LoadPoster loads poster func (issue *Issue) LoadPoster() error { - return issue.loadPoster(db.DefaultContext().Engine()) + return issue.loadPoster(db.GetEngine(db.DefaultContext)) } func (issue *Issue) loadPoster(e db.Engine) (err error) { @@ -205,7 +206,7 @@ func (issue *Issue) loadPullRequest(e db.Engine) (err error) { // LoadPullRequest loads pull request info func (issue *Issue) LoadPullRequest() error { - return issue.loadPullRequest(db.DefaultContext().Engine()) + return issue.loadPullRequest(db.GetEngine(db.DefaultContext)) } func (issue *Issue) loadComments(e db.Engine) (err error) { @@ -214,7 +215,7 @@ func (issue *Issue) loadComments(e db.Engine) (err error) { // LoadDiscussComments loads discuss comments func (issue *Issue) LoadDiscussComments() error { - return issue.loadCommentsByType(db.DefaultContext().Engine(), CommentTypeComment) + return issue.loadCommentsByType(db.GetEngine(db.DefaultContext), CommentTypeComment) } func (issue *Issue) loadCommentsByType(e db.Engine, tp CommentType) (err error) { @@ -327,18 +328,18 @@ func (issue *Issue) loadAttributes(e db.Engine) (err error) { // LoadAttributes loads the attribute of this issue. func (issue *Issue) LoadAttributes() error { - return issue.loadAttributes(db.DefaultContext().Engine()) + return issue.loadAttributes(db.GetEngine(db.DefaultContext)) } // LoadMilestone load milestone of this issue. func (issue *Issue) LoadMilestone() error { - return issue.loadMilestone(db.DefaultContext().Engine()) + return issue.loadMilestone(db.GetEngine(db.DefaultContext)) } // GetIsRead load the `IsRead` field of the issue func (issue *Issue) GetIsRead(userID int64) error { issueUser := &IssueUser{IssueID: issue.ID, UID: userID} - if has, err := db.DefaultContext().Engine().Get(issueUser); err != nil { + if has, err := db.GetEngine(db.DefaultContext).Get(issueUser); err != nil { return err } else if !has { issue.IsRead = false @@ -411,7 +412,7 @@ func (issue *Issue) hasLabel(e db.Engine, labelID int64) bool { // HasLabel returns true if issue has been labeled by given ID. func (issue *Issue) HasLabel(labelID int64) bool { - return issue.hasLabel(db.DefaultContext().Engine(), labelID) + return issue.hasLabel(db.GetEngine(db.DefaultContext), labelID) } // ReplyReference returns tokenized address to use for email reply headers @@ -473,13 +474,13 @@ func (issue *Issue) ClearLabels(doer *User) (err error) { } defer committer.Close() - if err := issue.loadRepo(ctx.Engine()); err != nil { + if err := issue.loadRepo(db.GetEngine(ctx)); err != nil { return err - } else if err = issue.loadPullRequest(ctx.Engine()); err != nil { + } else if err = issue.loadPullRequest(db.GetEngine(ctx)); err != nil { return err } - perm, err := getUserRepoPermission(ctx.Engine(), issue.Repo, doer) + perm, err := getUserRepoPermission(db.GetEngine(ctx), issue.Repo, doer) if err != nil { return err } @@ -487,7 +488,7 @@ func (issue *Issue) ClearLabels(doer *User) (err error) { return ErrRepoLabelNotExist{} } - if err = issue.clearLabels(ctx.Engine(), doer); err != nil { + if err = issue.clearLabels(db.GetEngine(ctx), doer); err != nil { return err } @@ -521,11 +522,11 @@ func (issue *Issue) ReplaceLabels(labels []*Label, doer *User) (err error) { } defer committer.Close() - if err = issue.loadRepo(ctx.Engine()); err != nil { + if err = issue.loadRepo(db.GetEngine(ctx)); err != nil { return err } - if err = issue.loadLabels(ctx.Engine()); err != nil { + if err = issue.loadLabels(db.GetEngine(ctx)); err != nil { return err } @@ -561,19 +562,19 @@ func (issue *Issue) ReplaceLabels(labels []*Label, doer *User) (err error) { toRemove = append(toRemove, issue.Labels[removeIndex:]...) if len(toAdd) > 0 { - if err = issue.addLabels(ctx.Engine(), toAdd, doer); err != nil { + if err = issue.addLabels(db.GetEngine(ctx), toAdd, doer); err != nil { return fmt.Errorf("addLabels: %v", err) } } for _, l := range toRemove { - if err = issue.removeLabel(ctx.Engine(), doer, l); err != nil { + if err = issue.removeLabel(db.GetEngine(ctx), doer, l); err != nil { return fmt.Errorf("removeLabel: %v", err) } } issue.Labels = nil - if err = issue.loadLabels(ctx.Engine()); err != nil { + if err = issue.loadLabels(db.GetEngine(ctx)); err != nil { return err } @@ -586,7 +587,7 @@ func (issue *Issue) ReadBy(userID int64) error { return err } - return setIssueNotificationStatusReadIfUnread(db.DefaultContext().Engine(), userID, issue.ID) + return setIssueNotificationStatusReadIfUnread(db.GetEngine(db.DefaultContext), userID, issue.ID) } func updateIssueCols(e db.Engine, issue *Issue, cols ...string) error { @@ -688,14 +689,14 @@ func (issue *Issue) ChangeStatus(doer *User, isClosed bool) (*Comment, error) { } defer committer.Close() - if err := issue.loadRepo(ctx.Engine()); err != nil { + if err := issue.loadRepo(db.GetEngine(ctx)); err != nil { return nil, err } - if err := issue.loadPoster(ctx.Engine()); err != nil { + if err := issue.loadPoster(db.GetEngine(ctx)); err != nil { return nil, err } - comment, err := issue.changeStatus(ctx.Engine(), doer, isClosed, false) + comment, err := issue.changeStatus(db.GetEngine(ctx), doer, isClosed, false) if err != nil { return nil, err } @@ -715,11 +716,11 @@ func (issue *Issue) ChangeTitle(doer *User, oldTitle string) (err error) { } defer committer.Close() - if err = updateIssueCols(ctx.Engine(), issue, "name"); err != nil { + if err = updateIssueCols(db.GetEngine(ctx), issue, "name"); err != nil { return fmt.Errorf("updateIssueCols: %v", err) } - if err = issue.loadRepo(ctx.Engine()); err != nil { + if err = issue.loadRepo(db.GetEngine(ctx)); err != nil { return fmt.Errorf("loadRepo: %v", err) } @@ -731,10 +732,10 @@ func (issue *Issue) ChangeTitle(doer *User, oldTitle string) (err error) { OldTitle: oldTitle, NewTitle: issue.Title, } - if _, err = createComment(ctx.Engine(), opts); err != nil { + if _, err = createComment(db.GetEngine(ctx), opts); err != nil { return fmt.Errorf("createComment: %v", err) } - if err = issue.addCrossReferences(ctx.Engine(), doer, true); err != nil { + if err = issue.addCrossReferences(db.GetEngine(ctx), doer, true); err != nil { return err } @@ -749,7 +750,7 @@ func (issue *Issue) ChangeRef(doer *User, oldRef string) (err error) { } defer committer.Close() - if err = updateIssueCols(ctx.Engine(), issue, "ref"); err != nil { + if err = updateIssueCols(db.GetEngine(ctx), issue, "ref"); err != nil { return fmt.Errorf("updateIssueCols: %v", err) } @@ -758,7 +759,7 @@ func (issue *Issue) ChangeRef(doer *User, oldRef string) (err error) { // AddDeletePRBranchComment adds delete branch comment for pull request issue func AddDeletePRBranchComment(doer *User, repo *Repository, issueID int64, branchName string) error { - issue, err := getIssueByID(db.DefaultContext().Engine(), issueID) + issue, err := getIssueByID(db.GetEngine(db.DefaultContext), issueID) if err != nil { return err } @@ -774,7 +775,7 @@ func AddDeletePRBranchComment(doer *User, repo *Repository, issueID int64, branc Issue: issue, OldRef: branchName, } - if _, err = createComment(ctx.Engine(), opts); err != nil { + if _, err = createComment(db.GetEngine(ctx), opts); err != nil { return err } @@ -788,13 +789,13 @@ func (issue *Issue) UpdateAttachments(uuids []string) (err error) { return err } defer committer.Close() - attachments, err := getAttachmentsByUUIDs(ctx.Engine(), uuids) + attachments, err := getAttachmentsByUUIDs(db.GetEngine(ctx), uuids) if err != nil { return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %v", uuids, err) } for i := 0; i < len(attachments); i++ { attachments[i].IssueID = issue.ID - if err := updateAttachment(ctx.Engine(), attachments[i]); err != nil { + if err := updateAttachment(db.GetEngine(ctx), attachments[i]); err != nil { return fmt.Errorf("update attachment [id: %d]: %v", attachments[i].ID, err) } } @@ -811,11 +812,11 @@ func (issue *Issue) ChangeContent(doer *User, content string) (err error) { } defer committer.Close() - if err = updateIssueCols(ctx.Engine(), issue, "content"); err != nil { + if err = updateIssueCols(db.GetEngine(ctx), issue, "content"); err != nil { return fmt.Errorf("UpdateIssueCols: %v", err) } - if err = issue.addCrossReferences(ctx.Engine(), doer, true); err != nil { + if err = issue.addCrossReferences(db.GetEngine(ctx), doer, true); err != nil { return err } @@ -854,7 +855,7 @@ func (issue *Issue) GetLastEventLabel() string { // GetLastComment return last comment for the current issue. func (issue *Issue) GetLastComment() (*Comment, error) { var c Comment - exist, err := db.DefaultContext().Engine().Where("type = ?", CommentTypeComment). + exist, err := db.GetEngine(db.DefaultContext).Where("type = ?", CommentTypeComment). And("issue_id = ?", issue.ID).Desc("id").Get(&c) if err != nil { return nil, err @@ -996,16 +997,16 @@ func RecalculateIssueIndexForRepo(repoID int64) error { } defer committer.Close() - if err := db.UpsertResourceIndex(ctx.Engine(), "issue_index", repoID); err != nil { + if err := db.UpsertResourceIndex(db.GetEngine(ctx), "issue_index", repoID); err != nil { return err } var max int64 - if _, err := ctx.Engine().Select(" MAX(`index`)").Table("issue").Where("repo_id=?", repoID).Get(&max); err != nil { + if _, err := db.GetEngine(ctx).Select(" MAX(`index`)").Table("issue").Where("repo_id=?", repoID).Get(&max); err != nil { return err } - if _, err := ctx.Engine().Exec("UPDATE `issue_index` SET max_index=? WHERE group_id=?", max, repoID); err != nil { + if _, err := db.GetEngine(ctx).Exec("UPDATE `issue_index` SET max_index=? WHERE group_id=?", max, repoID); err != nil { return err } @@ -1027,7 +1028,7 @@ func NewIssue(repo *Repository, issue *Issue, labelIDs []int64, uuids []string) } defer committer.Close() - if err = newIssue(ctx.Engine(), issue.Poster, NewIssueOptions{ + if err = newIssue(db.GetEngine(ctx), issue.Poster, NewIssueOptions{ Repo: repo, Issue: issue, LabelIDs: labelIDs, @@ -1055,7 +1056,7 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) { RepoID: repoID, Index: index, } - has, err := db.DefaultContext().Engine().Get(issue) + has, err := db.GetEngine(db.DefaultContext).Get(issue) if err != nil { return nil, err } else if !has { @@ -1086,16 +1087,16 @@ func getIssueByID(e db.Engine, id int64) (*Issue, error) { // GetIssueWithAttrsByID returns an issue with attributes by given ID. func GetIssueWithAttrsByID(id int64) (*Issue, error) { - issue, err := getIssueByID(db.DefaultContext().Engine(), id) + issue, err := getIssueByID(db.GetEngine(db.DefaultContext), id) if err != nil { return nil, err } - return issue, issue.loadAttributes(db.DefaultContext().Engine()) + return issue, issue.loadAttributes(db.GetEngine(db.DefaultContext)) } // GetIssueByID returns an issue by given ID. func GetIssueByID(id int64) (*Issue, error) { - return getIssueByID(db.DefaultContext().Engine(), id) + return getIssueByID(db.GetEngine(db.DefaultContext), id) } func getIssuesByIDs(e db.Engine, issueIDs []int64) ([]*Issue, error) { @@ -1111,12 +1112,12 @@ func getIssueIDsByRepoID(e db.Engine, repoID int64) ([]int64, error) { // GetIssueIDsByRepoID returns all issue ids by repo id func GetIssueIDsByRepoID(repoID int64) ([]int64, error) { - return getIssueIDsByRepoID(db.DefaultContext().Engine(), repoID) + return getIssueIDsByRepoID(db.GetEngine(db.DefaultContext), repoID) } // GetIssuesByIDs return issues with the given IDs. func GetIssuesByIDs(issueIDs []int64) ([]*Issue, error) { - return getIssuesByIDs(db.DefaultContext().Engine(), issueIDs) + return getIssuesByIDs(db.GetEngine(db.DefaultContext), issueIDs) } // IssuesOptions represents options of an issue. @@ -1316,7 +1317,7 @@ func applyReviewRequestedCondition(sess *xorm.Session, reviewRequestedID int64) // CountIssuesByRepo map from repoID to number of issues matching the options func CountIssuesByRepo(opts *IssuesOptions) (map[int64]int64, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id") @@ -1344,7 +1345,7 @@ func CountIssuesByRepo(opts *IssuesOptions) (map[int64]int64, error) { // GetRepoIDsForIssuesOptions find all repo ids for the given options func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *User) ([]int64, error) { repoIDs := make([]int64, 0, 5) - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id") @@ -1364,7 +1365,7 @@ func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *User) ([]int64, error // Issues returns a list of issues by given conditions. func Issues(opts *IssuesOptions) ([]*Issue, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id") @@ -1386,7 +1387,7 @@ func Issues(opts *IssuesOptions) ([]*Issue, error) { // CountIssues number return of issues by given conditions. func CountIssues(opts *IssuesOptions) (int64, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() countsSlice := make([]*struct { @@ -1411,7 +1412,7 @@ func CountIssues(opts *IssuesOptions) (int64, error) { // User permissions must be verified elsewhere if required. func GetParticipantsIDsByIssueID(issueID int64) ([]int64, error) { userIDs := make([]int64, 0, 5) - return userIDs, db.DefaultContext().Engine().Table("comment"). + return userIDs, db.GetEngine(db.DefaultContext).Table("comment"). Cols("poster_id"). Where("issue_id = ?", issueID). And("type in (?,?,?)", CommentTypeComment, CommentTypeCode, CommentTypeReview). @@ -1421,7 +1422,7 @@ func GetParticipantsIDsByIssueID(issueID int64) ([]int64, error) { // IsUserParticipantsOfIssue return true if user is participants of an issue func IsUserParticipantsOfIssue(user *User, issue *Issue) bool { - userIDs, err := issue.getParticipantIDsByIssue(db.DefaultContext().Engine()) + userIDs, err := issue.getParticipantIDsByIssue(db.GetEngine(db.DefaultContext)) if err != nil { log.Error(err.Error()) return false @@ -1430,7 +1431,7 @@ func IsUserParticipantsOfIssue(user *User, issue *Issue) bool { } // UpdateIssueMentions updates issue-user relations for mentioned users. -func UpdateIssueMentions(ctx *db.Context, issueID int64, mentions []*User) error { +func UpdateIssueMentions(ctx context.Context, issueID int64, mentions []*User) error { if len(mentions) == 0 { return nil } @@ -1529,7 +1530,7 @@ func getIssueStatsChunk(opts *IssueStatsOptions, issueIDs []int64) (*IssueStats, stats := &IssueStats{} countSession := func(opts *IssueStatsOptions) *xorm.Session { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Where("issue.repo_id = ?", opts.RepoID) if len(opts.IssueIDs) > 0 { @@ -1623,7 +1624,7 @@ func GetUserIssueStats(opts UserIssueStatsOptions) (*IssueStats, error) { } sess := func(cond builder.Cond) *xorm.Session { - s := db.DefaultContext().Engine().Where(cond) + s := db.GetEngine(db.DefaultContext).Where(cond) if len(opts.LabelIDs) > 0 { s.Join("INNER", "issue_label", "issue_label.issue_id = issue.id"). In("issue_label.label_id", opts.LabelIDs) @@ -1735,7 +1736,7 @@ func GetUserIssueStats(opts UserIssueStatsOptions) (*IssueStats, error) { // GetRepoIssueStats returns number of open and closed repository issues by given filter mode. func GetRepoIssueStats(repoID, uid int64, filterMode int, isPull bool) (numOpen, numClosed int64) { countSession := func(isClosed, isPull bool, repoID int64) *xorm.Session { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Where("is_closed = ?", isClosed). And("is_pull = ?", isPull). And("repo_id = ?", repoID) @@ -1787,7 +1788,7 @@ func SearchIssueIDsByKeyword(kw string, repoIDs []int64, limit, start int) (int6 ID int64 UpdatedUnix int64 }, 0, limit) - err := db.DefaultContext().Engine().Distinct("id", "updated_unix").Table("issue").Where(cond). + err := db.GetEngine(db.DefaultContext).Distinct("id", "updated_unix").Table("issue").Where(cond). OrderBy("`updated_unix` DESC").Limit(limit, start). Find(&res) if err != nil { @@ -1797,7 +1798,7 @@ func SearchIssueIDsByKeyword(kw string, repoIDs []int64, limit, start int) (int6 ids = append(ids, r.ID) } - total, err := db.DefaultContext().Engine().Distinct("id").Table("issue").Where(cond).Count() + total, err := db.GetEngine(db.DefaultContext).Distinct("id").Table("issue").Where(cond).Count() if err != nil { return 0, nil, err } @@ -1809,7 +1810,7 @@ func SearchIssueIDsByKeyword(kw string, repoIDs []int64, limit, start int) (int6 // If the issue status is changed a statusChangeComment is returned // similarly if the title is changed the titleChanged bool is set to true func UpdateIssueByAPI(issue *Issue, doer *User) (statusChangeComment *Comment, titleChanged bool, err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, false, err @@ -1868,7 +1869,7 @@ func UpdateIssueDeadline(issue *Issue, deadlineUnix timeutil.TimeStamp, doer *Us return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -1941,12 +1942,12 @@ func (issue *Issue) getBlockingDependencies(e db.Engine) (issueDeps []*Dependenc // BlockedByDependencies finds all Dependencies an issue is blocked by func (issue *Issue) BlockedByDependencies() ([]*DependencyInfo, error) { - return issue.getBlockedByDependencies(db.DefaultContext().Engine()) + return issue.getBlockedByDependencies(db.GetEngine(db.DefaultContext)) } // BlockingDependencies returns all blocking dependencies, aka all other issues a given issue blocks func (issue *Issue) BlockingDependencies() ([]*DependencyInfo, error) { - return issue.getBlockingDependencies(db.DefaultContext().Engine()) + return issue.getBlockingDependencies(db.GetEngine(db.DefaultContext)) } func (issue *Issue) updateClosedNum(e db.Engine) (err error) { @@ -1969,7 +1970,7 @@ func (issue *Issue) updateClosedNum(e db.Engine) (err error) { } // FindAndUpdateIssueMentions finds users mentioned in the given content string, and saves them in the database. -func (issue *Issue) FindAndUpdateIssueMentions(ctx *db.Context, doer *User, content string) (mentions []*User, err error) { +func (issue *Issue) FindAndUpdateIssueMentions(ctx context.Context, doer *User, content string) (mentions []*User, err error) { rawMentions := references.FindAllMentionsMarkdown(content) mentions, err = issue.ResolveMentionsByVisibility(ctx, doer, rawMentions) if err != nil { @@ -1983,18 +1984,18 @@ func (issue *Issue) FindAndUpdateIssueMentions(ctx *db.Context, doer *User, cont // ResolveMentionsByVisibility returns the users mentioned in an issue, removing those that // don't have access to reading it. Teams are expanded into their users, but organizations are ignored. -func (issue *Issue) ResolveMentionsByVisibility(ctx *db.Context, doer *User, mentions []string) (users []*User, err error) { +func (issue *Issue) ResolveMentionsByVisibility(ctx context.Context, doer *User, mentions []string) (users []*User, err error) { if len(mentions) == 0 { return } - if err = issue.loadRepo(ctx.Engine()); err != nil { + if err = issue.loadRepo(db.GetEngine(ctx)); err != nil { return } resolved := make(map[string]bool, 10) var mentionTeams []string - if err := issue.Repo.getOwner(ctx.Engine()); err != nil { + if err := issue.Repo.getOwner(db.GetEngine(ctx)); err != nil { return nil, err } @@ -2023,7 +2024,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx *db.Context, doer *User, men if issue.Repo.Owner.IsOrganization() && len(mentionTeams) > 0 { teams := make([]*Team, 0, len(mentionTeams)) - if err := ctx.Engine(). + if err := db.GetEngine(ctx). Join("INNER", "team_repo", "team_repo.team_id = team.id"). Where("team_repo.repo_id=?", issue.Repo.ID). In("team.lower_name", mentionTeams). @@ -2042,7 +2043,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx *db.Context, doer *User, men resolved[issue.Repo.Owner.LowerName+"/"+team.LowerName] = true continue } - has, err := ctx.Engine().Get(&TeamUnit{OrgID: issue.Repo.Owner.ID, TeamID: team.ID, Type: unittype}) + has, err := db.GetEngine(ctx).Get(&TeamUnit{OrgID: issue.Repo.Owner.ID, TeamID: team.ID, Type: unittype}) if err != nil { return nil, fmt.Errorf("get team units (%d): %v", team.ID, err) } @@ -2053,7 +2054,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx *db.Context, doer *User, men } if len(checked) != 0 { teamusers := make([]*User, 0, 20) - if err := ctx.Engine(). + if err := db.GetEngine(ctx). Join("INNER", "team_user", "team_user.uid = `user`.id"). In("`team_user`.team_id", checked). And("`user`.is_active = ?", true). @@ -2090,7 +2091,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx *db.Context, doer *User, men } unchecked := make([]*User, 0, len(mentionUsers)) - if err := ctx.Engine(). + if err := db.GetEngine(ctx). Where("`user`.is_active = ?", true). And("`user`.prohibit_login = ?", false). In("`user`.lower_name", mentionUsers). @@ -2102,7 +2103,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx *db.Context, doer *User, men continue } // Normal users must have read access to the referencing issue - perm, err := getUserRepoPermission(ctx.Engine(), issue.Repo, user) + perm, err := getUserRepoPermission(db.GetEngine(ctx), issue.Repo, user) if err != nil { return nil, fmt.Errorf("getUserRepoPermission [%d]: %v", user.ID, err) } @@ -2117,7 +2118,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx *db.Context, doer *User, men // UpdateIssuesMigrationsByType updates all migrated repositories' issues from gitServiceType to replace originalAuthorID to posterID func UpdateIssuesMigrationsByType(gitServiceType structs.GitServiceType, originalAuthorID string, posterID int64) error { - _, err := db.DefaultContext().Engine().Table("issue"). + _, err := db.GetEngine(db.DefaultContext).Table("issue"). Where("repo_id IN (SELECT id FROM repository WHERE original_service_type = ?)", gitServiceType). And("original_author_id = ?", originalAuthorID). Update(map[string]interface{}{ @@ -2130,7 +2131,7 @@ func UpdateIssuesMigrationsByType(gitServiceType structs.GitServiceType, origina // UpdateReactionsMigrationsByType updates all migrated repositories' reactions from gitServiceType to replace originalAuthorID to posterID func UpdateReactionsMigrationsByType(gitServiceType structs.GitServiceType, originalAuthorID string, userID int64) error { - _, err := db.DefaultContext().Engine().Table("reaction"). + _, err := db.GetEngine(db.DefaultContext).Table("reaction"). Where("original_author_id = ?", originalAuthorID). And(migratedIssueCond(gitServiceType)). Update(map[string]interface{}{ diff --git a/models/issue_assignees.go b/models/issue_assignees.go index 6b2d8bb3479f..0f7ba2d7022c 100644 --- a/models/issue_assignees.go +++ b/models/issue_assignees.go @@ -26,7 +26,7 @@ func init() { // LoadAssignees load assignees of this issue. func (issue *Issue) LoadAssignees() error { - return issue.loadAssignees(db.DefaultContext().Engine()) + return issue.loadAssignees(db.GetEngine(db.DefaultContext)) } // This loads all assignees of an issue @@ -56,7 +56,7 @@ func (issue *Issue) loadAssignees(e db.Engine) (err error) { // User permissions must be verified elsewhere if required. func GetAssigneeIDsByIssue(issueID int64) ([]int64, error) { userIDs := make([]int64, 0, 5) - return userIDs, db.DefaultContext().Engine().Table("issue_assignees"). + return userIDs, db.GetEngine(db.DefaultContext).Table("issue_assignees"). Cols("assignee_id"). Where("issue_id = ?", issueID). Distinct("assignee_id"). @@ -65,7 +65,7 @@ func GetAssigneeIDsByIssue(issueID int64) ([]int64, error) { // GetAssigneesByIssue returns everyone assigned to that issue func GetAssigneesByIssue(issue *Issue) (assignees []*User, err error) { - return getAssigneesByIssue(db.DefaultContext().Engine(), issue) + return getAssigneesByIssue(db.GetEngine(db.DefaultContext), issue) } func getAssigneesByIssue(e db.Engine, issue *Issue) (assignees []*User, err error) { @@ -79,7 +79,7 @@ func getAssigneesByIssue(e db.Engine, issue *Issue) (assignees []*User, err erro // IsUserAssignedToIssue returns true when the user is assigned to the issue func IsUserAssignedToIssue(issue *Issue, user *User) (isAssigned bool, err error) { - return isUserAssignedToIssue(db.DefaultContext().Engine(), issue, user) + return isUserAssignedToIssue(db.GetEngine(db.DefaultContext), issue, user) } func isUserAssignedToIssue(e db.Engine, issue *Issue, user *User) (isAssigned bool, err error) { @@ -94,7 +94,7 @@ func clearAssigneeByUserID(sess db.Engine, userID int64) (err error) { // ToggleAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it. func (issue *Issue) ToggleAssignee(doer *User, assigneeID int64) (removed bool, comment *Comment, err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { diff --git a/models/issue_comment.go b/models/issue_comment.go index 059d5b08a76f..d8f8e36df288 100644 --- a/models/issue_comment.go +++ b/models/issue_comment.go @@ -210,7 +210,7 @@ type PushActionContent struct { // LoadIssue loads issue from database func (c *Comment) LoadIssue() (err error) { - return c.loadIssue(db.DefaultContext().Engine()) + return c.loadIssue(db.GetEngine(db.DefaultContext)) } func (c *Comment) loadIssue(e db.Engine) (err error) { @@ -284,7 +284,7 @@ func (c *Comment) HTMLURL() string { log.Error("LoadIssue(%d): %v", c.IssueID, err) return "" } - err = c.Issue.loadRepo(db.DefaultContext().Engine()) + err = c.Issue.loadRepo(db.GetEngine(db.DefaultContext)) if err != nil { // Silently dropping errors :unamused: log.Error("loadRepo(%d): %v", c.Issue.RepoID, err) return "" @@ -313,7 +313,7 @@ func (c *Comment) APIURL() string { log.Error("LoadIssue(%d): %v", c.IssueID, err) return "" } - err = c.Issue.loadRepo(db.DefaultContext().Engine()) + err = c.Issue.loadRepo(db.GetEngine(db.DefaultContext)) if err != nil { // Silently dropping errors :unamused: log.Error("loadRepo(%d): %v", c.Issue.RepoID, err) return "" @@ -334,7 +334,7 @@ func (c *Comment) IssueURL() string { return "" } - err = c.Issue.loadRepo(db.DefaultContext().Engine()) + err = c.Issue.loadRepo(db.GetEngine(db.DefaultContext)) if err != nil { // Silently dropping errors :unamused: log.Error("loadRepo(%d): %v", c.Issue.RepoID, err) return "" @@ -350,7 +350,7 @@ func (c *Comment) PRURL() string { return "" } - err = c.Issue.loadRepo(db.DefaultContext().Engine()) + err = c.Issue.loadRepo(db.GetEngine(db.DefaultContext)) if err != nil { // Silently dropping errors :unamused: log.Error("loadRepo(%d): %v", c.Issue.RepoID, err) return "" @@ -380,7 +380,7 @@ func (c *Comment) EventTag() string { // LoadLabel if comment.Type is CommentTypeLabel, then load Label func (c *Comment) LoadLabel() error { var label Label - has, err := db.DefaultContext().Engine().ID(c.LabelID).Get(&label) + has, err := db.GetEngine(db.DefaultContext).ID(c.LabelID).Get(&label) if err != nil { return err } else if has { @@ -397,7 +397,7 @@ func (c *Comment) LoadLabel() error { func (c *Comment) LoadProject() error { if c.OldProjectID > 0 { var oldProject Project - has, err := db.DefaultContext().Engine().ID(c.OldProjectID).Get(&oldProject) + has, err := db.GetEngine(db.DefaultContext).ID(c.OldProjectID).Get(&oldProject) if err != nil { return err } else if has { @@ -407,7 +407,7 @@ func (c *Comment) LoadProject() error { if c.ProjectID > 0 { var project Project - has, err := db.DefaultContext().Engine().ID(c.ProjectID).Get(&project) + has, err := db.GetEngine(db.DefaultContext).ID(c.ProjectID).Get(&project) if err != nil { return err } else if has { @@ -422,7 +422,7 @@ func (c *Comment) LoadProject() error { func (c *Comment) LoadMilestone() error { if c.OldMilestoneID > 0 { var oldMilestone Milestone - has, err := db.DefaultContext().Engine().ID(c.OldMilestoneID).Get(&oldMilestone) + has, err := db.GetEngine(db.DefaultContext).ID(c.OldMilestoneID).Get(&oldMilestone) if err != nil { return err } else if has { @@ -432,7 +432,7 @@ func (c *Comment) LoadMilestone() error { if c.MilestoneID > 0 { var milestone Milestone - has, err := db.DefaultContext().Engine().ID(c.MilestoneID).Get(&milestone) + has, err := db.GetEngine(db.DefaultContext).ID(c.MilestoneID).Get(&milestone) if err != nil { return err } else if has { @@ -444,7 +444,7 @@ func (c *Comment) LoadMilestone() error { // LoadPoster loads comment poster func (c *Comment) LoadPoster() error { - return c.loadPoster(db.DefaultContext().Engine()) + return c.loadPoster(db.GetEngine(db.DefaultContext)) } // LoadAttachments loads attachments @@ -454,7 +454,7 @@ func (c *Comment) LoadAttachments() error { } var err error - c.Attachments, err = getAttachmentsByCommentID(db.DefaultContext().Engine(), c.ID) + c.Attachments, err = getAttachmentsByCommentID(db.GetEngine(db.DefaultContext), c.ID) if err != nil { log.Error("getAttachmentsByCommentID[%d]: %v", c.ID, err) } @@ -463,7 +463,7 @@ func (c *Comment) LoadAttachments() error { // UpdateAttachments update attachments by UUIDs for the comment func (c *Comment) UpdateAttachments(uuids []string) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -487,7 +487,7 @@ func (c *Comment) LoadAssigneeUserAndTeam() error { var err error if c.AssigneeID > 0 && c.Assignee == nil { - c.Assignee, err = getUserByID(db.DefaultContext().Engine(), c.AssigneeID) + c.Assignee, err = getUserByID(db.GetEngine(db.DefaultContext), c.AssigneeID) if err != nil { if !IsErrUserNotExist(err) { return err @@ -522,7 +522,7 @@ func (c *Comment) LoadResolveDoer() (err error) { if c.ResolveDoerID == 0 || c.Type != CommentTypeCode { return nil } - c.ResolveDoer, err = getUserByID(db.DefaultContext().Engine(), c.ResolveDoerID) + c.ResolveDoer, err = getUserByID(db.GetEngine(db.DefaultContext), c.ResolveDoerID) if err != nil { if IsErrUserNotExist(err) { c.ResolveDoer = NewGhostUser() @@ -542,7 +542,7 @@ func (c *Comment) LoadDepIssueDetails() (err error) { if c.DependentIssueID <= 0 || c.DependentIssue != nil { return nil } - c.DependentIssue, err = getIssueByID(db.DefaultContext().Engine(), c.DependentIssueID) + c.DependentIssue, err = getIssueByID(db.GetEngine(db.DefaultContext), c.DependentIssueID) return err } @@ -576,7 +576,7 @@ func (c *Comment) loadReactions(e db.Engine, repo *Repository) (err error) { // LoadReactions loads comment reactions func (c *Comment) LoadReactions(repo *Repository) error { - return c.loadReactions(db.DefaultContext().Engine(), repo) + return c.loadReactions(db.GetEngine(db.DefaultContext), repo) } func (c *Comment) loadReview(e db.Engine) (err error) { @@ -591,7 +591,7 @@ func (c *Comment) loadReview(e db.Engine) (err error) { // LoadReview loads the associated review func (c *Comment) LoadReview() error { - return c.loadReview(db.DefaultContext().Engine()) + return c.loadReview(db.GetEngine(db.DefaultContext)) } var notEnoughLines = regexp.MustCompile(`fatal: file .* has only \d+ lines?`) @@ -642,7 +642,7 @@ func (c *Comment) CodeCommentURL() string { log.Error("LoadIssue(%d): %v", c.IssueID, err) return "" } - err = c.Issue.loadRepo(db.DefaultContext().Engine()) + err = c.Issue.loadRepo(db.GetEngine(db.DefaultContext)) if err != nil { // Silently dropping errors :unamused: log.Error("loadRepo(%d): %v", c.Issue.RepoID, err) return "" @@ -899,7 +899,7 @@ type CreateCommentOptions struct { // CreateComment creates comment of issue or commit. func CreateComment(opts *CreateCommentOptions) (comment *Comment, err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return nil, err @@ -924,7 +924,7 @@ func CreateRefComment(doer *User, repo *Repository, issue *Issue, content, commi } // Check if same reference from same commit has already existed. - has, err := db.DefaultContext().Engine().Get(&Comment{ + has, err := db.GetEngine(db.DefaultContext).Get(&Comment{ Type: CommentTypeCommitRef, IssueID: issue.ID, CommitSHA: commitSHA, @@ -948,7 +948,7 @@ func CreateRefComment(doer *User, repo *Repository, issue *Issue, content, commi // GetCommentByID returns the comment by given ID. func GetCommentByID(id int64) (*Comment, error) { - return getCommentByID(db.DefaultContext().Engine(), id) + return getCommentByID(db.GetEngine(db.DefaultContext), id) } func getCommentByID(e db.Engine, id int64) (*Comment, error) { @@ -1025,12 +1025,12 @@ func findComments(e db.Engine, opts *FindCommentsOptions) ([]*Comment, error) { // FindComments returns all comments according options func FindComments(opts *FindCommentsOptions) ([]*Comment, error) { - return findComments(db.DefaultContext().Engine(), opts) + return findComments(db.GetEngine(db.DefaultContext), opts) } // CountComments count all comments according options by ignoring pagination func CountComments(opts *FindCommentsOptions) (int64, error) { - sess := db.DefaultContext().Engine().Where(opts.toConds()) + sess := db.GetEngine(db.DefaultContext).Where(opts.toConds()) if opts.RepoID > 0 { sess.Join("INNER", "issue", "issue.id = comment.issue_id") } @@ -1039,7 +1039,7 @@ func CountComments(opts *FindCommentsOptions) (int64, error) { // UpdateComment updates information of comment. func UpdateComment(c *Comment, doer *User) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -1063,7 +1063,7 @@ func UpdateComment(c *Comment, doer *User) error { // DeleteComment deletes the comment func DeleteComment(comment *Comment) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -1207,17 +1207,17 @@ func FetchCodeCommentsByLine(issue *Issue, currentUser *User, treePath string, l TreePath: treePath, Line: line, } - return findCodeComments(db.DefaultContext().Engine(), opts, issue, currentUser, nil) + return findCodeComments(db.GetEngine(db.DefaultContext), opts, issue, currentUser, nil) } // FetchCodeComments will return a 2d-map: ["Path"]["Line"] = Comments at line func FetchCodeComments(issue *Issue, currentUser *User) (CodeComments, error) { - return fetchCodeComments(db.DefaultContext().Engine(), issue, currentUser) + return fetchCodeComments(db.GetEngine(db.DefaultContext), issue, currentUser) } // UpdateCommentsMigrationsByType updates comments' migrations information via given git service type and original id and poster id func UpdateCommentsMigrationsByType(tp structs.GitServiceType, originalAuthorID string, posterID int64) error { - _, err := db.DefaultContext().Engine().Table("comment"). + _, err := db.GetEngine(db.DefaultContext).Table("comment"). Where(builder.In("issue_id", builder.Select("issue.id"). From("issue"). diff --git a/models/issue_comment_list.go b/models/issue_comment_list.go index b80fa129e2b8..bd1b48f8e5bf 100644 --- a/models/issue_comment_list.go +++ b/models/issue_comment_list.go @@ -526,20 +526,20 @@ func (comments CommentList) loadAttributes(e db.Engine) (err error) { // LoadAttributes loads attributes of the comments, except for attachments and // comments func (comments CommentList) LoadAttributes() error { - return comments.loadAttributes(db.DefaultContext().Engine()) + return comments.loadAttributes(db.GetEngine(db.DefaultContext)) } // LoadAttachments loads attachments func (comments CommentList) LoadAttachments() error { - return comments.loadAttachments(db.DefaultContext().Engine()) + return comments.loadAttachments(db.GetEngine(db.DefaultContext)) } // LoadPosters loads posters func (comments CommentList) LoadPosters() error { - return comments.loadPosters(db.DefaultContext().Engine()) + return comments.loadPosters(db.GetEngine(db.DefaultContext)) } // LoadIssues loads issues of comments func (comments CommentList) LoadIssues() error { - return comments.loadIssues(db.DefaultContext().Engine()) + return comments.loadIssues(db.GetEngine(db.DefaultContext)) } diff --git a/models/issue_dependency.go b/models/issue_dependency.go index 42ec5715f8aa..0dfb99ca014d 100644 --- a/models/issue_dependency.go +++ b/models/issue_dependency.go @@ -36,7 +36,7 @@ const ( // CreateIssueDependency creates a new dependency for an issue func CreateIssueDependency(user *User, issue, dep *Issue) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -77,7 +77,7 @@ func CreateIssueDependency(user *User, issue, dep *Issue) error { // RemoveIssueDependency removes a dependency from an issue func RemoveIssueDependency(user *User, issue, dep *Issue, depType DependencyType) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -118,7 +118,7 @@ func issueDepExists(e db.Engine, issueID, depID int64) (bool, error) { // IssueNoDependenciesLeft checks if issue can be closed func IssueNoDependenciesLeft(issue *Issue) (bool, error) { - return issueNoDependenciesLeft(db.DefaultContext().Engine(), issue) + return issueNoDependenciesLeft(db.GetEngine(db.DefaultContext), issue) } func issueNoDependenciesLeft(e db.Engine, issue *Issue) (bool, error) { @@ -135,7 +135,7 @@ func issueNoDependenciesLeft(e db.Engine, issue *Issue) (bool, error) { // IsDependenciesEnabled returns if dependencies are enabled and returns the default setting if not set. func (repo *Repository) IsDependenciesEnabled() bool { - return repo.isDependenciesEnabled(db.DefaultContext().Engine()) + return repo.isDependenciesEnabled(db.GetEngine(db.DefaultContext)) } func (repo *Repository) isDependenciesEnabled(e db.Engine) bool { diff --git a/models/issue_label.go b/models/issue_label.go index 5d7555e2bf24..87d7eb922123 100644 --- a/models/issue_label.go +++ b/models/issue_label.go @@ -6,6 +6,7 @@ package models import ( + "context" "fmt" "html/template" "math" @@ -242,8 +243,8 @@ func initializeLabels(e db.Engine, id int64, labelTemplate string, isOrg bool) e } // InitializeLabels adds a label set to a repository using a template -func InitializeLabels(ctx *db.Context, repoID int64, labelTemplate string, isOrg bool) error { - return initializeLabels(ctx.Engine(), repoID, labelTemplate, isOrg) +func InitializeLabels(ctx context.Context, repoID int64, labelTemplate string, isOrg bool) error { + return initializeLabels(db.GetEngine(ctx), repoID, labelTemplate, isOrg) } func newLabel(e db.Engine, label *Label) error { @@ -256,7 +257,7 @@ func NewLabel(label *Label) error { if !LabelColorPattern.MatchString(label.Color) { return fmt.Errorf("bad color code: %s", label.Color) } - return newLabel(db.DefaultContext().Engine(), label) + return newLabel(db.GetEngine(db.DefaultContext), label) } // NewLabels creates new labels @@ -271,7 +272,7 @@ func NewLabels(labels ...*Label) error { if !LabelColorPattern.MatchString(label.Color) { return fmt.Errorf("bad color code: %s", label.Color) } - if err := newLabel(ctx.Engine(), label); err != nil { + if err := newLabel(db.GetEngine(ctx), label); err != nil { return err } } @@ -283,7 +284,7 @@ func UpdateLabel(l *Label) error { if !LabelColorPattern.MatchString(l.Color) { return fmt.Errorf("bad color code: %s", l.Color) } - return updateLabelCols(db.DefaultContext().Engine(), l, "name", "description", "color") + return updateLabelCols(db.GetEngine(db.DefaultContext), l, "name", "description", "color") } // DeleteLabel delete a label @@ -296,7 +297,7 @@ func DeleteLabel(id, labelID int64) error { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -343,13 +344,13 @@ func getLabelByID(e db.Engine, labelID int64) (*Label, error) { // GetLabelByID returns a label by given ID. func GetLabelByID(id int64) (*Label, error) { - return getLabelByID(db.DefaultContext().Engine(), id) + return getLabelByID(db.GetEngine(db.DefaultContext), id) } // GetLabelsByIDs returns a list of labels by IDs func GetLabelsByIDs(labelIDs []int64) ([]*Label, error) { labels := make([]*Label, 0, len(labelIDs)) - return labels, db.DefaultContext().Engine().Table("label"). + return labels, db.GetEngine(db.DefaultContext).Table("label"). In("id", labelIDs). Asc("name"). Cols("id", "repo_id", "org_id"). @@ -403,7 +404,7 @@ func getLabelInRepoByID(e db.Engine, repoID, labelID int64) (*Label, error) { // GetLabelInRepoByName returns a label by name in given repository. func GetLabelInRepoByName(repoID int64, labelName string) (*Label, error) { - return getLabelInRepoByName(db.DefaultContext().Engine(), repoID, labelName) + return getLabelInRepoByName(db.GetEngine(db.DefaultContext), repoID, labelName) } // GetLabelIDsInRepoByNames returns a list of labelIDs by names in a given @@ -411,7 +412,7 @@ func GetLabelInRepoByName(repoID int64, labelName string) (*Label, error) { // it silently ignores label names that do not belong to the repository. func GetLabelIDsInRepoByNames(repoID int64, labelNames []string) ([]int64, error) { labelIDs := make([]int64, 0, len(labelNames)) - return labelIDs, db.DefaultContext().Engine().Table("label"). + return labelIDs, db.GetEngine(db.DefaultContext).Table("label"). Where("repo_id = ?", repoID). In("name", labelNames). Asc("name"). @@ -432,14 +433,14 @@ func BuildLabelNamesIssueIDsCondition(labelNames []string) *builder.Builder { // GetLabelInRepoByID returns a label by ID in given repository. func GetLabelInRepoByID(repoID, labelID int64) (*Label, error) { - return getLabelInRepoByID(db.DefaultContext().Engine(), repoID, labelID) + return getLabelInRepoByID(db.GetEngine(db.DefaultContext), repoID, labelID) } // GetLabelsInRepoByIDs returns a list of labels by IDs in given repository, // it silently ignores label IDs that do not belong to the repository. func GetLabelsInRepoByIDs(repoID int64, labelIDs []int64) ([]*Label, error) { labels := make([]*Label, 0, len(labelIDs)) - return labels, db.DefaultContext().Engine(). + return labels, db.GetEngine(db.DefaultContext). Where("repo_id = ?", repoID). In("id", labelIDs). Asc("name"). @@ -473,12 +474,12 @@ func getLabelsByRepoID(e db.Engine, repoID int64, sortType string, listOptions L // GetLabelsByRepoID returns all labels that belong to given repository by ID. func GetLabelsByRepoID(repoID int64, sortType string, listOptions ListOptions) ([]*Label, error) { - return getLabelsByRepoID(db.DefaultContext().Engine(), repoID, sortType, listOptions) + return getLabelsByRepoID(db.GetEngine(db.DefaultContext), repoID, sortType, listOptions) } // CountLabelsByRepoID count number of all labels that belong to given repository by ID. func CountLabelsByRepoID(repoID int64) (int64, error) { - return db.DefaultContext().Engine().Where("repo_id = ?", repoID).Count(&Label{}) + return db.GetEngine(db.DefaultContext).Where("repo_id = ?", repoID).Count(&Label{}) } // ________ @@ -528,7 +529,7 @@ func getLabelInOrgByID(e db.Engine, orgID, labelID int64) (*Label, error) { // GetLabelInOrgByName returns a label by name in given organization. func GetLabelInOrgByName(orgID int64, labelName string) (*Label, error) { - return getLabelInOrgByName(db.DefaultContext().Engine(), orgID, labelName) + return getLabelInOrgByName(db.GetEngine(db.DefaultContext), orgID, labelName) } // GetLabelIDsInOrgByNames returns a list of labelIDs by names in a given @@ -539,7 +540,7 @@ func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error) } labelIDs := make([]int64, 0, len(labelNames)) - return labelIDs, db.DefaultContext().Engine().Table("label"). + return labelIDs, db.GetEngine(db.DefaultContext).Table("label"). Where("org_id = ?", orgID). In("name", labelNames). Asc("name"). @@ -549,14 +550,14 @@ func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error) // GetLabelInOrgByID returns a label by ID in given organization. func GetLabelInOrgByID(orgID, labelID int64) (*Label, error) { - return getLabelInOrgByID(db.DefaultContext().Engine(), orgID, labelID) + return getLabelInOrgByID(db.GetEngine(db.DefaultContext), orgID, labelID) } // GetLabelsInOrgByIDs returns a list of labels by IDs in given organization, // it silently ignores label IDs that do not belong to the organization. func GetLabelsInOrgByIDs(orgID int64, labelIDs []int64) ([]*Label, error) { labels := make([]*Label, 0, len(labelIDs)) - return labels, db.DefaultContext().Engine(). + return labels, db.GetEngine(db.DefaultContext). Where("org_id = ?", orgID). In("id", labelIDs). Asc("name"). @@ -590,12 +591,12 @@ func getLabelsByOrgID(e db.Engine, orgID int64, sortType string, listOptions Lis // GetLabelsByOrgID returns all labels that belong to given organization by ID. func GetLabelsByOrgID(orgID int64, sortType string, listOptions ListOptions) ([]*Label, error) { - return getLabelsByOrgID(db.DefaultContext().Engine(), orgID, sortType, listOptions) + return getLabelsByOrgID(db.GetEngine(db.DefaultContext), orgID, sortType, listOptions) } // CountLabelsByOrgID count all labels that belong to given organization by ID. func CountLabelsByOrgID(orgID int64) (int64, error) { - return db.DefaultContext().Engine().Where("org_id = ?", orgID).Count(&Label{}) + return db.GetEngine(db.DefaultContext).Where("org_id = ?", orgID).Count(&Label{}) } // .___ @@ -615,7 +616,7 @@ func getLabelsByIssueID(e db.Engine, issueID int64) ([]*Label, error) { // GetLabelsByIssueID returns all labels that belong to given issue by ID. func GetLabelsByIssueID(issueID int64) ([]*Label, error) { - return getLabelsByIssueID(db.DefaultContext().Engine(), issueID) + return getLabelsByIssueID(db.GetEngine(db.DefaultContext), issueID) } func updateLabelCols(e db.Engine, l *Label, cols ...string) error { @@ -657,7 +658,7 @@ func hasIssueLabel(e db.Engine, issueID, labelID int64) bool { // HasIssueLabel returns true if issue has been labeled. func HasIssueLabel(issueID, labelID int64) bool { - return hasIssueLabel(db.DefaultContext().Engine(), issueID, labelID) + return hasIssueLabel(db.GetEngine(db.DefaultContext), issueID, labelID) } // newIssueLabel this function creates a new label it does not check if the label is valid for the issue @@ -695,7 +696,7 @@ func NewIssueLabel(issue *Issue, label *Label, doer *User) (err error) { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -750,12 +751,12 @@ func NewIssueLabels(issue *Issue, labels []*Label, doer *User) (err error) { } defer committer.Close() - if err = newIssueLabels(ctx.Engine(), issue, labels, doer); err != nil { + if err = newIssueLabels(db.GetEngine(ctx), issue, labels, doer); err != nil { return err } issue.Labels = nil - if err = issue.loadLabels(ctx.Engine()); err != nil { + if err = issue.loadLabels(db.GetEngine(ctx)); err != nil { return err } @@ -792,7 +793,7 @@ func deleteIssueLabel(e db.Engine, issue *Issue, label *Label, doer *User) (err // DeleteIssueLabel deletes issue-label relation. func DeleteIssueLabel(issue *Issue, label *Label, doer *User) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err diff --git a/models/issue_list.go b/models/issue_list.go index bac0803d8a73..ac7ec7ccbf24 100644 --- a/models/issue_list.go +++ b/models/issue_list.go @@ -63,7 +63,7 @@ func (issues IssueList) loadRepositories(e db.Engine) ([]*Repository, error) { // LoadRepositories loads issues' all repositories func (issues IssueList) LoadRepositories() ([]*Repository, error) { - return issues.loadRepositories(db.DefaultContext().Engine()) + return issues.loadRepositories(db.GetEngine(db.DefaultContext)) } func (issues IssueList) getPosterIDs() []int64 { @@ -502,33 +502,33 @@ func (issues IssueList) loadAttributes(e db.Engine) error { // LoadAttributes loads attributes of the issues, except for attachments and // comments func (issues IssueList) LoadAttributes() error { - return issues.loadAttributes(db.DefaultContext().Engine()) + return issues.loadAttributes(db.GetEngine(db.DefaultContext)) } // LoadAttachments loads attachments func (issues IssueList) LoadAttachments() error { - return issues.loadAttachments(db.DefaultContext().Engine()) + return issues.loadAttachments(db.GetEngine(db.DefaultContext)) } // LoadComments loads comments func (issues IssueList) LoadComments() error { - return issues.loadComments(db.DefaultContext().Engine(), builder.NewCond()) + return issues.loadComments(db.GetEngine(db.DefaultContext), builder.NewCond()) } // LoadDiscussComments loads discuss comments func (issues IssueList) LoadDiscussComments() error { - return issues.loadComments(db.DefaultContext().Engine(), builder.Eq{"comment.type": CommentTypeComment}) + return issues.loadComments(db.GetEngine(db.DefaultContext), builder.Eq{"comment.type": CommentTypeComment}) } // LoadPullRequests loads pull requests func (issues IssueList) LoadPullRequests() error { - return issues.loadPullRequests(db.DefaultContext().Engine()) + return issues.loadPullRequests(db.GetEngine(db.DefaultContext)) } // GetApprovalCounts returns a map of issue ID to slice of approval counts // FIXME: only returns official counts due to double counting of non-official approvals func (issues IssueList) GetApprovalCounts() (map[int64][]*ReviewCount, error) { - return issues.getApprovalCounts(db.DefaultContext().Engine()) + return issues.getApprovalCounts(db.GetEngine(db.DefaultContext)) } func (issues IssueList) getApprovalCounts(e db.Engine) (map[int64][]*ReviewCount, error) { diff --git a/models/issue_lock.go b/models/issue_lock.go index f7474bfd1d7c..d8e3b4c0aba8 100644 --- a/models/issue_lock.go +++ b/models/issue_lock.go @@ -37,7 +37,7 @@ func updateIssueLock(opts *IssueLockOptions, lock bool) error { commentType = CommentTypeUnlock } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err diff --git a/models/issue_milestone.go b/models/issue_milestone.go index da321911f66d..fb6ced5b41a3 100644 --- a/models/issue_milestone.go +++ b/models/issue_milestone.go @@ -85,7 +85,7 @@ func (m *Milestone) State() api.StateType { // NewMilestone creates new milestone of repository. func NewMilestone(m *Milestone) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -116,13 +116,13 @@ func getMilestoneByRepoID(e db.Engine, repoID, id int64) (*Milestone, error) { // GetMilestoneByRepoID returns the milestone in a repository. func GetMilestoneByRepoID(repoID, id int64) (*Milestone, error) { - return getMilestoneByRepoID(db.DefaultContext().Engine(), repoID, id) + return getMilestoneByRepoID(db.GetEngine(db.DefaultContext), repoID, id) } // GetMilestoneByRepoIDANDName return a milestone if one exist by name and repo func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) { var mile Milestone - has, err := db.DefaultContext().Engine().Where("repo_id=? AND name=?", repoID, name).Get(&mile) + has, err := db.GetEngine(db.DefaultContext).Where("repo_id=? AND name=?", repoID, name).Get(&mile) if err != nil { return nil, err } @@ -134,7 +134,7 @@ func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) // GetMilestoneByID returns the milestone via id . func GetMilestoneByID(id int64) (*Milestone, error) { - return getMilestoneByID(db.DefaultContext().Engine(), id) + return getMilestoneByID(db.GetEngine(db.DefaultContext), id) } func getMilestoneByID(e db.Engine, id int64) (*Milestone, error) { @@ -150,7 +150,7 @@ func getMilestoneByID(e db.Engine, id int64) (*Milestone, error) { // UpdateMilestone updates information of given milestone. func UpdateMilestone(m *Milestone, oldIsClosed bool) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -207,7 +207,7 @@ func updateMilestoneCounters(e db.Engine, id int64) error { // ChangeMilestoneStatusByRepoIDAndID changes a milestone open/closed status if the milestone ID is in the repo. func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -234,7 +234,7 @@ func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool // ChangeMilestoneStatus changes the milestone open/closed status. func ChangeMilestoneStatus(m *Milestone, isClosed bool) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -303,7 +303,7 @@ func changeMilestoneAssign(e *xorm.Session, doer *User, issue *Issue, oldMilesto // ChangeMilestoneAssign changes assignment of milestone for issue. func ChangeMilestoneAssign(issue *Issue, doer *User, oldMilestoneID int64) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -334,7 +334,7 @@ func DeleteMilestoneByRepoID(repoID, id int64) error { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -410,7 +410,7 @@ func (opts GetMilestonesOption) toCond() builder.Cond { // GetMilestones returns milestones filtered by GetMilestonesOption's func GetMilestones(opts GetMilestonesOption) (MilestoneList, int64, error) { - sess := db.DefaultContext().Engine().Where(opts.toCond()) + sess := db.GetEngine(db.DefaultContext).Where(opts.toCond()) if opts.Page != 0 { sess = setSessionPagination(sess, &opts) @@ -441,7 +441,7 @@ func GetMilestones(opts GetMilestonesOption) (MilestoneList, int64, error) { // SearchMilestones search milestones func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType string, keyword string) (MilestoneList, error) { miles := make([]*Milestone, 0, setting.UI.IssuePagingNum) - sess := db.DefaultContext().Engine().Where("is_closed = ?", isClosed) + sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -502,7 +502,7 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro var err error stats := &MilestonesStats{} - sess := db.DefaultContext().Engine().Where("is_closed = ?", false) + sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false) if repoCond.IsValid() { sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) } @@ -511,7 +511,7 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro return nil, err } - sess = db.DefaultContext().Engine().Where("is_closed = ?", true) + sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true) if repoCond.IsValid() { sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) } @@ -528,7 +528,7 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (* var err error stats := &MilestonesStats{} - sess := db.DefaultContext().Engine().Where("is_closed = ?", false) + sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -540,7 +540,7 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (* return nil, err } - sess = db.DefaultContext().Engine().Where("is_closed = ?", true) + sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -569,12 +569,12 @@ func countRepoClosedMilestones(e db.Engine, repoID int64) (int64, error) { // CountRepoClosedMilestones returns number of closed milestones in given repository. func CountRepoClosedMilestones(repoID int64) (int64, error) { - return countRepoClosedMilestones(db.DefaultContext().Engine(), repoID) + return countRepoClosedMilestones(db.GetEngine(db.DefaultContext), repoID) } // CountMilestonesByRepoCond map from repo conditions to number of milestones matching the options` func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64]int64, error) { - sess := db.DefaultContext().Engine().Where("is_closed = ?", isClosed) + sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) if repoCond.IsValid() { sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond)) } @@ -599,7 +599,7 @@ func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64] // CountMilestonesByRepoCondAndKw map from repo conditions and the keyword of milestones' name to number of milestones matching the options` func CountMilestonesByRepoCondAndKw(repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) { - sess := db.DefaultContext().Engine().Where("is_closed = ?", isClosed) + sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -707,10 +707,10 @@ func (m *Milestone) loadTotalTrackedTime(e db.Engine) error { // LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request func (milestones MilestoneList) LoadTotalTrackedTimes() error { - return milestones.loadTotalTrackedTimes(db.DefaultContext().Engine()) + return milestones.loadTotalTrackedTimes(db.GetEngine(db.DefaultContext)) } // LoadTotalTrackedTime loads the tracked time for the milestone func (m *Milestone) LoadTotalTrackedTime() error { - return m.loadTotalTrackedTime(db.DefaultContext().Engine()) + return m.loadTotalTrackedTime(db.GetEngine(db.DefaultContext)) } diff --git a/models/issue_milestone_test.go b/models/issue_milestone_test.go index fd12efb1467a..519b65715d15 100644 --- a/models/issue_milestone_test.go +++ b/models/issue_milestone_test.go @@ -173,7 +173,7 @@ func TestCountRepoMilestones(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) test := func(repoID int64) { repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository) - count, err := countRepoMilestones(db.DefaultContext().Engine(), repoID) + count, err := countRepoMilestones(db.GetEngine(db.DefaultContext), repoID) assert.NoError(t, err) assert.EqualValues(t, repo.NumMilestones, count) } @@ -181,7 +181,7 @@ func TestCountRepoMilestones(t *testing.T) { test(2) test(3) - count, err := countRepoMilestones(db.DefaultContext().Engine(), db.NonexistentID) + count, err := countRepoMilestones(db.GetEngine(db.DefaultContext), db.NonexistentID) assert.NoError(t, err) assert.EqualValues(t, 0, count) } @@ -223,16 +223,16 @@ func TestUpdateMilestoneCounters(t *testing.T) { issue.IsClosed = true issue.ClosedUnix = timeutil.TimeStampNow() - _, err := db.DefaultContext().Engine().ID(issue.ID).Cols("is_closed", "closed_unix").Update(issue) + _, err := db.GetEngine(db.DefaultContext).ID(issue.ID).Cols("is_closed", "closed_unix").Update(issue) assert.NoError(t, err) - assert.NoError(t, updateMilestoneCounters(db.DefaultContext().Engine(), issue.MilestoneID)) + assert.NoError(t, updateMilestoneCounters(db.GetEngine(db.DefaultContext), issue.MilestoneID)) CheckConsistencyFor(t, &Milestone{}) issue.IsClosed = false issue.ClosedUnix = 0 - _, err = db.DefaultContext().Engine().ID(issue.ID).Cols("is_closed", "closed_unix").Update(issue) + _, err = db.GetEngine(db.DefaultContext).ID(issue.ID).Cols("is_closed", "closed_unix").Update(issue) assert.NoError(t, err) - assert.NoError(t, updateMilestoneCounters(db.DefaultContext().Engine(), issue.MilestoneID)) + assert.NoError(t, updateMilestoneCounters(db.GetEngine(db.DefaultContext), issue.MilestoneID)) CheckConsistencyFor(t, &Milestone{}) } diff --git a/models/issue_reaction.go b/models/issue_reaction.go index 8fd22f6ca8c6..4e49add5c2be 100644 --- a/models/issue_reaction.go +++ b/models/issue_reaction.go @@ -71,7 +71,7 @@ func (opts *FindReactionsOptions) toConds() builder.Cond { // FindCommentReactions returns a ReactionList of all reactions from an comment func FindCommentReactions(comment *Comment) (ReactionList, error) { - return findReactions(db.DefaultContext().Engine(), FindReactionsOptions{ + return findReactions(db.GetEngine(db.DefaultContext), FindReactionsOptions{ IssueID: comment.IssueID, CommentID: comment.ID, }) @@ -79,7 +79,7 @@ func FindCommentReactions(comment *Comment) (ReactionList, error) { // FindIssueReactions returns a ReactionList of all reactions from an issue func FindIssueReactions(issue *Issue, listOptions ListOptions) (ReactionList, error) { - return findReactions(db.DefaultContext().Engine(), FindReactionsOptions{ + return findReactions(db.GetEngine(db.DefaultContext), FindReactionsOptions{ ListOptions: listOptions, IssueID: issue.ID, CommentID: -1, @@ -148,7 +148,7 @@ func CreateReaction(opts *ReactionOptions) (*Reaction, error) { return nil, ErrForbiddenIssueReaction{opts.Type} } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -203,7 +203,7 @@ func deleteReaction(e db.Engine, opts *ReactionOptions) error { // DeleteReaction deletes reaction for issue or comment. func DeleteReaction(opts *ReactionOptions) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -240,7 +240,7 @@ func (r *Reaction) LoadUser() (*User, error) { if r.User != nil { return r.User, nil } - user, err := getUserByID(db.DefaultContext().Engine(), r.UserID) + user, err := getUserByID(db.GetEngine(db.DefaultContext), r.UserID) if err != nil { return nil, err } @@ -314,7 +314,7 @@ func (list ReactionList) loadUsers(e db.Engine, repo *Repository) ([]*User, erro // LoadUsers loads reactions' all users func (list ReactionList) LoadUsers(repo *Repository) ([]*User, error) { - return list.loadUsers(db.DefaultContext().Engine(), repo) + return list.loadUsers(db.GetEngine(db.DefaultContext), repo) } // GetFirstUsers returns first reacted user display names separated by comma diff --git a/models/issue_reaction_test.go b/models/issue_reaction_test.go index b74d0ce9b880..dd15b816c73f 100644 --- a/models/issue_reaction_test.go +++ b/models/issue_reaction_test.go @@ -93,7 +93,7 @@ func TestIssueReactionCount(t *testing.T) { addReaction(t, user4, issue, nil, "heart") addReaction(t, ghost, issue, nil, "-1") - err := issue.loadReactions(db.DefaultContext().Engine()) + err := issue.loadReactions(db.GetEngine(db.DefaultContext)) assert.NoError(t, err) assert.Len(t, issue.Reactions, 7) diff --git a/models/issue_stopwatch.go b/models/issue_stopwatch.go index 1e8cf4c6a65d..157658e182df 100644 --- a/models/issue_stopwatch.go +++ b/models/issue_stopwatch.go @@ -48,7 +48,7 @@ func getStopwatch(e db.Engine, userID, issueID int64) (sw *Stopwatch, exists boo // GetUserStopwatches return list of all stopwatches of a user func GetUserStopwatches(userID int64, listOptions ListOptions) ([]*Stopwatch, error) { sws := make([]*Stopwatch, 0, 8) - sess := db.DefaultContext().Engine().Where("stopwatch.user_id = ?", userID) + sess := db.GetEngine(db.DefaultContext).Where("stopwatch.user_id = ?", userID) if listOptions.Page != 0 { sess = setSessionPagination(sess, &listOptions) } @@ -62,18 +62,18 @@ func GetUserStopwatches(userID int64, listOptions ListOptions) ([]*Stopwatch, er // CountUserStopwatches return count of all stopwatches of a user func CountUserStopwatches(userID int64) (int64, error) { - return db.DefaultContext().Engine().Where("user_id = ?", userID).Count(&Stopwatch{}) + return db.GetEngine(db.DefaultContext).Where("user_id = ?", userID).Count(&Stopwatch{}) } // StopwatchExists returns true if the stopwatch exists func StopwatchExists(userID, issueID int64) bool { - _, exists, _ := getStopwatch(db.DefaultContext().Engine(), userID, issueID) + _, exists, _ := getStopwatch(db.GetEngine(db.DefaultContext), userID, issueID) return exists } // HasUserStopwatch returns true if the user has a stopwatch func HasUserStopwatch(userID int64) (exists bool, sw *Stopwatch, err error) { - return hasUserStopwatch(db.DefaultContext().Engine(), userID) + return hasUserStopwatch(db.GetEngine(db.DefaultContext), userID) } func hasUserStopwatch(e db.Engine, userID int64) (exists bool, sw *Stopwatch, err error) { @@ -86,7 +86,7 @@ func hasUserStopwatch(e db.Engine, userID int64) (exists bool, sw *Stopwatch, er // CreateOrStopIssueStopwatch will create or remove a stopwatch and will log it into issue's timeline. func CreateOrStopIssueStopwatch(user *User, issue *Issue) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -175,7 +175,7 @@ func createOrStopIssueStopwatch(e *xorm.Session, user *User, issue *Issue) error // CancelStopwatch removes the given stopwatch and logs it into issue's timeline. func CancelStopwatch(user *User, issue *Issue) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err diff --git a/models/issue_test.go b/models/issue_test.go index a0ced7755846..d5f6f36e9c6e 100644 --- a/models/issue_test.go +++ b/models/issue_test.go @@ -77,7 +77,7 @@ func TestGetParticipantIDsByIssue(t *testing.T) { checkParticipants := func(issueID int64, userIDs []int) { issue, err := GetIssueByID(issueID) assert.NoError(t, err) - participants, err := issue.getParticipantIDsByIssue(db.DefaultContext().Engine()) + participants, err := issue.getParticipantIDsByIssue(db.GetEngine(db.DefaultContext)) if assert.NoError(t, err) { participantsIDs := make([]int, len(participants)) for i, uid := range participants { @@ -125,7 +125,7 @@ func TestUpdateIssueCols(t *testing.T) { issue.Content = "This should have no effect" now := time.Now().Unix() - assert.NoError(t, updateIssueCols(db.DefaultContext().Engine(), issue, "name")) + assert.NoError(t, updateIssueCols(db.GetEngine(db.DefaultContext), issue, "name")) then := time.Now().Unix() updatedIssue := db.AssertExistsAndLoadBean(t, &Issue{ID: issue.ID}).(*Issue) @@ -290,7 +290,7 @@ func TestIssue_loadTotalTimes(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) ms, err := GetIssueByID(2) assert.NoError(t, err) - assert.NoError(t, ms.loadTotalTimes(db.DefaultContext().Engine())) + assert.NoError(t, ms.loadTotalTimes(db.GetEngine(db.DefaultContext))) assert.Equal(t, int64(3682), ms.TotalTrackedTime) } @@ -363,7 +363,7 @@ func testInsertIssue(t *testing.T, title, content string, expectIndex int64) *Is err := NewIssue(repo, &issue, nil, nil) assert.NoError(t, err) - has, err := db.DefaultContext().Engine().ID(issue.ID).Get(&newIssue) + has, err := db.GetEngine(db.DefaultContext).ID(issue.ID).Get(&newIssue) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, issue.Title, newIssue.Title) @@ -380,11 +380,11 @@ func TestIssue_InsertIssue(t *testing.T) { // there are 5 issues and max index is 5 on repository 1, so this one should 6 issue := testInsertIssue(t, "my issue1", "special issue's comments?", 6) - _, err := db.DefaultContext().Engine().ID(issue.ID).Delete(new(Issue)) + _, err := db.GetEngine(db.DefaultContext).ID(issue.ID).Delete(new(Issue)) assert.NoError(t, err) issue = testInsertIssue(t, `my issue2, this is my son's love \n \r \ `, "special issue's '' comments?", 7) - _, err = db.DefaultContext().Engine().ID(issue.ID).Delete(new(Issue)) + _, err = db.GetEngine(db.DefaultContext).ID(issue.ID).Delete(new(Issue)) assert.NoError(t, err) } @@ -397,7 +397,7 @@ func TestIssue_ResolveMentions(t *testing.T) { r := db.AssertExistsAndLoadBean(t, &Repository{OwnerID: o.ID, LowerName: repo}).(*Repository) issue := &Issue{RepoID: r.ID} d := db.AssertExistsAndLoadBean(t, &User{LowerName: doer}).(*User) - resolved, err := issue.ResolveMentionsByVisibility(db.DefaultContext(), d, mentions) + resolved, err := issue.ResolveMentionsByVisibility(db.DefaultContext, d, mentions) assert.NoError(t, err) ids := make([]int64, len(resolved)) for i, user := range resolved { diff --git a/models/issue_tracked_time.go b/models/issue_tracked_time.go index 77b44df1ee39..d024c6896f62 100644 --- a/models/issue_tracked_time.go +++ b/models/issue_tracked_time.go @@ -40,7 +40,7 @@ func (t *TrackedTime) AfterLoad() { // LoadAttributes load Issue, User func (t *TrackedTime) LoadAttributes() (err error) { - return t.loadAttributes(db.DefaultContext().Engine()) + return t.loadAttributes(db.GetEngine(db.DefaultContext)) } func (t *TrackedTime) loadAttributes(e db.Engine) (err error) { @@ -131,12 +131,12 @@ func getTrackedTimes(e db.Engine, options *FindTrackedTimesOptions) (trackedTime // GetTrackedTimes returns all tracked times that fit to the given options. func GetTrackedTimes(opts *FindTrackedTimesOptions) (TrackedTimeList, error) { - return getTrackedTimes(db.DefaultContext().Engine(), opts) + return getTrackedTimes(db.GetEngine(db.DefaultContext), opts) } // CountTrackedTimes returns count of tracked times that fit to the given options. func CountTrackedTimes(opts *FindTrackedTimesOptions) (int64, error) { - sess := db.DefaultContext().Engine().Where(opts.toCond()) + sess := db.GetEngine(db.DefaultContext).Where(opts.toCond()) if opts.RepositoryID > 0 || opts.MilestoneID > 0 { sess = sess.Join("INNER", "issue", "issue.id = tracked_time.issue_id") } @@ -149,12 +149,12 @@ func getTrackedSeconds(e db.Engine, opts FindTrackedTimesOptions) (trackedSecond // GetTrackedSeconds return sum of seconds func GetTrackedSeconds(opts FindTrackedTimesOptions) (int64, error) { - return getTrackedSeconds(db.DefaultContext().Engine(), opts) + return getTrackedSeconds(db.GetEngine(db.DefaultContext), opts) } // AddTime will add the given time (in seconds) to the issue func AddTime(user *User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { @@ -230,7 +230,7 @@ func TotalTimes(options *FindTrackedTimesOptions) (map[*User]string, error) { // DeleteIssueUserTimes deletes times for issue func DeleteIssueUserTimes(issue *Issue, user *User) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { @@ -268,7 +268,7 @@ func DeleteIssueUserTimes(issue *Issue, user *User) error { // DeleteTime delete a specific Time func DeleteTime(t *TrackedTime) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { @@ -318,7 +318,7 @@ func deleteTime(e db.Engine, t *TrackedTime) error { // GetTrackedTimeByID returns raw TrackedTime without loading attributes by id func GetTrackedTimeByID(id int64) (*TrackedTime, error) { time := new(TrackedTime) - has, err := db.DefaultContext().Engine().ID(id).Get(time) + has, err := db.GetEngine(db.DefaultContext).ID(id).Get(time) if err != nil { return nil, err } else if !has { diff --git a/models/issue_user.go b/models/issue_user.go index 6f9b7591cd26..b112441e5b16 100644 --- a/models/issue_user.go +++ b/models/issue_user.go @@ -5,6 +5,7 @@ package models import ( + "context" "fmt" "code.gitea.io/gitea/models/db" @@ -57,27 +58,27 @@ func newIssueUsers(e db.Engine, repo *Repository, issue *Issue) error { // UpdateIssueUserByRead updates issue-user relation for reading. func UpdateIssueUserByRead(uid, issueID int64) error { - _, err := db.DefaultContext().Engine().Exec("UPDATE `issue_user` SET is_read=? WHERE uid=? AND issue_id=?", true, uid, issueID) + _, err := db.GetEngine(db.DefaultContext).Exec("UPDATE `issue_user` SET is_read=? WHERE uid=? AND issue_id=?", true, uid, issueID) return err } // UpdateIssueUsersByMentions updates issue-user pairs by mentioning. -func UpdateIssueUsersByMentions(ctx *db.Context, issueID int64, uids []int64) error { +func UpdateIssueUsersByMentions(ctx context.Context, issueID int64, uids []int64) error { for _, uid := range uids { iu := &IssueUser{ UID: uid, IssueID: issueID, } - has, err := ctx.Engine().Get(iu) + has, err := db.GetEngine(ctx).Get(iu) if err != nil { return err } iu.IsMentioned = true if has { - _, err = ctx.Engine().ID(iu.ID).Cols("is_mentioned").Update(iu) + _, err = db.GetEngine(ctx).ID(iu.ID).Cols("is_mentioned").Update(iu) } else { - _, err = ctx.Engine().Insert(iu) + _, err = db.GetEngine(ctx).Insert(iu) } if err != nil { return err diff --git a/models/issue_user_test.go b/models/issue_user_test.go index ca45d065635a..d4e504719fc6 100644 --- a/models/issue_user_test.go +++ b/models/issue_user_test.go @@ -26,7 +26,7 @@ func Test_newIssueUsers(t *testing.T) { // artificially insert new issue db.AssertSuccessfulInsert(t, newIssue) - assert.NoError(t, newIssueUsers(db.DefaultContext().Engine(), repo, newIssue)) + assert.NoError(t, newIssueUsers(db.GetEngine(db.DefaultContext), repo, newIssue)) // issue_user table should now have entries for new issue db.AssertExistsAndLoadBean(t, &IssueUser{IssueID: newIssue.ID, UID: newIssue.PosterID}) @@ -51,7 +51,7 @@ func TestUpdateIssueUsersByMentions(t *testing.T) { issue := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue) uids := []int64{2, 5} - assert.NoError(t, UpdateIssueUsersByMentions(db.DefaultContext(), issue.ID, uids)) + assert.NoError(t, UpdateIssueUsersByMentions(db.DefaultContext, issue.ID, uids)) for _, uid := range uids { db.AssertExistsAndLoadBean(t, &IssueUser{IssueID: issue.ID, UID: uid}, "is_mentioned=1") } diff --git a/models/issue_watch.go b/models/issue_watch.go index dd693f82f58c..cc1edcba1b66 100644 --- a/models/issue_watch.go +++ b/models/issue_watch.go @@ -28,7 +28,7 @@ type IssueWatchList []*IssueWatch // CreateOrUpdateIssueWatch set watching for a user and issue func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { - iw, exists, err := getIssueWatch(db.DefaultContext().Engine(), userID, issueID) + iw, exists, err := getIssueWatch(db.GetEngine(db.DefaultContext), userID, issueID) if err != nil { return err } @@ -40,13 +40,13 @@ func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { IsWatching: isWatching, } - if _, err := db.DefaultContext().Engine().Insert(iw); err != nil { + if _, err := db.GetEngine(db.DefaultContext).Insert(iw); err != nil { return err } } else { iw.IsWatching = isWatching - if _, err := db.DefaultContext().Engine().ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil { + if _, err := db.GetEngine(db.DefaultContext).ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil { return err } } @@ -56,7 +56,7 @@ func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { // GetIssueWatch returns all IssueWatch objects from db by user and issue // the current Web-UI need iw object for watchers AND explicit non-watchers func GetIssueWatch(userID, issueID int64) (iw *IssueWatch, exists bool, err error) { - return getIssueWatch(db.DefaultContext().Engine(), userID, issueID) + return getIssueWatch(db.GetEngine(db.DefaultContext), userID, issueID) } // Return watcher AND explicit non-watcher if entry in db exist @@ -72,14 +72,14 @@ func getIssueWatch(e db.Engine, userID, issueID int64) (iw *IssueWatch, exists b // CheckIssueWatch check if an user is watching an issue // it takes participants and repo watch into account func CheckIssueWatch(user *User, issue *Issue) (bool, error) { - iw, exist, err := getIssueWatch(db.DefaultContext().Engine(), user.ID, issue.ID) + iw, exist, err := getIssueWatch(db.GetEngine(db.DefaultContext), user.ID, issue.ID) if err != nil { return false, err } if exist { return iw.IsWatching, nil } - w, err := getWatch(db.DefaultContext().Engine(), user.ID, issue.RepoID) + w, err := getWatch(db.GetEngine(db.DefaultContext), user.ID, issue.RepoID) if err != nil { return false, err } @@ -90,7 +90,7 @@ func CheckIssueWatch(user *User, issue *Issue) (bool, error) { // but avoids joining with `user` for performance reasons // User permissions must be verified elsewhere if required func GetIssueWatchersIDs(issueID int64, watching bool) ([]int64, error) { - return getIssueWatchersIDs(db.DefaultContext().Engine(), issueID, watching) + return getIssueWatchersIDs(db.GetEngine(db.DefaultContext), issueID, watching) } func getIssueWatchersIDs(e db.Engine, issueID int64, watching bool) ([]int64, error) { @@ -104,7 +104,7 @@ func getIssueWatchersIDs(e db.Engine, issueID int64, watching bool) ([]int64, er // GetIssueWatchers returns watchers/unwatchers of a given issue func GetIssueWatchers(issueID int64, listOptions ListOptions) (IssueWatchList, error) { - return getIssueWatchers(db.DefaultContext().Engine(), issueID, listOptions) + return getIssueWatchers(db.GetEngine(db.DefaultContext), issueID, listOptions) } func getIssueWatchers(e db.Engine, issueID int64, listOptions ListOptions) (IssueWatchList, error) { diff --git a/models/issue_xref.go b/models/issue_xref.go index 3f9c1e7f5ca2..4630f4d3a483 100644 --- a/models/issue_xref.go +++ b/models/issue_xref.go @@ -277,7 +277,7 @@ func (comment *Comment) LoadRefIssue() (err error) { } comment.RefIssue, err = GetIssueByID(comment.RefIssueID) if err == nil { - err = comment.RefIssue.loadRepo(db.DefaultContext().Engine()) + err = comment.RefIssue.loadRepo(db.GetEngine(db.DefaultContext)) } return } @@ -337,7 +337,7 @@ func (comment *Comment) RefIssueIdent() string { // ResolveCrossReferences will return the list of references to close/reopen by this PR func (pr *PullRequest) ResolveCrossReferences() ([]*Comment, error) { unfiltered := make([]*Comment, 0, 5) - if err := db.DefaultContext().Engine(). + if err := db.GetEngine(db.DefaultContext). Where("ref_repo_id = ? AND ref_issue_id = ?", pr.Issue.RepoID, pr.Issue.ID). In("ref_action", []references.XRefAction{references.XRefActionCloses, references.XRefActionReopens}). OrderBy("id"). diff --git a/models/issue_xref_test.go b/models/issue_xref_test.go index 1ef9e347fa36..bf498e471073 100644 --- a/models/issue_xref_test.go +++ b/models/issue_xref_test.go @@ -139,7 +139,7 @@ func testCreateIssue(t *testing.T, repo, doer int64, title, content string, ispu Index: idx, } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() assert.NoError(t, sess.Begin()) @@ -170,7 +170,7 @@ func testCreateComment(t *testing.T, repo, doer, issue int64, content string) *C i := db.AssertExistsAndLoadBean(t, &Issue{ID: issue}).(*Issue) c := &Comment{Type: CommentTypeComment, PosterID: doer, Poster: d, IssueID: issue, Issue: i, Content: content} - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() assert.NoError(t, sess.Begin()) _, err := sess.Insert(c) diff --git a/models/lfs.go b/models/lfs.go index a856873e7548..87f7a2871f8a 100644 --- a/models/lfs.go +++ b/models/lfs.go @@ -44,7 +44,7 @@ var ErrLFSObjectNotExist = errors.New("LFS Meta object does not exist") func NewLFSMetaObject(m *LFSMetaObject) (*LFSMetaObject, error) { var err error - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return nil, err @@ -76,7 +76,7 @@ func (repo *Repository) GetLFSMetaObjectByOid(oid string) (*LFSMetaObject, error } m := &LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}, RepositoryID: repo.ID} - has, err := db.DefaultContext().Engine().Get(m) + has, err := db.GetEngine(db.DefaultContext).Get(m) if err != nil { return nil, err } else if !has { @@ -92,7 +92,7 @@ func (repo *Repository) RemoveLFSMetaObjectByOid(oid string) (int64, error) { return 0, ErrLFSObjectNotExist } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return -1, err @@ -113,7 +113,7 @@ func (repo *Repository) RemoveLFSMetaObjectByOid(oid string) (int64, error) { // GetLFSMetaObjects returns all LFSMetaObjects associated with a repository func (repo *Repository) GetLFSMetaObjects(page, pageSize int) ([]*LFSMetaObject, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if page >= 0 && pageSize > 0 { @@ -129,23 +129,23 @@ func (repo *Repository) GetLFSMetaObjects(page, pageSize int) ([]*LFSMetaObject, // CountLFSMetaObjects returns a count of all LFSMetaObjects associated with a repository func (repo *Repository) CountLFSMetaObjects() (int64, error) { - return db.DefaultContext().Engine().Count(&LFSMetaObject{RepositoryID: repo.ID}) + return db.GetEngine(db.DefaultContext).Count(&LFSMetaObject{RepositoryID: repo.ID}) } // LFSObjectAccessible checks if a provided Oid is accessible to the user func LFSObjectAccessible(user *User, oid string) (bool, error) { if user.IsAdmin { - count, err := db.DefaultContext().Engine().Count(&LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}}) + count, err := db.GetEngine(db.DefaultContext).Count(&LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}}) return count > 0, err } cond := accessibleRepositoryCondition(user) - count, err := db.DefaultContext().Engine().Where(cond).Join("INNER", "repository", "`lfs_meta_object`.repository_id = `repository`.id").Count(&LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}}) + count, err := db.GetEngine(db.DefaultContext).Where(cond).Join("INNER", "repository", "`lfs_meta_object`.repository_id = `repository`.id").Count(&LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}}) return count > 0, err } // LFSAutoAssociate auto associates accessible LFSMetaObjects func LFSAutoAssociate(metas []*LFSMetaObject, user *User, repoID int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -184,7 +184,7 @@ func IterateLFS(f func(mo *LFSMetaObject) error) error { const batchSize = 100 for { mos := make([]*LFSMetaObject, 0, batchSize) - if err := db.DefaultContext().Engine().Limit(batchSize, start).Find(&mos); err != nil { + if err := db.GetEngine(db.DefaultContext).Limit(batchSize, start).Find(&mos); err != nil { return err } if len(mos) == 0 { diff --git a/models/lfs_lock.go b/models/lfs_lock.go index d7efdf4440a9..ca49ab8a6a04 100644 --- a/models/lfs_lock.go +++ b/models/lfs_lock.go @@ -72,7 +72,7 @@ func CreateLFSLock(lock *LFSLock) (*LFSLock, error) { return nil, err } - _, err = db.DefaultContext().Engine().InsertOne(lock) + _, err = db.GetEngine(db.DefaultContext).InsertOne(lock) return lock, err } @@ -80,7 +80,7 @@ func CreateLFSLock(lock *LFSLock) (*LFSLock, error) { func GetLFSLock(repo *Repository, path string) (*LFSLock, error) { path = cleanPath(path) rel := &LFSLock{RepoID: repo.ID} - has, err := db.DefaultContext().Engine().Where("lower(path) = ?", strings.ToLower(path)).Get(rel) + has, err := db.GetEngine(db.DefaultContext).Where("lower(path) = ?", strings.ToLower(path)).Get(rel) if err != nil { return nil, err } @@ -93,7 +93,7 @@ func GetLFSLock(repo *Repository, path string) (*LFSLock, error) { // GetLFSLockByID returns release by given id. func GetLFSLockByID(id int64) (*LFSLock, error) { lock := new(LFSLock) - has, err := db.DefaultContext().Engine().ID(id).Get(lock) + has, err := db.GetEngine(db.DefaultContext).ID(id).Get(lock) if err != nil { return nil, err } else if !has { @@ -104,7 +104,7 @@ func GetLFSLockByID(id int64) (*LFSLock, error) { // GetLFSLockByRepoID returns a list of locks of repository. func GetLFSLockByRepoID(repoID int64, page, pageSize int) ([]*LFSLock, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if page >= 0 && pageSize > 0 { @@ -120,7 +120,7 @@ func GetLFSLockByRepoID(repoID int64, page, pageSize int) ([]*LFSLock, error) { // CountLFSLockByRepoID returns a count of all LFSLocks associated with a repository. func CountLFSLockByRepoID(repoID int64) (int64, error) { - return db.DefaultContext().Engine().Count(&LFSLock{RepoID: repoID}) + return db.GetEngine(db.DefaultContext).Count(&LFSLock{RepoID: repoID}) } // DeleteLFSLockByID deletes a lock by given ID. @@ -139,7 +139,7 @@ func DeleteLFSLockByID(id int64, u *User, force bool) (*LFSLock, error) { return nil, fmt.Errorf("user doesn't own lock and force flag is not set") } - _, err = db.DefaultContext().Engine().ID(id).Delete(new(LFSLock)) + _, err = db.GetEngine(db.DefaultContext).ID(id).Delete(new(LFSLock)) return lock, err } diff --git a/models/list_options.go b/models/list_options.go index 59bfa91678bd..25b9a05f16e8 100644 --- a/models/list_options.go +++ b/models/list_options.go @@ -21,7 +21,7 @@ type Paginator interface { func getPaginatedSession(p Paginator) *xorm.Session { skip, take := p.GetSkipTake() - return db.DefaultContext().Engine().Limit(take, skip) + return db.GetEngine(db.DefaultContext).Limit(take, skip) } // setSessionPagination sets pagination for a database session diff --git a/models/login_source.go b/models/login_source.go index 77a43fe4aec1..e1f7a7e08e51 100644 --- a/models/login_source.go +++ b/models/login_source.go @@ -208,7 +208,7 @@ func (source *LoginSource) SkipVerify() bool { // CreateLoginSource inserts a LoginSource in the DB if not already // existing with the given name. func CreateLoginSource(source *LoginSource) error { - has, err := db.DefaultContext().Engine().Where("name=?", source.Name).Exist(new(LoginSource)) + has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(LoginSource)) if err != nil { return err } else if has { @@ -219,7 +219,7 @@ func CreateLoginSource(source *LoginSource) error { source.IsSyncEnabled = false } - _, err = db.DefaultContext().Engine().Insert(source) + _, err = db.GetEngine(db.DefaultContext).Insert(source) if err != nil { return err } @@ -240,7 +240,7 @@ func CreateLoginSource(source *LoginSource) error { err = registerableSource.RegisterSource() if err != nil { // remove the LoginSource in case of errors while registering configuration - if _, err := db.DefaultContext().Engine().Delete(source); err != nil { + if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil { log.Error("CreateLoginSource: Error while wrapOpenIDConnectInitializeError: %v", err) } } @@ -250,13 +250,13 @@ func CreateLoginSource(source *LoginSource) error { // LoginSources returns a slice of all login sources found in DB. func LoginSources() ([]*LoginSource, error) { auths := make([]*LoginSource, 0, 6) - return auths, db.DefaultContext().Engine().Find(&auths) + return auths, db.GetEngine(db.DefaultContext).Find(&auths) } // LoginSourcesByType returns all sources of the specified type func LoginSourcesByType(loginType LoginType) ([]*LoginSource, error) { sources := make([]*LoginSource, 0, 1) - if err := db.DefaultContext().Engine().Where("type = ?", loginType).Find(&sources); err != nil { + if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil { return nil, err } return sources, nil @@ -265,7 +265,7 @@ func LoginSourcesByType(loginType LoginType) ([]*LoginSource, error) { // AllActiveLoginSources returns all active sources func AllActiveLoginSources() ([]*LoginSource, error) { sources := make([]*LoginSource, 0, 5) - if err := db.DefaultContext().Engine().Where("is_active = ?", true).Find(&sources); err != nil { + if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil { return nil, err } return sources, nil @@ -274,7 +274,7 @@ func AllActiveLoginSources() ([]*LoginSource, error) { // ActiveLoginSources returns all active sources of the specified type func ActiveLoginSources(loginType LoginType) ([]*LoginSource, error) { sources := make([]*LoginSource, 0, 1) - if err := db.DefaultContext().Engine().Where("is_active = ? and type = ?", true, loginType).Find(&sources); err != nil { + if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, loginType).Find(&sources); err != nil { return nil, err } return sources, nil @@ -305,7 +305,7 @@ func GetLoginSourceByID(id int64) (*LoginSource, error) { return source, nil } - has, err := db.DefaultContext().Engine().ID(id).Get(source) + has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source) if err != nil { return nil, err } else if !has { @@ -325,7 +325,7 @@ func UpdateSource(source *LoginSource) error { } } - _, err := db.DefaultContext().Engine().ID(source.ID).AllCols().Update(source) + _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source) if err != nil { return err } @@ -346,7 +346,7 @@ func UpdateSource(source *LoginSource) error { err = registerableSource.RegisterSource() if err != nil { // restore original values since we cannot update the provider it self - if _, err := db.DefaultContext().Engine().ID(source.ID).AllCols().Update(originalLoginSource); err != nil { + if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalLoginSource); err != nil { log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err) } } @@ -355,14 +355,14 @@ func UpdateSource(source *LoginSource) error { // DeleteSource deletes a LoginSource record in DB. func DeleteSource(source *LoginSource) error { - count, err := db.DefaultContext().Engine().Count(&User{LoginSource: source.ID}) + count, err := db.GetEngine(db.DefaultContext).Count(&User{LoginSource: source.ID}) if err != nil { return err } else if count > 0 { return ErrLoginSourceInUse{source.ID} } - count, err = db.DefaultContext().Engine().Count(&ExternalLoginUser{LoginSourceID: source.ID}) + count, err = db.GetEngine(db.DefaultContext).Count(&ExternalLoginUser{LoginSourceID: source.ID}) if err != nil { return err } else if count > 0 { @@ -375,12 +375,12 @@ func DeleteSource(source *LoginSource) error { } } - _, err = db.DefaultContext().Engine().ID(source.ID).Delete(new(LoginSource)) + _, err = db.GetEngine(db.DefaultContext).ID(source.ID).Delete(new(LoginSource)) return err } // CountLoginSources returns number of login sources. func CountLoginSources() int64 { - count, _ := db.DefaultContext().Engine().Count(new(LoginSource)) + count, _ := db.GetEngine(db.DefaultContext).Count(new(LoginSource)) return count } diff --git a/models/migrate.go b/models/migrate.go index 163573a16619..18b1b11e468a 100644 --- a/models/migrate.go +++ b/models/migrate.go @@ -18,7 +18,7 @@ func InsertMilestones(ms ...*Milestone) (err error) { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -39,7 +39,7 @@ func InsertMilestones(ms ...*Milestone) (err error) { // InsertIssues insert issues to database func InsertIssues(issues ...*Issue) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -144,7 +144,7 @@ func InsertIssueComments(comments []*Comment) error { issueIDs[comment.IssueID] = true } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -175,7 +175,7 @@ func InsertIssueComments(comments []*Comment) error { // InsertPullRequests inserted pull requests func InsertPullRequests(prs ...*PullRequest) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -195,7 +195,7 @@ func InsertPullRequests(prs ...*PullRequest) error { // InsertReleases migrates release func InsertReleases(rels ...*Release) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -233,7 +233,7 @@ func migratedIssueCond(tp structs.GitServiceType) builder.Cond { // UpdateReviewsMigrationsByType updates reviews' migrations information via given git service type and original id and poster id func UpdateReviewsMigrationsByType(tp structs.GitServiceType, originalAuthorID string, posterID int64) error { - _, err := db.DefaultContext().Engine().Table("review"). + _, err := db.GetEngine(db.DefaultContext).Table("review"). Where("original_author_id = ?", originalAuthorID). And(migratedIssueCond(tp)). Update(map[string]interface{}{ diff --git a/models/notification.go b/models/notification.go index 2f34def07227..af24a6cf5a7a 100644 --- a/models/notification.go +++ b/models/notification.go @@ -127,17 +127,17 @@ func getNotifications(e db.Engine, options *FindNotificationOptions) (nl Notific // GetNotifications returns all notifications that fit to the given options. func GetNotifications(opts *FindNotificationOptions) (NotificationList, error) { - return getNotifications(db.DefaultContext().Engine(), opts) + return getNotifications(db.GetEngine(db.DefaultContext), opts) } // CountNotifications count all notifications that fit to the given options and ignore pagination. func CountNotifications(opts *FindNotificationOptions) (int64, error) { - return db.DefaultContext().Engine().Where(opts.ToCond()).Count(&Notification{}) + return db.GetEngine(db.DefaultContext).Where(opts.ToCond()).Count(&Notification{}) } // CreateRepoTransferNotification creates notification for the user a repository was transferred to func CreateRepoTransferNotification(doer, newOwner *User, repo *Repository) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -179,7 +179,7 @@ func CreateRepoTransferNotification(doer, newOwner *User, repo *Repository) erro // for each watcher, or updates it if already exists // receiverID > 0 just send to reciver, else send to all watcher func CreateOrUpdateIssueNotifications(issueID, commentID, notificationAuthorID, receiverID int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -352,7 +352,7 @@ func getIssueNotification(e db.Engine, userID, issueID int64) (*Notification, er // NotificationsForUser returns notifications for a given user and status func NotificationsForUser(user *User, statuses []NotificationStatus, page, perPage int) (NotificationList, error) { - return notificationsForUser(db.DefaultContext().Engine(), user, statuses, page, perPage) + return notificationsForUser(db.GetEngine(db.DefaultContext), user, statuses, page, perPage) } func notificationsForUser(e db.Engine, user *User, statuses []NotificationStatus, page, perPage int) (notifications []*Notification, err error) { @@ -375,7 +375,7 @@ func notificationsForUser(e db.Engine, user *User, statuses []NotificationStatus // CountUnread count unread notifications for a user func CountUnread(user *User) int64 { - return countUnread(db.DefaultContext().Engine(), user.ID) + return countUnread(db.GetEngine(db.DefaultContext), user.ID) } func countUnread(e db.Engine, userID int64) int64 { @@ -389,7 +389,7 @@ func countUnread(e db.Engine, userID int64) int64 { // LoadAttributes load Repo Issue User and Comment if not loaded func (n *Notification) LoadAttributes() (err error) { - return n.loadAttributes(db.DefaultContext().Engine()) + return n.loadAttributes(db.GetEngine(db.DefaultContext)) } func (n *Notification) loadAttributes(e db.Engine) (err error) { @@ -451,12 +451,12 @@ func (n *Notification) loadUser(e db.Engine) (err error) { // GetRepo returns the repo of the notification func (n *Notification) GetRepo() (*Repository, error) { - return n.Repository, n.loadRepo(db.DefaultContext().Engine()) + return n.Repository, n.loadRepo(db.GetEngine(db.DefaultContext)) } // GetIssue returns the issue of the notification func (n *Notification) GetIssue() (*Issue, error) { - return n.Issue, n.loadIssue(db.DefaultContext().Engine()) + return n.Issue, n.loadIssue(db.GetEngine(db.DefaultContext)) } // HTMLURL formats a URL-string to the notification @@ -521,7 +521,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, []int, error) { if left < limit { limit = left } - rows, err := db.DefaultContext().Engine(). + rows, err := db.GetEngine(db.DefaultContext). In("id", repoIDs[:limit]). Rows(new(Repository)) if err != nil { @@ -597,7 +597,7 @@ func (nl NotificationList) LoadIssues() ([]int, error) { if left < limit { limit = left } - rows, err := db.DefaultContext().Engine(). + rows, err := db.GetEngine(db.DefaultContext). In("id", issueIDs[:limit]). Rows(new(Issue)) if err != nil { @@ -683,7 +683,7 @@ func (nl NotificationList) LoadComments() ([]int, error) { if left < limit { limit = left } - rows, err := db.DefaultContext().Engine(). + rows, err := db.GetEngine(db.DefaultContext). In("id", commentIDs[:limit]). Rows(new(Comment)) if err != nil { @@ -723,7 +723,7 @@ func (nl NotificationList) LoadComments() ([]int, error) { // GetNotificationCount returns the notification count for user func GetNotificationCount(user *User, status NotificationStatus) (int64, error) { - return getNotificationCount(db.DefaultContext().Engine(), user, status) + return getNotificationCount(db.GetEngine(db.DefaultContext), user, status) } func getNotificationCount(e db.Engine, user *User, status NotificationStatus) (count int64, err error) { @@ -746,7 +746,7 @@ func GetUIDsAndNotificationCounts(since, until timeutil.TimeStamp) ([]UserIDCoun `WHERE user_id IN (SELECT user_id FROM notification WHERE updated_unix >= ? AND ` + `updated_unix < ?) AND status = ? GROUP BY user_id` var res []UserIDCount - return res, db.DefaultContext().Engine().SQL(sql, since, until, NotificationStatusUnread).Find(&res) + return res, db.GetEngine(db.DefaultContext).SQL(sql, since, until, NotificationStatusUnread).Find(&res) } func setIssueNotificationStatusReadIfUnread(e db.Engine, userID, issueID int64) error { @@ -778,7 +778,7 @@ func setRepoNotificationStatusReadIfUnread(e db.Engine, userID, repoID int64) er // SetNotificationStatus change the notification status func SetNotificationStatus(notificationID int64, user *User, status NotificationStatus) (*Notification, error) { - notification, err := getNotificationByID(db.DefaultContext().Engine(), notificationID) + notification, err := getNotificationByID(db.GetEngine(db.DefaultContext), notificationID) if err != nil { return notification, err } @@ -789,13 +789,13 @@ func SetNotificationStatus(notificationID int64, user *User, status Notification notification.Status = status - _, err = db.DefaultContext().Engine().ID(notificationID).Update(notification) + _, err = db.GetEngine(db.DefaultContext).ID(notificationID).Update(notification) return notification, err } // GetNotificationByID return notification by ID func GetNotificationByID(notificationID int64) (*Notification, error) { - return getNotificationByID(db.DefaultContext().Engine(), notificationID) + return getNotificationByID(db.GetEngine(db.DefaultContext), notificationID) } func getNotificationByID(e db.Engine, notificationID int64) (*Notification, error) { @@ -817,7 +817,7 @@ func getNotificationByID(e db.Engine, notificationID int64) (*Notification, erro // UpdateNotificationStatuses updates the statuses of all of a user's notifications that are of the currentStatus type to the desiredStatus func UpdateNotificationStatuses(user *User, currentStatus, desiredStatus NotificationStatus) error { n := &Notification{Status: desiredStatus, UpdatedBy: user.ID} - _, err := db.DefaultContext().Engine(). + _, err := db.GetEngine(db.DefaultContext). Where("user_id = ? AND status = ?", user.ID, currentStatus). Cols("status", "updated_by", "updated_unix"). Update(n) diff --git a/models/oauth2.go b/models/oauth2.go index ac14813ddd13..7fdd5309fb92 100644 --- a/models/oauth2.go +++ b/models/oauth2.go @@ -9,7 +9,7 @@ import "code.gitea.io/gitea/models/db" // GetActiveOAuth2ProviderLoginSources returns all actived LoginOAuth2 sources func GetActiveOAuth2ProviderLoginSources() ([]*LoginSource, error) { sources := make([]*LoginSource, 0, 1) - if err := db.DefaultContext().Engine().Where("is_active = ? and type = ?", true, LoginOAuth2).Find(&sources); err != nil { + if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, LoginOAuth2).Find(&sources); err != nil { return nil, err } return sources, nil @@ -18,7 +18,7 @@ func GetActiveOAuth2ProviderLoginSources() ([]*LoginSource, error) { // GetActiveOAuth2LoginSourceByName returns a OAuth2 LoginSource based on the given name func GetActiveOAuth2LoginSourceByName(name string) (*LoginSource, error) { loginSource := new(LoginSource) - has, err := db.DefaultContext().Engine().Where("name = ? and type = ? and is_active = ?", name, LoginOAuth2, true).Get(loginSource) + has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, LoginOAuth2, true).Get(loginSource) if !has || err != nil { return nil, err } diff --git a/models/oauth2_application.go b/models/oauth2_application.go index 850b09234c0f..0fd2e38472e0 100644 --- a/models/oauth2_application.go +++ b/models/oauth2_application.go @@ -81,7 +81,7 @@ func (app *OAuth2Application) GenerateClientSecret() (string, error) { return "", err } app.ClientSecret = string(hashedSecret) - if _, err := db.DefaultContext().Engine().ID(app.ID).Cols("client_secret").Update(app); err != nil { + if _, err := db.GetEngine(db.DefaultContext).ID(app.ID).Cols("client_secret").Update(app); err != nil { return "", err } return clientSecret, nil @@ -94,7 +94,7 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool { // GetGrantByUserID returns a OAuth2Grant by its user and application ID func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) { - return app.getGrantByUserID(db.DefaultContext().Engine(), userID) + return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID) } func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) { @@ -109,7 +109,7 @@ func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant // CreateGrant generates a grant for an user func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) { - return app.createGrant(db.DefaultContext().Engine(), userID, scope) + return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope) } func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) { @@ -127,7 +127,7 @@ func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope strin // GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found. func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) { - return getOAuth2ApplicationByClientID(db.DefaultContext().Engine(), clientID) + return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID) } func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) { @@ -141,7 +141,7 @@ func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Ap // GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found. func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) { - return getOAuth2ApplicationByID(db.DefaultContext().Engine(), id) + return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id) } func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) { @@ -158,7 +158,7 @@ func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, er // GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) { - return getOAuth2ApplicationsByUserID(db.DefaultContext().Engine(), userID) + return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID) } func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) { @@ -176,7 +176,7 @@ type CreateOAuth2ApplicationOptions struct { // CreateOAuth2Application inserts a new oauth2 application func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { - return createOAuth2Application(db.DefaultContext().Engine(), opts) + return createOAuth2Application(db.GetEngine(db.DefaultContext), opts) } func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { @@ -203,7 +203,7 @@ type UpdateOAuth2ApplicationOptions struct { // UpdateOAuth2Application updates an oauth2 application func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) if err := sess.Begin(); err != nil { return nil, err } @@ -264,7 +264,7 @@ func deleteOAuth2Application(sess *xorm.Session, id, userid int64) error { // DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app. func DeleteOAuth2Application(id, userid int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -277,7 +277,7 @@ func DeleteOAuth2Application(id, userid int64) error { // ListOAuth2Applications returns a list of oauth2 applications belongs to given user. func ListOAuth2Applications(uid int64, listOptions ListOptions) ([]*OAuth2Application, int64, error) { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Where("uid=?", uid). Desc("id") @@ -329,7 +329,7 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect // Invalidate deletes the auth code from the database to invalidate this code func (code *OAuth2AuthorizationCode) Invalidate() error { - return code.invalidate(db.DefaultContext().Engine()) + return code.invalidate(db.GetEngine(db.DefaultContext)) } func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error { @@ -361,7 +361,7 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool // GetOAuth2AuthorizationByCode returns an authorization by its code func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) { - return getOAuth2AuthorizationByCode(db.DefaultContext().Engine(), code) + return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code) } func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) { @@ -402,7 +402,7 @@ func (grant *OAuth2Grant) TableName() string { // GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) { - return grant.generateNewAuthorizationCode(db.DefaultContext().Engine(), redirectURI, codeChallenge, codeChallengeMethod) + return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod) } func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) { @@ -426,7 +426,7 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, // IncreaseCounter increases the counter and updates the grant func (grant *OAuth2Grant) IncreaseCounter() error { - return grant.increaseCount(db.DefaultContext().Engine()) + return grant.increaseCount(db.GetEngine(db.DefaultContext)) } func (grant *OAuth2Grant) increaseCount(e db.Engine) error { @@ -454,7 +454,7 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool { // SetNonce updates the current nonce value of a grant func (grant *OAuth2Grant) SetNonce(nonce string) error { - return grant.setNonce(db.DefaultContext().Engine(), nonce) + return grant.setNonce(db.GetEngine(db.DefaultContext), nonce) } func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error { @@ -468,7 +468,7 @@ func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error { // GetOAuth2GrantByID returns the grant with the given ID func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) { - return getOAuth2GrantByID(db.DefaultContext().Engine(), id) + return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id) } func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) { @@ -483,7 +483,7 @@ func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) { // GetOAuth2GrantsByUserID lists all grants of a certain user func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) { - return getOAuth2GrantsByUserID(db.DefaultContext().Engine(), uid) + return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid) } func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) { @@ -515,7 +515,7 @@ func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) { // RevokeOAuth2Grant deletes the grant with grantID and userID func RevokeOAuth2Grant(grantID, userID int64) error { - return revokeOAuth2Grant(db.DefaultContext().Engine(), grantID, userID) + return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID) } func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error { diff --git a/models/org.go b/models/org.go index 3b1d9f50ece2..bc6c47fd456c 100644 --- a/models/org.go +++ b/models/org.go @@ -41,7 +41,7 @@ func (org *User) getTeam(e db.Engine, name string) (*Team, error) { // GetTeam returns named team of organization. func (org *User) GetTeam(name string) (*Team, error) { - return org.getTeam(db.DefaultContext().Engine(), name) + return org.getTeam(db.GetEngine(db.DefaultContext), name) } func (org *User) getOwnerTeam(e db.Engine) (*Team, error) { @@ -50,7 +50,7 @@ func (org *User) getOwnerTeam(e db.Engine) (*Team, error) { // GetOwnerTeam returns owner team of organization. func (org *User) GetOwnerTeam() (*Team, error) { - return org.getOwnerTeam(db.DefaultContext().Engine()) + return org.getOwnerTeam(db.GetEngine(db.DefaultContext)) } func (org *User) loadTeams(e db.Engine) error { @@ -65,7 +65,7 @@ func (org *User) loadTeams(e db.Engine) error { // LoadTeams load teams if not loaded. func (org *User) LoadTeams() error { - return org.loadTeams(db.DefaultContext().Engine()) + return org.loadTeams(db.GetEngine(db.DefaultContext)) } // GetMembers returns all members of organization. @@ -85,7 +85,7 @@ type FindOrgMembersOpts struct { // CountOrgMembers counts the organization's members func CountOrgMembers(opts *FindOrgMembersOpts) (int64, error) { - sess := db.DefaultContext().Engine().Where("org_id=?", opts.OrgID) + sess := db.GetEngine(db.DefaultContext).Where("org_id=?", opts.OrgID) if opts.PublicOnly { sess.And("is_public = ?", true) } @@ -129,7 +129,7 @@ func (org *User) removeOrgRepo(e db.Engine, repoID int64) error { // RemoveOrgRepo removes all team-repository relations of organization. func (org *User) RemoveOrgRepo(repoID int64) error { - return org.removeOrgRepo(db.DefaultContext().Engine(), repoID) + return org.removeOrgRepo(db.GetEngine(db.DefaultContext), repoID) } // CreateOrganization creates record of a new organization. @@ -162,7 +162,7 @@ func CreateOrganization(org, owner *User) (err error) { org.NumMembers = 1 org.Type = UserTypeOrganization - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -238,7 +238,7 @@ func GetOrgByName(name string) (*User, error) { LowerName: strings.ToLower(name), Type: UserTypeOrganization, } - has, err := db.DefaultContext().Engine().Get(u) + has, err := db.GetEngine(db.DefaultContext).Get(u) if err != nil { return nil, err } else if !has { @@ -249,7 +249,7 @@ func GetOrgByName(name string) (*User, error) { // CountOrganizations returns number of organizations. func CountOrganizations() int64 { - count, _ := db.DefaultContext().Engine(). + count, _ := db.GetEngine(db.DefaultContext). Where("type=1"). Count(new(User)) return count @@ -261,7 +261,7 @@ func DeleteOrganization(org *User) (err error) { return fmt.Errorf("%s is a user not an organization", org.Name) } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { @@ -353,12 +353,12 @@ func isOrganizationOwner(e db.Engine, orgID, uid int64) (bool, error) { // IsOrganizationOwner returns true if given user is in the owner team. func IsOrganizationOwner(orgID, uid int64) (bool, error) { - return isOrganizationOwner(db.DefaultContext().Engine(), orgID, uid) + return isOrganizationOwner(db.GetEngine(db.DefaultContext), orgID, uid) } // IsOrganizationMember returns true if given user is member of organization. func IsOrganizationMember(orgID, uid int64) (bool, error) { - return isOrganizationMember(db.DefaultContext().Engine(), orgID, uid) + return isOrganizationMember(db.GetEngine(db.DefaultContext), orgID, uid) } func isOrganizationMember(e db.Engine, orgID, uid int64) (bool, error) { @@ -371,7 +371,7 @@ func isOrganizationMember(e db.Engine, orgID, uid int64) (bool, error) { // IsPublicMembership returns true if given user public his/her membership. func IsPublicMembership(orgID, uid int64) (bool, error) { - return db.DefaultContext().Engine(). + return db.GetEngine(db.DefaultContext). Where("uid=?", uid). And("org_id=?", orgID). And("is_public=?", true). @@ -384,7 +384,7 @@ func CanCreateOrgRepo(orgID, uid int64) (bool, error) { if owner, err := IsOrganizationOwner(orgID, uid); owner || err != nil { return owner, err } - return db.DefaultContext().Engine(). + return db.GetEngine(db.DefaultContext). Where(builder.Eq{"team.can_create_org_repo": true}). Join("INNER", "team_user", "team_user.team_id = team.id"). And("team_user.uid = ?", uid). @@ -394,12 +394,12 @@ func CanCreateOrgRepo(orgID, uid int64) (bool, error) { // GetUsersWhoCanCreateOrgRepo returns users which are able to create repo in organization func GetUsersWhoCanCreateOrgRepo(orgID int64) ([]*User, error) { - return getUsersWhoCanCreateOrgRepo(db.DefaultContext().Engine(), orgID) + return getUsersWhoCanCreateOrgRepo(db.GetEngine(db.DefaultContext), orgID) } func getUsersWhoCanCreateOrgRepo(e db.Engine, orgID int64) ([]*User, error) { users := make([]*User, 0, 10) - return users, db.DefaultContext().Engine(). + return users, db.GetEngine(db.DefaultContext). Join("INNER", "`team_user`", "`team_user`.uid=`user`.id"). Join("INNER", "`team`", "`team`.id=`team_user`.team_id"). Where(builder.Eq{"team.can_create_org_repo": true}.Or(builder.Eq{"team.authorize": AccessModeOwner})). @@ -421,7 +421,7 @@ func getOrgsByUserID(sess *xorm.Session, userID int64, showAll bool) ([]*User, e // GetOrgsByUserID returns a list of organizations that the given user ID // has joined. func GetOrgsByUserID(userID int64, showAll bool) ([]*User, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() return getOrgsByUserID(sess, userID, showAll) } @@ -431,7 +431,7 @@ type MinimalOrg = User // GetUserOrgsList returns one user's all orgs list func GetUserOrgsList(user *User) ([]*MinimalOrg, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() schema, err := db.TableInfo(new(User)) @@ -502,7 +502,7 @@ func getOwnedOrgsByUserID(sess *xorm.Session, userID int64) ([]*User, error) { // HasOrgOrUserVisible tells if the given user can see the given org or user func HasOrgOrUserVisible(org, user *User) bool { - return hasOrgOrUserVisible(db.DefaultContext().Engine(), org, user) + return hasOrgOrUserVisible(db.GetEngine(db.DefaultContext), org, user) } func hasOrgOrUserVisible(e db.Engine, orgOrUser, user *User) bool { @@ -537,7 +537,7 @@ func HasOrgsVisible(orgs []*User, user *User) bool { // GetOwnedOrgsByUserID returns a list of organizations are owned by given user ID. func GetOwnedOrgsByUserID(userID int64) ([]*User, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() return getOwnedOrgsByUserID(sess, userID) } @@ -545,7 +545,7 @@ func GetOwnedOrgsByUserID(userID int64) ([]*User, error) { // GetOwnedOrgsByUserIDDesc returns a list of organizations are owned by // given user ID, ordered descending by the given condition. func GetOwnedOrgsByUserIDDesc(userID int64, desc string) ([]*User, error) { - return getOwnedOrgsByUserID(db.DefaultContext().Engine().Desc(desc), userID) + return getOwnedOrgsByUserID(db.GetEngine(db.DefaultContext).Desc(desc), userID) } // GetOrgsCanCreateRepoByUserID returns a list of organizations where given user ID @@ -553,7 +553,7 @@ func GetOwnedOrgsByUserIDDesc(userID int64, desc string) ([]*User, error) { func GetOrgsCanCreateRepoByUserID(userID int64) ([]*User, error) { orgs := make([]*User, 0, 10) - return orgs, db.DefaultContext().Engine().Where(builder.In("id", builder.Select("`user`.id").From("`user`"). + return orgs, db.GetEngine(db.DefaultContext).Where(builder.In("id", builder.Select("`user`.id").From("`user`"). Join("INNER", "`team_user`", "`team_user`.org_id = `user`.id"). Join("INNER", "`team`", "`team`.id = `team_user`.team_id"). Where(builder.Eq{"`team_user`.uid": userID}). @@ -565,7 +565,7 @@ func GetOrgsCanCreateRepoByUserID(userID int64) ([]*User, error) { // GetOrgUsersByUserID returns all organization-user relations by user ID. func GetOrgUsersByUserID(uid int64, opts *SearchOrganizationsOptions) ([]*OrgUser, error) { ous := make([]*OrgUser, 0, 10) - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Join("LEFT", "`user`", "`org_user`.org_id=`user`.id"). Where("`org_user`.uid=?", uid) if !opts.All { @@ -585,7 +585,7 @@ func GetOrgUsersByUserID(uid int64, opts *SearchOrganizationsOptions) ([]*OrgUse // GetOrgUsersByOrgID returns all organization-user relations by organization ID. func GetOrgUsersByOrgID(opts *FindOrgMembersOpts) ([]*OrgUser, error) { - return getOrgUsersByOrgID(db.DefaultContext().Engine(), opts) + return getOrgUsersByOrgID(db.GetEngine(db.DefaultContext), opts) } func getOrgUsersByOrgID(e db.Engine, opts *FindOrgMembersOpts) ([]*OrgUser, error) { @@ -607,7 +607,7 @@ func getOrgUsersByOrgID(e db.Engine, opts *FindOrgMembersOpts) ([]*OrgUser, erro // ChangeOrgUserStatus changes public or private membership status. func ChangeOrgUserStatus(orgID, uid int64, public bool) error { ou := new(OrgUser) - has, err := db.DefaultContext().Engine(). + has, err := db.GetEngine(db.DefaultContext). Where("uid=?", uid). And("org_id=?", orgID). Get(ou) @@ -618,7 +618,7 @@ func ChangeOrgUserStatus(orgID, uid int64, public bool) error { } ou.IsPublic = public - _, err = db.DefaultContext().Engine().ID(ou.ID).Cols("is_public").Update(ou) + _, err = db.GetEngine(db.DefaultContext).ID(ou.ID).Cols("is_public").Update(ou) return err } @@ -629,7 +629,7 @@ func AddOrgUser(orgID, uid int64) error { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -738,7 +738,7 @@ func removeOrgUser(sess *xorm.Session, orgID, userID int64) error { // RemoveOrgUser removes user from given organization. func RemoveOrgUser(orgID, userID int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -805,13 +805,13 @@ func (org *User) TeamsWithAccessToRepo(repoID int64, mode AccessMode) ([]*Team, // GetUserTeamIDs returns of all team IDs of the organization that user is member of. func (org *User) GetUserTeamIDs(userID int64) ([]int64, error) { - return org.getUserTeamIDs(db.DefaultContext().Engine(), userID) + return org.getUserTeamIDs(db.GetEngine(db.DefaultContext), userID) } // GetUserTeams returns all teams that belong to user, // and that the user has joined. func (org *User) GetUserTeams(userID int64) ([]*Team, error) { - return org.getUserTeams(db.DefaultContext().Engine(), userID) + return org.getUserTeams(db.GetEngine(db.DefaultContext), userID) } // AccessibleReposEnvironment operations involving the repositories that are @@ -838,7 +838,7 @@ type accessibleReposEnv struct { // AccessibleReposEnv builds an AccessibleReposEnvironment for the repositories in `org` // that are accessible to the specified user. func (org *User) AccessibleReposEnv(userID int64) (AccessibleReposEnvironment, error) { - return org.accessibleReposEnv(db.DefaultContext().Engine(), userID) + return org.accessibleReposEnv(db.GetEngine(db.DefaultContext), userID) } func (org *User) accessibleReposEnv(e db.Engine, userID int64) (AccessibleReposEnvironment, error) { @@ -871,7 +871,7 @@ func (org *User) AccessibleTeamReposEnv(team *Team) AccessibleReposEnvironment { return &accessibleReposEnv{ org: org, team: team, - e: db.DefaultContext().Engine(), + e: db.GetEngine(db.DefaultContext), orderBy: SearchOrderByRecentUpdated, } } diff --git a/models/org_team.go b/models/org_team.go index 280d814eefec..7ca715bb7899 100644 --- a/models/org_team.go +++ b/models/org_team.go @@ -82,7 +82,7 @@ func SearchTeam(opts *SearchTeamOptions) ([]*Team, int64, error) { cond = cond.And(builder.Eq{"org_id": opts.OrgID}) - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() count, err := sess. @@ -120,7 +120,7 @@ func (t *Team) ColorFormat(s fmt.State) { // GetUnits return a list of available units for a team func (t *Team) GetUnits() error { - return t.getUnits(db.DefaultContext().Engine()) + return t.getUnits(db.GetEngine(db.DefaultContext)) } func (t *Team) getUnits(e db.Engine) (err error) { @@ -173,7 +173,7 @@ func (t *Team) getRepositories(e db.Engine) error { // GetRepositories returns paginated repositories in team of organization. func (t *Team) GetRepositories(opts *SearchTeamOptions) error { if opts.Page == 0 { - return t.getRepositories(db.DefaultContext().Engine()) + return t.getRepositories(db.GetEngine(db.DefaultContext)) } return t.getRepositories(getPaginatedSession(opts)) @@ -187,7 +187,7 @@ func (t *Team) getMembers(e db.Engine) (err error) { // GetMembers returns paginated members in team of organization. func (t *Team) GetMembers(opts *SearchMembersOptions) (err error) { if opts.Page == 0 { - return t.getMembers(db.DefaultContext().Engine()) + return t.getMembers(db.GetEngine(db.DefaultContext)) } return t.getMembers(getPaginatedSession(opts)) @@ -210,7 +210,7 @@ func (t *Team) hasRepository(e db.Engine, repoID int64) bool { // HasRepository returns true if given repository belong to team. func (t *Team) HasRepository(repoID int64) bool { - return t.hasRepository(db.DefaultContext().Engine(), repoID) + return t.hasRepository(db.GetEngine(db.DefaultContext), repoID) } func (t *Team) addRepository(e db.Engine, repo *Repository) (err error) { @@ -264,7 +264,7 @@ func (t *Team) addAllRepositories(e db.Engine) error { // AddAllRepositories adds all repositories to the team func (t *Team) AddAllRepositories() (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -285,7 +285,7 @@ func (t *Team) AddRepository(repo *Repository) (err error) { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -304,7 +304,7 @@ func (t *Team) RemoveAllRepositories() (err error) { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -421,7 +421,7 @@ func (t *Team) RemoveRepository(repoID int64) error { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -436,7 +436,7 @@ func (t *Team) RemoveRepository(repoID int64) error { // UnitEnabled returns if the team has the given unit type enabled func (t *Team) UnitEnabled(tp UnitType) bool { - return t.unitEnabled(db.DefaultContext().Engine(), tp) + return t.unitEnabled(db.GetEngine(db.DefaultContext), tp) } func (t *Team) unitEnabled(e db.Engine, tp UnitType) bool { @@ -473,7 +473,7 @@ func NewTeam(t *Team) (err error) { return err } - has, err := db.DefaultContext().Engine().ID(t.OrgID).Get(new(User)) + has, err := db.GetEngine(db.DefaultContext).ID(t.OrgID).Get(new(User)) if err != nil { return err } @@ -482,7 +482,7 @@ func NewTeam(t *Team) (err error) { } t.LowerName = strings.ToLower(t.Name) - has, err = db.DefaultContext().Engine(). + has, err = db.GetEngine(db.DefaultContext). Where("org_id=?", t.OrgID). And("lower_name=?", t.LowerName). Get(new(Team)) @@ -493,7 +493,7 @@ func NewTeam(t *Team) (err error) { return ErrTeamAlreadyExist{t.OrgID, t.LowerName} } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -556,7 +556,7 @@ func getTeam(e db.Engine, orgID int64, name string) (*Team, error) { // GetTeam returns team by given team name and organization. func GetTeam(orgID int64, name string) (*Team, error) { - return getTeam(db.DefaultContext().Engine(), orgID, name) + return getTeam(db.GetEngine(db.DefaultContext), orgID, name) } // GetTeamIDsByNames returns a slice of team ids corresponds to names. @@ -594,7 +594,7 @@ func getTeamByID(e db.Engine, teamID int64) (*Team, error) { // GetTeamByID returns team by given ID. func GetTeamByID(teamID int64) (*Team, error) { - return getTeamByID(db.DefaultContext().Engine(), teamID) + return getTeamByID(db.GetEngine(db.DefaultContext), teamID) } // GetTeamNamesByID returns team's lower name from a list of team ids. @@ -604,7 +604,7 @@ func GetTeamNamesByID(teamIDs []int64) ([]string, error) { } var teamNames []string - err := db.DefaultContext().Engine().Table("team"). + err := db.GetEngine(db.DefaultContext).Table("team"). Select("lower_name"). In("id", teamIDs). Asc("name"). @@ -623,7 +623,7 @@ func UpdateTeam(t *Team, authChanged, includeAllChanged bool) (err error) { t.Description = t.Description[:255] } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -697,7 +697,7 @@ func DeleteTeam(t *Team) error { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -764,7 +764,7 @@ func isTeamMember(e db.Engine, orgID, teamID, userID int64) (bool, error) { // IsTeamMember returns true if given user is a member of team. func IsTeamMember(orgID, teamID, userID int64) (bool, error) { - return isTeamMember(db.DefaultContext().Engine(), orgID, teamID, userID) + return isTeamMember(db.GetEngine(db.DefaultContext), orgID, teamID, userID) } func getTeamUsersByTeamID(e db.Engine, teamID int64) ([]*TeamUser, error) { @@ -795,7 +795,7 @@ func getTeamMembers(e db.Engine, teamID int64) (_ []*User, err error) { // GetTeamMembers returns all members in given team of organization. func GetTeamMembers(teamID int64) ([]*User, error) { - return getTeamMembers(db.DefaultContext().Engine(), teamID) + return getTeamMembers(db.GetEngine(db.DefaultContext), teamID) } func getUserOrgTeams(e db.Engine, orgID, userID int64) (teams []*Team, err error) { @@ -818,7 +818,7 @@ func getUserRepoTeams(e db.Engine, orgID, userID, repoID int64) (teams []*Team, // GetUserOrgTeams returns all teams that user belongs to in given organization. func GetUserOrgTeams(orgID, userID int64) ([]*Team, error) { - return getUserOrgTeams(db.DefaultContext().Engine(), orgID, userID) + return getUserOrgTeams(db.GetEngine(db.DefaultContext), orgID, userID) } // AddTeamMember adds new membership of given team to given organization, @@ -838,7 +838,7 @@ func AddTeamMember(team *Team, userID int64) error { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -933,7 +933,7 @@ func removeTeamMember(e *xorm.Session, team *Team, userID int64) error { // RemoveTeamMember removes member from given team of given organization. func RemoveTeamMember(team *Team, userID int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -946,7 +946,7 @@ func RemoveTeamMember(team *Team, userID int64) error { // IsUserInTeams returns if a user in some teams func IsUserInTeams(userID int64, teamIDs []int64) (bool, error) { - return isUserInTeams(db.DefaultContext().Engine(), userID, teamIDs) + return isUserInTeams(db.GetEngine(db.DefaultContext), userID, teamIDs) } func isUserInTeams(e db.Engine, userID int64, teamIDs []int64) (bool, error) { @@ -956,7 +956,7 @@ func isUserInTeams(e db.Engine, userID int64, teamIDs []int64) (bool, error) { // UsersInTeamsCount counts the number of users which are in userIDs and teamIDs func UsersInTeamsCount(userIDs, teamIDs []int64) (int64, error) { var ids []int64 - if err := db.DefaultContext().Engine().In("uid", userIDs).In("team_id", teamIDs). + if err := db.GetEngine(db.DefaultContext).In("uid", userIDs).In("team_id", teamIDs). Table("team_user"). Cols("uid").GroupBy("uid").Find(&ids); err != nil { return 0, err @@ -990,7 +990,7 @@ func hasTeamRepo(e db.Engine, orgID, teamID, repoID int64) bool { // HasTeamRepo returns true if given repository belongs to team. func HasTeamRepo(orgID, teamID, repoID int64) bool { - return hasTeamRepo(db.DefaultContext().Engine(), orgID, teamID, repoID) + return hasTeamRepo(db.GetEngine(db.DefaultContext), orgID, teamID, repoID) } func addTeamRepo(e db.Engine, orgID, teamID, repoID int64) error { @@ -1013,7 +1013,7 @@ func removeTeamRepo(e db.Engine, teamID, repoID int64) error { // GetTeamsWithAccessToRepo returns all teams in an organization that have given access level to the repository. func GetTeamsWithAccessToRepo(orgID, repoID int64, mode AccessMode) ([]*Team, error) { teams := make([]*Team, 0, 5) - return teams, db.DefaultContext().Engine().Where("team.authorize >= ?", mode). + return teams, db.GetEngine(db.DefaultContext).Where("team.authorize >= ?", mode). Join("INNER", "team_repo", "team_repo.team_id = team.id"). And("team_repo.org_id = ?", orgID). And("team_repo.repo_id = ?", repoID). @@ -1046,7 +1046,7 @@ func getUnitsByTeamID(e db.Engine, teamID int64) (units []*TeamUnit, err error) // UpdateTeamUnits updates a teams's units func UpdateTeamUnits(team *Team, units []TeamUnit) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err diff --git a/models/org_test.go b/models/org_test.go index 9990a9d9f1a6..75dfc4262d5c 100644 --- a/models/org_test.go +++ b/models/org_test.go @@ -255,7 +255,7 @@ func TestGetOrgByName(t *testing.T) { func TestCountOrganizations(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) - expected, err := db.DefaultContext().Engine().Where("type=?", UserTypeOrganization).Count(&User{}) + expected, err := db.GetEngine(db.DefaultContext).Where("type=?", UserTypeOrganization).Count(&User{}) assert.NoError(t, err) assert.Equal(t, expected, CountOrganizations()) } diff --git a/models/project.go b/models/project.go index 1e520b877eff..8aaff50e151b 100644 --- a/models/project.go +++ b/models/project.go @@ -90,7 +90,7 @@ type ProjectSearchOptions struct { // GetProjects returns a list of all projects that have been created in the repository func GetProjects(opts ProjectSearchOptions) ([]*Project, int64, error) { - return getProjects(db.DefaultContext().Engine(), opts) + return getProjects(db.GetEngine(db.DefaultContext), opts) } func getProjects(e db.Engine, opts ProjectSearchOptions) ([]*Project, int64, error) { @@ -143,7 +143,7 @@ func NewProject(p *Project) error { return errors.New("project type is not valid") } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { @@ -167,7 +167,7 @@ func NewProject(p *Project) error { // GetProjectByID returns the projects in a repository func GetProjectByID(id int64) (*Project, error) { - return getProjectByID(db.DefaultContext().Engine(), id) + return getProjectByID(db.GetEngine(db.DefaultContext), id) } func getProjectByID(e db.Engine, id int64) (*Project, error) { @@ -185,7 +185,7 @@ func getProjectByID(e db.Engine, id int64) (*Project, error) { // UpdateProject updates project properties func UpdateProject(p *Project) error { - return updateProject(db.DefaultContext().Engine(), p) + return updateProject(db.GetEngine(db.DefaultContext), p) } func updateProject(e db.Engine, p *Project) error { @@ -220,7 +220,7 @@ func updateRepositoryProjectCount(e db.Engine, repoID int64) error { // ChangeProjectStatusByRepoIDAndID toggles a project between opened and closed func ChangeProjectStatusByRepoIDAndID(repoID, projectID int64, isClosed bool) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -244,7 +244,7 @@ func ChangeProjectStatusByRepoIDAndID(repoID, projectID int64, isClosed bool) er // ChangeProjectStatus toggle a project between opened and closed func ChangeProjectStatus(p *Project, isClosed bool) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -273,7 +273,7 @@ func changeProjectStatus(e db.Engine, p *Project, isClosed bool) error { // DeleteProjectByID deletes a project from a repository. func DeleteProjectByID(id int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err diff --git a/models/project_board.go b/models/project_board.go index 901850e51109..6a358685113b 100644 --- a/models/project_board.go +++ b/models/project_board.go @@ -100,13 +100,13 @@ func createBoardsForProjectsType(sess *xorm.Session, project *Project) error { // NewProjectBoard adds a new project board to a given project func NewProjectBoard(board *ProjectBoard) error { - _, err := db.DefaultContext().Engine().Insert(board) + _, err := db.GetEngine(db.DefaultContext).Insert(board) return err } // DeleteProjectBoardByID removes all issues references to the project board. func DeleteProjectBoardByID(boardID int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -146,7 +146,7 @@ func deleteProjectBoardByProjectID(e db.Engine, projectID int64) error { // GetProjectBoard fetches the current board of a project func GetProjectBoard(boardID int64) (*ProjectBoard, error) { - return getProjectBoard(db.DefaultContext().Engine(), boardID) + return getProjectBoard(db.GetEngine(db.DefaultContext), boardID) } func getProjectBoard(e db.Engine, boardID int64) (*ProjectBoard, error) { @@ -164,7 +164,7 @@ func getProjectBoard(e db.Engine, boardID int64) (*ProjectBoard, error) { // UpdateProjectBoard updates a project board func UpdateProjectBoard(board *ProjectBoard) error { - return updateProjectBoard(db.DefaultContext().Engine(), board) + return updateProjectBoard(db.GetEngine(db.DefaultContext), board) } func updateProjectBoard(e db.Engine, board *ProjectBoard) error { @@ -186,7 +186,7 @@ func updateProjectBoard(e db.Engine, board *ProjectBoard) error { // GetProjectBoards fetches all boards related to a project // if no default board set, first board is a temporary "Uncategorized" board func GetProjectBoards(projectID int64) (ProjectBoardList, error) { - return getProjectBoards(db.DefaultContext().Engine(), projectID) + return getProjectBoards(db.GetEngine(db.DefaultContext), projectID) } func getProjectBoards(e db.Engine, projectID int64) ([]*ProjectBoard, error) { @@ -226,7 +226,7 @@ func getDefaultBoard(e db.Engine, projectID int64) (*ProjectBoard, error) { // SetDefaultBoard represents a board for issues not assigned to one // if boardID is 0 unset default func SetDefaultBoard(projectID, boardID int64) error { - _, err := db.DefaultContext().Engine().Where(builder.Eq{ + _, err := db.GetEngine(db.DefaultContext).Where(builder.Eq{ "project_id": projectID, "`default`": true, }).Cols("`default`").Update(&ProjectBoard{Default: false}) @@ -235,7 +235,7 @@ func SetDefaultBoard(projectID, boardID int64) error { } if boardID > 0 { - _, err = db.DefaultContext().Engine().ID(boardID).Where(builder.Eq{"project_id": projectID}). + _, err = db.GetEngine(db.DefaultContext).ID(boardID).Where(builder.Eq{"project_id": projectID}). Cols("`default`").Update(&ProjectBoard{Default: true}) } @@ -293,7 +293,7 @@ func (bs ProjectBoardList) LoadIssues() (IssueList, error) { // UpdateProjectBoardSorting update project board sorting func UpdateProjectBoardSorting(bs ProjectBoardList) error { for i := range bs { - _, err := db.DefaultContext().Engine().ID(bs[i].ID).Cols( + _, err := db.GetEngine(db.DefaultContext).ID(bs[i].ID).Cols( "sorting", ).Update(bs[i]) if err != nil { diff --git a/models/project_issue.go b/models/project_issue.go index acc0a31e1eb2..a3179507dc9a 100644 --- a/models/project_issue.go +++ b/models/project_issue.go @@ -39,7 +39,7 @@ func deleteProjectIssuesByProjectID(e db.Engine, projectID int64) error { // LoadProject load the project the issue was assigned to func (i *Issue) LoadProject() (err error) { - return i.loadProject(db.DefaultContext().Engine()) + return i.loadProject(db.GetEngine(db.DefaultContext)) } func (i *Issue) loadProject(e db.Engine) (err error) { @@ -58,7 +58,7 @@ func (i *Issue) loadProject(e db.Engine) (err error) { // ProjectID return project id if issue was assigned to one func (i *Issue) ProjectID() int64 { - return i.projectID(db.DefaultContext().Engine()) + return i.projectID(db.GetEngine(db.DefaultContext)) } func (i *Issue) projectID(e db.Engine) int64 { @@ -72,7 +72,7 @@ func (i *Issue) projectID(e db.Engine) int64 { // ProjectBoardID return project board id if issue was assigned to one func (i *Issue) ProjectBoardID() int64 { - return i.projectBoardID(db.DefaultContext().Engine()) + return i.projectBoardID(db.GetEngine(db.DefaultContext)) } func (i *Issue) projectBoardID(e db.Engine) int64 { @@ -93,7 +93,7 @@ func (i *Issue) projectBoardID(e db.Engine) int64 { // NumIssues return counter of all issues assigned to a project func (p *Project) NumIssues() int { - c, err := db.DefaultContext().Engine().Table("project_issue"). + c, err := db.GetEngine(db.DefaultContext).Table("project_issue"). Where("project_id=?", p.ID). GroupBy("issue_id"). Cols("issue_id"). @@ -106,7 +106,7 @@ func (p *Project) NumIssues() int { // NumClosedIssues return counter of closed issues assigned to a project func (p *Project) NumClosedIssues() int { - c, err := db.DefaultContext().Engine().Table("project_issue"). + c, err := db.GetEngine(db.DefaultContext).Table("project_issue"). Join("INNER", "issue", "project_issue.issue_id=issue.id"). Where("project_issue.project_id=? AND issue.is_closed=?", p.ID, true). Cols("issue_id"). @@ -119,7 +119,7 @@ func (p *Project) NumClosedIssues() int { // NumOpenIssues return counter of open issues assigned to a project func (p *Project) NumOpenIssues() int { - c, err := db.DefaultContext().Engine().Table("project_issue"). + c, err := db.GetEngine(db.DefaultContext).Table("project_issue"). Join("INNER", "issue", "project_issue.issue_id=issue.id"). Where("project_issue.project_id=? AND issue.is_closed=?", p.ID, false).Count("issue.id") if err != nil { @@ -130,7 +130,7 @@ func (p *Project) NumOpenIssues() int { // ChangeProjectAssign changes the project associated with an issue func ChangeProjectAssign(issue *Issue, doer *User, newProjectID int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -183,7 +183,7 @@ func addUpdateIssueProject(e *xorm.Session, issue *Issue, doer *User, newProject // MoveIssueAcrossProjectBoards move a card from one board to another func MoveIssueAcrossProjectBoards(issue *Issue, board *ProjectBoard) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err diff --git a/models/protected_tag.go b/models/protected_tag.go index 0ce997e99a14..93318300e884 100644 --- a/models/protected_tag.go +++ b/models/protected_tag.go @@ -35,19 +35,19 @@ func init() { // InsertProtectedTag inserts a protected tag to database func InsertProtectedTag(pt *ProtectedTag) error { - _, err := db.DefaultContext().Engine().Insert(pt) + _, err := db.GetEngine(db.DefaultContext).Insert(pt) return err } // UpdateProtectedTag updates the protected tag func UpdateProtectedTag(pt *ProtectedTag) error { - _, err := db.DefaultContext().Engine().ID(pt.ID).AllCols().Update(pt) + _, err := db.GetEngine(db.DefaultContext).ID(pt.ID).AllCols().Update(pt) return err } // DeleteProtectedTag deletes a protected tag by ID func DeleteProtectedTag(pt *ProtectedTag) error { - _, err := db.DefaultContext().Engine().ID(pt.ID).Delete(&ProtectedTag{}) + _, err := db.GetEngine(db.DefaultContext).ID(pt.ID).Delete(&ProtectedTag{}) return err } @@ -86,13 +86,13 @@ func (pt *ProtectedTag) IsUserAllowed(userID int64) (bool, error) { // GetProtectedTags gets all protected tags of the repository func (repo *Repository) GetProtectedTags() ([]*ProtectedTag, error) { tags := make([]*ProtectedTag, 0) - return tags, db.DefaultContext().Engine().Find(&tags, &ProtectedTag{RepoID: repo.ID}) + return tags, db.GetEngine(db.DefaultContext).Find(&tags, &ProtectedTag{RepoID: repo.ID}) } // GetProtectedTagByID gets the protected tag with the specific id func GetProtectedTagByID(id int64) (*ProtectedTag, error) { tag := new(ProtectedTag) - has, err := db.DefaultContext().Engine().ID(id).Get(tag) + has, err := db.GetEngine(db.DefaultContext).ID(id).Get(tag) if err != nil { return nil, err } diff --git a/models/pull.go b/models/pull.go index 5cb7b57286f6..004af62f035c 100644 --- a/models/pull.go +++ b/models/pull.go @@ -122,7 +122,7 @@ func (pr *PullRequest) loadAttributes(e db.Engine) (err error) { // LoadAttributes loads pull request attributes from database func (pr *PullRequest) LoadAttributes() error { - return pr.loadAttributes(db.DefaultContext().Engine()) + return pr.loadAttributes(db.GetEngine(db.DefaultContext)) } func (pr *PullRequest) loadHeadRepo(e db.Engine) (err error) { @@ -148,12 +148,12 @@ func (pr *PullRequest) loadHeadRepo(e db.Engine) (err error) { // LoadHeadRepo loads the head repository func (pr *PullRequest) LoadHeadRepo() error { - return pr.loadHeadRepo(db.DefaultContext().Engine()) + return pr.loadHeadRepo(db.GetEngine(db.DefaultContext)) } // LoadBaseRepo loads the target repository func (pr *PullRequest) LoadBaseRepo() error { - return pr.loadBaseRepo(db.DefaultContext().Engine()) + return pr.loadBaseRepo(db.GetEngine(db.DefaultContext)) } func (pr *PullRequest) loadBaseRepo(e db.Engine) (err error) { @@ -180,7 +180,7 @@ func (pr *PullRequest) loadBaseRepo(e db.Engine) (err error) { // LoadIssue loads issue information from database func (pr *PullRequest) LoadIssue() (err error) { - return pr.loadIssue(db.DefaultContext().Engine()) + return pr.loadIssue(db.GetEngine(db.DefaultContext)) } func (pr *PullRequest) loadIssue(e db.Engine) (err error) { @@ -197,7 +197,7 @@ func (pr *PullRequest) loadIssue(e db.Engine) (err error) { // LoadProtectedBranch loads the protected branch of the base branch func (pr *PullRequest) LoadProtectedBranch() (err error) { - return pr.loadProtectedBranch(db.DefaultContext().Engine()) + return pr.loadProtectedBranch(db.GetEngine(db.DefaultContext)) } func (pr *PullRequest) loadProtectedBranch(e db.Engine) (err error) { @@ -257,7 +257,7 @@ type ReviewCount struct { // GetApprovalCounts returns the approval counts by type // FIXME: Only returns official counts due to double counting of non-official counts func (pr *PullRequest) GetApprovalCounts() ([]*ReviewCount, error) { - return pr.getApprovalCounts(db.DefaultContext().Engine()) + return pr.getApprovalCounts(db.GetEngine(db.DefaultContext)) } func (pr *PullRequest) getApprovalCounts(e db.Engine) ([]*ReviewCount, error) { @@ -284,7 +284,7 @@ func (pr *PullRequest) getReviewedByLines(writer io.Writer) error { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -393,7 +393,7 @@ func (pr *PullRequest) SetMerged() (bool, error) { pr.HasMerged = true - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return false, err @@ -455,7 +455,7 @@ func NewPullRequest(repo *Repository, issue *Issue, labelIDs []int64, uuids []st issue.Index = idx - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -492,7 +492,7 @@ func NewPullRequest(repo *Repository, issue *Issue, labelIDs []int64, uuids []st // by given head/base and repo/branch. func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch string, flow PullRequestFlow) (*PullRequest, error) { pr := new(PullRequest) - has, err := db.DefaultContext().Engine(). + has, err := db.GetEngine(db.DefaultContext). Where("head_repo_id=? AND head_branch=? AND base_repo_id=? AND base_branch=? AND has_merged=? AND flow = ? AND issue.is_closed=?", headRepoID, headBranch, baseRepoID, baseBranch, false, flow, false). Join("INNER", "issue", "issue.id=pull_request.issue_id"). @@ -510,7 +510,7 @@ func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch // by given head information (repo and branch). func GetLatestPullRequestByHeadInfo(repoID int64, branch string) (*PullRequest, error) { pr := new(PullRequest) - has, err := db.DefaultContext().Engine(). + has, err := db.GetEngine(db.DefaultContext). Where("head_repo_id = ? AND head_branch = ? AND flow = ?", repoID, branch, PullRequestFlowGithub). OrderBy("id DESC"). Get(pr) @@ -527,7 +527,7 @@ func GetPullRequestByIndex(repoID, index int64) (*PullRequest, error) { Index: index, } - has, err := db.DefaultContext().Engine().Get(pr) + has, err := db.GetEngine(db.DefaultContext).Get(pr) if err != nil { return nil, err } else if !has { @@ -557,13 +557,13 @@ func getPullRequestByID(e db.Engine, id int64) (*PullRequest, error) { // GetPullRequestByID returns a pull request by given ID. func GetPullRequestByID(id int64) (*PullRequest, error) { - return getPullRequestByID(db.DefaultContext().Engine(), id) + return getPullRequestByID(db.GetEngine(db.DefaultContext), id) } // GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID. func GetPullRequestByIssueIDWithNoAttributes(issueID int64) (*PullRequest, error) { var pr PullRequest - has, err := db.DefaultContext().Engine().Where("issue_id = ?", issueID).Get(&pr) + has, err := db.GetEngine(db.DefaultContext).Where("issue_id = ?", issueID).Get(&pr) if err != nil { return nil, err } @@ -591,7 +591,7 @@ func getPullRequestByIssueID(e db.Engine, issueID int64) (*PullRequest, error) { func GetAllUnmergedAgitPullRequestByPoster(uid int64) ([]*PullRequest, error) { pulls := make([]*PullRequest, 0, 10) - err := db.DefaultContext().Engine(). + err := db.GetEngine(db.DefaultContext). Where("has_merged=? AND flow = ? AND issue.is_closed=? AND issue.poster_id=?", false, PullRequestFlowAGit, false, uid). Join("INNER", "issue", "issue.id=pull_request.issue_id"). @@ -602,24 +602,24 @@ func GetAllUnmergedAgitPullRequestByPoster(uid int64) ([]*PullRequest, error) { // GetPullRequestByIssueID returns pull request by given issue ID. func GetPullRequestByIssueID(issueID int64) (*PullRequest, error) { - return getPullRequestByIssueID(db.DefaultContext().Engine(), issueID) + return getPullRequestByIssueID(db.GetEngine(db.DefaultContext), issueID) } // Update updates all fields of pull request. func (pr *PullRequest) Update() error { - _, err := db.DefaultContext().Engine().ID(pr.ID).AllCols().Update(pr) + _, err := db.GetEngine(db.DefaultContext).ID(pr.ID).AllCols().Update(pr) return err } // UpdateCols updates specific fields of pull request. func (pr *PullRequest) UpdateCols(cols ...string) error { - _, err := db.DefaultContext().Engine().ID(pr.ID).Cols(cols...).Update(pr) + _, err := db.GetEngine(db.DefaultContext).ID(pr.ID).Cols(cols...).Update(pr) return err } // UpdateColsIfNotMerged updates specific fields of a pull request if it has not been merged func (pr *PullRequest) UpdateColsIfNotMerged(cols ...string) error { - _, err := db.DefaultContext().Engine().Where("id = ? AND has_merged = ?", pr.ID, false).Cols(cols...).Update(pr) + _, err := db.GetEngine(db.DefaultContext).Where("id = ? AND has_merged = ?", pr.ID, false).Cols(cols...).Update(pr) return err } @@ -665,7 +665,7 @@ func (pr *PullRequest) GetWorkInProgressPrefix() string { // UpdateCommitDivergence update Divergence of a pull request func (pr *PullRequest) UpdateCommitDivergence(ahead, behind int) error { - return pr.updateCommitDivergence(db.DefaultContext().Engine(), ahead, behind) + return pr.updateCommitDivergence(db.GetEngine(db.DefaultContext), ahead, behind) } func (pr *PullRequest) updateCommitDivergence(e db.Engine, ahead, behind int) error { diff --git a/models/pull_list.go b/models/pull_list.go index ed8372658d36..57e2f9c85f7e 100644 --- a/models/pull_list.go +++ b/models/pull_list.go @@ -25,7 +25,7 @@ type PullRequestsOptions struct { } func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) { - sess := db.DefaultContext().Engine().Where("pull_request.base_repo_id=?", baseRepoID) + sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", baseRepoID) sess.Join("INNER", "issue", "pull_request.issue_id = issue.id") switch opts.State { @@ -51,7 +51,7 @@ func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xor // by given head information (repo and branch). func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) { prs := make([]*PullRequest, 0, 2) - return prs, db.DefaultContext().Engine(). + return prs, db.GetEngine(db.DefaultContext). Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?", repoID, branch, false, false, PullRequestFlowGithub). Join("INNER", "issue", "issue.id = pull_request.issue_id"). @@ -62,7 +62,7 @@ func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequ // by given base information (repo and branch). func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) { prs := make([]*PullRequest, 0, 2) - return prs, db.DefaultContext().Engine(). + return prs, db.GetEngine(db.DefaultContext). Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?", repoID, branch, false, false). Join("INNER", "issue", "issue.id=pull_request.issue_id"). @@ -72,7 +72,7 @@ func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequ // GetPullRequestIDsByCheckStatus returns all pull requests according the special checking status. func GetPullRequestIDsByCheckStatus(status PullRequestStatus) ([]int64, error) { prs := make([]int64, 0, 10) - return prs, db.DefaultContext().Engine().Table("pull_request"). + return prs, db.GetEngine(db.DefaultContext).Table("pull_request"). Where("status=?", status). Cols("pull_request.id"). Find(&prs) @@ -144,7 +144,7 @@ func (prs PullRequestList) getIssueIDs() []int64 { // LoadAttributes load all the prs attributes func (prs PullRequestList) LoadAttributes() error { - return prs.loadAttributes(db.DefaultContext().Engine()) + return prs.loadAttributes(db.GetEngine(db.DefaultContext)) } func (prs PullRequestList) invalidateCodeComments(e db.Engine, doer *User, repo *git.Repository, branch string) error { @@ -169,5 +169,5 @@ func (prs PullRequestList) invalidateCodeComments(e db.Engine, doer *User, repo // InvalidateCodeComments will lookup the prs for code comments which got invalidated by change func (prs PullRequestList) InvalidateCodeComments(doer *User, repo *git.Repository, branch string) error { - return prs.invalidateCodeComments(db.DefaultContext().Engine(), doer, repo, branch) + return prs.invalidateCodeComments(db.GetEngine(db.DefaultContext), doer, repo, branch) } diff --git a/models/release.go b/models/release.go index 2a6e9352f3b0..d6b629cfe839 100644 --- a/models/release.go +++ b/models/release.go @@ -6,6 +6,7 @@ package models import ( + "context" "errors" "fmt" "sort" @@ -72,7 +73,7 @@ func (r *Release) loadAttributes(e db.Engine) error { // LoadAttributes load repo and publisher attributes for a release func (r *Release) LoadAttributes() error { - return r.loadAttributes(db.DefaultContext().Engine()) + return r.loadAttributes(db.GetEngine(db.DefaultContext)) } // APIURL the api url for a release. release must have attributes loaded @@ -102,31 +103,31 @@ func IsReleaseExist(repoID int64, tagName string) (bool, error) { return false, nil } - return db.DefaultContext().Engine().Get(&Release{RepoID: repoID, LowerTagName: strings.ToLower(tagName)}) + return db.GetEngine(db.DefaultContext).Get(&Release{RepoID: repoID, LowerTagName: strings.ToLower(tagName)}) } // InsertRelease inserts a release func InsertRelease(rel *Release) error { - _, err := db.DefaultContext().Engine().Insert(rel) + _, err := db.GetEngine(db.DefaultContext).Insert(rel) return err } // InsertReleasesContext insert releases -func InsertReleasesContext(ctx *db.Context, rels []*Release) error { - _, err := ctx.Engine().Insert(rels) +func InsertReleasesContext(ctx context.Context, rels []*Release) error { + _, err := db.GetEngine(ctx).Insert(rels) return err } // UpdateRelease updates all columns of a release -func UpdateRelease(ctx *db.Context, rel *Release) error { - _, err := ctx.Engine().ID(rel.ID).AllCols().Update(rel) +func UpdateRelease(ctx context.Context, rel *Release) error { + _, err := db.GetEngine(ctx).ID(rel.ID).AllCols().Update(rel) return err } // AddReleaseAttachments adds a release attachments -func AddReleaseAttachments(ctx *db.Context, releaseID int64, attachmentUUIDs []string) (err error) { +func AddReleaseAttachments(ctx context.Context, releaseID int64, attachmentUUIDs []string) (err error) { // Check attachments - attachments, err := getAttachmentsByUUIDs(ctx.Engine(), attachmentUUIDs) + attachments, err := getAttachmentsByUUIDs(db.GetEngine(ctx), attachmentUUIDs) if err != nil { return fmt.Errorf("GetAttachmentsByUUIDs [uuids: %v]: %v", attachmentUUIDs, err) } @@ -137,7 +138,7 @@ func AddReleaseAttachments(ctx *db.Context, releaseID int64, attachmentUUIDs []s } attachments[i].ReleaseID = releaseID // No assign value could be 0, so ignore AllCols(). - if _, err = ctx.Engine().ID(attachments[i].ID).Update(attachments[i]); err != nil { + if _, err = db.GetEngine(ctx).ID(attachments[i].ID).Update(attachments[i]); err != nil { return fmt.Errorf("update attachment [%d]: %v", attachments[i].ID, err) } } @@ -155,14 +156,14 @@ func GetRelease(repoID int64, tagName string) (*Release, error) { } rel := &Release{RepoID: repoID, LowerTagName: strings.ToLower(tagName)} - _, err = db.DefaultContext().Engine().Get(rel) + _, err = db.GetEngine(db.DefaultContext).Get(rel) return rel, err } // GetReleaseByID returns release with given ID. func GetReleaseByID(id int64) (*Release, error) { rel := new(Release) - has, err := db.DefaultContext().Engine(). + has, err := db.GetEngine(db.DefaultContext). ID(id). Get(rel) if err != nil { @@ -208,7 +209,7 @@ func (opts *FindReleasesOptions) toConds(repoID int64) builder.Cond { // GetReleasesByRepoID returns a list of releases of repository. func GetReleasesByRepoID(repoID int64, opts FindReleasesOptions) ([]*Release, error) { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Desc("created_unix", "id"). Where(opts.toConds(repoID)) @@ -222,7 +223,7 @@ func GetReleasesByRepoID(repoID int64, opts FindReleasesOptions) ([]*Release, er // CountReleasesByRepoID returns a number of releases matching FindReleaseOptions and RepoID. func CountReleasesByRepoID(repoID int64, opts FindReleasesOptions) (int64, error) { - return db.DefaultContext().Engine().Where(opts.toConds(repoID)).Count(new(Release)) + return db.GetEngine(db.DefaultContext).Where(opts.toConds(repoID)).Count(new(Release)) } // GetLatestReleaseByRepoID returns the latest release for a repository @@ -234,7 +235,7 @@ func GetLatestReleaseByRepoID(repoID int64) (*Release, error) { And(builder.Eq{"is_tag": false}) rel := new(Release) - has, err := db.DefaultContext().Engine(). + has, err := db.GetEngine(db.DefaultContext). Desc("created_unix", "id"). Where(cond). Get(rel) @@ -248,8 +249,8 @@ func GetLatestReleaseByRepoID(repoID int64) (*Release, error) { } // GetReleasesByRepoIDAndNames returns a list of releases of repository according repoID and tagNames. -func GetReleasesByRepoIDAndNames(ctx *db.Context, repoID int64, tagNames []string) (rels []*Release, err error) { - err = ctx.Engine(). +func GetReleasesByRepoIDAndNames(ctx context.Context, repoID int64, tagNames []string) (rels []*Release, err error) { + err = db.GetEngine(ctx). In("tag_name", tagNames). Desc("created_unix"). Find(&rels, Release{RepoID: repoID}) @@ -258,7 +259,7 @@ func GetReleasesByRepoIDAndNames(ctx *db.Context, repoID int64, tagNames []strin // GetReleaseCountByRepoID returns the count of releases of repository func GetReleaseCountByRepoID(repoID int64, opts FindReleasesOptions) (int64, error) { - return db.DefaultContext().Engine().Where(opts.toConds(repoID)).Count(&Release{}) + return db.GetEngine(db.DefaultContext).Where(opts.toConds(repoID)).Count(&Release{}) } type releaseMetaSearch struct { @@ -281,7 +282,7 @@ func (s releaseMetaSearch) Less(i, j int) bool { // GetReleaseAttachments retrieves the attachments for releases func GetReleaseAttachments(rels ...*Release) (err error) { - return getReleaseAttachments(db.DefaultContext().Engine(), rels...) + return getReleaseAttachments(db.GetEngine(db.DefaultContext), rels...) } func getReleaseAttachments(e db.Engine, rels ...*Release) (err error) { @@ -352,13 +353,13 @@ func SortReleases(rels []*Release) { // DeleteReleaseByID deletes a release from database by given ID. func DeleteReleaseByID(id int64) error { - _, err := db.DefaultContext().Engine().ID(id).Delete(new(Release)) + _, err := db.GetEngine(db.DefaultContext).ID(id).Delete(new(Release)) return err } // UpdateReleasesMigrationsByType updates all migrated repositories' releases from gitServiceType to replace originalAuthorID to posterID func UpdateReleasesMigrationsByType(gitServiceType structs.GitServiceType, originalAuthorID string, posterID int64) error { - _, err := db.DefaultContext().Engine().Table("release"). + _, err := db.GetEngine(db.DefaultContext).Table("release"). Where("repo_id IN (SELECT id FROM repository WHERE original_service_type = ?)", gitServiceType). And("original_author_id = ?", originalAuthorID). Update(map[string]interface{}{ diff --git a/models/repo.go b/models/repo.go index 39e5a089ebfa..ae149f467d6d 100644 --- a/models/repo.go +++ b/models/repo.go @@ -302,7 +302,7 @@ func (repo *Repository) AfterLoad() { // It creates a fake object that contains error details // when error occurs. func (repo *Repository) MustOwner() *User { - return repo.mustOwner(db.DefaultContext().Engine()) + return repo.mustOwner(db.GetEngine(db.DefaultContext)) } // FullName returns the repository full name @@ -354,7 +354,7 @@ func (repo *Repository) getUnits(e db.Engine) (err error) { // CheckUnitUser check whether user could visit the unit of this repository func (repo *Repository) CheckUnitUser(user *User, unitType UnitType) bool { - return repo.checkUnitUser(db.DefaultContext().Engine(), user, unitType) + return repo.checkUnitUser(db.GetEngine(db.DefaultContext), user, unitType) } func (repo *Repository) checkUnitUser(e db.Engine, user *User, unitType UnitType) bool { @@ -372,7 +372,7 @@ func (repo *Repository) checkUnitUser(e db.Engine, user *User, unitType UnitType // UnitEnabled if this repository has the given unit enabled func (repo *Repository) UnitEnabled(tp UnitType) bool { - if err := repo.getUnits(db.DefaultContext().Engine()); err != nil { + if err := repo.getUnits(db.GetEngine(db.DefaultContext)); err != nil { log.Warn("Error loading repository (ID: %d) units: %s", repo.ID, err.Error()) } for _, unit := range repo.Units { @@ -434,7 +434,7 @@ func (repo *Repository) MustGetUnit(tp UnitType) *RepoUnit { // GetUnit returns a RepoUnit object func (repo *Repository) GetUnit(tp UnitType) (*RepoUnit, error) { - return repo.getUnit(db.DefaultContext().Engine(), tp) + return repo.getUnit(db.GetEngine(db.DefaultContext), tp) } func (repo *Repository) getUnit(e db.Engine, tp UnitType) (*RepoUnit, error) { @@ -460,7 +460,7 @@ func (repo *Repository) getOwner(e db.Engine) (err error) { // GetOwner returns the repository owner func (repo *Repository) GetOwner() error { - return repo.getOwner(db.DefaultContext().Engine()) + return repo.getOwner(db.GetEngine(db.DefaultContext)) } func (repo *Repository) mustOwner(e db.Engine) *User { @@ -498,7 +498,7 @@ func (repo *Repository) ComposeMetas() map[string]string { repo.MustOwner() if repo.Owner.IsOrganization() { teams := make([]string, 0, 5) - _ = db.DefaultContext().Engine().Table("team_repo"). + _ = db.GetEngine(db.DefaultContext).Table("team_repo"). Join("INNER", "team", "team.id = team_repo.team_id"). Where("team_repo.repo_id = ?", repo.ID). Select("team.lower_name"). @@ -561,7 +561,7 @@ func (repo *Repository) getAssignees(e db.Engine) (_ []*User, err error) { // GetAssignees returns all users that have write access and can be assigned to issues // of the repository, func (repo *Repository) GetAssignees() (_ []*User, err error) { - return repo.getAssignees(db.DefaultContext().Engine()) + return repo.getAssignees(db.GetEngine(db.DefaultContext)) } func (repo *Repository) getReviewers(e db.Engine, doerID, posterID int64) ([]*User, error) { @@ -613,7 +613,7 @@ func (repo *Repository) getReviewers(e db.Engine, doerID, posterID int64) ([]*Us // all repo watchers and all organization members. // TODO: may be we should have a busy choice for users to block review request to them. func (repo *Repository) GetReviewers(doerID, posterID int64) ([]*User, error) { - return repo.getReviewers(db.DefaultContext().Engine(), doerID, posterID) + return repo.getReviewers(db.GetEngine(db.DefaultContext), doerID, posterID) } // GetReviewerTeams get all teams can be requested to review @@ -659,7 +659,7 @@ func (repo *Repository) LoadPushMirrors() (err error) { // returns an error on failure (NOTE: no error is returned for // non-fork repositories, and BaseRepo will be left untouched) func (repo *Repository) GetBaseRepo() (err error) { - return repo.getBaseRepo(db.DefaultContext().Engine()) + return repo.getBaseRepo(db.GetEngine(db.DefaultContext)) } func (repo *Repository) getBaseRepo(e db.Engine) (err error) { @@ -680,7 +680,7 @@ func (repo *Repository) IsGenerated() bool { // returns an error on failure (NOTE: no error is returned for // non-generated repositories, and TemplateRepo will be left untouched) func (repo *Repository) GetTemplateRepo() (err error) { - return repo.getTemplateRepo(db.DefaultContext().Engine()) + return repo.getTemplateRepo(db.GetEngine(db.DefaultContext)) } func (repo *Repository) getTemplateRepo(e db.Engine) (err error) { @@ -724,7 +724,7 @@ func (repo *Repository) ComposeCompareURL(oldCommitID, newCommitID string) strin // UpdateDefaultBranch updates the default branch func (repo *Repository) UpdateDefaultBranch() error { - _, err := db.DefaultContext().Engine().ID(repo.ID).Cols("default_branch").Update(repo) + _, err := db.GetEngine(db.DefaultContext).ID(repo.ID).Cols("default_branch").Update(repo) return err } @@ -750,8 +750,8 @@ func (repo *Repository) updateSize(e db.Engine) error { } // UpdateSize updates the repository size, calculating it using util.GetDirectorySize -func (repo *Repository) UpdateSize(ctx *db.Context) error { - return repo.updateSize(ctx.Engine()) +func (repo *Repository) UpdateSize(ctx context.Context) error { + return repo.updateSize(db.GetEngine(ctx)) } // CanUserFork returns true if specified user can fork repository. @@ -812,12 +812,12 @@ func (repo *Repository) CanEnableEditor() bool { // GetReaders returns all users that have explicit read access or higher to the repository. func (repo *Repository) GetReaders() (_ []*User, err error) { - return repo.getUsersWithAccessMode(db.DefaultContext().Engine(), AccessModeRead) + return repo.getUsersWithAccessMode(db.GetEngine(db.DefaultContext), AccessModeRead) } // GetWriters returns all users that have write access to the repository. func (repo *Repository) GetWriters() (_ []*User, err error) { - return repo.getUsersWithAccessMode(db.DefaultContext().Engine(), AccessModeWrite) + return repo.getUsersWithAccessMode(db.GetEngine(db.DefaultContext), AccessModeWrite) } // IsReader returns true if user has explicit read access or higher to the repository. @@ -825,7 +825,7 @@ func (repo *Repository) IsReader(userID int64) (bool, error) { if repo.OwnerID == userID { return true, nil } - return db.DefaultContext().Engine().Where("repo_id = ? AND user_id = ? AND mode >= ?", repo.ID, userID, AccessModeRead).Get(&Access{}) + return db.GetEngine(db.DefaultContext).Where("repo_id = ? AND user_id = ? AND mode >= ?", repo.ID, userID, AccessModeRead).Get(&Access{}) } // getUsersWithAccessMode returns users that have at least given access mode to the repository. @@ -874,7 +874,7 @@ func (repo *Repository) DescriptionHTML() template.HTML { // ReadBy sets repo to be visited by given user. func (repo *Repository) ReadBy(userID int64) error { - return setRepoNotificationStatusReadIfUnread(db.DefaultContext().Engine(), userID, repo.ID) + return setRepoNotificationStatusReadIfUnread(db.GetEngine(db.DefaultContext), userID, repo.ID) } func isRepositoryExist(e db.Engine, u *User, repoName string) (bool, error) { @@ -891,7 +891,7 @@ func isRepositoryExist(e db.Engine, u *User, repoName string) (bool, error) { // IsRepositoryExist returns true if the repository with given name under user has already existed. func IsRepositoryExist(u *User, repoName string) (bool, error) { - return isRepositoryExist(db.DefaultContext().Engine(), u, repoName) + return isRepositoryExist(db.GetEngine(db.DefaultContext), u, repoName) } // CloneLink represents different types of clone URLs of repository. @@ -953,7 +953,7 @@ func CheckCreateRepository(doer, u *User, name string, overwriteOrAdopt bool) er return err } - has, err := isRepositoryExist(db.DefaultContext().Engine(), u, name) + has, err := isRepositoryExist(db.GetEngine(db.DefaultContext), u, name) if err != nil { return fmt.Errorf("IsRepositoryExist: %v", err) } else if has { @@ -1042,12 +1042,12 @@ func IsUsableRepoName(name string) error { } // CreateRepository creates a repository for the user/organization. -func CreateRepository(ctx *db.Context, doer, u *User, repo *Repository, overwriteOrAdopt bool) (err error) { +func CreateRepository(ctx context.Context, doer, u *User, repo *Repository, overwriteOrAdopt bool) (err error) { if err = IsUsableRepoName(repo.Name); err != nil { return err } - has, err := isRepositoryExist(ctx.Engine(), u, repo.Name) + has, err := isRepositoryExist(db.GetEngine(ctx), u, repo.Name) if err != nil { return fmt.Errorf("IsRepositoryExist: %v", err) } else if has { @@ -1068,10 +1068,10 @@ func CreateRepository(ctx *db.Context, doer, u *User, repo *Repository, overwrit } } - if _, err = ctx.Engine().Insert(repo); err != nil { + if _, err = db.GetEngine(ctx).Insert(repo); err != nil { return err } - if err = deleteRepoRedirect(ctx.Engine(), u.ID, repo.Name); err != nil { + if err = deleteRepoRedirect(db.GetEngine(ctx), u.ID, repo.Name); err != nil { return err } @@ -1102,46 +1102,46 @@ func CreateRepository(ctx *db.Context, doer, u *User, repo *Repository, overwrit } } - if _, err = ctx.Engine().Insert(&units); err != nil { + if _, err = db.GetEngine(ctx).Insert(&units); err != nil { return err } // Remember visibility preference. u.LastRepoVisibility = repo.IsPrivate - if err = updateUserCols(ctx.Engine(), u, "last_repo_visibility"); err != nil { + if err = updateUserCols(db.GetEngine(ctx), u, "last_repo_visibility"); err != nil { return fmt.Errorf("updateUser: %v", err) } - if _, err = ctx.Engine().Incr("num_repos").ID(u.ID).Update(new(User)); err != nil { + if _, err = db.GetEngine(ctx).Incr("num_repos").ID(u.ID).Update(new(User)); err != nil { return fmt.Errorf("increment user total_repos: %v", err) } u.NumRepos++ // Give access to all members in teams with access to all repositories. if u.IsOrganization() { - if err := u.loadTeams(ctx.Engine()); err != nil { + if err := u.loadTeams(db.GetEngine(ctx)); err != nil { return fmt.Errorf("loadTeams: %v", err) } for _, t := range u.Teams { if t.IncludesAllRepositories { - if err := t.addRepository(ctx.Engine(), repo); err != nil { + if err := t.addRepository(db.GetEngine(ctx), repo); err != nil { return fmt.Errorf("addRepository: %v", err) } } } - if isAdmin, err := isUserRepoAdmin(ctx.Engine(), repo, doer); err != nil { + if isAdmin, err := isUserRepoAdmin(db.GetEngine(ctx), repo, doer); err != nil { return fmt.Errorf("isUserRepoAdmin: %v", err) } else if !isAdmin { // Make creator repo admin if it wan't assigned automatically - if err = repo.addCollaborator(ctx.Engine(), doer); err != nil { + if err = repo.addCollaborator(db.GetEngine(ctx), doer); err != nil { return fmt.Errorf("AddCollaborator: %v", err) } - if err = repo.changeCollaborationAccessMode(ctx.Engine(), doer.ID, AccessModeAdmin); err != nil { + if err = repo.changeCollaborationAccessMode(db.GetEngine(ctx), doer.ID, AccessModeAdmin); err != nil { return fmt.Errorf("ChangeCollaborationAccessMode: %v", err) } } - } else if err = repo.recalculateAccesses(ctx.Engine()); err != nil { + } else if err = repo.recalculateAccesses(db.GetEngine(ctx)); err != nil { // Organization automatically called this in addRepository method. return fmt.Errorf("recalculateAccesses: %v", err) } @@ -1157,12 +1157,12 @@ func CreateRepository(ctx *db.Context, doer, u *User, repo *Repository, overwrit } if setting.Service.AutoWatchNewRepos { - if err = watchRepo(ctx.Engine(), doer.ID, repo.ID, true); err != nil { + if err = watchRepo(db.GetEngine(ctx), doer.ID, repo.ID, true); err != nil { return fmt.Errorf("watchRepo: %v", err) } } - if err = copyDefaultWebhooksToRepo(ctx.Engine(), repo.ID); err != nil { + if err = copyDefaultWebhooksToRepo(db.GetEngine(ctx), repo.ID); err != nil { return fmt.Errorf("copyDefaultWebhooksToRepo: %v", err) } @@ -1170,7 +1170,7 @@ func CreateRepository(ctx *db.Context, doer, u *User, repo *Repository, overwrit } func countRepositories(userID int64, private bool) int64 { - sess := db.DefaultContext().Engine().Where("id > 0") + sess := db.GetEngine(db.DefaultContext).Where("id > 0") if userID > 0 { sess.And("owner_id = ?", userID) @@ -1206,14 +1206,14 @@ func RepoPath(userName, repoName string) string { } // IncrementRepoForkNum increment repository fork number -func IncrementRepoForkNum(ctx *db.Context, repoID int64) error { - _, err := ctx.Engine().Exec("UPDATE `repository` SET num_forks=num_forks+1 WHERE id=?", repoID) +func IncrementRepoForkNum(ctx context.Context, repoID int64) error { + _, err := db.GetEngine(ctx).Exec("UPDATE `repository` SET num_forks=num_forks+1 WHERE id=?", repoID) return err } // DecrementRepoForkNum decrement repository fork number -func DecrementRepoForkNum(ctx *db.Context, repoID int64) error { - _, err := ctx.Engine().Exec("UPDATE `repository` SET num_forks=num_forks-1 WHERE id=?", repoID) +func DecrementRepoForkNum(ctx context.Context, repoID int64) error { + _, err := db.GetEngine(ctx).Exec("UPDATE `repository` SET num_forks=num_forks-1 WHERE id=?", repoID) return err } @@ -1253,7 +1253,7 @@ func ChangeRepositoryName(doer *User, repo *Repository, newRepoName string) (err } } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return fmt.Errorf("sess.Begin: %v", err) @@ -1275,7 +1275,7 @@ func getRepositoriesByForkID(e db.Engine, forkID int64) ([]*Repository, error) { // GetRepositoriesByForkID returns all repositories with given fork ID. func GetRepositoriesByForkID(forkID int64) ([]*Repository, error) { - return getRepositoriesByForkID(db.DefaultContext().Engine(), forkID) + return getRepositoriesByForkID(db.GetEngine(db.DefaultContext), forkID) } func updateRepository(e db.Engine, repo *Repository, visibilityChanged bool) (err error) { @@ -1353,13 +1353,13 @@ func updateRepository(e db.Engine, repo *Repository, visibilityChanged bool) (er } // UpdateRepositoryCtx updates a repository with db context -func UpdateRepositoryCtx(ctx *db.Context, repo *Repository, visibilityChanged bool) error { - return updateRepository(ctx.Engine(), repo, visibilityChanged) +func UpdateRepositoryCtx(ctx context.Context, repo *Repository, visibilityChanged bool) error { + return updateRepository(db.GetEngine(ctx), repo, visibilityChanged) } // UpdateRepository updates a repository func UpdateRepository(repo *Repository, visibilityChanged bool) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -1377,7 +1377,7 @@ func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error { if ownerID == 0 { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -1394,13 +1394,13 @@ func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error { // UpdateRepositoryUpdatedTime updates a repository's updated time func UpdateRepositoryUpdatedTime(repoID int64, updateTime time.Time) error { - _, err := db.DefaultContext().Engine().Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID) + _, err := db.GetEngine(db.DefaultContext).Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID) return err } // UpdateRepositoryUnits updates a repository's units func UpdateRepositoryUnits(repo *Repository, units []RepoUnit, deleteUnitTypes []UnitType) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -1427,7 +1427,7 @@ func UpdateRepositoryUnits(repo *Repository, units []RepoUnit, deleteUnitTypes [ // DeleteRepository deletes a repository for a user or organization. // make sure if you call this func to close open sessions (sqlite will otherwise get a deadlock) func DeleteRepository(doer *User, uid, repoID int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -1643,21 +1643,21 @@ func DeleteRepository(doer *User, uid, repoID int64) error { // Remove repository files. repoPath := repo.RepoPath() - removeAllWithNotice(db.DefaultContext().Engine(), "Delete repository files", repoPath) + removeAllWithNotice(db.GetEngine(db.DefaultContext), "Delete repository files", repoPath) // Remove wiki files if repo.HasWiki() { - removeAllWithNotice(db.DefaultContext().Engine(), "Delete repository wiki", repo.WikiPath()) + removeAllWithNotice(db.GetEngine(db.DefaultContext), "Delete repository wiki", repo.WikiPath()) } // Remove archives for i := range archivePaths { - removeStorageWithNotice(db.DefaultContext().Engine(), storage.RepoArchives, "Delete repo archive file", archivePaths[i]) + removeStorageWithNotice(db.GetEngine(db.DefaultContext), storage.RepoArchives, "Delete repo archive file", archivePaths[i]) } // Remove lfs objects for i := range lfsPaths { - removeStorageWithNotice(db.DefaultContext().Engine(), storage.LFS, "Delete orphaned LFS file", lfsPaths[i]) + removeStorageWithNotice(db.GetEngine(db.DefaultContext), storage.LFS, "Delete orphaned LFS file", lfsPaths[i]) } // Remove issue attachment files. @@ -1686,7 +1686,7 @@ func DeleteRepository(doer *User, uid, repoID int64) error { // GetRepositoryByOwnerAndName returns the repository by given ownername and reponame. func GetRepositoryByOwnerAndName(ownerName, repoName string) (*Repository, error) { - return getRepositoryByOwnerAndName(db.DefaultContext().Engine(), ownerName, repoName) + return getRepositoryByOwnerAndName(db.GetEngine(db.DefaultContext), ownerName, repoName) } func getRepositoryByOwnerAndName(e db.Engine, ownerName, repoName string) (*Repository, error) { @@ -1710,7 +1710,7 @@ func GetRepositoryByName(ownerID int64, name string) (*Repository, error) { OwnerID: ownerID, LowerName: strings.ToLower(name), } - has, err := db.DefaultContext().Engine().Get(repo) + has, err := db.GetEngine(db.DefaultContext).Get(repo) if err != nil { return nil, err } else if !has { @@ -1732,18 +1732,18 @@ func getRepositoryByID(e db.Engine, id int64) (*Repository, error) { // GetRepositoryByID returns the repository by given id if exists. func GetRepositoryByID(id int64) (*Repository, error) { - return getRepositoryByID(db.DefaultContext().Engine(), id) + return getRepositoryByID(db.GetEngine(db.DefaultContext), id) } // GetRepositoryByIDCtx returns the repository by given id if exists. -func GetRepositoryByIDCtx(ctx *db.Context, id int64) (*Repository, error) { - return getRepositoryByID(ctx.Engine(), id) +func GetRepositoryByIDCtx(ctx context.Context, id int64) (*Repository, error) { + return getRepositoryByID(db.GetEngine(ctx), id) } // GetRepositoriesMapByIDs returns the repositories by given id slice. func GetRepositoriesMapByIDs(ids []int64) (map[int64]*Repository, error) { repos := make(map[int64]*Repository, len(ids)) - return repos, db.DefaultContext().Engine().In("id", ids).Find(&repos) + return repos, db.GetEngine(db.DefaultContext).In("id", ids).Find(&repos) } // GetUserRepositories returns a list of repositories of given user. @@ -1762,7 +1762,7 @@ func GetUserRepositories(opts *SearchRepoOptions) ([]*Repository, int64, error) cond = cond.And(builder.In("lower_name", opts.LowerNames)) } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() count, err := sess.Where(cond).Count(new(Repository)) @@ -1778,7 +1778,7 @@ func GetUserRepositories(opts *SearchRepoOptions) ([]*Repository, int64, error) // GetUserMirrorRepositories returns a list of mirror repositories of given user. func GetUserMirrorRepositories(userID int64) ([]*Repository, error) { repos := make([]*Repository, 0, 10) - return repos, db.DefaultContext().Engine(). + return repos, db.GetEngine(db.DefaultContext). Where("owner_id = ?", userID). And("is_mirror = ?", true). Find(&repos) @@ -1798,17 +1798,17 @@ func getPrivateRepositoryCount(e db.Engine, u *User) (int64, error) { // GetRepositoryCount returns the total number of repositories of user. func GetRepositoryCount(u *User) (int64, error) { - return getRepositoryCount(db.DefaultContext().Engine(), u) + return getRepositoryCount(db.GetEngine(db.DefaultContext), u) } // GetPublicRepositoryCount returns the total number of public repositories of user. func GetPublicRepositoryCount(u *User) (int64, error) { - return getPublicRepositoryCount(db.DefaultContext().Engine(), u) + return getPublicRepositoryCount(db.GetEngine(db.DefaultContext), u) } // GetPrivateRepositoryCount returns the total number of private repositories of user. func GetPrivateRepositoryCount(u *User) (int64, error) { - return getPrivateRepositoryCount(db.DefaultContext().Engine(), u) + return getPrivateRepositoryCount(db.GetEngine(db.DefaultContext), u) } // DeleteOldRepositoryArchives deletes old repository archives. @@ -1817,7 +1817,7 @@ func DeleteOldRepositoryArchives(ctx context.Context, olderThan time.Duration) e for { var archivers []RepoArchiver - err := db.DefaultContext().Engine().Where("created_unix < ?", time.Now().Add(-olderThan).Unix()). + err := db.GetEngine(db.DefaultContext).Where("created_unix < ?", time.Now().Add(-olderThan).Unix()). Asc("created_unix"). Limit(100). Find(&archivers) @@ -1847,7 +1847,7 @@ func deleteOldRepoArchiver(ctx context.Context, archiver *RepoArchiver) error { if err != nil { return err } - _, err = db.DefaultContext().Engine().ID(archiver.ID).Delete(delRepoArchiver) + _, err = db.GetEngine(db.DefaultContext).ID(archiver.ID).Delete(delRepoArchiver) if err != nil { return err } @@ -1863,7 +1863,7 @@ type repoChecker struct { } func repoStatsCheck(ctx context.Context, checker *repoChecker) { - results, err := db.DefaultContext().Engine().Query(checker.querySQL) + results, err := db.GetEngine(db.DefaultContext).Query(checker.querySQL) if err != nil { log.Error("Select %s: %v", checker.desc, err) return @@ -1877,7 +1877,7 @@ func repoStatsCheck(ctx context.Context, checker *repoChecker) { default: } log.Trace("Updating %s: %d", checker.desc, id) - _, err = db.DefaultContext().Engine().Exec(checker.correctSQL, id, id) + _, err = db.GetEngine(db.DefaultContext).Exec(checker.correctSQL, id, id) if err != nil { log.Error("Update %s[%d]: %v", checker.desc, id, err) } @@ -1932,7 +1932,7 @@ func CheckRepoStats(ctx context.Context) error { // ***** START: Repository.NumClosedIssues ***** desc := "repository count 'num_closed_issues'" - results, err := db.DefaultContext().Engine().Query("SELECT repo.id FROM `repository` repo WHERE repo.num_closed_issues!=(SELECT COUNT(*) FROM `issue` WHERE repo_id=repo.id AND is_closed=? AND is_pull=?)", true, false) + results, err := db.GetEngine(db.DefaultContext).Query("SELECT repo.id FROM `repository` repo WHERE repo.num_closed_issues!=(SELECT COUNT(*) FROM `issue` WHERE repo_id=repo.id AND is_closed=? AND is_pull=?)", true, false) if err != nil { log.Error("Select %s: %v", desc, err) } else { @@ -1945,7 +1945,7 @@ func CheckRepoStats(ctx context.Context) error { default: } log.Trace("Updating %s: %d", desc, id) - _, err = db.DefaultContext().Engine().Exec("UPDATE `repository` SET num_closed_issues=(SELECT COUNT(*) FROM `issue` WHERE repo_id=? AND is_closed=? AND is_pull=?) WHERE id=?", id, true, false, id) + _, err = db.GetEngine(db.DefaultContext).Exec("UPDATE `repository` SET num_closed_issues=(SELECT COUNT(*) FROM `issue` WHERE repo_id=? AND is_closed=? AND is_pull=?) WHERE id=?", id, true, false, id) if err != nil { log.Error("Update %s[%d]: %v", desc, id, err) } @@ -1955,7 +1955,7 @@ func CheckRepoStats(ctx context.Context) error { // ***** START: Repository.NumClosedPulls ***** desc = "repository count 'num_closed_pulls'" - results, err = db.DefaultContext().Engine().Query("SELECT repo.id FROM `repository` repo WHERE repo.num_closed_pulls!=(SELECT COUNT(*) FROM `issue` WHERE repo_id=repo.id AND is_closed=? AND is_pull=?)", true, true) + results, err = db.GetEngine(db.DefaultContext).Query("SELECT repo.id FROM `repository` repo WHERE repo.num_closed_pulls!=(SELECT COUNT(*) FROM `issue` WHERE repo_id=repo.id AND is_closed=? AND is_pull=?)", true, true) if err != nil { log.Error("Select %s: %v", desc, err) } else { @@ -1968,7 +1968,7 @@ func CheckRepoStats(ctx context.Context) error { default: } log.Trace("Updating %s: %d", desc, id) - _, err = db.DefaultContext().Engine().Exec("UPDATE `repository` SET num_closed_pulls=(SELECT COUNT(*) FROM `issue` WHERE repo_id=? AND is_closed=? AND is_pull=?) WHERE id=?", id, true, true, id) + _, err = db.GetEngine(db.DefaultContext).Exec("UPDATE `repository` SET num_closed_pulls=(SELECT COUNT(*) FROM `issue` WHERE repo_id=? AND is_closed=? AND is_pull=?) WHERE id=?", id, true, true, id) if err != nil { log.Error("Update %s[%d]: %v", desc, id, err) } @@ -1978,7 +1978,7 @@ func CheckRepoStats(ctx context.Context) error { // FIXME: use checker when stop supporting old fork repo format. // ***** START: Repository.NumForks ***** - results, err = db.DefaultContext().Engine().Query("SELECT repo.id FROM `repository` repo WHERE repo.num_forks!=(SELECT COUNT(*) FROM `repository` WHERE fork_id=repo.id)") + results, err = db.GetEngine(db.DefaultContext).Query("SELECT repo.id FROM `repository` repo WHERE repo.num_forks!=(SELECT COUNT(*) FROM `repository` WHERE fork_id=repo.id)") if err != nil { log.Error("Select repository count 'num_forks': %v", err) } else { @@ -1998,7 +1998,7 @@ func CheckRepoStats(ctx context.Context) error { continue } - rawResult, err := db.DefaultContext().Engine().Query("SELECT COUNT(*) FROM `repository` WHERE fork_id=?", repo.ID) + rawResult, err := db.GetEngine(db.DefaultContext).Query("SELECT COUNT(*) FROM `repository` WHERE fork_id=?", repo.ID) if err != nil { log.Error("Select count of forks[%d]: %v", repo.ID, err) continue @@ -2018,7 +2018,7 @@ func CheckRepoStats(ctx context.Context) error { // SetArchiveRepoState sets if a repo is archived func (repo *Repository) SetArchiveRepoState(isArchived bool) (err error) { repo.IsArchived = isArchived - _, err = db.DefaultContext().Engine().Where("id = ?", repo.ID).Cols("is_archived").NoAutoTime().Update(repo) + _, err = db.GetEngine(db.DefaultContext).Where("id = ?", repo.ID).Cols("is_archived").NoAutoTime().Update(repo) return } @@ -2032,23 +2032,23 @@ func (repo *Repository) SetArchiveRepoState(isArchived bool) (err error) { // HasForkedRepo checks if given user has already forked a repository with given ID. func HasForkedRepo(ownerID, repoID int64) (*Repository, bool) { repo := new(Repository) - has, _ := db.DefaultContext().Engine(). + has, _ := db.GetEngine(db.DefaultContext). Where("owner_id=? AND fork_id=?", ownerID, repoID). Get(repo) return repo, has } // CopyLFS copies LFS data from one repo to another -func CopyLFS(ctx *db.Context, newRepo, oldRepo *Repository) error { +func CopyLFS(ctx context.Context, newRepo, oldRepo *Repository) error { var lfsObjects []*LFSMetaObject - if err := ctx.Engine().Where("repository_id=?", oldRepo.ID).Find(&lfsObjects); err != nil { + if err := db.GetEngine(ctx).Where("repository_id=?", oldRepo.ID).Find(&lfsObjects); err != nil { return err } for _, v := range lfsObjects { v.ID = 0 v.RepositoryID = newRepo.ID - if _, err := ctx.Engine().Insert(v); err != nil { + if _, err := db.GetEngine(ctx).Insert(v); err != nil { return err } } @@ -2060,7 +2060,7 @@ func CopyLFS(ctx *db.Context, newRepo, oldRepo *Repository) error { func (repo *Repository) GetForks(listOptions ListOptions) ([]*Repository, error) { if listOptions.Page == 0 { forks := make([]*Repository, 0, repo.NumForks) - return forks, db.DefaultContext().Engine().Find(&forks, &Repository{ForkID: repo.ID}) + return forks, db.GetEngine(db.DefaultContext).Find(&forks, &Repository{ForkID: repo.ID}) } sess := getPaginatedSession(&listOptions) @@ -2071,7 +2071,7 @@ func (repo *Repository) GetForks(listOptions ListOptions) ([]*Repository, error) // GetUserFork return user forked repository from this repository, if not forked return nil func (repo *Repository) GetUserFork(userID int64) (*Repository, error) { var forkedRepo Repository - has, err := db.DefaultContext().Engine().Where("fork_id = ?", repo.ID).And("owner_id = ?", userID).Get(&forkedRepo) + has, err := db.GetEngine(db.DefaultContext).Where("fork_id = ?", repo.ID).And("owner_id = ?", userID).Get(&forkedRepo) if err != nil { return nil, err } @@ -2114,7 +2114,7 @@ func updateRepositoryCols(e db.Engine, repo *Repository, cols ...string) error { // UpdateRepositoryCols updates repository's columns func UpdateRepositoryCols(repo *Repository, cols ...string) error { - return updateRepositoryCols(db.DefaultContext().Engine(), repo, cols...) + return updateRepositoryCols(db.GetEngine(db.DefaultContext), repo, cols...) } // GetTrustModel will get the TrustModel for the repo or the default trust model @@ -2132,7 +2132,7 @@ func (repo *Repository) GetTrustModel() TrustModelType { // DoctorUserStarNum recalculate Stars number for all user func DoctorUserStarNum() (err error) { const batchSize = 100 - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() for start := 0; ; start += batchSize { @@ -2170,7 +2170,7 @@ func IterateRepository(f func(repo *Repository) error) error { batchSize := setting.Database.IterateBufferSize for { repos := make([]*Repository, 0, batchSize) - if err := db.DefaultContext().Engine().Limit(batchSize, start).Find(&repos); err != nil { + if err := db.GetEngine(db.DefaultContext).Limit(batchSize, start).Find(&repos); err != nil { return err } if len(repos) == 0 { diff --git a/models/repo_activity.go b/models/repo_activity.go index cfbda21411d8..5986da7e77ae 100644 --- a/models/repo_activity.go +++ b/models/repo_activity.go @@ -246,7 +246,7 @@ func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) e } func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged bool) *xorm.Session { - sess := db.DefaultContext().Engine().Where("pull_request.base_repo_id=?", repoID). + sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", repoID). Join("INNER", "issue", "pull_request.issue_id = issue.id") if merged { @@ -314,7 +314,7 @@ func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Tim } func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session { - sess := db.DefaultContext().Engine().Where("issue.repo_id = ?", repoID). + sess := db.GetEngine(db.DefaultContext).Where("issue.repo_id = ?", repoID). And("issue.is_closed = ?", closed) if !unresolved { @@ -356,7 +356,7 @@ func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error } func releasesForActivityStatement(repoID int64, fromTime time.Time) *xorm.Session { - return db.DefaultContext().Engine().Where("release.repo_id = ?", repoID). + return db.GetEngine(db.DefaultContext).Where("release.repo_id = ?", repoID). And("release.is_draft = ?", false). And("release.created_unix >= ?", fromTime.Unix()) } diff --git a/models/repo_archiver.go b/models/repo_archiver.go index b268e65e0dce..647a3b47be34 100644 --- a/models/repo_archiver.go +++ b/models/repo_archiver.go @@ -5,6 +5,7 @@ package models import ( + "context" "fmt" "code.gitea.io/gitea/models/db" @@ -43,7 +44,7 @@ func (archiver *RepoArchiver) LoadRepo() (*Repository, error) { } var repo Repository - has, err := db.DefaultContext().Engine().ID(archiver.RepoID).Get(&repo) + has, err := db.GetEngine(db.DefaultContext).ID(archiver.RepoID).Get(&repo) if err != nil { return nil, err } @@ -61,9 +62,9 @@ func (archiver *RepoArchiver) RelativePath() (string, error) { } // GetRepoArchiver get an archiver -func GetRepoArchiver(ctx *db.Context, repoID int64, tp git.ArchiveType, commitID string) (*RepoArchiver, error) { +func GetRepoArchiver(ctx context.Context, repoID int64, tp git.ArchiveType, commitID string) (*RepoArchiver, error) { var archiver RepoArchiver - has, err := ctx.Engine().Where("repo_id=?", repoID).And("`type`=?", tp).And("commit_id=?", commitID).Get(&archiver) + has, err := db.GetEngine(ctx).Where("repo_id=?", repoID).And("`type`=?", tp).And("commit_id=?", commitID).Get(&archiver) if err != nil { return nil, err } @@ -74,19 +75,19 @@ func GetRepoArchiver(ctx *db.Context, repoID int64, tp git.ArchiveType, commitID } // AddRepoArchiver adds an archiver -func AddRepoArchiver(ctx *db.Context, archiver *RepoArchiver) error { - _, err := ctx.Engine().Insert(archiver) +func AddRepoArchiver(ctx context.Context, archiver *RepoArchiver) error { + _, err := db.GetEngine(ctx).Insert(archiver) return err } // UpdateRepoArchiverStatus updates archiver's status -func UpdateRepoArchiverStatus(ctx *db.Context, archiver *RepoArchiver) error { - _, err := ctx.Engine().ID(archiver.ID).Cols("status").Update(archiver) +func UpdateRepoArchiverStatus(ctx context.Context, archiver *RepoArchiver) error { + _, err := db.GetEngine(ctx).ID(archiver.ID).Cols("status").Update(archiver) return err } // DeleteAllRepoArchives deletes all repo archives records func DeleteAllRepoArchives() error { - _, err := db.DefaultContext().Engine().Where("1=1").Delete(new(RepoArchiver)) + _, err := db.GetEngine(db.DefaultContext).Where("1=1").Delete(new(RepoArchiver)) return err } diff --git a/models/repo_avatar.go b/models/repo_avatar.go index 8133cdc70653..bb5f083dd5e2 100644 --- a/models/repo_avatar.go +++ b/models/repo_avatar.go @@ -57,7 +57,7 @@ func (repo *Repository) generateRandomAvatar(e db.Engine) error { // RemoveRandomAvatars removes the randomly generated avatars that were created for repositories func RemoveRandomAvatars(ctx context.Context) error { - return db.DefaultContext().Engine(). + return db.GetEngine(db.DefaultContext). Where("id > 0").BufferSize(setting.Database.IterateBufferSize). Iterate(new(Repository), func(idx int, bean interface{}) error { @@ -77,7 +77,7 @@ func RemoveRandomAvatars(ctx context.Context) error { // RelAvatarLink returns a relative link to the repository's avatar. func (repo *Repository) RelAvatarLink() string { - return repo.relAvatarLink(db.DefaultContext().Engine()) + return repo.relAvatarLink(db.GetEngine(db.DefaultContext)) } func (repo *Repository) relAvatarLink(e db.Engine) string { @@ -101,7 +101,7 @@ func (repo *Repository) relAvatarLink(e db.Engine) string { // AvatarLink returns a link to the repository's avatar. func (repo *Repository) AvatarLink() string { - return repo.avatarLink(db.DefaultContext().Engine()) + return repo.avatarLink(db.GetEngine(db.DefaultContext)) } // avatarLink returns user avatar absolute link. @@ -129,7 +129,7 @@ func (repo *Repository) UploadAvatar(data []byte) error { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -172,7 +172,7 @@ func (repo *Repository) DeleteAvatar() error { avatarPath := repo.CustomAvatarRelativePath() log.Trace("DeleteAvatar[%d]: %s", repo.ID, avatarPath) - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err diff --git a/models/repo_collaboration.go b/models/repo_collaboration.go index 1199a5640491..08d2062dbba0 100644 --- a/models/repo_collaboration.go +++ b/models/repo_collaboration.go @@ -51,7 +51,7 @@ func (repo *Repository) addCollaborator(e db.Engine, u *User) error { // AddCollaborator adds new collaboration to a repository with default access mode. func (repo *Repository) AddCollaborator(u *User) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -104,12 +104,12 @@ func (repo *Repository) getCollaborators(e db.Engine, listOptions ListOptions) ( // GetCollaborators returns the collaborators for a repository func (repo *Repository) GetCollaborators(listOptions ListOptions) ([]*Collaborator, error) { - return repo.getCollaborators(db.DefaultContext().Engine(), listOptions) + return repo.getCollaborators(db.GetEngine(db.DefaultContext), listOptions) } // CountCollaborators returns total number of collaborators for a repository func (repo *Repository) CountCollaborators() (int64, error) { - return db.DefaultContext().Engine().Where("repo_id = ? ", repo.ID).Count(&Collaboration{}) + return db.GetEngine(db.DefaultContext).Where("repo_id = ? ", repo.ID).Count(&Collaboration{}) } func (repo *Repository) getCollaboration(e db.Engine, uid int64) (*Collaboration, error) { @@ -130,7 +130,7 @@ func (repo *Repository) isCollaborator(e db.Engine, userID int64) (bool, error) // IsCollaborator check if a user is a collaborator of a repository func (repo *Repository) IsCollaborator(userID int64) (bool, error) { - return repo.isCollaborator(db.DefaultContext().Engine(), userID) + return repo.isCollaborator(db.GetEngine(db.DefaultContext), userID) } func (repo *Repository) changeCollaborationAccessMode(e db.Engine, uid int64, mode AccessMode) error { @@ -169,7 +169,7 @@ func (repo *Repository) changeCollaborationAccessMode(e db.Engine, uid int64, mo // ChangeCollaborationAccessMode sets new access mode for the collaboration. func (repo *Repository) ChangeCollaborationAccessMode(uid int64, mode AccessMode) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -189,7 +189,7 @@ func (repo *Repository) DeleteCollaboration(uid int64) (err error) { UserID: uid, } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -259,7 +259,7 @@ func (repo *Repository) getRepoTeams(e db.Engine) (teams []*Team, err error) { // GetRepoTeams gets the list of teams that has access to the repository func (repo *Repository) GetRepoTeams() ([]*Team, error) { - return repo.getRepoTeams(db.DefaultContext().Engine()) + return repo.getRepoTeams(db.GetEngine(db.DefaultContext)) } // IsOwnerMemberCollaborator checks if a provided user is the owner, a collaborator or a member of a team in a repository @@ -267,7 +267,7 @@ func (repo *Repository) IsOwnerMemberCollaborator(userID int64) (bool, error) { if repo.OwnerID == userID { return true, nil } - teamMember, err := db.DefaultContext().Engine().Join("INNER", "team_repo", "team_repo.team_id = team_user.team_id"). + teamMember, err := db.GetEngine(db.DefaultContext).Join("INNER", "team_repo", "team_repo.team_id = team_user.team_id"). Join("INNER", "team_unit", "team_unit.team_id = team_user.team_id"). Where("team_repo.repo_id = ?", repo.ID). And("team_unit.`type` = ?", UnitTypeCode). @@ -279,5 +279,5 @@ func (repo *Repository) IsOwnerMemberCollaborator(userID int64) (bool, error) { return true, nil } - return db.DefaultContext().Engine().Get(&Collaboration{RepoID: repo.ID, UserID: userID}) + return db.GetEngine(db.DefaultContext).Get(&Collaboration{RepoID: repo.ID, UserID: userID}) } diff --git a/models/repo_collaboration_test.go b/models/repo_collaboration_test.go index 32dd2ef5a8f4..5a3ffef5fae1 100644 --- a/models/repo_collaboration_test.go +++ b/models/repo_collaboration_test.go @@ -32,7 +32,7 @@ func TestRepository_GetCollaborators(t *testing.T) { repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository) collaborators, err := repo.GetCollaborators(ListOptions{}) assert.NoError(t, err) - expectedLen, err := db.DefaultContext().Engine().Count(&Collaboration{RepoID: repoID}) + expectedLen, err := db.GetEngine(db.DefaultContext).Count(&Collaboration{RepoID: repoID}) assert.NoError(t, err) assert.Len(t, collaborators, int(expectedLen)) for _, collaborator := range collaborators { diff --git a/models/repo_generate.go b/models/repo_generate.go index e5f5c7e7fe4e..cb8bf45184f9 100644 --- a/models/repo_generate.go +++ b/models/repo_generate.go @@ -7,6 +7,7 @@ package models import ( "bufio" "bytes" + "context" "strconv" "strings" @@ -68,9 +69,9 @@ func (gt GiteaTemplate) Globs() []glob.Glob { } // GenerateTopics generates topics from a template repository -func GenerateTopics(ctx *db.Context, templateRepo, generateRepo *Repository) error { +func GenerateTopics(ctx context.Context, templateRepo, generateRepo *Repository) error { for _, topic := range templateRepo.Topics { - if _, err := addTopicByNameToRepo(ctx.Engine(), generateRepo.ID, topic); err != nil { + if _, err := addTopicByNameToRepo(db.GetEngine(ctx), generateRepo.ID, topic); err != nil { return err } } @@ -78,7 +79,7 @@ func GenerateTopics(ctx *db.Context, templateRepo, generateRepo *Repository) err } // GenerateGitHooks generates git hooks from a template repository -func GenerateGitHooks(ctx *db.Context, templateRepo, generateRepo *Repository) error { +func GenerateGitHooks(ctx context.Context, templateRepo, generateRepo *Repository) error { generateGitRepo, err := git.OpenRepository(generateRepo.RepoPath()) if err != nil { return err @@ -111,7 +112,7 @@ func GenerateGitHooks(ctx *db.Context, templateRepo, generateRepo *Repository) e } // GenerateWebhooks generates webhooks from a template repository -func GenerateWebhooks(ctx *db.Context, templateRepo, generateRepo *Repository) error { +func GenerateWebhooks(ctx context.Context, templateRepo, generateRepo *Repository) error { templateWebhooks, err := ListWebhooksByOpts(&ListWebhookOptions{RepoID: templateRepo.ID}) if err != nil { return err @@ -131,7 +132,7 @@ func GenerateWebhooks(ctx *db.Context, templateRepo, generateRepo *Repository) e Events: templateWebhook.Events, Meta: templateWebhook.Meta, } - if err := createWebhook(ctx.Engine(), generateWebhook); err != nil { + if err := createWebhook(db.GetEngine(ctx), generateWebhook); err != nil { return err } } @@ -139,18 +140,18 @@ func GenerateWebhooks(ctx *db.Context, templateRepo, generateRepo *Repository) e } // GenerateAvatar generates the avatar from a template repository -func GenerateAvatar(ctx *db.Context, templateRepo, generateRepo *Repository) error { +func GenerateAvatar(ctx context.Context, templateRepo, generateRepo *Repository) error { generateRepo.Avatar = strings.Replace(templateRepo.Avatar, strconv.FormatInt(templateRepo.ID, 10), strconv.FormatInt(generateRepo.ID, 10), 1) if _, err := storage.Copy(storage.RepoAvatars, generateRepo.CustomAvatarRelativePath(), storage.RepoAvatars, templateRepo.CustomAvatarRelativePath()); err != nil { return err } - return updateRepositoryCols(ctx.Engine(), generateRepo, "avatar") + return updateRepositoryCols(db.GetEngine(ctx), generateRepo, "avatar") } // GenerateIssueLabels generates issue labels from a template repository -func GenerateIssueLabels(ctx *db.Context, templateRepo, generateRepo *Repository) error { - templateLabels, err := getLabelsByRepoID(ctx.Engine(), templateRepo.ID, "", ListOptions{}) +func GenerateIssueLabels(ctx context.Context, templateRepo, generateRepo *Repository) error { + templateLabels, err := getLabelsByRepoID(db.GetEngine(ctx), templateRepo.ID, "", ListOptions{}) if err != nil { return err } @@ -162,7 +163,7 @@ func GenerateIssueLabels(ctx *db.Context, templateRepo, generateRepo *Repository Description: templateLabel.Description, Color: templateLabel.Color, } - if err := newLabel(ctx.Engine(), generateLabel); err != nil { + if err := newLabel(db.GetEngine(ctx), generateLabel); err != nil { return err } } diff --git a/models/repo_indexer.go b/models/repo_indexer.go index 4d3d18cac3fd..7029b0922b07 100644 --- a/models/repo_indexer.go +++ b/models/repo_indexer.go @@ -42,7 +42,7 @@ func GetUnindexedRepos(indexerType RepoIndexerType, maxRepoID int64, page, pageS }).And(builder.Eq{ "repository.is_empty": false, }) - sess := db.DefaultContext().Engine().Table("repository").Join("LEFT OUTER", "repo_indexer_status", "repository.id = repo_indexer_status.repo_id AND repo_indexer_status.indexer_type = ?", indexerType) + sess := db.GetEngine(db.DefaultContext).Table("repository").Join("LEFT OUTER", "repo_indexer_status", "repository.id = repo_indexer_status.repo_id AND repo_indexer_status.indexer_type = ?", indexerType) if maxRepoID > 0 { cond = builder.And(cond, builder.Lte{ "repository.id": maxRepoID, @@ -91,7 +91,7 @@ func (repo *Repository) getIndexerStatus(e db.Engine, indexerType RepoIndexerTyp // GetIndexerStatus loads repo codes indxer status func (repo *Repository) GetIndexerStatus(indexerType RepoIndexerType) (*RepoIndexerStatus, error) { - return repo.getIndexerStatus(db.DefaultContext().Engine(), indexerType) + return repo.getIndexerStatus(db.GetEngine(db.DefaultContext), indexerType) } // updateIndexerStatus updates indexer status @@ -120,5 +120,5 @@ func (repo *Repository) updateIndexerStatus(e db.Engine, indexerType RepoIndexer // UpdateIndexerStatus updates indexer status func (repo *Repository) UpdateIndexerStatus(indexerType RepoIndexerType, sha string) error { - return repo.updateIndexerStatus(db.DefaultContext().Engine(), indexerType, sha) + return repo.updateIndexerStatus(db.GetEngine(db.DefaultContext), indexerType, sha) } diff --git a/models/repo_language_stats.go b/models/repo_language_stats.go index c7af7dc9b5da..2f126aa3d2ac 100644 --- a/models/repo_language_stats.go +++ b/models/repo_language_stats.go @@ -75,12 +75,12 @@ func (repo *Repository) getLanguageStats(e db.Engine) (LanguageStatList, error) // GetLanguageStats returns the language statistics for a repository func (repo *Repository) GetLanguageStats() (LanguageStatList, error) { - return repo.getLanguageStats(db.DefaultContext().Engine()) + return repo.getLanguageStats(db.GetEngine(db.DefaultContext)) } // GetTopLanguageStats returns the top language statistics for a repository func (repo *Repository) GetTopLanguageStats(limit int) (LanguageStatList, error) { - stats, err := repo.getLanguageStats(db.DefaultContext().Engine()) + stats, err := repo.getLanguageStats(db.GetEngine(db.DefaultContext)) if err != nil { return nil, err } @@ -112,7 +112,7 @@ func (repo *Repository) GetTopLanguageStats(limit int) (LanguageStatList, error) // UpdateLanguageStats updates the language statistics for repository func (repo *Repository) UpdateLanguageStats(commitID string, stats map[string]int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) if err := sess.Begin(); err != nil { return err } @@ -183,7 +183,7 @@ func (repo *Repository) UpdateLanguageStats(commitID string, stats map[string]in // CopyLanguageStat Copy originalRepo language stat information to destRepo (use for forked repo) func CopyLanguageStat(originalRepo, destRepo *Repository) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err diff --git a/models/repo_list.go b/models/repo_list.go index 686fd27b77d6..7179114f4672 100644 --- a/models/repo_list.go +++ b/models/repo_list.go @@ -90,7 +90,7 @@ func (repos RepositoryList) loadAttributes(e db.Engine) error { // LoadAttributes loads the attributes for the given RepositoryList func (repos RepositoryList) LoadAttributes() error { - return repos.loadAttributes(db.DefaultContext().Engine()) + return repos.loadAttributes(db.GetEngine(db.DefaultContext)) } // MirrorRepositoryList contains the mirror repositories @@ -130,7 +130,7 @@ func (repos MirrorRepositoryList) loadAttributes(e db.Engine) error { // LoadAttributes loads the attributes for the given MirrorRepositoryList func (repos MirrorRepositoryList) LoadAttributes() error { - return repos.loadAttributes(db.DefaultContext().Engine()) + return repos.loadAttributes(db.GetEngine(db.DefaultContext)) } // SearchRepoOptions holds the search options @@ -410,7 +410,7 @@ func searchRepositoryByCondition(opts *SearchRepoOptions, cond builder.Cond) (*x opts.OrderBy = SearchOrderBy(fmt.Sprintf("CASE WHEN owner_id = %d THEN 0 ELSE owner_id END, %s", opts.PriorityOwnerID, opts.OrderBy)) } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) var count int64 if opts.PageSize > 0 { @@ -520,7 +520,7 @@ func AccessibleRepoIDsQuery(user *User) *builder.Builder { // FindUserAccessibleRepoIDs find all accessible repositories' ID by user's id func FindUserAccessibleRepoIDs(user *User) ([]int64, error) { repoIDs := make([]int64, 0, 10) - if err := db.DefaultContext().Engine(). + if err := db.GetEngine(db.DefaultContext). Table("repository"). Cols("id"). Where(accessibleRepositoryCondition(user)). diff --git a/models/repo_mirror.go b/models/repo_mirror.go index b086b87cf131..35685b322096 100644 --- a/models/repo_mirror.go +++ b/models/repo_mirror.go @@ -95,7 +95,7 @@ func getMirrorByRepoID(e db.Engine, repoID int64) (*Mirror, error) { // GetMirrorByRepoID returns mirror information of a repository. func GetMirrorByRepoID(repoID int64) (*Mirror, error) { - return getMirrorByRepoID(db.DefaultContext().Engine(), repoID) + return getMirrorByRepoID(db.GetEngine(db.DefaultContext), repoID) } func updateMirror(e db.Engine, m *Mirror) error { @@ -105,18 +105,18 @@ func updateMirror(e db.Engine, m *Mirror) error { // UpdateMirror updates the mirror func UpdateMirror(m *Mirror) error { - return updateMirror(db.DefaultContext().Engine(), m) + return updateMirror(db.GetEngine(db.DefaultContext), m) } // DeleteMirrorByRepoID deletes a mirror by repoID func DeleteMirrorByRepoID(repoID int64) error { - _, err := db.DefaultContext().Engine().Delete(&Mirror{RepoID: repoID}) + _, err := db.GetEngine(db.DefaultContext).Delete(&Mirror{RepoID: repoID}) return err } // MirrorsIterate iterates all mirror repositories. func MirrorsIterate(f func(idx int, bean interface{}) error) error { - return db.DefaultContext().Engine(). + return db.GetEngine(db.DefaultContext). Where("next_update_unix<=?", time.Now().Unix()). And("next_update_unix!=0"). Iterate(new(Mirror), f) @@ -124,6 +124,6 @@ func MirrorsIterate(f func(idx int, bean interface{}) error) error { // InsertMirror inserts a mirror to database func InsertMirror(mirror *Mirror) error { - _, err := db.DefaultContext().Engine().Insert(mirror) + _, err := db.GetEngine(db.DefaultContext).Insert(mirror) return err } diff --git a/models/repo_permission.go b/models/repo_permission.go index 4f518e1746f4..5ec933aa0fce 100644 --- a/models/repo_permission.go +++ b/models/repo_permission.go @@ -140,7 +140,7 @@ func (p *Permission) ColorFormat(s fmt.State) { // GetUserRepoPermission returns the user permissions to the repository func GetUserRepoPermission(repo *Repository, user *User) (Permission, error) { - return getUserRepoPermission(db.DefaultContext().Engine(), repo, user) + return getUserRepoPermission(db.GetEngine(db.DefaultContext), repo, user) } func getUserRepoPermission(e db.Engine, repo *Repository, user *User) (perm Permission, err error) { @@ -278,7 +278,7 @@ func IsUserRealRepoAdmin(repo *Repository, user *User) (bool, error) { return true, nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := repo.getOwner(sess); err != nil { @@ -295,7 +295,7 @@ func IsUserRealRepoAdmin(repo *Repository, user *User) (bool, error) { // IsUserRepoAdmin return true if user has admin right of a repo func IsUserRepoAdmin(repo *Repository, user *User) (bool, error) { - return isUserRepoAdmin(db.DefaultContext().Engine(), repo, user) + return isUserRepoAdmin(db.GetEngine(db.DefaultContext), repo, user) } func isUserRepoAdmin(e db.Engine, repo *Repository, user *User) (bool, error) { @@ -330,13 +330,13 @@ func isUserRepoAdmin(e db.Engine, repo *Repository, user *User) (bool, error) { // AccessLevel returns the Access a user has to a repository. Will return NoneAccess if the // user does not have access. func AccessLevel(user *User, repo *Repository) (AccessMode, error) { - return accessLevelUnit(db.DefaultContext().Engine(), user, repo, UnitTypeCode) + return accessLevelUnit(db.GetEngine(db.DefaultContext), user, repo, UnitTypeCode) } // AccessLevelUnit returns the Access a user has to a repository's. Will return NoneAccess if the // user does not have access. func AccessLevelUnit(user *User, repo *Repository, unitType UnitType) (AccessMode, error) { - return accessLevelUnit(db.DefaultContext().Engine(), user, repo, unitType) + return accessLevelUnit(db.GetEngine(db.DefaultContext), user, repo, unitType) } func accessLevelUnit(e db.Engine, user *User, repo *Repository, unitType UnitType) (AccessMode, error) { @@ -354,14 +354,14 @@ func hasAccessUnit(e db.Engine, user *User, repo *Repository, unitType UnitType, // HasAccessUnit returns true if user has testMode to the unit of the repository func HasAccessUnit(user *User, repo *Repository, unitType UnitType, testMode AccessMode) (bool, error) { - return hasAccessUnit(db.DefaultContext().Engine(), user, repo, unitType, testMode) + return hasAccessUnit(db.GetEngine(db.DefaultContext), user, repo, unitType, testMode) } // CanBeAssigned return true if user can be assigned to issue or pull requests in repo // Currently any write access (code, issues or pr's) is assignable, to match assignee list in user interface. // FIXME: user could send PullRequest also could be assigned??? func CanBeAssigned(user *User, repo *Repository, isPull bool) (bool, error) { - return canBeAssigned(db.DefaultContext().Engine(), user, repo, isPull) + return canBeAssigned(db.GetEngine(db.DefaultContext), user, repo, isPull) } func canBeAssigned(e db.Engine, user *User, repo *Repository, _ bool) (bool, error) { @@ -393,7 +393,7 @@ func hasAccess(e db.Engine, userID int64, repo *Repository) (bool, error) { // HasAccess returns true if user has access to repo func HasAccess(userID int64, repo *Repository) (bool, error) { - return hasAccess(db.DefaultContext().Engine(), userID, repo) + return hasAccess(db.GetEngine(db.DefaultContext), userID, repo) } // FilterOutRepoIdsWithoutUnitAccess filter out repos where user has no access to repositories diff --git a/models/repo_permission_test.go b/models/repo_permission_test.go index 625e0dd205b4..1fbf1b9f8fa3 100644 --- a/models/repo_permission_test.go +++ b/models/repo_permission_test.go @@ -16,7 +16,7 @@ func TestRepoPermissionPublicNonOrgRepo(t *testing.T) { // public non-organization repo repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository) - assert.NoError(t, repo.getUnits(db.DefaultContext().Engine())) + assert.NoError(t, repo.getUnits(db.GetEngine(db.DefaultContext))) // plain user user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User) @@ -69,7 +69,7 @@ func TestRepoPermissionPrivateNonOrgRepo(t *testing.T) { // private non-organization repo repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 2}).(*Repository) - assert.NoError(t, repo.getUnits(db.DefaultContext().Engine())) + assert.NoError(t, repo.getUnits(db.GetEngine(db.DefaultContext))) // plain user user := db.AssertExistsAndLoadBean(t, &User{ID: 4}).(*User) @@ -121,7 +121,7 @@ func TestRepoPermissionPublicOrgRepo(t *testing.T) { // public organization repo repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 32}).(*Repository) - assert.NoError(t, repo.getUnits(db.DefaultContext().Engine())) + assert.NoError(t, repo.getUnits(db.GetEngine(db.DefaultContext))) // plain user user := db.AssertExistsAndLoadBean(t, &User{ID: 5}).(*User) @@ -183,7 +183,7 @@ func TestRepoPermissionPrivateOrgRepo(t *testing.T) { // private organization repo repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 24}).(*Repository) - assert.NoError(t, repo.getUnits(db.DefaultContext().Engine())) + assert.NoError(t, repo.getUnits(db.GetEngine(db.DefaultContext))) // plain user user := db.AssertExistsAndLoadBean(t, &User{ID: 5}).(*User) diff --git a/models/repo_pushmirror.go b/models/repo_pushmirror.go index b0e285216227..c6207bae6df6 100644 --- a/models/repo_pushmirror.go +++ b/models/repo_pushmirror.go @@ -62,32 +62,32 @@ func (m *PushMirror) GetRemoteName() string { // InsertPushMirror inserts a push-mirror to database func InsertPushMirror(m *PushMirror) error { - _, err := db.DefaultContext().Engine().Insert(m) + _, err := db.GetEngine(db.DefaultContext).Insert(m) return err } // UpdatePushMirror updates the push-mirror func UpdatePushMirror(m *PushMirror) error { - _, err := db.DefaultContext().Engine().ID(m.ID).AllCols().Update(m) + _, err := db.GetEngine(db.DefaultContext).ID(m.ID).AllCols().Update(m) return err } // DeletePushMirrorByID deletes a push-mirrors by ID func DeletePushMirrorByID(ID int64) error { - _, err := db.DefaultContext().Engine().ID(ID).Delete(&PushMirror{}) + _, err := db.GetEngine(db.DefaultContext).ID(ID).Delete(&PushMirror{}) return err } // DeletePushMirrorsByRepoID deletes all push-mirrors by repoID func DeletePushMirrorsByRepoID(repoID int64) error { - _, err := db.DefaultContext().Engine().Delete(&PushMirror{RepoID: repoID}) + _, err := db.GetEngine(db.DefaultContext).Delete(&PushMirror{RepoID: repoID}) return err } // GetPushMirrorByID returns push-mirror information. func GetPushMirrorByID(ID int64) (*PushMirror, error) { m := &PushMirror{} - has, err := db.DefaultContext().Engine().ID(ID).Get(m) + has, err := db.GetEngine(db.DefaultContext).ID(ID).Get(m) if err != nil { return nil, err } else if !has { @@ -99,12 +99,12 @@ func GetPushMirrorByID(ID int64) (*PushMirror, error) { // GetPushMirrorsByRepoID returns push-mirror information of a repository. func GetPushMirrorsByRepoID(repoID int64) ([]*PushMirror, error) { mirrors := make([]*PushMirror, 0, 10) - return mirrors, db.DefaultContext().Engine().Where("repo_id=?", repoID).Find(&mirrors) + return mirrors, db.GetEngine(db.DefaultContext).Where("repo_id=?", repoID).Find(&mirrors) } // PushMirrorsIterate iterates all push-mirror repositories. func PushMirrorsIterate(f func(idx int, bean interface{}) error) error { - return db.DefaultContext().Engine(). + return db.GetEngine(db.DefaultContext). Where("last_update + (`interval` / ?) <= ?", time.Second, time.Now().Unix()). And("`interval` != 0"). Iterate(new(PushMirror), f) diff --git a/models/repo_redirect.go b/models/repo_redirect.go index 95196d2a0f2e..18422f9d18e2 100644 --- a/models/repo_redirect.go +++ b/models/repo_redirect.go @@ -26,7 +26,7 @@ func init() { func LookupRepoRedirect(ownerID int64, repoName string) (int64, error) { repoName = strings.ToLower(repoName) redirect := &RepoRedirect{OwnerID: ownerID, LowerName: repoName} - if has, err := db.DefaultContext().Engine().Get(redirect); err != nil { + if has, err := db.GetEngine(db.DefaultContext).Get(redirect); err != nil { return 0, err } else if !has { return 0, ErrRepoRedirectNotExist{OwnerID: ownerID, RepoName: repoName} diff --git a/models/repo_redirect_test.go b/models/repo_redirect_test.go index 9e99fdae37d2..9400422752cb 100644 --- a/models/repo_redirect_test.go +++ b/models/repo_redirect_test.go @@ -27,7 +27,7 @@ func TestNewRepoRedirect(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository) - assert.NoError(t, newRepoRedirect(db.DefaultContext().Engine(), repo.OwnerID, repo.ID, repo.Name, "newreponame")) + assert.NoError(t, newRepoRedirect(db.GetEngine(db.DefaultContext), repo.OwnerID, repo.ID, repo.Name, "newreponame")) db.AssertExistsAndLoadBean(t, &RepoRedirect{ OwnerID: repo.OwnerID, @@ -46,7 +46,7 @@ func TestNewRepoRedirect2(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository) - assert.NoError(t, newRepoRedirect(db.DefaultContext().Engine(), repo.OwnerID, repo.ID, repo.Name, "oldrepo1")) + assert.NoError(t, newRepoRedirect(db.GetEngine(db.DefaultContext), repo.OwnerID, repo.ID, repo.Name, "oldrepo1")) db.AssertExistsAndLoadBean(t, &RepoRedirect{ OwnerID: repo.OwnerID, @@ -65,7 +65,7 @@ func TestNewRepoRedirect3(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 2}).(*Repository) - assert.NoError(t, newRepoRedirect(db.DefaultContext().Engine(), repo.OwnerID, repo.ID, repo.Name, "newreponame")) + assert.NoError(t, newRepoRedirect(db.GetEngine(db.DefaultContext), repo.OwnerID, repo.ID, repo.Name, "newreponame")) db.AssertExistsAndLoadBean(t, &RepoRedirect{ OwnerID: repo.OwnerID, diff --git a/models/repo_test.go b/models/repo_test.go index 2467b63840b1..8073a9cd2f2a 100644 --- a/models/repo_test.go +++ b/models/repo_test.go @@ -109,7 +109,7 @@ func TestUpdateRepositoryVisibilityChanged(t *testing.T) { // Check visibility of action has become private act := Action{} - _, err = db.DefaultContext().Engine().ID(3).Get(&act) + _, err = db.GetEngine(db.DefaultContext).ID(3).Get(&act) assert.NoError(t, err) assert.True(t, act.IsPrivate) diff --git a/models/repo_transfer.go b/models/repo_transfer.go index 2ede0fbbe71b..e3eb756eb405 100644 --- a/models/repo_transfer.go +++ b/models/repo_transfer.go @@ -98,7 +98,7 @@ func (r *RepoTransfer) CanUserAcceptTransfer(u *User) bool { func GetPendingRepositoryTransfer(repo *Repository) (*RepoTransfer, error) { transfer := new(RepoTransfer) - has, err := db.DefaultContext().Engine().Where("repo_id = ? ", repo.ID).Get(transfer) + has, err := db.GetEngine(db.DefaultContext).Where("repo_id = ? ", repo.ID).Get(transfer) if err != nil { return nil, err } @@ -118,7 +118,7 @@ func deleteRepositoryTransfer(e db.Engine, repoID int64) error { // CancelRepositoryTransfer marks the repository as ready and remove pending transfer entry, // thus cancel the transfer process. func CancelRepositoryTransfer(repo *Repository) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -150,7 +150,7 @@ func TestRepositoryReadyForTransfer(status RepositoryStatus) error { // CreatePendingRepositoryTransfer transfer a repo from one owner to a new one. // it marks the repository transfer as "pending" func CreatePendingRepositoryTransfer(doer, newOwner *User, repoID int64, teams []*Team) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -232,7 +232,7 @@ func TransferOwnership(doer *User, newOwnerName string, repo *Repository) (err e } }() - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return fmt.Errorf("sess.Begin: %v", err) diff --git a/models/repo_watch.go b/models/repo_watch.go index 635507561e43..d3720fe857a4 100644 --- a/models/repo_watch.go +++ b/models/repo_watch.go @@ -60,7 +60,7 @@ func isWatchMode(mode RepoWatchMode) bool { // IsWatching checks if user has watched given repository. func IsWatching(userID, repoID int64) bool { - watch, err := getWatch(db.DefaultContext().Engine(), userID, repoID) + watch, err := getWatch(db.GetEngine(db.DefaultContext), userID, repoID) return err == nil && isWatchMode(watch.Mode) } @@ -107,10 +107,10 @@ func watchRepoMode(e db.Engine, watch Watch, mode RepoWatchMode) (err error) { // WatchRepoMode watch repository in specific mode. func WatchRepoMode(userID, repoID int64, mode RepoWatchMode) (err error) { var watch Watch - if watch, err = getWatch(db.DefaultContext().Engine(), userID, repoID); err != nil { + if watch, err = getWatch(db.GetEngine(db.DefaultContext), userID, repoID); err != nil { return err } - return watchRepoMode(db.DefaultContext().Engine(), watch, mode) + return watchRepoMode(db.GetEngine(db.DefaultContext), watch, mode) } func watchRepo(e db.Engine, userID, repoID int64, doWatch bool) (err error) { @@ -130,7 +130,7 @@ func watchRepo(e db.Engine, userID, repoID int64, doWatch bool) (err error) { // WatchRepo watch or unwatch repository. func WatchRepo(userID, repoID int64, watch bool) (err error) { - return watchRepo(db.DefaultContext().Engine(), userID, repoID, watch) + return watchRepo(db.GetEngine(db.DefaultContext), userID, repoID, watch) } func getWatchers(e db.Engine, repoID int64) ([]*Watch, error) { @@ -145,14 +145,14 @@ func getWatchers(e db.Engine, repoID int64) ([]*Watch, error) { // GetWatchers returns all watchers of given repository. func GetWatchers(repoID int64) ([]*Watch, error) { - return getWatchers(db.DefaultContext().Engine(), repoID) + return getWatchers(db.GetEngine(db.DefaultContext), repoID) } // GetRepoWatchersIDs returns IDs of watchers for a given repo ID // but avoids joining with `user` for performance reasons // User permissions must be verified elsewhere if required func GetRepoWatchersIDs(repoID int64) ([]int64, error) { - return getRepoWatchersIDs(db.DefaultContext().Engine(), repoID) + return getRepoWatchersIDs(db.GetEngine(db.DefaultContext), repoID) } func getRepoWatchersIDs(e db.Engine, repoID int64) ([]int64, error) { @@ -166,7 +166,7 @@ func getRepoWatchersIDs(e db.Engine, repoID int64) ([]int64, error) { // GetWatchers returns range of users watching given repository. func (repo *Repository) GetWatchers(opts ListOptions) ([]*User, error) { - sess := db.DefaultContext().Engine().Where("watch.repo_id=?", repo.ID). + sess := db.GetEngine(db.DefaultContext).Where("watch.repo_id=?", repo.ID). Join("LEFT", "watch", "`user`.id=`watch`.user_id"). And("`watch`.mode<>?", RepoWatchModeDont) if opts.Page > 0 { @@ -284,12 +284,12 @@ func notifyWatchers(e db.Engine, actions ...*Action) error { // NotifyWatchers creates batch of actions for every watcher. func NotifyWatchers(actions ...*Action) error { - return notifyWatchers(db.DefaultContext().Engine(), actions...) + return notifyWatchers(db.GetEngine(db.DefaultContext), actions...) } // NotifyWatchersActions creates batch of actions for every watcher. func NotifyWatchersActions(acts []*Action) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -318,5 +318,5 @@ func watchIfAuto(e db.Engine, userID, repoID int64, isWrite bool) error { // WatchIfAuto subscribes to repo if AutoWatchOnChanges is set func WatchIfAuto(userID, repoID int64, isWrite bool) error { - return watchIfAuto(db.DefaultContext().Engine(), userID, repoID, isWrite) + return watchIfAuto(db.GetEngine(db.DefaultContext), userID, repoID, isWrite) } diff --git a/models/review.go b/models/review.go index ed656f49737a..3f81c39b2fc8 100644 --- a/models/review.go +++ b/models/review.go @@ -95,7 +95,7 @@ func (r *Review) loadCodeComments(e db.Engine) (err error) { // LoadCodeComments loads CodeComments func (r *Review) LoadCodeComments() error { - return r.loadCodeComments(db.DefaultContext().Engine()) + return r.loadCodeComments(db.GetEngine(db.DefaultContext)) } func (r *Review) loadIssue(e db.Engine) (err error) { @@ -125,12 +125,12 @@ func (r *Review) loadReviewerTeam(e db.Engine) (err error) { // LoadReviewer loads reviewer func (r *Review) LoadReviewer() error { - return r.loadReviewer(db.DefaultContext().Engine()) + return r.loadReviewer(db.GetEngine(db.DefaultContext)) } // LoadReviewerTeam loads reviewer team func (r *Review) LoadReviewerTeam() error { - return r.loadReviewerTeam(db.DefaultContext().Engine()) + return r.loadReviewerTeam(db.GetEngine(db.DefaultContext)) } func (r *Review) loadAttributes(e db.Engine) (err error) { @@ -151,7 +151,7 @@ func (r *Review) loadAttributes(e db.Engine) (err error) { // LoadAttributes loads all attributes except CodeComments func (r *Review) LoadAttributes() error { - return r.loadAttributes(db.DefaultContext().Engine()) + return r.loadAttributes(db.GetEngine(db.DefaultContext)) } func getReviewByID(e db.Engine, id int64) (*Review, error) { @@ -167,7 +167,7 @@ func getReviewByID(e db.Engine, id int64) (*Review, error) { // GetReviewByID returns the review by the given ID func GetReviewByID(id int64) (*Review, error) { - return getReviewByID(db.DefaultContext().Engine(), id) + return getReviewByID(db.GetEngine(db.DefaultContext), id) } // FindReviewOptions represent possible filters to find reviews @@ -210,12 +210,12 @@ func findReviews(e db.Engine, opts FindReviewOptions) ([]*Review, error) { // FindReviews returns reviews passing FindReviewOptions func FindReviews(opts FindReviewOptions) ([]*Review, error) { - return findReviews(db.DefaultContext().Engine(), opts) + return findReviews(db.GetEngine(db.DefaultContext), opts) } // CountReviews returns count of reviews passing FindReviewOptions func CountReviews(opts FindReviewOptions) (int64, error) { - return db.DefaultContext().Engine().Where(opts.toCond()).Count(&Review{}) + return db.GetEngine(db.DefaultContext).Where(opts.toCond()).Count(&Review{}) } // CreateReviewOptions represent the options to create a review. Type, Issue and Reviewer are required. @@ -232,7 +232,7 @@ type CreateReviewOptions struct { // IsOfficialReviewer check if at least one of the provided reviewers can make official reviews in issue (counts towards required approvals) func IsOfficialReviewer(issue *Issue, reviewers ...*User) (bool, error) { - return isOfficialReviewer(db.DefaultContext().Engine(), issue, reviewers...) + return isOfficialReviewer(db.GetEngine(db.DefaultContext), issue, reviewers...) } func isOfficialReviewer(e db.Engine, issue *Issue, reviewers ...*User) (bool, error) { @@ -259,7 +259,7 @@ func isOfficialReviewer(e db.Engine, issue *Issue, reviewers ...*User) (bool, er // IsOfficialReviewerTeam check if reviewer in this team can make official reviews in issue (counts towards required approvals) func IsOfficialReviewerTeam(issue *Issue, team *Team) (bool, error) { - return isOfficialReviewerTeam(db.DefaultContext().Engine(), issue, team) + return isOfficialReviewerTeam(db.GetEngine(db.DefaultContext), issue, team) } func isOfficialReviewerTeam(e db.Engine, issue *Issue, team *Team) (bool, error) { @@ -310,7 +310,7 @@ func createReview(e db.Engine, opts CreateReviewOptions) (*Review, error) { // CreateReview creates a new review based on opts func CreateReview(opts CreateReviewOptions) (*Review, error) { - return createReview(db.DefaultContext().Engine(), opts) + return createReview(db.GetEngine(db.DefaultContext), opts) } func getCurrentReview(e db.Engine, reviewer *User, issue *Issue) (*Review, error) { @@ -335,12 +335,12 @@ func getCurrentReview(e db.Engine, reviewer *User, issue *Issue) (*Review, error // ReviewExists returns whether a review exists for a particular line of code in the PR func ReviewExists(issue *Issue, treePath string, line int64) (bool, error) { - return db.DefaultContext().Engine().Cols("id").Exist(&Comment{IssueID: issue.ID, TreePath: treePath, Line: line, Type: CommentTypeCode}) + return db.GetEngine(db.DefaultContext).Cols("id").Exist(&Comment{IssueID: issue.ID, TreePath: treePath, Line: line, Type: CommentTypeCode}) } // GetCurrentReview returns the current pending review of reviewer for given issue func GetCurrentReview(reviewer *User, issue *Issue) (*Review, error) { - return getCurrentReview(db.DefaultContext().Engine(), reviewer, issue) + return getCurrentReview(db.GetEngine(db.DefaultContext), reviewer, issue) } // ContentEmptyErr represents an content empty error @@ -358,7 +358,7 @@ func IsContentEmptyErr(err error) bool { // SubmitReview creates a review out of the existing pending review or creates a new one if no pending review exist func SubmitReview(doer *User, issue *Issue, reviewType ReviewType, content, commitID string, stale bool, attachmentUUIDs []string) (*Review, *Comment, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, nil, err @@ -470,7 +470,7 @@ func SubmitReview(doer *User, issue *Issue, reviewType ReviewType, content, comm func GetReviewersByIssueID(issueID int64) ([]*Review, error) { reviews := make([]*Review, 0, 10) - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -502,7 +502,7 @@ func GetReviewersFromOriginalAuthorsByIssueID(issueID int64) ([]*Review, error) reviews := make([]*Review, 0, 10) // Get latest review of each reviewer, sorted in order they were made - if err := db.DefaultContext().Engine().SQL("SELECT * FROM review WHERE id IN (SELECT max(id) as id FROM review WHERE issue_id = ? AND reviewer_team_id = 0 AND type in (?, ?, ?) AND original_author_id <> 0 GROUP BY issue_id, original_author_id) ORDER BY review.updated_unix ASC", + if err := db.GetEngine(db.DefaultContext).SQL("SELECT * FROM review WHERE id IN (SELECT max(id) as id FROM review WHERE issue_id = ? AND reviewer_team_id = 0 AND type in (?, ?, ?) AND original_author_id <> 0 GROUP BY issue_id, original_author_id) ORDER BY review.updated_unix ASC", issueID, ReviewTypeApprove, ReviewTypeReject, ReviewTypeRequest). Find(&reviews); err != nil { return nil, err @@ -513,7 +513,7 @@ func GetReviewersFromOriginalAuthorsByIssueID(issueID int64) ([]*Review, error) // GetReviewByIssueIDAndUserID get the latest review of reviewer for a pull request func GetReviewByIssueIDAndUserID(issueID, userID int64) (*Review, error) { - return getReviewByIssueIDAndUserID(db.DefaultContext().Engine(), issueID, userID) + return getReviewByIssueIDAndUserID(db.GetEngine(db.DefaultContext), issueID, userID) } func getReviewByIssueIDAndUserID(e db.Engine, issueID, userID int64) (*Review, error) { @@ -535,7 +535,7 @@ func getReviewByIssueIDAndUserID(e db.Engine, issueID, userID int64) (*Review, e // GetTeamReviewerByIssueIDAndTeamID get the latest review requst of reviewer team for a pull request func GetTeamReviewerByIssueIDAndTeamID(issueID, teamID int64) (review *Review, err error) { - return getTeamReviewerByIssueIDAndTeamID(db.DefaultContext().Engine(), issueID, teamID) + return getTeamReviewerByIssueIDAndTeamID(db.GetEngine(db.DefaultContext), issueID, teamID) } func getTeamReviewerByIssueIDAndTeamID(e db.Engine, issueID, teamID int64) (review *Review, err error) { @@ -557,14 +557,14 @@ func getTeamReviewerByIssueIDAndTeamID(e db.Engine, issueID, teamID int64) (revi // MarkReviewsAsStale marks existing reviews as stale func MarkReviewsAsStale(issueID int64) (err error) { - _, err = db.DefaultContext().Engine().Exec("UPDATE `review` SET stale=? WHERE issue_id=?", true, issueID) + _, err = db.GetEngine(db.DefaultContext).Exec("UPDATE `review` SET stale=? WHERE issue_id=?", true, issueID) return } // MarkReviewsAsNotStale marks existing reviews as not stale for a giving commit SHA func MarkReviewsAsNotStale(issueID int64, commitID string) (err error) { - _, err = db.DefaultContext().Engine().Exec("UPDATE `review` SET stale=? WHERE issue_id=? AND commit_id=?", false, issueID, commitID) + _, err = db.GetEngine(db.DefaultContext).Exec("UPDATE `review` SET stale=? WHERE issue_id=? AND commit_id=?", false, issueID, commitID) return } @@ -581,14 +581,14 @@ func DismissReview(review *Review, isDismiss bool) (err error) { return ErrReviewNotExist{} } - _, err = db.DefaultContext().Engine().ID(review.ID).Cols("dismissed").Update(review) + _, err = db.GetEngine(db.DefaultContext).ID(review.ID).Cols("dismissed").Update(review) return } // InsertReviews inserts review and review comments func InsertReviews(reviews []*Review) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { @@ -630,7 +630,7 @@ func InsertReviews(reviews []*Review) error { // AddReviewRequest add a review request from one reviewer func AddReviewRequest(issue *Issue, reviewer, doer *User) (*Comment, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -684,7 +684,7 @@ func AddReviewRequest(issue *Issue, reviewer, doer *User) (*Comment, error) { // RemoveReviewRequest remove a review request from one reviewer func RemoveReviewRequest(issue *Issue, reviewer, doer *User) (*Comment, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -737,7 +737,7 @@ func RemoveReviewRequest(issue *Issue, reviewer, doer *User) (*Comment, error) { // AddTeamReviewRequest add a review request from one team func AddTeamReviewRequest(issue *Issue, reviewer *Team, doer *User) (*Comment, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -796,7 +796,7 @@ func AddTeamReviewRequest(issue *Issue, reviewer *Team, doer *User) (*Comment, e // RemoveTeamReviewRequest remove a review request from one team func RemoveTeamReviewRequest(issue *Issue, reviewer *Team, doer *User) (*Comment, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -864,7 +864,7 @@ func MarkConversation(comment *Comment, doer *User, isResolve bool) (err error) return nil } - if _, err = db.DefaultContext().Engine().Exec("UPDATE `comment` SET resolve_doer_id=? WHERE id=?", doer.ID, comment.ID); err != nil { + if _, err = db.GetEngine(db.DefaultContext).Exec("UPDATE `comment` SET resolve_doer_id=? WHERE id=?", doer.ID, comment.ID); err != nil { return err } } else { @@ -872,7 +872,7 @@ func MarkConversation(comment *Comment, doer *User, isResolve bool) (err error) return nil } - if _, err = db.DefaultContext().Engine().Exec("UPDATE `comment` SET resolve_doer_id=? WHERE id=?", 0, comment.ID); err != nil { + if _, err = db.GetEngine(db.DefaultContext).Exec("UPDATE `comment` SET resolve_doer_id=? WHERE id=?", 0, comment.ID); err != nil { return err } } @@ -914,7 +914,7 @@ func CanMarkConversation(issue *Issue, doer *User) (permResult bool, err error) // DeleteReview delete a review and it's code comments func DeleteReview(r *Review) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { @@ -968,7 +968,7 @@ func (r *Review) GetCodeCommentsCount() int { conds = conds.And(builder.Eq{"invalidated": false}) } - count, err := db.DefaultContext().Engine().Where(conds).Count(new(Comment)) + count, err := db.GetEngine(db.DefaultContext).Where(conds).Count(new(Comment)) if err != nil { return 0 } @@ -983,7 +983,7 @@ func (r *Review) HTMLURL() string { ReviewID: r.ID, } comment := new(Comment) - has, err := db.DefaultContext().Engine().Where(opts.toConds()).Get(comment) + has, err := db.GetEngine(db.DefaultContext).Where(opts.toConds()).Get(comment) if err != nil || !has { return "" } diff --git a/models/session.go b/models/session.go index 9a0c71fdd97f..65fe2bef4f6c 100644 --- a/models/session.go +++ b/models/session.go @@ -24,7 +24,7 @@ func init() { // UpdateSession updates the session with provided id func UpdateSession(key string, data []byte) error { - _, err := db.DefaultContext().Engine().ID(key).Update(&Session{ + _, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{ Data: data, Expiry: timeutil.TimeStampNow(), }) @@ -36,7 +36,7 @@ func ReadSession(key string) (*Session, error) { session := Session{ Key: key, } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -60,12 +60,12 @@ func ExistSession(key string) (bool, error) { session := Session{ Key: key, } - return db.DefaultContext().Engine().Get(&session) + return db.GetEngine(db.DefaultContext).Get(&session) } // DestroySession destroys a session func DestroySession(key string) error { - _, err := db.DefaultContext().Engine().Delete(&Session{ + _, err := db.GetEngine(db.DefaultContext).Delete(&Session{ Key: key, }) return err @@ -73,7 +73,7 @@ func DestroySession(key string) error { // RegenerateSession regenerates a session from the old id func RegenerateSession(oldKey, newKey string) (*Session, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -117,11 +117,11 @@ func RegenerateSession(oldKey, newKey string) (*Session, error) { // CountSessions returns the number of sessions func CountSessions() (int64, error) { - return db.DefaultContext().Engine().Count(&Session{}) + return db.GetEngine(db.DefaultContext).Count(&Session{}) } // CleanupSessions cleans up expired sessions func CleanupSessions(maxLifetime int64) error { - _, err := db.DefaultContext().Engine().Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) + _, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) return err } diff --git a/models/ssh_key.go b/models/ssh_key.go index 04be4ce605e2..41016537eb65 100644 --- a/models/ssh_key.go +++ b/models/ssh_key.go @@ -95,7 +95,7 @@ func AddPublicKey(ownerID int64, name, content string, loginSourceID int64) (*Pu return nil, err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return nil, err @@ -134,7 +134,7 @@ func AddPublicKey(ownerID int64, name, content string, loginSourceID int64) (*Pu // GetPublicKeyByID returns public key by given ID. func GetPublicKeyByID(keyID int64) (*PublicKey, error) { key := new(PublicKey) - has, err := db.DefaultContext().Engine(). + has, err := db.GetEngine(db.DefaultContext). ID(keyID). Get(key) if err != nil { @@ -161,7 +161,7 @@ func searchPublicKeyByContentWithEngine(e db.Engine, content string) (*PublicKey // SearchPublicKeyByContent searches content as prefix (leak e-mail part) // and returns public key found. func SearchPublicKeyByContent(content string) (*PublicKey, error) { - return searchPublicKeyByContentWithEngine(db.DefaultContext().Engine(), content) + return searchPublicKeyByContentWithEngine(db.GetEngine(db.DefaultContext), content) } func searchPublicKeyByContentExactWithEngine(e db.Engine, content string) (*PublicKey, error) { @@ -180,7 +180,7 @@ func searchPublicKeyByContentExactWithEngine(e db.Engine, content string) (*Publ // SearchPublicKeyByContentExact searches content // and returns public key found. func SearchPublicKeyByContentExact(content string) (*PublicKey, error) { - return searchPublicKeyByContentExactWithEngine(db.DefaultContext().Engine(), content) + return searchPublicKeyByContentExactWithEngine(db.GetEngine(db.DefaultContext), content) } // SearchPublicKey returns a list of public keys matching the provided arguments. @@ -193,12 +193,12 @@ func SearchPublicKey(uid int64, fingerprint string) ([]*PublicKey, error) { if fingerprint != "" { cond = cond.And(builder.Eq{"fingerprint": fingerprint}) } - return keys, db.DefaultContext().Engine().Where(cond).Find(&keys) + return keys, db.GetEngine(db.DefaultContext).Where(cond).Find(&keys) } // ListPublicKeys returns a list of public keys belongs to given user. func ListPublicKeys(uid int64, listOptions ListOptions) ([]*PublicKey, error) { - sess := db.DefaultContext().Engine().Where("owner_id = ? AND type != ?", uid, KeyTypePrincipal) + sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type != ?", uid, KeyTypePrincipal) if listOptions.Page != 0 { sess = setSessionPagination(sess, &listOptions) @@ -212,14 +212,14 @@ func ListPublicKeys(uid int64, listOptions ListOptions) ([]*PublicKey, error) { // CountPublicKeys count public keys a user has func CountPublicKeys(userID int64) (int64, error) { - sess := db.DefaultContext().Engine().Where("owner_id = ? AND type != ?", userID, KeyTypePrincipal) + sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type != ?", userID, KeyTypePrincipal) return sess.Count(&PublicKey{}) } // ListPublicKeysBySource returns a list of synchronized public keys for a given user and login source. func ListPublicKeysBySource(uid, loginSourceID int64) ([]*PublicKey, error) { keys := make([]*PublicKey, 0, 5) - return keys, db.DefaultContext().Engine(). + return keys, db.GetEngine(db.DefaultContext). Where("owner_id = ? AND login_source_id = ?", uid, loginSourceID). Find(&keys) } @@ -228,13 +228,13 @@ func ListPublicKeysBySource(uid, loginSourceID int64) ([]*PublicKey, error) { func UpdatePublicKeyUpdated(id int64) error { // Check if key exists before update as affected rows count is unreliable // and will return 0 affected rows if two updates are made at the same time - if cnt, err := db.DefaultContext().Engine().ID(id).Count(&PublicKey{}); err != nil { + if cnt, err := db.GetEngine(db.DefaultContext).ID(id).Count(&PublicKey{}); err != nil { return err } else if cnt != 1 { return ErrKeyNotExist{id} } - _, err := db.DefaultContext().Engine().ID(id).Cols("updated_unix").Update(&PublicKey{ + _, err := db.GetEngine(db.DefaultContext).ID(id).Cols("updated_unix").Update(&PublicKey{ UpdatedUnix: timeutil.TimeStampNow(), }) if err != nil { @@ -333,7 +333,7 @@ func DeletePublicKey(doer *User, id int64) (err error) { return ErrKeyAccessDenied{doer.ID, key.ID, "public"} } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -358,7 +358,7 @@ func DeletePublicKey(doer *User, id int64) (err error) { // deleteKeysMarkedForDeletion returns true if ssh keys needs update func deleteKeysMarkedForDeletion(keys []string) (bool, error) { // Start session - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return false, err diff --git a/models/ssh_key_authorized_keys.go b/models/ssh_key_authorized_keys.go index e490f2118049..ed17a12e9a8a 100644 --- a/models/ssh_key_authorized_keys.go +++ b/models/ssh_key_authorized_keys.go @@ -115,10 +115,10 @@ func appendAuthorizedKeysToFile(keys ...*PublicKey) error { } // RewriteAllPublicKeys removes any authorized key and rewrite all keys from database again. -// Note: db.DefaultContext().Engine().Iterate does not get latest data after insert/delete, so we have to call this function +// Note: db.GetEngine(db.DefaultContext).Iterate does not get latest data after insert/delete, so we have to call this function // outside any session scope independently. func RewriteAllPublicKeys() error { - return rewriteAllPublicKeys(db.DefaultContext().Engine()) + return rewriteAllPublicKeys(db.GetEngine(db.DefaultContext)) } func rewriteAllPublicKeys(e db.Engine) error { @@ -179,7 +179,7 @@ func rewriteAllPublicKeys(e db.Engine) error { // RegeneratePublicKeys regenerates the authorized_keys file func RegeneratePublicKeys(t io.StringWriter) error { - return regeneratePublicKeys(db.DefaultContext().Engine(), t) + return regeneratePublicKeys(db.GetEngine(db.DefaultContext), t) } func regeneratePublicKeys(e db.Engine, t io.StringWriter) error { diff --git a/models/ssh_key_authorized_principals.go b/models/ssh_key_authorized_principals.go index c5f7b2178588..c053b4b6d55c 100644 --- a/models/ssh_key_authorized_principals.go +++ b/models/ssh_key_authorized_principals.go @@ -40,10 +40,10 @@ import ( const authorizedPrincipalsFile = "authorized_principals" // RewriteAllPrincipalKeys removes any authorized principal and rewrite all keys from database again. -// Note: db.DefaultContext().Engine().Iterate does not get latest data after insert/delete, so we have to call this function +// Note: db.GetEngine(db.DefaultContext).Iterate does not get latest data after insert/delete, so we have to call this function // outside any session scope independently. func RewriteAllPrincipalKeys() error { - return rewriteAllPrincipalKeys(db.DefaultContext().Engine()) + return rewriteAllPrincipalKeys(db.GetEngine(db.DefaultContext)) } func rewriteAllPrincipalKeys(e db.Engine) error { @@ -102,7 +102,7 @@ func rewriteAllPrincipalKeys(e db.Engine) error { // RegeneratePrincipalKeys regenerates the authorized_principals file func RegeneratePrincipalKeys(t io.StringWriter) error { - return regeneratePrincipalKeys(db.DefaultContext().Engine(), t) + return regeneratePrincipalKeys(db.GetEngine(db.DefaultContext), t) } func regeneratePrincipalKeys(e db.Engine, t io.StringWriter) error { diff --git a/models/ssh_key_deploy.go b/models/ssh_key_deploy.go index 7aa764522ab3..3b9a16828074 100644 --- a/models/ssh_key_deploy.go +++ b/models/ssh_key_deploy.go @@ -107,7 +107,7 @@ func addDeployKey(e *xorm.Session, keyID, repoID int64, name, fingerprint string // HasDeployKey returns true if public key is a deploy key of given repository. func HasDeployKey(keyID, repoID int64) bool { - has, _ := db.DefaultContext().Engine(). + has, _ := db.GetEngine(db.DefaultContext). Where("key_id = ? AND repo_id = ?", keyID, repoID). Get(new(DeployKey)) return has @@ -125,7 +125,7 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey accessMode = AccessModeWrite } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return nil, err @@ -164,7 +164,7 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey // GetDeployKeyByID returns deploy key by given ID. func GetDeployKeyByID(id int64) (*DeployKey, error) { - return getDeployKeyByID(db.DefaultContext().Engine(), id) + return getDeployKeyByID(db.GetEngine(db.DefaultContext), id) } func getDeployKeyByID(e db.Engine, id int64) (*DeployKey, error) { @@ -180,7 +180,7 @@ func getDeployKeyByID(e db.Engine, id int64) (*DeployKey, error) { // GetDeployKeyByRepo returns deploy key by given public key ID and repository ID. func GetDeployKeyByRepo(keyID, repoID int64) (*DeployKey, error) { - return getDeployKeyByRepo(db.DefaultContext().Engine(), keyID, repoID) + return getDeployKeyByRepo(db.GetEngine(db.DefaultContext), keyID, repoID) } func getDeployKeyByRepo(e db.Engine, keyID, repoID int64) (*DeployKey, error) { @@ -199,19 +199,19 @@ func getDeployKeyByRepo(e db.Engine, keyID, repoID int64) (*DeployKey, error) { // UpdateDeployKeyCols updates deploy key information in the specified columns. func UpdateDeployKeyCols(key *DeployKey, cols ...string) error { - _, err := db.DefaultContext().Engine().ID(key.ID).Cols(cols...).Update(key) + _, err := db.GetEngine(db.DefaultContext).ID(key.ID).Cols(cols...).Update(key) return err } // UpdateDeployKey updates deploy key information. func UpdateDeployKey(key *DeployKey) error { - _, err := db.DefaultContext().Engine().ID(key.ID).AllCols().Update(key) + _, err := db.GetEngine(db.DefaultContext).ID(key.ID).AllCols().Update(key) return err } // DeleteDeployKey deletes deploy key from its repository authorized_keys file if needed. func DeleteDeployKey(doer *User, id int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -293,7 +293,7 @@ func (opt ListDeployKeysOptions) toCond() builder.Cond { // ListDeployKeys returns a list of deploy keys matching the provided arguments. func ListDeployKeys(opts *ListDeployKeysOptions) ([]*DeployKey, error) { - return listDeployKeys(db.DefaultContext().Engine(), opts) + return listDeployKeys(db.GetEngine(db.DefaultContext), opts) } func listDeployKeys(e db.Engine, opts *ListDeployKeysOptions) ([]*DeployKey, error) { @@ -312,5 +312,5 @@ func listDeployKeys(e db.Engine, opts *ListDeployKeysOptions) ([]*DeployKey, err // CountDeployKeys returns count deploy keys matching the provided arguments. func CountDeployKeys(opts *ListDeployKeysOptions) (int64, error) { - return db.DefaultContext().Engine().Where(opts.toCond()).Count(&DeployKey{}) + return db.GetEngine(db.DefaultContext).Where(opts.toCond()).Count(&DeployKey{}) } diff --git a/models/ssh_key_principals.go b/models/ssh_key_principals.go index 5f95cf74df09..383693e14ed7 100644 --- a/models/ssh_key_principals.go +++ b/models/ssh_key_principals.go @@ -24,7 +24,7 @@ import ( // AddPrincipalKey adds new principal to database and authorized_principals file. func AddPrincipalKey(ownerID int64, content string, loginSourceID int64) (*PublicKey, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -113,7 +113,7 @@ func CheckPrincipalKeyString(user *User, content string) (_ string, err error) { // ListPrincipalKeys returns a list of principals belongs to given user. func ListPrincipalKeys(uid int64, listOptions ListOptions) ([]*PublicKey, error) { - sess := db.DefaultContext().Engine().Where("owner_id = ? AND type = ?", uid, KeyTypePrincipal) + sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type = ?", uid, KeyTypePrincipal) if listOptions.Page != 0 { sess = setSessionPagination(sess, &listOptions) diff --git a/models/star.go b/models/star.go index 2d50a002d7e4..ad583f19985a 100644 --- a/models/star.go +++ b/models/star.go @@ -23,7 +23,7 @@ func init() { // StarRepo or unstar repository. func StarRepo(userID, repoID int64, star bool) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { @@ -65,7 +65,7 @@ func StarRepo(userID, repoID int64, star bool) error { // IsStaring checks if user has starred given repository. func IsStaring(userID, repoID int64) bool { - return isStaring(db.DefaultContext().Engine(), userID, repoID) + return isStaring(db.GetEngine(db.DefaultContext), userID, repoID) } func isStaring(e db.Engine, userID, repoID int64) bool { @@ -75,7 +75,7 @@ func isStaring(e db.Engine, userID, repoID int64) bool { // GetStargazers returns the users that starred the repo. func (repo *Repository) GetStargazers(opts ListOptions) ([]*User, error) { - sess := db.DefaultContext().Engine().Where("star.repo_id = ?", repo.ID). + sess := db.GetEngine(db.DefaultContext).Where("star.repo_id = ?", repo.ID). Join("LEFT", "star", "`user`.id = star.uid") if opts.Page > 0 { sess = setSessionPagination(sess, &opts) @@ -93,7 +93,7 @@ func (u *User) GetStarredRepos(private bool, page, pageSize int, orderBy string) if len(orderBy) == 0 { orderBy = "updated_unix DESC" } - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Join("INNER", "star", "star.repo_id = repository.id"). Where("star.uid = ?", u.ID). OrderBy(orderBy) @@ -113,7 +113,7 @@ func (u *User) GetStarredRepos(private bool, page, pageSize int, orderBy string) return } - if err = repos.loadAttributes(db.DefaultContext().Engine()); err != nil { + if err = repos.loadAttributes(db.GetEngine(db.DefaultContext)); err != nil { return } @@ -122,7 +122,7 @@ func (u *User) GetStarredRepos(private bool, page, pageSize int, orderBy string) // GetStarredRepoCount returns the numbers of repo the user starred. func (u *User) GetStarredRepoCount(private bool) (int64, error) { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Join("INNER", "star", "star.repo_id = repository.id"). Where("star.uid = ?", u.ID) diff --git a/models/statistic.go b/models/statistic.go index 715a10e70d17..d192a971f59f 100644 --- a/models/statistic.go +++ b/models/statistic.go @@ -23,12 +23,12 @@ type Statistic struct { func GetStatistic() (stats Statistic) { stats.Counter.User = CountUsers() stats.Counter.Org = CountOrganizations() - stats.Counter.PublicKey, _ = db.DefaultContext().Engine().Count(new(PublicKey)) + stats.Counter.PublicKey, _ = db.GetEngine(db.DefaultContext).Count(new(PublicKey)) stats.Counter.Repo = CountRepositories(true) - stats.Counter.Watch, _ = db.DefaultContext().Engine().Count(new(Watch)) - stats.Counter.Star, _ = db.DefaultContext().Engine().Count(new(Star)) - stats.Counter.Action, _ = db.DefaultContext().Engine().Count(new(Action)) - stats.Counter.Access, _ = db.DefaultContext().Engine().Count(new(Access)) + stats.Counter.Watch, _ = db.GetEngine(db.DefaultContext).Count(new(Watch)) + stats.Counter.Star, _ = db.GetEngine(db.DefaultContext).Count(new(Star)) + stats.Counter.Action, _ = db.GetEngine(db.DefaultContext).Count(new(Action)) + stats.Counter.Access, _ = db.GetEngine(db.DefaultContext).Count(new(Access)) type IssueCount struct { Count int64 @@ -36,7 +36,7 @@ func GetStatistic() (stats Statistic) { } issueCounts := []IssueCount{} - _ = db.DefaultContext().Engine().Select("COUNT(*) AS count, is_closed").Table("issue").GroupBy("is_closed").Find(&issueCounts) + _ = db.GetEngine(db.DefaultContext).Select("COUNT(*) AS count, is_closed").Table("issue").GroupBy("is_closed").Find(&issueCounts) for _, c := range issueCounts { if c.IsClosed { stats.Counter.IssueClosed = c.Count @@ -47,17 +47,17 @@ func GetStatistic() (stats Statistic) { stats.Counter.Issue = stats.Counter.IssueClosed + stats.Counter.IssueOpen - stats.Counter.Comment, _ = db.DefaultContext().Engine().Count(new(Comment)) + stats.Counter.Comment, _ = db.GetEngine(db.DefaultContext).Count(new(Comment)) stats.Counter.Oauth = 0 - stats.Counter.Follow, _ = db.DefaultContext().Engine().Count(new(Follow)) - stats.Counter.Mirror, _ = db.DefaultContext().Engine().Count(new(Mirror)) - stats.Counter.Release, _ = db.DefaultContext().Engine().Count(new(Release)) + stats.Counter.Follow, _ = db.GetEngine(db.DefaultContext).Count(new(Follow)) + stats.Counter.Mirror, _ = db.GetEngine(db.DefaultContext).Count(new(Mirror)) + stats.Counter.Release, _ = db.GetEngine(db.DefaultContext).Count(new(Release)) stats.Counter.LoginSource = CountLoginSources() - stats.Counter.Webhook, _ = db.DefaultContext().Engine().Count(new(Webhook)) - stats.Counter.Milestone, _ = db.DefaultContext().Engine().Count(new(Milestone)) - stats.Counter.Label, _ = db.DefaultContext().Engine().Count(new(Label)) - stats.Counter.HookTask, _ = db.DefaultContext().Engine().Count(new(HookTask)) - stats.Counter.Team, _ = db.DefaultContext().Engine().Count(new(Team)) - stats.Counter.Attachment, _ = db.DefaultContext().Engine().Count(new(Attachment)) + stats.Counter.Webhook, _ = db.GetEngine(db.DefaultContext).Count(new(Webhook)) + stats.Counter.Milestone, _ = db.GetEngine(db.DefaultContext).Count(new(Milestone)) + stats.Counter.Label, _ = db.GetEngine(db.DefaultContext).Count(new(Label)) + stats.Counter.HookTask, _ = db.GetEngine(db.DefaultContext).Count(new(HookTask)) + stats.Counter.Team, _ = db.GetEngine(db.DefaultContext).Count(new(Team)) + stats.Counter.Attachment, _ = db.GetEngine(db.DefaultContext).Count(new(Attachment)) return } diff --git a/models/task.go b/models/task.go index a943834b9592..7da9307c9596 100644 --- a/models/task.go +++ b/models/task.go @@ -49,7 +49,7 @@ type TranslatableMessage struct { // LoadRepo loads repository of the task func (task *Task) LoadRepo() error { - return task.loadRepo(db.DefaultContext().Engine()) + return task.loadRepo(db.GetEngine(db.DefaultContext)) } func (task *Task) loadRepo(e db.Engine) error { @@ -76,7 +76,7 @@ func (task *Task) LoadDoer() error { } var doer User - has, err := db.DefaultContext().Engine().ID(task.DoerID).Get(&doer) + has, err := db.GetEngine(db.DefaultContext).ID(task.DoerID).Get(&doer) if err != nil { return err } else if !has { @@ -96,7 +96,7 @@ func (task *Task) LoadOwner() error { } var owner User - has, err := db.DefaultContext().Engine().ID(task.OwnerID).Get(&owner) + has, err := db.GetEngine(db.DefaultContext).ID(task.OwnerID).Get(&owner) if err != nil { return err } else if !has { @@ -111,7 +111,7 @@ func (task *Task) LoadOwner() error { // UpdateCols updates some columns func (task *Task) UpdateCols(cols ...string) error { - _, err := db.DefaultContext().Engine().ID(task.ID).Cols(cols...).Update(task) + _, err := db.GetEngine(db.DefaultContext).ID(task.ID).Cols(cols...).Update(task) return err } @@ -170,7 +170,7 @@ func GetMigratingTask(repoID int64) (*Task, error) { RepoID: repoID, Type: structs.TaskTypeMigrateRepo, } - has, err := db.DefaultContext().Engine().Get(&task) + has, err := db.GetEngine(db.DefaultContext).Get(&task) if err != nil { return nil, err } else if !has { @@ -186,7 +186,7 @@ func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, e DoerID: doerID, Type: structs.TaskTypeMigrateRepo, } - has, err := db.DefaultContext().Engine().Get(&task) + has, err := db.GetEngine(db.DefaultContext).Get(&task) if err != nil { return nil, nil, err } else if !has { @@ -217,13 +217,13 @@ func (opts FindTaskOptions) ToConds() builder.Cond { // FindTasks find all tasks func FindTasks(opts FindTaskOptions) ([]*Task, error) { tasks := make([]*Task, 0, 10) - err := db.DefaultContext().Engine().Where(opts.ToConds()).Find(&tasks) + err := db.GetEngine(db.DefaultContext).Where(opts.ToConds()).Find(&tasks) return tasks, err } // CreateTask creates a task on database func CreateTask(task *Task) error { - return createTask(db.DefaultContext().Engine(), task) + return createTask(db.GetEngine(db.DefaultContext), task) } func createTask(e db.Engine, task *Task) error { @@ -253,7 +253,7 @@ func FinishMigrateTask(task *Task) error { } task.PayloadContent = string(confBytes) - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err diff --git a/models/token.go b/models/token.go index 4e1e8f1e88aa..48ae79542461 100644 --- a/models/token.go +++ b/models/token.go @@ -69,7 +69,7 @@ func NewAccessToken(t *AccessToken) error { t.Token = base.EncodeSha1(gouuid.New().String()) t.TokenHash = hashToken(t.Token, t.TokenSalt) t.TokenLastEight = t.Token[len(t.Token)-8:] - _, err = db.DefaultContext().Engine().Insert(t) + _, err = db.GetEngine(db.DefaultContext).Insert(t) return err } @@ -110,7 +110,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) { TokenLastEight: lastEight, } // Re-get the token from the db in case it has been deleted in the intervening period - has, err := db.DefaultContext().Engine().ID(id).Get(token) + has, err := db.GetEngine(db.DefaultContext).ID(id).Get(token) if err != nil { return nil, err } @@ -121,7 +121,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) { } var tokens []AccessToken - err := db.DefaultContext().Engine().Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens) + err := db.GetEngine(db.DefaultContext).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens) if err != nil { return nil, err } else if len(tokens) == 0 { @@ -142,7 +142,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) { // AccessTokenByNameExists checks if a token name has been used already by a user. func AccessTokenByNameExists(token *AccessToken) (bool, error) { - return db.DefaultContext().Engine().Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist() + return db.GetEngine(db.DefaultContext).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist() } // ListAccessTokensOptions contain filter options @@ -154,7 +154,7 @@ type ListAccessTokensOptions struct { // ListAccessTokens returns a list of access tokens belongs to given user. func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) { - sess := db.DefaultContext().Engine().Where("uid=?", opts.UserID) + sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID) if len(opts.Name) != 0 { sess = sess.Where("name=?", opts.Name) @@ -175,13 +175,13 @@ func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) { // UpdateAccessToken updates information of access token. func UpdateAccessToken(t *AccessToken) error { - _, err := db.DefaultContext().Engine().ID(t.ID).AllCols().Update(t) + _, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t) return err } // CountAccessTokens count access tokens belongs to given user by options func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) { - sess := db.DefaultContext().Engine().Where("uid=?", opts.UserID) + sess := db.GetEngine(db.DefaultContext).Where("uid=?", opts.UserID) if len(opts.Name) != 0 { sess = sess.Where("name=?", opts.Name) } @@ -190,7 +190,7 @@ func CountAccessTokens(opts ListAccessTokensOptions) (int64, error) { // DeleteAccessTokenByID deletes access token by given ID. func DeleteAccessTokenByID(id, userID int64) error { - cnt, err := db.DefaultContext().Engine().ID(id).Delete(&AccessToken{ + cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&AccessToken{ UID: userID, }) if err != nil { diff --git a/models/topic.go b/models/topic.go index 4b219cc0eba1..cf563e9b11c5 100644 --- a/models/topic.go +++ b/models/topic.go @@ -88,7 +88,7 @@ func SanitizeAndValidateTopics(topics []string) (validTopics, invalidTopics []st // GetTopicByName retrieves topic by name func GetTopicByName(name string) (*Topic, error) { var topic Topic - if has, err := db.DefaultContext().Engine().Where("name = ?", name).Get(&topic); err != nil { + if has, err := db.GetEngine(db.DefaultContext).Where("name = ?", name).Get(&topic); err != nil { return nil, err } else if !has { return nil, ErrTopicNotExist{name} @@ -184,7 +184,7 @@ func (opts *FindTopicOptions) toConds() builder.Cond { // FindTopics retrieves the topics via FindTopicOptions func FindTopics(opts *FindTopicOptions) ([]*Topic, int64, error) { - sess := db.DefaultContext().Engine().Select("topic.*").Where(opts.toConds()) + sess := db.GetEngine(db.DefaultContext).Select("topic.*").Where(opts.toConds()) if opts.RepoID > 0 { sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") } @@ -198,7 +198,7 @@ func FindTopics(opts *FindTopicOptions) ([]*Topic, int64, error) { // CountTopics counts the number of topics matching the FindTopicOptions func CountTopics(opts *FindTopicOptions) (int64, error) { - sess := db.DefaultContext().Engine().Where(opts.toConds()) + sess := db.GetEngine(db.DefaultContext).Where(opts.toConds()) if opts.RepoID > 0 { sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") } @@ -207,7 +207,7 @@ func CountTopics(opts *FindTopicOptions) (int64, error) { // GetRepoTopicByName retrieves topic from name for a repo if it exist func GetRepoTopicByName(repoID int64, topicName string) (*Topic, error) { - return getRepoTopicByName(db.DefaultContext().Engine(), repoID, topicName) + return getRepoTopicByName(db.GetEngine(db.DefaultContext), repoID, topicName) } func getRepoTopicByName(e db.Engine, repoID int64, topicName string) (*Topic, error) { @@ -225,7 +225,7 @@ func getRepoTopicByName(e db.Engine, repoID int64, topicName string) (*Topic, er // AddTopic adds a topic name to a repository (if it does not already have it) func AddTopic(repoID int64, topicName string) (*Topic, error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return nil, err @@ -272,7 +272,7 @@ func DeleteTopic(repoID int64, topicName string) (*Topic, error) { return nil, nil } - err = removeTopicFromRepo(db.DefaultContext().Engine(), repoID, topic) + err = removeTopicFromRepo(db.GetEngine(db.DefaultContext), repoID, topic) return topic, err } @@ -286,7 +286,7 @@ func SaveTopics(repoID int64, topicNames ...string) error { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { diff --git a/models/twofactor.go b/models/twofactor.go index f2443e04cece..dd7fde77e21e 100644 --- a/models/twofactor.go +++ b/models/twofactor.go @@ -93,13 +93,13 @@ func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) { // NewTwoFactor creates a new two-factor authentication token. func NewTwoFactor(t *TwoFactor) error { - _, err := db.DefaultContext().Engine().Insert(t) + _, err := db.GetEngine(db.DefaultContext).Insert(t) return err } // UpdateTwoFactor updates a two-factor authentication token. func UpdateTwoFactor(t *TwoFactor) error { - _, err := db.DefaultContext().Engine().ID(t.ID).AllCols().Update(t) + _, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t) return err } @@ -107,7 +107,7 @@ func UpdateTwoFactor(t *TwoFactor) error { // the user, if any. func GetTwoFactorByUID(uid int64) (*TwoFactor, error) { twofa := &TwoFactor{} - has, err := db.DefaultContext().Engine().Where("uid=?", uid).Get(twofa) + has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa) if err != nil { return nil, err } else if !has { @@ -118,7 +118,7 @@ func GetTwoFactorByUID(uid int64) (*TwoFactor, error) { // DeleteTwoFactorByID deletes two-factor authentication token by given ID. func DeleteTwoFactorByID(id, userID int64) error { - cnt, err := db.DefaultContext().Engine().ID(id).Delete(&TwoFactor{ + cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{ UID: userID, }) if err != nil { diff --git a/models/u2f.go b/models/u2f.go index 66943f8dd2c2..17b829562634 100644 --- a/models/u2f.go +++ b/models/u2f.go @@ -45,7 +45,7 @@ func (reg *U2FRegistration) updateCounter(e db.Engine) error { // UpdateCounter will update the database value of counter func (reg *U2FRegistration) UpdateCounter() error { - return reg.updateCounter(db.DefaultContext().Engine()) + return reg.updateCounter(db.GetEngine(db.DefaultContext)) } // U2FRegistrationList is a list of *U2FRegistration @@ -73,7 +73,7 @@ func getU2FRegistrationsByUID(e db.Engine, uid int64) (U2FRegistrationList, erro // GetU2FRegistrationByID returns U2F registration by id func GetU2FRegistrationByID(id int64) (*U2FRegistration, error) { - return getU2FRegistrationByID(db.DefaultContext().Engine(), id) + return getU2FRegistrationByID(db.GetEngine(db.DefaultContext), id) } func getU2FRegistrationByID(e db.Engine, id int64) (*U2FRegistration, error) { @@ -88,7 +88,7 @@ func getU2FRegistrationByID(e db.Engine, id int64) (*U2FRegistration, error) { // GetU2FRegistrationsByUID returns all U2F registrations of the given user func GetU2FRegistrationsByUID(uid int64) (U2FRegistrationList, error) { - return getU2FRegistrationsByUID(db.DefaultContext().Engine(), uid) + return getU2FRegistrationsByUID(db.GetEngine(db.DefaultContext), uid) } func createRegistration(e db.Engine, user *User, name string, reg *u2f.Registration) (*U2FRegistration, error) { @@ -111,12 +111,12 @@ func createRegistration(e db.Engine, user *User, name string, reg *u2f.Registrat // CreateRegistration will create a new U2FRegistration from the given Registration func CreateRegistration(user *User, name string, reg *u2f.Registration) (*U2FRegistration, error) { - return createRegistration(db.DefaultContext().Engine(), user, name, reg) + return createRegistration(db.GetEngine(db.DefaultContext), user, name, reg) } // DeleteRegistration will delete U2FRegistration func DeleteRegistration(reg *U2FRegistration) error { - return deleteRegistration(db.DefaultContext().Engine(), reg) + return deleteRegistration(db.GetEngine(db.DefaultContext), reg) } func deleteRegistration(e db.Engine, reg *U2FRegistration) error { diff --git a/models/update.go b/models/update.go index 0fb73ac89197..0898ab54c125 100644 --- a/models/update.go +++ b/models/update.go @@ -5,6 +5,7 @@ package models import ( + "context" "fmt" "strings" @@ -12,8 +13,8 @@ import ( ) // PushUpdateDeleteTagsContext updates a number of delete tags with context -func PushUpdateDeleteTagsContext(ctx *db.Context, repo *Repository, tags []string) error { - return pushUpdateDeleteTags(ctx.Engine(), repo, tags) +func PushUpdateDeleteTagsContext(ctx context.Context, repo *Repository, tags []string) error { + return pushUpdateDeleteTags(db.GetEngine(ctx), repo, tags) } func pushUpdateDeleteTags(e db.Engine, repo *Repository, tags []string) error { @@ -55,14 +56,14 @@ func PushUpdateDeleteTag(repo *Repository, tagName string) error { return fmt.Errorf("GetRelease: %v", err) } if rel.IsTag { - if _, err = db.DefaultContext().Engine().ID(rel.ID).Delete(new(Release)); err != nil { + if _, err = db.GetEngine(db.DefaultContext).ID(rel.ID).Delete(new(Release)); err != nil { return fmt.Errorf("Delete: %v", err) } } else { rel.IsDraft = true rel.NumCommits = 0 rel.Sha1 = "" - if _, err = db.DefaultContext().Engine().ID(rel.ID).AllCols().Update(rel); err != nil { + if _, err = db.GetEngine(db.DefaultContext).ID(rel.ID).AllCols().Update(rel); err != nil { return fmt.Errorf("Update: %v", err) } } @@ -79,7 +80,7 @@ func SaveOrUpdateTag(repo *Repository, newRel *Release) error { if rel == nil { rel = newRel - if _, err = db.DefaultContext().Engine().Insert(rel); err != nil { + if _, err = db.GetEngine(db.DefaultContext).Insert(rel); err != nil { return fmt.Errorf("InsertOne: %v", err) } } else { @@ -90,7 +91,7 @@ func SaveOrUpdateTag(repo *Repository, newRel *Release) error { if rel.IsTag && newRel.PublisherID > 0 { rel.PublisherID = newRel.PublisherID } - if _, err = db.DefaultContext().Engine().ID(rel.ID).AllCols().Update(rel); err != nil { + if _, err = db.GetEngine(db.DefaultContext).ID(rel.ID).AllCols().Update(rel); err != nil { return fmt.Errorf("Update: %v", err) } } diff --git a/models/upload.go b/models/upload.go index ca88c6a393eb..503220db577b 100644 --- a/models/upload.go +++ b/models/upload.go @@ -73,7 +73,7 @@ func NewUpload(name string, buf []byte, file multipart.File) (_ *Upload, err err return nil, fmt.Errorf("Copy: %v", err) } - if _, err := db.DefaultContext().Engine().Insert(upload); err != nil { + if _, err := db.GetEngine(db.DefaultContext).Insert(upload); err != nil { return nil, err } @@ -83,7 +83,7 @@ func NewUpload(name string, buf []byte, file multipart.File) (_ *Upload, err err // GetUploadByUUID returns the Upload by UUID func GetUploadByUUID(uuid string) (*Upload, error) { upload := &Upload{} - has, err := db.DefaultContext().Engine().Where("uuid=?", uuid).Get(upload) + has, err := db.GetEngine(db.DefaultContext).Where("uuid=?", uuid).Get(upload) if err != nil { return nil, err } else if !has { @@ -100,7 +100,7 @@ func GetUploadsByUUIDs(uuids []string) ([]*Upload, error) { // Silently drop invalid uuids. uploads := make([]*Upload, 0, len(uuids)) - return uploads, db.DefaultContext().Engine().In("uuid", uuids).Find(&uploads) + return uploads, db.GetEngine(db.DefaultContext).In("uuid", uuids).Find(&uploads) } // DeleteUploads deletes multiple uploads @@ -109,7 +109,7 @@ func DeleteUploads(uploads ...*Upload) (err error) { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err diff --git a/models/user.go b/models/user.go index a4a3d83166aa..fc5d417d3609 100644 --- a/models/user.go +++ b/models/user.go @@ -236,7 +236,7 @@ func (u *User) GetEmail() string { // GetAllUsers returns a slice of all individual users found in DB. func GetAllUsers() ([]*User, error) { users := make([]*User, 0) - return users, db.DefaultContext().Engine().OrderBy("id").Where("type = ?", UserTypeIndividual).Find(&users) + return users, db.GetEngine(db.DefaultContext).OrderBy("id").Where("type = ?", UserTypeIndividual).Find(&users) } // IsLocal returns true if user login type is LoginPlain. @@ -332,7 +332,7 @@ func (u *User) GenerateEmailActivateCode(email string) string { // GetFollowers returns range of user's followers. func (u *User) GetFollowers(listOptions ListOptions) ([]*User, error) { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Where("follow.follow_id=?", u.ID). Join("LEFT", "follow", "`user`.id=follow.user_id") @@ -354,7 +354,7 @@ func (u *User) IsFollowing(followID int64) bool { // GetFollowing returns range of user's following. func (u *User) GetFollowing(listOptions ListOptions) ([]*User, error) { - sess := db.DefaultContext().Engine(). + sess := db.GetEngine(db.DefaultContext). Where("follow.user_id=?", u.ID). Join("LEFT", "follow", "`user`.id=follow.follow_id") @@ -437,7 +437,7 @@ func (u *User) IsPasswordSet() bool { // IsVisibleToUser check if viewer is able to see user profile func (u *User) IsVisibleToUser(viewer *User) bool { - return u.isVisibleToUser(db.DefaultContext().Engine(), viewer) + return u.isVisibleToUser(db.GetEngine(db.DefaultContext), viewer) } func (u *User) isVisibleToUser(e db.Engine, viewer *User) bool { @@ -465,7 +465,7 @@ func (u *User) isVisibleToUser(e db.Engine, viewer *User) bool { } // Now we need to check if they in some organization together - count, err := db.DefaultContext().Engine().Table("team_user"). + count, err := db.GetEngine(db.DefaultContext).Table("team_user"). Where( builder.And( builder.Eq{"uid": viewer.ID}, @@ -508,7 +508,7 @@ func (u *User) IsUserOrgOwner(orgID int64) bool { // HasMemberWithUserID returns true if user with userID is part of the u organisation. func (u *User) HasMemberWithUserID(userID int64) bool { - return u.hasMemberWithUserID(db.DefaultContext().Engine(), userID) + return u.hasMemberWithUserID(db.GetEngine(db.DefaultContext), userID) } func (u *User) hasMemberWithUserID(e db.Engine, userID int64) bool { @@ -538,7 +538,7 @@ func (u *User) getOrganizationCount(e db.Engine) (int64, error) { // GetOrganizationCount returns count of membership of organization of user. func (u *User) GetOrganizationCount() (int64, error) { - return u.getOrganizationCount(db.DefaultContext().Engine()) + return u.getOrganizationCount(db.GetEngine(db.DefaultContext)) } // GetRepositories returns repositories that user owns, including private repositories. @@ -552,7 +552,7 @@ func (u *User) GetRepositories(listOpts ListOptions, names ...string) (err error func (u *User) GetRepositoryIDs(units ...UnitType) ([]int64, error) { var ids []int64 - sess := db.DefaultContext().Engine().Table("repository").Cols("repository.id") + sess := db.GetEngine(db.DefaultContext).Table("repository").Cols("repository.id") if len(units) > 0 { sess = sess.Join("INNER", "repo_unit", "repository.id = repo_unit.repo_id") @@ -567,7 +567,7 @@ func (u *User) GetRepositoryIDs(units ...UnitType) ([]int64, error) { func (u *User) GetActiveRepositoryIDs(units ...UnitType) ([]int64, error) { var ids []int64 - sess := db.DefaultContext().Engine().Table("repository").Cols("repository.id") + sess := db.GetEngine(db.DefaultContext).Table("repository").Cols("repository.id") if len(units) > 0 { sess = sess.Join("INNER", "repo_unit", "repository.id = repo_unit.repo_id") @@ -584,7 +584,7 @@ func (u *User) GetActiveRepositoryIDs(units ...UnitType) ([]int64, error) { func (u *User) GetOrgRepositoryIDs(units ...UnitType) ([]int64, error) { var ids []int64 - if err := db.DefaultContext().Engine().Table("repository"). + if err := db.GetEngine(db.DefaultContext).Table("repository"). Cols("repository.id"). Join("INNER", "team_user", "repository.owner_id = team_user.org_id"). Join("INNER", "team_repo", "(? != ? and repository.is_private != ?) OR (team_user.team_id = team_repo.team_id AND repository.id = team_repo.repo_id)", true, u.IsRestricted, true). @@ -605,7 +605,7 @@ func (u *User) GetOrgRepositoryIDs(units ...UnitType) ([]int64, error) { func (u *User) GetActiveOrgRepositoryIDs(units ...UnitType) ([]int64, error) { var ids []int64 - if err := db.DefaultContext().Engine().Table("repository"). + if err := db.GetEngine(db.DefaultContext).Table("repository"). Cols("repository.id"). Join("INNER", "team_user", "repository.owner_id = team_user.org_id"). Join("INNER", "team_repo", "(? != ? and repository.is_private != ?) OR (team_user.team_id = team_repo.team_id AND repository.id = team_repo.repo_id)", true, u.IsRestricted, true). @@ -743,7 +743,7 @@ func isUserExist(e db.Engine, uid int64, name string) (bool, error) { // If uid is presented, then check will rule out that one, // it is used when update a user name in settings page. func IsUserExist(uid int64, name string) (bool, error) { - return isUserExist(db.DefaultContext().Engine(), uid, name) + return isUserExist(db.GetEngine(db.DefaultContext), uid, name) } // GetUserSalt returns a random user salt token. @@ -879,7 +879,7 @@ func CreateUser(u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err e u.Visibility = overwriteDefault[0].Visibility } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -949,7 +949,7 @@ func countUsers(e db.Engine) int64 { // CountUsers returns number of users. func CountUsers() int64 { - return countUsers(db.DefaultContext().Engine()) + return countUsers(db.GetEngine(db.DefaultContext)) } // get user by verify code @@ -997,7 +997,7 @@ func VerifyActiveEmailCode(code, email string) *EmailAddress { if base.VerifyTimeLimitCode(data, minutes, prefix) { emailAddress := &EmailAddress{UID: user.ID, Email: email} - if has, _ := db.DefaultContext().Engine().Get(emailAddress); has { + if has, _ := db.GetEngine(db.DefaultContext).Get(emailAddress); has { return emailAddress } } @@ -1012,7 +1012,7 @@ func ChangeUserName(u *User, newUserName string) (err error) { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -1086,12 +1086,12 @@ func updateUser(e db.Engine, u *User) error { // UpdateUser updates user's information. func UpdateUser(u *User) error { - return updateUser(db.DefaultContext().Engine(), u) + return updateUser(db.GetEngine(db.DefaultContext), u) } // UpdateUserCols update user according special columns func UpdateUserCols(u *User, cols ...string) error { - return updateUserCols(db.DefaultContext().Engine(), u, cols...) + return updateUserCols(db.GetEngine(db.DefaultContext), u, cols...) } func updateUserCols(e db.Engine, u *User, cols ...string) error { @@ -1105,7 +1105,7 @@ func updateUserCols(e db.Engine, u *User, cols ...string) error { // UpdateUserSetting updates user's settings. func UpdateUserSetting(u *User) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -1311,7 +1311,7 @@ func DeleteUser(u *User) (err error) { return fmt.Errorf("%s is an organization not a user", u.Name) } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -1329,13 +1329,13 @@ func DeleteUser(u *User) (err error) { func DeleteInactiveUsers(ctx context.Context, olderThan time.Duration) (err error) { users := make([]*User, 0, 10) if olderThan > 0 { - if err = db.DefaultContext().Engine(). + if err = db.GetEngine(db.DefaultContext). Where("is_active = ? and created_unix < ?", false, time.Now().Add(-olderThan).Unix()). Find(&users); err != nil { return fmt.Errorf("get all inactive users: %v", err) } } else { - if err = db.DefaultContext().Engine(). + if err = db.GetEngine(db.DefaultContext). Where("is_active = ?", false). Find(&users); err != nil { return fmt.Errorf("get all inactive users: %v", err) @@ -1357,7 +1357,7 @@ func DeleteInactiveUsers(ctx context.Context, olderThan time.Duration) (err erro } } - _, err = db.DefaultContext().Engine(). + _, err = db.GetEngine(db.DefaultContext). Where("is_activated = ?", false). Delete(new(EmailAddress)) return err @@ -1381,12 +1381,12 @@ func getUserByID(e db.Engine, id int64) (*User, error) { // GetUserByID returns the user object by given ID if exists. func GetUserByID(id int64) (*User, error) { - return getUserByID(db.DefaultContext().Engine(), id) + return getUserByID(db.GetEngine(db.DefaultContext), id) } // GetUserByName returns user by given name. func GetUserByName(name string) (*User, error) { - return getUserByName(db.DefaultContext().Engine(), name) + return getUserByName(db.GetEngine(db.DefaultContext), name) } func getUserByName(e db.Engine, name string) (*User, error) { @@ -1406,7 +1406,7 @@ func getUserByName(e db.Engine, name string) (*User, error) { // GetUserEmailsByNames returns a list of e-mails corresponds to names of users // that have their email notifications set to enabled or onmention. func GetUserEmailsByNames(names []string) []string { - return getUserEmailsByNames(db.DefaultContext().Engine(), names) + return getUserEmailsByNames(db.GetEngine(db.DefaultContext), names) } func getUserEmailsByNames(e db.Engine, names []string) []string { @@ -1431,7 +1431,7 @@ func GetMaileableUsersByIDs(ids []int64, isMention bool) ([]*User, error) { ous := make([]*User, 0, len(ids)) if isMention { - return ous, db.DefaultContext().Engine().In("id", ids). + return ous, db.GetEngine(db.DefaultContext).In("id", ids). Where("`type` = ?", UserTypeIndividual). And("`prohibit_login` = ?", false). And("`is_active` = ?", true). @@ -1439,7 +1439,7 @@ func GetMaileableUsersByIDs(ids []int64, isMention bool) ([]*User, error) { Find(&ous) } - return ous, db.DefaultContext().Engine().In("id", ids). + return ous, db.GetEngine(db.DefaultContext).In("id", ids). Where("`type` = ?", UserTypeIndividual). And("`prohibit_login` = ?", false). And("`is_active` = ?", true). @@ -1450,7 +1450,7 @@ func GetMaileableUsersByIDs(ids []int64, isMention bool) ([]*User, error) { // GetUserNamesByIDs returns usernames for all resolved users from a list of Ids. func GetUserNamesByIDs(ids []int64) ([]string, error) { unames := make([]string, 0, len(ids)) - err := db.DefaultContext().Engine().In("id", ids). + err := db.GetEngine(db.DefaultContext).In("id", ids). Table("user"). Asc("name"). Cols("name"). @@ -1464,7 +1464,7 @@ func GetUsersByIDs(ids []int64) (UserList, error) { if len(ids) == 0 { return ous, nil } - err := db.DefaultContext().Engine().In("id", ids). + err := db.GetEngine(db.DefaultContext).In("id", ids). Asc("name"). Find(&ous) return ous, err @@ -1490,7 +1490,7 @@ func GetUserIDsByNames(names []string, ignoreNonExistent bool) ([]int64, error) // GetUsersBySource returns a list of Users for a login source func GetUsersBySource(s *LoginSource) ([]*User, error) { var users []*User - err := db.DefaultContext().Engine().Where("login_type = ? AND login_source = ?", s.Type, s.ID).Find(&users) + err := db.GetEngine(db.DefaultContext).Where("login_type = ? AND login_source = ?", s.Type, s.ID).Find(&users) return users, err } @@ -1539,11 +1539,11 @@ func ValidateCommitsWithEmails(oldCommits []*git.Commit) []*UserCommit { // GetUserByEmail returns the user object by given e-mail if exists. func GetUserByEmail(email string) (*User, error) { - return GetUserByEmailContext(db.DefaultContext(), email) + return GetUserByEmailContext(db.DefaultContext, email) } // GetUserByEmailContext returns the user object by given e-mail if exists with db context -func GetUserByEmailContext(ctx *db.Context, email string) (*User, error) { +func GetUserByEmailContext(ctx context.Context, email string) (*User, error) { if len(email) == 0 { return nil, ErrUserNotExist{0, email, 0} } @@ -1551,7 +1551,7 @@ func GetUserByEmailContext(ctx *db.Context, email string) (*User, error) { email = strings.ToLower(email) // First try to find the user by primary email user := &User{Email: email} - has, err := ctx.Engine().Get(user) + has, err := db.GetEngine(ctx).Get(user) if err != nil { return nil, err } @@ -1561,19 +1561,19 @@ func GetUserByEmailContext(ctx *db.Context, email string) (*User, error) { // Otherwise, check in alternative list for activated email addresses emailAddress := &EmailAddress{Email: email, IsActivated: true} - has, err = ctx.Engine().Get(emailAddress) + has, err = db.GetEngine(ctx).Get(emailAddress) if err != nil { return nil, err } if has { - return getUserByID(ctx.Engine(), emailAddress.UID) + return getUserByID(db.GetEngine(ctx), emailAddress.UID) } // Finally, if email address is the protected email address: if strings.HasSuffix(email, fmt.Sprintf("@%s", setting.Service.NoReplyAddress)) { username := strings.TrimSuffix(email, fmt.Sprintf("@%s", setting.Service.NoReplyAddress)) user := &User{} - has, err := ctx.Engine().Where("lower_name=?", username).Get(user) + has, err := db.GetEngine(ctx).Where("lower_name=?", username).Get(user) if err != nil { return nil, err } @@ -1587,7 +1587,7 @@ func GetUserByEmailContext(ctx *db.Context, email string) (*User, error) { // GetUser checks if a user already exists func GetUser(user *User) (bool, error) { - return db.DefaultContext().Engine().Get(user) + return db.GetEngine(db.DefaultContext).Get(user) } // SearchUserOptions contains the options for searching @@ -1664,7 +1664,7 @@ func (opts *SearchUserOptions) toConds() builder.Cond { // it returns results in given range and number of total results. func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { cond := opts.toConds() - count, err := db.DefaultContext().Engine().Where(cond).Count(new(User)) + count, err := db.GetEngine(db.DefaultContext).Where(cond).Count(new(User)) if err != nil { return nil, 0, fmt.Errorf("Count: %v", err) } @@ -1673,7 +1673,7 @@ func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { opts.OrderBy = SearchOrderByAlphabetically } - sess := db.DefaultContext().Engine().Where(cond).OrderBy(opts.OrderBy.String()) + sess := db.GetEngine(db.DefaultContext).Where(cond).OrderBy(opts.OrderBy.String()) if opts.Page != 0 { sess = setSessionPagination(sess, opts) } @@ -1684,7 +1684,7 @@ func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { // GetStarredRepos returns the repos starred by a particular user func GetStarredRepos(userID int64, private bool, listOptions ListOptions) ([]*Repository, error) { - sess := db.DefaultContext().Engine().Where("star.uid=?", userID). + sess := db.GetEngine(db.DefaultContext).Where("star.uid=?", userID). Join("LEFT", "star", "`repository`.id=`star`.repo_id") if !private { sess = sess.And("is_private=?", false) @@ -1703,7 +1703,7 @@ func GetStarredRepos(userID int64, private bool, listOptions ListOptions) ([]*Re // GetWatchedRepos returns the repos watched by a particular user func GetWatchedRepos(userID int64, private bool, listOptions ListOptions) ([]*Repository, int64, error) { - sess := db.DefaultContext().Engine().Where("watch.user_id=?", userID). + sess := db.GetEngine(db.DefaultContext).Where("watch.user_id=?", userID). And("`watch`.mode<>?", RepoWatchModeDont). Join("LEFT", "watch", "`repository`.id=`watch`.repo_id") if !private { @@ -1729,7 +1729,7 @@ func IterateUser(f func(user *User) error) error { batchSize := setting.Database.IterateBufferSize for { users := make([]*User, 0, batchSize) - if err := db.DefaultContext().Engine().Limit(batchSize, start).Find(&users); err != nil { + if err := db.GetEngine(db.DefaultContext).Limit(batchSize, start).Find(&users); err != nil { return err } if len(users) == 0 { diff --git a/models/user_avatar.go b/models/user_avatar.go index 99e533758841..65e59eb326ae 100644 --- a/models/user_avatar.go +++ b/models/user_avatar.go @@ -26,7 +26,7 @@ func (u *User) CustomAvatarRelativePath() string { // GenerateRandomAvatar generates a random avatar for user. func (u *User) GenerateRandomAvatar() error { - return u.generateRandomAvatar(db.DefaultContext().Engine()) + return u.generateRandomAvatar(db.GetEngine(db.DefaultContext)) } func (u *User) generateRandomAvatar(e db.Engine) error { @@ -125,7 +125,7 @@ func (u *User) UploadAvatar(data []byte) error { return err } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -165,7 +165,7 @@ func (u *User) DeleteAvatar() error { u.UseCustomAvatar = false u.Avatar = "" - if _, err := db.DefaultContext().Engine().ID(u.ID).Cols("avatar, use_custom_avatar").Update(u); err != nil { + if _, err := db.GetEngine(db.DefaultContext).ID(u.ID).Cols("avatar, use_custom_avatar").Update(u); err != nil { return fmt.Errorf("UpdateUser: %v", err) } return nil diff --git a/models/user_follow.go b/models/user_follow.go index ceaaaf9103d9..8832aa2f18dc 100644 --- a/models/user_follow.go +++ b/models/user_follow.go @@ -23,7 +23,7 @@ func init() { // IsFollowing returns true if user is following followID. func IsFollowing(userID, followID int64) bool { - has, _ := db.DefaultContext().Engine().Get(&Follow{UserID: userID, FollowID: followID}) + has, _ := db.GetEngine(db.DefaultContext).Get(&Follow{UserID: userID, FollowID: followID}) return has } @@ -33,7 +33,7 @@ func FollowUser(userID, followID int64) (err error) { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -59,7 +59,7 @@ func UnfollowUser(userID, followID int64) (err error) { return nil } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err diff --git a/models/user_heatmap.go b/models/user_heatmap.go index e713612635aa..3e94a6f9b703 100644 --- a/models/user_heatmap.go +++ b/models/user_heatmap.go @@ -59,7 +59,7 @@ func getUserHeatmapData(user *User, team *Team, doer *User) ([]*UserHeatmapData, return nil, err } - return hdata, db.DefaultContext().Engine(). + return hdata, db.GetEngine(db.DefaultContext). Select(groupBy+" AS timestamp, count(user_id) as contributions"). Table("action"). Where(cond). diff --git a/models/user_mail.go b/models/user_mail.go index 169399bc7143..51d34d26826a 100644 --- a/models/user_mail.go +++ b/models/user_mail.go @@ -58,7 +58,7 @@ func ValidateEmail(email string) error { // GetEmailAddresses returns all email addresses belongs to given user. func GetEmailAddresses(uid int64) ([]*EmailAddress, error) { emails := make([]*EmailAddress, 0, 5) - if err := db.DefaultContext().Engine(). + if err := db.GetEngine(db.DefaultContext). Where("uid=?", uid). Asc("id"). Find(&emails); err != nil { @@ -71,7 +71,7 @@ func GetEmailAddresses(uid int64) ([]*EmailAddress, error) { func GetEmailAddressByID(uid, id int64) (*EmailAddress, error) { // User ID is required for security reasons email := &EmailAddress{UID: uid} - if has, err := db.DefaultContext().Engine().ID(id).Get(email); err != nil { + if has, err := db.GetEngine(db.DefaultContext).ID(id).Get(email); err != nil { return nil, err } else if !has { return nil, nil @@ -114,7 +114,7 @@ func isEmailUsed(e db.Engine, email string) (bool, error) { // IsEmailUsed returns true if the email has been used. func IsEmailUsed(email string) (bool, error) { - return isEmailUsed(db.DefaultContext().Engine(), email) + return isEmailUsed(db.GetEngine(db.DefaultContext), email) } func addEmailAddress(e db.Engine, email *EmailAddress) error { @@ -136,7 +136,7 @@ func addEmailAddress(e db.Engine, email *EmailAddress) error { // AddEmailAddress adds an email address to given user. func AddEmailAddress(email *EmailAddress) error { - return addEmailAddress(db.DefaultContext().Engine(), email) + return addEmailAddress(db.GetEngine(db.DefaultContext), email) } // AddEmailAddresses adds an email address to given user. @@ -159,7 +159,7 @@ func AddEmailAddresses(emails []*EmailAddress) error { } } - if _, err := db.DefaultContext().Engine().Insert(emails); err != nil { + if _, err := db.GetEngine(db.DefaultContext).Insert(emails); err != nil { return fmt.Errorf("Insert: %v", err) } @@ -168,7 +168,7 @@ func AddEmailAddresses(emails []*EmailAddress) error { // Activate activates the email address to given user. func (email *EmailAddress) Activate() error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -206,12 +206,12 @@ func DeleteEmailAddress(email *EmailAddress) (err error) { UID: email.UID, } if email.ID > 0 { - deleted, err = db.DefaultContext().Engine().ID(email.ID).Delete(&address) + deleted, err = db.GetEngine(db.DefaultContext).ID(email.ID).Delete(&address) } else { if email.Email != "" && email.LowerEmail == "" { email.LowerEmail = strings.ToLower(email.Email) } - deleted, err = db.DefaultContext().Engine(). + deleted, err = db.GetEngine(db.DefaultContext). Where("lower_email=?", email.LowerEmail). Delete(&address) } @@ -237,7 +237,7 @@ func DeleteEmailAddresses(emails []*EmailAddress) (err error) { // MakeEmailPrimary sets primary email address of given user. func MakeEmailPrimary(email *EmailAddress) error { - has, err := db.DefaultContext().Engine().Get(email) + has, err := db.GetEngine(db.DefaultContext).Get(email) if err != nil { return err } else if !has { @@ -249,14 +249,14 @@ func MakeEmailPrimary(email *EmailAddress) error { } user := &User{} - has, err = db.DefaultContext().Engine().ID(email.UID).Get(user) + has, err = db.GetEngine(db.DefaultContext).ID(email.UID).Get(user) if err != nil { return err } else if !has { return ErrUserNotExist{email.UID, "", 0} } - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -346,7 +346,7 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) cond = cond.And(builder.Eq{"email_address.is_activated": false}) } - count, err := db.DefaultContext().Engine().Join("INNER", "`user`", "`user`.ID = email_address.uid"). + count, err := db.GetEngine(db.DefaultContext).Join("INNER", "`user`", "`user`.ID = email_address.uid"). Where(cond).Count(new(EmailAddress)) if err != nil { return nil, 0, fmt.Errorf("Count: %v", err) @@ -360,7 +360,7 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) opts.setDefaultValues() emails := make([]*SearchEmailResult, 0, opts.PageSize) - err = db.DefaultContext().Engine().Table("email_address"). + err = db.GetEngine(db.DefaultContext).Table("email_address"). Select("email_address.*, `user`.name, `user`.full_name"). Join("INNER", "`user`", "`user`.ID = email_address.uid"). Where(cond). @@ -374,7 +374,7 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) // ActivateUserEmail will change the activated state of an email address, // either primary or secondary (all in the email_address table) func ActivateUserEmail(userID int64, email string, activate bool) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err diff --git a/models/user_openid.go b/models/user_openid.go index 4844b68c2993..17a58536a238 100644 --- a/models/user_openid.go +++ b/models/user_openid.go @@ -30,7 +30,7 @@ func init() { // GetUserOpenIDs returns all openid addresses that belongs to given user. func GetUserOpenIDs(uid int64) ([]*UserOpenID, error) { openids := make([]*UserOpenID, 0, 5) - if err := db.DefaultContext().Engine(). + if err := db.GetEngine(db.DefaultContext). Where("uid=?", uid). Asc("id"). Find(&openids); err != nil { @@ -64,7 +64,7 @@ func addUserOpenID(e db.Engine, openid *UserOpenID) error { // AddUserOpenID adds an pre-verified/normalized OpenID URI to given user. func AddUserOpenID(openid *UserOpenID) error { - return addUserOpenID(db.DefaultContext().Engine(), openid) + return addUserOpenID(db.GetEngine(db.DefaultContext), openid) } // DeleteUserOpenID deletes an openid address of given user. @@ -75,9 +75,9 @@ func DeleteUserOpenID(openid *UserOpenID) (err error) { UID: openid.UID, } if openid.ID > 0 { - deleted, err = db.DefaultContext().Engine().ID(openid.ID).Delete(&address) + deleted, err = db.GetEngine(db.DefaultContext).ID(openid.ID).Delete(&address) } else { - deleted, err = db.DefaultContext().Engine(). + deleted, err = db.GetEngine(db.DefaultContext). Where("openid=?", openid.URI). Delete(&address) } @@ -92,7 +92,7 @@ func DeleteUserOpenID(openid *UserOpenID) (err error) { // ToggleUserOpenIDVisibility toggles visibility of an openid address of given user. func ToggleUserOpenIDVisibility(id int64) (err error) { - _, err = db.DefaultContext().Engine().Exec("update `user_open_id` set `show` = not `show` where `id` = ?", id) + _, err = db.GetEngine(db.DefaultContext).Exec("update `user_open_id` set `show` = not `show` where `id` = ?", id) return err } @@ -111,7 +111,7 @@ func GetUserByOpenID(uri string) (*User, error) { // Otherwise, check in openid table oid := &UserOpenID{} - has, err := db.DefaultContext().Engine().Where("uri=?", uri).Get(oid) + has, err := db.GetEngine(db.DefaultContext).Where("uri=?", uri).Get(oid) if err != nil { return nil, err } diff --git a/models/user_redirect.go b/models/user_redirect.go index 4380eb5984e0..fdc730775f70 100644 --- a/models/user_redirect.go +++ b/models/user_redirect.go @@ -25,7 +25,7 @@ func init() { func LookupUserRedirect(userName string) (int64, error) { userName = strings.ToLower(userName) redirect := &UserRedirect{LowerName: userName} - if has, err := db.DefaultContext().Engine().Get(redirect); err != nil { + if has, err := db.GetEngine(db.DefaultContext).Get(redirect); err != nil { return 0, err } else if !has { return 0, ErrUserRedirectNotExist{Name: userName} diff --git a/models/user_redirect_test.go b/models/user_redirect_test.go index 190778cfbcbb..346bf98b846e 100644 --- a/models/user_redirect_test.go +++ b/models/user_redirect_test.go @@ -27,7 +27,7 @@ func TestNewUserRedirect(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) user := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User) - assert.NoError(t, newUserRedirect(db.DefaultContext().Engine(), user.ID, user.Name, "newusername")) + assert.NoError(t, newUserRedirect(db.GetEngine(db.DefaultContext), user.ID, user.Name, "newusername")) db.AssertExistsAndLoadBean(t, &UserRedirect{ LowerName: user.LowerName, @@ -44,7 +44,7 @@ func TestNewUserRedirect2(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) user := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User) - assert.NoError(t, newUserRedirect(db.DefaultContext().Engine(), user.ID, user.Name, "olduser1")) + assert.NoError(t, newUserRedirect(db.GetEngine(db.DefaultContext), user.ID, user.Name, "olduser1")) db.AssertExistsAndLoadBean(t, &UserRedirect{ LowerName: user.LowerName, @@ -61,7 +61,7 @@ func TestNewUserRedirect3(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User) - assert.NoError(t, newUserRedirect(db.DefaultContext().Engine(), user.ID, user.Name, "newusername")) + assert.NoError(t, newUserRedirect(db.GetEngine(db.DefaultContext), user.ID, user.Name, "newusername")) db.AssertExistsAndLoadBean(t, &UserRedirect{ LowerName: user.LowerName, diff --git a/models/user_test.go b/models/user_test.go index aadf7e4c5eb8..6c616a60a902 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -160,7 +160,7 @@ func TestDeleteUser(t *testing.T) { user := db.AssertExistsAndLoadBean(t, &User{ID: userID}).(*User) ownedRepos := make([]*Repository, 0, 10) - assert.NoError(t, db.DefaultContext().Engine().Find(&ownedRepos, &Repository{OwnerID: userID})) + assert.NoError(t, db.GetEngine(db.DefaultContext).Find(&ownedRepos, &Repository{OwnerID: userID})) if len(ownedRepos) > 0 { err := DeleteUser(user) assert.Error(t, err) @@ -169,7 +169,7 @@ func TestDeleteUser(t *testing.T) { } orgUsers := make([]*OrgUser, 0, 10) - assert.NoError(t, db.DefaultContext().Engine().Find(&orgUsers, &OrgUser{UID: userID})) + assert.NoError(t, db.GetEngine(db.DefaultContext).Find(&orgUsers, &OrgUser{UID: userID})) for _, orgUser := range orgUsers { if err := RemoveOrgUser(orgUser.OrgID, orgUser.UID); err != nil { assert.True(t, IsErrLastOrgOwner(err)) @@ -281,7 +281,7 @@ func TestGetOrgRepositoryIDs(t *testing.T) { func TestNewGitSig(t *testing.T) { users := make([]*User, 0, 20) - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() sess.Find(&users) @@ -296,7 +296,7 @@ func TestNewGitSig(t *testing.T) { func TestDisplayName(t *testing.T) { users := make([]*User, 0, 20) - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() sess.Find(&users) diff --git a/models/userlist.go b/models/userlist.go index 9741c90ccfa0..bfa7ea1e2ea2 100644 --- a/models/userlist.go +++ b/models/userlist.go @@ -29,7 +29,7 @@ func (users UserList) IsUserOrgOwner(orgID int64) map[int64]bool { for _, user := range users { results[user.ID] = false // Set default to false } - ownerMaps, err := users.loadOrganizationOwners(db.DefaultContext().Engine(), orgID) + ownerMaps, err := users.loadOrganizationOwners(db.GetEngine(db.DefaultContext), orgID) if err == nil { for _, owner := range ownerMaps { results[owner.UID] = true @@ -69,7 +69,7 @@ func (users UserList) GetTwoFaStatus() map[int64]bool { for _, user := range users { results[user.ID] = false // Set default to false } - tokenMaps, err := users.loadTwoFactorStatus(db.DefaultContext().Engine()) + tokenMaps, err := users.loadTwoFactorStatus(db.GetEngine(db.DefaultContext)) if err == nil { for _, token := range tokenMaps { results[token.UID] = true diff --git a/models/webhook.go b/models/webhook.go index fca6eec67e93..034b37263aad 100644 --- a/models/webhook.go +++ b/models/webhook.go @@ -351,7 +351,7 @@ func (w *Webhook) EventsArray() []string { // CreateWebhook creates a new web hook. func CreateWebhook(w *Webhook) error { - return createWebhook(db.DefaultContext().Engine(), w) + return createWebhook(db.GetEngine(db.DefaultContext), w) } func createWebhook(e db.Engine, w *Webhook) error { @@ -363,7 +363,7 @@ func createWebhook(e db.Engine, w *Webhook) error { // getWebhook uses argument bean as query condition, // ID must be specified and do not assign unnecessary fields. func getWebhook(bean *Webhook) (*Webhook, error) { - has, err := db.DefaultContext().Engine().Get(bean) + has, err := db.GetEngine(db.DefaultContext).Get(bean) if err != nil { return nil, err } else if !has { @@ -434,17 +434,17 @@ func listWebhooksByOpts(e db.Engine, opts *ListWebhookOptions) ([]*Webhook, erro // ListWebhooksByOpts return webhooks based on options func ListWebhooksByOpts(opts *ListWebhookOptions) ([]*Webhook, error) { - return listWebhooksByOpts(db.DefaultContext().Engine(), opts) + return listWebhooksByOpts(db.GetEngine(db.DefaultContext), opts) } // CountWebhooksByOpts count webhooks based on options and ignore pagination func CountWebhooksByOpts(opts *ListWebhookOptions) (int64, error) { - return db.DefaultContext().Engine().Where(opts.toCond()).Count(&Webhook{}) + return db.GetEngine(db.DefaultContext).Where(opts.toCond()).Count(&Webhook{}) } // GetDefaultWebhooks returns all admin-default webhooks. func GetDefaultWebhooks() ([]*Webhook, error) { - return getDefaultWebhooks(db.DefaultContext().Engine()) + return getDefaultWebhooks(db.GetEngine(db.DefaultContext)) } func getDefaultWebhooks(e db.Engine) ([]*Webhook, error) { @@ -457,7 +457,7 @@ func getDefaultWebhooks(e db.Engine) ([]*Webhook, error) { // GetSystemOrDefaultWebhook returns admin system or default webhook by given ID. func GetSystemOrDefaultWebhook(id int64) (*Webhook, error) { webhook := &Webhook{ID: id} - has, err := db.DefaultContext().Engine(). + has, err := db.GetEngine(db.DefaultContext). Where("repo_id=? AND org_id=?", 0, 0). Get(webhook) if err != nil { @@ -470,7 +470,7 @@ func GetSystemOrDefaultWebhook(id int64) (*Webhook, error) { // GetSystemWebhooks returns all admin system webhooks. func GetSystemWebhooks() ([]*Webhook, error) { - return getSystemWebhooks(db.DefaultContext().Engine()) + return getSystemWebhooks(db.GetEngine(db.DefaultContext)) } func getSystemWebhooks(e db.Engine) ([]*Webhook, error) { @@ -482,20 +482,20 @@ func getSystemWebhooks(e db.Engine) ([]*Webhook, error) { // UpdateWebhook updates information of webhook. func UpdateWebhook(w *Webhook) error { - _, err := db.DefaultContext().Engine().ID(w.ID).AllCols().Update(w) + _, err := db.GetEngine(db.DefaultContext).ID(w.ID).AllCols().Update(w) return err } // UpdateWebhookLastStatus updates last status of webhook. func UpdateWebhookLastStatus(w *Webhook) error { - _, err := db.DefaultContext().Engine().ID(w.ID).Cols("last_status").Update(w) + _, err := db.GetEngine(db.DefaultContext).ID(w.ID).Cols("last_status").Update(w) return err } // deleteWebhook uses argument bean as query condition, // ID must be specified and do not assign unnecessary fields. func deleteWebhook(bean *Webhook) (err error) { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err = sess.Begin(); err != nil { return err @@ -530,7 +530,7 @@ func DeleteWebhookByOrgID(orgID, id int64) error { // DeleteDefaultSystemWebhook deletes an admin-configured default or system webhook (where Org and Repo ID both 0) func DeleteDefaultSystemWebhook(id int64) error { - sess := db.DefaultContext().NewSession() + sess := db.NewSession(db.DefaultContext) defer sess.Close() if err := sess.Begin(); err != nil { return err @@ -713,7 +713,7 @@ func (t *HookTask) simpleMarshalJSON(v interface{}) string { // HookTasks returns a list of hook tasks by given conditions. func HookTasks(hookID int64, page int) ([]*HookTask, error) { tasks := make([]*HookTask, 0, setting.Webhook.PagingNum) - return tasks, db.DefaultContext().Engine(). + return tasks, db.GetEngine(db.DefaultContext). Limit(setting.Webhook.PagingNum, (page-1)*setting.Webhook.PagingNum). Where("hook_id=?", hookID). Desc("id"). @@ -723,7 +723,7 @@ func HookTasks(hookID int64, page int) ([]*HookTask, error) { // CreateHookTask creates a new hook task, // it handles conversion from Payload to PayloadContent. func CreateHookTask(t *HookTask) error { - return createHookTask(db.DefaultContext().Engine(), t) + return createHookTask(db.GetEngine(db.DefaultContext), t) } func createHookTask(e db.Engine, t *HookTask) error { @@ -739,14 +739,14 @@ func createHookTask(e db.Engine, t *HookTask) error { // UpdateHookTask updates information of hook task. func UpdateHookTask(t *HookTask) error { - _, err := db.DefaultContext().Engine().ID(t.ID).AllCols().Update(t) + _, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t) return err } // FindUndeliveredHookTasks represents find the undelivered hook tasks func FindUndeliveredHookTasks() ([]*HookTask, error) { tasks := make([]*HookTask, 0, 10) - if err := db.DefaultContext().Engine().Where("is_delivered=?", false).Find(&tasks); err != nil { + if err := db.GetEngine(db.DefaultContext).Where("is_delivered=?", false).Find(&tasks); err != nil { return nil, err } return tasks, nil @@ -755,7 +755,7 @@ func FindUndeliveredHookTasks() ([]*HookTask, error) { // FindRepoUndeliveredHookTasks represents find the undelivered hook tasks of one repository func FindRepoUndeliveredHookTasks(repoID int64) ([]*HookTask, error) { tasks := make([]*HookTask, 0, 5) - if err := db.DefaultContext().Engine().Where("repo_id=? AND is_delivered=?", repoID, false).Find(&tasks); err != nil { + if err := db.GetEngine(db.DefaultContext).Where("repo_id=? AND is_delivered=?", repoID, false).Find(&tasks); err != nil { return nil, err } return tasks, nil @@ -767,7 +767,7 @@ func CleanupHookTaskTable(ctx context.Context, cleanupType HookTaskCleanupType, if cleanupType == OlderThan { deleteOlderThan := time.Now().Add(-olderThan).UnixNano() - deletes, err := db.DefaultContext().Engine(). + deletes, err := db.GetEngine(db.DefaultContext). Where("is_delivered = ? and delivered < ?", true, deleteOlderThan). Delete(new(HookTask)) if err != nil { @@ -776,7 +776,7 @@ func CleanupHookTaskTable(ctx context.Context, cleanupType HookTaskCleanupType, log.Trace("Deleted %d rows from hook_task", deletes) } else if cleanupType == PerWebhook { hookIDs := make([]int64, 0, 10) - err := db.DefaultContext().Engine().Table("webhook"). + err := db.GetEngine(db.DefaultContext).Table("webhook"). Where("id > 0"). Cols("id"). Find(&hookIDs) @@ -801,7 +801,7 @@ func CleanupHookTaskTable(ctx context.Context, cleanupType HookTaskCleanupType, func deleteDeliveredHookTasksByWebhook(hookID int64, numberDeliveriesToKeep int) error { log.Trace("Deleting hook_task rows for webhook %d, keeping the most recent %d deliveries", hookID, numberDeliveriesToKeep) deliveryDates := make([]int64, 0, 10) - err := db.DefaultContext().Engine().Table("hook_task"). + err := db.GetEngine(db.DefaultContext).Table("hook_task"). Where("hook_task.hook_id = ? AND hook_task.is_delivered = ? AND hook_task.delivered is not null", hookID, true). Cols("hook_task.delivered"). Join("INNER", "webhook", "hook_task.hook_id = webhook.id"). @@ -813,7 +813,7 @@ func deleteDeliveredHookTasksByWebhook(hookID int64, numberDeliveriesToKeep int) } if len(deliveryDates) > 0 { - deletes, err := db.DefaultContext().Engine(). + deletes, err := db.GetEngine(db.DefaultContext). Where("hook_id = ? and is_delivered = ? and delivered <= ?", hookID, true, deliveryDates[0]). Delete(new(HookTask)) if err != nil { diff --git a/modules/doctor/mergebase.go b/modules/doctor/mergebase.go index e792c2b2fddb..c959da8d7f5f 100644 --- a/modules/doctor/mergebase.go +++ b/modules/doctor/mergebase.go @@ -18,7 +18,7 @@ import ( func iteratePRs(repo *models.Repository, each func(*models.Repository, *models.PullRequest) error) error { return db.Iterate( - db.DefaultContext(), + db.DefaultContext, new(models.PullRequest), builder.Eq{"base_repo_id": repo.ID}, func(idx int, bean interface{}) error { diff --git a/modules/doctor/misc.go b/modules/doctor/misc.go index 25bc3c3a7267..2f748bcb7187 100644 --- a/modules/doctor/misc.go +++ b/modules/doctor/misc.go @@ -26,7 +26,7 @@ import ( func iterateRepositories(each func(*models.Repository) error) error { err := db.Iterate( - db.DefaultContext(), + db.DefaultContext, new(models.Repository), builder.Gt{"id": 0}, func(idx int, bean interface{}) error { diff --git a/modules/repository/adopt.go b/modules/repository/adopt.go index 9371822fbcc1..daefee9c7460 100644 --- a/modules/repository/adopt.go +++ b/modules/repository/adopt.go @@ -5,6 +5,7 @@ package repository import ( + "context" "fmt" "os" "path/filepath" @@ -47,7 +48,7 @@ func AdoptRepository(doer, u *models.User, opts models.CreateRepoOptions) (*mode IsEmpty: !opts.AutoInit, } - if err := db.WithTx(func(ctx *db.Context) error { + if err := db.WithTx(func(ctx context.Context) error { repoPath := models.RepoPath(u.Name, repo.Name) isExist, err := util.IsExist(repoPath) if err != nil { diff --git a/modules/repository/check.go b/modules/repository/check.go index 1b550ad4f08f..78e6f8a90135 100644 --- a/modules/repository/check.go +++ b/modules/repository/check.go @@ -24,7 +24,7 @@ func GitFsck(ctx context.Context, timeout time.Duration, args []string) error { log.Trace("Doing: GitFsck") if err := db.Iterate( - db.DefaultContext(), + db.DefaultContext, new(models.Repository), builder.Expr("id>0 AND is_fsck_enabled=?", true), func(idx int, bean interface{}) error { @@ -59,7 +59,7 @@ func GitGcRepos(ctx context.Context, timeout time.Duration, args ...string) erro args = append([]string{"gc"}, args...) if err := db.Iterate( - db.DefaultContext(), + db.DefaultContext, new(models.Repository), builder.Gt{"id": 0}, func(idx int, bean interface{}) error { @@ -94,7 +94,7 @@ func GitGcRepos(ctx context.Context, timeout time.Duration, args ...string) erro } // Now update the size of the repository - if err := repo.UpdateSize(db.DefaultContext()); err != nil { + if err := repo.UpdateSize(db.DefaultContext); err != nil { log.Error("Updating size as part of garbage collection failed for %v. Stdout: %s\nError: %v", repo, stdout, err) desc := fmt.Sprintf("Updating size as part of garbage collection failed for %s. Stdout: %s\nError: %v", repo.RepoPath(), stdout, err) if err = models.CreateRepositoryNotice(desc); err != nil { @@ -116,7 +116,7 @@ func GitGcRepos(ctx context.Context, timeout time.Duration, args ...string) erro func gatherMissingRepoRecords(ctx context.Context) ([]*models.Repository, error) { repos := make([]*models.Repository, 0, 10) if err := db.Iterate( - db.DefaultContext(), + db.DefaultContext, new(models.Repository), builder.Gt{"id": 0}, func(idx int, bean interface{}) error { diff --git a/modules/repository/create.go b/modules/repository/create.go index 80f446e83f05..0e91a73b8359 100644 --- a/modules/repository/create.go +++ b/modules/repository/create.go @@ -5,6 +5,7 @@ package repository import ( + "context" "fmt" "strings" @@ -55,7 +56,7 @@ func CreateRepository(doer, u *models.User, opts models.CreateRepoOptions) (*mod var rollbackRepo *models.Repository - if err := db.WithTx(func(ctx *db.Context) error { + if err := db.WithTx(func(ctx context.Context) error { if err := models.CreateRepository(ctx, doer, u, repo, false); err != nil { return err } diff --git a/modules/repository/fork.go b/modules/repository/fork.go index ff69f75b32a8..59c07271a646 100644 --- a/modules/repository/fork.go +++ b/modules/repository/fork.go @@ -5,6 +5,7 @@ package repository import ( + "context" "fmt" "strings" "time" @@ -79,7 +80,7 @@ func ForkRepository(doer, owner *models.User, opts models.ForkRepoOptions) (_ *m panic(panicErr) }() - err = db.WithTx(func(ctx *db.Context) error { + err = db.WithTx(func(ctx context.Context) error { if err = models.CreateRepository(ctx, doer, owner, repo, false); err != nil { return err } @@ -123,7 +124,7 @@ func ForkRepository(doer, owner *models.User, opts models.ForkRepoOptions) (_ *m } // even if below operations failed, it could be ignored. And they will be retried - ctx := db.DefaultContext() + ctx := db.DefaultContext if err := repo.UpdateSize(ctx); err != nil { log.Error("Failed to update size for repository: %v", err) } @@ -136,7 +137,7 @@ func ForkRepository(doer, owner *models.User, opts models.ForkRepoOptions) (_ *m // ConvertForkToNormalRepository convert the provided repo from a forked repo to normal repo func ConvertForkToNormalRepository(repo *models.Repository) error { - err := db.WithTx(func(ctx *db.Context) error { + err := db.WithTx(func(ctx context.Context) error { repo, err := models.GetRepositoryByIDCtx(ctx, repo.ID) if err != nil { return err diff --git a/modules/repository/generate.go b/modules/repository/generate.go index 8ab518add503..4fcebc06dc3e 100644 --- a/modules/repository/generate.go +++ b/modules/repository/generate.go @@ -5,6 +5,7 @@ package repository import ( + "context" "fmt" "os" "path" @@ -13,7 +14,6 @@ import ( "time" "code.gitea.io/gitea/models" - "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/util" @@ -185,7 +185,7 @@ func generateRepoCommit(repo, templateRepo, generateRepo *models.Repository, tmp return initRepoCommit(tmpDir, repo, repo.Owner, templateRepo.DefaultBranch) } -func generateGitContent(ctx *db.Context, repo, templateRepo, generateRepo *models.Repository) (err error) { +func generateGitContent(ctx context.Context, repo, templateRepo, generateRepo *models.Repository) (err error) { tmpDir, err := os.MkdirTemp(os.TempDir(), "gitea-"+repo.Name) if err != nil { return fmt.Errorf("Failed to create temp dir for repository %s: %v", repo.RepoPath(), err) @@ -223,7 +223,7 @@ func generateGitContent(ctx *db.Context, repo, templateRepo, generateRepo *model } // GenerateGitContent generates git content from a template repository -func GenerateGitContent(ctx *db.Context, templateRepo, generateRepo *models.Repository) error { +func GenerateGitContent(ctx context.Context, templateRepo, generateRepo *models.Repository) error { if err := generateGitContent(ctx, generateRepo, templateRepo, generateRepo); err != nil { return err } @@ -239,7 +239,7 @@ func GenerateGitContent(ctx *db.Context, templateRepo, generateRepo *models.Repo } // GenerateRepository generates a repository from a template -func GenerateRepository(ctx *db.Context, doer, owner *models.User, templateRepo *models.Repository, opts models.GenerateRepoOptions) (_ *models.Repository, err error) { +func GenerateRepository(ctx context.Context, doer, owner *models.User, templateRepo *models.Repository, opts models.GenerateRepoOptions) (_ *models.Repository, err error) { generateRepo := &models.Repository{ OwnerID: owner.ID, Owner: owner, diff --git a/modules/repository/hooks.go b/modules/repository/hooks.go index e219903f75ae..6072dda0163f 100644 --- a/modules/repository/hooks.go +++ b/modules/repository/hooks.go @@ -221,7 +221,7 @@ func SyncRepositoryHooks(ctx context.Context) error { log.Trace("Doing: SyncRepositoryHooks") if err := db.Iterate( - db.DefaultContext(), + db.DefaultContext, new(models.Repository), builder.Gt{"id": 0}, func(idx int, bean interface{}) error { diff --git a/modules/repository/init.go b/modules/repository/init.go index 2a86e964cac4..5a1ff7e98bc6 100644 --- a/modules/repository/init.go +++ b/modules/repository/init.go @@ -6,6 +6,7 @@ package repository import ( "bytes" + "context" "fmt" "os" "path/filepath" @@ -13,7 +14,6 @@ import ( "time" "code.gitea.io/gitea/models" - "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -22,7 +22,7 @@ import ( "github.com/unknwon/com" ) -func prepareRepoCommit(ctx *db.Context, repo *models.Repository, tmpDir, repoPath string, opts models.CreateRepoOptions) error { +func prepareRepoCommit(ctx context.Context, repo *models.Repository, tmpDir, repoPath string, opts models.CreateRepoOptions) error { commitTimeStr := time.Now().Format(time.RFC3339) authorSig := repo.Owner.NewGitSig() @@ -196,7 +196,7 @@ func checkInitRepository(owner, name string) (err error) { return nil } -func adoptRepository(ctx *db.Context, repoPath string, u *models.User, repo *models.Repository, opts models.CreateRepoOptions) (err error) { +func adoptRepository(ctx context.Context, repoPath string, u *models.User, repo *models.Repository, opts models.CreateRepoOptions) (err error) { isExist, err := util.IsExist(repoPath) if err != nil { log.Error("Unable to check if %s exists. Error: %v", repoPath, err) @@ -283,7 +283,7 @@ func adoptRepository(ctx *db.Context, repoPath string, u *models.User, repo *mod } // InitRepository initializes README and .gitignore if needed. -func initRepository(ctx *db.Context, repoPath string, u *models.User, repo *models.Repository, opts models.CreateRepoOptions) (err error) { +func initRepository(ctx context.Context, repoPath string, u *models.User, repo *models.Repository, opts models.CreateRepoOptions) (err error) { if err = checkInitRepository(repo.OwnerName, repo.Name); err != nil { return err } diff --git a/modules/repository/repo.go b/modules/repository/repo.go index b59a80ee2f22..ee970fd711ee 100644 --- a/modules/repository/repo.go +++ b/modules/repository/repo.go @@ -133,7 +133,7 @@ func MigrateRepositoryGitData(ctx context.Context, u *models.User, repo *models. } } - if err = repo.UpdateSize(db.DefaultContext()); err != nil { + if err = repo.UpdateSize(db.DefaultContext); err != nil { log.Error("Failed to update size for repository: %v", err) } diff --git a/modules/repository/update.go b/modules/repository/update.go index d9ff12e1ad56..b9a5db2a6a13 100644 --- a/modules/repository/update.go +++ b/modules/repository/update.go @@ -5,6 +5,7 @@ package repository import ( + "context" "fmt" "strings" "time" @@ -17,7 +18,7 @@ import ( // PushUpdateAddDeleteTags updates a number of added and delete tags func PushUpdateAddDeleteTags(repo *models.Repository, gitRepo *git.Repository, addTags, delTags []string) error { - return db.WithTx(func(ctx *db.Context) error { + return db.WithTx(func(ctx context.Context) error { if err := models.PushUpdateDeleteTagsContext(ctx, repo, delTags); err != nil { return err } @@ -26,7 +27,7 @@ func PushUpdateAddDeleteTags(repo *models.Repository, gitRepo *git.Repository, a } // pushUpdateAddTags updates a number of add tags -func pushUpdateAddTags(ctx *db.Context, repo *models.Repository, gitRepo *git.Repository, tags []string) error { +func pushUpdateAddTags(ctx context.Context, repo *models.Repository, gitRepo *git.Repository, tags []string) error { if len(tags) == 0 { return nil } diff --git a/routers/web/org/org_labels.go b/routers/web/org/org_labels.go index 17509c50fe0e..13728a31b30a 100644 --- a/routers/web/org/org_labels.go +++ b/routers/web/org/org_labels.go @@ -99,7 +99,7 @@ func InitializeLabels(ctx *context.Context) { return } - if err := models.InitializeLabels(db.DefaultContext(), ctx.Org.Organization.ID, form.TemplateName, true); err != nil { + if err := models.InitializeLabels(db.DefaultContext, ctx.Org.Organization.ID, form.TemplateName, true); err != nil { if models.IsErrIssueLabelTemplateLoad(err) { originalErr := err.(models.ErrIssueLabelTemplateLoad).OriginalError ctx.Flash.Error(ctx.Tr("repo.issues.label_templates.fail_to_load_file", form.TemplateName, originalErr)) diff --git a/routers/web/repo/issue_label.go b/routers/web/repo/issue_label.go index daed302dcaa7..0ce511448547 100644 --- a/routers/web/repo/issue_label.go +++ b/routers/web/repo/issue_label.go @@ -39,7 +39,7 @@ func InitializeLabels(ctx *context.Context) { return } - if err := models.InitializeLabels(db.DefaultContext(), ctx.Repo.Repository.ID, form.TemplateName, false); err != nil { + if err := models.InitializeLabels(db.DefaultContext, ctx.Repo.Repository.ID, form.TemplateName, false); err != nil { if models.IsErrIssueLabelTemplateLoad(err) { originalErr := err.(models.ErrIssueLabelTemplateLoad).OriginalError ctx.Flash.Error(ctx.Tr("repo.issues.label_templates.fail_to_load_file", form.TemplateName, originalErr)) diff --git a/routers/web/repo/repo.go b/routers/web/repo/repo.go index 46a80524edd5..735bf4fe9f06 100644 --- a/routers/web/repo/repo.go +++ b/routers/web/repo/repo.go @@ -343,7 +343,7 @@ func RedirectDownload(ctx *context.Context) { ) tagNames := []string{vTag} curRepo := ctx.Repo.Repository - releases, err := models.GetReleasesByRepoIDAndNames(db.DefaultContext(), curRepo.ID, tagNames) + releases, err := models.GetReleasesByRepoIDAndNames(db.DefaultContext, curRepo.ID, tagNames) if err != nil { if models.IsErrAttachmentNotExist(err) { ctx.Error(http.StatusNotFound) @@ -380,7 +380,7 @@ func Download(ctx *context.Context) { return } - archiver, err := models.GetRepoArchiver(db.DefaultContext(), aReq.RepoID, aReq.Type, aReq.CommitID) + archiver, err := models.GetRepoArchiver(db.DefaultContext, aReq.RepoID, aReq.Type, aReq.CommitID) if err != nil { ctx.ServerError("models.GetRepoArchiver", err) return @@ -410,7 +410,7 @@ func Download(ctx *context.Context) { return } times++ - archiver, err = models.GetRepoArchiver(db.DefaultContext(), aReq.RepoID, aReq.Type, aReq.CommitID) + archiver, err = models.GetRepoArchiver(db.DefaultContext, aReq.RepoID, aReq.Type, aReq.CommitID) if err != nil { ctx.ServerError("archiver_service.StartArchive", err) return @@ -466,7 +466,7 @@ func InitiateDownload(ctx *context.Context) { return } - archiver, err := models.GetRepoArchiver(db.DefaultContext(), aReq.RepoID, aReq.Type, aReq.CommitID) + archiver, err := models.GetRepoArchiver(db.DefaultContext, aReq.RepoID, aReq.Type, aReq.CommitID) if err != nil { ctx.ServerError("archiver_service.StartArchive", err) return diff --git a/services/attachment/attachment.go b/services/attachment/attachment.go index 06f79be01b9e..7500a8ac3a65 100644 --- a/services/attachment/attachment.go +++ b/services/attachment/attachment.go @@ -6,6 +6,7 @@ package attachment import ( "bytes" + "context" "fmt" "io" @@ -23,7 +24,7 @@ func NewAttachment(attach *models.Attachment, file io.Reader) (*models.Attachmen return nil, fmt.Errorf("attachment %s should belong to a repository", attach.Name) } - err := db.WithTx(func(ctx *db.Context) error { + err := db.WithTx(func(ctx context.Context) error { attach.UUID = uuid.New().String() size, err := storage.Attachments.Save(attach.RelativePath(), file, -1) if err != nil { diff --git a/services/comments/comments.go b/services/comments/comments.go index 901b82e38013..d65c66aef26b 100644 --- a/services/comments/comments.go +++ b/services/comments/comments.go @@ -23,7 +23,7 @@ func CreateIssueComment(doer *models.User, repo *models.Repository, issue *model if err != nil { return nil, err } - mentions, err := issue.FindAndUpdateIssueMentions(db.DefaultContext(), doer, comment.Content) + mentions, err := issue.FindAndUpdateIssueMentions(db.DefaultContext, doer, comment.Content) if err != nil { return nil, err } diff --git a/services/issue/issue.go b/services/issue/issue.go index b2ac24e088e6..e3571bd396f6 100644 --- a/services/issue/issue.go +++ b/services/issue/issue.go @@ -24,7 +24,7 @@ func NewIssue(repo *models.Repository, issue *models.Issue, labelIDs []int64, uu } } - mentions, err := issue.FindAndUpdateIssueMentions(db.DefaultContext(), issue.Poster, issue.Content) + mentions, err := issue.FindAndUpdateIssueMentions(db.DefaultContext, issue.Poster, issue.Content) if err != nil { return err } diff --git a/services/mirror/mirror_pull.go b/services/mirror/mirror_pull.go index 1c0bf2b59511..c2b413131d19 100644 --- a/services/mirror/mirror_pull.go +++ b/services/mirror/mirror_pull.go @@ -204,7 +204,7 @@ func runSync(ctx context.Context, m *models.Mirror) ([]*mirrorSyncResult, bool) gitRepo.Close() log.Trace("SyncMirrors [repo: %-v]: updating size of repository", m.Repo) - if err := m.Repo.UpdateSize(db.DefaultContext()); err != nil { + if err := m.Repo.UpdateSize(db.DefaultContext); err != nil { log.Error("Failed to update size for mirror repository: %v", err) } diff --git a/services/pull/pull.go b/services/pull/pull.go index d78d9b1bd004..bd5551b6dcc0 100644 --- a/services/pull/pull.go +++ b/services/pull/pull.go @@ -59,7 +59,7 @@ func NewPullRequest(repo *models.Repository, pull *models.Issue, labelIDs []int6 return err } - mentions, err := pull.FindAndUpdateIssueMentions(db.DefaultContext(), pull.Poster, pull.Content) + mentions, err := pull.FindAndUpdateIssueMentions(db.DefaultContext, pull.Poster, pull.Content) if err != nil { return err } diff --git a/services/pull/review.go b/services/pull/review.go index ce34cc59dfd6..f65314c45d10 100644 --- a/services/pull/review.go +++ b/services/pull/review.go @@ -59,7 +59,7 @@ func CreateCodeComment(doer *models.User, gitRepo *git.Repository, issue *models return nil, err } - mentions, err := issue.FindAndUpdateIssueMentions(db.DefaultContext(), doer, comment.Content) + mentions, err := issue.FindAndUpdateIssueMentions(db.DefaultContext, doer, comment.Content) if err != nil { return nil, err } @@ -246,7 +246,7 @@ func SubmitReview(doer *models.User, gitRepo *git.Repository, issue *models.Issu return nil, nil, err } - ctx := db.DefaultContext() + ctx := db.DefaultContext mentions, err := issue.FindAndUpdateIssueMentions(ctx, doer, comm.Content) if err != nil { return nil, nil, err diff --git a/services/release/release.go b/services/release/release.go index a9181398b59a..f6f456e8fa37 100644 --- a/services/release/release.go +++ b/services/release/release.go @@ -123,7 +123,7 @@ func CreateRelease(gitRepo *git.Repository, rel *models.Release, attachmentUUIDs return err } - if err = models.AddReleaseAttachments(db.DefaultContext(), rel.ID, attachmentUUIDs); err != nil { + if err = models.AddReleaseAttachments(db.DefaultContext, rel.ID, attachmentUUIDs); err != nil { return err } @@ -311,7 +311,7 @@ func DeleteReleaseByID(id int64, doer *models.User, delTag bool) error { } else { rel.IsTag = true - if err = models.UpdateRelease(db.DefaultContext(), rel); err != nil { + if err = models.UpdateRelease(db.DefaultContext, rel); err != nil { return fmt.Errorf("Update: %v", err) } } diff --git a/services/repository/generate.go b/services/repository/generate.go index 5a9f6aa6584b..fe38723dea35 100644 --- a/services/repository/generate.go +++ b/services/repository/generate.go @@ -5,6 +5,8 @@ package repository import ( + "context" + "code.gitea.io/gitea/models" "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/log" @@ -21,7 +23,7 @@ func GenerateRepository(doer, owner *models.User, templateRepo *models.Repositor } var generateRepo *models.Repository - if err = db.WithTx(func(ctx *db.Context) error { + if err = db.WithTx(func(ctx context.Context) error { generateRepo, err = repo_module.GenerateRepository(ctx, doer, owner, templateRepo, opts) if err != nil { return err diff --git a/services/repository/push.go b/services/repository/push.go index f7590d5787dc..4d86667539c4 100644 --- a/services/repository/push.go +++ b/services/repository/push.go @@ -83,7 +83,7 @@ func pushUpdates(optsList []*repo_module.PushUpdateOptions) error { } defer gitRepo.Close() - if err = repo.UpdateSize(db.DefaultContext()); err != nil { + if err = repo.UpdateSize(db.DefaultContext); err != nil { log.Error("Failed to update size for repository: %v", err) } From 4a2655098fd1a594c7d33a144932bb5ec2fd7cd9 Mon Sep 17 00:00:00 2001 From: crapStone Date: Thu, 23 Sep 2021 18:57:52 +0200 Subject: [PATCH 05/13] Unify issue and pr subtitles (#17133) --- options/locale/locale_en-US.ini | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/options/locale/locale_en-US.ini b/options/locale/locale_en-US.ini index b38b249b2cf1..2ee0f5b56664 100644 --- a/options/locale/locale_en-US.ini +++ b/options/locale/locale_en-US.ini @@ -1197,11 +1197,11 @@ issues.action_milestone_no_select = No milestone issues.action_assignee = Assignee issues.action_assignee_no_select = No assignee issues.opened_by = opened %[1]s by %[3]s -pulls.merged_by = by %[3]s merged %[1]s -pulls.merged_by_fake = by %[2]s merged %[1]s -issues.closed_by = by %[3]s closed %[1]s -issues.opened_by_fake = by %[2]s opened %[1]s -issues.closed_by_fake = by %[2]s closed %[1]s +pulls.merged_by = merged %[1]s by %[3]s +pulls.merged_by_fake = merged %[1]s by %[2]s +issues.closed_by = closed %[1]s by %[3]s +issues.opened_by_fake = opened %[1]s by %[2]s +issues.closed_by_fake = closed %[1]s by %[2]s issues.previous = Previous issues.next = Next issues.open_title = Open From 5842a55b3103d3f09751eb7b3b049415197debad Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 24 Sep 2021 19:32:56 +0800 Subject: [PATCH 06/13] Move login related structs and functions to models/login (#17093) * Move login related structs and functions to models/login * Fix test * Fix lint * Fix lint * Fix lint of windows * Fix lint * Fix test * Fix test * Only load necessary fixtures when preparing unit tests envs * Fix lint * Fix test * Fix test * Fix error log * Fix error log * Fix error log * remove unnecessary change * fix error log * merge main branch --- .golangci.yml | 1 - cmd/admin.go | 21 +- cmd/admin_auth_ldap.go | 32 +-- cmd/admin_auth_ldap_test.go | 242 ++++++++--------- contrib/fixtures/fixture_generation.go | 4 +- contrib/pr/checkout.go | 4 +- integrations/api_oauth2_apps_test.go | 27 +- integrations/integration_test.go | 4 +- models/access.go | 2 +- models/commit_status.go | 14 +- models/commit_status_test.go | 2 +- models/{ => db}/list_options.go | 22 +- models/{ => db}/list_options_test.go | 2 +- models/db/test_fixtures.go | 9 +- models/db/unit_tests.go | 24 +- models/error.go | 91 ------- models/external_login_user.go | 3 +- models/gpg_key.go | 6 +- models/gpg_key_commit_verification.go | 3 +- models/issue.go | 2 +- models/issue_comment.go | 4 +- models/issue_label.go | 12 +- models/issue_label_test.go | 8 +- models/issue_milestone.go | 4 +- models/issue_milestone_test.go | 4 +- models/issue_reaction.go | 6 +- models/issue_stopwatch.go | 4 +- models/issue_test.go | 6 +- models/issue_tracked_time.go | 4 +- models/issue_watch.go | 6 +- models/issue_watch_test.go | 8 +- models/login/main_test.go | 21 ++ models/login/oauth2.go | 70 +++++ models/{ => login}/oauth2_application.go | 31 +-- models/{ => login}/oauth2_application_test.go | 10 +- models/{login_source.go => login/source.go} | 255 ++++++++++-------- .../source_test.go} | 10 +- models/migrations/migrations_test.go | 5 +- models/notification.go | 4 +- models/oauth2.go | 27 -- models/org.go | 6 +- models/org_team.go | 8 +- models/org_test.go | 4 +- models/pull_list.go | 4 +- models/pull_sign.go | 3 +- models/pull_test.go | 4 +- models/release.go | 4 +- models/repo.go | 6 +- models/repo_collaboration.go | 8 +- models/repo_collaboration_test.go | 2 +- models/repo_generate.go | 2 +- models/repo_list.go | 2 +- models/repo_list_test.go | 68 ++--- models/repo_sign.go | 7 +- models/repo_transfer.go | 2 +- models/repo_unit.go | 3 +- models/repo_watch.go | 4 +- models/repo_watch_test.go | 20 +- models/review.go | 4 +- models/ssh_key.go | 27 +- models/ssh_key_deploy.go | 4 +- models/ssh_key_principals.go | 4 +- models/star.go | 4 +- models/star_test.go | 4 +- models/statistic.go | 7 +- models/token.go | 4 +- models/topic.go | 4 +- models/topic_test.go | 2 +- models/user.go | 35 +-- models/user_mail.go | 4 +- models/user_mail_test.go | 4 +- models/user_test.go | 31 ++- models/webhook.go | 4 +- modules/convert/convert.go | 5 +- modules/gitgraph/graph_models.go | 3 +- modules/indexer/issues/indexer.go | 3 +- modules/migrations/gitea_uploader_test.go | 6 +- modules/repository/adopt.go | 35 ++- modules/repository/repo.go | 6 +- routers/api/v1/admin/user.go | 9 +- routers/api/v1/repo/issue.go | 7 +- routers/api/v1/user/app.go | 17 +- routers/api/v1/user/gpg_key.go | 3 +- routers/api/v1/user/star.go | 3 +- routers/api/v1/user/watch.go | 3 +- routers/api/v1/utils/utils.go | 6 +- routers/web/admin/auths.go | 82 +++--- routers/web/admin/emails.go | 3 +- routers/web/admin/orgs.go | 3 +- routers/web/admin/repos.go | 3 +- routers/web/admin/users.go | 30 ++- routers/web/events/events.go | 3 +- routers/web/explore/org.go | 3 +- routers/web/explore/repo.go | 3 +- routers/web/explore/user.go | 3 +- routers/web/org/home.go | 5 +- routers/web/org/org_labels.go | 2 +- routers/web/org/setting.go | 3 +- routers/web/repo/commit.go | 3 +- routers/web/repo/issue.go | 19 +- routers/web/repo/issue_label.go | 4 +- routers/web/repo/milestone.go | 3 +- routers/web/repo/pull.go | 7 +- routers/web/repo/release.go | 3 +- routers/web/repo/setting.go | 3 +- routers/web/repo/view.go | 9 +- routers/web/user/auth.go | 21 +- routers/web/user/home.go | 5 +- routers/web/user/oauth.go | 29 +- routers/web/user/oauth_test.go | 7 +- routers/web/user/profile.go | 11 +- routers/web/user/setting/applications.go | 5 +- routers/web/user/setting/keys.go | 7 +- routers/web/user/setting/oauth2.go | 18 +- routers/web/user/setting/profile.go | 5 +- routers/web/user/setting/security.go | 5 +- services/auth/login_source.go | 41 +++ services/auth/oauth2.go | 5 +- services/auth/signin.go | 7 +- .../auth/source/db/assert_interface_test.go | 4 +- services/auth/source/db/source.go | 9 +- .../auth/source/ldap/assert_interface_test.go | 14 +- services/auth/source/ldap/source.go | 9 +- .../auth/source/ldap/source_authenticate.go | 11 +- .../source/oauth2/assert_interface_test.go | 8 +- services/auth/source/oauth2/init.go | 4 +- services/auth/source/oauth2/providers.go | 3 +- services/auth/source/oauth2/source.go | 7 +- .../auth/source/pam/assert_interface_test.go | 6 +- services/auth/source/pam/source.go | 7 +- .../auth/source/pam/source_authenticate.go | 11 +- .../auth/source/smtp/assert_interface_test.go | 12 +- services/auth/source/smtp/source.go | 7 +- .../auth/source/smtp/source_authenticate.go | 33 +-- .../auth/source/sspi/assert_interface_test.go | 4 +- services/auth/source/sspi/source.go | 3 +- services/auth/sspi_windows.go | 5 +- services/auth/sync.go | 3 +- services/externalaccount/user.go | 3 +- services/pull/commit_status.go | 3 +- services/pull/pull.go | 2 +- services/pull/review.go | 2 +- 142 files changed, 1046 insertions(+), 903 deletions(-) rename models/{ => db}/list_options.go (78%) rename models/{ => db}/list_options_test.go (98%) create mode 100644 models/login/main_test.go create mode 100644 models/login/oauth2.go rename models/{ => login}/oauth2_application.go (96%) rename models/{ => login}/oauth2_application_test.go (96%) rename models/{login_source.go => login/source.go} (55%) rename models/{login_source_test.go => login/source_test.go} (86%) delete mode 100644 models/oauth2.go create mode 100644 services/auth/login_source.go diff --git a/.golangci.yml b/.golangci.yml index c3dd47ec29da..2d66e01ffaf8 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -111,4 +111,3 @@ issues: linters: - staticcheck text: "svc.IsAnInteractiveSession is deprecated: Use IsWindowsService instead." - diff --git a/cmd/admin.go b/cmd/admin.go index cfc297c47464..099083ae9100 100644 --- a/cmd/admin.go +++ b/cmd/admin.go @@ -14,6 +14,8 @@ import ( "text/tabwriter" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/graceful" "code.gitea.io/gitea/modules/log" @@ -21,6 +23,7 @@ import ( repo_module "code.gitea.io/gitea/modules/repository" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/storage" + auth_service "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/auth/source/oauth2" "github.com/urfave/cli" @@ -529,7 +532,7 @@ func runRepoSyncReleases(_ *cli.Context) error { log.Trace("Synchronizing repository releases (this may take a while)") for page := 1; ; page++ { repos, count, err := models.SearchRepositoryByName(&models.SearchRepoOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: models.RepositoryListDefaultPageSize, Page: page, }, @@ -629,8 +632,8 @@ func runAddOauth(c *cli.Context) error { return err } - return models.CreateLoginSource(&models.LoginSource{ - Type: models.LoginOAuth2, + return login.CreateSource(&login.Source{ + Type: login.OAuth2, Name: c.String("name"), IsActive: true, Cfg: parseOAuth2Config(c), @@ -646,7 +649,7 @@ func runUpdateOauth(c *cli.Context) error { return err } - source, err := models.GetLoginSourceByID(c.Int64("id")) + source, err := login.GetSourceByID(c.Int64("id")) if err != nil { return err } @@ -705,7 +708,7 @@ func runUpdateOauth(c *cli.Context) error { oAuth2Config.CustomURLMapping = customURLMapping source.Cfg = oAuth2Config - return models.UpdateSource(source) + return login.UpdateSource(source) } func runListAuth(c *cli.Context) error { @@ -713,7 +716,7 @@ func runListAuth(c *cli.Context) error { return err } - loginSources, err := models.LoginSources() + loginSources, err := login.Sources() if err != nil { return err @@ -733,7 +736,7 @@ func runListAuth(c *cli.Context) error { w := tabwriter.NewWriter(os.Stdout, c.Int("min-width"), c.Int("tab-width"), c.Int("padding"), padChar, flags) fmt.Fprintf(w, "ID\tName\tType\tEnabled\n") for _, source := range loginSources { - fmt.Fprintf(w, "%d\t%s\t%s\t%t\n", source.ID, source.Name, models.LoginNames[source.Type], source.IsActive) + fmt.Fprintf(w, "%d\t%s\t%s\t%t\n", source.ID, source.Name, source.Type.String(), source.IsActive) } w.Flush() @@ -749,10 +752,10 @@ func runDeleteAuth(c *cli.Context) error { return err } - source, err := models.GetLoginSourceByID(c.Int64("id")) + source, err := login.GetSourceByID(c.Int64("id")) if err != nil { return err } - return models.DeleteSource(source) + return auth_service.DeleteLoginSource(source) } diff --git a/cmd/admin_auth_ldap.go b/cmd/admin_auth_ldap.go index feeaf17661f7..e95e1d15c64c 100644 --- a/cmd/admin_auth_ldap.go +++ b/cmd/admin_auth_ldap.go @@ -8,7 +8,7 @@ import ( "fmt" "strings" - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/auth/source/ldap" "github.com/urfave/cli" @@ -17,9 +17,9 @@ import ( type ( authService struct { initDB func() error - createLoginSource func(loginSource *models.LoginSource) error - updateLoginSource func(loginSource *models.LoginSource) error - getLoginSourceByID func(id int64) (*models.LoginSource, error) + createLoginSource func(loginSource *login.Source) error + updateLoginSource func(loginSource *login.Source) error + getLoginSourceByID func(id int64) (*login.Source, error) } ) @@ -164,14 +164,14 @@ var ( func newAuthService() *authService { return &authService{ initDB: initDB, - createLoginSource: models.CreateLoginSource, - updateLoginSource: models.UpdateSource, - getLoginSourceByID: models.GetLoginSourceByID, + createLoginSource: login.CreateSource, + updateLoginSource: login.UpdateSource, + getLoginSourceByID: login.GetSourceByID, } } // parseLoginSource assigns values on loginSource according to command line flags. -func parseLoginSource(c *cli.Context, loginSource *models.LoginSource) { +func parseLoginSource(c *cli.Context, loginSource *login.Source) { if c.IsSet("name") { loginSource.Name = c.String("name") } @@ -269,7 +269,7 @@ func findLdapSecurityProtocolByName(name string) (ldap.SecurityProtocol, bool) { // getLoginSource gets the login source by its id defined in the command line flags. // It returns an error if the id is not set, does not match any source or if the source is not of expected type. -func (a *authService) getLoginSource(c *cli.Context, loginType models.LoginType) (*models.LoginSource, error) { +func (a *authService) getLoginSource(c *cli.Context, loginType login.Type) (*login.Source, error) { if err := argsSet(c, "id"); err != nil { return nil, err } @@ -280,7 +280,7 @@ func (a *authService) getLoginSource(c *cli.Context, loginType models.LoginType) } if loginSource.Type != loginType { - return nil, fmt.Errorf("Invalid authentication type. expected: %s, actual: %s", models.LoginNames[loginType], models.LoginNames[loginSource.Type]) + return nil, fmt.Errorf("Invalid authentication type. expected: %s, actual: %s", loginType.String(), loginSource.Type.String()) } return loginSource, nil @@ -296,8 +296,8 @@ func (a *authService) addLdapBindDn(c *cli.Context) error { return err } - loginSource := &models.LoginSource{ - Type: models.LoginLDAP, + loginSource := &login.Source{ + Type: login.LDAP, IsActive: true, // active by default Cfg: &ldap.Source{ Enabled: true, // always true @@ -318,7 +318,7 @@ func (a *authService) updateLdapBindDn(c *cli.Context) error { return err } - loginSource, err := a.getLoginSource(c, models.LoginLDAP) + loginSource, err := a.getLoginSource(c, login.LDAP) if err != nil { return err } @@ -341,8 +341,8 @@ func (a *authService) addLdapSimpleAuth(c *cli.Context) error { return err } - loginSource := &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource := &login.Source{ + Type: login.DLDAP, IsActive: true, // active by default Cfg: &ldap.Source{ Enabled: true, // always true @@ -363,7 +363,7 @@ func (a *authService) updateLdapSimpleAuth(c *cli.Context) error { return err } - loginSource, err := a.getLoginSource(c, models.LoginDLDAP) + loginSource, err := a.getLoginSource(c, login.DLDAP) if err != nil { return err } diff --git a/cmd/admin_auth_ldap_test.go b/cmd/admin_auth_ldap_test.go index 692b11e3f422..c26cbdaf39c6 100644 --- a/cmd/admin_auth_ldap_test.go +++ b/cmd/admin_auth_ldap_test.go @@ -7,7 +7,7 @@ package cmd import ( "testing" - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/auth/source/ldap" "github.com/stretchr/testify/assert" @@ -23,7 +23,7 @@ func TestAddLdapBindDn(t *testing.T) { // Test cases var cases = []struct { args []string - loginSource *models.LoginSource + loginSource *login.Source errMsg string }{ // case 0 @@ -51,8 +51,8 @@ func TestAddLdapBindDn(t *testing.T) { "--synchronize-users", "--page-size", "99", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Name: "ldap (via Bind DN) source full", IsActive: false, IsSyncEnabled: true, @@ -91,8 +91,8 @@ func TestAddLdapBindDn(t *testing.T) { "--user-filter", "(memberOf=cn=user-group,ou=example,dc=min-domain-bind,dc=org)", "--email-attribute", "mail-bind min", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Name: "ldap (via Bind DN) source min", IsActive: true, Cfg: &ldap.Source{ @@ -203,20 +203,20 @@ func TestAddLdapBindDn(t *testing.T) { for n, c := range cases { // Mock functions. - var createdLoginSource *models.LoginSource + var createdLoginSource *login.Source service := &authService{ initDB: func() error { return nil }, - createLoginSource: func(loginSource *models.LoginSource) error { + createLoginSource: func(loginSource *login.Source) error { createdLoginSource = loginSource return nil }, - updateLoginSource: func(loginSource *models.LoginSource) error { + updateLoginSource: func(loginSource *login.Source) error { assert.FailNow(t, "case %d: should not call updateLoginSource", n) return nil }, - getLoginSourceByID: func(id int64) (*models.LoginSource, error) { + getLoginSourceByID: func(id int64) (*login.Source, error) { assert.FailNow(t, "case %d: should not call getLoginSourceByID", n) return nil, nil }, @@ -247,7 +247,7 @@ func TestAddLdapSimpleAuth(t *testing.T) { // Test cases var cases = []struct { args []string - loginSource *models.LoginSource + loginSource *login.Source errMsg string }{ // case 0 @@ -271,8 +271,8 @@ func TestAddLdapSimpleAuth(t *testing.T) { "--public-ssh-key-attribute", "publickey-simple full", "--user-dn", "cn=%s,ou=Users,dc=full-domain-simple,dc=org", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Name: "ldap (simple auth) source full", IsActive: false, Cfg: &ldap.Source{ @@ -307,8 +307,8 @@ func TestAddLdapSimpleAuth(t *testing.T) { "--email-attribute", "mail-simple min", "--user-dn", "cn=%s,ou=Users,dc=min-domain-simple,dc=org", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Name: "ldap (simple auth) source min", IsActive: true, Cfg: &ldap.Source{ @@ -432,20 +432,20 @@ func TestAddLdapSimpleAuth(t *testing.T) { for n, c := range cases { // Mock functions. - var createdLoginSource *models.LoginSource + var createdLoginSource *login.Source service := &authService{ initDB: func() error { return nil }, - createLoginSource: func(loginSource *models.LoginSource) error { + createLoginSource: func(loginSource *login.Source) error { createdLoginSource = loginSource return nil }, - updateLoginSource: func(loginSource *models.LoginSource) error { + updateLoginSource: func(loginSource *login.Source) error { assert.FailNow(t, "case %d: should not call updateLoginSource", n) return nil }, - getLoginSourceByID: func(id int64) (*models.LoginSource, error) { + getLoginSourceByID: func(id int64) (*login.Source, error) { assert.FailNow(t, "case %d: should not call getLoginSourceByID", n) return nil, nil }, @@ -477,8 +477,8 @@ func TestUpdateLdapBindDn(t *testing.T) { var cases = []struct { args []string id int64 - existingLoginSource *models.LoginSource - loginSource *models.LoginSource + existingLoginSource *login.Source + loginSource *login.Source errMsg string }{ // case 0 @@ -507,15 +507,15 @@ func TestUpdateLdapBindDn(t *testing.T) { "--page-size", "99", }, id: 23, - existingLoginSource: &models.LoginSource{ - Type: models.LoginLDAP, + existingLoginSource: &login.Source{ + Type: login.LDAP, IsActive: true, Cfg: &ldap.Source{ Enabled: true, }, }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Name: "ldap (via Bind DN) source full", IsActive: false, IsSyncEnabled: true, @@ -548,8 +548,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "ldap-test", "--id", "1", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{}, }, }, @@ -560,8 +560,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--name", "ldap (via Bind DN) source", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Name: "ldap (via Bind DN) source", Cfg: &ldap.Source{ Name: "ldap (via Bind DN) source", @@ -575,13 +575,13 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--not-active", }, - existingLoginSource: &models.LoginSource{ - Type: models.LoginLDAP, + existingLoginSource: &login.Source{ + Type: login.LDAP, IsActive: true, Cfg: &ldap.Source{}, }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, IsActive: false, Cfg: &ldap.Source{}, }, @@ -593,8 +593,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--security-protocol", "LDAPS", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ SecurityProtocol: ldap.SecurityProtocol(1), }, @@ -607,8 +607,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--skip-tls-verify", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ SkipVerify: true, }, @@ -621,8 +621,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--host", "ldap-server", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ Host: "ldap-server", }, @@ -635,8 +635,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--port", "389", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ Port: 389, }, @@ -649,8 +649,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--user-search-base", "ou=Users,dc=domain,dc=org", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ UserBase: "ou=Users,dc=domain,dc=org", }, @@ -663,8 +663,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--user-filter", "(memberOf=cn=user-group,ou=example,dc=domain,dc=org)", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ Filter: "(memberOf=cn=user-group,ou=example,dc=domain,dc=org)", }, @@ -677,8 +677,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--admin-filter", "(memberOf=cn=admin-group,ou=example,dc=domain,dc=org)", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ AdminFilter: "(memberOf=cn=admin-group,ou=example,dc=domain,dc=org)", }, @@ -691,8 +691,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--username-attribute", "uid", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ AttributeUsername: "uid", }, @@ -705,8 +705,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--firstname-attribute", "givenName", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ AttributeName: "givenName", }, @@ -719,8 +719,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--surname-attribute", "sn", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ AttributeSurname: "sn", }, @@ -733,8 +733,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--email-attribute", "mail", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ AttributeMail: "mail", }, @@ -747,8 +747,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--attributes-in-bind", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ AttributesInBind: true, }, @@ -761,8 +761,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--public-ssh-key-attribute", "publickey", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ AttributeSSHPublicKey: "publickey", }, @@ -775,8 +775,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--bind-dn", "cn=readonly,dc=domain,dc=org", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ BindDN: "cn=readonly,dc=domain,dc=org", }, @@ -789,8 +789,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--bind-password", "secret", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ BindPassword: "secret", }, @@ -803,8 +803,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--synchronize-users", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, IsSyncEnabled: true, Cfg: &ldap.Source{}, }, @@ -816,8 +816,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "--id", "1", "--page-size", "12", }, - loginSource: &models.LoginSource{ - Type: models.LoginLDAP, + loginSource: &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{ SearchPageSize: 12, }, @@ -845,8 +845,8 @@ func TestUpdateLdapBindDn(t *testing.T) { "ldap-test", "--id", "1", }, - existingLoginSource: &models.LoginSource{ - Type: models.LoginOAuth2, + existingLoginSource: &login.Source{ + Type: login.OAuth2, Cfg: &ldap.Source{}, }, errMsg: "Invalid authentication type. expected: LDAP (via BindDN), actual: OAuth2", @@ -855,28 +855,28 @@ func TestUpdateLdapBindDn(t *testing.T) { for n, c := range cases { // Mock functions. - var updatedLoginSource *models.LoginSource + var updatedLoginSource *login.Source service := &authService{ initDB: func() error { return nil }, - createLoginSource: func(loginSource *models.LoginSource) error { + createLoginSource: func(loginSource *login.Source) error { assert.FailNow(t, "case %d: should not call createLoginSource", n) return nil }, - updateLoginSource: func(loginSource *models.LoginSource) error { + updateLoginSource: func(loginSource *login.Source) error { updatedLoginSource = loginSource return nil }, - getLoginSourceByID: func(id int64) (*models.LoginSource, error) { + getLoginSourceByID: func(id int64) (*login.Source, error) { if c.id != 0 { assert.Equal(t, c.id, id, "case %d: wrong id", n) } if c.existingLoginSource != nil { return c.existingLoginSource, nil } - return &models.LoginSource{ - Type: models.LoginLDAP, + return &login.Source{ + Type: login.LDAP, Cfg: &ldap.Source{}, }, nil }, @@ -908,8 +908,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { var cases = []struct { args []string id int64 - existingLoginSource *models.LoginSource - loginSource *models.LoginSource + existingLoginSource *login.Source + loginSource *login.Source errMsg string }{ // case 0 @@ -935,8 +935,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--user-dn", "cn=%s,ou=Users,dc=full-domain-simple,dc=org", }, id: 7, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Name: "ldap (simple auth) source full", IsActive: false, Cfg: &ldap.Source{ @@ -964,8 +964,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "ldap-test", "--id", "1", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{}, }, }, @@ -976,8 +976,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--name", "ldap (simple auth) source", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Name: "ldap (simple auth) source", Cfg: &ldap.Source{ Name: "ldap (simple auth) source", @@ -991,13 +991,13 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--not-active", }, - existingLoginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + existingLoginSource: &login.Source{ + Type: login.DLDAP, IsActive: true, Cfg: &ldap.Source{}, }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, IsActive: false, Cfg: &ldap.Source{}, }, @@ -1009,8 +1009,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--security-protocol", "starttls", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ SecurityProtocol: ldap.SecurityProtocol(2), }, @@ -1023,8 +1023,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--skip-tls-verify", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ SkipVerify: true, }, @@ -1037,8 +1037,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--host", "ldap-server", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ Host: "ldap-server", }, @@ -1051,8 +1051,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--port", "987", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ Port: 987, }, @@ -1065,8 +1065,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--user-search-base", "ou=Users,dc=domain,dc=org", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ UserBase: "ou=Users,dc=domain,dc=org", }, @@ -1079,8 +1079,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--user-filter", "(&(objectClass=posixAccount)(cn=%s))", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ Filter: "(&(objectClass=posixAccount)(cn=%s))", }, @@ -1093,8 +1093,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--admin-filter", "(memberOf=cn=admin-group,ou=example,dc=domain,dc=org)", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ AdminFilter: "(memberOf=cn=admin-group,ou=example,dc=domain,dc=org)", }, @@ -1107,8 +1107,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--username-attribute", "uid", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ AttributeUsername: "uid", }, @@ -1121,8 +1121,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--firstname-attribute", "givenName", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ AttributeName: "givenName", }, @@ -1135,8 +1135,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--surname-attribute", "sn", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ AttributeSurname: "sn", }, @@ -1149,8 +1149,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--email-attribute", "mail", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ AttributeMail: "mail", @@ -1164,8 +1164,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--public-ssh-key-attribute", "publickey", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ AttributeSSHPublicKey: "publickey", }, @@ -1178,8 +1178,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "--id", "1", "--user-dn", "cn=%s,ou=Users,dc=domain,dc=org", }, - loginSource: &models.LoginSource{ - Type: models.LoginDLDAP, + loginSource: &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{ UserDN: "cn=%s,ou=Users,dc=domain,dc=org", }, @@ -1207,8 +1207,8 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { "ldap-test", "--id", "1", }, - existingLoginSource: &models.LoginSource{ - Type: models.LoginPAM, + existingLoginSource: &login.Source{ + Type: login.PAM, Cfg: &ldap.Source{}, }, errMsg: "Invalid authentication type. expected: LDAP (simple auth), actual: PAM", @@ -1217,28 +1217,28 @@ func TestUpdateLdapSimpleAuth(t *testing.T) { for n, c := range cases { // Mock functions. - var updatedLoginSource *models.LoginSource + var updatedLoginSource *login.Source service := &authService{ initDB: func() error { return nil }, - createLoginSource: func(loginSource *models.LoginSource) error { + createLoginSource: func(loginSource *login.Source) error { assert.FailNow(t, "case %d: should not call createLoginSource", n) return nil }, - updateLoginSource: func(loginSource *models.LoginSource) error { + updateLoginSource: func(loginSource *login.Source) error { updatedLoginSource = loginSource return nil }, - getLoginSourceByID: func(id int64) (*models.LoginSource, error) { + getLoginSourceByID: func(id int64) (*login.Source, error) { if c.id != 0 { assert.Equal(t, c.id, id, "case %d: wrong id", n) } if c.existingLoginSource != nil { return c.existingLoginSource, nil } - return &models.LoginSource{ - Type: models.LoginDLDAP, + return &login.Source{ + Type: login.DLDAP, Cfg: &ldap.Source{}, }, nil }, diff --git a/contrib/fixtures/fixture_generation.go b/contrib/fixtures/fixture_generation.go index 5408a005c663..5e7dd39a78fb 100644 --- a/contrib/fixtures/fixture_generation.go +++ b/contrib/fixtures/fixture_generation.go @@ -31,7 +31,9 @@ var ( func main() { pathToGiteaRoot := "." fixturesDir = filepath.Join(pathToGiteaRoot, "models", "fixtures") - if err := db.CreateTestEngine(fixturesDir); err != nil { + if err := db.CreateTestEngine(db.FixturesOptions{ + Dir: fixturesDir, + }); err != nil { fmt.Printf("CreateTestEngine: %+v", err) os.Exit(1) } diff --git a/contrib/pr/checkout.go b/contrib/pr/checkout.go index cba6d4d372da..d831ebdabdb2 100644 --- a/contrib/pr/checkout.go +++ b/contrib/pr/checkout.go @@ -101,7 +101,9 @@ func runPR() { db.HasEngine = true //x.ShowSQL(true) err = db.InitFixtures( - path.Join(curDir, "models/fixtures/"), + db.FixturesOptions{ + Dir: path.Join(curDir, "models/fixtures/"), + }, ) if err != nil { fmt.Printf("Error initializing test database: %v\n", err) diff --git a/integrations/api_oauth2_apps_test.go b/integrations/api_oauth2_apps_test.go index 4cda41755a6c..6f0a46249f51 100644 --- a/integrations/api_oauth2_apps_test.go +++ b/integrations/api_oauth2_apps_test.go @@ -11,6 +11,7 @@ import ( "code.gitea.io/gitea/models" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" api "code.gitea.io/gitea/modules/structs" "github.com/stretchr/testify/assert" @@ -46,7 +47,7 @@ func testAPICreateOAuth2Application(t *testing.T) { assert.Len(t, createdApp.ClientID, 36) assert.NotEmpty(t, createdApp.Created) assert.EqualValues(t, appBody.RedirectURIs[0], createdApp.RedirectURIs[0]) - db.AssertExistsAndLoadBean(t, &models.OAuth2Application{UID: user.ID, Name: createdApp.Name}) + db.AssertExistsAndLoadBean(t, &login.OAuth2Application{UID: user.ID, Name: createdApp.Name}) } func testAPIListOAuth2Applications(t *testing.T) { @@ -54,13 +55,13 @@ func testAPIListOAuth2Applications(t *testing.T) { session := loginUser(t, user.Name) token := getTokenForLoggedInUser(t, session) - existApp := db.AssertExistsAndLoadBean(t, &models.OAuth2Application{ + existApp := db.AssertExistsAndLoadBean(t, &login.OAuth2Application{ UID: user.ID, Name: "test-app-1", RedirectURIs: []string{ "http://www.google.com", }, - }).(*models.OAuth2Application) + }).(*login.OAuth2Application) urlStr := fmt.Sprintf("/api/v1/user/applications/oauth2?token=%s", token) req := NewRequest(t, "GET", urlStr) @@ -75,7 +76,7 @@ func testAPIListOAuth2Applications(t *testing.T) { assert.Len(t, expectedApp.ClientID, 36) assert.Empty(t, expectedApp.ClientSecret) assert.EqualValues(t, existApp.RedirectURIs[0], expectedApp.RedirectURIs[0]) - db.AssertExistsAndLoadBean(t, &models.OAuth2Application{ID: expectedApp.ID, Name: expectedApp.Name}) + db.AssertExistsAndLoadBean(t, &login.OAuth2Application{ID: expectedApp.ID, Name: expectedApp.Name}) } func testAPIDeleteOAuth2Application(t *testing.T) { @@ -83,16 +84,16 @@ func testAPIDeleteOAuth2Application(t *testing.T) { session := loginUser(t, user.Name) token := getTokenForLoggedInUser(t, session) - oldApp := db.AssertExistsAndLoadBean(t, &models.OAuth2Application{ + oldApp := db.AssertExistsAndLoadBean(t, &login.OAuth2Application{ UID: user.ID, Name: "test-app-1", - }).(*models.OAuth2Application) + }).(*login.OAuth2Application) urlStr := fmt.Sprintf("/api/v1/user/applications/oauth2/%d?token=%s", oldApp.ID, token) req := NewRequest(t, "DELETE", urlStr) session.MakeRequest(t, req, http.StatusNoContent) - db.AssertNotExistsBean(t, &models.OAuth2Application{UID: oldApp.UID, Name: oldApp.Name}) + db.AssertNotExistsBean(t, &login.OAuth2Application{UID: oldApp.UID, Name: oldApp.Name}) // Delete again will return not found req = NewRequest(t, "DELETE", urlStr) @@ -104,13 +105,13 @@ func testAPIGetOAuth2Application(t *testing.T) { session := loginUser(t, user.Name) token := getTokenForLoggedInUser(t, session) - existApp := db.AssertExistsAndLoadBean(t, &models.OAuth2Application{ + existApp := db.AssertExistsAndLoadBean(t, &login.OAuth2Application{ UID: user.ID, Name: "test-app-1", RedirectURIs: []string{ "http://www.google.com", }, - }).(*models.OAuth2Application) + }).(*login.OAuth2Application) urlStr := fmt.Sprintf("/api/v1/user/applications/oauth2/%d?token=%s", existApp.ID, token) req := NewRequest(t, "GET", urlStr) @@ -126,19 +127,19 @@ func testAPIGetOAuth2Application(t *testing.T) { assert.Empty(t, expectedApp.ClientSecret) assert.Len(t, expectedApp.RedirectURIs, 1) assert.EqualValues(t, existApp.RedirectURIs[0], expectedApp.RedirectURIs[0]) - db.AssertExistsAndLoadBean(t, &models.OAuth2Application{ID: expectedApp.ID, Name: expectedApp.Name}) + db.AssertExistsAndLoadBean(t, &login.OAuth2Application{ID: expectedApp.ID, Name: expectedApp.Name}) } func testAPIUpdateOAuth2Application(t *testing.T) { user := db.AssertExistsAndLoadBean(t, &models.User{ID: 2}).(*models.User) - existApp := db.AssertExistsAndLoadBean(t, &models.OAuth2Application{ + existApp := db.AssertExistsAndLoadBean(t, &login.OAuth2Application{ UID: user.ID, Name: "test-app-1", RedirectURIs: []string{ "http://www.google.com", }, - }).(*models.OAuth2Application) + }).(*login.OAuth2Application) appBody := api.CreateOAuth2ApplicationOptions{ Name: "test-app-1", @@ -160,5 +161,5 @@ func testAPIUpdateOAuth2Application(t *testing.T) { assert.Len(t, expectedApp.RedirectURIs, 2) assert.EqualValues(t, expectedApp.RedirectURIs[0], appBody.RedirectURIs[0]) assert.EqualValues(t, expectedApp.RedirectURIs[1], appBody.RedirectURIs[1]) - db.AssertExistsAndLoadBean(t, &models.OAuth2Application{ID: expectedApp.ID, Name: expectedApp.Name}) + db.AssertExistsAndLoadBean(t, &login.OAuth2Application{ID: expectedApp.ID, Name: expectedApp.Name}) } diff --git a/integrations/integration_test.go b/integrations/integration_test.go index fac36320cf88..1429893270b0 100644 --- a/integrations/integration_test.go +++ b/integrations/integration_test.go @@ -113,7 +113,9 @@ func TestMain(m *testing.M) { } err := db.InitFixtures( - path.Join(filepath.Dir(setting.AppPath), "models/fixtures/"), + db.FixturesOptions{ + Dir: filepath.Join(filepath.Dir(setting.AppPath), "models/fixtures/"), + }, ) if err != nil { fmt.Printf("Error initializing test database: %v\n", err) diff --git a/models/access.go b/models/access.go index 88fbe8189fa8..560234aae807 100644 --- a/models/access.go +++ b/models/access.go @@ -225,7 +225,7 @@ func (repo *Repository) refreshAccesses(e db.Engine, accessMap map[int64]*userAc // refreshCollaboratorAccesses retrieves repository collaborations with their access modes. func (repo *Repository) refreshCollaboratorAccesses(e db.Engine, accessMap map[int64]*userAccess) error { - collaborators, err := repo.getCollaborators(e, ListOptions{}) + collaborators, err := repo.getCollaborators(e, db.ListOptions{}) if err != nil { return fmt.Errorf("getCollaborations: %v", err) } diff --git a/models/commit_status.go b/models/commit_status.go index ada94667cccc..a6ded049c31c 100644 --- a/models/commit_status.go +++ b/models/commit_status.go @@ -163,7 +163,7 @@ func CalcCommitStatus(statuses []*CommitStatus) *CommitStatus { // CommitStatusOptions holds the options for query commit statuses type CommitStatusOptions struct { - ListOptions + db.ListOptions State string SortType string } @@ -178,7 +178,7 @@ func GetCommitStatuses(repo *Repository, sha string, opts *CommitStatusOptions) } countSession := listCommitStatusesStatement(repo, sha, opts) - countSession = setSessionPagination(countSession, opts) + countSession = db.SetSessionPagination(countSession, opts) maxResults, err := countSession.Count(new(CommitStatus)) if err != nil { log.Error("Count PRs: %v", err) @@ -187,7 +187,7 @@ func GetCommitStatuses(repo *Repository, sha string, opts *CommitStatusOptions) statuses := make([]*CommitStatus, 0, opts.PageSize) findSession := listCommitStatusesStatement(repo, sha, opts) - findSession = setSessionPagination(findSession, opts) + findSession = db.SetSessionPagination(findSession, opts) sortCommitStatusesSession(findSession, opts.SortType) return statuses, maxResults, findSession.Find(&statuses) } @@ -227,18 +227,18 @@ type CommitStatusIndex struct { } // GetLatestCommitStatus returns all statuses with a unique context for a given commit. -func GetLatestCommitStatus(repoID int64, sha string, listOptions ListOptions) ([]*CommitStatus, error) { +func GetLatestCommitStatus(repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, error) { return getLatestCommitStatus(db.GetEngine(db.DefaultContext), repoID, sha, listOptions) } -func getLatestCommitStatus(e db.Engine, repoID int64, sha string, listOptions ListOptions) ([]*CommitStatus, error) { +func getLatestCommitStatus(e db.Engine, repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, error) { ids := make([]int64, 0, 10) sess := e.Table(&CommitStatus{}). Where("repo_id = ?", repoID).And("sha = ?", sha). Select("max( id ) as id"). GroupBy("context_hash").OrderBy("max( id ) desc") - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) err := sess.Find(&ids) if err != nil { @@ -336,7 +336,7 @@ func ParseCommitsWithStatus(oldCommits []*SignCommit, repo *Repository) []*SignC commit := &SignCommitWithStatuses{ SignCommit: c, } - statuses, err := GetLatestCommitStatus(repo.ID, commit.ID.String(), ListOptions{}) + statuses, err := GetLatestCommitStatus(repo.ID, commit.ID.String(), db.ListOptions{}) if err != nil { log.Error("GetLatestCommitStatus: %v", err) } else { diff --git a/models/commit_status_test.go b/models/commit_status_test.go index 0d8dbf264686..7f4709144ceb 100644 --- a/models/commit_status_test.go +++ b/models/commit_status_test.go @@ -19,7 +19,7 @@ func TestGetCommitStatuses(t *testing.T) { sha1 := "1234123412341234123412341234123412341234" - statuses, maxResults, err := GetCommitStatuses(repo1, sha1, &CommitStatusOptions{ListOptions: ListOptions{Page: 1, PageSize: 50}}) + statuses, maxResults, err := GetCommitStatuses(repo1, sha1, &CommitStatusOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 50}}) assert.NoError(t, err) assert.Equal(t, int(maxResults), 5) assert.Len(t, statuses, 5) diff --git a/models/list_options.go b/models/db/list_options.go similarity index 78% rename from models/list_options.go rename to models/db/list_options.go index 25b9a05f16e8..f31febfe25a1 100644 --- a/models/list_options.go +++ b/models/db/list_options.go @@ -2,10 +2,9 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package models +package db import ( - "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/setting" "xorm.io/xorm" @@ -17,22 +16,22 @@ type Paginator interface { GetStartEnd() (start, end int) } -// getPaginatedSession creates a paginated database session -func getPaginatedSession(p Paginator) *xorm.Session { +// GetPaginatedSession creates a paginated database session +func GetPaginatedSession(p Paginator) *xorm.Session { skip, take := p.GetSkipTake() - return db.GetEngine(db.DefaultContext).Limit(take, skip) + return x.Limit(take, skip) } -// setSessionPagination sets pagination for a database session -func setSessionPagination(sess *xorm.Session, p Paginator) *xorm.Session { +// SetSessionPagination sets pagination for a database session +func SetSessionPagination(sess *xorm.Session, p Paginator) *xorm.Session { skip, take := p.GetSkipTake() return sess.Limit(take, skip) } -// setSessionPagination sets pagination for a database engine -func setEnginePagination(e db.Engine, p Paginator) db.Engine { +// SetEnginePagination sets pagination for a database engine +func SetEnginePagination(e Engine, p Paginator) Engine { skip, take := p.GetSkipTake() return e.Limit(take, skip) @@ -46,7 +45,7 @@ type ListOptions struct { // GetSkipTake returns the skip and take values func (opts *ListOptions) GetSkipTake() (skip, take int) { - opts.setDefaultValues() + opts.SetDefaultValues() return (opts.Page - 1) * opts.PageSize, opts.PageSize } @@ -57,7 +56,8 @@ func (opts *ListOptions) GetStartEnd() (start, end int) { return } -func (opts *ListOptions) setDefaultValues() { +// SetDefaultValues sets default values +func (opts *ListOptions) SetDefaultValues() { if opts.PageSize <= 0 { opts.PageSize = setting.API.DefaultPagingNum } diff --git a/models/list_options_test.go b/models/db/list_options_test.go similarity index 98% rename from models/list_options_test.go rename to models/db/list_options_test.go index 3145aa7c162c..2c860afdfbdd 100644 --- a/models/list_options_test.go +++ b/models/db/list_options_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package models +package db import ( "testing" diff --git a/models/db/test_fixtures.go b/models/db/test_fixtures.go index 172701513309..2715b688ea35 100644 --- a/models/db/test_fixtures.go +++ b/models/db/test_fixtures.go @@ -17,13 +17,18 @@ import ( var fixtures *testfixtures.Loader // InitFixtures initialize test fixtures for a test database -func InitFixtures(dir string, engine ...*xorm.Engine) (err error) { +func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { e := x if len(engine) == 1 { e = engine[0] } - testfiles := testfixtures.Directory(dir) + var testfiles func(*testfixtures.Loader) error + if opts.Dir != "" { + testfiles = testfixtures.Directory(opts.Dir) + } else { + testfiles = testfixtures.Files(opts.Files...) + } dialect := "unknown" switch e.Dialect().URI().DBType { case schemas.POSTGRES: diff --git a/models/db/unit_tests.go b/models/db/unit_tests.go index 781f0ecca20f..d81610df6b6b 100644 --- a/models/db/unit_tests.go +++ b/models/db/unit_tests.go @@ -44,11 +44,21 @@ func fatalTestError(fmtStr string, args ...interface{}) { // MainTest a reusable TestMain(..) function for unit tests that need to use a // test database. Creates the test database, and sets necessary settings. -func MainTest(m *testing.M, pathToGiteaRoot string) { +func MainTest(m *testing.M, pathToGiteaRoot string, fixtureFiles ...string) { var err error giteaRoot = pathToGiteaRoot fixturesDir = filepath.Join(pathToGiteaRoot, "models", "fixtures") - if err = CreateTestEngine(fixturesDir); err != nil { + + var opts FixturesOptions + if len(fixtureFiles) == 0 { + opts.Dir = fixturesDir + } else { + for _, f := range fixtureFiles { + opts.Files = append(opts.Files, filepath.Join(fixturesDir, f)) + } + } + + if err = CreateTestEngine(opts); err != nil { fatalTestError("Error creating test engine: %v\n", err) } @@ -102,8 +112,14 @@ func MainTest(m *testing.M, pathToGiteaRoot string) { os.Exit(exitStatus) } +// FixturesOptions fixtures needs to be loaded options +type FixturesOptions struct { + Dir string + Files []string +} + // CreateTestEngine creates a memory database and loads the fixture data from fixturesDir -func CreateTestEngine(fixturesDir string) error { +func CreateTestEngine(opts FixturesOptions) error { var err error x, err = xorm.NewEngine("sqlite3", "file::memory:?cache=shared&_txlock=immediate") if err != nil { @@ -123,7 +139,7 @@ func CreateTestEngine(fixturesDir string) error { e: x, } - return InitFixtures(fixturesDir) + return InitFixtures(opts) } // PrepareTestDatabase load test fixtures into test database diff --git a/models/error.go b/models/error.go index fd8f2771ae25..956b24009735 100644 --- a/models/error.go +++ b/models/error.go @@ -1836,58 +1836,6 @@ func (err ErrAttachmentNotExist) Error() string { return fmt.Sprintf("attachment does not exist [id: %d, uuid: %s]", err.ID, err.UUID) } -// .____ .__ _________ -// | | ____ ____ |__| ____ / _____/ ____ __ _________ ____ ____ -// | | / _ \ / ___\| |/ \ \_____ \ / _ \| | \_ __ \_/ ___\/ __ \ -// | |__( <_> ) /_/ > | | \ / ( <_> ) | /| | \/\ \__\ ___/ -// |_______ \____/\___ /|__|___| / /_______ /\____/|____/ |__| \___ >___ > -// \/ /_____/ \/ \/ \/ \/ - -// ErrLoginSourceNotExist represents a "LoginSourceNotExist" kind of error. -type ErrLoginSourceNotExist struct { - ID int64 -} - -// IsErrLoginSourceNotExist checks if an error is a ErrLoginSourceNotExist. -func IsErrLoginSourceNotExist(err error) bool { - _, ok := err.(ErrLoginSourceNotExist) - return ok -} - -func (err ErrLoginSourceNotExist) Error() string { - return fmt.Sprintf("login source does not exist [id: %d]", err.ID) -} - -// ErrLoginSourceAlreadyExist represents a "LoginSourceAlreadyExist" kind of error. -type ErrLoginSourceAlreadyExist struct { - Name string -} - -// IsErrLoginSourceAlreadyExist checks if an error is a ErrLoginSourceAlreadyExist. -func IsErrLoginSourceAlreadyExist(err error) bool { - _, ok := err.(ErrLoginSourceAlreadyExist) - return ok -} - -func (err ErrLoginSourceAlreadyExist) Error() string { - return fmt.Sprintf("login source already exists [name: %s]", err.Name) -} - -// ErrLoginSourceInUse represents a "LoginSourceInUse" kind of error. -type ErrLoginSourceInUse struct { - ID int64 -} - -// IsErrLoginSourceInUse checks if an error is a ErrLoginSourceInUse. -func IsErrLoginSourceInUse(err error) bool { - _, ok := err.(ErrLoginSourceInUse) - return ok -} - -func (err ErrLoginSourceInUse) Error() string { - return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID) -} - // ___________ // \__ ___/___ _____ _____ // | |_/ __ \\__ \ / \ @@ -2159,42 +2107,3 @@ func (err ErrNotValidReviewRequest) Error() string { err.UserID, err.RepoID) } - -// ________ _____ __ .__ -// \_____ \ / _ \ __ ___/ |_| |__ -// / | \ / /_\ \| | \ __\ | \ -// / | \/ | \ | /| | | Y \ -// \_______ /\____|__ /____/ |__| |___| / -// \/ \/ \/ - -// ErrOAuthClientIDInvalid will be thrown if client id cannot be found -type ErrOAuthClientIDInvalid struct { - ClientID string -} - -// IsErrOauthClientIDInvalid checks if an error is a ErrReviewNotExist. -func IsErrOauthClientIDInvalid(err error) bool { - _, ok := err.(ErrOAuthClientIDInvalid) - return ok -} - -// Error returns the error message -func (err ErrOAuthClientIDInvalid) Error() string { - return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID) -} - -// ErrOAuthApplicationNotFound will be thrown if id cannot be found -type ErrOAuthApplicationNotFound struct { - ID int64 -} - -// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist. -func IsErrOAuthApplicationNotFound(err error) bool { - _, ok := err.(ErrOAuthApplicationNotFound) - return ok -} - -// Error returns the error message -func (err ErrOAuthApplicationNotFound) Error() string { - return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID) -} diff --git a/models/external_login_user.go b/models/external_login_user.go index c6a4b71b53e2..6b023a4cb2a9 100644 --- a/models/external_login_user.go +++ b/models/external_login_user.go @@ -8,6 +8,7 @@ import ( "time" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/structs" "github.com/markbates/goth" @@ -106,7 +107,7 @@ func GetUserIDByExternalUserID(provider, userID string) (int64, error) { // UpdateExternalUser updates external user's information func UpdateExternalUser(user *User, gothUser goth.User) error { - loginSource, err := GetActiveOAuth2LoginSourceByName(gothUser.Provider) + loginSource, err := login.GetActiveOAuth2LoginSourceByName(gothUser.Provider) if err != nil { return err } diff --git a/models/gpg_key.go b/models/gpg_key.go index d8dd79c28538..a62ed61966ed 100644 --- a/models/gpg_key.go +++ b/models/gpg_key.go @@ -62,14 +62,14 @@ func (key *GPGKey) AfterLoad(session *xorm.Session) { } // ListGPGKeys returns a list of public keys belongs to given user. -func ListGPGKeys(uid int64, listOptions ListOptions) ([]*GPGKey, error) { +func ListGPGKeys(uid int64, listOptions db.ListOptions) ([]*GPGKey, error) { return listGPGKeys(db.GetEngine(db.DefaultContext), uid, listOptions) } -func listGPGKeys(e db.Engine, uid int64, listOptions ListOptions) ([]*GPGKey, error) { +func listGPGKeys(e db.Engine, uid int64, listOptions db.ListOptions) ([]*GPGKey, error) { sess := e.Table(&GPGKey{}).Where("owner_id=? AND primary_key_id=''", uid) if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) } keys := make([]*GPGKey, 0, 2) diff --git a/models/gpg_key_commit_verification.go b/models/gpg_key_commit_verification.go index a4c7d702850f..f508303a0965 100644 --- a/models/gpg_key_commit_verification.go +++ b/models/gpg_key_commit_verification.go @@ -9,6 +9,7 @@ import ( "hash" "strings" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -156,7 +157,7 @@ func ParseCommitWithSignature(c *git.Commit) *CommitVerification { // Now try to associate the signature with the committer, if present if committer.ID != 0 { - keys, err := ListGPGKeys(committer.ID, ListOptions{}) + keys, err := ListGPGKeys(committer.ID, db.ListOptions{}) if err != nil { // Skipping failed to get gpg keys of user log.Error("ListGPGKeys: %v", err) return &CommitVerification{ diff --git a/models/issue.go b/models/issue.go index cafd996ac51d..b8c7053b2d2a 100644 --- a/models/issue.go +++ b/models/issue.go @@ -1122,7 +1122,7 @@ func GetIssuesByIDs(issueIDs []int64) ([]*Issue, error) { // IssuesOptions represents options of an issue. type IssuesOptions struct { - ListOptions + db.ListOptions RepoIDs []int64 // include all repos if empty AssigneeID int64 PosterID int64 diff --git a/models/issue_comment.go b/models/issue_comment.go index d8f8e36df288..01e41814a474 100644 --- a/models/issue_comment.go +++ b/models/issue_comment.go @@ -964,7 +964,7 @@ func getCommentByID(e db.Engine, id int64) (*Comment, error) { // FindCommentsOptions describes the conditions to Find comments type FindCommentsOptions struct { - ListOptions + db.ListOptions RepoID int64 IssueID int64 ReviewID int64 @@ -1012,7 +1012,7 @@ func findComments(e db.Engine, opts *FindCommentsOptions) ([]*Comment, error) { } if opts.Page != 0 { - sess = setSessionPagination(sess, opts) + sess = db.SetSessionPagination(sess, opts) } // WARNING: If you change this order you will need to fix createCodeComment diff --git a/models/issue_label.go b/models/issue_label.go index 87d7eb922123..293b7140f777 100644 --- a/models/issue_label.go +++ b/models/issue_label.go @@ -447,7 +447,7 @@ func GetLabelsInRepoByIDs(repoID int64, labelIDs []int64) ([]*Label, error) { Find(&labels) } -func getLabelsByRepoID(e db.Engine, repoID int64, sortType string, listOptions ListOptions) ([]*Label, error) { +func getLabelsByRepoID(e db.Engine, repoID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { if repoID <= 0 { return nil, ErrRepoLabelNotExist{0, repoID} } @@ -466,14 +466,14 @@ func getLabelsByRepoID(e db.Engine, repoID int64, sortType string, listOptions L } if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) } return labels, sess.Find(&labels) } // GetLabelsByRepoID returns all labels that belong to given repository by ID. -func GetLabelsByRepoID(repoID int64, sortType string, listOptions ListOptions) ([]*Label, error) { +func GetLabelsByRepoID(repoID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { return getLabelsByRepoID(db.GetEngine(db.DefaultContext), repoID, sortType, listOptions) } @@ -564,7 +564,7 @@ func GetLabelsInOrgByIDs(orgID int64, labelIDs []int64) ([]*Label, error) { Find(&labels) } -func getLabelsByOrgID(e db.Engine, orgID int64, sortType string, listOptions ListOptions) ([]*Label, error) { +func getLabelsByOrgID(e db.Engine, orgID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { if orgID <= 0 { return nil, ErrOrgLabelNotExist{0, orgID} } @@ -583,14 +583,14 @@ func getLabelsByOrgID(e db.Engine, orgID int64, sortType string, listOptions Lis } if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) } return labels, sess.Find(&labels) } // GetLabelsByOrgID returns all labels that belong to given organization by ID. -func GetLabelsByOrgID(orgID int64, sortType string, listOptions ListOptions) ([]*Label, error) { +func GetLabelsByOrgID(orgID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { return getLabelsByOrgID(db.GetEngine(db.DefaultContext), orgID, sortType, listOptions) } diff --git a/models/issue_label_test.go b/models/issue_label_test.go index 384965b846c2..93807a326f80 100644 --- a/models/issue_label_test.go +++ b/models/issue_label_test.go @@ -123,7 +123,7 @@ func TestGetLabelsInRepoByIDs(t *testing.T) { func TestGetLabelsByRepoID(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) testSuccess := func(repoID int64, sortType string, expectedIssueIDs []int64) { - labels, err := GetLabelsByRepoID(repoID, sortType, ListOptions{}) + labels, err := GetLabelsByRepoID(repoID, sortType, db.ListOptions{}) assert.NoError(t, err) assert.Len(t, labels, len(expectedIssueIDs)) for i, label := range labels { @@ -214,7 +214,7 @@ func TestGetLabelsInOrgByIDs(t *testing.T) { func TestGetLabelsByOrgID(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) testSuccess := func(orgID int64, sortType string, expectedIssueIDs []int64) { - labels, err := GetLabelsByOrgID(orgID, sortType, ListOptions{}) + labels, err := GetLabelsByOrgID(orgID, sortType, db.ListOptions{}) assert.NoError(t, err) assert.Len(t, labels, len(expectedIssueIDs)) for i, label := range labels { @@ -227,10 +227,10 @@ func TestGetLabelsByOrgID(t *testing.T) { testSuccess(3, "default", []int64{3, 4}) var err error - _, err = GetLabelsByOrgID(0, "leastissues", ListOptions{}) + _, err = GetLabelsByOrgID(0, "leastissues", db.ListOptions{}) assert.True(t, IsErrOrgLabelNotExist(err)) - _, err = GetLabelsByOrgID(-1, "leastissues", ListOptions{}) + _, err = GetLabelsByOrgID(-1, "leastissues", db.ListOptions{}) assert.True(t, IsErrOrgLabelNotExist(err)) } diff --git a/models/issue_milestone.go b/models/issue_milestone.go index fb6ced5b41a3..3898e5b39785 100644 --- a/models/issue_milestone.go +++ b/models/issue_milestone.go @@ -378,7 +378,7 @@ func (milestones MilestoneList) getMilestoneIDs() []int64 { // GetMilestonesOption contain options to get milestones type GetMilestonesOption struct { - ListOptions + db.ListOptions RepoID int64 State api.StateType Name string @@ -413,7 +413,7 @@ func GetMilestones(opts GetMilestonesOption) (MilestoneList, int64, error) { sess := db.GetEngine(db.DefaultContext).Where(opts.toCond()) if opts.Page != 0 { - sess = setSessionPagination(sess, &opts) + sess = db.SetSessionPagination(sess, &opts) } switch opts.SortType { diff --git a/models/issue_milestone_test.go b/models/issue_milestone_test.go index 519b65715d15..099fe47c7c1c 100644 --- a/models/issue_milestone_test.go +++ b/models/issue_milestone_test.go @@ -102,7 +102,7 @@ func TestGetMilestones(t *testing.T) { test := func(sortType string, sortCond func(*Milestone) int) { for _, page := range []int{0, 1} { milestones, _, err := GetMilestones(GetMilestonesOption{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: page, PageSize: setting.UI.IssuePagingNum, }, @@ -119,7 +119,7 @@ func TestGetMilestones(t *testing.T) { assert.True(t, sort.IntsAreSorted(values)) milestones, _, err = GetMilestones(GetMilestonesOption{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: page, PageSize: setting.UI.IssuePagingNum, }, diff --git a/models/issue_reaction.go b/models/issue_reaction.go index 4e49add5c2be..423eb8b96cf6 100644 --- a/models/issue_reaction.go +++ b/models/issue_reaction.go @@ -35,7 +35,7 @@ func init() { // FindReactionsOptions describes the conditions to Find reactions type FindReactionsOptions struct { - ListOptions + db.ListOptions IssueID int64 CommentID int64 UserID int64 @@ -78,7 +78,7 @@ func FindCommentReactions(comment *Comment) (ReactionList, error) { } // FindIssueReactions returns a ReactionList of all reactions from an issue -func FindIssueReactions(issue *Issue, listOptions ListOptions) (ReactionList, error) { +func FindIssueReactions(issue *Issue, listOptions db.ListOptions) (ReactionList, error) { return findReactions(db.GetEngine(db.DefaultContext), FindReactionsOptions{ ListOptions: listOptions, IssueID: issue.ID, @@ -92,7 +92,7 @@ func findReactions(e db.Engine, opts FindReactionsOptions) ([]*Reaction, error) In("reaction.`type`", setting.UI.Reactions). Asc("reaction.issue_id", "reaction.comment_id", "reaction.created_unix", "reaction.id") if opts.Page != 0 { - e = setEnginePagination(e, &opts) + e = db.SetEnginePagination(e, &opts) reactions := make([]*Reaction, 0, opts.PageSize) return reactions, e.Find(&reactions) diff --git a/models/issue_stopwatch.go b/models/issue_stopwatch.go index 157658e182df..e8f19dd738c4 100644 --- a/models/issue_stopwatch.go +++ b/models/issue_stopwatch.go @@ -46,11 +46,11 @@ func getStopwatch(e db.Engine, userID, issueID int64) (sw *Stopwatch, exists boo } // GetUserStopwatches return list of all stopwatches of a user -func GetUserStopwatches(userID int64, listOptions ListOptions) ([]*Stopwatch, error) { +func GetUserStopwatches(userID int64, listOptions db.ListOptions) ([]*Stopwatch, error) { sws := make([]*Stopwatch, 0, 8) sess := db.GetEngine(db.DefaultContext).Where("stopwatch.user_id = ?", userID) if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) } err := sess.Find(&sws) diff --git a/models/issue_test.go b/models/issue_test.go index d5f6f36e9c6e..d726a2434476 100644 --- a/models/issue_test.go +++ b/models/issue_test.go @@ -151,7 +151,7 @@ func TestIssues(t *testing.T) { IssuesOptions{ RepoIDs: []int64{1, 3}, SortType: "oldest", - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, PageSize: 4, }, @@ -161,7 +161,7 @@ func TestIssues(t *testing.T) { { IssuesOptions{ LabelIDs: []int64{1}, - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, PageSize: 4, }, @@ -171,7 +171,7 @@ func TestIssues(t *testing.T) { { IssuesOptions{ LabelIDs: []int64{1, 2}, - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, PageSize: 4, }, diff --git a/models/issue_tracked_time.go b/models/issue_tracked_time.go index d024c6896f62..79de8910194a 100644 --- a/models/issue_tracked_time.go +++ b/models/issue_tracked_time.go @@ -75,7 +75,7 @@ func (tl TrackedTimeList) LoadAttributes() (err error) { // FindTrackedTimesOptions represent the filters for tracked times. If an ID is 0 it will be ignored. type FindTrackedTimesOptions struct { - ListOptions + db.ListOptions IssueID int64 UserID int64 RepositoryID int64 @@ -118,7 +118,7 @@ func (opts *FindTrackedTimesOptions) toSession(e db.Engine) db.Engine { sess = sess.Where(opts.toCond()) if opts.Page != 0 { - sess = setEnginePagination(sess, opts) + sess = db.SetEnginePagination(sess, opts) } return sess diff --git a/models/issue_watch.go b/models/issue_watch.go index cc1edcba1b66..5bac406ad01b 100644 --- a/models/issue_watch.go +++ b/models/issue_watch.go @@ -103,11 +103,11 @@ func getIssueWatchersIDs(e db.Engine, issueID int64, watching bool) ([]int64, er } // GetIssueWatchers returns watchers/unwatchers of a given issue -func GetIssueWatchers(issueID int64, listOptions ListOptions) (IssueWatchList, error) { +func GetIssueWatchers(issueID int64, listOptions db.ListOptions) (IssueWatchList, error) { return getIssueWatchers(db.GetEngine(db.DefaultContext), issueID, listOptions) } -func getIssueWatchers(e db.Engine, issueID int64, listOptions ListOptions) (IssueWatchList, error) { +func getIssueWatchers(e db.Engine, issueID int64, listOptions db.ListOptions) (IssueWatchList, error) { sess := e. Where("`issue_watch`.issue_id = ?", issueID). And("`issue_watch`.is_watching = ?", true). @@ -116,7 +116,7 @@ func getIssueWatchers(e db.Engine, issueID int64, listOptions ListOptions) (Issu Join("INNER", "`user`", "`user`.id = `issue_watch`.user_id") if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) watches := make([]*IssueWatch, 0, listOptions.PageSize) return watches, sess.Find(&watches) } diff --git a/models/issue_watch_test.go b/models/issue_watch_test.go index f85e7cef59a6..139ed41cb608 100644 --- a/models/issue_watch_test.go +++ b/models/issue_watch_test.go @@ -43,22 +43,22 @@ func TestGetIssueWatch(t *testing.T) { func TestGetIssueWatchers(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) - iws, err := GetIssueWatchers(1, ListOptions{}) + iws, err := GetIssueWatchers(1, db.ListOptions{}) assert.NoError(t, err) // Watcher is inactive, thus 0 assert.Len(t, iws, 0) - iws, err = GetIssueWatchers(2, ListOptions{}) + iws, err = GetIssueWatchers(2, db.ListOptions{}) assert.NoError(t, err) // Watcher is explicit not watching assert.Len(t, iws, 0) - iws, err = GetIssueWatchers(5, ListOptions{}) + iws, err = GetIssueWatchers(5, db.ListOptions{}) assert.NoError(t, err) // Issue has no Watchers assert.Len(t, iws, 0) - iws, err = GetIssueWatchers(7, ListOptions{}) + iws, err = GetIssueWatchers(7, db.ListOptions{}) assert.NoError(t, err) // Issue has one watcher assert.Len(t, iws, 1) diff --git a/models/login/main_test.go b/models/login/main_test.go new file mode 100644 index 000000000000..ef4b5907bfd4 --- /dev/null +++ b/models/login/main_test.go @@ -0,0 +1,21 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package login + +import ( + "path/filepath" + "testing" + + "code.gitea.io/gitea/models/db" +) + +func TestMain(m *testing.M) { + db.MainTest(m, filepath.Join("..", ".."), + "login_source.yml", + "oauth2_application.yml", + "oauth2_authorization_code.yml", + "oauth2_grant.yml", + ) +} diff --git a/models/login/oauth2.go b/models/login/oauth2.go new file mode 100644 index 000000000000..45ab59dd78ca --- /dev/null +++ b/models/login/oauth2.go @@ -0,0 +1,70 @@ +// Copyright 2017 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package login + +import ( + "fmt" + + "code.gitea.io/gitea/models/db" +) + +// ________ _____ __ .__ +// \_____ \ / _ \ __ ___/ |_| |__ +// / | \ / /_\ \| | \ __\ | \ +// / | \/ | \ | /| | | Y \ +// \_______ /\____|__ /____/ |__| |___| / +// \/ \/ \/ + +// ErrOAuthClientIDInvalid will be thrown if client id cannot be found +type ErrOAuthClientIDInvalid struct { + ClientID string +} + +// IsErrOauthClientIDInvalid checks if an error is a ErrReviewNotExist. +func IsErrOauthClientIDInvalid(err error) bool { + _, ok := err.(ErrOAuthClientIDInvalid) + return ok +} + +// Error returns the error message +func (err ErrOAuthClientIDInvalid) Error() string { + return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID) +} + +// ErrOAuthApplicationNotFound will be thrown if id cannot be found +type ErrOAuthApplicationNotFound struct { + ID int64 +} + +// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist. +func IsErrOAuthApplicationNotFound(err error) bool { + _, ok := err.(ErrOAuthApplicationNotFound) + return ok +} + +// Error returns the error message +func (err ErrOAuthApplicationNotFound) Error() string { + return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID) +} + +// GetActiveOAuth2ProviderLoginSources returns all actived LoginOAuth2 sources +func GetActiveOAuth2ProviderLoginSources() ([]*Source, error) { + sources := make([]*Source, 0, 1) + if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil { + return nil, err + } + return sources, nil +} + +// GetActiveOAuth2LoginSourceByName returns a OAuth2 LoginSource based on the given name +func GetActiveOAuth2LoginSourceByName(name string) (*Source, error) { + loginSource := new(Source) + has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(loginSource) + if !has || err != nil { + return nil, err + } + + return loginSource, nil +} diff --git a/models/oauth2_application.go b/models/login/oauth2_application.go similarity index 96% rename from models/oauth2_application.go rename to models/login/oauth2_application.go index 0fd2e38472e0..060bfe5bc3b0 100644 --- a/models/oauth2_application.go +++ b/models/login/oauth2_application.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package models +package login import ( "crypto/sha256" @@ -23,19 +23,14 @@ import ( // OAuth2Application represents an OAuth2 client (RFC 6749) type OAuth2Application struct { - ID int64 `xorm:"pk autoincr"` - UID int64 `xorm:"INDEX"` - User *User `xorm:"-"` - - Name string - + ID int64 `xorm:"pk autoincr"` + UID int64 `xorm:"INDEX"` + Name string ClientID string `xorm:"unique"` ClientSecret string - - RedirectURIs []string `xorm:"redirect_uris JSON TEXT"` - - CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` - UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` + RedirectURIs []string `xorm:"redirect_uris JSON TEXT"` + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` } func init() { @@ -57,14 +52,6 @@ func (app *OAuth2Application) PrimaryRedirectURI() string { return app.RedirectURIs[0] } -// LoadUser will load User by UID -func (app *OAuth2Application) LoadUser() (err error) { - if app.User == nil { - app.User, err = GetUserByID(app.UID) - } - return -} - // ContainsRedirectURI checks if redirectURI is allowed for app func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool { return util.IsStringInSlice(redirectURI, app.RedirectURIs, true) @@ -276,13 +263,13 @@ func DeleteOAuth2Application(id, userid int64) error { } // ListOAuth2Applications returns a list of oauth2 applications belongs to given user. -func ListOAuth2Applications(uid int64, listOptions ListOptions) ([]*OAuth2Application, int64, error) { +func ListOAuth2Applications(uid int64, listOptions db.ListOptions) ([]*OAuth2Application, int64, error) { sess := db.GetEngine(db.DefaultContext). Where("uid=?", uid). Desc("id") if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) apps := make([]*OAuth2Application, 0, listOptions.PageSize) total, err := sess.FindAndCount(&apps) diff --git a/models/oauth2_application_test.go b/models/login/oauth2_application_test.go similarity index 96% rename from models/oauth2_application_test.go rename to models/login/oauth2_application_test.go index b01ef967fc88..cb064cef1b4d 100644 --- a/models/oauth2_application_test.go +++ b/models/login/oauth2_application_test.go @@ -2,12 +2,13 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package models +package login import ( "testing" "code.gitea.io/gitea/models/db" + "github.com/stretchr/testify/assert" ) @@ -69,13 +70,6 @@ func TestCreateOAuth2Application(t *testing.T) { db.AssertExistsAndLoadBean(t, &OAuth2Application{Name: "newapp"}) } -func TestOAuth2Application_LoadUser(t *testing.T) { - assert.NoError(t, db.PrepareTestDatabase()) - app := db.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application) - assert.NoError(t, app.LoadUser()) - assert.NotNil(t, app.User) -} - func TestOAuth2Application_TableName(t *testing.T) { assert.Equal(t, "oauth2_application", new(OAuth2Application).TableName()) } diff --git a/models/login_source.go b/models/login/source.go similarity index 55% rename from models/login_source.go rename to models/login/source.go index e1f7a7e08e51..1001d49b51b9 100644 --- a/models/login_source.go +++ b/models/login/source.go @@ -3,9 +3,10 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package models +package login import ( + "fmt" "reflect" "strconv" @@ -17,43 +18,43 @@ import ( "xorm.io/xorm/convert" ) -// LoginType represents an login type. -type LoginType int +// Type represents an login type. +type Type int // Note: new type must append to the end of list to maintain compatibility. const ( - LoginNoType LoginType = iota - LoginPlain // 1 - LoginLDAP // 2 - LoginSMTP // 3 - LoginPAM // 4 - LoginDLDAP // 5 - LoginOAuth2 // 6 - LoginSSPI // 7 + NoType Type = iota + Plain // 1 + LDAP // 2 + SMTP // 3 + PAM // 4 + DLDAP // 5 + OAuth2 // 6 + SSPI // 7 ) // String returns the string name of the LoginType -func (typ LoginType) String() string { - return LoginNames[typ] +func (typ Type) String() string { + return Names[typ] } // Int returns the int value of the LoginType -func (typ LoginType) Int() int { +func (typ Type) Int() int { return int(typ) } -// LoginNames contains the name of LoginType values. -var LoginNames = map[LoginType]string{ - LoginLDAP: "LDAP (via BindDN)", - LoginDLDAP: "LDAP (simple auth)", // Via direct bind - LoginSMTP: "SMTP", - LoginPAM: "PAM", - LoginOAuth2: "OAuth2", - LoginSSPI: "SPNEGO with SSPI", +// Names contains the name of LoginType values. +var Names = map[Type]string{ + LDAP: "LDAP (via BindDN)", + DLDAP: "LDAP (simple auth)", // Via direct bind + SMTP: "SMTP", + PAM: "PAM", + OAuth2: "OAuth2", + SSPI: "SPNEGO with SSPI", } -// LoginConfig represents login config as far as the db is concerned -type LoginConfig interface { +// Config represents login config as far as the db is concerned +type Config interface { convert.Conversion } @@ -83,33 +84,33 @@ type RegisterableSource interface { UnregisterSource() error } -// LoginSourceSettable configurations can have their loginSource set on them -type LoginSourceSettable interface { - SetLoginSource(*LoginSource) +// SourceSettable configurations can have their loginSource set on them +type SourceSettable interface { + SetLoginSource(*Source) } -// RegisterLoginTypeConfig register a config for a provided type -func RegisterLoginTypeConfig(typ LoginType, exemplar LoginConfig) { +// RegisterTypeConfig register a config for a provided type +func RegisterTypeConfig(typ Type, exemplar Config) { if reflect.TypeOf(exemplar).Kind() == reflect.Ptr { // Pointer: - registeredLoginConfigs[typ] = func() LoginConfig { - return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(LoginConfig) + registeredConfigs[typ] = func() Config { + return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config) } return } // Not a Pointer - registeredLoginConfigs[typ] = func() LoginConfig { - return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(LoginConfig) + registeredConfigs[typ] = func() Config { + return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config) } } -var registeredLoginConfigs = map[LoginType]func() LoginConfig{} +var registeredConfigs = map[Type]func() Config{} -// LoginSource represents an external way for authorizing users. -type LoginSource struct { +// Source represents an external way for authorizing users. +type Source struct { ID int64 `xorm:"pk autoincr"` - Type LoginType + Type Type Name string `xorm:"UNIQUE"` IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"` IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"` @@ -119,8 +120,13 @@ type LoginSource struct { UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` } +// TableName xorm will read the table name from this method +func (Source) TableName() string { + return "login_source" +} + func init() { - db.RegisterModel(new(LoginSource)) + db.RegisterModel(new(Source)) } // Cell2Int64 converts a xorm.Cell type to int64, @@ -137,82 +143,82 @@ func Cell2Int64(val xorm.Cell) int64 { } // BeforeSet is invoked from XORM before setting the value of a field of this object. -func (source *LoginSource) BeforeSet(colName string, val xorm.Cell) { +func (source *Source) BeforeSet(colName string, val xorm.Cell) { if colName == "type" { - typ := LoginType(Cell2Int64(val)) - constructor, ok := registeredLoginConfigs[typ] + typ := Type(Cell2Int64(val)) + constructor, ok := registeredConfigs[typ] if !ok { return } source.Cfg = constructor() - if settable, ok := source.Cfg.(LoginSourceSettable); ok { + if settable, ok := source.Cfg.(SourceSettable); ok { settable.SetLoginSource(source) } } } // TypeName return name of this login source type. -func (source *LoginSource) TypeName() string { - return LoginNames[source.Type] +func (source *Source) TypeName() string { + return Names[source.Type] } // IsLDAP returns true of this source is of the LDAP type. -func (source *LoginSource) IsLDAP() bool { - return source.Type == LoginLDAP +func (source *Source) IsLDAP() bool { + return source.Type == LDAP } // IsDLDAP returns true of this source is of the DLDAP type. -func (source *LoginSource) IsDLDAP() bool { - return source.Type == LoginDLDAP +func (source *Source) IsDLDAP() bool { + return source.Type == DLDAP } // IsSMTP returns true of this source is of the SMTP type. -func (source *LoginSource) IsSMTP() bool { - return source.Type == LoginSMTP +func (source *Source) IsSMTP() bool { + return source.Type == SMTP } // IsPAM returns true of this source is of the PAM type. -func (source *LoginSource) IsPAM() bool { - return source.Type == LoginPAM +func (source *Source) IsPAM() bool { + return source.Type == PAM } // IsOAuth2 returns true of this source is of the OAuth2 type. -func (source *LoginSource) IsOAuth2() bool { - return source.Type == LoginOAuth2 +func (source *Source) IsOAuth2() bool { + return source.Type == OAuth2 } // IsSSPI returns true of this source is of the SSPI type. -func (source *LoginSource) IsSSPI() bool { - return source.Type == LoginSSPI +func (source *Source) IsSSPI() bool { + return source.Type == SSPI } // HasTLS returns true of this source supports TLS. -func (source *LoginSource) HasTLS() bool { +func (source *Source) HasTLS() bool { hasTLSer, ok := source.Cfg.(HasTLSer) return ok && hasTLSer.HasTLS() } // UseTLS returns true of this source is configured to use TLS. -func (source *LoginSource) UseTLS() bool { +func (source *Source) UseTLS() bool { useTLSer, ok := source.Cfg.(UseTLSer) return ok && useTLSer.UseTLS() } // SkipVerify returns true if this source is configured to skip SSL // verification. -func (source *LoginSource) SkipVerify() bool { +func (source *Source) SkipVerify() bool { skipVerifiable, ok := source.Cfg.(SkipVerifiable) return ok && skipVerifiable.IsSkipVerify() } -// CreateLoginSource inserts a LoginSource in the DB if not already +// CreateSource inserts a LoginSource in the DB if not already // existing with the given name. -func CreateLoginSource(source *LoginSource) error { - has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(LoginSource)) +func CreateSource(source *Source) error { + has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source)) if err != nil { return err } else if has { - return ErrLoginSourceAlreadyExist{source.Name} + return ErrSourceAlreadyExist{source.Name} } // Synchronization is only available with LDAP for now if !source.IsLDAP() { @@ -228,7 +234,7 @@ func CreateLoginSource(source *LoginSource) error { return nil } - if settable, ok := source.Cfg.(LoginSourceSettable); ok { + if settable, ok := source.Cfg.(SourceSettable); ok { settable.SetLoginSource(source) } @@ -241,40 +247,40 @@ func CreateLoginSource(source *LoginSource) error { if err != nil { // remove the LoginSource in case of errors while registering configuration if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil { - log.Error("CreateLoginSource: Error while wrapOpenIDConnectInitializeError: %v", err) + log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err) } } return err } -// LoginSources returns a slice of all login sources found in DB. -func LoginSources() ([]*LoginSource, error) { - auths := make([]*LoginSource, 0, 6) +// Sources returns a slice of all login sources found in DB. +func Sources() ([]*Source, error) { + auths := make([]*Source, 0, 6) return auths, db.GetEngine(db.DefaultContext).Find(&auths) } -// LoginSourcesByType returns all sources of the specified type -func LoginSourcesByType(loginType LoginType) ([]*LoginSource, error) { - sources := make([]*LoginSource, 0, 1) +// SourcesByType returns all sources of the specified type +func SourcesByType(loginType Type) ([]*Source, error) { + sources := make([]*Source, 0, 1) if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil { return nil, err } return sources, nil } -// AllActiveLoginSources returns all active sources -func AllActiveLoginSources() ([]*LoginSource, error) { - sources := make([]*LoginSource, 0, 5) +// AllActiveSources returns all active sources +func AllActiveSources() ([]*Source, error) { + sources := make([]*Source, 0, 5) if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil { return nil, err } return sources, nil } -// ActiveLoginSources returns all active sources of the specified type -func ActiveLoginSources(loginType LoginType) ([]*LoginSource, error) { - sources := make([]*LoginSource, 0, 1) - if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, loginType).Find(&sources); err != nil { +// ActiveSources returns all active sources of the specified type +func ActiveSources(tp Type) ([]*Source, error) { + sources := make([]*Source, 0, 1) + if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil { return nil, err } return sources, nil @@ -286,19 +292,19 @@ func IsSSPIEnabled() bool { if !db.HasEngine { return false } - sources, err := ActiveLoginSources(LoginSSPI) + sources, err := ActiveSources(SSPI) if err != nil { - log.Error("ActiveLoginSources: %v", err) + log.Error("ActiveSources: %v", err) return false } return len(sources) > 0 } -// GetLoginSourceByID returns login source by given ID. -func GetLoginSourceByID(id int64) (*LoginSource, error) { - source := new(LoginSource) +// GetSourceByID returns login source by given ID. +func GetSourceByID(id int64) (*Source, error) { + source := new(Source) if id == 0 { - source.Cfg = registeredLoginConfigs[LoginNoType]() + source.Cfg = registeredConfigs[NoType]() // Set this source to active // FIXME: allow disabling of db based password authentication in future source.IsActive = true @@ -309,18 +315,18 @@ func GetLoginSourceByID(id int64) (*LoginSource, error) { if err != nil { return nil, err } else if !has { - return nil, ErrLoginSourceNotExist{id} + return nil, ErrSourceNotExist{id} } return source, nil } -// UpdateSource updates a LoginSource record in DB. -func UpdateSource(source *LoginSource) error { - var originalLoginSource *LoginSource +// UpdateSource updates a Source record in DB. +func UpdateSource(source *Source) error { + var originalLoginSource *Source if source.IsOAuth2() { // keep track of the original values so we can restore in case of errors while registering OAuth2 providers var err error - if originalLoginSource, err = GetLoginSourceByID(source.ID); err != nil { + if originalLoginSource, err = GetSourceByID(source.ID); err != nil { return err } } @@ -334,7 +340,7 @@ func UpdateSource(source *LoginSource) error { return nil } - if settable, ok := source.Cfg.(LoginSourceSettable); ok { + if settable, ok := source.Cfg.(SourceSettable); ok { settable.SetLoginSource(source) } @@ -353,34 +359,53 @@ func UpdateSource(source *LoginSource) error { return err } -// DeleteSource deletes a LoginSource record in DB. -func DeleteSource(source *LoginSource) error { - count, err := db.GetEngine(db.DefaultContext).Count(&User{LoginSource: source.ID}) - if err != nil { - return err - } else if count > 0 { - return ErrLoginSourceInUse{source.ID} - } +// CountSources returns number of login sources. +func CountSources() int64 { + count, _ := db.GetEngine(db.DefaultContext).Count(new(Source)) + return count +} - count, err = db.GetEngine(db.DefaultContext).Count(&ExternalLoginUser{LoginSourceID: source.ID}) - if err != nil { - return err - } else if count > 0 { - return ErrLoginSourceInUse{source.ID} - } +// ErrSourceNotExist represents a "SourceNotExist" kind of error. +type ErrSourceNotExist struct { + ID int64 +} - if registerableSource, ok := source.Cfg.(RegisterableSource); ok { - if err := registerableSource.UnregisterSource(); err != nil { - return err - } - } +// IsErrSourceNotExist checks if an error is a ErrSourceNotExist. +func IsErrSourceNotExist(err error) bool { + _, ok := err.(ErrSourceNotExist) + return ok +} - _, err = db.GetEngine(db.DefaultContext).ID(source.ID).Delete(new(LoginSource)) - return err +func (err ErrSourceNotExist) Error() string { + return fmt.Sprintf("login source does not exist [id: %d]", err.ID) } -// CountLoginSources returns number of login sources. -func CountLoginSources() int64 { - count, _ := db.GetEngine(db.DefaultContext).Count(new(LoginSource)) - return count +// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error. +type ErrSourceAlreadyExist struct { + Name string +} + +// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist. +func IsErrSourceAlreadyExist(err error) bool { + _, ok := err.(ErrSourceAlreadyExist) + return ok +} + +func (err ErrSourceAlreadyExist) Error() string { + return fmt.Sprintf("login source already exists [name: %s]", err.Name) +} + +// ErrSourceInUse represents a "SourceInUse" kind of error. +type ErrSourceInUse struct { + ID int64 +} + +// IsErrSourceInUse checks if an error is a ErrSourceInUse. +func IsErrSourceInUse(err error) bool { + _, ok := err.(ErrSourceInUse) + return ok +} + +func (err ErrSourceInUse) Error() string { + return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID) } diff --git a/models/login_source_test.go b/models/login/source_test.go similarity index 86% rename from models/login_source_test.go rename to models/login/source_test.go index ea1812238f19..d98609037cd5 100644 --- a/models/login_source_test.go +++ b/models/login/source_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package models +package login import ( "strings" @@ -36,13 +36,13 @@ func (source *TestSource) ToDB() ([]byte, error) { func TestDumpLoginSource(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) - loginSourceSchema, err := db.TableInfo(new(LoginSource)) + loginSourceSchema, err := db.TableInfo(new(Source)) assert.NoError(t, err) - RegisterLoginTypeConfig(LoginOAuth2, new(TestSource)) + RegisterTypeConfig(OAuth2, new(TestSource)) - CreateLoginSource(&LoginSource{ - Type: LoginOAuth2, + CreateSource(&Source{ + Type: OAuth2, Name: "TestSource", IsActive: false, Cfg: &TestSource{ diff --git a/models/migrations/migrations_test.go b/models/migrations/migrations_test.go index fffc44be122e..78624f1e2715 100644 --- a/models/migrations/migrations_test.go +++ b/models/migrations/migrations_test.go @@ -241,7 +241,10 @@ func prepareTestEnv(t *testing.T, skip int, syncModels ...interface{}) (*xorm.En if _, err := os.Stat(fixturesDir); err == nil { t.Logf("initializing fixtures from: %s", fixturesDir) - if err := db.InitFixtures(fixturesDir, x); err != nil { + if err := db.InitFixtures( + db.FixturesOptions{ + Dir: fixturesDir, + }, x); err != nil { t.Errorf("error whilst initializing fixtures from %s: %v", fixturesDir, err) return x, deferFn } diff --git a/models/notification.go b/models/notification.go index af24a6cf5a7a..bcbe8b0988e8 100644 --- a/models/notification.go +++ b/models/notification.go @@ -74,7 +74,7 @@ func init() { // FindNotificationOptions represent the filters for notifications. If an ID is 0 it will be ignored. type FindNotificationOptions struct { - ListOptions + db.ListOptions UserID int64 RepoID int64 IssueID int64 @@ -115,7 +115,7 @@ func (opts *FindNotificationOptions) ToCond() builder.Cond { func (opts *FindNotificationOptions) ToSession(e db.Engine) *xorm.Session { sess := e.Where(opts.ToCond()) if opts.Page != 0 { - sess = setSessionPagination(sess, opts) + sess = db.SetSessionPagination(sess, opts) } return sess } diff --git a/models/oauth2.go b/models/oauth2.go deleted file mode 100644 index 7fdd5309fb92..000000000000 --- a/models/oauth2.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2017 The Gitea Authors. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE file. - -package models - -import "code.gitea.io/gitea/models/db" - -// GetActiveOAuth2ProviderLoginSources returns all actived LoginOAuth2 sources -func GetActiveOAuth2ProviderLoginSources() ([]*LoginSource, error) { - sources := make([]*LoginSource, 0, 1) - if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, LoginOAuth2).Find(&sources); err != nil { - return nil, err - } - return sources, nil -} - -// GetActiveOAuth2LoginSourceByName returns a OAuth2 LoginSource based on the given name -func GetActiveOAuth2LoginSourceByName(name string) (*LoginSource, error) { - loginSource := new(LoginSource) - has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, LoginOAuth2, true).Get(loginSource) - if !has || err != nil { - return nil, err - } - - return loginSource, nil -} diff --git a/models/org.go b/models/org.go index bc6c47fd456c..94939d2c74b8 100644 --- a/models/org.go +++ b/models/org.go @@ -78,7 +78,7 @@ func (org *User) GetMembers() (err error) { // FindOrgMembersOpts represensts find org members conditions type FindOrgMembersOpts struct { - ListOptions + db.ListOptions OrgID int64 PublicOnly bool } @@ -574,7 +574,7 @@ func GetOrgUsersByUserID(uid int64, opts *SearchOrganizationsOptions) ([]*OrgUse } if opts.PageSize != 0 { - sess = setSessionPagination(sess, opts) + sess = db.SetSessionPagination(sess, opts) } err := sess. @@ -594,7 +594,7 @@ func getOrgUsersByOrgID(e db.Engine, opts *FindOrgMembersOpts) ([]*OrgUser, erro sess.And("is_public = ?", true) } if opts.ListOptions.PageSize > 0 { - sess = setSessionPagination(sess, opts) + sess = db.SetSessionPagination(sess, opts) ous := make([]*OrgUser, 0, opts.PageSize) return ous, sess.Find(&ous) diff --git a/models/org_team.go b/models/org_team.go index 7ca715bb7899..fc6a5f2c3b18 100644 --- a/models/org_team.go +++ b/models/org_team.go @@ -47,7 +47,7 @@ func init() { // SearchTeamOptions holds the search options type SearchTeamOptions struct { - ListOptions + db.ListOptions UserID int64 Keyword string OrgID int64 @@ -56,7 +56,7 @@ type SearchTeamOptions struct { // SearchMembersOptions holds the search options type SearchMembersOptions struct { - ListOptions + db.ListOptions } // SearchTeam search for teams. Caller is responsible to check permissions. @@ -176,7 +176,7 @@ func (t *Team) GetRepositories(opts *SearchTeamOptions) error { return t.getRepositories(db.GetEngine(db.DefaultContext)) } - return t.getRepositories(getPaginatedSession(opts)) + return t.getRepositories(db.GetPaginatedSession(opts)) } func (t *Team) getMembers(e db.Engine) (err error) { @@ -190,7 +190,7 @@ func (t *Team) GetMembers(opts *SearchMembersOptions) (err error) { return t.getMembers(db.GetEngine(db.DefaultContext)) } - return t.getMembers(getPaginatedSession(opts)) + return t.getMembers(db.GetPaginatedSession(opts)) } // AddMember adds new membership of the team to the organization, diff --git a/models/org_test.go b/models/org_test.go index 75dfc4262d5c..2df89b2afcac 100644 --- a/models/org_test.go +++ b/models/org_test.go @@ -399,7 +399,7 @@ func TestGetOrgUsersByOrgID(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) orgUsers, err := GetOrgUsersByOrgID(&FindOrgMembersOpts{ - ListOptions: ListOptions{}, + ListOptions: db.ListOptions{}, OrgID: 3, PublicOnly: false, }) @@ -420,7 +420,7 @@ func TestGetOrgUsersByOrgID(t *testing.T) { } orgUsers, err = GetOrgUsersByOrgID(&FindOrgMembersOpts{ - ListOptions: ListOptions{}, + ListOptions: db.ListOptions{}, OrgID: db.NonexistentID, PublicOnly: false, }) diff --git a/models/pull_list.go b/models/pull_list.go index 57e2f9c85f7e..57ef21021395 100644 --- a/models/pull_list.go +++ b/models/pull_list.go @@ -17,7 +17,7 @@ import ( // PullRequestsOptions holds the options for PRs type PullRequestsOptions struct { - ListOptions + db.ListOptions State string SortType string Labels []string @@ -101,7 +101,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, log.Error("listPullRequestStatement: %v", err) return nil, maxResults, err } - findSession = setSessionPagination(findSession, opts) + findSession = db.SetSessionPagination(findSession, opts) prs := make([]*PullRequest, 0, opts.PageSize) return prs, maxResults, findSession.Find(&prs) } diff --git a/models/pull_sign.go b/models/pull_sign.go index e7cf4ab666ee..2e7cbff48b43 100644 --- a/models/pull_sign.go +++ b/models/pull_sign.go @@ -5,6 +5,7 @@ package models import ( + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -35,7 +36,7 @@ Loop: case always: break Loop case pubkey: - keys, err := ListGPGKeys(u.ID, ListOptions{}) + keys, err := ListGPGKeys(u.ID, db.ListOptions{}) if err != nil { return false, "", nil, err } diff --git a/models/pull_test.go b/models/pull_test.go index 6543d0ec9660..2b7ef2f664af 100644 --- a/models/pull_test.go +++ b/models/pull_test.go @@ -56,7 +56,7 @@ func TestPullRequest_LoadHeadRepo(t *testing.T) { func TestPullRequestsNewest(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) prs, count, err := PullRequests(1, &PullRequestsOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, }, State: "open", @@ -75,7 +75,7 @@ func TestPullRequestsNewest(t *testing.T) { func TestPullRequestsOldest(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) prs, count, err := PullRequests(1, &PullRequestsOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, }, State: "open", diff --git a/models/release.go b/models/release.go index d6b629cfe839..4624791b8ff4 100644 --- a/models/release.go +++ b/models/release.go @@ -177,7 +177,7 @@ func GetReleaseByID(id int64) (*Release, error) { // FindReleasesOptions describes the conditions to Find releases type FindReleasesOptions struct { - ListOptions + db.ListOptions IncludeDrafts bool IncludeTags bool IsPreRelease util.OptionalBool @@ -214,7 +214,7 @@ func GetReleasesByRepoID(repoID int64, opts FindReleasesOptions) ([]*Release, er Where(opts.toConds(repoID)) if opts.PageSize != 0 { - sess = setSessionPagination(sess, &opts.ListOptions) + sess = db.SetSessionPagination(sess, &opts.ListOptions) } rels := make([]*Release, 0, opts.PageSize) diff --git a/models/repo.go b/models/repo.go index ae149f467d6d..efd78c6042d6 100644 --- a/models/repo.go +++ b/models/repo.go @@ -1772,7 +1772,7 @@ func GetUserRepositories(opts *SearchRepoOptions) ([]*Repository, int64, error) sess.Where(cond).OrderBy(opts.OrderBy.String()) repos := make([]*Repository, 0, opts.PageSize) - return repos, count, setSessionPagination(sess, opts).Find(&repos) + return repos, count, db.SetSessionPagination(sess, opts).Find(&repos) } // GetUserMirrorRepositories returns a list of mirror repositories of given user. @@ -2057,13 +2057,13 @@ func CopyLFS(ctx context.Context, newRepo, oldRepo *Repository) error { } // GetForks returns all the forks of the repository -func (repo *Repository) GetForks(listOptions ListOptions) ([]*Repository, error) { +func (repo *Repository) GetForks(listOptions db.ListOptions) ([]*Repository, error) { if listOptions.Page == 0 { forks := make([]*Repository, 0, repo.NumForks) return forks, db.GetEngine(db.DefaultContext).Find(&forks, &Repository{ForkID: repo.ID}) } - sess := getPaginatedSession(&listOptions) + sess := db.GetPaginatedSession(&listOptions) forks := make([]*Repository, 0, listOptions.PageSize) return forks, sess.Find(&forks, &Repository{ForkID: repo.ID}) } diff --git a/models/repo_collaboration.go b/models/repo_collaboration.go index 08d2062dbba0..08360c102d8a 100644 --- a/models/repo_collaboration.go +++ b/models/repo_collaboration.go @@ -64,13 +64,13 @@ func (repo *Repository) AddCollaborator(u *User) error { return sess.Commit() } -func (repo *Repository) getCollaborations(e db.Engine, listOptions ListOptions) ([]*Collaboration, error) { +func (repo *Repository) getCollaborations(e db.Engine, listOptions db.ListOptions) ([]*Collaboration, error) { if listOptions.Page == 0 { collaborations := make([]*Collaboration, 0, 8) return collaborations, e.Find(&collaborations, &Collaboration{RepoID: repo.ID}) } - e = setEnginePagination(e, &listOptions) + e = db.SetEnginePagination(e, &listOptions) collaborations := make([]*Collaboration, 0, listOptions.PageSize) return collaborations, e.Find(&collaborations, &Collaboration{RepoID: repo.ID}) @@ -82,7 +82,7 @@ type Collaborator struct { Collaboration *Collaboration } -func (repo *Repository) getCollaborators(e db.Engine, listOptions ListOptions) ([]*Collaborator, error) { +func (repo *Repository) getCollaborators(e db.Engine, listOptions db.ListOptions) ([]*Collaborator, error) { collaborations, err := repo.getCollaborations(e, listOptions) if err != nil { return nil, fmt.Errorf("getCollaborations: %v", err) @@ -103,7 +103,7 @@ func (repo *Repository) getCollaborators(e db.Engine, listOptions ListOptions) ( } // GetCollaborators returns the collaborators for a repository -func (repo *Repository) GetCollaborators(listOptions ListOptions) ([]*Collaborator, error) { +func (repo *Repository) GetCollaborators(listOptions db.ListOptions) ([]*Collaborator, error) { return repo.getCollaborators(db.GetEngine(db.DefaultContext), listOptions) } diff --git a/models/repo_collaboration_test.go b/models/repo_collaboration_test.go index 5a3ffef5fae1..326fb4dbf7a8 100644 --- a/models/repo_collaboration_test.go +++ b/models/repo_collaboration_test.go @@ -30,7 +30,7 @@ func TestRepository_GetCollaborators(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) test := func(repoID int64) { repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository) - collaborators, err := repo.GetCollaborators(ListOptions{}) + collaborators, err := repo.GetCollaborators(db.ListOptions{}) assert.NoError(t, err) expectedLen, err := db.GetEngine(db.DefaultContext).Count(&Collaboration{RepoID: repoID}) assert.NoError(t, err) diff --git a/models/repo_generate.go b/models/repo_generate.go index cb8bf45184f9..650da711a347 100644 --- a/models/repo_generate.go +++ b/models/repo_generate.go @@ -151,7 +151,7 @@ func GenerateAvatar(ctx context.Context, templateRepo, generateRepo *Repository) // GenerateIssueLabels generates issue labels from a template repository func GenerateIssueLabels(ctx context.Context, templateRepo, generateRepo *Repository) error { - templateLabels, err := getLabelsByRepoID(db.GetEngine(ctx), templateRepo.ID, "", ListOptions{}) + templateLabels, err := getLabelsByRepoID(db.GetEngine(ctx), templateRepo.ID, "", db.ListOptions{}) if err != nil { return err } diff --git a/models/repo_list.go b/models/repo_list.go index 7179114f4672..6804a997c845 100644 --- a/models/repo_list.go +++ b/models/repo_list.go @@ -135,7 +135,7 @@ func (repos MirrorRepositoryList) LoadAttributes() error { // SearchRepoOptions holds the search options type SearchRepoOptions struct { - ListOptions + db.ListOptions Actor *User Keyword string OwnerID int64 diff --git a/models/repo_list_test.go b/models/repo_list_test.go index a1fd454c1061..3c30cad564d6 100644 --- a/models/repo_list_test.go +++ b/models/repo_list_test.go @@ -18,7 +18,7 @@ func TestSearchRepository(t *testing.T) { // test search public repository on explore page repos, count, err := SearchRepositoryByName(&SearchRepoOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, PageSize: 10, }, @@ -33,7 +33,7 @@ func TestSearchRepository(t *testing.T) { assert.Equal(t, int64(1), count) repos, count, err = SearchRepositoryByName(&SearchRepoOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, PageSize: 10, }, @@ -47,7 +47,7 @@ func TestSearchRepository(t *testing.T) { // test search private repository on explore page repos, count, err = SearchRepositoryByName(&SearchRepoOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, PageSize: 10, }, @@ -63,7 +63,7 @@ func TestSearchRepository(t *testing.T) { assert.Equal(t, int64(1), count) repos, count, err = SearchRepositoryByName(&SearchRepoOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, PageSize: 10, }, @@ -85,7 +85,7 @@ func TestSearchRepository(t *testing.T) { // Test search within description repos, count, err = SearchRepository(&SearchRepoOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, PageSize: 10, }, @@ -102,7 +102,7 @@ func TestSearchRepository(t *testing.T) { // Test NOT search within description repos, count, err = SearchRepository(&SearchRepoOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ Page: 1, PageSize: 10, }, @@ -122,142 +122,142 @@ func TestSearchRepository(t *testing.T) { }{ { name: "PublicRepositoriesByName", - opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: ListOptions{PageSize: 10}, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: db.ListOptions{PageSize: 10}, Collaborate: util.OptionalBoolFalse}, count: 7, }, { name: "PublicAndPrivateRepositoriesByName", - opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: ListOptions{Page: 1, PageSize: 10}, Private: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: db.ListOptions{Page: 1, PageSize: 10}, Private: true, Collaborate: util.OptionalBoolFalse}, count: 14, }, { name: "PublicAndPrivateRepositoriesByNameWithPagesizeLimitFirstPage", - opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: ListOptions{Page: 1, PageSize: 5}, Private: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: db.ListOptions{Page: 1, PageSize: 5}, Private: true, Collaborate: util.OptionalBoolFalse}, count: 14, }, { name: "PublicAndPrivateRepositoriesByNameWithPagesizeLimitSecondPage", - opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: ListOptions{Page: 2, PageSize: 5}, Private: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: db.ListOptions{Page: 2, PageSize: 5}, Private: true, Collaborate: util.OptionalBoolFalse}, count: 14, }, { name: "PublicAndPrivateRepositoriesByNameWithPagesizeLimitThirdPage", - opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: ListOptions{Page: 3, PageSize: 5}, Private: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: db.ListOptions{Page: 3, PageSize: 5}, Private: true, Collaborate: util.OptionalBoolFalse}, count: 14, }, { name: "PublicAndPrivateRepositoriesByNameWithPagesizeLimitFourthPage", - opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: ListOptions{Page: 3, PageSize: 5}, Private: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: db.ListOptions{Page: 3, PageSize: 5}, Private: true, Collaborate: util.OptionalBoolFalse}, count: 14, }, { name: "PublicRepositoriesOfUser", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Collaborate: util.OptionalBoolFalse}, count: 2, }, { name: "PublicRepositoriesOfUser2", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 18, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 18, Collaborate: util.OptionalBoolFalse}, count: 0, }, { name: "PublicRepositoriesOfUser3", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 20, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 20, Collaborate: util.OptionalBoolFalse}, count: 2, }, { name: "PublicAndPrivateRepositoriesOfUser", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Private: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Private: true, Collaborate: util.OptionalBoolFalse}, count: 4, }, { name: "PublicAndPrivateRepositoriesOfUser2", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 18, Private: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 18, Private: true, Collaborate: util.OptionalBoolFalse}, count: 0, }, { name: "PublicAndPrivateRepositoriesOfUser3", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 20, Private: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 20, Private: true, Collaborate: util.OptionalBoolFalse}, count: 4, }, { name: "PublicRepositoriesOfUserIncludingCollaborative", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 15}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 15}, count: 5, }, { name: "PublicRepositoriesOfUser2IncludingCollaborative", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 18}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 18}, count: 1, }, { name: "PublicRepositoriesOfUser3IncludingCollaborative", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 20}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 20}, count: 3, }, { name: "PublicAndPrivateRepositoriesOfUserIncludingCollaborative", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Private: true}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Private: true}, count: 9, }, { name: "PublicAndPrivateRepositoriesOfUser2IncludingCollaborative", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 18, Private: true}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 18, Private: true}, count: 4, }, { name: "PublicAndPrivateRepositoriesOfUser3IncludingCollaborative", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 20, Private: true}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 20, Private: true}, count: 7, }, { name: "PublicRepositoriesOfOrganization", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 17, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 17, Collaborate: util.OptionalBoolFalse}, count: 1, }, { name: "PublicAndPrivateRepositoriesOfOrganization", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 17, Private: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 17, Private: true, Collaborate: util.OptionalBoolFalse}, count: 2, }, { name: "AllPublic/PublicRepositoriesByName", - opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: ListOptions{PageSize: 10}, AllPublic: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: db.ListOptions{PageSize: 10}, AllPublic: true, Collaborate: util.OptionalBoolFalse}, count: 7, }, { name: "AllPublic/PublicAndPrivateRepositoriesByName", - opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: ListOptions{Page: 1, PageSize: 10}, Private: true, AllPublic: true, Collaborate: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{Keyword: "big_test_", ListOptions: db.ListOptions{Page: 1, PageSize: 10}, Private: true, AllPublic: true, Collaborate: util.OptionalBoolFalse}, count: 14, }, { name: "AllPublic/PublicRepositoriesOfUserIncludingCollaborative", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, AllPublic: true, Template: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, AllPublic: true, Template: util.OptionalBoolFalse}, count: 28, }, { name: "AllPublic/PublicAndPrivateRepositoriesOfUserIncludingCollaborative", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Private: true, AllPublic: true, AllLimited: true, Template: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Private: true, AllPublic: true, AllLimited: true, Template: util.OptionalBoolFalse}, count: 33, }, { name: "AllPublic/PublicAndPrivateRepositoriesOfUserIncludingCollaborativeByName", - opts: &SearchRepoOptions{Keyword: "test", ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Private: true, AllPublic: true}, + opts: &SearchRepoOptions{Keyword: "test", ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 15, Private: true, AllPublic: true}, count: 15, }, { name: "AllPublic/PublicAndPrivateRepositoriesOfUser2IncludingCollaborativeByName", - opts: &SearchRepoOptions{Keyword: "test", ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 18, Private: true, AllPublic: true}, + opts: &SearchRepoOptions{Keyword: "test", ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 18, Private: true, AllPublic: true}, count: 13, }, { name: "AllPublic/PublicRepositoriesOfOrganization", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, OwnerID: 17, AllPublic: true, Collaborate: util.OptionalBoolFalse, Template: util.OptionalBoolFalse}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, OwnerID: 17, AllPublic: true, Collaborate: util.OptionalBoolFalse, Template: util.OptionalBoolFalse}, count: 28, }, { name: "AllTemplates", - opts: &SearchRepoOptions{ListOptions: ListOptions{Page: 1, PageSize: 10}, Template: util.OptionalBoolTrue}, + opts: &SearchRepoOptions{ListOptions: db.ListOptions{Page: 1, PageSize: 10}, Template: util.OptionalBoolTrue}, count: 2, }, } diff --git a/models/repo_sign.go b/models/repo_sign.go index be9309ed4eac..ae0895df7646 100644 --- a/models/repo_sign.go +++ b/models/repo_sign.go @@ -7,6 +7,7 @@ package models import ( "strings" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/process" @@ -120,7 +121,7 @@ Loop: case always: break Loop case pubkey: - keys, err := ListGPGKeys(u.ID, ListOptions{}) + keys, err := ListGPGKeys(u.ID, db.ListOptions{}) if err != nil { return false, "", nil, err } @@ -156,7 +157,7 @@ Loop: case always: break Loop case pubkey: - keys, err := ListGPGKeys(u.ID, ListOptions{}) + keys, err := ListGPGKeys(u.ID, db.ListOptions{}) if err != nil { return false, "", nil, err } @@ -209,7 +210,7 @@ Loop: case always: break Loop case pubkey: - keys, err := ListGPGKeys(u.ID, ListOptions{}) + keys, err := ListGPGKeys(u.ID, db.ListOptions{}) if err != nil { return false, "", nil, err } diff --git a/models/repo_transfer.go b/models/repo_transfer.go index e3eb756eb405..fe50c1cc0460 100644 --- a/models/repo_transfer.go +++ b/models/repo_transfer.go @@ -266,7 +266,7 @@ func TransferOwnership(doer *User, newOwnerName string, repo *Repository) (err e } // Remove redundant collaborators. - collaborators, err := repo.getCollaborators(sess, ListOptions{}) + collaborators, err := repo.getCollaborators(sess, db.ListOptions{}) if err != nil { return fmt.Errorf("getCollaborators: %v", err) } diff --git a/models/repo_unit.go b/models/repo_unit.go index c35312be604f..7061119bd851 100644 --- a/models/repo_unit.go +++ b/models/repo_unit.go @@ -8,6 +8,7 @@ import ( "fmt" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/json" "code.gitea.io/gitea/modules/timeutil" @@ -153,7 +154,7 @@ func (cfg *PullRequestsConfig) AllowedMergeStyleCount() int { func (r *RepoUnit) BeforeSet(colName string, val xorm.Cell) { switch colName { case "type": - switch UnitType(Cell2Int64(val)) { + switch UnitType(login.Cell2Int64(val)) { case UnitTypeCode, UnitTypeReleases, UnitTypeWiki, UnitTypeProjects: r.Config = new(UnitConfig) case UnitTypeExternalWiki: diff --git a/models/repo_watch.go b/models/repo_watch.go index d3720fe857a4..b37d47874e34 100644 --- a/models/repo_watch.go +++ b/models/repo_watch.go @@ -165,12 +165,12 @@ func getRepoWatchersIDs(e db.Engine, repoID int64) ([]int64, error) { } // GetWatchers returns range of users watching given repository. -func (repo *Repository) GetWatchers(opts ListOptions) ([]*User, error) { +func (repo *Repository) GetWatchers(opts db.ListOptions) ([]*User, error) { sess := db.GetEngine(db.DefaultContext).Where("watch.repo_id=?", repo.ID). Join("LEFT", "watch", "`user`.id=`watch`.user_id"). And("`watch`.mode<>?", RepoWatchModeDont) if opts.Page > 0 { - sess = setSessionPagination(sess, &opts) + sess = db.SetSessionPagination(sess, &opts) users := make([]*User, 0, opts.PageSize) return users, sess.Find(&users) diff --git a/models/repo_watch_test.go b/models/repo_watch_test.go index 1a94b8ad3092..52222af2ca14 100644 --- a/models/repo_watch_test.go +++ b/models/repo_watch_test.go @@ -60,7 +60,7 @@ func TestRepository_GetWatchers(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository) - watchers, err := repo.GetWatchers(ListOptions{Page: 1}) + watchers, err := repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, repo.NumWatches) for _, watcher := range watchers { @@ -68,7 +68,7 @@ func TestRepository_GetWatchers(t *testing.T) { } repo = db.AssertExistsAndLoadBean(t, &Repository{ID: 9}).(*Repository) - watchers, err = repo.GetWatchers(ListOptions{Page: 1}) + watchers, err = repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, 0) } @@ -114,7 +114,7 @@ func TestWatchIfAuto(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository) - watchers, err := repo.GetWatchers(ListOptions{Page: 1}) + watchers, err := repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, repo.NumWatches) @@ -124,13 +124,13 @@ func TestWatchIfAuto(t *testing.T) { // Must not add watch assert.NoError(t, WatchIfAuto(8, 1, true)) - watchers, err = repo.GetWatchers(ListOptions{Page: 1}) + watchers, err = repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) // Should not add watch assert.NoError(t, WatchIfAuto(10, 1, true)) - watchers, err = repo.GetWatchers(ListOptions{Page: 1}) + watchers, err = repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) @@ -138,31 +138,31 @@ func TestWatchIfAuto(t *testing.T) { // Must not add watch assert.NoError(t, WatchIfAuto(8, 1, true)) - watchers, err = repo.GetWatchers(ListOptions{Page: 1}) + watchers, err = repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) // Should not add watch assert.NoError(t, WatchIfAuto(12, 1, false)) - watchers, err = repo.GetWatchers(ListOptions{Page: 1}) + watchers, err = repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) // Should add watch assert.NoError(t, WatchIfAuto(12, 1, true)) - watchers, err = repo.GetWatchers(ListOptions{Page: 1}) + watchers, err = repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount+1) // Should remove watch, inhibit from adding auto assert.NoError(t, WatchRepo(12, 1, false)) - watchers, err = repo.GetWatchers(ListOptions{Page: 1}) + watchers, err = repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) // Must not add watch assert.NoError(t, WatchIfAuto(12, 1, true)) - watchers, err = repo.GetWatchers(ListOptions{Page: 1}) + watchers, err = repo.GetWatchers(db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) } diff --git a/models/review.go b/models/review.go index 3f81c39b2fc8..5b12e6ffa29e 100644 --- a/models/review.go +++ b/models/review.go @@ -172,7 +172,7 @@ func GetReviewByID(id int64) (*Review, error) { // FindReviewOptions represent possible filters to find reviews type FindReviewOptions struct { - ListOptions + db.ListOptions Type ReviewType IssueID int64 ReviewerID int64 @@ -200,7 +200,7 @@ func findReviews(e db.Engine, opts FindReviewOptions) ([]*Review, error) { reviews := make([]*Review, 0, 10) sess := e.Where(opts.toCond()) if opts.Page > 0 { - sess = setSessionPagination(sess, &opts) + sess = db.SetSessionPagination(sess, &opts) } return reviews, sess. Asc("created_unix"). diff --git a/models/ssh_key.go b/models/ssh_key.go index 41016537eb65..c08fb72e7513 100644 --- a/models/ssh_key.go +++ b/models/ssh_key.go @@ -11,6 +11,7 @@ import ( "time" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/util" @@ -197,10 +198,10 @@ func SearchPublicKey(uid int64, fingerprint string) ([]*PublicKey, error) { } // ListPublicKeys returns a list of public keys belongs to given user. -func ListPublicKeys(uid int64, listOptions ListOptions) ([]*PublicKey, error) { +func ListPublicKeys(uid int64, listOptions db.ListOptions) ([]*PublicKey, error) { sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type != ?", uid, KeyTypePrincipal) if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) keys := make([]*PublicKey, 0, listOptions.PageSize) return keys, sess.Find(&keys) @@ -255,7 +256,7 @@ func deletePublicKeys(e db.Engine, keyIDs ...int64) error { // PublicKeysAreExternallyManaged returns whether the provided KeyID represents an externally managed Key func PublicKeysAreExternallyManaged(keys []*PublicKey) ([]bool, error) { - sources := make([]*LoginSource, 0, 5) + sources := make([]*login.Source, 0, 5) externals := make([]bool, len(keys)) keyloop: for i, key := range keys { @@ -264,7 +265,7 @@ keyloop: continue keyloop } - var source *LoginSource + var source *login.Source sourceloop: for _, s := range sources { @@ -276,11 +277,11 @@ keyloop: if source == nil { var err error - source, err = GetLoginSourceByID(key.LoginSourceID) + source, err = login.GetSourceByID(key.LoginSourceID) if err != nil { - if IsErrLoginSourceNotExist(err) { + if login.IsErrSourceNotExist(err) { externals[i] = false - sources[i] = &LoginSource{ + sources[i] = &login.Source{ ID: key.LoginSourceID, } continue keyloop @@ -289,7 +290,7 @@ keyloop: } } - if sshKeyProvider, ok := source.Cfg.(SSHKeyProvider); ok && sshKeyProvider.ProvidesSSHKeys() { + if sshKeyProvider, ok := source.Cfg.(login.SSHKeyProvider); ok && sshKeyProvider.ProvidesSSHKeys() { // Disable setting SSH keys for this user externals[i] = true } @@ -307,14 +308,14 @@ func PublicKeyIsExternallyManaged(id int64) (bool, error) { if key.LoginSourceID == 0 { return false, nil } - source, err := GetLoginSourceByID(key.LoginSourceID) + source, err := login.GetSourceByID(key.LoginSourceID) if err != nil { - if IsErrLoginSourceNotExist(err) { + if login.IsErrSourceNotExist(err) { return false, nil } return false, err } - if sshKeyProvider, ok := source.Cfg.(SSHKeyProvider); ok && sshKeyProvider.ProvidesSSHKeys() { + if sshKeyProvider, ok := source.Cfg.(login.SSHKeyProvider); ok && sshKeyProvider.ProvidesSSHKeys() { // Disable setting SSH keys for this user return true, nil } @@ -387,7 +388,7 @@ func deleteKeysMarkedForDeletion(keys []string) (bool, error) { } // AddPublicKeysBySource add a users public keys. Returns true if there are changes. -func AddPublicKeysBySource(usr *User, s *LoginSource, sshPublicKeys []string) bool { +func AddPublicKeysBySource(usr *User, s *login.Source, sshPublicKeys []string) bool { var sshKeysNeedUpdate bool for _, sshKey := range sshPublicKeys { var err error @@ -425,7 +426,7 @@ func AddPublicKeysBySource(usr *User, s *LoginSource, sshPublicKeys []string) bo } // SynchronizePublicKeys updates a users public keys. Returns true if there are changes. -func SynchronizePublicKeys(usr *User, s *LoginSource, sshPublicKeys []string) bool { +func SynchronizePublicKeys(usr *User, s *login.Source, sshPublicKeys []string) bool { var sshKeysNeedUpdate bool log.Trace("synchronizePublicKeys[%s]: Handling Public SSH Key synchronization for user %s", s.Name, usr.Name) diff --git a/models/ssh_key_deploy.go b/models/ssh_key_deploy.go index 3b9a16828074..34cf03e92518 100644 --- a/models/ssh_key_deploy.go +++ b/models/ssh_key_deploy.go @@ -271,7 +271,7 @@ func deleteDeployKey(sess db.Engine, doer *User, id int64) error { // ListDeployKeysOptions are options for ListDeployKeys type ListDeployKeysOptions struct { - ListOptions + db.ListOptions RepoID int64 KeyID int64 Fingerprint string @@ -300,7 +300,7 @@ func listDeployKeys(e db.Engine, opts *ListDeployKeysOptions) ([]*DeployKey, err sess := e.Where(opts.toCond()) if opts.Page != 0 { - sess = setSessionPagination(sess, opts) + sess = db.SetSessionPagination(sess, opts) keys := make([]*DeployKey, 0, opts.PageSize) return keys, sess.Find(&keys) diff --git a/models/ssh_key_principals.go b/models/ssh_key_principals.go index 383693e14ed7..44b2ee0bb4e6 100644 --- a/models/ssh_key_principals.go +++ b/models/ssh_key_principals.go @@ -112,10 +112,10 @@ func CheckPrincipalKeyString(user *User, content string) (_ string, err error) { } // ListPrincipalKeys returns a list of principals belongs to given user. -func ListPrincipalKeys(uid int64, listOptions ListOptions) ([]*PublicKey, error) { +func ListPrincipalKeys(uid int64, listOptions db.ListOptions) ([]*PublicKey, error) { sess := db.GetEngine(db.DefaultContext).Where("owner_id = ? AND type = ?", uid, KeyTypePrincipal) if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) keys := make([]*PublicKey, 0, listOptions.PageSize) return keys, sess.Find(&keys) diff --git a/models/star.go b/models/star.go index ad583f19985a..ee7791513d97 100644 --- a/models/star.go +++ b/models/star.go @@ -74,11 +74,11 @@ func isStaring(e db.Engine, userID, repoID int64) bool { } // GetStargazers returns the users that starred the repo. -func (repo *Repository) GetStargazers(opts ListOptions) ([]*User, error) { +func (repo *Repository) GetStargazers(opts db.ListOptions) ([]*User, error) { sess := db.GetEngine(db.DefaultContext).Where("star.repo_id = ?", repo.ID). Join("LEFT", "star", "`user`.id = star.uid") if opts.Page > 0 { - sess = setSessionPagination(sess, &opts) + sess = db.SetSessionPagination(sess, &opts) users := make([]*User, 0, opts.PageSize) return users, sess.Find(&users) diff --git a/models/star_test.go b/models/star_test.go index c0c0a607be35..326da8a861d9 100644 --- a/models/star_test.go +++ b/models/star_test.go @@ -34,7 +34,7 @@ func TestRepository_GetStargazers(t *testing.T) { // repo with stargazers assert.NoError(t, db.PrepareTestDatabase()) repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository) - gazers, err := repo.GetStargazers(ListOptions{Page: 0}) + gazers, err := repo.GetStargazers(db.ListOptions{Page: 0}) assert.NoError(t, err) if assert.Len(t, gazers, 1) { assert.Equal(t, int64(2), gazers[0].ID) @@ -45,7 +45,7 @@ func TestRepository_GetStargazers2(t *testing.T) { // repo with stargazers assert.NoError(t, db.PrepareTestDatabase()) repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository) - gazers, err := repo.GetStargazers(ListOptions{Page: 0}) + gazers, err := repo.GetStargazers(db.ListOptions{Page: 0}) assert.NoError(t, err) assert.Len(t, gazers, 0) } diff --git a/models/statistic.go b/models/statistic.go index d192a971f59f..43b1afbc4817 100644 --- a/models/statistic.go +++ b/models/statistic.go @@ -4,7 +4,10 @@ package models -import "code.gitea.io/gitea/models/db" +import ( + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" +) // Statistic contains the database statistics type Statistic struct { @@ -52,7 +55,7 @@ func GetStatistic() (stats Statistic) { stats.Counter.Follow, _ = db.GetEngine(db.DefaultContext).Count(new(Follow)) stats.Counter.Mirror, _ = db.GetEngine(db.DefaultContext).Count(new(Mirror)) stats.Counter.Release, _ = db.GetEngine(db.DefaultContext).Count(new(Release)) - stats.Counter.LoginSource = CountLoginSources() + stats.Counter.LoginSource = login.CountSources() stats.Counter.Webhook, _ = db.GetEngine(db.DefaultContext).Count(new(Webhook)) stats.Counter.Milestone, _ = db.GetEngine(db.DefaultContext).Count(new(Milestone)) stats.Counter.Label, _ = db.GetEngine(db.DefaultContext).Count(new(Label)) diff --git a/models/token.go b/models/token.go index 48ae79542461..07d013ac8ed4 100644 --- a/models/token.go +++ b/models/token.go @@ -147,7 +147,7 @@ func AccessTokenByNameExists(token *AccessToken) (bool, error) { // ListAccessTokensOptions contain filter options type ListAccessTokensOptions struct { - ListOptions + db.ListOptions Name string UserID int64 } @@ -163,7 +163,7 @@ func ListAccessTokens(opts ListAccessTokensOptions) ([]*AccessToken, error) { sess = sess.Desc("id") if opts.Page != 0 { - sess = setSessionPagination(sess, &opts) + sess = db.SetSessionPagination(sess, &opts) tokens := make([]*AccessToken, 0, opts.PageSize) return tokens, sess.Find(&tokens) diff --git a/models/topic.go b/models/topic.go index cf563e9b11c5..6eb8c67b8dd8 100644 --- a/models/topic.go +++ b/models/topic.go @@ -164,7 +164,7 @@ func removeTopicsFromRepo(e db.Engine, repoID int64) error { // FindTopicOptions represents the options when fdin topics type FindTopicOptions struct { - ListOptions + db.ListOptions RepoID int64 Keyword string } @@ -189,7 +189,7 @@ func FindTopics(opts *FindTopicOptions) ([]*Topic, int64, error) { sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") } if opts.PageSize != 0 && opts.Page != 0 { - sess = setSessionPagination(sess, opts) + sess = db.SetSessionPagination(sess, opts) } topics := make([]*Topic, 0, 10) total, err := sess.Desc("topic.repo_count").FindAndCount(&topics) diff --git a/models/topic_test.go b/models/topic_test.go index 9f6352e7e753..b069deaba3c2 100644 --- a/models/topic_test.go +++ b/models/topic_test.go @@ -23,7 +23,7 @@ func TestAddTopic(t *testing.T) { assert.Len(t, topics, totalNrOfTopics) topics, total, err := FindTopics(&FindTopicOptions{ - ListOptions: ListOptions{Page: 1, PageSize: 2}, + ListOptions: db.ListOptions{Page: 1, PageSize: 2}, }) assert.NoError(t, err) assert.Len(t, topics, 2) diff --git a/models/user.go b/models/user.go index fc5d417d3609..fb40a0acc8f2 100644 --- a/models/user.go +++ b/models/user.go @@ -21,6 +21,7 @@ import ( "unicode/utf8" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" @@ -106,7 +107,7 @@ type User struct { // is to change his/her password after registration. MustChangePassword bool `xorm:"NOT NULL DEFAULT false"` - LoginType LoginType + LoginType login.Type LoginSource int64 `xorm:"NOT NULL DEFAULT 0"` LoginName string Type UserType @@ -169,7 +170,7 @@ func init() { // SearchOrganizationsOptions options to filter organizations type SearchOrganizationsOptions struct { - ListOptions + db.ListOptions All bool } @@ -241,12 +242,12 @@ func GetAllUsers() ([]*User, error) { // IsLocal returns true if user login type is LoginPlain. func (u *User) IsLocal() bool { - return u.LoginType <= LoginPlain + return u.LoginType <= login.Plain } // IsOAuth2 returns true if user login type is LoginOAuth2. func (u *User) IsOAuth2() bool { - return u.LoginType == LoginOAuth2 + return u.LoginType == login.OAuth2 } // HasForkedRepo checks if user has already forked a repository with given ID. @@ -331,13 +332,13 @@ func (u *User) GenerateEmailActivateCode(email string) string { } // GetFollowers returns range of user's followers. -func (u *User) GetFollowers(listOptions ListOptions) ([]*User, error) { +func (u *User) GetFollowers(listOptions db.ListOptions) ([]*User, error) { sess := db.GetEngine(db.DefaultContext). Where("follow.follow_id=?", u.ID). Join("LEFT", "follow", "`user`.id=follow.user_id") if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) users := make([]*User, 0, listOptions.PageSize) return users, sess.Find(&users) @@ -353,13 +354,13 @@ func (u *User) IsFollowing(followID int64) bool { } // GetFollowing returns range of user's following. -func (u *User) GetFollowing(listOptions ListOptions) ([]*User, error) { +func (u *User) GetFollowing(listOptions db.ListOptions) ([]*User, error) { sess := db.GetEngine(db.DefaultContext). Where("follow.user_id=?", u.ID). Join("LEFT", "follow", "`user`.id=follow.follow_id") if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) users := make([]*User, 0, listOptions.PageSize) return users, sess.Find(&users) @@ -542,7 +543,7 @@ func (u *User) GetOrganizationCount() (int64, error) { } // GetRepositories returns repositories that user owns, including private repositories. -func (u *User) GetRepositories(listOpts ListOptions, names ...string) (err error) { +func (u *User) GetRepositories(listOpts db.ListOptions, names ...string) (err error) { u.Repos, _, err = GetUserRepositories(&SearchRepoOptions{Actor: u, Private: true, ListOptions: listOpts, LowerNames: names}) return err } @@ -1252,7 +1253,7 @@ func deleteUser(e db.Engine, u *User) error { // ***** END: PublicKey ***** // ***** START: GPGPublicKey ***** - keys, err := listGPGKeys(e, u.ID, ListOptions{}) + keys, err := listGPGKeys(e, u.ID, db.ListOptions{}) if err != nil { return fmt.Errorf("ListGPGKeys: %v", err) } @@ -1488,7 +1489,7 @@ func GetUserIDsByNames(names []string, ignoreNonExistent bool) ([]int64, error) } // GetUsersBySource returns a list of Users for a login source -func GetUsersBySource(s *LoginSource) ([]*User, error) { +func GetUsersBySource(s *login.Source) ([]*User, error) { var users []*User err := db.GetEngine(db.DefaultContext).Where("login_type = ? AND login_source = ?", s.Type, s.ID).Find(&users) return users, err @@ -1592,7 +1593,7 @@ func GetUser(user *User) (bool, error) { // SearchUserOptions contains the options for searching type SearchUserOptions struct { - ListOptions + db.ListOptions Keyword string Type UserType UID int64 @@ -1675,7 +1676,7 @@ func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { sess := db.GetEngine(db.DefaultContext).Where(cond).OrderBy(opts.OrderBy.String()) if opts.Page != 0 { - sess = setSessionPagination(sess, opts) + sess = db.SetSessionPagination(sess, opts) } users = make([]*User, 0, opts.PageSize) @@ -1683,7 +1684,7 @@ func SearchUsers(opts *SearchUserOptions) (users []*User, _ int64, _ error) { } // GetStarredRepos returns the repos starred by a particular user -func GetStarredRepos(userID int64, private bool, listOptions ListOptions) ([]*Repository, error) { +func GetStarredRepos(userID int64, private bool, listOptions db.ListOptions) ([]*Repository, error) { sess := db.GetEngine(db.DefaultContext).Where("star.uid=?", userID). Join("LEFT", "star", "`repository`.id=`star`.repo_id") if !private { @@ -1691,7 +1692,7 @@ func GetStarredRepos(userID int64, private bool, listOptions ListOptions) ([]*Re } if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) repos := make([]*Repository, 0, listOptions.PageSize) return repos, sess.Find(&repos) @@ -1702,7 +1703,7 @@ func GetStarredRepos(userID int64, private bool, listOptions ListOptions) ([]*Re } // GetWatchedRepos returns the repos watched by a particular user -func GetWatchedRepos(userID int64, private bool, listOptions ListOptions) ([]*Repository, int64, error) { +func GetWatchedRepos(userID int64, private bool, listOptions db.ListOptions) ([]*Repository, int64, error) { sess := db.GetEngine(db.DefaultContext).Where("watch.user_id=?", userID). And("`watch`.mode<>?", RepoWatchModeDont). Join("LEFT", "watch", "`repository`.id=`watch`.repo_id") @@ -1711,7 +1712,7 @@ func GetWatchedRepos(userID int64, private bool, listOptions ListOptions) ([]*Re } if listOptions.Page != 0 { - sess = setSessionPagination(sess, &listOptions) + sess = db.SetSessionPagination(sess, &listOptions) repos := make([]*Repository, 0, listOptions.PageSize) total, err := sess.FindAndCount(&repos) diff --git a/models/user_mail.go b/models/user_mail.go index 51d34d26826a..caa931788d5b 100644 --- a/models/user_mail.go +++ b/models/user_mail.go @@ -301,7 +301,7 @@ const ( // SearchEmailOptions are options to search e-mail addresses for the admin panel type SearchEmailOptions struct { - ListOptions + db.ListOptions Keyword string SortType SearchEmailOrderBy IsPrimary util.OptionalBool @@ -357,7 +357,7 @@ func SearchEmails(opts *SearchEmailOptions) ([]*SearchEmailResult, int64, error) orderby = SearchEmailOrderByEmail.String() } - opts.setDefaultValues() + opts.SetDefaultValues() emails := make([]*SearchEmailResult, 0, opts.PageSize) err = db.GetEngine(db.DefaultContext).Table("email_address"). diff --git a/models/user_mail_test.go b/models/user_mail_test.go index 22e5f786bf91..384f28b7bf4e 100644 --- a/models/user_mail_test.go +++ b/models/user_mail_test.go @@ -189,7 +189,7 @@ func TestListEmails(t *testing.T) { // Must find all users and their emails opts := &SearchEmailOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ PageSize: 10000, }, } @@ -241,7 +241,7 @@ func TestListEmails(t *testing.T) { // Must find more than one page, but retrieve only one opts = &SearchEmailOptions{ - ListOptions: ListOptions{ + ListOptions: db.ListOptions{ PageSize: 5, Page: 1, }, diff --git a/models/user_test.go b/models/user_test.go index 6c616a60a902..bf796a8c6252 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -11,6 +11,7 @@ import ( "testing" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/modules/util" @@ -18,6 +19,14 @@ import ( "github.com/stretchr/testify/assert" ) +func TestOAuth2Application_LoadUser(t *testing.T) { + assert.NoError(t, db.PrepareTestDatabase()) + app := db.AssertExistsAndLoadBean(t, &login.OAuth2Application{ID: 1}).(*login.OAuth2Application) + user, err := GetUserByID(app.UID) + assert.NoError(t, err) + assert.NotNil(t, user) +} + func TestUserIsPublicMember(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) @@ -116,19 +125,19 @@ func TestSearchUsers(t *testing.T) { testSuccess(opts, expectedOrgIDs) } - testOrgSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: ListOptions{Page: 1, PageSize: 2}}, + testOrgSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: db.ListOptions{Page: 1, PageSize: 2}}, []int64{3, 6}) - testOrgSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: ListOptions{Page: 2, PageSize: 2}}, + testOrgSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: db.ListOptions{Page: 2, PageSize: 2}}, []int64{7, 17}) - testOrgSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: ListOptions{Page: 3, PageSize: 2}}, + testOrgSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: db.ListOptions{Page: 3, PageSize: 2}}, []int64{19, 25}) - testOrgSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: ListOptions{Page: 4, PageSize: 2}}, + testOrgSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: db.ListOptions{Page: 4, PageSize: 2}}, []int64{26}) - testOrgSuccess(&SearchUserOptions{ListOptions: ListOptions{Page: 5, PageSize: 2}}, + testOrgSuccess(&SearchUserOptions{ListOptions: db.ListOptions{Page: 5, PageSize: 2}}, []int64{}) // test users @@ -137,20 +146,20 @@ func TestSearchUsers(t *testing.T) { testSuccess(opts, expectedUserIDs) } - testUserSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: ListOptions{Page: 1}}, + testUserSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: db.ListOptions{Page: 1}}, []int64{1, 2, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 21, 24, 27, 28, 29, 30}) - testUserSuccess(&SearchUserOptions{ListOptions: ListOptions{Page: 1}, IsActive: util.OptionalBoolFalse}, + testUserSuccess(&SearchUserOptions{ListOptions: db.ListOptions{Page: 1}, IsActive: util.OptionalBoolFalse}, []int64{9}) - testUserSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: ListOptions{Page: 1}, IsActive: util.OptionalBoolTrue}, + testUserSuccess(&SearchUserOptions{OrderBy: "id ASC", ListOptions: db.ListOptions{Page: 1}, IsActive: util.OptionalBoolTrue}, []int64{1, 2, 4, 5, 8, 10, 11, 12, 13, 14, 15, 16, 18, 20, 21, 24, 28, 29, 30}) - testUserSuccess(&SearchUserOptions{Keyword: "user1", OrderBy: "id ASC", ListOptions: ListOptions{Page: 1}, IsActive: util.OptionalBoolTrue}, + testUserSuccess(&SearchUserOptions{Keyword: "user1", OrderBy: "id ASC", ListOptions: db.ListOptions{Page: 1}, IsActive: util.OptionalBoolTrue}, []int64{1, 10, 11, 12, 13, 14, 15, 16, 18}) // order by name asc default - testUserSuccess(&SearchUserOptions{Keyword: "user1", ListOptions: ListOptions{Page: 1}, IsActive: util.OptionalBoolTrue}, + testUserSuccess(&SearchUserOptions{Keyword: "user1", ListOptions: db.ListOptions{Page: 1}, IsActive: util.OptionalBoolTrue}, []int64{1, 10, 11, 12, 13, 14, 15, 16, 18}) } @@ -407,7 +416,7 @@ func TestAddLdapSSHPublicKeys(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User) - s := &LoginSource{ID: 1} + s := &login.Source{ID: 1} testCases := []struct { keyString string diff --git a/models/webhook.go b/models/webhook.go index 034b37263aad..9d04f8f5e4fb 100644 --- a/models/webhook.go +++ b/models/webhook.go @@ -397,7 +397,7 @@ func GetWebhookByOrgID(orgID, id int64) (*Webhook, error) { // ListWebhookOptions are options to filter webhooks on ListWebhooksByOpts type ListWebhookOptions struct { - ListOptions + db.ListOptions RepoID int64 OrgID int64 IsActive util.OptionalBool @@ -421,7 +421,7 @@ func listWebhooksByOpts(e db.Engine, opts *ListWebhookOptions) ([]*Webhook, erro sess := e.Where(opts.toCond()) if opts.Page != 0 { - sess = setSessionPagination(sess, opts) + sess = db.SetSessionPagination(sess, opts) webhooks := make([]*Webhook, 0, opts.PageSize) err := sess.Find(&webhooks) return webhooks, err diff --git a/modules/convert/convert.go b/modules/convert/convert.go index 404786ec9c9a..3c309a82f861 100644 --- a/modules/convert/convert.go +++ b/modules/convert/convert.go @@ -12,6 +12,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/structs" @@ -338,8 +339,8 @@ func ToTopicResponse(topic *models.Topic) *api.TopicResponse { } } -// ToOAuth2Application convert from models.OAuth2Application to api.OAuth2Application -func ToOAuth2Application(app *models.OAuth2Application) *api.OAuth2Application { +// ToOAuth2Application convert from login.OAuth2Application to api.OAuth2Application +func ToOAuth2Application(app *login.OAuth2Application) *api.OAuth2Application { return &api.OAuth2Application{ ID: app.ID, Name: app.Name, diff --git a/modules/gitgraph/graph_models.go b/modules/gitgraph/graph_models.go index ec47f0ad84ac..86bd8cb237bf 100644 --- a/modules/gitgraph/graph_models.go +++ b/modules/gitgraph/graph_models.go @@ -10,6 +10,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" ) @@ -114,7 +115,7 @@ func (graph *Graph) LoadAndProcessCommits(repository *models.Repository, gitRepo _ = models.CalculateTrustStatus(c.Verification, repository, &keyMap) - statuses, err := models.GetLatestCommitStatus(repository.ID, c.Commit.ID.String(), models.ListOptions{}) + statuses, err := models.GetLatestCommitStatus(repository.ID, c.Commit.ID.String(), db.ListOptions{}) if err != nil { log.Error("GetLatestCommitStatus: %v", err) } else { diff --git a/modules/indexer/issues/indexer.go b/modules/indexer/issues/indexer.go index 676b6686ea5b..4e133b4dd393 100644 --- a/modules/indexer/issues/indexer.go +++ b/modules/indexer/issues/indexer.go @@ -12,6 +12,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/graceful" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/queue" @@ -241,7 +242,7 @@ func populateIssueIndexer(ctx context.Context) { default: } repos, _, err := models.SearchRepositoryByName(&models.SearchRepoOptions{ - ListOptions: models.ListOptions{Page: page, PageSize: models.RepositoryListDefaultPageSize}, + ListOptions: db.ListOptions{Page: page, PageSize: models.RepositoryListDefaultPageSize}, OrderBy: models.SearchOrderByID, Private: true, Collaborate: util.OptionalBoolFalse, diff --git a/modules/migrations/gitea_uploader_test.go b/modules/migrations/gitea_uploader_test.go index 73293f1f8540..b8b947961f4b 100644 --- a/modules/migrations/gitea_uploader_test.go +++ b/modules/migrations/gitea_uploader_test.go @@ -69,12 +69,12 @@ func TestGiteaUploadRepo(t *testing.T) { assert.NoError(t, err) assert.Empty(t, milestones) - labels, err := models.GetLabelsByRepoID(repo.ID, "", models.ListOptions{}) + labels, err := models.GetLabelsByRepoID(repo.ID, "", db.ListOptions{}) assert.NoError(t, err) assert.Len(t, labels, 12) releases, err := models.GetReleasesByRepoID(repo.ID, models.FindReleasesOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: 10, Page: 0, }, @@ -84,7 +84,7 @@ func TestGiteaUploadRepo(t *testing.T) { assert.Len(t, releases, 8) releases, err = models.GetReleasesByRepoID(repo.ID, models.FindReleasesOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: 10, Page: 0, }, diff --git a/modules/repository/adopt.go b/modules/repository/adopt.go index daefee9c7460..c5c059f4718b 100644 --- a/modules/repository/adopt.go +++ b/modules/repository/adopt.go @@ -122,7 +122,7 @@ func DeleteUnadoptedRepository(doer, u *models.User, repoName string) error { } // ListUnadoptedRepositories lists all the unadopted repositories that match the provided query -func ListUnadoptedRepositories(query string, opts *models.ListOptions) ([]string, int, error) { +func ListUnadoptedRepositories(query string, opts *db.ListOptions) ([]string, int, error) { globUser, _ := glob.Compile("*") globRepo, _ := glob.Compile("*") @@ -165,10 +165,13 @@ func ListUnadoptedRepositories(query string, opts *models.ListOptions) ([]string // Clean up old repoNamesToCheck if len(repoNamesToCheck) > 0 { - repos, _, err := models.GetUserRepositories(&models.SearchRepoOptions{Actor: ctxUser, Private: true, ListOptions: models.ListOptions{ - Page: 1, - PageSize: opts.PageSize, - }, LowerNames: repoNamesToCheck}) + repos, _, err := models.GetUserRepositories(&models.SearchRepoOptions{ + Actor: ctxUser, + Private: true, + ListOptions: db.ListOptions{ + Page: 1, + PageSize: opts.PageSize, + }, LowerNames: repoNamesToCheck}) if err != nil { return err } @@ -219,10 +222,13 @@ func ListUnadoptedRepositories(query string, opts *models.ListOptions) ([]string if count < end { repoNamesToCheck = append(repoNamesToCheck, name) if len(repoNamesToCheck) >= opts.PageSize { - repos, _, err := models.GetUserRepositories(&models.SearchRepoOptions{Actor: ctxUser, Private: true, ListOptions: models.ListOptions{ - Page: 1, - PageSize: opts.PageSize, - }, LowerNames: repoNamesToCheck}) + repos, _, err := models.GetUserRepositories(&models.SearchRepoOptions{ + Actor: ctxUser, + Private: true, + ListOptions: db.ListOptions{ + Page: 1, + PageSize: opts.PageSize, + }, LowerNames: repoNamesToCheck}) if err != nil { return err } @@ -254,10 +260,13 @@ func ListUnadoptedRepositories(query string, opts *models.ListOptions) ([]string } if len(repoNamesToCheck) > 0 { - repos, _, err := models.GetUserRepositories(&models.SearchRepoOptions{Actor: ctxUser, Private: true, ListOptions: models.ListOptions{ - Page: 1, - PageSize: opts.PageSize, - }, LowerNames: repoNamesToCheck}) + repos, _, err := models.GetUserRepositories(&models.SearchRepoOptions{ + Actor: ctxUser, + Private: true, + ListOptions: db.ListOptions{ + Page: 1, + PageSize: opts.PageSize, + }, LowerNames: repoNamesToCheck}) if err != nil { return nil, 0, err } diff --git a/modules/repository/repo.go b/modules/repository/repo.go index ee970fd711ee..6b40a894fb8b 100644 --- a/modules/repository/repo.go +++ b/modules/repository/repo.go @@ -224,7 +224,11 @@ func CleanUpMigrateInfo(repo *models.Repository) (*models.Repository, error) { // SyncReleasesWithTags synchronizes release table with repository tags func SyncReleasesWithTags(repo *models.Repository, gitRepo *git.Repository) error { existingRelTags := make(map[string]struct{}) - opts := models.FindReleasesOptions{IncludeDrafts: true, IncludeTags: true, ListOptions: models.ListOptions{PageSize: 50}} + opts := models.FindReleasesOptions{ + IncludeDrafts: true, + IncludeTags: true, + ListOptions: db.ListOptions{PageSize: 50}, + } for page := 1; ; page++ { opts.Page = page rels, err := models.GetReleasesByRepoID(repo.ID, opts) diff --git a/routers/api/v1/admin/user.go b/routers/api/v1/admin/user.go index e5a75da759ea..2d585b60401b 100644 --- a/routers/api/v1/admin/user.go +++ b/routers/api/v1/admin/user.go @@ -11,6 +11,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/convert" "code.gitea.io/gitea/modules/log" @@ -27,12 +28,12 @@ func parseLoginSource(ctx *context.APIContext, u *models.User, sourceID int64, l return } - source, err := models.GetLoginSourceByID(sourceID) + source, err := login.GetSourceByID(sourceID) if err != nil { - if models.IsErrLoginSourceNotExist(err) { + if login.IsErrSourceNotExist(err) { ctx.Error(http.StatusUnprocessableEntity, "", err) } else { - ctx.Error(http.StatusInternalServerError, "GetLoginSourceByID", err) + ctx.Error(http.StatusInternalServerError, "login.GetSourceByID", err) } return } @@ -74,7 +75,7 @@ func CreateUser(ctx *context.APIContext) { Passwd: form.Password, MustChangePassword: true, IsActive: true, - LoginType: models.LoginPlain, + LoginType: login.Plain, } if form.MustChangePassword != nil { u.MustChangePassword = *form.MustChangePassword diff --git a/routers/api/v1/repo/issue.go b/routers/api/v1/repo/issue.go index c576d18e5fed..7b97a5683a36 100644 --- a/routers/api/v1/repo/issue.go +++ b/routers/api/v1/repo/issue.go @@ -13,6 +13,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/convert" issue_indexer "code.gitea.io/gitea/modules/indexer/issues" @@ -226,7 +227,7 @@ func SearchIssues(ctx *context.APIContext) { // This would otherwise return all issues if no issues were found by the search. if len(keyword) == 0 || len(issueIDs) > 0 || len(includedLabelNames) > 0 || len(includedMilestones) > 0 { issuesOpt := &models.IssuesOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ Page: ctx.FormInt("page"), PageSize: limit, }, @@ -261,7 +262,7 @@ func SearchIssues(ctx *context.APIContext) { return } - issuesOpt.ListOptions = models.ListOptions{ + issuesOpt.ListOptions = db.ListOptions{ Page: -1, } if filteredCount, err = models.CountIssues(issuesOpt); err != nil { @@ -470,7 +471,7 @@ func ListIssues(ctx *context.APIContext) { return } - issuesOpt.ListOptions = models.ListOptions{ + issuesOpt.ListOptions = db.ListOptions{ Page: -1, } if filteredCount, err = models.CountIssues(issuesOpt); err != nil { diff --git a/routers/api/v1/user/app.go b/routers/api/v1/user/app.go index f0f7cb4159b6..bf45bf4dd5a6 100644 --- a/routers/api/v1/user/app.go +++ b/routers/api/v1/user/app.go @@ -12,6 +12,7 @@ import ( "strconv" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/convert" api "code.gitea.io/gitea/modules/structs" @@ -212,7 +213,7 @@ func CreateOauth2Application(ctx *context.APIContext) { data := web.GetForm(ctx).(*api.CreateOAuth2ApplicationOptions) - app, err := models.CreateOAuth2Application(models.CreateOAuth2ApplicationOptions{ + app, err := login.CreateOAuth2Application(login.CreateOAuth2ApplicationOptions{ Name: data.Name, UserID: ctx.User.ID, RedirectURIs: data.RedirectURIs, @@ -251,7 +252,7 @@ func ListOauth2Applications(ctx *context.APIContext) { // "200": // "$ref": "#/responses/OAuth2ApplicationList" - apps, total, err := models.ListOAuth2Applications(ctx.User.ID, utils.GetListOptions(ctx)) + apps, total, err := login.ListOAuth2Applications(ctx.User.ID, utils.GetListOptions(ctx)) if err != nil { ctx.Error(http.StatusInternalServerError, "ListOAuth2Applications", err) return @@ -287,8 +288,8 @@ func DeleteOauth2Application(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" appID := ctx.ParamsInt64(":id") - if err := models.DeleteOAuth2Application(appID, ctx.User.ID); err != nil { - if models.IsErrOAuthApplicationNotFound(err) { + if err := login.DeleteOAuth2Application(appID, ctx.User.ID); err != nil { + if login.IsErrOAuthApplicationNotFound(err) { ctx.NotFound() } else { ctx.Error(http.StatusInternalServerError, "DeleteOauth2ApplicationByID", err) @@ -319,9 +320,9 @@ func GetOauth2Application(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" appID := ctx.ParamsInt64(":id") - app, err := models.GetOAuth2ApplicationByID(appID) + app, err := login.GetOAuth2ApplicationByID(appID) if err != nil { - if models.IsErrOauthClientIDInvalid(err) || models.IsErrOAuthApplicationNotFound(err) { + if login.IsErrOauthClientIDInvalid(err) || login.IsErrOAuthApplicationNotFound(err) { ctx.NotFound() } else { ctx.Error(http.StatusInternalServerError, "GetOauth2ApplicationByID", err) @@ -362,14 +363,14 @@ func UpdateOauth2Application(ctx *context.APIContext) { data := web.GetForm(ctx).(*api.CreateOAuth2ApplicationOptions) - app, err := models.UpdateOAuth2Application(models.UpdateOAuth2ApplicationOptions{ + app, err := login.UpdateOAuth2Application(login.UpdateOAuth2ApplicationOptions{ Name: data.Name, UserID: ctx.User.ID, ID: appID, RedirectURIs: data.RedirectURIs, }) if err != nil { - if models.IsErrOauthClientIDInvalid(err) || models.IsErrOAuthApplicationNotFound(err) { + if login.IsErrOauthClientIDInvalid(err) || login.IsErrOAuthApplicationNotFound(err) { ctx.NotFound() } else { ctx.Error(http.StatusInternalServerError, "UpdateOauth2ApplicationByID", err) diff --git a/routers/api/v1/user/gpg_key.go b/routers/api/v1/user/gpg_key.go index f32d60d03816..9066268bba29 100644 --- a/routers/api/v1/user/gpg_key.go +++ b/routers/api/v1/user/gpg_key.go @@ -9,6 +9,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/convert" api "code.gitea.io/gitea/modules/structs" @@ -16,7 +17,7 @@ import ( "code.gitea.io/gitea/routers/api/v1/utils" ) -func listGPGKeys(ctx *context.APIContext, uid int64, listOptions models.ListOptions) { +func listGPGKeys(ctx *context.APIContext, uid int64, listOptions db.ListOptions) { keys, err := models.ListGPGKeys(uid, listOptions) if err != nil { ctx.Error(http.StatusInternalServerError, "ListGPGKeys", err) diff --git a/routers/api/v1/user/star.go b/routers/api/v1/user/star.go index 8ee167685639..f067722bfa82 100644 --- a/routers/api/v1/user/star.go +++ b/routers/api/v1/user/star.go @@ -9,6 +9,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/convert" api "code.gitea.io/gitea/modules/structs" @@ -17,7 +18,7 @@ import ( // getStarredRepos returns the repos that the user with the specified userID has // starred -func getStarredRepos(user *models.User, private bool, listOptions models.ListOptions) ([]*api.Repository, error) { +func getStarredRepos(user *models.User, private bool, listOptions db.ListOptions) ([]*api.Repository, error) { starredRepos, err := models.GetStarredRepos(user.ID, private, listOptions) if err != nil { return nil, err diff --git a/routers/api/v1/user/watch.go b/routers/api/v1/user/watch.go index f32ce7359864..3c6f8b30704f 100644 --- a/routers/api/v1/user/watch.go +++ b/routers/api/v1/user/watch.go @@ -8,6 +8,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/convert" api "code.gitea.io/gitea/modules/structs" @@ -15,7 +16,7 @@ import ( ) // getWatchedRepos returns the repos that the user with the specified userID is watching -func getWatchedRepos(user *models.User, private bool, listOptions models.ListOptions) ([]*api.Repository, int64, error) { +func getWatchedRepos(user *models.User, private bool, listOptions db.ListOptions) ([]*api.Repository, int64, error) { watchedRepos, total, err := models.GetWatchedRepos(user.ID, private, listOptions) if err != nil { return nil, 0, err diff --git a/routers/api/v1/utils/utils.go b/routers/api/v1/utils/utils.go index 81f5086c9611..756485711571 100644 --- a/routers/api/v1/utils/utils.go +++ b/routers/api/v1/utils/utils.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/convert" ) @@ -60,8 +60,8 @@ func prepareQueryArg(ctx *context.APIContext, name string) (value string, err er } // GetListOptions returns list options using the page and limit parameters -func GetListOptions(ctx *context.APIContext) models.ListOptions { - return models.ListOptions{ +func GetListOptions(ctx *context.APIContext) db.ListOptions { + return db.ListOptions{ Page: ctx.FormInt("page"), PageSize: convert.ToCorrectPageSize(ctx.FormInt("limit")), } diff --git a/routers/web/admin/auths.go b/routers/web/admin/auths.go index 2937190a1f50..1b005e5c7bce 100644 --- a/routers/web/admin/auths.go +++ b/routers/web/admin/auths.go @@ -11,6 +11,7 @@ import ( "regexp" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/auth/pam" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" @@ -18,6 +19,7 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/modules/web" + auth_service "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/auth/source/ldap" "code.gitea.io/gitea/services/auth/source/oauth2" pamService "code.gitea.io/gitea/services/auth/source/pam" @@ -46,13 +48,13 @@ func Authentications(ctx *context.Context) { ctx.Data["PageIsAdminAuthentications"] = true var err error - ctx.Data["Sources"], err = models.LoginSources() + ctx.Data["Sources"], err = login.Sources() if err != nil { - ctx.ServerError("LoginSources", err) + ctx.ServerError("login.Sources", err) return } - ctx.Data["Total"] = models.CountLoginSources() + ctx.Data["Total"] = login.CountSources() ctx.HTML(http.StatusOK, tplAuths) } @@ -64,14 +66,14 @@ type dropdownItem struct { var ( authSources = func() []dropdownItem { items := []dropdownItem{ - {models.LoginNames[models.LoginLDAP], models.LoginLDAP}, - {models.LoginNames[models.LoginDLDAP], models.LoginDLDAP}, - {models.LoginNames[models.LoginSMTP], models.LoginSMTP}, - {models.LoginNames[models.LoginOAuth2], models.LoginOAuth2}, - {models.LoginNames[models.LoginSSPI], models.LoginSSPI}, + {login.LDAP.String(), login.LDAP}, + {login.DLDAP.String(), login.DLDAP}, + {login.SMTP.String(), login.SMTP}, + {login.OAuth2.String(), login.OAuth2}, + {login.SSPI.String(), login.SSPI}, } if pam.Supported { - items = append(items, dropdownItem{models.LoginNames[models.LoginPAM], models.LoginPAM}) + items = append(items, dropdownItem{login.Names[login.PAM], login.PAM}) } return items }() @@ -89,8 +91,8 @@ func NewAuthSource(ctx *context.Context) { ctx.Data["PageIsAdmin"] = true ctx.Data["PageIsAdminAuthentications"] = true - ctx.Data["type"] = models.LoginLDAP - ctx.Data["CurrentTypeName"] = models.LoginNames[models.LoginLDAP] + ctx.Data["type"] = login.LDAP + ctx.Data["CurrentTypeName"] = login.Names[login.LDAP] ctx.Data["CurrentSecurityProtocol"] = ldap.SecurityProtocolNames[ldap.SecurityProtocolUnencrypted] ctx.Data["smtp_auth"] = "PLAIN" ctx.Data["is_active"] = true @@ -217,7 +219,7 @@ func NewAuthSourcePost(ctx *context.Context) { ctx.Data["PageIsAdmin"] = true ctx.Data["PageIsAdminAuthentications"] = true - ctx.Data["CurrentTypeName"] = models.LoginNames[models.LoginType(form.Type)] + ctx.Data["CurrentTypeName"] = login.Type(form.Type).String() ctx.Data["CurrentSecurityProtocol"] = ldap.SecurityProtocolNames[ldap.SecurityProtocol(form.SecurityProtocol)] ctx.Data["AuthSources"] = authSources ctx.Data["SecurityProtocols"] = securityProtocols @@ -233,28 +235,28 @@ func NewAuthSourcePost(ctx *context.Context) { hasTLS := false var config convert.Conversion - switch models.LoginType(form.Type) { - case models.LoginLDAP, models.LoginDLDAP: + switch login.Type(form.Type) { + case login.LDAP, login.DLDAP: config = parseLDAPConfig(form) hasTLS = ldap.SecurityProtocol(form.SecurityProtocol) > ldap.SecurityProtocolUnencrypted - case models.LoginSMTP: + case login.SMTP: config = parseSMTPConfig(form) hasTLS = true - case models.LoginPAM: + case login.PAM: config = &pamService.Source{ ServiceName: form.PAMServiceName, EmailDomain: form.PAMEmailDomain, } - case models.LoginOAuth2: + case login.OAuth2: config = parseOAuth2Config(form) - case models.LoginSSPI: + case login.SSPI: var err error config, err = parseSSPIConfig(ctx, form) if err != nil { ctx.RenderWithErr(err.Error(), tplAuthNew, form) return } - existing, err := models.LoginSourcesByType(models.LoginSSPI) + existing, err := login.SourcesByType(login.SSPI) if err != nil || len(existing) > 0 { ctx.Data["Err_Type"] = true ctx.RenderWithErr(ctx.Tr("admin.auths.login_source_of_type_exist"), tplAuthNew, form) @@ -271,18 +273,18 @@ func NewAuthSourcePost(ctx *context.Context) { return } - if err := models.CreateLoginSource(&models.LoginSource{ - Type: models.LoginType(form.Type), + if err := login.CreateSource(&login.Source{ + Type: login.Type(form.Type), Name: form.Name, IsActive: form.IsActive, IsSyncEnabled: form.IsSyncEnabled, Cfg: config, }); err != nil { - if models.IsErrLoginSourceAlreadyExist(err) { + if login.IsErrSourceAlreadyExist(err) { ctx.Data["Err_Name"] = true - ctx.RenderWithErr(ctx.Tr("admin.auths.login_source_exist", err.(models.ErrLoginSourceAlreadyExist).Name), tplAuthNew, form) + ctx.RenderWithErr(ctx.Tr("admin.auths.login_source_exist", err.(login.ErrSourceAlreadyExist).Name), tplAuthNew, form) } else { - ctx.ServerError("CreateSource", err) + ctx.ServerError("login.CreateSource", err) } return } @@ -304,9 +306,9 @@ func EditAuthSource(ctx *context.Context) { oauth2providers := oauth2.GetOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers - source, err := models.GetLoginSourceByID(ctx.ParamsInt64(":authid")) + source, err := login.GetSourceByID(ctx.ParamsInt64(":authid")) if err != nil { - ctx.ServerError("GetLoginSourceByID", err) + ctx.ServerError("login.GetSourceByID", err) return } ctx.Data["Source"] = source @@ -339,9 +341,9 @@ func EditAuthSourcePost(ctx *context.Context) { oauth2providers := oauth2.GetOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers - source, err := models.GetLoginSourceByID(ctx.ParamsInt64(":authid")) + source, err := login.GetSourceByID(ctx.ParamsInt64(":authid")) if err != nil { - ctx.ServerError("GetLoginSourceByID", err) + ctx.ServerError("login.GetSourceByID", err) return } ctx.Data["Source"] = source @@ -353,19 +355,19 @@ func EditAuthSourcePost(ctx *context.Context) { } var config convert.Conversion - switch models.LoginType(form.Type) { - case models.LoginLDAP, models.LoginDLDAP: + switch login.Type(form.Type) { + case login.LDAP, login.DLDAP: config = parseLDAPConfig(form) - case models.LoginSMTP: + case login.SMTP: config = parseSMTPConfig(form) - case models.LoginPAM: + case login.PAM: config = &pamService.Source{ ServiceName: form.PAMServiceName, EmailDomain: form.PAMEmailDomain, } - case models.LoginOAuth2: + case login.OAuth2: config = parseOAuth2Config(form) - case models.LoginSSPI: + case login.SSPI: config, err = parseSSPIConfig(ctx, form) if err != nil { ctx.RenderWithErr(err.Error(), tplAuthEdit, form) @@ -380,7 +382,7 @@ func EditAuthSourcePost(ctx *context.Context) { source.IsActive = form.IsActive source.IsSyncEnabled = form.IsSyncEnabled source.Cfg = config - if err := models.UpdateSource(source); err != nil { + if err := login.UpdateSource(source); err != nil { if models.IsErrOpenIDConnectInitialize(err) { ctx.Flash.Error(err.Error(), true) ctx.HTML(http.StatusOK, tplAuthEdit) @@ -397,17 +399,17 @@ func EditAuthSourcePost(ctx *context.Context) { // DeleteAuthSource response for deleting an auth source func DeleteAuthSource(ctx *context.Context) { - source, err := models.GetLoginSourceByID(ctx.ParamsInt64(":authid")) + source, err := login.GetSourceByID(ctx.ParamsInt64(":authid")) if err != nil { - ctx.ServerError("GetLoginSourceByID", err) + ctx.ServerError("login.GetSourceByID", err) return } - if err = models.DeleteSource(source); err != nil { - if models.IsErrLoginSourceInUse(err) { + if err = auth_service.DeleteLoginSource(source); err != nil { + if login.IsErrSourceInUse(err) { ctx.Flash.Error(ctx.Tr("admin.auths.still_in_used")) } else { - ctx.Flash.Error(fmt.Sprintf("DeleteSource: %v", err)) + ctx.Flash.Error(fmt.Sprintf("DeleteLoginSource: %v", err)) } ctx.JSON(http.StatusOK, map[string]interface{}{ "redirect": setting.AppSubURL + "/admin/auths/" + ctx.Params(":authid"), diff --git a/routers/web/admin/emails.go b/routers/web/admin/emails.go index 017d696e202d..5cbe70020b70 100644 --- a/routers/web/admin/emails.go +++ b/routers/web/admin/emails.go @@ -10,6 +10,7 @@ import ( "net/url" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" @@ -28,7 +29,7 @@ func Emails(ctx *context.Context) { ctx.Data["PageIsAdminEmails"] = true opts := &models.SearchEmailOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: setting.UI.Admin.UserPagingNum, Page: ctx.FormInt("page"), }, diff --git a/routers/web/admin/orgs.go b/routers/web/admin/orgs.go index a2b3ed1bcc0f..df3118b60f22 100644 --- a/routers/web/admin/orgs.go +++ b/routers/web/admin/orgs.go @@ -7,6 +7,7 @@ package admin import ( "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/setting" @@ -27,7 +28,7 @@ func Organizations(ctx *context.Context) { explore.RenderUserSearch(ctx, &models.SearchUserOptions{ Actor: ctx.User, Type: models.UserTypeOrganization, - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: setting.UI.Admin.OrgPagingNum, }, Visible: []structs.VisibleType{structs.VisibleTypePublic, structs.VisibleTypeLimited, structs.VisibleTypePrivate}, diff --git a/routers/web/admin/repos.go b/routers/web/admin/repos.go index 4c3f2ad614ee..2f4d182af8fe 100644 --- a/routers/web/admin/repos.go +++ b/routers/web/admin/repos.go @@ -10,6 +10,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" @@ -68,7 +69,7 @@ func UnadoptedRepos(ctx *context.Context) { ctx.Data["PageIsAdmin"] = true ctx.Data["PageIsAdminRepositories"] = true - opts := models.ListOptions{ + opts := db.ListOptions{ PageSize: setting.UI.Admin.UserPagingNum, Page: ctx.FormInt("page"), } diff --git a/routers/web/admin/users.go b/routers/web/admin/users.go index acccc516bb45..2556cae3a87a 100644 --- a/routers/web/admin/users.go +++ b/routers/web/admin/users.go @@ -12,6 +12,8 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" @@ -39,7 +41,7 @@ func Users(ctx *context.Context) { explore.RenderUserSearch(ctx, &models.SearchUserOptions{ Actor: ctx.User, Type: models.UserTypeIndividual, - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: setting.UI.Admin.UserPagingNum, }, SearchByEmail: true, @@ -56,9 +58,9 @@ func NewUser(ctx *context.Context) { ctx.Data["login_type"] = "0-0" - sources, err := models.LoginSources() + sources, err := login.Sources() if err != nil { - ctx.ServerError("LoginSources", err) + ctx.ServerError("login.Sources", err) return } ctx.Data["Sources"] = sources @@ -75,9 +77,9 @@ func NewUserPost(ctx *context.Context) { ctx.Data["PageIsAdminUsers"] = true ctx.Data["DefaultUserVisibilityMode"] = setting.Service.DefaultUserVisibilityMode - sources, err := models.LoginSources() + sources, err := login.Sources() if err != nil { - ctx.ServerError("LoginSources", err) + ctx.ServerError("login.Sources", err) return } ctx.Data["Sources"] = sources @@ -94,19 +96,19 @@ func NewUserPost(ctx *context.Context) { Email: form.Email, Passwd: form.Password, IsActive: true, - LoginType: models.LoginPlain, + LoginType: login.Plain, } if len(form.LoginType) > 0 { fields := strings.Split(form.LoginType, "-") if len(fields) == 2 { lType, _ := strconv.ParseInt(fields[0], 10, 0) - u.LoginType = models.LoginType(lType) + u.LoginType = login.Type(lType) u.LoginSource, _ = strconv.ParseInt(fields[1], 10, 64) u.LoginName = form.LoginName } } - if u.LoginType == models.LoginNoType || u.LoginType == models.LoginPlain { + if u.LoginType == login.NoType || u.LoginType == login.Plain { if len(form.Password) < setting.MinPasswordLength { ctx.Data["Err_Password"] = true ctx.RenderWithErr(ctx.Tr("auth.password_too_short", setting.MinPasswordLength), tplUserNew, &form) @@ -176,18 +178,18 @@ func prepareUserInfo(ctx *context.Context) *models.User { ctx.Data["User"] = u if u.LoginSource > 0 { - ctx.Data["LoginSource"], err = models.GetLoginSourceByID(u.LoginSource) + ctx.Data["LoginSource"], err = login.GetSourceByID(u.LoginSource) if err != nil { - ctx.ServerError("GetLoginSourceByID", err) + ctx.ServerError("login.GetSourceByID", err) return nil } } else { - ctx.Data["LoginSource"] = &models.LoginSource{} + ctx.Data["LoginSource"] = &login.Source{} } - sources, err := models.LoginSources() + sources, err := login.Sources() if err != nil { - ctx.ServerError("LoginSources", err) + ctx.ServerError("login.Sources", err) return nil } ctx.Data["Sources"] = sources @@ -247,7 +249,7 @@ func EditUserPost(ctx *context.Context) { if u.LoginSource != loginSource { u.LoginSource = loginSource - u.LoginType = models.LoginType(loginType) + u.LoginType = login.Type(loginType) } } diff --git a/routers/web/events/events.go b/routers/web/events/events.go index a630d9c224e8..974aa755d16c 100644 --- a/routers/web/events/events.go +++ b/routers/web/events/events.go @@ -9,6 +9,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/convert" "code.gitea.io/gitea/modules/eventsource" @@ -93,7 +94,7 @@ loop: go unregister() break loop case <-stopwatchTimer.C: - sws, err := models.GetUserStopwatches(ctx.User.ID, models.ListOptions{}) + sws, err := models.GetUserStopwatches(ctx.User.ID, db.ListOptions{}) if err != nil { log.Error("Unable to GetUserStopwatches: %v", err) continue diff --git a/routers/web/explore/org.go b/routers/web/explore/org.go index 470e0eb8530b..d005cfa50322 100644 --- a/routers/web/explore/org.go +++ b/routers/web/explore/org.go @@ -6,6 +6,7 @@ package explore import ( "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/setting" @@ -33,7 +34,7 @@ func Organizations(ctx *context.Context) { RenderUserSearch(ctx, &models.SearchUserOptions{ Actor: ctx.User, Type: models.UserTypeOrganization, - ListOptions: models.ListOptions{PageSize: setting.UI.ExplorePagingNum}, + ListOptions: db.ListOptions{PageSize: setting.UI.ExplorePagingNum}, Visible: visibleTypes, }, tplExploreOrganizations) } diff --git a/routers/web/explore/repo.go b/routers/web/explore/repo.go index dfc6261b33ed..78035037e510 100644 --- a/routers/web/explore/repo.go +++ b/routers/web/explore/repo.go @@ -8,6 +8,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/setting" @@ -77,7 +78,7 @@ func RenderRepoSearch(ctx *context.Context, opts *RepoSearchOptions) { ctx.Data["TopicOnly"] = topicOnly repos, count, err = models.SearchRepository(&models.SearchRepoOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ Page: page, PageSize: opts.PageSize, }, diff --git a/routers/web/explore/user.go b/routers/web/explore/user.go index aeaaf92c1221..4ddb90132d16 100644 --- a/routers/web/explore/user.go +++ b/routers/web/explore/user.go @@ -9,6 +9,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/setting" @@ -99,7 +100,7 @@ func Users(ctx *context.Context) { RenderUserSearch(ctx, &models.SearchUserOptions{ Actor: ctx.User, Type: models.UserTypeIndividual, - ListOptions: models.ListOptions{PageSize: setting.UI.ExplorePagingNum}, + ListOptions: db.ListOptions{PageSize: setting.UI.ExplorePagingNum}, IsActive: util.OptionalBoolTrue, Visible: []structs.VisibleType{structs.VisibleTypePublic, structs.VisibleTypeLimited, structs.VisibleTypePrivate}, }, tplExploreUsers) diff --git a/routers/web/org/home.go b/routers/web/org/home.go index f682dc5cb697..89bd12a18f78 100644 --- a/routers/web/org/home.go +++ b/routers/web/org/home.go @@ -8,6 +8,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/markup" @@ -91,7 +92,7 @@ func Home(ctx *context.Context) { err error ) repos, count, err = models.SearchRepository(&models.SearchRepoOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: setting.UI.User.RepoPagingNum, Page: page, }, @@ -110,7 +111,7 @@ func Home(ctx *context.Context) { var opts = &models.FindOrgMembersOpts{ OrgID: org.ID, PublicOnly: true, - ListOptions: models.ListOptions{Page: 1, PageSize: 25}, + ListOptions: db.ListOptions{Page: 1, PageSize: 25}, } if ctx.User != nil { diff --git a/routers/web/org/org_labels.go b/routers/web/org/org_labels.go index 13728a31b30a..5079d9baa71d 100644 --- a/routers/web/org/org_labels.go +++ b/routers/web/org/org_labels.go @@ -16,7 +16,7 @@ import ( // RetrieveLabels find all the labels of an organization func RetrieveLabels(ctx *context.Context) { - labels, err := models.GetLabelsByOrgID(ctx.Org.Organization.ID, ctx.FormString("sort"), models.ListOptions{}) + labels, err := models.GetLabelsByOrgID(ctx.Org.Organization.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("RetrieveLabels.GetLabels", err) return diff --git a/routers/web/org/setting.go b/routers/web/org/setting.go index 7e6fc5bf4cd9..277ff9d97359 100644 --- a/routers/web/org/setting.go +++ b/routers/web/org/setting.go @@ -10,6 +10,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" @@ -103,7 +104,7 @@ func SettingsPost(ctx *context.Context) { // update forks visibility if visibilityChanged { - if err := org.GetRepositories(models.ListOptions{Page: 1, PageSize: org.NumRepos}); err != nil { + if err := org.GetRepositories(db.ListOptions{Page: 1, PageSize: org.NumRepos}); err != nil { ctx.ServerError("GetRepositories", err) return } diff --git a/routers/web/repo/commit.go b/routers/web/repo/commit.go index 810581640cc2..61435527da1b 100644 --- a/routers/web/repo/commit.go +++ b/routers/web/repo/commit.go @@ -12,6 +12,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/charset" "code.gitea.io/gitea/modules/context" @@ -287,7 +288,7 @@ func Diff(ctx *context.Context) { commitID = commit.ID.String() } - statuses, err := models.GetLatestCommitStatus(ctx.Repo.Repository.ID, commitID, models.ListOptions{}) + statuses, err := models.GetLatestCommitStatus(ctx.Repo.Repository.ID, commitID, db.ListOptions{}) if err != nil { log.Error("GetLatestCommitStatus: %v", err) } diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go index 013286f2de90..7498830d94cc 100644 --- a/routers/web/repo/issue.go +++ b/routers/web/repo/issue.go @@ -16,6 +16,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/convert" @@ -216,7 +217,7 @@ func issues(ctx *context.Context, milestoneID, projectID int64, isPullOption uti issues = []*models.Issue{} } else { issues, err = models.Issues(&models.IssuesOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ Page: pager.Paginater.Current(), PageSize: setting.UI.IssuePagingNum, }, @@ -278,14 +279,14 @@ func issues(ctx *context.Context, milestoneID, projectID int64, isPullOption uti return } - labels, err := models.GetLabelsByRepoID(repo.ID, "", models.ListOptions{}) + labels, err := models.GetLabelsByRepoID(repo.ID, "", db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByRepoID", err) return } if repo.Owner.IsOrganization() { - orgLabels, err := models.GetLabelsByOrgID(repo.Owner.ID, ctx.FormString("sort"), models.ListOptions{}) + orgLabels, err := models.GetLabelsByOrgID(repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByOrgID", err) return @@ -645,14 +646,14 @@ func RetrieveRepoMetas(ctx *context.Context, repo *models.Repository, isPull boo return nil } - labels, err := models.GetLabelsByRepoID(repo.ID, "", models.ListOptions{}) + labels, err := models.GetLabelsByRepoID(repo.ID, "", db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByRepoID", err) return nil } ctx.Data["Labels"] = labels if repo.Owner.IsOrganization() { - orgLabels, err := models.GetLabelsByOrgID(repo.Owner.ID, ctx.FormString("sort"), models.ListOptions{}) + orgLabels, err := models.GetLabelsByOrgID(repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { return nil } @@ -735,10 +736,10 @@ func setTemplateIfExists(ctx *context.Context, ctxDataKey string, possibleDirs [ ctx.Data[issueTemplateTitleKey] = meta.Title ctx.Data[ctxDataKey] = templateBody labelIDs := make([]string, 0, len(meta.Labels)) - if repoLabels, err := models.GetLabelsByRepoID(ctx.Repo.Repository.ID, "", models.ListOptions{}); err == nil { + if repoLabels, err := models.GetLabelsByRepoID(ctx.Repo.Repository.ID, "", db.ListOptions{}); err == nil { ctx.Data["Labels"] = repoLabels if ctx.Repo.Owner.IsOrganization() { - if orgLabels, err := models.GetLabelsByOrgID(ctx.Repo.Owner.ID, ctx.FormString("sort"), models.ListOptions{}); err == nil { + if orgLabels, err := models.GetLabelsByOrgID(ctx.Repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}); err == nil { ctx.Data["OrgLabels"] = orgLabels repoLabels = append(repoLabels, orgLabels...) } @@ -1164,7 +1165,7 @@ func ViewIssue(ctx *context.Context) { for i := range issue.Labels { labelIDMark[issue.Labels[i].ID] = true } - labels, err := models.GetLabelsByRepoID(repo.ID, "", models.ListOptions{}) + labels, err := models.GetLabelsByRepoID(repo.ID, "", db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByRepoID", err) return @@ -1172,7 +1173,7 @@ func ViewIssue(ctx *context.Context) { ctx.Data["Labels"] = labels if repo.Owner.IsOrganization() { - orgLabels, err := models.GetLabelsByOrgID(repo.Owner.ID, ctx.FormString("sort"), models.ListOptions{}) + orgLabels, err := models.GetLabelsByOrgID(repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByOrgID", err) return diff --git a/routers/web/repo/issue_label.go b/routers/web/repo/issue_label.go index 0ce511448547..b97f57175b4c 100644 --- a/routers/web/repo/issue_label.go +++ b/routers/web/repo/issue_label.go @@ -54,7 +54,7 @@ func InitializeLabels(ctx *context.Context) { // RetrieveLabels find all the labels of a repository and organization func RetrieveLabels(ctx *context.Context) { - labels, err := models.GetLabelsByRepoID(ctx.Repo.Repository.ID, ctx.FormString("sort"), models.ListOptions{}) + labels, err := models.GetLabelsByRepoID(ctx.Repo.Repository.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("RetrieveLabels.GetLabels", err) return @@ -67,7 +67,7 @@ func RetrieveLabels(ctx *context.Context) { ctx.Data["Labels"] = labels if ctx.Repo.Owner.IsOrganization() { - orgLabels, err := models.GetLabelsByOrgID(ctx.Repo.Owner.ID, ctx.FormString("sort"), models.ListOptions{}) + orgLabels, err := models.GetLabelsByOrgID(ctx.Repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByOrgID", err) return diff --git a/routers/web/repo/milestone.go b/routers/web/repo/milestone.go index 80f1eb52318a..21e1fb2eab8d 100644 --- a/routers/web/repo/milestone.go +++ b/routers/web/repo/milestone.go @@ -9,6 +9,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/markup" @@ -59,7 +60,7 @@ func Milestones(ctx *context.Context) { } miles, total, err := models.GetMilestones(models.GetMilestonesOption{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ Page: page, PageSize: setting.UI.IssuePagingNum, }, diff --git a/routers/web/repo/pull.go b/routers/web/repo/pull.go index 6b369195de35..c370e7f04d63 100644 --- a/routers/web/repo/pull.go +++ b/routers/web/repo/pull.go @@ -16,6 +16,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/git" @@ -335,7 +336,7 @@ func PrepareMergedViewPullInfo(ctx *context.Context, issue *models.Issue) *git.C if len(compareInfo.Commits) != 0 { sha := compareInfo.Commits[0].ID.String() - commitStatuses, err := models.GetLatestCommitStatus(ctx.Repo.Repository.ID, sha, models.ListOptions{}) + commitStatuses, err := models.GetLatestCommitStatus(ctx.Repo.Repository.ID, sha, db.ListOptions{}) if err != nil { ctx.ServerError("GetLatestCommitStatus", err) return nil @@ -389,7 +390,7 @@ func PrepareViewPullInfo(ctx *context.Context, issue *models.Issue) *git.Compare ctx.ServerError(fmt.Sprintf("GetRefCommitID(%s)", pull.GetGitRefName()), err) return nil } - commitStatuses, err := models.GetLatestCommitStatus(repo.ID, sha, models.ListOptions{}) + commitStatuses, err := models.GetLatestCommitStatus(repo.ID, sha, db.ListOptions{}) if err != nil { ctx.ServerError("GetLatestCommitStatus", err) return nil @@ -478,7 +479,7 @@ func PrepareViewPullInfo(ctx *context.Context, issue *models.Issue) *git.Compare return nil } - commitStatuses, err := models.GetLatestCommitStatus(repo.ID, sha, models.ListOptions{}) + commitStatuses, err := models.GetLatestCommitStatus(repo.ID, sha, db.ListOptions{}) if err != nil { ctx.ServerError("GetLatestCommitStatus", err) return nil diff --git a/routers/web/repo/release.go b/routers/web/repo/release.go index 0603f0ee9745..df1fd745d867 100644 --- a/routers/web/repo/release.go +++ b/routers/web/repo/release.go @@ -11,6 +11,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" @@ -83,7 +84,7 @@ func releasesOrTags(ctx *context.Context, isTagList bool) { ctx.Data["PageIsTagList"] = false } - listOptions := models.ListOptions{ + listOptions := db.ListOptions{ Page: ctx.FormInt("page"), PageSize: ctx.FormInt("limit"), } diff --git a/routers/web/repo/setting.go b/routers/web/repo/setting.go index ed82c2eeb5f8..e71a5bf482bb 100644 --- a/routers/web/repo/setting.go +++ b/routers/web/repo/setting.go @@ -15,6 +15,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/git" @@ -768,7 +769,7 @@ func Collaboration(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("repo.settings") ctx.Data["PageIsSettingsCollaboration"] = true - users, err := ctx.Repo.Repository.GetCollaborators(models.ListOptions{}) + users, err := ctx.Repo.Repository.GetCollaborators(db.ListOptions{}) if err != nil { ctx.ServerError("GetCollaborators", err) return diff --git a/routers/web/repo/view.go b/routers/web/repo/view.go index addde15de153..c0a35bcb4f05 100644 --- a/routers/web/repo/view.go +++ b/routers/web/repo/view.go @@ -18,6 +18,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/cache" "code.gitea.io/gitea/modules/charset" @@ -377,7 +378,7 @@ func renderDirectory(ctx *context.Context, treeLink string) { ctx.Data["LatestCommitUser"] = models.ValidateCommitWithEmail(latestCommit) - statuses, err := models.GetLatestCommitStatus(ctx.Repo.Repository.ID, ctx.Repo.Commit.ID.String(), models.ListOptions{}) + statuses, err := models.GetLatestCommitStatus(ctx.Repo.Repository.ID, ctx.Repo.Commit.ID.String(), db.ListOptions{}) if err != nil { log.Error("GetLatestCommitStatus: %v", err) } @@ -758,7 +759,7 @@ func renderCode(ctx *context.Context) { } // RenderUserCards render a page show users according the input template -func RenderUserCards(ctx *context.Context, total int, getter func(opts models.ListOptions) ([]*models.User, error), tpl base.TplName) { +func RenderUserCards(ctx *context.Context, total int, getter func(opts db.ListOptions) ([]*models.User, error), tpl base.TplName) { page := ctx.FormInt("page") if page <= 0 { page = 1 @@ -766,7 +767,7 @@ func RenderUserCards(ctx *context.Context, total int, getter func(opts models.Li pager := context.NewPagination(total, models.ItemsPerPage, page, 5) ctx.Data["Page"] = pager - items, err := getter(models.ListOptions{ + items, err := getter(db.ListOptions{ Page: pager.Paginater.Current(), PageSize: models.ItemsPerPage, }) @@ -801,7 +802,7 @@ func Forks(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("repos.forks") // TODO: need pagination - forks, err := ctx.Repo.Repository.GetForks(models.ListOptions{}) + forks, err := ctx.Repo.Repository.GetForks(db.ListOptions{}) if err != nil { ctx.ServerError("GetForks", err) return diff --git a/routers/web/user/auth.go b/routers/web/user/auth.go index 9785ca68d51c..733ace81b02a 100644 --- a/routers/web/user/auth.go +++ b/routers/web/user/auth.go @@ -14,6 +14,7 @@ import ( "code.gitea.io/gitea/models" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/eventsource" @@ -147,7 +148,7 @@ func SignIn(ctx *context.Context) { ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login" ctx.Data["PageIsSignIn"] = true ctx.Data["PageIsLogin"] = true - ctx.Data["EnableSSPI"] = models.IsSSPIEnabled() + ctx.Data["EnableSSPI"] = login.IsSSPIEnabled() ctx.HTML(http.StatusOK, tplSignIn) } @@ -167,7 +168,7 @@ func SignInPost(ctx *context.Context) { ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login" ctx.Data["PageIsSignIn"] = true ctx.Data["PageIsLogin"] = true - ctx.Data["EnableSSPI"] = models.IsSSPIEnabled() + ctx.Data["EnableSSPI"] = login.IsSSPIEnabled() if ctx.HasError() { ctx.HTML(http.StatusOK, tplSignIn) @@ -573,7 +574,7 @@ func handleSignInFull(ctx *context.Context, u *models.User, remember bool, obeyR func SignInOAuth(ctx *context.Context) { provider := ctx.Params(":provider") - loginSource, err := models.GetActiveOAuth2LoginSourceByName(provider) + loginSource, err := login.GetActiveOAuth2LoginSourceByName(provider) if err != nil { ctx.ServerError("SignIn", err) return @@ -608,7 +609,7 @@ func SignInOAuthCallback(ctx *context.Context) { provider := ctx.Params(":provider") // first look if the provider is still active - loginSource, err := models.GetActiveOAuth2LoginSourceByName(provider) + loginSource, err := login.GetActiveOAuth2LoginSourceByName(provider) if err != nil { ctx.ServerError("SignIn", err) return @@ -653,7 +654,7 @@ func SignInOAuthCallback(ctx *context.Context) { FullName: gothUser.Name, Email: gothUser.Email, IsActive: !setting.OAuth2Client.RegisterEmailConfirm, - LoginType: models.LoginOAuth2, + LoginType: login.OAuth2, LoginSource: loginSource.ID, LoginName: gothUser.UserID, } @@ -711,7 +712,7 @@ func updateAvatarIfNeed(url string, u *models.User) { } } -func handleOAuth2SignIn(ctx *context.Context, source *models.LoginSource, u *models.User, gothUser goth.User) { +func handleOAuth2SignIn(ctx *context.Context, source *login.Source, u *models.User, gothUser goth.User) { updateAvatarIfNeed(gothUser.AvatarURL, u) needs2FA := false @@ -785,7 +786,7 @@ func handleOAuth2SignIn(ctx *context.Context, source *models.LoginSource, u *mod // OAuth2UserLoginCallback attempts to handle the callback from the OAuth2 provider and if successful // login the user -func oAuth2UserLoginCallback(loginSource *models.LoginSource, request *http.Request, response http.ResponseWriter) (*models.User, goth.User, error) { +func oAuth2UserLoginCallback(loginSource *login.Source, request *http.Request, response http.ResponseWriter) (*models.User, goth.User, error) { gothUser, err := loginSource.Cfg.(*oauth2.Source).Callback(request, response) if err != nil { if err.Error() == "securecookie: the value is too long" { @@ -797,7 +798,7 @@ func oAuth2UserLoginCallback(loginSource *models.LoginSource, request *http.Requ user := &models.User{ LoginName: gothUser.UserID, - LoginType: models.LoginOAuth2, + LoginType: login.OAuth2, LoginSource: loginSource.ID, } @@ -1068,7 +1069,7 @@ func LinkAccountPostRegister(ctx *context.Context) { } } - loginSource, err := models.GetActiveOAuth2LoginSourceByName(gothUser.Provider) + loginSource, err := login.GetActiveOAuth2LoginSourceByName(gothUser.Provider) if err != nil { ctx.ServerError("CreateUser", err) } @@ -1078,7 +1079,7 @@ func LinkAccountPostRegister(ctx *context.Context) { Email: form.Email, Passwd: form.Password, IsActive: !(setting.Service.RegisterEmailConfirm || setting.Service.RegisterManualConfirm), - LoginType: models.LoginOAuth2, + LoginType: login.OAuth2, LoginSource: loginSource.ID, LoginName: gothUser.UserID, } diff --git a/routers/web/user/home.go b/routers/web/user/home.go index bb75558dc85a..2f1fca452711 100644 --- a/routers/web/user/home.go +++ b/routers/web/user/home.go @@ -15,6 +15,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" issue_indexer "code.gitea.io/gitea/modules/indexer/issues" @@ -846,7 +847,7 @@ func repoIDMap(ctxUser *models.User, issueCountByRepo map[int64]int64, unitType // ShowSSHKeys output all the ssh keys of user by uid func ShowSSHKeys(ctx *context.Context, uid int64) { - keys, err := models.ListPublicKeys(uid, models.ListOptions{}) + keys, err := models.ListPublicKeys(uid, db.ListOptions{}) if err != nil { ctx.ServerError("ListPublicKeys", err) return @@ -862,7 +863,7 @@ func ShowSSHKeys(ctx *context.Context, uid int64) { // ShowGPGKeys output all the public GPG keys of user by uid func ShowGPGKeys(ctx *context.Context, uid int64) { - keys, err := models.ListGPGKeys(uid, models.ListOptions{}) + keys, err := models.ListGPGKeys(uid, db.ListOptions{}) if err != nil { ctx.ServerError("ListGPGKeys", err) return diff --git a/routers/web/user/oauth.go b/routers/web/user/oauth.go index cec6a92bbea4..d9fc5eeaf923 100644 --- a/routers/web/user/oauth.go +++ b/routers/web/user/oauth.go @@ -13,6 +13,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/json" @@ -115,7 +116,7 @@ type AccessTokenResponse struct { IDToken string `json:"id_token,omitempty"` } -func newAccessTokenResponse(grant *models.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { +func newAccessTokenResponse(grant *login.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { if setting.OAuth2.InvalidateRefreshTokens { if err := grant.IncreaseCounter(); err != nil { return nil, &AccessTokenError{ @@ -162,7 +163,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, serverKey, clientKey oaut // generate OpenID Connect id_token signedIDToken := "" if grant.ScopeContains("openid") { - app, err := models.GetOAuth2ApplicationByID(grant.ApplicationID) + app, err := login.GetOAuth2ApplicationByID(grant.ApplicationID) if err != nil { return nil, &AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidRequest, @@ -268,9 +269,9 @@ func IntrospectOAuth(ctx *context.Context) { token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey) if err == nil { if token.Valid() == nil { - grant, err := models.GetOAuth2GrantByID(token.GrantID) + grant, err := login.GetOAuth2GrantByID(token.GrantID) if err == nil && grant != nil { - app, err := models.GetOAuth2ApplicationByID(grant.ApplicationID) + app, err := login.GetOAuth2ApplicationByID(grant.ApplicationID) if err == nil && app != nil { response.Active = true response.Scope = grant.Scope @@ -299,9 +300,9 @@ func AuthorizeOAuth(ctx *context.Context) { return } - app, err := models.GetOAuth2ApplicationByClientID(form.ClientID) + app, err := login.GetOAuth2ApplicationByClientID(form.ClientID) if err != nil { - if models.IsErrOauthClientIDInvalid(err) { + if login.IsErrOauthClientIDInvalid(err) { handleAuthorizeError(ctx, AuthorizeError{ ErrorCode: ErrorCodeUnauthorizedClient, ErrorDescription: "Client ID not registered", @@ -312,8 +313,10 @@ func AuthorizeOAuth(ctx *context.Context) { ctx.ServerError("GetOAuth2ApplicationByClientID", err) return } - if err := app.LoadUser(); err != nil { - ctx.ServerError("LoadUser", err) + + user, err := models.GetUserByID(app.UID) + if err != nil { + ctx.ServerError("GetUserByID", err) return } @@ -406,7 +409,7 @@ func AuthorizeOAuth(ctx *context.Context) { ctx.Data["State"] = form.State ctx.Data["Scope"] = form.Scope ctx.Data["Nonce"] = form.Nonce - ctx.Data["ApplicationUserLink"] = "@" + html.EscapeString(app.User.Name) + "" + ctx.Data["ApplicationUserLink"] = "@" + html.EscapeString(user.Name) + "" ctx.Data["ApplicationRedirectDomainHTML"] = "" + html.EscapeString(form.RedirectURI) + "" // TODO document SESSION <=> FORM err = ctx.Session.Set("client_id", app.ClientID) @@ -443,7 +446,7 @@ func GrantApplicationOAuth(ctx *context.Context) { ctx.Error(http.StatusBadRequest) return } - app, err := models.GetOAuth2ApplicationByClientID(form.ClientID) + app, err := login.GetOAuth2ApplicationByClientID(form.ClientID) if err != nil { ctx.ServerError("GetOAuth2ApplicationByClientID", err) return @@ -581,7 +584,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server return } // get grant before increasing counter - grant, err := models.GetOAuth2GrantByID(token.GrantID) + grant, err := login.GetOAuth2GrantByID(token.GrantID) if err != nil || grant == nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidGrant, @@ -608,7 +611,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server } func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) { - app, err := models.GetOAuth2ApplicationByClientID(form.ClientID) + app, err := login.GetOAuth2ApplicationByClientID(form.ClientID) if err != nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidClient, @@ -630,7 +633,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s }) return } - authorizationCode, err := models.GetOAuth2AuthorizationByCode(form.Code) + authorizationCode, err := login.GetOAuth2AuthorizationByCode(form.Code) if err != nil || authorizationCode == nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeUnauthorizedClient, diff --git a/routers/web/user/oauth_test.go b/routers/web/user/oauth_test.go index 09abf1ee2a68..27d339b778ea 100644 --- a/routers/web/user/oauth_test.go +++ b/routers/web/user/oauth_test.go @@ -9,13 +9,14 @@ import ( "code.gitea.io/gitea/models" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/auth/source/oauth2" "github.com/golang-jwt/jwt" "github.com/stretchr/testify/assert" ) -func createAndParseToken(t *testing.T, grant *models.OAuth2Grant) *oauth2.OIDCToken { +func createAndParseToken(t *testing.T, grant *login.OAuth2Grant) *oauth2.OIDCToken { signingKey, err := oauth2.CreateJWTSigningKey("HS256", make([]byte, 32)) assert.NoError(t, err) assert.NotNil(t, signingKey) @@ -42,7 +43,7 @@ func createAndParseToken(t *testing.T, grant *models.OAuth2Grant) *oauth2.OIDCTo func TestNewAccessTokenResponse_OIDCToken(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) - grants, err := models.GetOAuth2GrantsByUserID(3) + grants, err := login.GetOAuth2GrantsByUserID(3) assert.NoError(t, err) assert.Len(t, grants, 1) @@ -58,7 +59,7 @@ func TestNewAccessTokenResponse_OIDCToken(t *testing.T) { assert.False(t, oidcToken.EmailVerified) user := db.AssertExistsAndLoadBean(t, &models.User{ID: 5}).(*models.User) - grants, err = models.GetOAuth2GrantsByUserID(user.ID) + grants, err = login.GetOAuth2GrantsByUserID(user.ID) assert.NoError(t, err) assert.Len(t, grants, 1) diff --git a/routers/web/user/profile.go b/routers/web/user/profile.go index 9ecdc2345c5e..d64d5621dead 100644 --- a/routers/web/user/profile.go +++ b/routers/web/user/profile.go @@ -12,6 +12,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/markup" "code.gitea.io/gitea/modules/markup/markdown" @@ -192,7 +193,7 @@ func Profile(ctx *context.Context) { ctx.Data["Keyword"] = keyword switch tab { case "followers": - items, err := ctxUser.GetFollowers(models.ListOptions{ + items, err := ctxUser.GetFollowers(db.ListOptions{ PageSize: setting.UI.User.RepoPagingNum, Page: page, }) @@ -204,7 +205,7 @@ func Profile(ctx *context.Context) { total = ctxUser.NumFollowers case "following": - items, err := ctxUser.GetFollowing(models.ListOptions{ + items, err := ctxUser.GetFollowing(db.ListOptions{ PageSize: setting.UI.User.RepoPagingNum, Page: page, }) @@ -229,7 +230,7 @@ func Profile(ctx *context.Context) { case "stars": ctx.Data["PageIsProfileStarList"] = true repos, count, err = models.SearchRepository(&models.SearchRepoOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: setting.UI.User.RepoPagingNum, Page: page, }, @@ -260,7 +261,7 @@ func Profile(ctx *context.Context) { } case "watching": repos, count, err = models.SearchRepository(&models.SearchRepoOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: setting.UI.User.RepoPagingNum, Page: page, }, @@ -281,7 +282,7 @@ func Profile(ctx *context.Context) { total = int(count) default: repos, count, err = models.SearchRepository(&models.SearchRepoOptions{ - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: setting.UI.User.RepoPagingNum, Page: page, }, diff --git a/routers/web/user/setting/applications.go b/routers/web/user/setting/applications.go index 5e208afafea3..9976337bfab9 100644 --- a/routers/web/user/setting/applications.go +++ b/routers/web/user/setting/applications.go @@ -9,6 +9,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/setting" @@ -92,12 +93,12 @@ func loadApplicationsData(ctx *context.Context) { ctx.Data["Tokens"] = tokens ctx.Data["EnableOAuth2"] = setting.OAuth2.Enable if setting.OAuth2.Enable { - ctx.Data["Applications"], err = models.GetOAuth2ApplicationsByUserID(ctx.User.ID) + ctx.Data["Applications"], err = login.GetOAuth2ApplicationsByUserID(ctx.User.ID) if err != nil { ctx.ServerError("GetOAuth2ApplicationsByUserID", err) return } - ctx.Data["Grants"], err = models.GetOAuth2GrantsByUserID(ctx.User.ID) + ctx.Data["Grants"], err = login.GetOAuth2GrantsByUserID(ctx.User.ID) if err != nil { ctx.ServerError("GetOAuth2GrantsByUserID", err) return diff --git a/routers/web/user/setting/keys.go b/routers/web/user/setting/keys.go index 24b9a9e205a0..bb7a50841bb8 100644 --- a/routers/web/user/setting/keys.go +++ b/routers/web/user/setting/keys.go @@ -9,6 +9,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/setting" @@ -233,7 +234,7 @@ func DeleteKey(ctx *context.Context) { } func loadKeysData(ctx *context.Context) { - keys, err := models.ListPublicKeys(ctx.User.ID, models.ListOptions{}) + keys, err := models.ListPublicKeys(ctx.User.ID, db.ListOptions{}) if err != nil { ctx.ServerError("ListPublicKeys", err) return @@ -247,7 +248,7 @@ func loadKeysData(ctx *context.Context) { } ctx.Data["ExternalKeys"] = externalKeys - gpgkeys, err := models.ListGPGKeys(ctx.User.ID, models.ListOptions{}) + gpgkeys, err := models.ListGPGKeys(ctx.User.ID, db.ListOptions{}) if err != nil { ctx.ServerError("ListGPGKeys", err) return @@ -258,7 +259,7 @@ func loadKeysData(ctx *context.Context) { // generate a new aes cipher using the csrfToken ctx.Data["TokenToSign"] = tokenToSign - principals, err := models.ListPrincipalKeys(ctx.User.ID, models.ListOptions{}) + principals, err := models.ListPrincipalKeys(ctx.User.ID, db.ListOptions{}) if err != nil { ctx.ServerError("ListPrincipalKeys", err) return diff --git a/routers/web/user/setting/oauth2.go b/routers/web/user/setting/oauth2.go index 8de0720b5105..0f338ab5d1cb 100644 --- a/routers/web/user/setting/oauth2.go +++ b/routers/web/user/setting/oauth2.go @@ -8,7 +8,7 @@ import ( "fmt" "net/http" - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" @@ -34,7 +34,7 @@ func OAuthApplicationsPost(ctx *context.Context) { return } // TODO validate redirect URI - app, err := models.CreateOAuth2Application(models.CreateOAuth2ApplicationOptions{ + app, err := login.CreateOAuth2Application(login.CreateOAuth2ApplicationOptions{ Name: form.Name, RedirectURIs: []string{form.RedirectURI}, UserID: ctx.User.ID, @@ -67,7 +67,7 @@ func OAuthApplicationsEdit(ctx *context.Context) { } // TODO validate redirect URI var err error - if ctx.Data["App"], err = models.UpdateOAuth2Application(models.UpdateOAuth2ApplicationOptions{ + if ctx.Data["App"], err = login.UpdateOAuth2Application(login.UpdateOAuth2ApplicationOptions{ ID: ctx.ParamsInt64("id"), Name: form.Name, RedirectURIs: []string{form.RedirectURI}, @@ -85,9 +85,9 @@ func OAuthApplicationsRegenerateSecret(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("settings") ctx.Data["PageIsSettingsApplications"] = true - app, err := models.GetOAuth2ApplicationByID(ctx.ParamsInt64("id")) + app, err := login.GetOAuth2ApplicationByID(ctx.ParamsInt64("id")) if err != nil { - if models.IsErrOAuthApplicationNotFound(err) { + if login.IsErrOAuthApplicationNotFound(err) { ctx.NotFound("Application not found", err) return } @@ -110,9 +110,9 @@ func OAuthApplicationsRegenerateSecret(ctx *context.Context) { // OAuth2ApplicationShow displays the given application func OAuth2ApplicationShow(ctx *context.Context) { - app, err := models.GetOAuth2ApplicationByID(ctx.ParamsInt64("id")) + app, err := login.GetOAuth2ApplicationByID(ctx.ParamsInt64("id")) if err != nil { - if models.IsErrOAuthApplicationNotFound(err) { + if login.IsErrOAuthApplicationNotFound(err) { ctx.NotFound("Application not found", err) return } @@ -129,7 +129,7 @@ func OAuth2ApplicationShow(ctx *context.Context) { // DeleteOAuth2Application deletes the given oauth2 application func DeleteOAuth2Application(ctx *context.Context) { - if err := models.DeleteOAuth2Application(ctx.FormInt64("id"), ctx.User.ID); err != nil { + if err := login.DeleteOAuth2Application(ctx.FormInt64("id"), ctx.User.ID); err != nil { ctx.ServerError("DeleteOAuth2Application", err) return } @@ -147,7 +147,7 @@ func RevokeOAuth2Grant(ctx *context.Context) { ctx.ServerError("RevokeOAuth2Grant", fmt.Errorf("user id or grant id is zero")) return } - if err := models.RevokeOAuth2Grant(ctx.FormInt64("id"), ctx.User.ID); err != nil { + if err := login.RevokeOAuth2Grant(ctx.FormInt64("id"), ctx.User.ID); err != nil { ctx.ServerError("RevokeOAuth2Grant", err) return } diff --git a/routers/web/user/setting/profile.go b/routers/web/user/setting/profile.go index bd967af32b56..d75149b8fc71 100644 --- a/routers/web/user/setting/profile.go +++ b/routers/web/user/setting/profile.go @@ -15,6 +15,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" @@ -235,7 +236,7 @@ func Repos(ctx *context.Context) { ctx.Data["allowAdopt"] = ctx.IsUserSiteAdmin() || setting.Repository.AllowAdoptionOfUnadoptedRepositories ctx.Data["allowDelete"] = ctx.IsUserSiteAdmin() || setting.Repository.AllowDeleteOfUnadoptedRepositories - opts := models.ListOptions{ + opts := db.ListOptions{ PageSize: setting.UI.Admin.UserPagingNum, Page: ctx.FormInt("page"), } @@ -284,7 +285,7 @@ func Repos(ctx *context.Context) { return } - if err := ctxUser.GetRepositories(models.ListOptions{Page: 1, PageSize: setting.UI.Admin.UserPagingNum}, repoNames...); err != nil { + if err := ctxUser.GetRepositories(db.ListOptions{Page: 1, PageSize: setting.UI.Admin.UserPagingNum}, repoNames...); err != nil { ctx.ServerError("GetRepositories", err) return } diff --git a/routers/web/user/setting/security.go b/routers/web/user/setting/security.go index 3406194015f6..d4abe84d9601 100644 --- a/routers/web/user/setting/security.go +++ b/routers/web/user/setting/security.go @@ -9,6 +9,7 @@ import ( "net/http" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/setting" @@ -87,9 +88,9 @@ func loadSecurityData(ctx *context.Context) { } // map the provider display name with the LoginSource - sources := make(map[*models.LoginSource]string) + sources := make(map[*login.Source]string) for _, externalAccount := range accountLinks { - if loginSource, err := models.GetLoginSourceByID(externalAccount.LoginSourceID); err == nil { + if loginSource, err := login.GetSourceByID(externalAccount.LoginSourceID); err == nil { var providerDisplayName string type DisplayNamed interface { diff --git a/services/auth/login_source.go b/services/auth/login_source.go new file mode 100644 index 000000000000..723dd2b1a5d0 --- /dev/null +++ b/services/auth/login_source.go @@ -0,0 +1,41 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" +) + +// DeleteLoginSource deletes a LoginSource record in DB. +func DeleteLoginSource(source *login.Source) error { + count, err := db.GetEngine(db.DefaultContext).Count(&models.User{LoginSource: source.ID}) + if err != nil { + return err + } else if count > 0 { + return login.ErrSourceInUse{ + ID: source.ID, + } + } + + count, err = db.GetEngine(db.DefaultContext).Count(&models.ExternalLoginUser{LoginSourceID: source.ID}) + if err != nil { + return err + } else if count > 0 { + return login.ErrSourceInUse{ + ID: source.ID, + } + } + + if registerableSource, ok := source.Cfg.(login.RegisterableSource); ok { + if err := registerableSource.UnregisterSource(); err != nil { + return err + } + } + + _, err = db.GetEngine(db.DefaultContext).ID(source.ID).Delete(new(login.Source)) + return err +} diff --git a/services/auth/oauth2.go b/services/auth/oauth2.go index e79b640ce46d..9b342f3458f6 100644 --- a/services/auth/oauth2.go +++ b/services/auth/oauth2.go @@ -12,6 +12,7 @@ import ( "code.gitea.io/gitea/models" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/web/middleware" @@ -35,8 +36,8 @@ func CheckOAuthAccessToken(accessToken string) int64 { log.Trace("oauth2.ParseToken: %v", err) return 0 } - var grant *models.OAuth2Grant - if grant, err = models.GetOAuth2GrantByID(token.GrantID); err != nil || grant == nil { + var grant *login.OAuth2Grant + if grant, err = login.GetOAuth2GrantByID(token.GrantID); err != nil || grant == nil { return 0 } if token.Type != oauth2.TypeAccessToken { diff --git a/services/auth/signin.go b/services/auth/signin.go index a7acb95ba250..a7ad029456df 100644 --- a/services/auth/signin.go +++ b/services/auth/signin.go @@ -9,6 +9,7 @@ import ( "code.gitea.io/gitea/models" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/log" // Register the sources @@ -21,7 +22,7 @@ import ( ) // UserSignIn validates user name and password. -func UserSignIn(username, password string) (*models.User, *models.LoginSource, error) { +func UserSignIn(username, password string) (*models.User, *login.Source, error) { var user *models.User if strings.Contains(username, "@") { user = &models.User{Email: strings.ToLower(strings.TrimSpace(username))} @@ -50,7 +51,7 @@ func UserSignIn(username, password string) (*models.User, *models.LoginSource, e } if hasUser { - source, err := models.GetLoginSourceByID(user.LoginSource) + source, err := login.GetSourceByID(user.LoginSource) if err != nil { return nil, nil, err } @@ -78,7 +79,7 @@ func UserSignIn(username, password string) (*models.User, *models.LoginSource, e return user, source, nil } - sources, err := models.AllActiveLoginSources() + sources, err := login.AllActiveSources() if err != nil { return nil, nil, err } diff --git a/services/auth/source/db/assert_interface_test.go b/services/auth/source/db/assert_interface_test.go index 2e0fa9ba2247..a8b137ec4817 100644 --- a/services/auth/source/db/assert_interface_test.go +++ b/services/auth/source/db/assert_interface_test.go @@ -5,7 +5,7 @@ package db_test import ( - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/auth/source/db" ) @@ -15,7 +15,7 @@ import ( type sourceInterface interface { auth.PasswordAuthenticator - models.LoginConfig + login.Config } var _ (sourceInterface) = &db.Source{} diff --git a/services/auth/source/db/source.go b/services/auth/source/db/source.go index 182c05f0dfcc..2fedff3a7ea8 100644 --- a/services/auth/source/db/source.go +++ b/services/auth/source/db/source.go @@ -4,7 +4,10 @@ package db -import "code.gitea.io/gitea/models" +import ( + "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" +) // Source is a password authentication service type Source struct{} @@ -26,6 +29,6 @@ func (source *Source) Authenticate(user *models.User, login, password string) (* } func init() { - models.RegisterLoginTypeConfig(models.LoginNoType, &Source{}) - models.RegisterLoginTypeConfig(models.LoginPlain, &Source{}) + login.RegisterTypeConfig(login.NoType, &Source{}) + login.RegisterTypeConfig(login.Plain, &Source{}) } diff --git a/services/auth/source/ldap/assert_interface_test.go b/services/auth/source/ldap/assert_interface_test.go index a0425d2f763c..c480119cd3fe 100644 --- a/services/auth/source/ldap/assert_interface_test.go +++ b/services/auth/source/ldap/assert_interface_test.go @@ -5,7 +5,7 @@ package ldap_test import ( - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/auth/source/ldap" ) @@ -17,12 +17,12 @@ type sourceInterface interface { auth.PasswordAuthenticator auth.SynchronizableSource auth.LocalTwoFASkipper - models.SSHKeyProvider - models.LoginConfig - models.SkipVerifiable - models.HasTLSer - models.UseTLSer - models.LoginSourceSettable + login.SSHKeyProvider + login.Config + login.SkipVerifiable + login.HasTLSer + login.UseTLSer + login.SourceSettable } var _ (sourceInterface) = &ldap.Source{} diff --git a/services/auth/source/ldap/source.go b/services/auth/source/ldap/source.go index d1228d41aeb1..82ff7313b288 100644 --- a/services/auth/source/ldap/source.go +++ b/services/auth/source/ldap/source.go @@ -8,6 +8,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/json" "code.gitea.io/gitea/modules/secret" "code.gitea.io/gitea/modules/setting" @@ -55,7 +56,7 @@ type Source struct { SkipLocalTwoFA bool // Skip Local 2fa for users authenticated with this source // reference to the loginSource - loginSource *models.LoginSource + loginSource *login.Source } // FromDB fills up a LDAPConfig from serialized format. @@ -109,11 +110,11 @@ func (source *Source) ProvidesSSHKeys() bool { } // SetLoginSource sets the related LoginSource -func (source *Source) SetLoginSource(loginSource *models.LoginSource) { +func (source *Source) SetLoginSource(loginSource *login.Source) { source.loginSource = loginSource } func init() { - models.RegisterLoginTypeConfig(models.LoginLDAP, &Source{}) - models.RegisterLoginTypeConfig(models.LoginDLDAP, &Source{}) + login.RegisterTypeConfig(login.LDAP, &Source{}) + login.RegisterTypeConfig(login.DLDAP, &Source{}) } diff --git a/services/auth/source/ldap/source_authenticate.go b/services/auth/source/ldap/source_authenticate.go index 46478e60296e..f302a9d5837f 100644 --- a/services/auth/source/ldap/source_authenticate.go +++ b/services/auth/source/ldap/source_authenticate.go @@ -9,16 +9,17 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/mailer" ) // Authenticate queries if login/password is valid against the LDAP directory pool, // and create a local user if success when enabled. -func (source *Source) Authenticate(user *models.User, login, password string) (*models.User, error) { - sr := source.SearchEntry(login, password, source.loginSource.Type == models.LoginDLDAP) +func (source *Source) Authenticate(user *models.User, userName, password string) (*models.User, error) { + sr := source.SearchEntry(userName, password, source.loginSource.Type == login.DLDAP) if sr == nil { // User not in LDAP, do nothing - return nil, models.ErrUserNotExist{Name: login} + return nil, models.ErrUserNotExist{Name: userName} } isAttributeSSHPublicKeySet := len(strings.TrimSpace(source.AttributeSSHPublicKey)) > 0 @@ -64,7 +65,7 @@ func (source *Source) Authenticate(user *models.User, login, password string) (* // Fallback. if len(sr.Username) == 0 { - sr.Username = login + sr.Username = userName } if len(sr.Mail) == 0 { @@ -78,7 +79,7 @@ func (source *Source) Authenticate(user *models.User, login, password string) (* Email: sr.Mail, LoginType: source.loginSource.Type, LoginSource: source.loginSource.ID, - LoginName: login, + LoginName: userName, IsActive: true, IsAdmin: sr.IsAdmin, IsRestricted: sr.IsRestricted, diff --git a/services/auth/source/oauth2/assert_interface_test.go b/services/auth/source/oauth2/assert_interface_test.go index 4157427ff2f7..0a1986a3b29b 100644 --- a/services/auth/source/oauth2/assert_interface_test.go +++ b/services/auth/source/oauth2/assert_interface_test.go @@ -5,7 +5,7 @@ package oauth2_test import ( - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/auth/source/oauth2" ) @@ -14,9 +14,9 @@ import ( // It tightly binds the interfaces and implementation without breaking go import cycles type sourceInterface interface { - models.LoginConfig - models.LoginSourceSettable - models.RegisterableSource + login.Config + login.SourceSettable + login.RegisterableSource auth.PasswordAuthenticator } diff --git a/services/auth/source/oauth2/init.go b/services/auth/source/oauth2/init.go index be31503eeff5..343b24cf6f58 100644 --- a/services/auth/source/oauth2/init.go +++ b/services/auth/source/oauth2/init.go @@ -8,8 +8,8 @@ import ( "net/http" "sync" - "code.gitea.io/gitea/models" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -74,7 +74,7 @@ func ResetOAuth2() error { // initOAuth2LoginSources is used to load and register all active OAuth2 providers func initOAuth2LoginSources() error { - loginSources, _ := models.GetActiveOAuth2ProviderLoginSources() + loginSources, _ := login.GetActiveOAuth2ProviderLoginSources() for _, source := range loginSources { oauth2Source, ok := source.Cfg.(*Source) if !ok { diff --git a/services/auth/source/oauth2/providers.go b/services/auth/source/oauth2/providers.go index 2196e304928e..0fd57a8dbd5a 100644 --- a/services/auth/source/oauth2/providers.go +++ b/services/auth/source/oauth2/providers.go @@ -9,6 +9,7 @@ import ( "sort" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -87,7 +88,7 @@ func GetOAuth2Providers() []Provider { func GetActiveOAuth2Providers() ([]string, map[string]Provider, error) { // Maybe also separate used and unused providers so we can force the registration of only 1 active provider for each type - loginSources, err := models.GetActiveOAuth2ProviderLoginSources() + loginSources, err := login.GetActiveOAuth2ProviderLoginSources() if err != nil { return nil, nil, err } diff --git a/services/auth/source/oauth2/source.go b/services/auth/source/oauth2/source.go index 7b22383d7ed6..49bb9a0148fb 100644 --- a/services/auth/source/oauth2/source.go +++ b/services/auth/source/oauth2/source.go @@ -6,6 +6,7 @@ package oauth2 import ( "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/json" ) @@ -27,7 +28,7 @@ type Source struct { SkipLocalTwoFA bool // reference to the loginSource - loginSource *models.LoginSource + loginSource *login.Source } // FromDB fills up an OAuth2Config from serialized format. @@ -41,10 +42,10 @@ func (source *Source) ToDB() ([]byte, error) { } // SetLoginSource sets the related LoginSource -func (source *Source) SetLoginSource(loginSource *models.LoginSource) { +func (source *Source) SetLoginSource(loginSource *login.Source) { source.loginSource = loginSource } func init() { - models.RegisterLoginTypeConfig(models.LoginOAuth2, &Source{}) + login.RegisterTypeConfig(login.OAuth2, &Source{}) } diff --git a/services/auth/source/pam/assert_interface_test.go b/services/auth/source/pam/assert_interface_test.go index a0bebdf9c679..a151c2f52e6e 100644 --- a/services/auth/source/pam/assert_interface_test.go +++ b/services/auth/source/pam/assert_interface_test.go @@ -5,7 +5,7 @@ package pam_test import ( - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/auth/source/pam" ) @@ -15,8 +15,8 @@ import ( type sourceInterface interface { auth.PasswordAuthenticator - models.LoginConfig - models.LoginSourceSettable + login.Config + login.SourceSettable } var _ (sourceInterface) = &pam.Source{} diff --git a/services/auth/source/pam/source.go b/services/auth/source/pam/source.go index 75aa99e45fd4..0bfa7cdb06b8 100644 --- a/services/auth/source/pam/source.go +++ b/services/auth/source/pam/source.go @@ -6,6 +6,7 @@ package pam import ( "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/json" ) @@ -22,7 +23,7 @@ type Source struct { EmailDomain string // reference to the loginSource - loginSource *models.LoginSource + loginSource *login.Source } // FromDB fills up a PAMConfig from serialized format. @@ -36,10 +37,10 @@ func (source *Source) ToDB() ([]byte, error) { } // SetLoginSource sets the related LoginSource -func (source *Source) SetLoginSource(loginSource *models.LoginSource) { +func (source *Source) SetLoginSource(loginSource *login.Source) { source.loginSource = loginSource } func init() { - models.RegisterLoginTypeConfig(models.LoginPAM, &Source{}) + login.RegisterTypeConfig(login.PAM, &Source{}) } diff --git a/services/auth/source/pam/source_authenticate.go b/services/auth/source/pam/source_authenticate.go index 8241aed7256f..ad6fbb5cce8f 100644 --- a/services/auth/source/pam/source_authenticate.go +++ b/services/auth/source/pam/source_authenticate.go @@ -9,6 +9,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/auth/pam" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/services/mailer" @@ -18,11 +19,11 @@ import ( // Authenticate queries if login/password is valid against the PAM, // and create a local user if success when enabled. -func (source *Source) Authenticate(user *models.User, login, password string) (*models.User, error) { - pamLogin, err := pam.Auth(source.ServiceName, login, password) +func (source *Source) Authenticate(user *models.User, userName, password string) (*models.User, error) { + pamLogin, err := pam.Auth(source.ServiceName, userName, password) if err != nil { if strings.Contains(err.Error(), "Authentication failure") { - return nil, models.ErrUserNotExist{Name: login} + return nil, models.ErrUserNotExist{Name: userName} } return nil, err } @@ -54,9 +55,9 @@ func (source *Source) Authenticate(user *models.User, login, password string) (* Name: username, Email: email, Passwd: password, - LoginType: models.LoginPAM, + LoginType: login.PAM, LoginSource: source.loginSource.ID, - LoginName: login, // This is what the user typed in + LoginName: userName, // This is what the user typed in IsActive: true, } diff --git a/services/auth/source/smtp/assert_interface_test.go b/services/auth/source/smtp/assert_interface_test.go index bc2042e06996..d1c982472fc6 100644 --- a/services/auth/source/smtp/assert_interface_test.go +++ b/services/auth/source/smtp/assert_interface_test.go @@ -5,7 +5,7 @@ package smtp_test import ( - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/auth" "code.gitea.io/gitea/services/auth/source/smtp" ) @@ -15,11 +15,11 @@ import ( type sourceInterface interface { auth.PasswordAuthenticator - models.LoginConfig - models.SkipVerifiable - models.HasTLSer - models.UseTLSer - models.LoginSourceSettable + login.Config + login.SkipVerifiable + login.HasTLSer + login.UseTLSer + login.SourceSettable } var _ (sourceInterface) = &smtp.Source{} diff --git a/services/auth/source/smtp/source.go b/services/auth/source/smtp/source.go index 39c9851ede23..487375c3044b 100644 --- a/services/auth/source/smtp/source.go +++ b/services/auth/source/smtp/source.go @@ -6,6 +6,7 @@ package smtp import ( "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/json" ) @@ -28,7 +29,7 @@ type Source struct { DisableHelo bool // reference to the loginSource - loginSource *models.LoginSource + loginSource *login.Source } // FromDB fills up an SMTPConfig from serialized format. @@ -57,10 +58,10 @@ func (source *Source) UseTLS() bool { } // SetLoginSource sets the related LoginSource -func (source *Source) SetLoginSource(loginSource *models.LoginSource) { +func (source *Source) SetLoginSource(loginSource *login.Source) { source.loginSource = loginSource } func init() { - models.RegisterLoginTypeConfig(models.LoginSMTP, &Source{}) + login.RegisterTypeConfig(login.SMTP, &Source{}) } diff --git a/services/auth/source/smtp/source_authenticate.go b/services/auth/source/smtp/source_authenticate.go index cff64c69d2f9..f50baa56a253 100644 --- a/services/auth/source/smtp/source_authenticate.go +++ b/services/auth/source/smtp/source_authenticate.go @@ -11,31 +11,32 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/services/mailer" ) // Authenticate queries if the provided login/password is authenticates against the SMTP server // Users will be autoregistered as required -func (source *Source) Authenticate(user *models.User, login, password string) (*models.User, error) { +func (source *Source) Authenticate(user *models.User, userName, password string) (*models.User, error) { // Verify allowed domains. if len(source.AllowedDomains) > 0 { - idx := strings.Index(login, "@") + idx := strings.Index(userName, "@") if idx == -1 { - return nil, models.ErrUserNotExist{Name: login} - } else if !util.IsStringInSlice(login[idx+1:], strings.Split(source.AllowedDomains, ","), true) { - return nil, models.ErrUserNotExist{Name: login} + return nil, models.ErrUserNotExist{Name: userName} + } else if !util.IsStringInSlice(userName[idx+1:], strings.Split(source.AllowedDomains, ","), true) { + return nil, models.ErrUserNotExist{Name: userName} } } var auth smtp.Auth switch source.Auth { case PlainAuthentication: - auth = smtp.PlainAuth("", login, password, source.Host) + auth = smtp.PlainAuth("", userName, password, source.Host) case LoginAuthentication: - auth = &loginAuthenticator{login, password} + auth = &loginAuthenticator{userName, password} case CRAMMD5Authentication: - auth = smtp.CRAMMD5Auth(login, password) + auth = smtp.CRAMMD5Auth(userName, password) default: return nil, errors.New("unsupported SMTP auth type") } @@ -46,11 +47,11 @@ func (source *Source) Authenticate(user *models.User, login, password string) (* tperr, ok := err.(*textproto.Error) if (ok && tperr.Code == 535) || strings.Contains(err.Error(), "Username and Password not accepted") { - return nil, models.ErrUserNotExist{Name: login} + return nil, models.ErrUserNotExist{Name: userName} } if (ok && tperr.Code == 534) || strings.Contains(err.Error(), "Application-specific password required") { - return nil, models.ErrUserNotExist{Name: login} + return nil, models.ErrUserNotExist{Name: userName} } return nil, err } @@ -59,20 +60,20 @@ func (source *Source) Authenticate(user *models.User, login, password string) (* return user, nil } - username := login - idx := strings.Index(login, "@") + username := userName + idx := strings.Index(userName, "@") if idx > -1 { - username = login[:idx] + username = userName[:idx] } user = &models.User{ LowerName: strings.ToLower(username), Name: strings.ToLower(username), - Email: login, + Email: userName, Passwd: password, - LoginType: models.LoginSMTP, + LoginType: login.SMTP, LoginSource: source.loginSource.ID, - LoginName: login, + LoginName: userName, IsActive: true, } diff --git a/services/auth/source/sspi/assert_interface_test.go b/services/auth/source/sspi/assert_interface_test.go index 605a6ec6c541..1efa69c05ba6 100644 --- a/services/auth/source/sspi/assert_interface_test.go +++ b/services/auth/source/sspi/assert_interface_test.go @@ -5,7 +5,7 @@ package sspi_test import ( - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/services/auth/source/sspi" ) @@ -13,7 +13,7 @@ import ( // It tightly binds the interfaces and implementation without breaking go import cycles type sourceInterface interface { - models.LoginConfig + login.Config } var _ (sourceInterface) = &sspi.Source{} diff --git a/services/auth/source/sspi/source.go b/services/auth/source/sspi/source.go index 58cb10de1df9..68fd6a607948 100644 --- a/services/auth/source/sspi/source.go +++ b/services/auth/source/sspi/source.go @@ -6,6 +6,7 @@ package sspi import ( "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/json" ) @@ -36,5 +37,5 @@ func (cfg *Source) ToDB() ([]byte, error) { } func init() { - models.RegisterLoginTypeConfig(models.LoginSSPI, &Source{}) + login.RegisterTypeConfig(login.SSPI, &Source{}) } diff --git a/services/auth/sspi_windows.go b/services/auth/sspi_windows.go index d7e0f55242aa..a4c39b00646d 100644 --- a/services/auth/sspi_windows.go +++ b/services/auth/sspi_windows.go @@ -10,6 +10,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -152,7 +153,7 @@ func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore, // getConfig retrieves the SSPI configuration from login sources func (s *SSPI) getConfig() (*sspi.Source, error) { - sources, err := models.ActiveLoginSources(models.LoginSSPI) + sources, err := login.ActiveSources(login.SSPI) if err != nil { return nil, err } @@ -248,7 +249,7 @@ func sanitizeUsername(username string, cfg *sspi.Source) string { // fails (or if negotiation should continue), which would prevent other authentication methods // to execute at all. func specialInit() { - if models.IsSSPIEnabled() { + if login.IsSSPIEnabled() { Register(&SSPI{}) } } diff --git a/services/auth/sync.go b/services/auth/sync.go index a34b4d1d2694..6d69650e5baa 100644 --- a/services/auth/sync.go +++ b/services/auth/sync.go @@ -8,6 +8,7 @@ import ( "context" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/log" ) @@ -15,7 +16,7 @@ import ( func SyncExternalUsers(ctx context.Context, updateExisting bool) error { log.Trace("Doing: SyncExternalUsers") - ls, err := models.LoginSources() + ls, err := login.Sources() if err != nil { log.Error("SyncExternalUsers: %v", err) return err diff --git a/services/externalaccount/user.go b/services/externalaccount/user.go index 45773fdb127d..e43b3ca7c5f5 100644 --- a/services/externalaccount/user.go +++ b/services/externalaccount/user.go @@ -8,6 +8,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/structs" "github.com/markbates/goth" @@ -15,7 +16,7 @@ import ( // LinkAccountToUser link the gothUser to the user func LinkAccountToUser(user *models.User, gothUser goth.User) error { - loginSource, err := models.GetActiveOAuth2LoginSourceByName(gothUser.Provider) + loginSource, err := login.GetActiveOAuth2LoginSourceByName(gothUser.Provider) if err != nil { return err } diff --git a/services/pull/commit_status.go b/services/pull/commit_status.go index c5c930ee0d8b..f1f351138b8e 100644 --- a/services/pull/commit_status.go +++ b/services/pull/commit_status.go @@ -7,6 +7,7 @@ package pull import ( "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/structs" @@ -129,7 +130,7 @@ func GetPullRequestCommitStatusState(pr *models.PullRequest) (structs.CommitStat return "", errors.Wrap(err, "LoadBaseRepo") } - commitStatuses, err := models.GetLatestCommitStatus(pr.BaseRepo.ID, sha, models.ListOptions{}) + commitStatuses, err := models.GetLatestCommitStatus(pr.BaseRepo.ID, sha, db.ListOptions{}) if err != nil { return "", errors.Wrap(err, "GetLatestCommitStatus") } diff --git a/services/pull/pull.go b/services/pull/pull.go index bd5551b6dcc0..f7e231379b7f 100644 --- a/services/pull/pull.go +++ b/services/pull/pull.go @@ -780,7 +780,7 @@ func getLastCommitStatus(gitRepo *git.Repository, pr *models.PullRequest) (statu return nil, err } - statusList, err := models.GetLatestCommitStatus(pr.BaseRepo.ID, sha, models.ListOptions{}) + statusList, err := models.GetLatestCommitStatus(pr.BaseRepo.ID, sha, db.ListOptions{}) if err != nil { return nil, err } diff --git a/services/pull/review.go b/services/pull/review.go index f65314c45d10..081b17cd83e8 100644 --- a/services/pull/review.go +++ b/services/pull/review.go @@ -138,7 +138,7 @@ func createCodeComment(doer *models.User, repo *models.Repository, issue *models Line: line, TreePath: treePath, Type: models.CommentTypeCode, - ListOptions: models.ListOptions{ + ListOptions: db.ListOptions{ PageSize: 1, Page: 1, }, From 623d2dd411b6a84a01bff3ca8046f1bd01773ffb Mon Sep 17 00:00:00 2001 From: zeripath Date: Fri, 24 Sep 2021 14:29:32 +0100 Subject: [PATCH 07/13] Prevent panic in Org mode HighlightCodeBlock (#17140) When rendering source in org mode there is a mistake in the highlight code that causes a panic. This PR fixes this. Fix #17139 Signed-off-by: Andrew Thornton --- modules/highlight/highlight.go | 23 ++++++++++++----------- modules/markup/orgmode/orgmode.go | 9 ++++++++- modules/markup/orgmode/orgmode_test.go | 26 ++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/modules/highlight/highlight.go b/modules/highlight/highlight.go index 079f7a44bd11..6684fbe84291 100644 --- a/modules/highlight/highlight.go +++ b/modules/highlight/highlight.go @@ -66,17 +66,6 @@ func Code(fileName, code string) string { if len(code) > sizeLimit { return code } - formatter := html.New(html.WithClasses(true), - html.WithLineNumbers(false), - html.PreventSurroundingPre(true), - ) - if formatter == nil { - log.Error("Couldn't create chroma formatter") - return code - } - - htmlbuf := bytes.Buffer{} - htmlw := bufio.NewWriter(&htmlbuf) var lexer chroma.Lexer if val, ok := highlightMapping[filepath.Ext(fileName)]; ok { @@ -97,6 +86,18 @@ func Code(fileName, code string) string { } cache.Add(fileName, lexer) } + return CodeFromLexer(lexer, code) +} + +// CodeFromLexer returns a HTML version of code string with chroma syntax highlighting classes +func CodeFromLexer(lexer chroma.Lexer, code string) string { + formatter := html.New(html.WithClasses(true), + html.WithLineNumbers(false), + html.PreventSurroundingPre(true), + ) + + htmlbuf := bytes.Buffer{} + htmlw := bufio.NewWriter(&htmlbuf) iterator, err := lexer.Tokenise(nil, string(code)) if err != nil { diff --git a/modules/markup/orgmode/orgmode.go b/modules/markup/orgmode/orgmode.go index 7e9f1f45c5a7..b035e04a1fcc 100644 --- a/modules/markup/orgmode/orgmode.go +++ b/modules/markup/orgmode/orgmode.go @@ -12,6 +12,7 @@ import ( "strings" "code.gitea.io/gitea/modules/highlight" + "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/markup" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/util" @@ -51,6 +52,12 @@ func (Renderer) SanitizerRules() []setting.MarkupSanitizerRule { func Render(ctx *markup.RenderContext, input io.Reader, output io.Writer) error { htmlWriter := org.NewHTMLWriter() htmlWriter.HighlightCodeBlock = func(source, lang string, inline bool) string { + defer func() { + if err := recover(); err != nil { + log.Error("Panic in HighlightCodeBlock: %v\n%s", err, log.Stack(2)) + panic(err) + } + }() var w strings.Builder if _, err := w.WriteString(`
`); err != nil {
 			return ""
@@ -80,7 +87,7 @@ func Render(ctx *markup.RenderContext, input io.Reader, output io.Writer) error
 			}
 			lexer = chroma.Coalesce(lexer)
 
-			if _, err := w.WriteString(highlight.Code(lexer.Config().Filenames[0], source)); err != nil {
+			if _, err := w.WriteString(highlight.CodeFromLexer(lexer, source)); err != nil {
 				return ""
 			}
 		}
diff --git a/modules/markup/orgmode/orgmode_test.go b/modules/markup/orgmode/orgmode_test.go
index da89326e9e13..81d0d66a76c5 100644
--- a/modules/markup/orgmode/orgmode_test.go
+++ b/modules/markup/orgmode/orgmode_test.go
@@ -57,3 +57,29 @@ func TestRender_Images(t *testing.T) {
 	test("[[file:"+url+"]]",
 		"

\""+result+"\"

") } + +func TestRender_Source(t *testing.T) { + setting.AppURL = AppURL + setting.AppSubURL = AppSubURL + + test := func(input, expected string) { + buffer, err := RenderString(&markup.RenderContext{ + URLPrefix: setting.AppSubURL, + }, input) + assert.NoError(t, err) + assert.Equal(t, strings.TrimSpace(expected), strings.TrimSpace(buffer)) + } + + test(`#+begin_src go +// HelloWorld prints "Hello World" +func HelloWorld() { + fmt.Println("Hello World") +} +#+end_src +`, `
+
// HelloWorld prints "Hello World"
+func HelloWorld() {
+	fmt.Println("Hello World")
+}
+
`) +} From cbd5dc4dd6936862e57e9fa257879176c273d6e6 Mon Sep 17 00:00:00 2001 From: GiteaBot Date: Sat, 25 Sep 2021 00:04:51 +0000 Subject: [PATCH 08/13] [skip ci] Updated translations via Crowdin --- options/locale/locale_cs-CZ.ini | 5 ---- options/locale/locale_de-DE.ini | 5 ---- options/locale/locale_el-GR.ini | 5 ---- options/locale/locale_es-ES.ini | 5 ---- options/locale/locale_fa-IR.ini | 2 -- options/locale/locale_fr-FR.ini | 5 ---- options/locale/locale_it-IT.ini | 2 -- options/locale/locale_ja-JP.ini | 5 ---- options/locale/locale_lv-LV.ini | 5 ---- options/locale/locale_nl-NL.ini | 2 -- options/locale/locale_pl-PL.ini | 2 -- options/locale/locale_pt-BR.ini | 2 -- options/locale/locale_pt-PT.ini | 5 ---- options/locale/locale_ru-RU.ini | 48 +++++++++++++++++++++++++++++---- options/locale/locale_sv-SE.ini | 1 - options/locale/locale_tr-TR.ini | 5 ---- options/locale/locale_uk-UA.ini | 5 ---- options/locale/locale_zh-CN.ini | 5 ---- options/locale/locale_zh-TW.ini | 12 ++++----- 19 files changed, 49 insertions(+), 77 deletions(-) diff --git a/options/locale/locale_cs-CZ.ini b/options/locale/locale_cs-CZ.ini index 9261e5cd3d55..0bff600e394c 100644 --- a/options/locale/locale_cs-CZ.ini +++ b/options/locale/locale_cs-CZ.ini @@ -1155,11 +1155,6 @@ issues.action_milestone_no_select=Žádný milník issues.action_assignee=Zpracovatel issues.action_assignee_no_select=Bez zpracovatele issues.opened_by=otevřeno %[1]s uživatelem %[3]s -pulls.merged_by=od %[3]s sloučen %[1]s -pulls.merged_by_fake=od %[2]s sloučen %[1]s -issues.closed_by=od %[3]s uzavřen %[1]s -issues.opened_by_fake=od %[2]s otevřen %[1]s -issues.closed_by_fake=od %[2]s uzavřen %[1]s issues.previous=Předchozí issues.next=Další issues.open_title=otevřený diff --git a/options/locale/locale_de-DE.ini b/options/locale/locale_de-DE.ini index 79ed26a51839..f8490488dac7 100644 --- a/options/locale/locale_de-DE.ini +++ b/options/locale/locale_de-DE.ini @@ -1197,11 +1197,6 @@ issues.action_milestone_no_select=Kein Meilenstein issues.action_assignee=Zuständig issues.action_assignee_no_select=Niemand zuständig issues.opened_by=%[1]s von %[3]s geöffnet -pulls.merged_by=von %[3]s %[1]s zusammengefügt -pulls.merged_by_fake=von %[2]s zusammengefügt %[1]s -issues.closed_by=von %[3]s %[1]s geschlossen -issues.opened_by_fake=von %[2]s %[1]s geöffnet -issues.closed_by_fake=von %[2]s %[1]s geschlossen issues.previous=Vorherige issues.next=Nächste issues.open_title=Offen diff --git a/options/locale/locale_el-GR.ini b/options/locale/locale_el-GR.ini index 9a42db7d7191..a8cc787351b0 100644 --- a/options/locale/locale_el-GR.ini +++ b/options/locale/locale_el-GR.ini @@ -1197,11 +1197,6 @@ issues.action_milestone_no_select=Χωρίς ορόσημο issues.action_assignee=Αποδέκτης issues.action_assignee_no_select=Κανένας Αποδέκτης issues.opened_by=άνοιξαν %[1]s από %[3]s -pulls.merged_by=από %[3]s συγχωνεύτηκαν %[1]s -pulls.merged_by_fake=από %[2]s συγχωνεύτηκαν %[1]s -issues.closed_by=από %[3]s έκλεισε %[1]s -issues.opened_by_fake=από %[2]s άνοιξαν %[1]s -issues.closed_by_fake=από %[2]s έκλεισε %[1]s issues.previous=Προηγούμενο issues.next=Επόμενο issues.open_title=Ανοιχτό diff --git a/options/locale/locale_es-ES.ini b/options/locale/locale_es-ES.ini index fb777b3fae25..48a3f350790a 100644 --- a/options/locale/locale_es-ES.ini +++ b/options/locale/locale_es-ES.ini @@ -1197,11 +1197,6 @@ issues.action_milestone_no_select=Sin hito issues.action_assignee=Asignado a issues.action_assignee_no_select=Sin asignado issues.opened_by=abierta %[1]s por %[3]s -pulls.merged_by=por %[3]s fusionado %[1]s -pulls.merged_by_fake=por %[2]s fusionado %[1]s -issues.closed_by=por %[3]s cerrado %[1]s -issues.opened_by_fake=por %[2]s abierto %[1]s -issues.closed_by_fake=por %[2]s cerrado %[1]s issues.previous=Página Anterior issues.next=Página Siguiente issues.open_title=Abierta diff --git a/options/locale/locale_fa-IR.ini b/options/locale/locale_fa-IR.ini index 4a440031370b..66ca3f8a37d5 100644 --- a/options/locale/locale_fa-IR.ini +++ b/options/locale/locale_fa-IR.ini @@ -950,8 +950,6 @@ issues.action_milestone_no_select=بدون نقطه عطف issues.action_assignee=مسئول رسیدگی issues.action_assignee_no_select=بدون مسئول رسیدگی issues.opened_by=%[1]s باز شده توسط %[3]s -issues.closed_by=%[1]s بوسیله %[3]s بسته شده است -issues.closed_by_fake=%[1]s بوسیله %[2]s بسته شده است issues.previous=قبلی issues.next=بعدی issues.open_title=باز diff --git a/options/locale/locale_fr-FR.ini b/options/locale/locale_fr-FR.ini index f0943f835b93..0b74a9501bb6 100644 --- a/options/locale/locale_fr-FR.ini +++ b/options/locale/locale_fr-FR.ini @@ -1137,11 +1137,6 @@ issues.action_milestone_no_select=Aucun jalon issues.action_assignee=Assigné à issues.action_assignee_no_select=Pas d'assignataire issues.opened_by=créé %[1]s par %[3]s -pulls.merged_by=par %[3]s fusionné %[1]s -pulls.merged_by_fake=par %[2]s fusionnés %[1]s -issues.closed_by=par %[3]s fermé %[1]s -issues.opened_by_fake=par %[2]s ouverts %[1]s -issues.closed_by_fake=par %[2]s fermé %[1]s issues.previous=Page Précédente issues.next=Page Suivante issues.open_title=Ouvert diff --git a/options/locale/locale_it-IT.ini b/options/locale/locale_it-IT.ini index d318de2c1459..0f685d0b8459 100644 --- a/options/locale/locale_it-IT.ini +++ b/options/locale/locale_it-IT.ini @@ -1055,8 +1055,6 @@ issues.action_milestone_no_select=Nessuna pietra miliare issues.action_assignee=Assegnatario issues.action_assignee_no_select=Nessun assegnatario issues.opened_by=aperto %[1]s da %[3]s -issues.closed_by=del %[3]s chiuso %[1]s -issues.closed_by_fake=della %[2]s chiusa %[1]s issues.previous=Pagina precedente issues.next=Pagina successiva issues.open_title=Aperto diff --git a/options/locale/locale_ja-JP.ini b/options/locale/locale_ja-JP.ini index a6ed9c362c93..a78c233a1816 100644 --- a/options/locale/locale_ja-JP.ini +++ b/options/locale/locale_ja-JP.ini @@ -1197,11 +1197,6 @@ issues.action_milestone_no_select=マイルストーンなし issues.action_assignee=担当者 issues.action_assignee_no_select=担当者なし issues.opened_by=%[3]sが%[1]sに作成 -pulls.merged_by=%[3]sが作成、%[1]sにマージ -pulls.merged_by_fake=%[2]sが作成、%[1]sにマージ -issues.closed_by=%[3]sが作成、%[1]sにクローズ -issues.opened_by_fake=%[2]sが%[1]sにオープン -issues.closed_by_fake=%[2]sが作成、%[1]sにクローズ issues.previous=前ページ issues.next=次ページ issues.open_title=オープン diff --git a/options/locale/locale_lv-LV.ini b/options/locale/locale_lv-LV.ini index 5af6c34b6eb7..142fc3c0edc1 100644 --- a/options/locale/locale_lv-LV.ini +++ b/options/locale/locale_lv-LV.ini @@ -1186,11 +1186,6 @@ issues.action_milestone_no_select=Nav atskaites punkta issues.action_assignee=Atbildīgais issues.action_assignee_no_select=Nav atbildīgā issues.opened_by=%[3]s atvēra %[1]s -pulls.merged_by=%[3]s sapludināja %[1]s -pulls.merged_by_fake=%[2]s sapludināja %[1]s -issues.closed_by=%[3]s aizvēra %[1]s -issues.opened_by_fake=%[2]s atvēra %[1]s -issues.closed_by_fake=%[2]s aizvēra %[1]s issues.previous=Iepriekšējā issues.next=Nākamā issues.open_title=Atvērta diff --git a/options/locale/locale_nl-NL.ini b/options/locale/locale_nl-NL.ini index 40d16843cc92..90c4a4b189b8 100644 --- a/options/locale/locale_nl-NL.ini +++ b/options/locale/locale_nl-NL.ini @@ -1044,8 +1044,6 @@ issues.action_milestone_no_select=Geen mijlpaal issues.action_assignee=Toegewezene issues.action_assignee_no_select=Geen verantwoordelijke issues.opened_by=%[1]s geopend door %[3]s -issues.closed_by=door %[3]s gesloten %[1]s -issues.closed_by_fake=met %[2]gesloten %[1]s issues.previous=Vorige issues.next=Volgende issues.open_title=Open diff --git a/options/locale/locale_pl-PL.ini b/options/locale/locale_pl-PL.ini index 15db4b37c356..aca75f6cfb3c 100644 --- a/options/locale/locale_pl-PL.ini +++ b/options/locale/locale_pl-PL.ini @@ -972,8 +972,6 @@ issues.action_milestone_no_select=Brak kamieni milowych issues.action_assignee=Przypisany issues.action_assignee_no_select=Brak przypisania issues.opened_by=otworzone %[1]s przez %[3]s -issues.closed_by=przez %[3]s zamknięte %[1]s -issues.closed_by_fake=przez %[2]s zamknięte %[1]s issues.previous=Poprzedni issues.next=Następny issues.open_title=Otwarty diff --git a/options/locale/locale_pt-BR.ini b/options/locale/locale_pt-BR.ini index 3beef302b8e4..44ba152e518a 100644 --- a/options/locale/locale_pt-BR.ini +++ b/options/locale/locale_pt-BR.ini @@ -1136,8 +1136,6 @@ issues.action_milestone_no_select=Sem marco issues.action_assignee=Responsável issues.action_assignee_no_select=Sem responsável issues.opened_by=aberto por %[3]s %[1]s -pulls.merged_by=por %[3]s merge aplicado %[1]s -pulls.merged_by_fake=por %[2]s merge aplicado %[1]s issues.previous=Anterior issues.next=Próximo issues.open_title=Aberto diff --git a/options/locale/locale_pt-PT.ini b/options/locale/locale_pt-PT.ini index e5b3a5e17ca9..30ee45048ed0 100644 --- a/options/locale/locale_pt-PT.ini +++ b/options/locale/locale_pt-PT.ini @@ -1197,11 +1197,6 @@ issues.action_milestone_no_select=Sem etapa issues.action_assignee=Responsável issues.action_assignee_no_select=Sem responsável issues.opened_by=aberta %[1]s por %[3]s -pulls.merged_by=de %[3]s integrado %[1]s -pulls.merged_by_fake=por %[2]s integrou %[1]s -issues.closed_by=de %[3]s fechada %[1]s -issues.opened_by_fake=de %[2]s aberto %[1]s -issues.closed_by_fake=de %[2]s fechada %[1]s issues.previous=Anterior issues.next=Seguinte issues.open_title=Aberta diff --git a/options/locale/locale_ru-RU.ini b/options/locale/locale_ru-RU.ini index 1596b9347861..c6ea930cd0bd 100644 --- a/options/locale/locale_ru-RU.ini +++ b/options/locale/locale_ru-RU.ini @@ -350,6 +350,7 @@ issue_assigned.issue=@%[1]s назначил вам задачу %[2]s в реп issue.x_mentioned_you=@%s упомянул вас: issue.action.force_push=%[1]s форсировал отправку изменений %[2]s с %[3]s до %[4]s. +issue.action.push_1=@%[1]s отправил %[3]d изменение %[2]s issue.action.push_n=@%[1]s отправил %[3]d изменений %[2]s issue.action.close=@%[1]s закрыты #%[2]d. issue.action.reopen=@%[1]s переоткрыты #%[2]d. @@ -596,6 +597,7 @@ ssh_principal_been_used=Участник уже был добавлен на с gpg_key_id_used=Публичный GPG ключ с таким же идентификатором уже существует. gpg_no_key_email_found=Этот GPG ключ не соответствует ни одному активному адресу электронной почты, связанному с вашей учетной записью. Он по-прежнему может быть добавлен, если вы подписали указанный токен. gpg_key_matched_identities=Соответствующие идентификаторы: +gpg_key_matched_identities_long=Встроенные в этот ключ идентификаторы соответствуют следующим активным email-адресам этого пользователя и коммиты, соответствующие этим email-адресам могут быть проверены с помощью этого ключа. gpg_key_verified=Проверенный ключ gpg_key_verified_long=Ключ был проверен токеном и может быть использован для проверки коммитов, соответствующих любым активным адресом электронной почты этого пользователя в дополнение к любым соответствующим идентификаторам этого ключа. gpg_key_verify=Проверить @@ -766,6 +768,10 @@ fork_repo=Форкнуть репозиторий fork_from=Форк от fork_visibility_helper=Видимость форкнутого репозитория изменить нельзя. use_template=Использовать этот шаблон +clone_in_vsc=Клонировать в VS Code +download_zip=Скачать ZIP +download_tar=Скачать TAR.GZ +download_bundle=Скачать BUNDLE generate_repo=Создать репозиторий generate_from=Создать из repo_desc=Описание @@ -894,6 +900,12 @@ migrate.migrate=Миграция из %s migrate.migrating=Перенос из %s... migrate.migrating_failed=Перенос из %s не удался. migrate.migrating_failed.error=Ошибка: %s +migrate.github.description=Перенести данные с github.com или других экземпляров GitHub Enterprise Server. +migrate.git.description=Перенести только репозиторий из любого Git сервиса. +migrate.gitlab.description=Перенести данные с gitlab.com или других экземпляров GitLab. +migrate.gitea.description=Перенести данные с gitea.com или других экземпляров Gitea. +migrate.gogs.description=Перенести данные с notabug.org или других экземпляров Gogs. +migrate.onedev.description=Перенести данные с code.onedev.io или других экземпляров OneDev. migrate.migrating_git=Перенос Git данных migrate.migrating_topics=Миграция тем migrate.migrating_milestones=Миграция этапов @@ -1033,6 +1045,7 @@ editor.require_signed_commit=Ветка ожидает подписанный к commits.desc=Просмотр истории изменений исходного кода. commits.commits=Коммитов commits.no_commits=Ничего общего в коммитах. '%s' и '%s' имеют совершенно разные истории. +commits.nothing_to_compare=Эти ветки одинаковы. commits.search=Поиск коммитов… commits.search.tooltip=Вы можете предварять ключевые слова словами "author:", "committer:", "after:", или "before:", например, "revert author:Alice before:2019-04-01". commits.find=Поиск @@ -1184,11 +1197,11 @@ issues.action_milestone_no_select=Нет этапа issues.action_assignee=Ответственный issues.action_assignee_no_select=Нет ответственного issues.opened_by=открыта %[1]s %[3]s -pulls.merged_by=на %[3]s мигрированных %[1]s -pulls.merged_by_fake=%[2]s мигрировал %[1]s -issues.closed_by=на %[3]s закрытых %[1]s -issues.opened_by_fake=%[2]s открыл(а) %[1]s -issues.closed_by_fake=%[2]s закрыл(а) %[1]s +pulls.merged_by=слито %[1]s пользователем %[3]s +pulls.merged_by_fake=слито %[1]s пользователем %[2]s +issues.closed_by=закрыт %[1]s пользователем %[3]s +issues.opened_by_fake=открыт %[1]s пользователем %[2]s +issues.closed_by_fake=закрыт %[1]s пользователем %[2]s issues.previous=Предыдущая страница issues.next=Следующая страница issues.open_title=Открыто @@ -1319,6 +1332,8 @@ issues.dependency.remove=Удалить issues.dependency.remove_info=Удалить эту зависимость issues.dependency.added_dependency=`добавить новую зависимость %s` issues.dependency.removed_dependency=`убрал зависимость %s` +issues.dependency.pr_closing_blockedby=Закрытие этого Pull Request'а блокируется следующими задачами +issues.dependency.issue_closing_blockedby=Закрытие этой задачи блокируется следующими задачами issues.dependency.issue_close_blocks=Эта задача блокирует закрытие следующих задач issues.dependency.pr_close_blocks=Этот запрос на слияние блокирует закрытие следующих задач issues.dependency.issue_close_blocked=Вам необходимо закрыть все задачи, блокирующие эту задачу, прежде чем вы сможете её закрыть. @@ -1427,6 +1442,10 @@ pulls.no_merge_helper=Включите опции слияния в настро pulls.no_merge_wip=Данный Pull Request не может быть принят, поскольку он помечен как находящийся в разработке. pulls.no_merge_not_ready=Этот запрос не готов к слиянию, обратите внимания на ревью и проверки. pulls.no_merge_access=У вас нет права для слияния данного запроса. +pulls.merge_pull_request=Создать коммит на слияние +pulls.rebase_merge_pull_request=Выпольнить Rebase, а затем fast-forward слияние +pulls.rebase_merge_commit_pull_request=Выпольнить rebase, а затем создать коммит слияния +pulls.squash_merge_pull_request=Создать объединенный (squash) коммит pulls.merge_manually=Слито вручную pulls.merge_commit_id=ID коммита слияния pulls.require_signed_wont_sign=Данная ветка ожидает подписанные коммиты, однако слияние не будет подписано @@ -1449,6 +1468,8 @@ pulls.status_checks_failure=Некоторые проверки не удали pulls.status_checks_error=Некоторые проверки сообщили об ошибках pulls.status_checks_requested=Требуется pulls.status_checks_details=Информация +pulls.update_branch=Обновить ветку посредством слияния +pulls.update_branch_rebase=Обновить ветку через rebase pulls.update_branch_success=Обновление ветки выполнено успешно pulls.update_not_allowed=У вас недостаточно прав для обновления ветки pulls.outdated_with_base_branch=Эта ветка отстает от базовой ветки @@ -1834,6 +1855,7 @@ settings.add_telegram_hook_desc=Добавить интеграцию с Matrix в ваш репозиторий. settings.add_msteams_hook_desc=Добавить интеграцию с Microsoft Teams в ваш репозиторий. settings.add_feishu_hook_desc=Добавить интеграцию Feishu в ваш репозиторий. +settings.add_Wechat_hook_desc=Добавить интеграцию с Wechatwork в ваш репозиторий. settings.deploy_keys=Ключи развертывания settings.add_deploy_key=Добавить ключ развертывания settings.deploy_key_desc=Ключи развёртывания доступны только для чтения. Это не то же самое что и SSH-ключи аккаунта. @@ -1886,6 +1908,8 @@ settings.require_signed_commits=Требовать подписанные ком settings.require_signed_commits_desc=Отклонить push'ы в эту ветку, если они не подписаны или не проверены. settings.protect_protected_file_patterns=Защищённые шаблоны файлов (разделённые через '\;'): settings.protect_protected_file_patterns_desc=Защищенные файлы, которые не могут быть изменены напрямую, даже если пользователь имеет право добавлять, редактировать или удалять файлы в этой ветке. Шаблоны могут быть разделены точкой с запятой ('\;'). Смотрите github.com/gobwas/glob документацию для синтаксиса шаблонов. Например: .drone.yml, /docs/**/*.txt. +settings.protect_unprotected_file_patterns=Незащищённые шаблоны файлов (разделённые через '\;'): +settings.protect_unprotected_file_patterns_desc=Незащищенные файлы, которые могут быть изменены напрямую, если пользователь имеет доступ на запись, ограничения связанные с push здесь не влияют. Шаблоны могут быть разделены точкой с запятой ('\;'). Смотрите github.com/gobwas/glob документацию для синтаксиса шаблонов. Например: .drone.yml, /docs/**/*.txt. settings.add_protected_branch=Включить защиту settings.delete_protected_branch=Отключить защиту settings.update_protect_branch_success=Настройки защиты ветки '%s' были успешно изменены. @@ -1989,6 +2013,8 @@ diff.file_byte_size=Размер diff.file_suppressed=Разница между файлами не показана из-за своего большого размера diff.file_suppressed_line_too_long=Различия файлов скрыты, потому что одна или несколько строк слишком длинны diff.too_many_files=Некоторые файлы не были показаны из-за большого количества измененных файлов +diff.generated=сгенерированный +diff.vendored=поставляемый diff.comment.placeholder=Оставить комментарий diff.comment.markdown_info=Поддерживается синтаксис Markdown. diff.comment.add_single_comment=Добавить простой комментарий @@ -2155,12 +2181,15 @@ members.member_role=Роль участника: members.owner=Владелец members.member=Участник members.remove=Удалить +members.remove.detail=Исключить %[1]s из %[2]s? members.leave=Покинуть +members.leave.detail=Покинуть %s? members.invite_desc=Добавить нового участника в %s: members.invite_now=Пригласите сейчас teams.join=Объединить teams.leave=Выйти +teams.leave.detail=Покинуть %s? teams.can_create_org_repo=Создать репозитории teams.can_create_org_repo_helper=Участники могут создавать новые репозитории в организации. Создатель получит администраторский доступ к новому репозиторию. teams.read_access=Доступ на чтение @@ -2412,6 +2441,11 @@ auths.smtpport=SMTP-порт auths.allowed_domains=Разрешенные домены auths.allowed_domains_helper=Оставьте пустым, чтобы разрешить все домены. Разделите несколько доменов запятой (','). auths.skip_tls_verify=Пропустить проверку TLS +auths.force_smtps=Принудительный SMTPS +auths.force_smtps_helper=SMTPS всегда использует 465 порт. Установите это, что бы принудительно использовать SMTPS на других портах. (Иначе STARTTLS будет использоваться на других портах, если это поддерживается хостом.) +auths.helo_hostname=HELO Hostname +auths.helo_hostname_helper=Имя хоста отправляется с HELO. Оставьте поле пустым, чтобы отправить текущее имя хоста. +auths.disable_helo=Отключить HELO auths.pam_service_name=Имя службы PAM auths.pam_email_domain=Домен почты PAM (необязательно) auths.oauth2_provider=Поставщик OAuth2 @@ -2424,6 +2458,9 @@ auths.oauth2_tokenURL=URL токена auths.oauth2_authURL=URL авторизации auths.oauth2_profileURL=URL аккаунта auths.oauth2_emailURL=URL-адрес электронной почты +auths.skip_local_two_fa=Пропустить локальную двухфакторную аутентификацию +auths.skip_local_two_fa_helper=Если значение не задано, локальным пользователям с установленной двухфакторной аутентификацией все равно придется пройти двухфакторную аутентификацию для входа в систему +auths.oauth2_tenant=Tenant auths.enable_auto_register=Включить автоматическую регистрацию auths.sspi_auto_create_users=Автоматически создавать пользователей auths.sspi_auto_create_users_helper=Разрешить метод аутентификации SSPI для автоматического создания новых учётных записей для пользователей, которые впервые входят в систему @@ -2695,6 +2732,7 @@ comment_issue=`прокомментировал(а) задачу %s#%[2]s` merge_pull_request=`принял(а) Pull Request %s#%[2]s` transfer_repo=передал(а) репозиторий %s %s +push_tag=создал(а) тэг %[4]s в %[3]s delete_tag=удалил(а) тэг %[2]s из %[3]s delete_branch=удалил(а) ветку %[2]s из %[3]s compare_branch=Сравнить diff --git a/options/locale/locale_sv-SE.ini b/options/locale/locale_sv-SE.ini index 6bb47ca0ee69..236202f17dd3 100644 --- a/options/locale/locale_sv-SE.ini +++ b/options/locale/locale_sv-SE.ini @@ -1000,7 +1000,6 @@ issues.action_milestone_no_select=Ingen Milsten issues.action_assignee=Tilldelad issues.action_assignee_no_select=Ingen tilldelad issues.opened_by=öppnade %[1]s av %[3]s -issues.closed_by_fake=av %[2]s stängde %[1]s issues.previous=Föregående issues.next=Nästa issues.open_title=Öppen diff --git a/options/locale/locale_tr-TR.ini b/options/locale/locale_tr-TR.ini index 043e4b63c65f..fdd7973dc62e 100644 --- a/options/locale/locale_tr-TR.ini +++ b/options/locale/locale_tr-TR.ini @@ -1175,11 +1175,6 @@ issues.action_milestone_no_select=Kilometre Taşı Yok issues.action_assignee=Atanan issues.action_assignee_no_select=Atanan yok issues.opened_by=%[3]s tarafından %[1]s açıldı -pulls.merged_by=%[1]s %[3]s tarafından açılan istek birleştirildi -pulls.merged_by_fake=%[2]s tarafından açılan istek %[1]s birleştirildi -issues.closed_by=%[1]s %[3]s tarafından kapatıldı -issues.opened_by_fake=%[1]s %[2]s tarafından açıldı -issues.closed_by_fake=%[1]s %[2]s tarafından kapatıldı issues.previous=Önceki issues.next=Sonraki issues.open_title=Açık diff --git a/options/locale/locale_uk-UA.ini b/options/locale/locale_uk-UA.ini index 15555f315c14..12e695374f76 100644 --- a/options/locale/locale_uk-UA.ini +++ b/options/locale/locale_uk-UA.ini @@ -1166,11 +1166,6 @@ issues.action_milestone_no_select=Етап відсутній issues.action_assignee=Виконавець issues.action_assignee_no_select=Немає виконавеця issues.opened_by=%[1]s відкрито %[3]s -pulls.merged_by=до %[3] злито %[1]s -pulls.merged_by_fake=%[2]s об'єднаний %[1]s -issues.closed_by=закрито %[3]s %[1]s -issues.opened_by_fake=%[2]s відкрив(ла) %[1]s -issues.closed_by_fake=закрито %[2]s %[1]s issues.previous=Попередній issues.next=Далі issues.open_title=Відкрито diff --git a/options/locale/locale_zh-CN.ini b/options/locale/locale_zh-CN.ini index f5a76a6097d8..2439f2acea6a 100644 --- a/options/locale/locale_zh-CN.ini +++ b/options/locale/locale_zh-CN.ini @@ -1197,11 +1197,6 @@ issues.action_milestone_no_select=无里程碑 issues.action_assignee=指派人筛选 issues.action_assignee_no_select=未指派 issues.opened_by=由 %[3]s 于 %[1]s创建 -pulls.merged_by=%[3]s 合并于 %[1]s -pulls.merged_by_fake=%[2]s 合并于 %[1]s -issues.closed_by=%[3]s 关闭于 %[1]s -issues.opened_by_fake=%[2]s 创建于 %[1]s -issues.closed_by_fake=%[2]s 关闭于 %[1]s issues.previous=上一页 issues.next=下一页 issues.open_title=开启中 diff --git a/options/locale/locale_zh-TW.ini b/options/locale/locale_zh-TW.ini index 03a021f7ff09..086383868791 100644 --- a/options/locale/locale_zh-TW.ini +++ b/options/locale/locale_zh-TW.ini @@ -1196,12 +1196,12 @@ issues.action_milestone=里程碑 issues.action_milestone_no_select=無里程碑 issues.action_assignee=成員 issues.action_assignee_no_select=沒有成員 -issues.opened_by=由 %[3]s 於 %[1]s建立 -pulls.merged_by=由 %[3]s 建立,%[1]s合併 -pulls.merged_by_fake=由 %[2]s 建立,%[1]s合併 -issues.closed_by=由 %[3]s 建立,%[1]s關閉 -issues.opened_by_fake=由 %[2]s 建立,%[1]s開放 -issues.closed_by_fake=由 %[2]s 建立,%[1]s關閉 +issues.opened_by=建立於 %[1]s 由 %[3]s +pulls.merged_by=合併於 %[1]s,由 %[3]s 建立 +pulls.merged_by_fake=合併於 %[1]s,由 %[2]s 建立 +issues.closed_by=關閉於 %[1]s,由 %[3]s 建立 +issues.opened_by_fake=建立於 %[1]s 由 %[2]s +issues.closed_by_fake=關閉於 %[1]s,由 %[2]s 建立 issues.previous=上一頁 issues.next=下一頁 issues.open_title=開放中 From 6fb7fb6cfc4e5dd14caa13c5d2965a1e98efdcaf Mon Sep 17 00:00:00 2001 From: sebastian-sauer Date: Sat, 25 Sep 2021 08:45:55 +0200 Subject: [PATCH 09/13] Force color-adjust for markdown checkboxes (#17146) this forces browsers to render background correctly Co-authored-by: techknowlogick --- web_src/less/markup/content.less | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/web_src/less/markup/content.less b/web_src/less/markup/content.less index df87c21d8c54..8d9858f0df30 100644 --- a/web_src/less/markup/content.less +++ b/web_src/less/markup/content.less @@ -181,6 +181,8 @@ opacity: 1 !important; // override fomantic on edit preview pointer-events: auto !important; // override fomantic on edit preview vertical-align: middle !important; // override fomantic on edit preview + -webkit-print-color-adjust: exact; + color-adjust: exact; } input[type="checkbox"]:not([disabled]):hover, @@ -204,6 +206,8 @@ content: ""; mask-image: var(--checkbox-mask-checked); -webkit-mask-image: var(--checkbox-mask-checked); + -webkit-print-color-adjust: exact; + color-adjust: exact; } input[type="checkbox"]:indeterminate::after { From 91e21d4fca8b867614d08537e92bc6c8fc7b0444 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 25 Sep 2021 21:00:12 +0800 Subject: [PATCH 10/13] Move twofactor to models/login (#17143) --- models/error.go | 41 ---------------------- models/login/main_test.go | 1 + models/{ => login}/twofactor.go | 28 ++++++++++++--- models/{ => login}/u2f.go | 34 +++++++++++++++--- models/{ => login}/u2f_test.go | 8 ++--- models/pull_sign.go | 5 +-- models/repo_sign.go | 13 +++---- models/token.go | 5 +-- models/userlist.go | 5 +-- modules/context/api.go | 5 +-- modules/context/auth.go | 6 ++-- routers/web/admin/users.go | 10 +++--- routers/web/repo/http.go | 5 +-- routers/web/user/auth.go | 40 ++++++++++----------- routers/web/user/setting/security.go | 6 ++-- routers/web/user/setting/security_twofa.go | 28 +++++++-------- routers/web/user/setting/security_u2f.go | 12 +++---- 17 files changed, 131 insertions(+), 121 deletions(-) rename models/{ => login}/twofactor.go (83%) rename models/{ => login}/u2f.go (71%) rename models/{ => login}/u2f_test.go (91%) diff --git a/models/error.go b/models/error.go index 956b24009735..1179fa6eb751 100644 --- a/models/error.go +++ b/models/error.go @@ -1876,25 +1876,6 @@ func (err ErrTeamNotExist) Error() string { return fmt.Sprintf("team does not exist [org_id %d, team_id %d, name: %s]", err.OrgID, err.TeamID, err.Name) } -// -// Two-factor authentication -// - -// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication. -type ErrTwoFactorNotEnrolled struct { - UID int64 -} - -// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled. -func IsErrTwoFactorNotEnrolled(err error) bool { - _, ok := err.(ErrTwoFactorNotEnrolled) - return ok -} - -func (err ErrTwoFactorNotEnrolled) Error() string { - return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID) -} - // ____ ___ .__ .___ // | | \______ | | _________ __| _/ // | | /\____ \| | / _ \__ \ / __ | @@ -1959,28 +1940,6 @@ func (err ErrExternalLoginUserNotExist) Error() string { return fmt.Sprintf("external login user link does not exists [userID: %d, loginSourceID: %d]", err.UserID, err.LoginSourceID) } -// ____ ________________________________ .__ __ __ .__ -// | | \_____ \_ _____/\______ \ ____ ____ |__| _______/ |_____________ _/ |_|__| ____ ____ -// | | // ____/| __) | _// __ \ / ___\| |/ ___/\ __\_ __ \__ \\ __\ |/ _ \ / \ -// | | // \| \ | | \ ___// /_/ > |\___ \ | | | | \// __ \| | | ( <_> ) | \ -// |______/ \_______ \___ / |____|_ /\___ >___ /|__/____ > |__| |__| (____ /__| |__|\____/|___| / -// \/ \/ \/ \/_____/ \/ \/ \/ - -// ErrU2FRegistrationNotExist represents a "ErrU2FRegistrationNotExist" kind of error. -type ErrU2FRegistrationNotExist struct { - ID int64 -} - -func (err ErrU2FRegistrationNotExist) Error() string { - return fmt.Sprintf("U2F registration does not exist [id: %d]", err.ID) -} - -// IsErrU2FRegistrationNotExist checks if an error is a ErrU2FRegistrationNotExist. -func IsErrU2FRegistrationNotExist(err error) bool { - _, ok := err.(ErrU2FRegistrationNotExist) - return ok -} - // .___ ________ .___ .__ // | | ______ ________ __ ____ \______ \ ____ ______ ____ ____ __| _/____ ____ ____ |__| ____ ______ // | |/ ___// ___/ | \_/ __ \ | | \_/ __ \\____ \_/ __ \ / \ / __ |/ __ \ / \_/ ___\| |/ __ \ / ___/ diff --git a/models/login/main_test.go b/models/login/main_test.go index ef4b5907bfd4..141952a5941d 100644 --- a/models/login/main_test.go +++ b/models/login/main_test.go @@ -17,5 +17,6 @@ func TestMain(m *testing.M) { "oauth2_application.yml", "oauth2_authorization_code.yml", "oauth2_grant.yml", + "u2f_registration.yml", ) } diff --git a/models/twofactor.go b/models/login/twofactor.go similarity index 83% rename from models/twofactor.go rename to models/login/twofactor.go index dd7fde77e21e..1c4d2734fca0 100644 --- a/models/twofactor.go +++ b/models/login/twofactor.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package models +package login import ( "crypto/md5" @@ -21,6 +21,25 @@ import ( "golang.org/x/crypto/pbkdf2" ) +// +// Two-factor authentication +// + +// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication. +type ErrTwoFactorNotEnrolled struct { + UID int64 +} + +// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled. +func IsErrTwoFactorNotEnrolled(err error) bool { + _, ok := err.(ErrTwoFactorNotEnrolled) + return ok +} + +func (err ErrTwoFactorNotEnrolled) Error() string { + return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID) +} + // TwoFactor represents a two-factor authentication token. type TwoFactor struct { ID int64 `xorm:"pk autoincr"` @@ -44,11 +63,12 @@ func (t *TwoFactor) GenerateScratchToken() (string, error) { return "", err } t.ScratchSalt, _ = util.RandomString(10) - t.ScratchHash = hashToken(token, t.ScratchSalt) + t.ScratchHash = HashToken(token, t.ScratchSalt) return token, nil } -func hashToken(token, salt string) string { +// HashToken return the hashable salt +func HashToken(token, salt string) string { tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New) return fmt.Sprintf("%x", tempHash) } @@ -58,7 +78,7 @@ func (t *TwoFactor) VerifyScratchToken(token string) bool { if len(token) == 0 { return false } - tempHash := hashToken(token, t.ScratchSalt) + tempHash := HashToken(token, t.ScratchSalt) return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1 } diff --git a/models/u2f.go b/models/login/u2f.go similarity index 71% rename from models/u2f.go rename to models/login/u2f.go index 17b829562634..64b1fb322ac8 100644 --- a/models/u2f.go +++ b/models/login/u2f.go @@ -2,9 +2,11 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package models +package login import ( + "fmt" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/timeutil" @@ -12,6 +14,28 @@ import ( "github.com/tstranex/u2f" ) +// ____ ________________________________ .__ __ __ .__ +// | | \_____ \_ _____/\______ \ ____ ____ |__| _______/ |_____________ _/ |_|__| ____ ____ +// | | // ____/| __) | _// __ \ / ___\| |/ ___/\ __\_ __ \__ \\ __\ |/ _ \ / \ +// | | // \| \ | | \ ___// /_/ > |\___ \ | | | | \// __ \| | | ( <_> ) | \ +// |______/ \_______ \___ / |____|_ /\___ >___ /|__/____ > |__| |__| (____ /__| |__|\____/|___| / +// \/ \/ \/ \/_____/ \/ \/ \/ + +// ErrU2FRegistrationNotExist represents a "ErrU2FRegistrationNotExist" kind of error. +type ErrU2FRegistrationNotExist struct { + ID int64 +} + +func (err ErrU2FRegistrationNotExist) Error() string { + return fmt.Sprintf("U2F registration does not exist [id: %d]", err.ID) +} + +// IsErrU2FRegistrationNotExist checks if an error is a ErrU2FRegistrationNotExist. +func IsErrU2FRegistrationNotExist(err error) bool { + _, ok := err.(ErrU2FRegistrationNotExist) + return ok +} + // U2FRegistration represents the registration data and counter of a security key type U2FRegistration struct { ID int64 `xorm:"pk autoincr"` @@ -91,13 +115,13 @@ func GetU2FRegistrationsByUID(uid int64) (U2FRegistrationList, error) { return getU2FRegistrationsByUID(db.GetEngine(db.DefaultContext), uid) } -func createRegistration(e db.Engine, user *User, name string, reg *u2f.Registration) (*U2FRegistration, error) { +func createRegistration(e db.Engine, userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) { raw, err := reg.MarshalBinary() if err != nil { return nil, err } r := &U2FRegistration{ - UserID: user.ID, + UserID: userID, Name: name, Counter: 0, Raw: raw, @@ -110,8 +134,8 @@ func createRegistration(e db.Engine, user *User, name string, reg *u2f.Registrat } // CreateRegistration will create a new U2FRegistration from the given Registration -func CreateRegistration(user *User, name string, reg *u2f.Registration) (*U2FRegistration, error) { - return createRegistration(db.GetEngine(db.DefaultContext), user, name, reg) +func CreateRegistration(userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) { + return createRegistration(db.GetEngine(db.DefaultContext), userID, name, reg) } // DeleteRegistration will delete U2FRegistration diff --git a/models/u2f_test.go b/models/login/u2f_test.go similarity index 91% rename from models/u2f_test.go rename to models/login/u2f_test.go index 44eca6953d43..b0305775caf5 100644 --- a/models/u2f_test.go +++ b/models/login/u2f_test.go @@ -2,12 +2,13 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package models +package login import ( "testing" "code.gitea.io/gitea/models/db" + "github.com/stretchr/testify/assert" "github.com/tstranex/u2f" ) @@ -55,14 +56,13 @@ func TestU2FRegistration_UpdateLargeCounter(t *testing.T) { func TestCreateRegistration(t *testing.T) { assert.NoError(t, db.PrepareTestDatabase()) - user := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User) - res, err := CreateRegistration(user, "U2F Created Key", &u2f.Registration{Raw: []byte("Test")}) + res, err := CreateRegistration(1, "U2F Created Key", &u2f.Registration{Raw: []byte("Test")}) assert.NoError(t, err) assert.Equal(t, "U2F Created Key", res.Name) assert.Equal(t, []byte("Test"), res.Raw) - db.AssertExistsIf(t, true, &U2FRegistration{Name: "U2F Created Key", UserID: user.ID}) + db.AssertExistsIf(t, true, &U2FRegistration{Name: "U2F Created Key", UserID: 1}) } func TestDeleteRegistration(t *testing.T) { diff --git a/models/pull_sign.go b/models/pull_sign.go index 2e7cbff48b43..028a3e5c3b65 100644 --- a/models/pull_sign.go +++ b/models/pull_sign.go @@ -6,6 +6,7 @@ package models import ( "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -44,8 +45,8 @@ Loop: return false, "", nil, &ErrWontSign{pubkey} } case twofa: - twofaModel, err := GetTwoFactorByUID(u.ID) - if err != nil && !IsErrTwoFactorNotEnrolled(err) { + twofaModel, err := login.GetTwoFactorByUID(u.ID) + if err != nil && !login.IsErrTwoFactorNotEnrolled(err) { return false, "", nil, err } if twofaModel == nil { diff --git a/models/repo_sign.go b/models/repo_sign.go index ae0895df7646..f7a303b0c124 100644 --- a/models/repo_sign.go +++ b/models/repo_sign.go @@ -8,6 +8,7 @@ import ( "strings" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/process" @@ -129,8 +130,8 @@ Loop: return false, "", nil, &ErrWontSign{pubkey} } case twofa: - twofaModel, err := GetTwoFactorByUID(u.ID) - if err != nil && !IsErrTwoFactorNotEnrolled(err) { + twofaModel, err := login.GetTwoFactorByUID(u.ID) + if err != nil && !login.IsErrTwoFactorNotEnrolled(err) { return false, "", nil, err } if twofaModel == nil { @@ -165,8 +166,8 @@ Loop: return false, "", nil, &ErrWontSign{pubkey} } case twofa: - twofaModel, err := GetTwoFactorByUID(u.ID) - if err != nil && !IsErrTwoFactorNotEnrolled(err) { + twofaModel, err := login.GetTwoFactorByUID(u.ID) + if err != nil && !login.IsErrTwoFactorNotEnrolled(err) { return false, "", nil, err } if twofaModel == nil { @@ -218,8 +219,8 @@ Loop: return false, "", nil, &ErrWontSign{pubkey} } case twofa: - twofaModel, err := GetTwoFactorByUID(u.ID) - if err != nil && !IsErrTwoFactorNotEnrolled(err) { + twofaModel, err := login.GetTwoFactorByUID(u.ID) + if err != nil && !login.IsErrTwoFactorNotEnrolled(err) { return false, "", nil, err } if twofaModel == nil { diff --git a/models/token.go b/models/token.go index 07d013ac8ed4..3cffdd9ba276 100644 --- a/models/token.go +++ b/models/token.go @@ -11,6 +11,7 @@ import ( "time" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/timeutil" @@ -67,7 +68,7 @@ func NewAccessToken(t *AccessToken) error { } t.TokenSalt = salt t.Token = base.EncodeSha1(gouuid.New().String()) - t.TokenHash = hashToken(t.Token, t.TokenSalt) + t.TokenHash = login.HashToken(t.Token, t.TokenSalt) t.TokenLastEight = t.Token[len(t.Token)-8:] _, err = db.GetEngine(db.DefaultContext).Insert(t) return err @@ -129,7 +130,7 @@ func GetAccessTokenBySHA(token string) (*AccessToken, error) { } for _, t := range tokens { - tempHash := hashToken(token, t.TokenSalt) + tempHash := login.HashToken(token, t.TokenSalt) if subtle.ConstantTimeCompare([]byte(t.TokenHash), []byte(tempHash)) == 1 { if successfulAccessTokenCache != nil { successfulAccessTokenCache.Add(token, t.ID) diff --git a/models/userlist.go b/models/userlist.go index bfa7ea1e2ea2..aebdb4f48c25 100644 --- a/models/userlist.go +++ b/models/userlist.go @@ -8,6 +8,7 @@ import ( "fmt" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/log" ) @@ -79,13 +80,13 @@ func (users UserList) GetTwoFaStatus() map[int64]bool { return results } -func (users UserList) loadTwoFactorStatus(e db.Engine) (map[int64]*TwoFactor, error) { +func (users UserList) loadTwoFactorStatus(e db.Engine) (map[int64]*login.TwoFactor, error) { if len(users) == 0 { return nil, nil } userIDs := users.getUserIDs() - tokenMaps := make(map[int64]*TwoFactor, len(userIDs)) + tokenMaps := make(map[int64]*login.TwoFactor, len(userIDs)) err := e. In("uid", userIDs). Find(&tokenMaps) diff --git a/modules/context/api.go b/modules/context/api.go index e80e63cd9623..e5216d911f8a 100644 --- a/modules/context/api.go +++ b/modules/context/api.go @@ -14,6 +14,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -219,9 +220,9 @@ func (ctx *APIContext) CheckForOTP() { } otpHeader := ctx.Req.Header.Get("X-Gitea-OTP") - twofa, err := models.GetTwoFactorByUID(ctx.Context.User.ID) + twofa, err := login.GetTwoFactorByUID(ctx.Context.User.ID) if err != nil { - if models.IsErrTwoFactorNotEnrolled(err) { + if login.IsErrTwoFactorNotEnrolled(err) { return // No 2FA enrollment for this user } ctx.Context.Error(http.StatusInternalServerError) diff --git a/modules/context/auth.go b/modules/context/auth.go index 0a62b2741e4a..7faa93d78b59 100644 --- a/modules/context/auth.go +++ b/modules/context/auth.go @@ -8,7 +8,7 @@ package context import ( "net/http" - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/web/middleware" @@ -154,9 +154,9 @@ func ToggleAPI(options *ToggleOptions) func(ctx *APIContext) { if skip, ok := ctx.Data["SkipLocalTwoFA"]; ok && skip.(bool) { return // Skip 2FA } - twofa, err := models.GetTwoFactorByUID(ctx.User.ID) + twofa, err := login.GetTwoFactorByUID(ctx.User.ID) if err != nil { - if models.IsErrTwoFactorNotEnrolled(err) { + if login.IsErrTwoFactorNotEnrolled(err) { return // No 2FA enrollment for this user } ctx.InternalServerError(err) diff --git a/routers/web/admin/users.go b/routers/web/admin/users.go index 2556cae3a87a..ea666ab4d4db 100644 --- a/routers/web/admin/users.go +++ b/routers/web/admin/users.go @@ -195,9 +195,9 @@ func prepareUserInfo(ctx *context.Context) *models.User { ctx.Data["Sources"] = sources ctx.Data["TwoFactorEnabled"] = true - _, err = models.GetTwoFactorByUID(u.ID) + _, err = login.GetTwoFactorByUID(u.ID) if err != nil { - if !models.IsErrTwoFactorNotEnrolled(err) { + if !login.IsErrTwoFactorNotEnrolled(err) { ctx.ServerError("IsErrTwoFactorNotEnrolled", err) return nil } @@ -295,13 +295,13 @@ func EditUserPost(ctx *context.Context) { } if form.Reset2FA { - tf, err := models.GetTwoFactorByUID(u.ID) - if err != nil && !models.IsErrTwoFactorNotEnrolled(err) { + tf, err := login.GetTwoFactorByUID(u.ID) + if err != nil && !login.IsErrTwoFactorNotEnrolled(err) { ctx.ServerError("GetTwoFactorByUID", err) return } - if err = models.DeleteTwoFactorByID(tf.ID, u.ID); err != nil { + if err = login.DeleteTwoFactorByID(tf.ID, u.ID); err != nil { ctx.ServerError("DeleteTwoFactorByID", err) return } diff --git a/routers/web/repo/http.go b/routers/web/repo/http.go index fbd1e19a8219..162338a9597c 100644 --- a/routers/web/repo/http.go +++ b/routers/web/repo/http.go @@ -21,6 +21,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/log" @@ -174,12 +175,12 @@ func httpBase(ctx *context.Context) (h *serviceHandler) { } if ctx.IsBasicAuth && ctx.Data["IsApiToken"] != true { - _, err = models.GetTwoFactorByUID(ctx.User.ID) + _, err = login.GetTwoFactorByUID(ctx.User.ID) if err == nil { // TODO: This response should be changed to "invalid credentials" for security reasons once the expectation behind it (creating an app token to authenticate) is properly documented ctx.HandleText(http.StatusUnauthorized, "Users with two-factor authentication enabled cannot perform HTTP/HTTPS operations via plain username and password. Please create and use a personal access token on the user settings page") return - } else if !models.IsErrTwoFactorNotEnrolled(err) { + } else if !login.IsErrTwoFactorNotEnrolled(err) { ctx.ServerError("IsErrTwoFactorNotEnrolled", err) return } diff --git a/routers/web/user/auth.go b/routers/web/user/auth.go index 733ace81b02a..12328e46a165 100644 --- a/routers/web/user/auth.go +++ b/routers/web/user/auth.go @@ -213,9 +213,9 @@ func SignInPost(ctx *context.Context) { // If this user is enrolled in 2FA, we can't sign the user in just yet. // Instead, redirect them to the 2FA authentication page. - _, err = models.GetTwoFactorByUID(u.ID) + _, err = login.GetTwoFactorByUID(u.ID) if err != nil { - if models.IsErrTwoFactorNotEnrolled(err) { + if login.IsErrTwoFactorNotEnrolled(err) { handleSignIn(ctx, u, form.Remember) } else { ctx.ServerError("UserSignIn", err) @@ -237,7 +237,7 @@ func SignInPost(ctx *context.Context) { return } - regs, err := models.GetU2FRegistrationsByUID(u.ID) + regs, err := login.GetU2FRegistrationsByUID(u.ID) if err == nil && len(regs) > 0 { ctx.Redirect(setting.AppSubURL + "/user/u2f") return @@ -277,7 +277,7 @@ func TwoFactorPost(ctx *context.Context) { } id := idSess.(int64) - twofa, err := models.GetTwoFactorByUID(id) + twofa, err := login.GetTwoFactorByUID(id) if err != nil { ctx.ServerError("UserSignIn", err) return @@ -313,7 +313,7 @@ func TwoFactorPost(ctx *context.Context) { } twofa.LastUsedPasscode = form.Passcode - if err = models.UpdateTwoFactor(twofa); err != nil { + if err = login.UpdateTwoFactor(twofa); err != nil { ctx.ServerError("UserSignIn", err) return } @@ -356,7 +356,7 @@ func TwoFactorScratchPost(ctx *context.Context) { } id := idSess.(int64) - twofa, err := models.GetTwoFactorByUID(id) + twofa, err := login.GetTwoFactorByUID(id) if err != nil { ctx.ServerError("UserSignIn", err) return @@ -370,7 +370,7 @@ func TwoFactorScratchPost(ctx *context.Context) { ctx.ServerError("UserSignIn", err) return } - if err = models.UpdateTwoFactor(twofa); err != nil { + if err = login.UpdateTwoFactor(twofa); err != nil { ctx.ServerError("UserSignIn", err) return } @@ -418,7 +418,7 @@ func U2FChallenge(ctx *context.Context) { return } id := idSess.(int64) - regs, err := models.GetU2FRegistrationsByUID(id) + regs, err := login.GetU2FRegistrationsByUID(id) if err != nil { ctx.ServerError("UserSignIn", err) return @@ -454,7 +454,7 @@ func U2FSign(ctx *context.Context) { } challenge := challSess.(*u2f.Challenge) id := idSess.(int64) - regs, err := models.GetU2FRegistrationsByUID(id) + regs, err := login.GetU2FRegistrationsByUID(id) if err != nil { ctx.ServerError("UserSignIn", err) return @@ -717,8 +717,8 @@ func handleOAuth2SignIn(ctx *context.Context, source *login.Source, u *models.Us needs2FA := false if !source.Cfg.(*oauth2.Source).SkipLocalTwoFA { - _, err := models.GetTwoFactorByUID(u.ID) - if err != nil && !models.IsErrTwoFactorNotEnrolled(err) { + _, err := login.GetTwoFactorByUID(u.ID) + if err != nil && !login.IsErrTwoFactorNotEnrolled(err) { ctx.ServerError("UserSignIn", err) return } @@ -775,7 +775,7 @@ func handleOAuth2SignIn(ctx *context.Context, source *login.Source, u *models.Us } // If U2F is enrolled -> Redirect to U2F instead - regs, err := models.GetU2FRegistrationsByUID(u.ID) + regs, err := login.GetU2FRegistrationsByUID(u.ID) if err == nil && len(regs) > 0 { ctx.Redirect(setting.AppSubURL + "/user/u2f") return @@ -935,9 +935,9 @@ func linkAccount(ctx *context.Context, u *models.User, gothUser goth.User, remem // If this user is enrolled in 2FA, we can't sign the user in just yet. // Instead, redirect them to the 2FA authentication page. // We deliberately ignore the skip local 2fa setting here because we are linking to a previous user here - _, err := models.GetTwoFactorByUID(u.ID) + _, err := login.GetTwoFactorByUID(u.ID) if err != nil { - if !models.IsErrTwoFactorNotEnrolled(err) { + if !login.IsErrTwoFactorNotEnrolled(err) { ctx.ServerError("UserLinkAccount", err) return } @@ -967,7 +967,7 @@ func linkAccount(ctx *context.Context, u *models.User, gothUser goth.User, remem } // If U2F is enrolled -> Redirect to U2F instead - regs, err := models.GetU2FRegistrationsByUID(u.ID) + regs, err := login.GetU2FRegistrationsByUID(u.ID) if err == nil && len(regs) > 0 { ctx.Redirect(setting.AppSubURL + "/user/u2f") return @@ -1561,7 +1561,7 @@ func ForgotPasswdPost(ctx *context.Context) { ctx.HTML(http.StatusOK, tplForgotPassword) } -func commonResetPassword(ctx *context.Context) (*models.User, *models.TwoFactor) { +func commonResetPassword(ctx *context.Context) (*models.User, *login.TwoFactor) { code := ctx.FormString("code") ctx.Data["Title"] = ctx.Tr("auth.reset_password") @@ -1583,9 +1583,9 @@ func commonResetPassword(ctx *context.Context) (*models.User, *models.TwoFactor) return nil, nil } - twofa, err := models.GetTwoFactorByUID(u.ID) + twofa, err := login.GetTwoFactorByUID(u.ID) if err != nil { - if !models.IsErrTwoFactorNotEnrolled(err) { + if !login.IsErrTwoFactorNotEnrolled(err) { ctx.Error(http.StatusInternalServerError, "CommonResetPassword", err.Error()) return nil, nil } @@ -1680,7 +1680,7 @@ func ResetPasswdPost(ctx *context.Context) { } twofa.LastUsedPasscode = passcode - if err = models.UpdateTwoFactor(twofa); err != nil { + if err = login.UpdateTwoFactor(twofa); err != nil { ctx.ServerError("ResetPasswdPost: UpdateTwoFactor", err) return } @@ -1712,7 +1712,7 @@ func ResetPasswdPost(ctx *context.Context) { ctx.ServerError("UserSignIn", err) return } - if err = models.UpdateTwoFactor(twofa); err != nil { + if err = login.UpdateTwoFactor(twofa); err != nil { ctx.ServerError("UserSignIn", err) return } diff --git a/routers/web/user/setting/security.go b/routers/web/user/setting/security.go index d4abe84d9601..53f672282d1a 100644 --- a/routers/web/user/setting/security.go +++ b/routers/web/user/setting/security.go @@ -56,9 +56,9 @@ func DeleteAccountLink(ctx *context.Context) { func loadSecurityData(ctx *context.Context) { enrolled := true - _, err := models.GetTwoFactorByUID(ctx.User.ID) + _, err := login.GetTwoFactorByUID(ctx.User.ID) if err != nil { - if models.IsErrTwoFactorNotEnrolled(err) { + if login.IsErrTwoFactorNotEnrolled(err) { enrolled = false } else { ctx.ServerError("SettingsTwoFactor", err) @@ -67,7 +67,7 @@ func loadSecurityData(ctx *context.Context) { } ctx.Data["TwofaEnrolled"] = enrolled if enrolled { - ctx.Data["U2FRegistrations"], err = models.GetU2FRegistrationsByUID(ctx.User.ID) + ctx.Data["U2FRegistrations"], err = login.GetU2FRegistrationsByUID(ctx.User.ID) if err != nil { ctx.ServerError("GetU2FRegistrationsByUID", err) return diff --git a/routers/web/user/setting/security_twofa.go b/routers/web/user/setting/security_twofa.go index 7b08a05939b3..5b1cbab17fe7 100644 --- a/routers/web/user/setting/security_twofa.go +++ b/routers/web/user/setting/security_twofa.go @@ -13,7 +13,7 @@ import ( "net/http" "strings" - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -29,9 +29,9 @@ func RegenerateScratchTwoFactor(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("settings") ctx.Data["PageIsSettingsSecurity"] = true - t, err := models.GetTwoFactorByUID(ctx.User.ID) + t, err := login.GetTwoFactorByUID(ctx.User.ID) if err != nil { - if models.IsErrTwoFactorNotEnrolled(err) { + if login.IsErrTwoFactorNotEnrolled(err) { ctx.Flash.Error(ctx.Tr("setting.twofa_not_enrolled")) ctx.Redirect(setting.AppSubURL + "/user/settings/security") } @@ -45,7 +45,7 @@ func RegenerateScratchTwoFactor(ctx *context.Context) { return } - if err = models.UpdateTwoFactor(t); err != nil { + if err = login.UpdateTwoFactor(t); err != nil { ctx.ServerError("SettingsTwoFactor: Failed to UpdateTwoFactor", err) return } @@ -59,9 +59,9 @@ func DisableTwoFactor(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("settings") ctx.Data["PageIsSettingsSecurity"] = true - t, err := models.GetTwoFactorByUID(ctx.User.ID) + t, err := login.GetTwoFactorByUID(ctx.User.ID) if err != nil { - if models.IsErrTwoFactorNotEnrolled(err) { + if login.IsErrTwoFactorNotEnrolled(err) { ctx.Flash.Error(ctx.Tr("setting.twofa_not_enrolled")) ctx.Redirect(setting.AppSubURL + "/user/settings/security") } @@ -69,8 +69,8 @@ func DisableTwoFactor(ctx *context.Context) { return } - if err = models.DeleteTwoFactorByID(t.ID, ctx.User.ID); err != nil { - if models.IsErrTwoFactorNotEnrolled(err) { + if err = login.DeleteTwoFactorByID(t.ID, ctx.User.ID); err != nil { + if login.IsErrTwoFactorNotEnrolled(err) { // There is a potential DB race here - we must have been disabled by another request in the intervening period ctx.Flash.Success(ctx.Tr("settings.twofa_disabled")) ctx.Redirect(setting.AppSubURL + "/user/settings/security") @@ -146,7 +146,7 @@ func EnrollTwoFactor(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("settings") ctx.Data["PageIsSettingsSecurity"] = true - t, err := models.GetTwoFactorByUID(ctx.User.ID) + t, err := login.GetTwoFactorByUID(ctx.User.ID) if t != nil { // already enrolled - we should redirect back! log.Warn("Trying to re-enroll %-v in twofa when already enrolled", ctx.User) @@ -154,7 +154,7 @@ func EnrollTwoFactor(ctx *context.Context) { ctx.Redirect(setting.AppSubURL + "/user/settings/security") return } - if err != nil && !models.IsErrTwoFactorNotEnrolled(err) { + if err != nil && !login.IsErrTwoFactorNotEnrolled(err) { ctx.ServerError("SettingsTwoFactor: GetTwoFactorByUID", err) return } @@ -172,14 +172,14 @@ func EnrollTwoFactorPost(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("settings") ctx.Data["PageIsSettingsSecurity"] = true - t, err := models.GetTwoFactorByUID(ctx.User.ID) + t, err := login.GetTwoFactorByUID(ctx.User.ID) if t != nil { // already enrolled ctx.Flash.Error(ctx.Tr("setting.twofa_is_enrolled")) ctx.Redirect(setting.AppSubURL + "/user/settings/security") return } - if err != nil && !models.IsErrTwoFactorNotEnrolled(err) { + if err != nil && !login.IsErrTwoFactorNotEnrolled(err) { ctx.ServerError("SettingsTwoFactor: Failed to check if already enrolled with GetTwoFactorByUID", err) return } @@ -209,7 +209,7 @@ func EnrollTwoFactorPost(ctx *context.Context) { return } - t = &models.TwoFactor{ + t = &login.TwoFactor{ UID: ctx.User.ID, } err = t.SetSecret(secret) @@ -238,7 +238,7 @@ func EnrollTwoFactorPost(ctx *context.Context) { log.Error("Unable to save changes to the session: %v", err) } - if err = models.NewTwoFactor(t); err != nil { + if err = login.NewTwoFactor(t); err != nil { // FIXME: We need to handle a unique constraint fail here it's entirely possible that another request has beaten us. // If there is a unique constraint fail we should just tolerate the error ctx.ServerError("SettingsTwoFactor: Failed to save two factor", err) diff --git a/routers/web/user/setting/security_u2f.go b/routers/web/user/setting/security_u2f.go index f9e35549fbfa..d1d6d1e8cad8 100644 --- a/routers/web/user/setting/security_u2f.go +++ b/routers/web/user/setting/security_u2f.go @@ -8,7 +8,7 @@ import ( "errors" "net/http" - "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/login" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" @@ -34,7 +34,7 @@ func U2FRegister(ctx *context.Context) { ctx.ServerError("Unable to set session key for u2fChallenge", err) return } - regs, err := models.GetU2FRegistrationsByUID(ctx.User.ID) + regs, err := login.GetU2FRegistrationsByUID(ctx.User.ID) if err != nil { ctx.ServerError("GetU2FRegistrationsByUID", err) return @@ -78,7 +78,7 @@ func U2FRegisterPost(ctx *context.Context) { ctx.ServerError("u2f.Register", err) return } - if _, err = models.CreateRegistration(ctx.User, name, reg); err != nil { + if _, err = login.CreateRegistration(ctx.User.ID, name, reg); err != nil { ctx.ServerError("u2f.Register", err) return } @@ -88,9 +88,9 @@ func U2FRegisterPost(ctx *context.Context) { // U2FDelete deletes an security key by id func U2FDelete(ctx *context.Context) { form := web.GetForm(ctx).(*forms.U2FDeleteForm) - reg, err := models.GetU2FRegistrationByID(form.ID) + reg, err := login.GetU2FRegistrationByID(form.ID) if err != nil { - if models.IsErrU2FRegistrationNotExist(err) { + if login.IsErrU2FRegistrationNotExist(err) { ctx.Status(200) return } @@ -101,7 +101,7 @@ func U2FDelete(ctx *context.Context) { ctx.Status(401) return } - if err := models.DeleteRegistration(reg); err != nil { + if err := login.DeleteRegistration(reg); err != nil { ctx.ServerError("DeleteRegistration", err) return } From 58d81835e2e78cc218d238c580a58cdd54535b44 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 25 Sep 2021 22:27:01 +0800 Subject: [PATCH 11/13] Fix wrong i18n keys (#17150) Co-authored-by: 6543 <6543@obermui.de> --- routers/web/user/setting/keys.go | 2 +- routers/web/user/setting/security_twofa.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/routers/web/user/setting/keys.go b/routers/web/user/setting/keys.go index bb7a50841bb8..22a0fe474171 100644 --- a/routers/web/user/setting/keys.go +++ b/routers/web/user/setting/keys.go @@ -209,7 +209,7 @@ func DeleteKey(ctx *context.Context) { return } if external { - ctx.Flash.Error(ctx.Tr("setting.ssh_externally_managed")) + ctx.Flash.Error(ctx.Tr("settings.ssh_externally_managed")) ctx.Redirect(setting.AppSubURL + "/user/settings/keys") return } diff --git a/routers/web/user/setting/security_twofa.go b/routers/web/user/setting/security_twofa.go index 5b1cbab17fe7..94f975f9fe6f 100644 --- a/routers/web/user/setting/security_twofa.go +++ b/routers/web/user/setting/security_twofa.go @@ -32,7 +32,7 @@ func RegenerateScratchTwoFactor(ctx *context.Context) { t, err := login.GetTwoFactorByUID(ctx.User.ID) if err != nil { if login.IsErrTwoFactorNotEnrolled(err) { - ctx.Flash.Error(ctx.Tr("setting.twofa_not_enrolled")) + ctx.Flash.Error(ctx.Tr("settings.twofa_not_enrolled")) ctx.Redirect(setting.AppSubURL + "/user/settings/security") } ctx.ServerError("SettingsTwoFactor: Failed to GetTwoFactorByUID", err) @@ -62,7 +62,7 @@ func DisableTwoFactor(ctx *context.Context) { t, err := login.GetTwoFactorByUID(ctx.User.ID) if err != nil { if login.IsErrTwoFactorNotEnrolled(err) { - ctx.Flash.Error(ctx.Tr("setting.twofa_not_enrolled")) + ctx.Flash.Error(ctx.Tr("settings.twofa_not_enrolled")) ctx.Redirect(setting.AppSubURL + "/user/settings/security") } ctx.ServerError("SettingsTwoFactor: Failed to GetTwoFactorByUID", err) @@ -150,7 +150,7 @@ func EnrollTwoFactor(ctx *context.Context) { if t != nil { // already enrolled - we should redirect back! log.Warn("Trying to re-enroll %-v in twofa when already enrolled", ctx.User) - ctx.Flash.Error(ctx.Tr("setting.twofa_is_enrolled")) + ctx.Flash.Error(ctx.Tr("settings.twofa_is_enrolled")) ctx.Redirect(setting.AppSubURL + "/user/settings/security") return } @@ -175,7 +175,7 @@ func EnrollTwoFactorPost(ctx *context.Context) { t, err := login.GetTwoFactorByUID(ctx.User.ID) if t != nil { // already enrolled - ctx.Flash.Error(ctx.Tr("setting.twofa_is_enrolled")) + ctx.Flash.Error(ctx.Tr("settings.twofa_is_enrolled")) ctx.Redirect(setting.AppSubURL + "/user/settings/security") return } From 7e9bd206fd13a34ccdd59dfa8bf5474e5e7f000d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexey=20=E3=80=92erentyev?= Date: Sun, 26 Sep 2021 00:29:25 +0300 Subject: [PATCH 12/13] Fix bundle creation (#17079) Signed-off-by: Alexey Terentyev Co-authored-by: 6543 <6543@obermui.de> Co-authored-by: Gwyneth Morgan <87623694+gwymor@users.noreply.github.com> Co-authored-by: Gwyneth Morgan --- modules/git/repo.go | 24 +++++++++++++++++------- services/archiver/archiver.go | 6 ++++-- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/modules/git/repo.go b/modules/git/repo.go index e7d42dacb165..89af7aa9e1da 100644 --- a/modules/git/repo.go +++ b/modules/git/repo.go @@ -425,14 +425,24 @@ func (repo *Repository) CreateBundle(ctx context.Context, commit string, out io. } defer os.RemoveAll(tmp) - tmpFile := filepath.Join(tmp, "bundle") - args := []string{ - "bundle", - "create", - tmpFile, - commit, + env := append(os.Environ(), "GIT_OBJECT_DIRECTORY="+filepath.Join(repo.Path, "objects")) + _, err = NewCommandContext(ctx, "init", "--bare").RunInDirWithEnv(tmp, env) + if err != nil { + return err + } + + _, err = NewCommandContext(ctx, "reset", "--soft", commit).RunInDirWithEnv(tmp, env) + if err != nil { + return err } - _, err = NewCommandContext(ctx, args...).RunInDir(repo.Path) + + _, err = NewCommandContext(ctx, "branch", "-m", "bundle").RunInDirWithEnv(tmp, env) + if err != nil { + return err + } + + tmpFile := filepath.Join(tmp, "bundle") + _, err = NewCommandContext(ctx, "bundle", "create", tmpFile, "bundle", "HEAD").RunInDirWithEnv(tmp, env) if err != nil { return err } diff --git a/services/archiver/archiver.go b/services/archiver/archiver.go index 6d4d46e4e02c..d602b9ed7fd5 100644 --- a/services/archiver/archiver.go +++ b/services/archiver/archiver.go @@ -136,9 +136,11 @@ func doArchive(r *ArchiveRequest) (*models.RepoArchiver, error) { if err == nil { if archiver.Status == models.RepoArchiverGenerating { archiver.Status = models.RepoArchiverReady - return archiver, models.UpdateRepoArchiverStatus(ctx, archiver) + if err = models.UpdateRepoArchiverStatus(ctx, archiver); err != nil { + return nil, err + } } - return archiver, nil + return archiver, committer.Commit() } if !errors.Is(err, os.ErrNotExist) { From 74542ad35bae2195972df86862da43e8d45f425f Mon Sep 17 00:00:00 2001 From: GiteaBot Date: Sun, 26 Sep 2021 00:05:01 +0000 Subject: [PATCH 13/13] [skip ci] Updated translations via Crowdin --- options/locale/locale_ru-RU.ini | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/options/locale/locale_ru-RU.ini b/options/locale/locale_ru-RU.ini index c6ea930cd0bd..4e61469f4136 100644 --- a/options/locale/locale_ru-RU.ini +++ b/options/locale/locale_ru-RU.ini @@ -508,7 +508,7 @@ public_profile=Открытый профиль biography_placeholder=Расскажите немного о себе profile_desc=Ваш адрес электронной почты будет использован для уведомлений и других операций. password_username_disabled=Нелокальным пользователям запрещено изменение их имени пользователя. Для получения более подробной информации обратитесь к администратору сайта. -full_name=ФИО +full_name=Имя и фамилия website=Веб-сайт location=Местоположение update_theme=Обновить тему @@ -1202,8 +1202,8 @@ pulls.merged_by_fake=слито %[1]s пользователем %[2]s issues.closed_by=закрыт %[1]s пользователем %[3]s issues.opened_by_fake=открыт %[1]s пользователем %[2]s issues.closed_by_fake=закрыт %[1]s пользователем %[2]s -issues.previous=Предыдущая страница -issues.next=Следующая страница +issues.previous=Предыдущая +issues.next=Следующая issues.open_title=Открыто issues.closed_title=Закрыто issues.num_comments=комментариев: %d