diff --git a/.changeset/soft-hotels-decide.md b/.changeset/soft-hotels-decide.md new file mode 100644 index 00000000000..75b4cadd4e5 --- /dev/null +++ b/.changeset/soft-hotels-decide.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +switch more EVM components to use sqlutil.DataStore #internal diff --git a/common/txmgr/broadcaster.go b/common/txmgr/broadcaster.go index a13673bf91b..1651f6417bf 100644 --- a/common/txmgr/broadcaster.go +++ b/common/txmgr/broadcaster.go @@ -689,7 +689,7 @@ func (eb *Broadcaster[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) save // is relatively benign and probably nobody will ever run into it in // practice, but something to be aware of. if etx.PipelineTaskRunID.Valid && eb.resumeCallback != nil && etx.SignalCallback { - err := eb.resumeCallback(etx.PipelineTaskRunID.UUID, nil, fmt.Errorf("fatal error while sending transaction: %s", etx.Error.String)) + err := eb.resumeCallback(ctx, etx.PipelineTaskRunID.UUID, nil, fmt.Errorf("fatal error while sending transaction: %s", etx.Error.String)) if errors.Is(err, sql.ErrNoRows) { lgr.Debugw("callback missing or already resumed", "etxID", etx.ID) } else if err != nil { diff --git a/common/txmgr/confirmer.go b/common/txmgr/confirmer.go index 53e1c3c4206..d61f9a3dddd 100644 --- a/common/txmgr/confirmer.go +++ b/common/txmgr/confirmer.go @@ -1120,7 +1120,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Res } ec.lggr.Debugw("Callback: resuming tx with receipt", "output", output, "taskErr", taskErr, "pipelineTaskRunID", data.ID) - if err := ec.resumeCallback(data.ID, output, taskErr); err != nil { + if err := ec.resumeCallback(ctx, data.ID, output, taskErr); err != nil { return fmt.Errorf("failed to resume suspended pipeline run: %w", err) } // Mark tx as having completed callback diff --git a/common/txmgr/txmgr.go b/common/txmgr/txmgr.go index d183a8c3ade..b996b76f1a5 100644 --- a/common/txmgr/txmgr.go +++ b/common/txmgr/txmgr.go @@ -27,7 +27,7 @@ import ( // https://www.notion.so/chainlink/Txm-Architecture-Overview-9dc62450cd7a443ba9e7dceffa1a8d6b // ResumeCallback is assumed to be idempotent -type ResumeCallback func(id uuid.UUID, result interface{}, err error) error +type ResumeCallback func(ctx context.Context, id uuid.UUID, result interface{}, err error) error // TxManager is the main component of the transaction manager. // It is also the interface to external callers. diff --git a/core/bridges/orm_test.go b/core/bridges/orm_test.go index 204dc5fe115..85e8b9ecdef 100644 --- a/core/bridges/orm_test.go +++ b/core/bridges/orm_test.go @@ -17,7 +17,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" ) @@ -144,8 +143,8 @@ func TestORM_TestCachedResponse(t *testing.T) { db := pgtest.NewSqlxDB(t) orm := bridges.NewORM(db) - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) _, err = orm.GetCachedResponse(ctx, "dot", specID, 1*time.Second) diff --git a/core/chains/evm/forwarders/forwarder_manager.go b/core/chains/evm/forwarders/forwarder_manager.go index f0786c091c4..7a7a274127f 100644 --- a/core/chains/evm/forwarders/forwarder_manager.go +++ b/core/chains/evm/forwarders/forwarder_manager.go @@ -54,13 +54,13 @@ type FwdMgr struct { wg sync.WaitGroup } -func NewFwdMgr(db sqlutil.DataSource, client evmclient.Client, logpoller evmlogpoller.LogPoller, l logger.Logger, cfg Config) *FwdMgr { +func NewFwdMgr(ds sqlutil.DataSource, client evmclient.Client, logpoller evmlogpoller.LogPoller, l logger.Logger, cfg Config) *FwdMgr { lggr := logger.Sugared(logger.Named(l, "EVMForwarderManager")) fwdMgr := FwdMgr{ logger: lggr, cfg: cfg, evmClient: client, - ORM: NewORM(db), + ORM: NewORM(ds), logpoller: logpoller, sendersCache: make(map[common.Address][]common.Address), } diff --git a/core/chains/evm/forwarders/orm.go b/core/chains/evm/forwarders/orm.go index cf498518d6d..8076cba4831 100644 --- a/core/chains/evm/forwarders/orm.go +++ b/core/chains/evm/forwarders/orm.go @@ -23,50 +23,50 @@ type ORM interface { FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error) } -type DbORM struct { - db sqlutil.DataSource +type DSORM struct { + ds sqlutil.DataSource } -var _ ORM = &DbORM{} +var _ ORM = &DSORM{} -func NewORM(db sqlutil.DataSource) *DbORM { - return &DbORM{db: db} +func NewORM(ds sqlutil.DataSource) *DSORM { + return &DSORM{ds: ds} } -func (o *DbORM) Transaction(ctx context.Context, fn func(*DbORM) error) (err error) { - return sqlutil.Transact(ctx, o.new, o.db, nil, fn) +func (o *DSORM) Transact(ctx context.Context, fn func(*DSORM) error) (err error) { + return sqlutil.Transact(ctx, o.new, o.ds, nil, fn) } // new returns a NewORM like o, but backed by q. -func (o *DbORM) new(q sqlutil.DataSource) *DbORM { return NewORM(q) } +func (o *DSORM) new(q sqlutil.DataSource) *DSORM { return NewORM(q) } // CreateForwarder creates the Forwarder address associated with the current EVM chain id. -func (o *DbORM) CreateForwarder(ctx context.Context, addr common.Address, evmChainId big.Big) (fwd Forwarder, err error) { +func (o *DSORM) CreateForwarder(ctx context.Context, addr common.Address, evmChainId big.Big) (fwd Forwarder, err error) { sql := `INSERT INTO evm.forwarders (address, evm_chain_id, created_at, updated_at) VALUES ($1, $2, now(), now()) RETURNING *` - err = o.db.GetContext(ctx, &fwd, sql, addr, evmChainId) + err = o.ds.GetContext(ctx, &fwd, sql, addr, evmChainId) return fwd, err } // DeleteForwarder removes a forwarder address. // If cleanup is non-nil, it can be used to perform any chain- or contract-specific cleanup that need to happen atomically // on forwarder deletion. If cleanup returns an error, forwarder deletion will be aborted. -func (o *DbORM) DeleteForwarder(ctx context.Context, id int64, cleanup func(tx sqlutil.DataSource, evmChainID int64, addr common.Address) error) (err error) { - return o.Transaction(ctx, func(orm *DbORM) error { +func (o *DSORM) DeleteForwarder(ctx context.Context, id int64, cleanup func(tx sqlutil.DataSource, evmChainID int64, addr common.Address) error) (err error) { + return o.Transact(ctx, func(orm *DSORM) error { var dest struct { EvmChainId int64 Address common.Address } - err := orm.db.GetContext(ctx, &dest, `SELECT evm_chain_id, address FROM evm.forwarders WHERE id = $1`, id) + err := orm.ds.GetContext(ctx, &dest, `SELECT evm_chain_id, address FROM evm.forwarders WHERE id = $1`, id) if err != nil { return err } if cleanup != nil { - if err = cleanup(orm.db, dest.EvmChainId, dest.Address); err != nil { + if err = cleanup(orm.ds, dest.EvmChainId, dest.Address); err != nil { return err } } - result, err := orm.db.ExecContext(ctx, `DELETE FROM evm.forwarders WHERE id = $1`, id) + result, err := orm.ds.ExecContext(ctx, `DELETE FROM evm.forwarders WHERE id = $1`, id) // If the forwarder wasn't found, we still want to delete the filter. // In that case, the transaction must return nil, even though DeleteForwarder // will return sql.ErrNoRows @@ -82,27 +82,27 @@ func (o *DbORM) DeleteForwarder(ctx context.Context, id int64, cleanup func(tx s } // FindForwarders returns all forwarder addresses from offset up until limit. -func (o *DbORM) FindForwarders(ctx context.Context, offset, limit int) (fwds []Forwarder, count int, err error) { +func (o *DSORM) FindForwarders(ctx context.Context, offset, limit int) (fwds []Forwarder, count int, err error) { sql := `SELECT count(*) FROM evm.forwarders` - if err = o.db.GetContext(ctx, &count, sql); err != nil { + if err = o.ds.GetContext(ctx, &count, sql); err != nil { return } sql = `SELECT * FROM evm.forwarders ORDER BY created_at DESC, id DESC LIMIT $1 OFFSET $2` - if err = o.db.SelectContext(ctx, &fwds, sql, limit, offset); err != nil { + if err = o.ds.SelectContext(ctx, &fwds, sql, limit, offset); err != nil { return } return } // FindForwardersByChain returns all forwarder addresses for a chain. -func (o *DbORM) FindForwardersByChain(ctx context.Context, evmChainId big.Big) (fwds []Forwarder, err error) { +func (o *DSORM) FindForwardersByChain(ctx context.Context, evmChainId big.Big) (fwds []Forwarder, err error) { sql := `SELECT * FROM evm.forwarders where evm_chain_id = $1 ORDER BY created_at DESC, id DESC` - err = o.db.SelectContext(ctx, &fwds, sql, evmChainId) + err = o.ds.SelectContext(ctx, &fwds, sql, evmChainId) return } -func (o *DbORM) FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error) { +func (o *DSORM) FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error) { var fwdrs []Forwarder arg := map[string]interface{}{ @@ -127,8 +127,8 @@ func (o *DbORM) FindForwardersInListByChain(ctx context.Context, evmChainId big. return nil, pkgerrors.Wrap(err, "Failed to run sqlx.IN on query") } - query = o.db.Rebind(query) - err = o.db.SelectContext(ctx, &fwdrs, query, args...) + query = o.ds.Rebind(query) + err = o.ds.SelectContext(ctx, &fwdrs, query, args...) if err != nil { return nil, pkgerrors.Wrap(err, "Failed to execute query") diff --git a/core/chains/evm/headtracker/orm.go b/core/chains/evm/headtracker/orm.go index 8912bafecdf..9d569ade08d 100644 --- a/core/chains/evm/headtracker/orm.go +++ b/core/chains/evm/headtracker/orm.go @@ -31,14 +31,14 @@ var _ ORM = &DbORM{} type DbORM struct { chainID ubig.Big - db sqlutil.DataSource + ds sqlutil.DataSource } // NewORM creates an ORM scoped to chainID. -func NewORM(chainID big.Int, db sqlutil.DataSource) *DbORM { +func NewORM(chainID big.Int, ds sqlutil.DataSource) *DbORM { return &DbORM{ chainID: ubig.Big(chainID), - db: db, + ds: ds, } } @@ -48,19 +48,19 @@ func (orm *DbORM) IdempotentInsertHead(ctx context.Context, head *evmtypes.Head) INSERT INTO evm.heads (hash, number, parent_hash, created_at, timestamp, l1_block_number, evm_chain_id, base_fee_per_gas) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (evm_chain_id, hash) DO NOTHING` - _, err := orm.db.ExecContext(ctx, query, head.Hash, head.Number, head.ParentHash, head.CreatedAt, head.Timestamp, head.L1BlockNumber, orm.chainID, head.BaseFeePerGas) + _, err := orm.ds.ExecContext(ctx, query, head.Hash, head.Number, head.ParentHash, head.CreatedAt, head.Timestamp, head.L1BlockNumber, orm.chainID, head.BaseFeePerGas) return pkgerrors.Wrap(err, "IdempotentInsertHead failed to insert head") } func (orm *DbORM) TrimOldHeads(ctx context.Context, minBlockNumber int64) (err error) { query := `DELETE FROM evm.heads WHERE evm_chain_id = $1 AND number < $2` - _, err = orm.db.ExecContext(ctx, query, orm.chainID, minBlockNumber) + _, err = orm.ds.ExecContext(ctx, query, orm.chainID, minBlockNumber) return err } func (orm *DbORM) LatestHead(ctx context.Context) (head *evmtypes.Head, err error) { head = new(evmtypes.Head) - err = orm.db.GetContext(ctx, head, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 ORDER BY number DESC, created_at DESC, id DESC LIMIT 1`, orm.chainID) + err = orm.ds.GetContext(ctx, head, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 ORDER BY number DESC, created_at DESC, id DESC LIMIT 1`, orm.chainID) if pkgerrors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -69,14 +69,14 @@ func (orm *DbORM) LatestHead(ctx context.Context) (head *evmtypes.Head, err erro } func (orm *DbORM) LatestHeads(ctx context.Context, minBlockNumer int64) (heads []*evmtypes.Head, err error) { - err = orm.db.SelectContext(ctx, &heads, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 AND number >= $2 ORDER BY number DESC, created_at DESC, id DESC`, orm.chainID, minBlockNumer) + err = orm.ds.SelectContext(ctx, &heads, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 AND number >= $2 ORDER BY number DESC, created_at DESC, id DESC`, orm.chainID, minBlockNumer) err = pkgerrors.Wrap(err, "LatestHeads failed") return } func (orm *DbORM) HeadByHash(ctx context.Context, hash common.Hash) (head *evmtypes.Head, err error) { head = new(evmtypes.Head) - err = orm.db.GetContext(ctx, head, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 AND hash = $2`, orm.chainID, hash) + err = orm.ds.GetContext(ctx, head, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 AND hash = $2`, orm.chainID, hash) if pkgerrors.Is(err, sql.ErrNoRows) { return nil, nil } diff --git a/core/chains/evm/log/broadcaster.go b/core/chains/evm/log/broadcaster.go index a96474c0f78..148c36148c2 100644 --- a/core/chains/evm/log/broadcaster.go +++ b/core/chains/evm/log/broadcaster.go @@ -9,14 +9,13 @@ import ( "sync/atomic" "time" - "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" pkgerrors "github.com/pkg/errors" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" @@ -60,12 +59,10 @@ type ( Register(listener Listener, opts ListenerOpts) (unsubscribe func()) WasAlreadyConsumed(ctx context.Context, lb Broadcast) (bool, error) - MarkConsumed(ctx context.Context, lb Broadcast) error - - // MarkManyConsumed marks all the provided log broadcasts as consumed. - MarkManyConsumed(ctx context.Context, lbs []Broadcast) error + // ds is optional + MarkConsumed(ctx context.Context, ds sqlutil.DataSource, lb Broadcast) error - // NOTE: WasAlreadyConsumed, MarkConsumed and MarkManyConsumed MUST be used within a single goroutine in order for WasAlreadyConsumed to be accurate + // NOTE: WasAlreadyConsumed, and MarkConsumed MUST be used within a single goroutine in order for WasAlreadyConsumed to be accurate } BroadcasterInTest interface { @@ -422,12 +419,15 @@ func (b *broadcaster) eventLoop(chRawLogs <-chan types.Log, chErr <-chan error) debounceResubscribe := time.NewTicker(1 * time.Second) defer debounceResubscribe.Stop() + ctx, cancel := b.chStop.NewCtx() + defer cancel() + b.logger.Debug("Starting the event loop") for { // Replay requests take priority. select { case req := <-b.replayChannel: - b.onReplayRequest(req) + b.onReplayRequest(ctx, req) return true, nil default: } @@ -456,7 +456,7 @@ func (b *broadcaster) eventLoop(chRawLogs <-chan types.Log, chErr <-chan error) needsResubscribe = b.onChangeSubscriberStatus() || needsResubscribe case req := <-b.replayChannel: - b.onReplayRequest(req) + b.onReplayRequest(ctx, req) return true, nil case <-debounceResubscribe.C: @@ -480,7 +480,7 @@ func (b *broadcaster) eventLoop(chRawLogs <-chan types.Log, chErr <-chan error) } // onReplayRequest clears the pool and sets the block backfill number. -func (b *broadcaster) onReplayRequest(replayReq replayRequest) { +func (b *broadcaster) onReplayRequest(ctx context.Context, replayReq replayRequest) { // notify subscribers that we are about to replay. for subscriber := range b.registrations.registeredSubs { if subscriber.opts.ReplayStartedCallback != nil { @@ -495,11 +495,11 @@ func (b *broadcaster) onReplayRequest(replayReq replayRequest) { b.backfillBlockNumber.Int64 = replayReq.fromBlock b.backfillBlockNumber.Valid = true if replayReq.forceBroadcast { - ctx, cancel := b.chStop.CtxCancel(context.WithTimeout(context.Background(), time.Minute)) - ctx = sqlutil.WithoutDefaultTimeout(ctx) - defer cancel() // Use a longer timeout in the event that a very large amount of logs need to be marked - // as consumed. + // as unconsumed. + var cancel func() + ctx, cancel = context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() err := b.orm.MarkBroadcastsUnconsumed(ctx, replayReq.fromBlock) if err != nil { b.logger.Errorw("Error marking broadcasts as unconsumed", @@ -694,25 +694,12 @@ func (b *broadcaster) WasAlreadyConsumed(ctx context.Context, lb Broadcast) (boo } // MarkConsumed marks the log as having been successfully consumed by the subscriber -func (b *broadcaster) MarkConsumed(ctx context.Context, lb Broadcast) error { - return b.orm.MarkBroadcastConsumed(ctx, lb.RawLog().BlockHash, lb.RawLog().BlockNumber, lb.RawLog().Index, lb.JobID()) -} - -// MarkManyConsumed marks the logs as having been successfully consumed by the subscriber -func (b *broadcaster) MarkManyConsumed(ctx context.Context, lbs []Broadcast) (err error) { - var ( - blockHashes = make([]common.Hash, len(lbs)) - blockNumbers = make([]uint64, len(lbs)) - logIndexes = make([]uint, len(lbs)) - jobIDs = make([]int32, len(lbs)) - ) - for i := range lbs { - blockHashes[i] = lbs[i].RawLog().BlockHash - blockNumbers[i] = lbs[i].RawLog().BlockNumber - logIndexes[i] = lbs[i].RawLog().Index - jobIDs[i] = lbs[i].JobID() +func (b *broadcaster) MarkConsumed(ctx context.Context, ds sqlutil.DataSource, lb Broadcast) error { + orm := b.orm + if ds != nil { + orm = orm.WithDataSource(ds) } - return b.orm.MarkBroadcastsConsumed(ctx, blockHashes, blockNumbers, logIndexes, jobIDs) + return orm.MarkBroadcastConsumed(ctx, lb.RawLog().BlockHash, lb.RawLog().BlockNumber, lb.RawLog().Index, lb.JobID()) } // test only @@ -779,10 +766,7 @@ func (n *NullBroadcaster) TrackedAddressesCount() uint32 { func (n *NullBroadcaster) WasAlreadyConsumed(ctx context.Context, lb Broadcast) (bool, error) { return false, pkgerrors.New(n.ErrMsg) } -func (n *NullBroadcaster) MarkConsumed(ctx context.Context, lb Broadcast) error { - return pkgerrors.New(n.ErrMsg) -} -func (n *NullBroadcaster) MarkManyConsumed(ctx context.Context, lbs []Broadcast) error { +func (n *NullBroadcaster) MarkConsumed(ctx context.Context, ds sqlutil.DataSource, lb Broadcast) error { return pkgerrors.New(n.ErrMsg) } diff --git a/core/chains/evm/log/helpers_test.go b/core/chains/evm/log/helpers_test.go index 18f396fab9d..85c2fe783bb 100644 --- a/core/chains/evm/log/helpers_test.go +++ b/core/chains/evm/log/helpers_test.go @@ -281,7 +281,7 @@ func (listener *simpleLogListener) SkipMarkingConsumed(skip bool) { listener.skipMarkingConsumed.Store(skip) } -func (listener *simpleLogListener) HandleLog(lb log.Broadcast) { +func (listener *simpleLogListener) HandleLog(ctx context.Context, lb log.Broadcast) { listener.received.Lock() defer listener.received.Unlock() listener.lggr.Tracef("Listener %v HandleLog for block %v %v received at %v %v", listener.name, lb.RawLog().BlockNumber, lb.RawLog().BlockHash, lb.LatestBlockNumber(), lb.LatestBlockHash()) diff --git a/core/chains/evm/log/mocks/broadcaster.go b/core/chains/evm/log/mocks/broadcaster.go index 26fe1a35101..e5164b56611 100644 --- a/core/chains/evm/log/mocks/broadcaster.go +++ b/core/chains/evm/log/mocks/broadcaster.go @@ -8,6 +8,8 @@ import ( log "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" mock "github.com/stretchr/testify/mock" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + types "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) @@ -102,35 +104,17 @@ func (_m *Broadcaster) IsConnected() bool { return r0 } -// MarkConsumed provides a mock function with given fields: ctx, lb -func (_m *Broadcaster) MarkConsumed(ctx context.Context, lb log.Broadcast) error { - ret := _m.Called(ctx, lb) +// MarkConsumed provides a mock function with given fields: ctx, ds, lb +func (_m *Broadcaster) MarkConsumed(ctx context.Context, ds sqlutil.DataSource, lb log.Broadcast) error { + ret := _m.Called(ctx, ds, lb) if len(ret) == 0 { panic("no return value specified for MarkConsumed") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, log.Broadcast) error); ok { - r0 = rf(ctx, lb) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MarkManyConsumed provides a mock function with given fields: ctx, lbs -func (_m *Broadcaster) MarkManyConsumed(ctx context.Context, lbs []log.Broadcast) error { - ret := _m.Called(ctx, lbs) - - if len(ret) == 0 { - panic("no return value specified for MarkManyConsumed") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, []log.Broadcast) error); ok { - r0 = rf(ctx, lbs) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, log.Broadcast) error); ok { + r0 = rf(ctx, ds, lb) } else { r0 = ret.Error(0) } diff --git a/core/chains/evm/log/orm.go b/core/chains/evm/log/orm.go index 71c9675d6fd..6e94d3bf8a8 100644 --- a/core/chains/evm/log/orm.go +++ b/core/chains/evm/log/orm.go @@ -3,16 +3,13 @@ package log import ( "context" "database/sql" - "fmt" "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/jmoiron/sqlx" pkgerrors "github.com/pkg/errors" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" - "github.com/smartcontractkit/chainlink-common/pkg/utils" ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" ) @@ -31,8 +28,6 @@ type ORM interface { WasBroadcastConsumed(ctx context.Context, blockHash common.Hash, logIndex uint, jobID int32) (bool, error) // MarkBroadcastConsumed marks the log broadcast as consumed by jobID. MarkBroadcastConsumed(ctx context.Context, blockHash common.Hash, blockNumber uint64, logIndex uint, jobID int32) error - // MarkBroadcastsConsumed marks the log broadcasts as consumed by jobID. - MarkBroadcastsConsumed(ctx context.Context, blockHashes []common.Hash, blockNumbers []uint64, logIndexes []uint, jobIDs []int32) error // MarkBroadcastsUnconsumed marks all log broadcasts from all jobs on or after fromBlock as // unconsumed. MarkBroadcastsUnconsumed(ctx context.Context, fromBlock int64) error @@ -45,20 +40,23 @@ type ORM interface { // Reinitialize cleans up the database by removing any unconsumed broadcasts, then updating (if necessary) and // returning the pending minimum block number. Reinitialize(ctx context.Context) (blockNumber *int64, err error) + + WithDataSource(sqlutil.DataSource) ORM } type orm struct { - db sqlutil.DataSource + ds sqlutil.DataSource evmChainID ubig.Big } var _ ORM = (*orm)(nil) -func NewORM(db sqlutil.DataSource, evmChainID big.Int) *orm { - return &orm{ - db: db, - evmChainID: *ubig.New(&evmChainID), - } +func NewORM(ds sqlutil.DataSource, evmChainID big.Int) *orm { + return &orm{ds, *ubig.New(&evmChainID)} +} + +func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { + return &orm{ds, o.evmChainID} } func (o *orm) WasBroadcastConsumed(ctx context.Context, blockHash common.Hash, logIndex uint, jobID int32) (consumed bool, err error) { @@ -75,7 +73,7 @@ func (o *orm) WasBroadcastConsumed(ctx context.Context, blockHash common.Hash, l jobID, o.evmChainID, } - err = o.db.GetContext(ctx, &consumed, query, args...) + err = o.ds.GetContext(ctx, &consumed, query, args...) if pkgerrors.Is(err, sql.ErrNoRows) { return false, nil } @@ -90,7 +88,7 @@ func (o *orm) FindBroadcasts(ctx context.Context, fromBlockNum int64, toBlockNum AND block_number <= $2 AND evm_chain_id = $3 ` - err := o.db.SelectContext(ctx, &broadcasts, query, fromBlockNum, toBlockNum, o.evmChainID) + err := o.ds.SelectContext(ctx, &broadcasts, query, fromBlockNum, toBlockNum, o.evmChainID) if err != nil { return nil, pkgerrors.Wrap(err, "failed to find log broadcasts") } @@ -98,7 +96,7 @@ func (o *orm) FindBroadcasts(ctx context.Context, fromBlockNum int64, toBlockNum } func (o *orm) CreateBroadcast(ctx context.Context, blockHash common.Hash, blockNumber uint64, logIndex uint, jobID int32) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` INSERT INTO log_broadcasts (block_hash, block_number, log_index, job_id, created_at, updated_at, consumed, evm_chain_id) VALUES ($1, $2, $3, $4, NOW(), NOW(), false, $5) `, blockHash, blockNumber, logIndex, jobID, o.evmChainID) @@ -106,7 +104,7 @@ func (o *orm) CreateBroadcast(ctx context.Context, blockHash common.Hash, blockN } func (o *orm) MarkBroadcastConsumed(ctx context.Context, blockHash common.Hash, blockNumber uint64, logIndex uint, jobID int32) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` INSERT INTO log_broadcasts (block_hash, block_number, log_index, job_id, created_at, updated_at, consumed, evm_chain_id) VALUES ($1, $2, $3, $4, NOW(), NOW(), true, $5) ON CONFLICT (job_id, block_hash, log_index, evm_chain_id) DO UPDATE @@ -115,45 +113,9 @@ func (o *orm) MarkBroadcastConsumed(ctx context.Context, blockHash common.Hash, return pkgerrors.Wrap(err, "failed to mark log broadcast as consumed") } -// MarkBroadcastsConsumed marks many broadcasts as consumed. -// The lengths of all the provided slices must be equal, otherwise an error is returned. -func (o *orm) MarkBroadcastsConsumed(ctx context.Context, blockHashes []common.Hash, blockNumbers []uint64, logIndexes []uint, jobIDs []int32) error { - if !utils.AllEqual(len(blockHashes), len(blockNumbers), len(logIndexes), len(jobIDs)) { - return fmt.Errorf("all arg slice lengths must be equal, got: %d %d %d %d", - len(blockHashes), len(blockNumbers), len(logIndexes), len(jobIDs), - ) - } - - type input struct { - BlockHash common.Hash `db:"blockHash"` - BlockNumber uint64 `db:"blockNumber"` - LogIndex uint `db:"logIndex"` - JobID int32 `db:"jobID"` - ChainID ubig.Big `db:"chainID"` - } - inputs := make([]input, len(blockHashes)) - query := ` -INSERT INTO log_broadcasts (block_hash, block_number, log_index, job_id, created_at, updated_at, consumed, evm_chain_id) -VALUES (:blockHash, :blockNumber, :logIndex, :jobID, NOW(), NOW(), true, :chainID) -ON CONFLICT (job_id, block_hash, log_index, evm_chain_id) DO UPDATE -SET consumed = true, updated_at = NOW(); - ` - for i := range blockHashes { - inputs[i] = input{ - BlockHash: blockHashes[i], - BlockNumber: blockNumbers[i], - LogIndex: logIndexes[i], - JobID: jobIDs[i], - ChainID: o.evmChainID, - } - } - _, err := o.db.(*sqlx.DB).NamedExecContext(ctx, query, inputs) - return pkgerrors.Wrap(err, "mark broadcasts consumed") -} - // MarkBroadcastsUnconsumed implements the ORM interface. func (o *orm) MarkBroadcastsUnconsumed(ctx context.Context, fromBlock int64) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` UPDATE log_broadcasts SET consumed = false WHERE block_number >= $1 @@ -193,7 +155,7 @@ func (o *orm) Reinitialize(ctx context.Context) (*int64, error) { } func (o *orm) SetPendingMinBlock(ctx context.Context, blockNumber *int64) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` INSERT INTO log_broadcasts_pending (evm_chain_id, block_number, created_at, updated_at) VALUES ($1, $2, NOW(), NOW()) ON CONFLICT (evm_chain_id) DO UPDATE SET block_number = $3, updated_at = NOW() `, o.evmChainID, blockNumber, blockNumber) @@ -202,7 +164,7 @@ func (o *orm) SetPendingMinBlock(ctx context.Context, blockNumber *int64) error func (o *orm) GetPendingMinBlock(ctx context.Context) (*int64, error) { var blockNumber *int64 - err := o.db.GetContext(ctx, &blockNumber, ` + err := o.ds.GetContext(ctx, &blockNumber, ` SELECT block_number FROM log_broadcasts_pending WHERE evm_chain_id = $1 `, o.evmChainID) if pkgerrors.Is(err, sql.ErrNoRows) { @@ -215,7 +177,7 @@ func (o *orm) GetPendingMinBlock(ctx context.Context) (*int64, error) { func (o *orm) getUnconsumedMinBlock(ctx context.Context) (*int64, error) { var blockNumber *int64 - err := o.db.GetContext(ctx, &blockNumber, ` + err := o.ds.GetContext(ctx, &blockNumber, ` SELECT min(block_number) FROM log_broadcasts WHERE evm_chain_id = $1 AND consumed = false @@ -230,7 +192,7 @@ func (o *orm) getUnconsumedMinBlock(ctx context.Context) (*int64, error) { } func (o *orm) removeUnconsumed(ctx context.Context) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` DELETE FROM log_broadcasts WHERE evm_chain_id = $1 AND consumed = false diff --git a/core/chains/evm/log/orm_test.go b/core/chains/evm/log/orm_test.go index ba9509d4518..1a6d927cd50 100644 --- a/core/chains/evm/log/orm_test.go +++ b/core/chains/evm/log/orm_test.go @@ -21,7 +21,6 @@ func TestORM_broadcasts(t *testing.T) { db := pgtest.NewSqlxDB(t) cfg := configtest.NewGeneralConfig(t, nil) ethKeyStore := cltest.NewKeyStore(t, db, cfg.Database()).Eth() - ctx := testutils.Context(t) orm := log.NewORM(db, cltest.FixtureChainID) @@ -44,12 +43,12 @@ func TestORM_broadcasts(t *testing.T) { require.Zero(t, rowsAffected) t.Run("WasBroadcastConsumed_DNE", func(t *testing.T) { - _, err := orm.WasBroadcastConsumed(ctx, rawLog.BlockHash, rawLog.Index, listener.JobID()) + _, err := orm.WasBroadcastConsumed(testutils.Context(t), rawLog.BlockHash, rawLog.Index, listener.JobID()) require.NoError(t, err) }) require.True(t, t.Run("CreateBroadcast", func(t *testing.T) { - err := orm.CreateBroadcast(ctx, rawLog.BlockHash, rawLog.BlockNumber, rawLog.Index, listener.JobID()) + err := orm.CreateBroadcast(testutils.Context(t), rawLog.BlockHash, rawLog.BlockNumber, rawLog.Index, listener.JobID()) require.NoError(t, err) var consumed null.Bool @@ -59,13 +58,13 @@ func TestORM_broadcasts(t *testing.T) { })) t.Run("WasBroadcastConsumed_false", func(t *testing.T) { - was, err := orm.WasBroadcastConsumed(ctx, rawLog.BlockHash, rawLog.Index, listener.JobID()) + was, err := orm.WasBroadcastConsumed(testutils.Context(t), rawLog.BlockHash, rawLog.Index, listener.JobID()) require.NoError(t, err) require.False(t, was) }) require.True(t, t.Run("MarkBroadcastConsumed", func(t *testing.T) { - err := orm.MarkBroadcastConsumed(ctx, rawLog.BlockHash, rawLog.BlockNumber, rawLog.Index, listener.JobID()) + err := orm.MarkBroadcastConsumed(testutils.Context(t), rawLog.BlockHash, rawLog.BlockNumber, rawLog.Index, listener.JobID()) require.NoError(t, err) var consumed null.Bool @@ -74,66 +73,17 @@ func TestORM_broadcasts(t *testing.T) { require.Equal(t, null.BoolFrom(true), consumed) })) - t.Run("MarkBroadcastsConsumed Success", func(t *testing.T) { - var ( - err error - blockHashes []common.Hash - blockNumbers []uint64 - logIndexes []uint - jobIDs []int32 - ) - for i := 0; i < 3; i++ { - l := cltest.RandomLog(t) - err = orm.CreateBroadcast(ctx, l.BlockHash, l.BlockNumber, l.Index, listener.JobID()) - require.NoError(t, err) - blockHashes = append(blockHashes, l.BlockHash) - blockNumbers = append(blockNumbers, l.BlockNumber) - logIndexes = append(logIndexes, l.Index) - jobIDs = append(jobIDs, listener.JobID()) - - } - err = orm.MarkBroadcastsConsumed(ctx, blockHashes, blockNumbers, logIndexes, jobIDs) - require.NoError(t, err) - - for i := range blockHashes { - was, err := orm.WasBroadcastConsumed(ctx, blockHashes[i], logIndexes[i], jobIDs[i]) - require.NoError(t, err) - require.True(t, was) - } - }) - - t.Run("MarkBroadcastsConsumed Failure", func(t *testing.T) { - var ( - err error - blockHashes []common.Hash - blockNumbers []uint64 - logIndexes []uint - jobIDs []int32 - ) - for i := 0; i < 5; i++ { - l := cltest.RandomLog(t) - err = orm.CreateBroadcast(ctx, l.BlockHash, l.BlockNumber, l.Index, listener.JobID()) - require.NoError(t, err) - blockHashes = append(blockHashes, l.BlockHash) - blockNumbers = append(blockNumbers, l.BlockNumber) - logIndexes = append(logIndexes, l.Index) - jobIDs = append(jobIDs, listener.JobID()) - } - err = orm.MarkBroadcastsConsumed(ctx, blockHashes[:len(blockHashes)-2], blockNumbers, logIndexes, jobIDs) - require.Error(t, err) - }) - t.Run("WasBroadcastConsumed_true", func(t *testing.T) { - was, err := orm.WasBroadcastConsumed(ctx, rawLog.BlockHash, rawLog.Index, listener.JobID()) + was, err := orm.WasBroadcastConsumed(testutils.Context(t), rawLog.BlockHash, rawLog.Index, listener.JobID()) require.NoError(t, err) require.True(t, was) }) } func TestORM_pending(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) orm := log.NewORM(db, cltest.FixtureChainID) - ctx := testutils.Context(t) num, err := orm.GetPendingMinBlock(ctx) require.NoError(t, err) @@ -156,9 +106,9 @@ func TestORM_pending(t *testing.T) { } func TestORM_MarkUnconsumed(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) cfg := configtest.NewGeneralConfig(t, nil) - ctx := testutils.Context(t) ethKeyStore := cltest.NewKeyStore(t, db, cfg.Database()).Eth() orm := log.NewORM(db, cltest.FixtureChainID) @@ -256,8 +206,8 @@ func TestORM_Reinitialize(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { db := pgtest.NewSqlxDB(t) - orm := log.NewORM(db, cltest.FixtureChainID) ctx := testutils.Context(t) + orm := log.NewORM(db, cltest.FixtureChainID) jobID := cltest.MustInsertV2JobSpec(t, db, common.BigToAddress(big.NewInt(rand.Int63()))).ID diff --git a/core/chains/evm/log/registrations.go b/core/chains/evm/log/registrations.go index b56d3f4aaaa..c82fee43b6e 100644 --- a/core/chains/evm/log/registrations.go +++ b/core/chains/evm/log/registrations.go @@ -62,7 +62,7 @@ type ( // The Listener responds to log events through HandleLog. Listener interface { - HandleLog(b Broadcast) + HandleLog(ctx context.Context, b Broadcast) JobID() int32 } @@ -240,6 +240,9 @@ func (r *registrations) sendLogs(ctx context.Context, logsToSend []logsOnBlock, for _, log := range logsPerBlock.Logs { handlers.sendLog(ctx, log, latestHead, broadcastsExisting, bc, r.logger) + if ctx.Err() != nil { + return + } } } } @@ -442,7 +445,7 @@ func (r *handler) sendLog(ctx context.Context, log types.Log, latestHead evmtype wg.Add(1) go func() { defer wg.Done() - handleLog(&broadcast{ + handleLog(ctx, &broadcast{ latestBlockNumber, latestHead.Hash, latestHead.ReceiptsRoot, diff --git a/core/chains/evm/log/registrations_test.go b/core/chains/evm/log/registrations_test.go index 2be01dca2bf..8c0beaa9379 100644 --- a/core/chains/evm/log/registrations_test.go +++ b/core/chains/evm/log/registrations_test.go @@ -1,6 +1,7 @@ package log import ( + "context" "testing" "github.com/ethereum/go-ethereum/common" @@ -18,8 +19,8 @@ type testListener struct { jobID int32 } -func (tl testListener) JobID() int32 { return tl.jobID } -func (tl testListener) HandleLog(Broadcast) { panic("not implemented") } +func (tl testListener) JobID() int32 { return tl.jobID } +func (tl testListener) HandleLog(context.Context, Broadcast) { panic("not implemented") } func newTestListener(t *testing.T, jobID int32) testListener { return testListener{jobID} diff --git a/core/chains/evm/txmgr/broadcaster_test.go b/core/chains/evm/txmgr/broadcaster_test.go index 1e8f1c73b34..3500002e8da 100644 --- a/core/chains/evm/txmgr/broadcaster_test.go +++ b/core/chains/evm/txmgr/broadcaster_test.go @@ -1113,7 +1113,7 @@ func TestEthBroadcaster_ProcessUnstartedEthTxs_Errors(t *testing.T) { t.Run("with erroring callback bails out", func(t *testing.T) { require.NoError(t, txStore.InsertTx(ctx, &etx)) - fn := func(id uuid.UUID, result interface{}, err error) error { + fn := func(ctx context.Context, id uuid.UUID, result interface{}, err error) error { return errors.New("something exploded in the callback") } @@ -1130,7 +1130,7 @@ func TestEthBroadcaster_ProcessUnstartedEthTxs_Errors(t *testing.T) { }) t.Run("calls resume with error", func(t *testing.T) { - fn := func(id uuid.UUID, result interface{}, err error) error { + fn := func(ctx context.Context, id uuid.UUID, result interface{}, err error) error { require.Equal(t, id, tr.ID) require.Nil(t, result) require.Error(t, err) diff --git a/core/chains/evm/txmgr/confirmer_test.go b/core/chains/evm/txmgr/confirmer_test.go index 80868d448e0..357dafcbdc4 100644 --- a/core/chains/evm/txmgr/confirmer_test.go +++ b/core/chains/evm/txmgr/confirmer_test.go @@ -1,6 +1,7 @@ package txmgr_test import ( + "context" "encoding/json" "errors" "fmt" @@ -2966,7 +2967,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { pgtest.MustExec(t, db, `SET CONSTRAINTS pipeline_runs_pipeline_spec_id_fkey DEFERRED`) t.Run("doesn't process task runs that are not suspended (possibly already previously resumed)", func(t *testing.T) { - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(uuid.UUID, interface{}, error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(context.Context, uuid.UUID, interface{}, error) error { t.Fatal("No value expected") return nil }) @@ -2985,7 +2986,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { }) t.Run("doesn't process task runs where the receipt is younger than minConfirmations", func(t *testing.T) { - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(uuid.UUID, interface{}, error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(context.Context, uuid.UUID, interface{}, error) error { t.Fatal("No value expected") return nil }) @@ -3006,7 +3007,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { ch := make(chan interface{}) nonce := evmtypes.Nonce(3) var err error - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(id uuid.UUID, value interface{}, thisErr error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(ctx context.Context, id uuid.UUID, value interface{}, thisErr error) error { err = thisErr ch <- value return nil @@ -3059,7 +3060,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { } ch := make(chan data) nonce := evmtypes.Nonce(4) - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(id uuid.UUID, value interface{}, err error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(ctx context.Context, id uuid.UUID, value interface{}, err error) error { ch <- data{value, err} return nil }) @@ -3106,7 +3107,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { t.Run("does not mark callback complete if callback fails", func(t *testing.T) { nonce := evmtypes.Nonce(5) - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(uuid.UUID, interface{}, error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(context.Context, uuid.UUID, interface{}, error) error { return errors.New("error") }) diff --git a/core/internal/cltest/cltest.go b/core/internal/cltest/cltest.go index 3a92269cc03..ba182d60515 100644 --- a/core/internal/cltest/cltest.go +++ b/core/internal/cltest/cltest.go @@ -182,7 +182,7 @@ type JobPipelineConfig interface { func NewJobPipelineV2(t testing.TB, cfg pipeline.BridgeConfig, jpcfg JobPipelineConfig, dbCfg pg.QConfig, legacyChains legacyevm.LegacyChainContainer, db *sqlx.DB, keyStore keystore.Master, restrictedHTTPClient, unrestrictedHTTPClient *http.Client) JobPipelineV2TestHelper { lggr := logger.TestLogger(t) - prm := pipeline.NewORM(db, lggr, dbCfg, jpcfg.MaxSuccessfulRuns()) + prm := pipeline.NewORM(db, lggr, jpcfg.MaxSuccessfulRuns()) btORM := bridges.NewORM(db) jrm := job.NewORM(db, prm, btORM, keyStore, lggr, dbCfg) pr := pipeline.NewRunner(prm, btORM, jpcfg, cfg, legacyChains, keyStore.Eth(), keyStore.VRF(), lggr, restrictedHTTPClient, unrestrictedHTTPClient) diff --git a/core/internal/cltest/factories.go b/core/internal/cltest/factories.go index d7e1036bcac..43cf902ca8a 100644 --- a/core/internal/cltest/factories.go +++ b/core/internal/cltest/factories.go @@ -402,7 +402,7 @@ func MustInsertKeeperJob(t *testing.T, db *sqlx.DB, korm *keeper.ORM, from evmty cfg := configtest.NewTestGeneralConfig(t) tlg := logger.TestLogger(t) - prm := pipeline.NewORM(db, tlg, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + prm := pipeline.NewORM(db, tlg, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) jrm := job.NewORM(db, prm, btORM, nil, tlg, cfg.Database()) err = jrm.InsertJob(&jb) diff --git a/core/internal/cltest/job_factories.go b/core/internal/cltest/job_factories.go index 5d8f75e36c3..2b527fbc29c 100644 --- a/core/internal/cltest/job_factories.go +++ b/core/internal/cltest/job_factories.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" @@ -43,12 +44,13 @@ func MinimalOCRNonBootstrapSpec(contractAddress, transmitterAddress types.EIP55A } func MustInsertWebhookSpec(t *testing.T, db *sqlx.DB) (job.Job, job.WebhookSpec) { + ctx := testutils.Context(t) jobORM, pipelineORM := getORMs(t, db) webhookSpec := job.WebhookSpec{} require.NoError(t, jobORM.InsertWebhookSpec(&webhookSpec)) pSpec := pipeline.Pipeline{} - pipelineSpecID, err := pipelineORM.CreateSpec(pSpec, 0) + pipelineSpecID, err := pipelineORM.CreateSpec(ctx, nil, pSpec, 0) require.NoError(t, err) createdJob := job.Job{WebhookSpecID: &webhookSpec.ID, WebhookSpec: &webhookSpec, SchemaVersion: 1, Type: "webhook", @@ -62,7 +64,7 @@ func getORMs(t *testing.T, db *sqlx.DB) (jobORM job.ORM, pipelineORM pipeline.OR config := configtest.NewTestGeneralConfig(t) keyStore := NewKeyStore(t, db, config.Database()) lggr := logger.TestLogger(t) - pipelineORM = pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM = pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) jobORM = job.NewORM(db, pipelineORM, bridgeORM, keyStore, lggr, config.Database()) t.Cleanup(func() { jobORM.Close() }) diff --git a/core/internal/features/features_test.go b/core/internal/features/features_test.go index 2c40c848263..4afad453110 100644 --- a/core/internal/features/features_test.go +++ b/core/internal/features/features_test.go @@ -236,7 +236,7 @@ observationSource = """ _ = cltest.CreateJobRunViaExternalInitiatorV2(t, app, jobUUID, *eia, cltest.MustJSONMarshal(t, eiRequest)) - pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(app.GetSqlxDB()) jobORM := job.NewORM(app.GetSqlxDB(), pipelineORM, bridgeORM, app.KeyStore, logger.TestLogger(t), cfg.Database()) diff --git a/core/services/blockhashstore/delegate.go b/core/services/blockhashstore/delegate.go index 9a11c057c32..6bcfc26ddb6 100644 --- a/core/services/blockhashstore/delegate.go +++ b/core/services/blockhashstore/delegate.go @@ -19,7 +19,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -194,7 +193,7 @@ func (d *Delegate) BeforeJobCreated(spec job.Job) {} func (d *Delegate) BeforeJobDeleted(spec job.Job) {} // OnDeleteJob satisfies the job.Delegate interface. -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // service is a job.Service that runs the BHS feeder every pollPeriod. type service struct { diff --git a/core/services/blockheaderfeeder/delegate.go b/core/services/blockheaderfeeder/delegate.go index 19edb43bc23..07cab534af7 100644 --- a/core/services/blockheaderfeeder/delegate.go +++ b/core/services/blockheaderfeeder/delegate.go @@ -19,7 +19,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/blockhashstore" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -208,7 +207,7 @@ func (d *Delegate) BeforeJobCreated(spec job.Job) {} func (d *Delegate) BeforeJobDeleted(spec job.Job) {} // OnDeleteJob satisfies the job.Delegate interface. -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // service is a job.Service that runs the BHS feeder every pollPeriod. type service struct { diff --git a/core/services/chainlink/application.go b/core/services/chainlink/application.go index 832bea523b5..8542074c27c 100644 --- a/core/services/chainlink/application.go +++ b/core/services/chainlink/application.go @@ -308,7 +308,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { } var ( - pipelineORM = pipeline.NewORM(sqlxDB, globalLogger, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM = pipeline.NewORM(sqlxDB, globalLogger, cfg.JobPipeline().MaxSuccessfulRuns()) bridgeORM = bridges.NewORM(sqlxDB) mercuryORM = mercury.NewORM(opts.DB) pipelineRunner = pipeline.NewRunner(pipelineORM, bridgeORM, cfg.JobPipeline(), cfg.WebServer(), legacyEVMChains, keyStore.Eth(), keyStore.VRF(), globalLogger, restrictedHTTPClient, unrestrictedHTTPClient) @@ -346,7 +346,6 @@ func NewApplication(opts ApplicationOpts) (Application, error) { pipelineORM, legacyEVMChains, globalLogger, - cfg.Database(), mailMon), job.Webhook: webhook.NewDelegate( pipelineRunner, @@ -829,7 +828,7 @@ func (app *ChainlinkApplication) ResumeJobV2( taskID uuid.UUID, result pipeline.Result, ) error { - return app.pipelineRunner.ResumeRun(taskID, result.Value, result.Error) + return app.pipelineRunner.ResumeRun(ctx, taskID, result.Value, result.Error) } func (app *ChainlinkApplication) GetFeedsService() feeds.Service { diff --git a/core/services/cron/cron_test.go b/core/services/cron/cron_test.go index 3ace0f3ceae..c3ecc0957c7 100644 --- a/core/services/cron/cron_test.go +++ b/core/services/cron/cron_test.go @@ -27,7 +27,7 @@ func TestCronV2Pipeline(t *testing.T) { keyStore := cltest.NewKeyStore(t, db, cfg.Database()) lggr := logger.TestLogger(t) - orm := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + orm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) jobORM := job.NewORM(db, orm, btORM, keyStore, lggr, cfg.Database()) diff --git a/core/services/cron/delegate.go b/core/services/cron/delegate.go index 05b5b36c00f..d8a1390103e 100644 --- a/core/services/cron/delegate.go +++ b/core/services/cron/delegate.go @@ -7,7 +7,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -29,10 +28,10 @@ func (d *Delegate) JobType() job.Type { return job.Cron } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the scheduler to be used for running cron jobs func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { diff --git a/core/services/directrequest/delegate.go b/core/services/directrequest/delegate.go index d6afc215fb9..33a0a7e73da 100644 --- a/core/services/directrequest/delegate.go +++ b/core/services/directrequest/delegate.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" @@ -19,7 +20,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/operator_wrapper" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" ) @@ -63,10 +63,10 @@ func (d *Delegate) JobType() job.Type { return job.DirectRequest } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the log listener service for a direct request job func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.ServiceCtx, error) { @@ -191,7 +191,7 @@ func (l *listener) Close() error { }) } -func (l *listener) HandleLog(lb log.Broadcast) { +func (l *listener) HandleLog(ctx context.Context, lb log.Broadcast) { log := lb.DecodedLog() if log == nil || reflect.ValueOf(log).IsNil() { l.logger.Error("HandleLog: ignoring nil value") @@ -374,7 +374,7 @@ func (l *listener) handleOracleRequest(ctx context.Context, request *operator_wr }, }) run := pipeline.NewRun(*l.job.PipelineSpec, vars) - _, err := l.pipelineRunner.Run(ctx, run, l.logger, true, func(tx pg.Queryer) error { + _, err := l.pipelineRunner.Run(ctx, run, l.logger, true, func(tx sqlutil.DataSource) error { l.markLogConsumed(ctx, lb) return nil }) @@ -407,7 +407,7 @@ func (l *listener) handleCancelOracleRequest(ctx context.Context, request *opera } func (l *listener) markLogConsumed(ctx context.Context, lb log.Broadcast) { - if err := l.logBroadcaster.MarkConsumed(ctx, lb); err != nil { + if err := l.logBroadcaster.MarkConsumed(ctx, nil, lb); err != nil { l.logger.Errorw("Unable to mark log consumed", "err", err, "log", lb.String()) } } diff --git a/core/services/directrequest/delegate_test.go b/core/services/directrequest/delegate_test.go index a7f2ba01315..0235a0c4eec 100644 --- a/core/services/directrequest/delegate_test.go +++ b/core/services/directrequest/delegate_test.go @@ -15,6 +15,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox/mailboxtest" "github.com/smartcontractkit/chainlink/v2/core/bridges" @@ -31,7 +32,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" "github.com/smartcontractkit/chainlink/v2/core/services/directrequest" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" pipeline_mocks "github.com/smartcontractkit/chainlink/v2/core/services/pipeline/mocks" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" @@ -88,7 +88,7 @@ func NewDirectRequestUniverseWithConfig(t *testing.T, cfg chainlink.GeneralConfi keyStore := cltest.NewKeyStore(t, db, cfg.Database()) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: cfg, Client: ethClient, LogBroadcaster: broadcaster, MailMon: mailMon, KeyStore: keyStore.Eth()}) lggr := logger.TestLogger(t) - orm := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + orm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) jobORM := job.NewORM(db, orm, btORM, keyStore, lggr, cfg.Database()) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -159,28 +159,29 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) runBeganAwaiter := cltest.NewAwaiter() uni.runner.On("Run", mock.Anything, mock.AnythingOfType("*pipeline.Run"), mock.Anything, mock.Anything, mock.Anything). Return(false, nil). Run(func(args mock.Arguments) { runBeganAwaiter.ItHappened() - fn := args.Get(4).(func(pg.Queryer) error) + fn := args.Get(4).(func(source sqlutil.DataSource) error) require.NoError(t, fn(nil)) }).Once() - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) require.NotNil(t, uni.listener, "listener was nil; expected broadcaster.Register to have been called") // check if the job exists under the correct ID - drJob, jErr := uni.jobORM.FindJob(testutils.Context(t), uni.listener.JobID()) + drJob, jErr := uni.jobORM.FindJob(ctx, uni.listener.JobID()) require.NoError(t, jErr) require.Equal(t, drJob.ID, uni.listener.JobID()) require.NotNil(t, drJob.DirectRequestSpec) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) runBeganAwaiter.AwaitOrFail(t, 5*time.Second) @@ -207,12 +208,13 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("DecodedLog").Return(&logOracleRequest).Maybe() log.On("String").Return("") log.On("EVMChainID").Return(*big.NewInt(0)) - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Maybe() + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) uni.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) @@ -224,7 +226,7 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { uni.runner.On("Run", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { runBeganAwaiter.ItHappened() - fn := args.Get(4).(func(pg.Queryer) error) + fn := args.Get(4).(func(sqlutil.DataSource) error) require.NoError(t, fn(nil)) }).Once().Return(false, nil) @@ -241,7 +243,7 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log := log_mocks.NewBroadcast(t) lbAwaiter := cltest.NewAwaiter() uni.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { lbAwaiter.ItHappened() }).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { lbAwaiter.ItHappened() }).Return(nil) logCancelOracleRequest := operator_wrapper.OperatorCancelOracleRequest{RequestId: uni.spec.ExternalIDEncodeStringToTopic()} logAwaiter := cltest.NewAwaiter() @@ -251,10 +253,11 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) log.On("String").Return("") - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) logAwaiter.AwaitOrFail(t) lbAwaiter.AwaitOrFail(t) @@ -279,12 +282,13 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("String").Return("") log.On("DecodedLog").Return(&logCancelOracleRequest) lbAwaiter := cltest.NewAwaiter() - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { lbAwaiter.ItHappened() }).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { lbAwaiter.ItHappened() }).Return(nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) lbAwaiter.AwaitOrFail(t) @@ -314,7 +318,7 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) runLog.On("DecodedLog").Return(&logOracleRequest) runLog.On("String").Return("") - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) cancelLog := log_mocks.NewBroadcast(t) @@ -328,9 +332,10 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) cancelLog.On("DecodedLog").Return(&logCancelOracleRequest) cancelLog.On("String").Return("") - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) timeout := 5 * time.Second @@ -346,11 +351,11 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { runCancelledAwaiter.ItHappened() } }).Once().Return(false, nil) - uni.listener.HandleLog(runLog) + uni.listener.HandleLog(ctx, runLog) runBeganAwaiter.AwaitOrFail(t, timeout) - uni.listener.HandleLog(cancelLog) + uni.listener.HandleLog(ctx, cancelLog) runCancelledAwaiter.AwaitOrFail(t, timeout) @@ -384,25 +389,26 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) runBeganAwaiter := cltest.NewAwaiter() uni.runner.On("Run", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { runBeganAwaiter.ItHappened() - fn := args.Get(4).(func(pg.Queryer) error) + fn := args.Get(4).(func(sqlutil.DataSource) error) require.NoError(t, fn(nil)) }).Once().Return(false, nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) // check if the job exists under the correct ID - drJob, jErr := uni.jobORM.FindJob(testutils.Context(t), uni.listener.JobID()) + drJob, jErr := uni.jobORM.FindJob(ctx, uni.listener.JobID()) require.NoError(t, jErr) require.Equal(t, drJob.ID, uni.listener.JobID()) require.NotNil(t, drJob.DirectRequestSpec) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) runBeganAwaiter.AwaitOrFail(t, 5*time.Second) @@ -433,14 +439,15 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") markConsumedLogAwaiter := cltest.NewAwaiter() - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { markConsumedLogAwaiter.ItHappened() }).Return(nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) markConsumedLogAwaiter.AwaitOrFail(t, 5*time.Second) @@ -479,27 +486,28 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") markConsumedLogAwaiter := cltest.NewAwaiter() - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { markConsumedLogAwaiter.ItHappened() }).Return(nil) runBeganAwaiter := cltest.NewAwaiter() uni.runner.On("Run", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { runBeganAwaiter.ItHappened() - fn := args.Get(4).(func(pg.Queryer) error) + fn := args.Get(4).(func(sqlutil.DataSource) error) require.NoError(t, fn(nil)) }).Once().Return(false, nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) // check if the job exists under the correct ID - drJob, jErr := uni.jobORM.FindJob(testutils.Context(t), uni.listener.JobID()) + drJob, jErr := uni.jobORM.FindJob(ctx, uni.listener.JobID()) require.NoError(t, jErr) require.Equal(t, drJob.ID, uni.listener.JobID()) require.NotNil(t, drJob.DirectRequestSpec) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) runBeganAwaiter.AwaitOrFail(t, 5*time.Second) @@ -534,14 +542,15 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") markConsumedLogAwaiter := cltest.NewAwaiter() - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { markConsumedLogAwaiter.ItHappened() }).Return(nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) markConsumedLogAwaiter.AwaitOrFail(t, 5*time.Second) diff --git a/core/services/feeds/orm_test.go b/core/services/feeds/orm_test.go index 23f40b9d55c..51a85a33a46 100644 --- a/core/services/feeds/orm_test.go +++ b/core/services/feeds/orm_test.go @@ -1652,7 +1652,7 @@ func createJob(t *testing.T, db *sqlx.DB, externalJobID uuid.UUID) *job.Job { config = configtest.NewGeneralConfig(t, nil) keyStore = cltest.NewKeyStore(t, db, config.Database()) lggr = logger.TestLogger(t) - pipelineORM = pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM = pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM = bridges.NewORM(db) relayExtenders = evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) ) diff --git a/core/services/feeds/service.go b/core/services/feeds/service.go index aa6ccdb39d7..27d324d2342 100644 --- a/core/services/feeds/service.go +++ b/core/services/feeds/service.go @@ -1191,7 +1191,7 @@ func (s *service) newChainConfigMsg(cfg ChainConfig) (*pb.ChainConfig, error) { }, nil } -// newFMConfigMsg generates a FMConfig protobuf message. Flux Monitor does not +// newFluxMonitorConfigMsg generates a FMConfig protobuf message. Flux Monitor does not // have any configuration but this is here for consistency. func (*service) newFluxMonitorConfigMsg(cfg FluxMonitorConfig) *pb.FluxMonitorConfig { return &pb.FluxMonitorConfig{Enabled: cfg.Enabled} diff --git a/core/services/fluxmonitorv2/delegate.go b/core/services/fluxmonitorv2/delegate.go index 1e2eba8d000..ddb255800b1 100644 --- a/core/services/fluxmonitorv2/delegate.go +++ b/core/services/fluxmonitorv2/delegate.go @@ -13,7 +13,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -56,10 +55,10 @@ func (d *Delegate) JobType() job.Type { return job.FluxMonitor } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the flux monitor service for the job spec func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { @@ -80,7 +79,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] fm, err := NewFromJobSpec( jb, d.db, - NewORM(d.db, d.lggr, chain.Config().Database(), chain.TxManager(), strategy, checker), + NewORM(d.db, d.lggr, chain.TxManager(), strategy, checker), d.jobORM, d.pipelineORM, NewKeyStore(d.ethKeyStore), @@ -89,10 +88,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] d.pipelineRunner, chain.Config().EVM(), chain.Config().EVM().GasEstimator(), - chain.Config().EVM().Transactions(), - chain.Config().FluxMonitor(), chain.Config().JobPipeline(), - chain.Config().Database(), d.lggr, ) if err != nil { diff --git a/core/services/fluxmonitorv2/flux_monitor.go b/core/services/fluxmonitorv2/flux_monitor.go index 73034faa3ce..5eebb319030 100644 --- a/core/services/fluxmonitorv2/flux_monitor.go +++ b/core/services/fluxmonitorv2/flux_monitor.go @@ -16,6 +16,7 @@ import ( "github.com/jmoiron/sqlx" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/bridges" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" @@ -27,7 +28,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/recovery" "github.com/smartcontractkit/chainlink/v2/core/services/fluxmonitorv2/promfm" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -64,7 +64,7 @@ type FluxMonitor struct { jobSpec job.Job spec pipeline.Spec runner pipeline.Runner - q pg.Q + ds sqlutil.DataSource orm ORM jobORM job.ORM pipelineORM pipeline.ORM @@ -93,7 +93,7 @@ func NewFluxMonitor( pipelineRunner pipeline.Runner, jobSpec job.Job, spec pipeline.Spec, - q pg.Q, + ds sqlutil.DataSource, orm ORM, jobORM job.ORM, pipelineORM pipeline.ORM, @@ -111,7 +111,7 @@ func NewFluxMonitor( chainID *big.Int, ) (*FluxMonitor, error) { fm := &FluxMonitor{ - q: q, + ds: ds, runner: pipelineRunner, jobSpec: jobSpec, spec: spec, @@ -159,10 +159,7 @@ func NewFromJobSpec( pipelineRunner pipeline.Runner, cfg Config, fcfg EvmFeeConfig, - ecfg EvmTransactionsConfig, - fmcfg FluxMonitorConfig, jcfg JobPipelineConfig, - dbCfg pg.QConfig, lggr logger.Logger, ) (*FluxMonitor, error) { fmSpec := jobSpec.FluxMonitorSpec @@ -253,7 +250,7 @@ func NewFromJobSpec( pipelineRunner, jobSpec, *jobSpec.PipelineSpec, - pg.NewQ(db, lggr, dbCfg), + db, orm, jobORM, pipelineORM, @@ -325,7 +322,7 @@ func (fm *FluxMonitor) Close() error { func (fm *FluxMonitor) JobID() int32 { return fm.spec.JobID } // HandleLog processes the contract logs -func (fm *FluxMonitor) HandleLog(broadcast log.Broadcast) { +func (fm *FluxMonitor) HandleLog(ctx context.Context, broadcast log.Broadcast) { log := broadcast.DecodedLog() if log == nil || reflect.ValueOf(log).IsNil() { fm.logger.Panic("HandleLog: failed to handle log of type nil") @@ -509,15 +506,16 @@ func (fm *FluxMonitor) SetOracleAddress() error { } func (fm *FluxMonitor) processLogs() { - for !fm.backlog.Empty() { + ctx, cancel := fm.chStop.NewCtx() + defer cancel() + + for ctx.Err() == nil && !fm.backlog.Empty() { broadcast := fm.backlog.Take() - fm.processBroadcast(broadcast) + fm.processBroadcast(ctx, broadcast) } } -func (fm *FluxMonitor) processBroadcast(broadcast log.Broadcast) { - ctx, cancel := fm.chStop.NewCtx() - defer cancel() +func (fm *FluxMonitor) processBroadcast(ctx context.Context, broadcast log.Broadcast) { // If the log is a duplicate of one we've seen before, ignore it (this // happens because of the LogBroadcaster's backfilling behavior). consumed, err := fm.logBroadcaster.WasAlreadyConsumed(ctx, broadcast) @@ -553,7 +551,7 @@ func (fm *FluxMonitor) processBroadcast(broadcast log.Broadcast) { } func (fm *FluxMonitor) markLogAsConsumed(ctx context.Context, broadcast log.Broadcast, decodedLog interface{}, started time.Time) { - if err := fm.logBroadcaster.MarkConsumed(ctx, broadcast); err != nil { + if err := fm.logBroadcaster.MarkConsumed(ctx, nil, broadcast); err != nil { fm.logger.Errorw("Failed to mark log as consumed", "err", err, "logType", fmt.Sprintf("%T", decodedLog), "log", broadcast.String(), "elapsed", time.Since(started)) } @@ -608,7 +606,7 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr var markConsumed = true defer func() { if markConsumed { - if err := fm.logBroadcaster.MarkConsumed(ctx, lb); err != nil { + if err := fm.logBroadcaster.MarkConsumed(ctx, nil, lb); err != nil { fm.logger.Errorw("Failed to mark log consumed", "err", err, "log", lb.String()) } } @@ -665,13 +663,13 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr // We always want to reset the idle timer upon receiving a NewRound log, so we do it before any `return` statements. fm.pollManager.ResetIdleTimer(log.StartedAt.Uint64()) - mostRecentRoundID, err := fm.orm.MostRecentFluxMonitorRoundID(fm.contractAddress) + mostRecentRoundID, err := fm.orm.MostRecentFluxMonitorRoundID(ctx, fm.contractAddress) if err != nil && !errors.Is(err, sql.ErrNoRows) { newRoundLogger.Errorf("error fetching Flux Monitor most recent round ID from DB: %v", err) return } - roundStats, jobRunStatus, err := fm.statsAndStatusForRound(logRoundID, 1) + roundStats, jobRunStatus, err := fm.statsAndStatusForRound(ctx, logRoundID, 1) if err != nil { newRoundLogger.Errorf("error determining round stats / run status for round: %v", err) return @@ -680,14 +678,14 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr if logRoundID < mostRecentRoundID && roundStats.NumNewRoundLogs > 0 { newRoundLogger.Debugf("Received an older round log (and number of previously received NewRound logs is: %v) - "+ "a possible reorg, hence deleting round ids from %v to %v", roundStats.NumNewRoundLogs, logRoundID, mostRecentRoundID) - err = fm.orm.DeleteFluxMonitorRoundsBackThrough(fm.contractAddress, logRoundID) + err = fm.orm.DeleteFluxMonitorRoundsBackThrough(ctx, fm.contractAddress, logRoundID) if err != nil { newRoundLogger.Errorf("error deleting reorged Flux Monitor rounds from DB: %v", err) return } // as all newer stats were deleted, at this point a new round stats entry will be created - roundStats, err = fm.orm.FindOrCreateFluxMonitorRoundStats(fm.contractAddress, logRoundID, 1) + roundStats, err = fm.orm.FindOrCreateFluxMonitorRoundStats(ctx, fm.contractAddress, logRoundID, 1) if err != nil { newRoundLogger.Errorf("error determining subsequent round stats for round: %v", err) return @@ -771,7 +769,7 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr return } - if !fm.isValidSubmission(newRoundLogger, answer, started) { + if !fm.isValidSubmission(ctx, newRoundLogger, answer, started) { return } @@ -779,14 +777,14 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr newRoundLogger.Error("roundState.PaymentAmount shouldn't be nil") } - err = fm.q.Transaction(func(tx pg.Queryer) error { - if err2 := fm.runner.InsertFinishedRun(run, false, pg.WithQueryer(tx)); err2 != nil { + err = fm.Transact(ctx, func(tx sqlutil.DataSource) error { + if err2 := fm.runner.InsertFinishedRun(ctx, tx, run, false); err2 != nil { return err2 } if err2 := fm.queueTransactionForTxm(ctx, tx, run.ID, answer, roundState.RoundId, &log); err2 != nil { return err2 } - return fm.logBroadcaster.MarkConsumed(ctx, lb) + return fm.logBroadcaster.MarkConsumed(ctx, tx, lb) }) // Either the tx failed and we want to reprocess the log, or it succeeded and already marked it consumed markConsumed = false @@ -796,6 +794,10 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr } } +func (fm *FluxMonitor) Transact(ctx context.Context, fn func(sqlutil.DataSource) error) error { + return sqlutil.TransactDataSource(ctx, fm.ds, nil, fn) +} + var ( // ErrNotEligible defines when the round is not eligible for submission ErrNotEligible = errors.New("not eligible to submit") @@ -832,7 +834,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker var markConsumed = true defer func() { if markConsumed && broadcast != nil { - if err := fm.logBroadcaster.MarkConsumed(ctx, broadcast); err != nil { + if err := fm.logBroadcaster.MarkConsumed(ctx, nil, broadcast); err != nil { l.Errorw("Failed to mark log consumed", "err", err, "log", broadcast.String()) } } @@ -863,8 +865,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker roundState, err := fm.roundState(0) if err != nil { l.Errorw("unable to determine eligibility to submit from FluxAggregator contract", "err", err) - fm.jobORM.TryRecordError( - fm.spec.JobID, + fm.jobORM.TryRecordError(fm.spec.JobID, "Unable to call roundState method on provided contract. Check contract address.", ) @@ -884,8 +885,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker roundStateNew, err2 := fm.roundState(roundState.RoundId) if err2 != nil { l.Errorw("unable to determine eligibility to submit from FluxAggregator contract", "err", err2) - fm.jobORM.TryRecordError( - fm.spec.JobID, + fm.jobORM.TryRecordError(fm.spec.JobID, "Unable to call roundState method on provided contract. Check contract address.", ) @@ -909,7 +909,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker } }() - roundStats, jobRunStatus, err := fm.statsAndStatusForRound(roundState.RoundId, 0) + roundStats, jobRunStatus, err := fm.statsAndStatusForRound(ctx, roundState.RoundId, 0) if err != nil { l.Errorw("error determining round stats / run status for round", "err", err) @@ -977,7 +977,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker return } - if !fm.isValidSubmission(l, answer, started) { + if !fm.isValidSubmission(ctx, l, answer, started) { return } @@ -1005,8 +1005,8 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker l.Error("roundState.PaymentAmount shouldn't be nil") } - err = fm.q.Transaction(func(tx pg.Queryer) error { - if err2 := fm.runner.InsertFinishedRun(run, true, pg.WithQueryer(tx)); err2 != nil { + err = fm.Transact(ctx, func(tx sqlutil.DataSource) error { + if err2 := fm.runner.InsertFinishedRun(ctx, tx, run, true); err2 != nil { return err2 } if err2 := fm.queueTransactionForTxm(ctx, tx, run.ID, answer, roundState.RoundId, nil); err2 != nil { @@ -1014,7 +1014,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker } if broadcast != nil { // In the case of a flag lowered, the pollEligible call is triggered by a log. - return fm.logBroadcaster.MarkConsumed(ctx, broadcast) + return fm.logBroadcaster.MarkConsumed(ctx, tx, broadcast) } return nil }) @@ -1031,7 +1031,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker // If the answer is outside the allowable range, log an error and don't submit. // to avoid an onchain reversion. -func (fm *FluxMonitor) isValidSubmission(l logger.Logger, answer decimal.Decimal, started time.Time) bool { +func (fm *FluxMonitor) isValidSubmission(ctx context.Context, l logger.Logger, answer decimal.Decimal, started time.Time) bool { if fm.submissionChecker.IsValid(answer) { return true } @@ -1085,7 +1085,7 @@ func (fm *FluxMonitor) initialRoundState() flux_aggregator_wrapper.OracleRoundSt return latestRoundState } -func (fm *FluxMonitor) queueTransactionForTxm(ctx context.Context, tx pg.Queryer, runID int64, answer decimal.Decimal, roundID uint32, log *flux_aggregator_wrapper.FluxAggregatorNewRound) error { +func (fm *FluxMonitor) queueTransactionForTxm(ctx context.Context, tx sqlutil.DataSource, runID int64, answer decimal.Decimal, roundID uint32, log *flux_aggregator_wrapper.FluxAggregatorNewRound) error { // Use pipeline run ID to generate globally unique key that can correlate this run to a Tx idempotencyKey := fmt.Sprintf("fluxmonitor-%d", runID) // Submit the Eth Tx @@ -1105,12 +1105,12 @@ func (fm *FluxMonitor) queueTransactionForTxm(ctx context.Context, tx pg.Queryer numLogs = 1 } // Update the flux monitor round stats - err = fm.orm.UpdateFluxMonitorRoundStats( + err = fm.orm.WithDataSource(tx).UpdateFluxMonitorRoundStats( + ctx, fm.contractAddress, roundID, runID, numLogs, - pg.WithQueryer(tx), ) if err != nil { fm.logger.Errorw( @@ -1124,8 +1124,8 @@ func (fm *FluxMonitor) queueTransactionForTxm(ctx context.Context, tx pg.Queryer return nil } -func (fm *FluxMonitor) statsAndStatusForRound(roundID uint32, newRoundLogs uint) (FluxMonitorRoundStatsV2, pipeline.RunStatus, error) { - roundStats, err := fm.orm.FindOrCreateFluxMonitorRoundStats(fm.contractAddress, roundID, newRoundLogs) +func (fm *FluxMonitor) statsAndStatusForRound(ctx context.Context, roundID uint32, newRoundLogs uint) (FluxMonitorRoundStatsV2, pipeline.RunStatus, error) { + roundStats, err := fm.orm.FindOrCreateFluxMonitorRoundStats(ctx, fm.contractAddress, roundID, newRoundLogs) if err != nil { return FluxMonitorRoundStatsV2{}, pipeline.RunStatusUnknown, err } @@ -1133,7 +1133,7 @@ func (fm *FluxMonitor) statsAndStatusForRound(roundID uint32, newRoundLogs uint) // JobRun will not exist if this is the first time responding to this round var run pipeline.Run if roundStats.PipelineRunID.Valid { - run, err = fm.pipelineORM.FindRun(roundStats.PipelineRunID.Int64) + run, err = fm.pipelineORM.FindRun(ctx, roundStats.PipelineRunID.Int64) if err != nil { return FluxMonitorRoundStatsV2{}, pipeline.RunStatusUnknown, err } diff --git a/core/services/fluxmonitorv2/flux_monitor_test.go b/core/services/fluxmonitorv2/flux_monitor_test.go index e4db716bbbb..042ddb99afb 100644 --- a/core/services/fluxmonitorv2/flux_monitor_test.go +++ b/core/services/fluxmonitorv2/flux_monitor_test.go @@ -22,6 +22,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" logmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log/mocks" @@ -31,7 +32,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/cltest/heavyweight" "github.com/smartcontractkit/chainlink/v2/core/internal/mocks" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" corenull "github.com/smartcontractkit/chainlink/v2/core/null" @@ -40,7 +40,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" jobmocks "github.com/smartcontractkit/chainlink/v2/core/services/job/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" pipelinemocks "github.com/smartcontractkit/chainlink/v2/core/services/pipeline/mocks" ) @@ -53,8 +52,8 @@ var ( type answerSet struct{ latestAnswer, polledAnswer int64 } -func newORM(t *testing.T, db *sqlx.DB, cfg pg.QConfig, txm txmgr.TxManager) fluxmonitorv2.ORM { - return fluxmonitorv2.NewORM(db, logger.TestLogger(t), cfg, txm, txmgrcommon.NewSendEveryStrategy(), txmgr.TransmitCheckerSpec{}) +func newORM(t *testing.T, db *sqlx.DB, txm txmgr.TxManager) fluxmonitorv2.ORM { + return fluxmonitorv2.NewORM(db, logger.TestLogger(t), txm, txmgrcommon.NewSendEveryStrategy(), txmgr.TransmitCheckerSpec{}) } var ( @@ -149,7 +148,7 @@ type setupOptions struct { // setup sets up a Flux Monitor for testing, allowing the test to provide // functional options to configure the setup -func setup(t *testing.T, db *sqlx.DB, optionFns ...func(*setupOptions)) (*fluxmonitorv2.FluxMonitor, *testMocks) { +func setup(t *testing.T, ds sqlutil.DataSource, optionFns ...func(*setupOptions)) (*fluxmonitorv2.FluxMonitor, *testMocks) { t.Helper() testutils.SkipShort(t, "long test") @@ -190,7 +189,7 @@ func setup(t *testing.T, db *sqlx.DB, optionFns ...func(*setupOptions)) (*fluxmo tm.pipelineRunner, job.Job{}, pipelineSpec, - pg.NewQ(db, lggr, pgtest.NewQConfig(true)), + ds, options.orm, tm.jobORM, tm.pipelineORM, @@ -386,7 +385,7 @@ func TestFluxMonitor_PollIfEligible(t *testing.T) { } tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(reportableRoundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(reportableRoundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: reportableRoundID, @@ -395,12 +394,12 @@ func TestFluxMonitor_PollIfEligible(t *testing.T) { }, nil) tm.pipelineORM. - On("FindRun", run.ID). + On("FindRun", mock.Anything, run.ID). Return(run, nil) } else { if tc.connected { tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(reportableRoundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(reportableRoundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: reportableRoundID, @@ -469,7 +468,7 @@ func TestFluxMonitor_PollIfEligible(t *testing.T) { tm.pipelineRunner.On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }). Once() tm.contractSubmitter. @@ -479,13 +478,14 @@ func TestFluxMonitor_PollIfEligible(t *testing.T) { tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(reportableRoundID), int64(1), mock.Anything, - mock.Anything, ). Return(nil) + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) } oracles := []common.Address{nodeAddr, testutils.NewAddress()} @@ -560,6 +560,7 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { logsAwaiter := cltest.NewAwaiter() tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) tm.fluxAggregator.On("Address").Return(common.Address{}) tm.fluxAggregator.On("LatestRoundData", nilOpts).Return(freshContractRoundDataResponse()).Maybe() @@ -573,19 +574,18 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { tm.fluxAggregator.On("OracleRoundState", nilOpts, nodeAddr, uint32(3)).Return(makeRoundStateForRoundID(3), nil).Once() tm.fluxAggregator.On("OracleRoundState", nilOpts, nodeAddr, uint32(4)).Return(makeRoundStateForRoundID(4), nil).Once() tm.fluxAggregator.On("GetOracles", nilOpts).Return(oracles, nil) - // tm.fluxAggregator.On("Address").Return(contractAddress, nil) tm.logBroadcaster.On("Register", fm, mock.Anything).Return(func() {}) tm.logBroadcaster.On("IsConnected").Return(true).Maybe() - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(1), nil) - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(3), nil) - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(4), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(1), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(3), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(4), nil) // Round 1 run := &pipeline.Run{ID: 1} tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 1, @@ -605,7 +605,7 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }).Once() tm.contractSubmitter. On("Submit", mock.Anything, big.NewInt(1), big.NewInt(fetchedValue), buildIdempotencyKey(run.ID)). @@ -614,18 +614,18 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(1), mock.AnythingOfType("int64"), //int64(1), mock.Anything, - mock.Anything, ). Return(nil).Once() // Round 3 run = &pipeline.Run{ID: 2} tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(3), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(3), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 3, @@ -645,7 +645,7 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 2 + args.Get(2).(*pipeline.Run).ID = 2 }).Once() tm.contractSubmitter. On("Submit", mock.Anything, big.NewInt(3), big.NewInt(fetchedValue), buildIdempotencyKey(run.ID)). @@ -653,18 +653,18 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(3), mock.AnythingOfType("int64"), //int64(2), mock.Anything, - mock.Anything, ). Return(nil).Once() // Round 4 run = &pipeline.Run{ID: 3} tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(4), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(4), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 3, @@ -684,7 +684,7 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 3 + args.Get(2).(*pipeline.Run).ID = 3 }).Once() tm.contractSubmitter. On("Submit", mock.Anything, big.NewInt(4), big.NewInt(fetchedValue), buildIdempotencyKey(run.ID)). @@ -692,11 +692,11 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(4), mock.AnythingOfType("int64"), //int64(3), mock.Anything, - mock.Anything, ). Return(nil). Once(). @@ -711,17 +711,17 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { logBroadcast.On("DecodedLog").Return(&flux_aggregator_wrapper.FluxAggregatorNewRound{RoundId: big.NewInt(int64(i)), StartedAt: big.NewInt(0)}) logBroadcast.On("String").Maybe().Return("") tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) logBroadcasts = append(logBroadcasts, logBroadcast) } - - fm.HandleLog(logBroadcasts[0]) // Get the checker to start processing a log so we can freeze it + ctx := testutils.Context(t) + fm.HandleLog(ctx, logBroadcasts[0]) // Get the checker to start processing a log so we can freeze it readyToFillQueue.AwaitOrFail(t) - fm.HandleLog(logBroadcasts[1]) // This log is evicted from the priority queue - fm.HandleLog(logBroadcasts[2]) - fm.HandleLog(logBroadcasts[3]) + fm.HandleLog(ctx, logBroadcasts[1]) // This log is evicted from the priority queue + fm.HandleLog(ctx, logBroadcasts[2]) + fm.HandleLog(ctx, logBroadcasts[3]) logsAwaiter.ItHappened() readyToAssert.AwaitOrFail(t) @@ -749,7 +749,7 @@ func TestFluxMonitor_TriggerIdleTimeThreshold(t *testing.T) { t.Parallel() var ( - orm = newORM(t, db, pgtest.NewQConfig(true), nil) + orm = newORM(t, db, nil) ) fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(tc.idleTimerDisabled), setIdleTimerPeriod(tc.idleDuration), withORM(orm)) @@ -795,8 +795,8 @@ func TestFluxMonitor_TriggerIdleTimeThreshold(t *testing.T) { tm.logBroadcast.On("DecodedLog").Return(&decodedLog) tm.logBroadcast.On("String").Maybe().Return("") tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) - fm.HandleLog(tm.logBroadcast) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) + fm.HandleLog(testutils.Context(t), tm.logBroadcast) g.Eventually(chBlock).Should(gomega.BeClosed()) @@ -856,7 +856,7 @@ func TestFluxMonitor_HibernationTickerFiresMultipleTimes(t *testing.T) { pollOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 1, @@ -873,7 +873,7 @@ func TestFluxMonitor_HibernationTickerFiresMultipleTimes(t *testing.T) { // Finds an existing run created by the initial poll tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -881,7 +881,7 @@ func TestFluxMonitor_HibernationTickerFiresMultipleTimes(t *testing.T) { NumSubmissions: 1, }, nil).Once() finishedAt := time.Now() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{ + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{ FinishedAt: null.TimeFrom(finishedAt), }, nil) @@ -893,7 +893,7 @@ func TestFluxMonitor_HibernationTickerFiresMultipleTimes(t *testing.T) { pollOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(2), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(2), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 2, @@ -950,7 +950,7 @@ func TestFluxMonitor_HibernationIsEnteredAndRetryTickerStopped(t *testing.T) { pollOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, roundOne, mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, roundOne, mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 1, @@ -970,7 +970,7 @@ func TestFluxMonitor_HibernationIsEnteredAndRetryTickerStopped(t *testing.T) { // Finds an error run, so that retry ticker will be kicked off tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, roundOne, mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, roundOne, mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -978,7 +978,7 @@ func TestFluxMonitor_HibernationIsEnteredAndRetryTickerStopped(t *testing.T) { NumSubmissions: 1, }, nil).Once() finishedAt := time.Now() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{ + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{ FinishedAt: null.TimeFrom(finishedAt), FatalErrors: []null.String{null.StringFrom("an error to start retry ticker")}, }, nil) @@ -997,7 +997,7 @@ func TestFluxMonitor_HibernationIsEnteredAndRetryTickerStopped(t *testing.T) { roundState2 := flux_aggregator_wrapper.OracleRoundState{RoundId: 2, EligibleToSubmit: false, LatestSubmission: answerBigInt, StartedAt: 0} tm.fluxAggregator.On("OracleRoundState", nilOpts, nodeAddr, roundZero).Return(roundState2, nil).Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, roundTwo, mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, roundTwo, mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 2, @@ -1054,7 +1054,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { roundState1 := flux_aggregator_wrapper.OracleRoundState{RoundId: 1, EligibleToSubmit: false, LatestSubmission: answerBigInt, StartedAt: now()} tm.fluxAggregator.On("OracleRoundState", nilOpts, nodeAddr, uint32(0)).Return(roundState1, nil).Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 1, @@ -1072,7 +1072,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { }) // Finds an existing run created by the initial poll tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -1080,7 +1080,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { NumSubmissions: 1, }, nil).Once() finishedAt := time.Now() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{ + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{ FinishedAt: null.TimeFrom(finishedAt), }, nil) @@ -1092,7 +1092,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { idleDurationOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(2), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(2), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 2, @@ -1107,7 +1107,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { idleDurationOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(3), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(3), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 3, @@ -1118,7 +1118,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil).Once() tm.logBroadcast.On("DecodedLog").Return(&flux_aggregator_wrapper.FluxAggregatorAnswerUpdated{}) tm.logBroadcast.On("String").Maybe().Return("") - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Once() + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() fm.ExportedBacklog().Add(fluxmonitorv2.PriorityNewRoundLog, tm.logBroadcast) fm.ExportedProcessLogs() @@ -1133,7 +1133,7 @@ func TestFluxMonitor_RoundTimeoutCausesPoll_timesOutAtZero(t *testing.T) { var ( oracles = []common.Address{nodeAddr, testutils.NewAddress()} - orm = newORM(t, db, pgtest.NewQConfig(true), nil) + orm = newORM(t, db, nil) ) fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), withORM(orm)) @@ -1193,10 +1193,7 @@ func TestFluxMonitor_UsesPreviousRoundStateOnStartup_RoundTimeout(t *testing.T) t.Run(test.name, func(t *testing.T) { t.Parallel() - cfg := configtest.NewTestGeneralConfig(t) - var ( - orm = newORM(t, db, cfg.Database(), nil) - ) + orm := newORM(t, db, nil) fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), withORM(orm)) @@ -1260,11 +1257,7 @@ func TestFluxMonitor_UsesPreviousRoundStateOnStartup_IdleTimer(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - cfg := configtest.NewTestGeneralConfig(t) - - var ( - orm = newORM(t, db, cfg.Database(), nil) - ) + orm := newORM(t, db, nil) fm, tm := setup(t, db, @@ -1323,11 +1316,7 @@ func TestFluxMonitor_RoundTimeoutCausesPoll_timesOutNotZero(t *testing.T) { g := gomega.NewWithT(t) db, nodeAddr := setupStoreWithKey(t) oracles := []common.Address{nodeAddr, testutils.NewAddress()} - cfg := configtest.NewTestGeneralConfig(t) - - var ( - orm = newORM(t, db, cfg.Database(), nil) - ) + orm := newORM(t, db, nil) fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), withORM(orm)) @@ -1381,14 +1370,14 @@ func TestFluxMonitor_RoundTimeoutCausesPoll_timesOutNotZero(t *testing.T) { servicetest.Run(t, fm) tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) tm.logBroadcast.On("DecodedLog").Return(&flux_aggregator_wrapper.FluxAggregatorNewRound{ RoundId: big.NewInt(0), StartedAt: big.NewInt(time.Now().UTC().Unix()), }) tm.logBroadcast.On("String").Maybe().Return("") // To mark it consumed, we need to be eligible to submit. - fm.HandleLog(tm.logBroadcast) + fm.HandleLog(testutils.Context(t), tm.logBroadcast) g.Eventually(chRoundState1).Should(gomega.BeClosed()) g.Eventually(chRoundState2).Should(gomega.BeClosed()) @@ -1409,7 +1398,7 @@ func TestFluxMonitor_ConsumeLogBroadcast(t *testing.T) { tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil).Once() tm.logBroadcast.On("DecodedLog").Return(&flux_aggregator_wrapper.FluxAggregatorAnswerUpdated{}) tm.logBroadcast.On("String").Maybe().Return("") - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Once() + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() fm.ExportedBacklog().Add(fluxmonitorv2.PriorityNewRoundLog, tm.logBroadcast) fm.ExportedProcessLogs() @@ -1468,11 +1457,12 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Maybe() + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) // Mocks initiated by the New Round log - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(roundID), nil).Once() + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(roundID), nil).Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: roundID, @@ -1492,17 +1482,17 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Once() + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() tm.contractSubmitter.On("Submit", mock.Anything, big.NewInt(roundID), big.NewInt(answer), buildIdempotencyKey(run.ID)).Return(nil).Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(roundID), int64(1), uint(1), - mock.Anything, ). Return(nil) @@ -1545,7 +1535,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -1554,7 +1544,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { }, nil).Once() now := time.Now() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{ + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{ FinishedAt: null.TimeFrom(now), }, nil) @@ -1583,6 +1573,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { run := &pipeline.Run{ID: 1} tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Maybe() + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) // First, force the node to try to poll, which should result in a submission tm.fluxAggregator.On("LatestRoundData", nilOpts).Return(flux_aggregator_wrapper.LatestRoundData{ @@ -1600,7 +1591,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { }, nil). Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: roundID, @@ -1620,16 +1611,16 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }) tm.contractSubmitter.On("Submit", mock.Anything, big.NewInt(roundID), big.NewInt(answer), buildIdempotencyKey(run.ID)).Return(nil).Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(roundID), int64(1), uint(0), - mock.Anything, ). Return(nil). Once() @@ -1639,18 +1630,18 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { fm.ExportedPollIfEligible(0, 0) // Now fire off the NewRound log and ensure it does not respond this time - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(roundID), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(roundID), nil) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, RoundID: roundID, NumSubmissions: 1, }, nil).Once() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{}, nil) + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{}, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) fm.ExportedRespondToNewRoundLog(&flux_aggregator_wrapper.FluxAggregatorNewRound{ RoundId: big.NewInt(roundID), StartedAt: big.NewInt(0), @@ -1679,6 +1670,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { run := &pipeline.Run{ID: 1} tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Maybe() + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) // First, force the node to try to poll, which should result in a submission tm.fluxAggregator.On("LatestRoundData", nilOpts).Return(flux_aggregator_wrapper.LatestRoundData{ @@ -1696,7 +1688,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { }, nil). Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: roundID, @@ -1716,16 +1708,16 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }) tm.contractSubmitter.On("Submit", mock.Anything, big.NewInt(roundID), big.NewInt(answer), buildIdempotencyKey(run.ID)).Return(nil).Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(roundID), int64(1), uint(0), - mock.Anything, ). Return(nil). Once() @@ -1735,27 +1727,27 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { fm.ExportedPollIfEligible(0, 0) // Now fire off the NewRound log and ensure it does not respond this time - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(roundID), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(roundID), nil) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(olderRoundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(olderRoundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, RoundID: olderRoundID, NumSubmissions: 1, }, nil).Once() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{}, nil) + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{}, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) fm.ExportedRespondToNewRoundLog(&flux_aggregator_wrapper.FluxAggregatorNewRound{ RoundId: big.NewInt(olderRoundID), StartedAt: big.NewInt(0), }, log.NewLogBroadcast(types.Log{}, cltest.FixtureChainID, nil)) // Simulate a reorg - fire the same NewRound log again, which should result in a submission this time - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(roundID), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(roundID), nil) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(olderRoundID), uint(1)). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(olderRoundID), uint(1)). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -1763,14 +1755,14 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { NumSubmissions: 1, NumNewRoundLogs: 1, }, nil).Once() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{}, nil) + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{}, nil) // all newer round stats should be deleted - tm.orm.On("DeleteFluxMonitorRoundsBackThrough", contractAddress, uint32(olderRoundID)).Return(nil) + tm.orm.On("DeleteFluxMonitorRoundsBackThrough", mock.Anything, contractAddress, uint32(olderRoundID)).Return(nil) // then we are returning a fresh round stat, with NumSubmissions: 0 tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(olderRoundID), uint(1)). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(olderRoundID), uint(1)). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -1795,16 +1787,16 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(olderRoundID), int64(1), uint(1), - mock.Anything, ). Return(nil). Once() - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) fm.ExportedRespondToNewRoundLog(&flux_aggregator_wrapper.FluxAggregatorNewRound{ RoundId: big.NewInt(olderRoundID), StartedAt: big.NewInt(0), @@ -1824,6 +1816,7 @@ func TestFluxMonitor_DrumbeatTicker(t *testing.T) { fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), enableDrumbeatTicker("@every 3s", 2*time.Second)) tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil) + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) const fetchedAnswer = 100 answerBigInt := big.NewInt(fetchedAnswer) @@ -1853,7 +1846,7 @@ func TestFluxMonitor_DrumbeatTicker(t *testing.T) { Return(roundState, nil). Once() - tm.orm.On("FindOrCreateFluxMonitorRoundStats", contractAddress, roundID, mock.Anything). + tm.orm.On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, roundID, mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{Aggregator: contractAddress, RoundID: roundID}, nil). Once() @@ -1895,7 +1888,7 @@ func TestFluxMonitor_DrumbeatTicker(t *testing.T) { tm.pipelineRunner.On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = runID + args.Get(2).(*pipeline.Run).ID = runID }). Once() tm.contractSubmitter. @@ -1904,7 +1897,7 @@ func TestFluxMonitor_DrumbeatTicker(t *testing.T) { Once() tm.orm. - On("UpdateFluxMonitorRoundStats", contractAddress, roundID, runID, mock.Anything, mock.Anything). + On("UpdateFluxMonitorRoundStats", mock.Anything, contractAddress, roundID, runID, mock.Anything). Return(nil). Once() } diff --git a/core/services/fluxmonitorv2/mocks/orm.go b/core/services/fluxmonitorv2/mocks/orm.go index 287c7ebb5fa..e5173db8264 100644 --- a/core/services/fluxmonitorv2/mocks/orm.go +++ b/core/services/fluxmonitorv2/mocks/orm.go @@ -11,7 +11,7 @@ import ( mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) // ORM is an autogenerated mock type for the ORM type @@ -19,9 +19,9 @@ type ORM struct { mock.Mock } -// CountFluxMonitorRoundStats provides a mock function with given fields: -func (_m *ORM) CountFluxMonitorRoundStats() (int, error) { - ret := _m.Called() +// CountFluxMonitorRoundStats provides a mock function with given fields: ctx +func (_m *ORM) CountFluxMonitorRoundStats(ctx context.Context) (int, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for CountFluxMonitorRoundStats") @@ -29,17 +29,17 @@ func (_m *ORM) CountFluxMonitorRoundStats() (int, error) { var r0 int var r1 error - if rf, ok := ret.Get(0).(func() (int, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (int, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -65,17 +65,17 @@ func (_m *ORM) CreateEthTransaction(ctx context.Context, fromAddress common.Addr return r0 } -// DeleteFluxMonitorRoundsBackThrough provides a mock function with given fields: aggregator, roundID -func (_m *ORM) DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, roundID uint32) error { - ret := _m.Called(aggregator, roundID) +// DeleteFluxMonitorRoundsBackThrough provides a mock function with given fields: ctx, aggregator, roundID +func (_m *ORM) DeleteFluxMonitorRoundsBackThrough(ctx context.Context, aggregator common.Address, roundID uint32) error { + ret := _m.Called(ctx, aggregator, roundID) if len(ret) == 0 { panic("no return value specified for DeleteFluxMonitorRoundsBackThrough") } var r0 error - if rf, ok := ret.Get(0).(func(common.Address, uint32) error); ok { - r0 = rf(aggregator, roundID) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, uint32) error); ok { + r0 = rf(ctx, aggregator, roundID) } else { r0 = ret.Error(0) } @@ -83,9 +83,9 @@ func (_m *ORM) DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, rou return r0 } -// FindOrCreateFluxMonitorRoundStats provides a mock function with given fields: aggregator, roundID, newRoundLogs -func (_m *ORM) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, newRoundLogs uint) (fluxmonitorv2.FluxMonitorRoundStatsV2, error) { - ret := _m.Called(aggregator, roundID, newRoundLogs) +// FindOrCreateFluxMonitorRoundStats provides a mock function with given fields: ctx, aggregator, roundID, newRoundLogs +func (_m *ORM) FindOrCreateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, newRoundLogs uint) (fluxmonitorv2.FluxMonitorRoundStatsV2, error) { + ret := _m.Called(ctx, aggregator, roundID, newRoundLogs) if len(ret) == 0 { panic("no return value specified for FindOrCreateFluxMonitorRoundStats") @@ -93,17 +93,17 @@ func (_m *ORM) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roun var r0 fluxmonitorv2.FluxMonitorRoundStatsV2 var r1 error - if rf, ok := ret.Get(0).(func(common.Address, uint32, uint) (fluxmonitorv2.FluxMonitorRoundStatsV2, error)); ok { - return rf(aggregator, roundID, newRoundLogs) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, uint32, uint) (fluxmonitorv2.FluxMonitorRoundStatsV2, error)); ok { + return rf(ctx, aggregator, roundID, newRoundLogs) } - if rf, ok := ret.Get(0).(func(common.Address, uint32, uint) fluxmonitorv2.FluxMonitorRoundStatsV2); ok { - r0 = rf(aggregator, roundID, newRoundLogs) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, uint32, uint) fluxmonitorv2.FluxMonitorRoundStatsV2); ok { + r0 = rf(ctx, aggregator, roundID, newRoundLogs) } else { r0 = ret.Get(0).(fluxmonitorv2.FluxMonitorRoundStatsV2) } - if rf, ok := ret.Get(1).(func(common.Address, uint32, uint) error); ok { - r1 = rf(aggregator, roundID, newRoundLogs) + if rf, ok := ret.Get(1).(func(context.Context, common.Address, uint32, uint) error); ok { + r1 = rf(ctx, aggregator, roundID, newRoundLogs) } else { r1 = ret.Error(1) } @@ -111,9 +111,9 @@ func (_m *ORM) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roun return r0, r1 } -// MostRecentFluxMonitorRoundID provides a mock function with given fields: aggregator -func (_m *ORM) MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, error) { - ret := _m.Called(aggregator) +// MostRecentFluxMonitorRoundID provides a mock function with given fields: ctx, aggregator +func (_m *ORM) MostRecentFluxMonitorRoundID(ctx context.Context, aggregator common.Address) (uint32, error) { + ret := _m.Called(ctx, aggregator) if len(ret) == 0 { panic("no return value specified for MostRecentFluxMonitorRoundID") @@ -121,17 +121,17 @@ func (_m *ORM) MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, var r0 uint32 var r1 error - if rf, ok := ret.Get(0).(func(common.Address) (uint32, error)); ok { - return rf(aggregator) + if rf, ok := ret.Get(0).(func(context.Context, common.Address) (uint32, error)); ok { + return rf(ctx, aggregator) } - if rf, ok := ret.Get(0).(func(common.Address) uint32); ok { - r0 = rf(aggregator) + if rf, ok := ret.Get(0).(func(context.Context, common.Address) uint32); ok { + r0 = rf(ctx, aggregator) } else { r0 = ret.Get(0).(uint32) } - if rf, ok := ret.Get(1).(func(common.Address) error); ok { - r1 = rf(aggregator) + if rf, ok := ret.Get(1).(func(context.Context, common.Address) error); ok { + r1 = rf(ctx, aggregator) } else { r1 = ret.Error(1) } @@ -139,24 +139,17 @@ func (_m *ORM) MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, return r0, r1 } -// UpdateFluxMonitorRoundStats provides a mock function with given fields: aggregator, roundID, runID, newRoundLogsAddition, qopts -func (_m *ORM) UpdateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, aggregator, roundID, runID, newRoundLogsAddition) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// UpdateFluxMonitorRoundStats provides a mock function with given fields: ctx, aggregator, roundID, runID, newRoundLogsAddition +func (_m *ORM) UpdateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint) error { + ret := _m.Called(ctx, aggregator, roundID, runID, newRoundLogsAddition) if len(ret) == 0 { panic("no return value specified for UpdateFluxMonitorRoundStats") } var r0 error - if rf, ok := ret.Get(0).(func(common.Address, uint32, int64, uint, ...pg.QOpt) error); ok { - r0 = rf(aggregator, roundID, runID, newRoundLogsAddition, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, uint32, int64, uint) error); ok { + r0 = rf(ctx, aggregator, roundID, runID, newRoundLogsAddition) } else { r0 = ret.Error(0) } @@ -164,6 +157,26 @@ func (_m *ORM) UpdateFluxMonitorRoundStats(aggregator common.Address, roundID ui return r0 } +// WithDataSource provides a mock function with given fields: _a0 +func (_m *ORM) WithDataSource(_a0 sqlutil.DataSource) fluxmonitorv2.ORM { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for WithDataSource") + } + + var r0 fluxmonitorv2.ORM + if rf, ok := ret.Get(0).(func(sqlutil.DataSource) fluxmonitorv2.ORM); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(fluxmonitorv2.ORM) + } + } + + return r0 +} + // NewORM creates a new instance of ORM. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewORM(t interface { diff --git a/core/services/fluxmonitorv2/orm.go b/core/services/fluxmonitorv2/orm.go index 91973387e32..e090b84ed04 100644 --- a/core/services/fluxmonitorv2/orm.go +++ b/core/services/fluxmonitorv2/orm.go @@ -7,12 +7,10 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type transmitter interface { @@ -23,48 +21,49 @@ type transmitter interface { // ORM defines an interface for database commands related to Flux Monitor v2 type ORM interface { - MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, error) - DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, roundID uint32) error - FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, newRoundLogs uint) (FluxMonitorRoundStatsV2, error) - UpdateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint, qopts ...pg.QOpt) error + MostRecentFluxMonitorRoundID(ctx context.Context, aggregator common.Address) (uint32, error) + DeleteFluxMonitorRoundsBackThrough(ctx context.Context, aggregator common.Address, roundID uint32) error + FindOrCreateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, newRoundLogs uint) (FluxMonitorRoundStatsV2, error) + UpdateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint) error CreateEthTransaction(ctx context.Context, fromAddress, toAddress common.Address, payload []byte, gasLimit uint64, idempotencyKey *string) error - CountFluxMonitorRoundStats() (count int, err error) + CountFluxMonitorRoundStats(ctx context.Context) (count int, err error) + + WithDataSource(sqlutil.DataSource) ORM } type orm struct { - q pg.Q + ds sqlutil.DataSource txm transmitter strategy types.TxStrategy checker txmgr.TransmitCheckerSpec logger logger.Logger } +func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { return o.withDataSource(ds) } + +func (o *orm) withDataSource(ds sqlutil.DataSource) *orm { + return &orm{ds, o.txm, o.strategy, o.checker, o.logger} +} + // NewORM initializes a new ORM -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, txm transmitter, strategy types.TxStrategy, checker txmgr.TransmitCheckerSpec) ORM { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger, txm transmitter, strategy types.TxStrategy, checker txmgr.TransmitCheckerSpec) ORM { namedLogger := lggr.Named("FluxMonitorORM") - q := pg.NewQ(db, namedLogger, cfg) - return &orm{ - q, - txm, - strategy, - checker, - namedLogger, - } + return &orm{ds, txm, strategy, checker, namedLogger} } // MostRecentFluxMonitorRoundID finds roundID of the most recent round that the // provided oracle address submitted to -func (o *orm) MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, error) { +func (o *orm) MostRecentFluxMonitorRoundID(ctx context.Context, aggregator common.Address) (uint32, error) { var stats FluxMonitorRoundStatsV2 - err := o.q.Get(&stats, `SELECT * FROM flux_monitor_round_stats_v2 WHERE aggregator = $1 ORDER BY round_id DESC LIMIT 1`, aggregator) + err := o.ds.GetContext(ctx, &stats, `SELECT * FROM flux_monitor_round_stats_v2 WHERE aggregator = $1 ORDER BY round_id DESC LIMIT 1`, aggregator) return stats.RoundID, errors.Wrap(err, "MostRecentFluxMonitorRoundID failed") } // DeleteFluxMonitorRoundsBackThrough deletes all the RoundStat records for a // given oracle address starting from the most recent round back through the // given round -func (o *orm) DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, roundID uint32) error { - _, err := o.q.Exec(` +func (o *orm) DeleteFluxMonitorRoundsBackThrough(ctx context.Context, aggregator common.Address, roundID uint32) error { + _, err := o.ds.ExecContext(ctx, ` DELETE FROM flux_monitor_round_stats_v2 WHERE aggregator = $1 AND round_id >= $2 @@ -74,14 +73,14 @@ func (o *orm) DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, roun // FindOrCreateFluxMonitorRoundStats find the round stats record for a given // oracle on a given round, or creates it if no record exists -func (o *orm) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, newRoundLogs uint) (stats FluxMonitorRoundStatsV2, err error) { - err = o.q.Transaction(func(tx pg.Queryer) error { - err = tx.Get(&stats, +func (o *orm) FindOrCreateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, newRoundLogs uint) (stats FluxMonitorRoundStatsV2, err error) { + err = sqlutil.Transact(ctx, o.withDataSource, o.ds, nil, func(tx *orm) error { + err = tx.ds.GetContext(ctx, &stats, `INSERT INTO flux_monitor_round_stats_v2 (aggregator, round_id, num_new_round_logs, num_submissions) VALUES ($1, $2, $3, 0) ON CONFLICT (aggregator, round_id) DO NOTHING`, aggregator, roundID, newRoundLogs) if errors.Is(err, sql.ErrNoRows) { - err = tx.Get(&stats, `SELECT * FROM flux_monitor_round_stats_v2 WHERE aggregator=$1 AND round_id=$2`, aggregator, roundID) + err = tx.ds.GetContext(ctx, &stats, `SELECT * FROM flux_monitor_round_stats_v2 WHERE aggregator=$1 AND round_id=$2`, aggregator, roundID) } return err }) @@ -91,9 +90,8 @@ func (o *orm) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, round // UpdateFluxMonitorRoundStats trys to create a RoundStat record for the given oracle // at the given round. If one already exists, it increments the num_submissions column. -func (o *orm) UpdateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) - err := q.ExecQ(` +func (o *orm) UpdateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint) error { + _, err := o.ds.ExecContext(ctx, ` INSERT INTO flux_monitor_round_stats_v2 ( aggregator, round_id, pipeline_run_id, num_new_round_logs, num_submissions ) VALUES ( @@ -108,8 +106,8 @@ func (o *orm) UpdateFluxMonitorRoundStats(aggregator common.Address, roundID uin } // CountFluxMonitorRoundStats counts the total number of records -func (o *orm) CountFluxMonitorRoundStats() (count int, err error) { - err = o.q.Get(&count, `SELECT count(*) FROM flux_monitor_round_stats_v2`) +func (o *orm) CountFluxMonitorRoundStats(ctx context.Context) (count int, err error) { + err = o.ds.GetContext(ctx, &count, `SELECT count(*) FROM flux_monitor_round_stats_v2`) return count, errors.Wrap(err, "CountFluxMonitorRoundStats failed") } diff --git a/core/services/fluxmonitorv2/orm_test.go b/core/services/fluxmonitorv2/orm_test.go index 9b31525831b..f6904b9fe97 100644 --- a/core/services/fluxmonitorv2/orm_test.go +++ b/core/services/fluxmonitorv2/orm_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" commontxmmocks "github.com/smartcontractkit/chainlink/v2/common/txmgr/types/mocks" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" @@ -28,61 +29,62 @@ import ( func TestORM_MostRecentFluxMonitorRoundID(t *testing.T) { t.Parallel() + ctx := tests.Context(t) db := pgtest.NewSqlxDB(t) - cfg := pgtest.NewQConfig(true) - orm := newORM(t, db, cfg, nil) + orm := newORM(t, db, nil) address := testutils.NewAddress() // Setup the rounds for round := uint32(0); round < 10; round++ { - _, err := orm.FindOrCreateFluxMonitorRoundStats(address, round, 1) + _, err := orm.FindOrCreateFluxMonitorRoundStats(ctx, address, round, 1) require.NoError(t, err) } - count, err := orm.CountFluxMonitorRoundStats() + count, err := orm.CountFluxMonitorRoundStats(ctx) require.NoError(t, err) require.Equal(t, 10, count) // Ensure round stats are not created again for the same address/roundID - stats, err := orm.FindOrCreateFluxMonitorRoundStats(address, uint32(0), 1) + stats, err := orm.FindOrCreateFluxMonitorRoundStats(ctx, address, uint32(0), 1) require.NoError(t, err) require.Equal(t, uint32(0), stats.RoundID) require.Equal(t, address, stats.Aggregator) require.Equal(t, uint64(1), stats.NumNewRoundLogs) - count, err = orm.CountFluxMonitorRoundStats() + count, err = orm.CountFluxMonitorRoundStats(ctx) require.NoError(t, err) require.Equal(t, 10, count) - roundID, err := orm.MostRecentFluxMonitorRoundID(testutils.NewAddress()) + roundID, err := orm.MostRecentFluxMonitorRoundID(ctx, testutils.NewAddress()) require.Error(t, err) require.Equal(t, uint32(0), roundID) - roundID, err = orm.MostRecentFluxMonitorRoundID(address) + roundID, err = orm.MostRecentFluxMonitorRoundID(ctx, address) require.NoError(t, err) require.Equal(t, uint32(9), roundID) // Deleting rounds against a new address should incur no changes - err = orm.DeleteFluxMonitorRoundsBackThrough(testutils.NewAddress(), 5) + err = orm.DeleteFluxMonitorRoundsBackThrough(ctx, testutils.NewAddress(), 5) require.NoError(t, err) - count, err = orm.CountFluxMonitorRoundStats() + count, err = orm.CountFluxMonitorRoundStats(ctx) require.NoError(t, err) require.Equal(t, 10, count) // Deleting rounds against the address - err = orm.DeleteFluxMonitorRoundsBackThrough(address, 5) + err = orm.DeleteFluxMonitorRoundsBackThrough(ctx, address, 5) require.NoError(t, err) - count, err = orm.CountFluxMonitorRoundStats() + count, err = orm.CountFluxMonitorRoundStats(ctx) require.NoError(t, err) require.Equal(t, 5, count) } func TestORM_UpdateFluxMonitorRoundStats(t *testing.T) { t.Parallel() + ctx := tests.Context(t) cfg := configtest.NewGeneralConfig(t, nil) db := pgtest.NewSqlxDB(t) @@ -92,13 +94,13 @@ func TestORM_UpdateFluxMonitorRoundStats(t *testing.T) { // Instantiate a real pipeline ORM because we need to create a pipeline run // for the foreign key constraint of the stats record - pipelineORM := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) // Instantiate a real job ORM because we need to create a job to satisfy // a check in pipeline.CreateRun jobORM := job.NewORM(db, pipelineORM, bridgeORM, keyStore, lggr, cfg.Database()) - orm := newORM(t, db, cfg.Database(), nil) + orm := newORM(t, db, nil) address := testutils.NewAddress() var roundID uint32 = 1 @@ -129,13 +131,13 @@ func TestORM_UpdateFluxMonitorRoundStats(t *testing.T) { }, }, } - err := pipelineORM.InsertFinishedRun(run, true) + err := pipelineORM.InsertFinishedRun(ctx, run, true) require.NoError(t, err) - err = orm.UpdateFluxMonitorRoundStats(address, roundID, run.ID, 0) + err = orm.UpdateFluxMonitorRoundStats(ctx, address, roundID, run.ID, 0) require.NoError(t, err) - stats, err := orm.FindOrCreateFluxMonitorRoundStats(address, roundID, 0) + stats, err := orm.FindOrCreateFluxMonitorRoundStats(ctx, address, roundID, 0) require.NoError(t, err) require.Equal(t, expectedCount, stats.NumSubmissions) require.True(t, stats.PipelineRunID.Valid) @@ -177,7 +179,7 @@ func TestORM_CreateEthTransaction(t *testing.T) { var ( txm = txmmocks.NewMockEvmTxManager(t) - orm = fluxmonitorv2.NewORM(db, logger.TestLogger(t), cfg, txm, strategy, txmgr.TransmitCheckerSpec{}) + orm = fluxmonitorv2.NewORM(db, logger.TestLogger(t), txm, strategy, txmgr.TransmitCheckerSpec{}) _, from = cltest.MustInsertRandomKey(t, ethKeyStore) to = testutils.NewAddress() diff --git a/core/services/gateway/delegate.go b/core/services/gateway/delegate.go index ba34f2894de..8cddc027803 100644 --- a/core/services/gateway/delegate.go +++ b/core/services/gateway/delegate.go @@ -41,10 +41,10 @@ func (d *Delegate) JobType() job.Type { return job.Gateway } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the scheduler to be used for running observer jobs func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { diff --git a/core/services/job/job_orm_test.go b/core/services/job/job_orm_test.go index a6e3622df1b..c60f096c358 100644 --- a/core/services/job/job_orm_test.go +++ b/core/services/job/job_orm_test.go @@ -82,7 +82,7 @@ func TestORM(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -346,7 +346,7 @@ func TestORM_DeleteJob_DeletesAssociatedRecords(t *testing.T) { require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) korm := keeper.NewORM(db, logger.TestLogger(t)) @@ -444,7 +444,7 @@ func TestORM_CreateJob_VRFV2(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -528,7 +528,7 @@ func TestORM_CreateJob_VRFV2Plus(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -615,7 +615,7 @@ func TestORM_CreateJob_OCRBootstrap(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -641,7 +641,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { keyStore := cltest.NewKeyStore(t, db, config.Database()) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -736,7 +736,7 @@ func TestORM_CreateJob_OCR_DuplicatedContractAddress(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -805,7 +805,7 @@ func TestORM_CreateJob_OCR2_DuplicatedContractAddress(t *testing.T) { require.NoError(t, keyStore.OCR2().Add(cltest.DefaultOCR2Key)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -866,7 +866,7 @@ func TestORM_CreateJob_OCR2_Sending_Keys_Transmitter_Keys_Validations(t *testing require.NoError(t, keyStore.OCR2().Add(cltest.DefaultOCR2Key)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -986,7 +986,7 @@ func Test_FindJobs(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1067,7 +1067,7 @@ func Test_FindJob(t *testing.T) { require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) require.NoError(t, keyStore.CSA().Add(cltest.DefaultCSAKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1250,7 +1250,7 @@ func Test_FindJobsByPipelineSpecIDs(t *testing.T) { keyStore := cltest.NewKeyStore(t, db, config.Database()) require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1298,7 +1298,7 @@ func Test_FindPipelineRuns(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1359,7 +1359,7 @@ func Test_PipelineRunsByJobID(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1419,7 +1419,7 @@ func Test_FindPipelineRunIDsByJobID(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1527,7 +1527,7 @@ func Test_FindPipelineRunsByIDs(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1585,7 +1585,7 @@ func Test_FindPipelineRunByID(t *testing.T) { err := keyStore.OCR().Add(cltest.DefaultOCRKey) require.NoError(t, err) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1628,7 +1628,7 @@ func Test_FindJobWithoutSpecErrors(t *testing.T) { err := keyStore.OCR().Add(cltest.DefaultOCRKey) require.NoError(t, err) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1665,7 +1665,7 @@ func Test_FindSpecErrorsByJobIDs(t *testing.T) { err := keyStore.OCR().Add(cltest.DefaultOCRKey) require.NoError(t, err) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1699,7 +1699,7 @@ func Test_CountPipelineRunsByJobID(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1740,6 +1740,7 @@ func Test_CountPipelineRunsByJobID(t *testing.T) { func mustInsertPipelineRun(t *testing.T, orm pipeline.ORM, j job.Job) pipeline.Run { t.Helper() + ctx := testutils.Context(t) run := pipeline.Run{ PipelineSpecID: j.PipelineSpecID, @@ -1750,7 +1751,7 @@ func mustInsertPipelineRun(t *testing.T, orm pipeline.ORM, j job.Job) pipeline.R CreatedAt: time.Now(), FinishedAt: null.Time{}, } - err := orm.CreateRun(&run) + err := orm.CreateRun(ctx, &run) require.NoError(t, err) return run } diff --git a/core/services/job/job_pipeline_orm_integration_test.go b/core/services/job/job_pipeline_orm_integration_test.go index 698e60eca7b..696005c270e 100644 --- a/core/services/job/job_pipeline_orm_integration_test.go +++ b/core/services/job/job_pipeline_orm_integration_test.go @@ -126,14 +126,15 @@ func TestPipelineORM_Integration(t *testing.T) { _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) t.Run("creates task DAGs", func(t *testing.T) { + ctx := testutils.Context(t) clearJobsDb(t, db) - orm := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + orm := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) p, err := pipeline.Parse(DotStr) require.NoError(t, err) - specID, err = orm.CreateSpec(*p, models.Interval(0)) + specID, err = orm.CreateSpec(ctx, nil, *p, models.Interval(0)) require.NoError(t, err) var pipelineSpecs []pipeline.Spec @@ -152,7 +153,7 @@ func TestPipelineORM_Integration(t *testing.T) { lggr := logger.TestLogger(t) cfg := configtest.NewTestGeneralConfig(t) clearJobsDb(t, db) - orm := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + orm := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{Client: evmtest.NewEthClientMockWithDefaultChain(t), DB: db, GeneralConfig: config, KeyStore: ethKeyStore}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) diff --git a/core/services/job/kv_orm_test.go b/core/services/job/kv_orm_test.go index 6a3269e9992..3ba03b8bc3c 100644 --- a/core/services/job/kv_orm_test.go +++ b/core/services/job/kv_orm_test.go @@ -29,7 +29,7 @@ func TestJobKVStore(t *testing.T) { lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobID := int32(1337) diff --git a/core/services/job/orm.go b/core/services/job/orm.go index 4ac1e7c6047..9d2a6545163 100644 --- a/core/services/job/orm.go +++ b/core/services/job/orm.go @@ -456,7 +456,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { o.lggr.Panicf("Unsupported jb.Type: %v", jb.Type) } - pipelineSpecID, err := o.pipelineORM.CreateSpec(p, jb.MaxTaskDuration, pg.WithQueryer(tx)) + pipelineSpecID, err := o.pipelineORM.CreateSpec(ctx, tx, p, jb.MaxTaskDuration) if err != nil { return errors.Wrap(err, "failed to create pipeline spec") } diff --git a/core/services/job/runner_integration_test.go b/core/services/job/runner_integration_test.go index ed2950ac382..6149bb71cf6 100644 --- a/core/services/job/runner_integration_test.go +++ b/core/services/job/runner_integration_test.go @@ -80,7 +80,7 @@ func TestRunner(t *testing.T) { ethClient.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Maybe().Return(nil, nil) ctx := testutils.Context(t) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) require.NoError(t, pipelineORM.Start(ctx)) t.Cleanup(func() { assert.NoError(t, pipelineORM.Close()) }) btORM := bridges.NewORM(db) @@ -888,7 +888,7 @@ func TestRunner_Success_Callback_AsyncJob(t *testing.T) { _ = cltest.CreateJobRunViaExternalInitiatorV2(t, app, jobUUID, *eia, cltest.MustJSONMarshal(t, eiRequest)) - pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(app.GetSqlxDB()) jobORM := NewTestORM(t, app.GetSqlxDB(), pipelineORM, bridgesORM, app.KeyStore, cfg.Database()) @@ -1065,7 +1065,7 @@ func TestRunner_Error_Callback_AsyncJob(t *testing.T) { t.Run("simulate request from EI -> Core node with erroring callback", func(t *testing.T) { _ = cltest.CreateJobRunViaExternalInitiatorV2(t, app, jobUUID, *eia, cltest.MustJSONMarshal(t, eiRequest)) - pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(app.GetSqlxDB()) jobORM := NewTestORM(t, app.GetSqlxDB(), pipelineORM, bridgesORM, app.KeyStore, cfg.Database()) diff --git a/core/services/job/spawner.go b/core/services/job/spawner.go index 3d30a3190b3..8024424226c 100644 --- a/core/services/job/spawner.go +++ b/core/services/job/spawner.go @@ -78,7 +78,7 @@ type ( // non-db side effects. This is required in order to guarantee mutual atomicity between // all tasks intended to happen during job deletion. For the same reason, the job will // not show up in the db within OnDeleteJob(), even though it is still actively running. - OnDeleteJob(ctx context.Context, jb Job, q pg.Queryer) error + OnDeleteJob(ctx context.Context, jb Job) error } activeJob struct { @@ -340,7 +340,7 @@ func (js *spawner) DeleteJob(jobID int32, qopts ...pg.QOpt) error { // we know the DELETE will succeed. The DELETE will be finalized only if all db transactions in OnDeleteJob() // succeed. If either of those fails, the job will not be stopped and everything will be rolled back. lggr.Debugw("Callback: OnDeleteJob") - err = aj.delegate.OnDeleteJob(ctx, aj.spec, tx) + err = aj.delegate.OnDeleteJob(ctx, aj.spec) if err != nil { return err } @@ -395,7 +395,9 @@ func (n *NullDelegate) ServicesForSpec(ctx context.Context, spec Job) (s []Servi return } -func (n *NullDelegate) BeforeJobCreated(spec Job) {} -func (n *NullDelegate) AfterJobCreated(spec Job) {} -func (n *NullDelegate) BeforeJobDeleted(spec Job) {} -func (n *NullDelegate) OnDeleteJob(ctx context.Context, spec Job, q pg.Queryer) error { return nil } +func (n *NullDelegate) BeforeJobCreated(spec Job) {} +func (n *NullDelegate) AfterJobCreated(spec Job) {} +func (n *NullDelegate) BeforeJobDeleted(spec Job) {} +func (n *NullDelegate) OnDeleteJob(context.Context, Job) error { + return nil +} diff --git a/core/services/job/spawner_test.go b/core/services/job/spawner_test.go index d2e7a80d5d4..802763cfaab 100644 --- a/core/services/job/spawner_test.go +++ b/core/services/job/spawner_test.go @@ -100,7 +100,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) t.Run("should respect its dependents", func(t *testing.T) { lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) a := utils.NewDependentAwaiter() a.AddDependents(1) spawner := job.NewSpawner(orm, config.Database(), noopChecker{}, map[job.Type]job.Delegate{}, db, lggr, []utils.DependentAwaiter{a}) @@ -123,7 +123,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { jobB := makeOCRJobSpec(t, address, bridge.Name.String(), bridge2.Name.String()) lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) eventuallyA := cltest.NewAwaiter() serviceA1 := mocks.NewServiceCtx(t) @@ -188,7 +188,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { serviceA2.On("Start", mock.Anything).Return(nil).Once().Run(func(mock.Arguments) { eventually.ItHappened() }) lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) mailMon := servicetest.Run(t, mailboxtest.NewMonitor(t)) d := ocr.NewDelegate(nil, orm, nil, nil, nil, monitoringEndpoint, legacyChains, logger.TestLogger(t), config.Database(), mailMon) delegateA := &delegate{jobA.Type, []job.ServiceCtx{serviceA1, serviceA2}, 0, nil, d} @@ -222,7 +222,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { serviceA2.On("Start", mock.Anything).Return(nil).Once().Run(func(mock.Arguments) { eventuallyStart.ItHappened() }) lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) mailMon := servicetest.Run(t, mailboxtest.NewMonitor(t)) d := ocr.NewDelegate(nil, orm, nil, nil, nil, monitoringEndpoint, legacyChains, logger.TestLogger(t), config.Database(), mailMon) delegateA := &delegate{jobA.Type, []job.ServiceCtx{serviceA1, serviceA2}, 0, nil, d} @@ -300,7 +300,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { jobOCR2VRF := makeOCR2VRFJobSpec(t, keyStore, config, address, chain.ID(), 2) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) mailMon := servicetest.Run(t, mailboxtest.NewMonitor(t)) processConfig := plugins.NewRegistrarConfig(loop.GRPCOpts{}, func(name string) (*plugins.RegisteredLoop, error) { return nil, nil }, func(loopId string) {}) diff --git a/core/services/keeper/delegate.go b/core/services/keeper/delegate.go index 679ccf3053d..9652434759b 100644 --- a/core/services/keeper/delegate.go +++ b/core/services/keeper/delegate.go @@ -11,7 +11,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -51,10 +50,10 @@ func (d *Delegate) JobType() job.Type { return job.Keeper } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { diff --git a/core/services/keeper/integration_test.go b/core/services/keeper/integration_test.go index 49073c8de56..08699d3d835 100644 --- a/core/services/keeper/integration_test.go +++ b/core/services/keeper/integration_test.go @@ -175,6 +175,7 @@ func TestKeeperEthIntegration(t *testing.T) { test := tt t.Run(test.name, func(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) g := gomega.NewWithT(t) // setup node key @@ -249,12 +250,12 @@ func TestKeeperEthIntegration(t *testing.T) { korm := keeper.NewORM(db, logger.TestLogger(t)) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, backend.Backend(), nodeKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // create job regAddrEIP55 := evmtypes.EIP55AddressFromAddress(regAddr) job := cltest.MustInsertKeeperJob(t, db, korm, nodeAddressEIP55, regAddrEIP55) - err = app.JobSpawner().StartService(testutils.Context(t), job) + err = app.JobSpawner().StartService(ctx, job) require.NoError(t, err) // keeper job is triggered and payload is received @@ -311,7 +312,7 @@ func TestKeeperEthIntegration(t *testing.T) { cltest.AssertRecordEventually(t, app.GetSqlxDB(), ®istry, fmt.Sprintf("SELECT * FROM keeper_registries WHERE id = %d", registry.ID), func() bool { return registry.KeeperIndex == -1 }) - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) // Since we set grace period to 0, we can have more than 1 pipeline run per perform // This happens in case we start a pipeline run before previous perform tx is committed to chain @@ -481,6 +482,7 @@ func TestKeeperForwarderEthIntegration(t *testing.T) { func TestMaxPerformDataSize(t *testing.T) { t.Parallel() t.Run("max_perform_data_size_test", func(t *testing.T) { + ctx := testutils.Context(t) maxPerformDataSize := 1000 // Will be set as config override g := gomega.NewWithT(t) @@ -552,12 +554,12 @@ func TestMaxPerformDataSize(t *testing.T) { korm := keeper.NewORM(db, logger.TestLogger(t)) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, backend.Backend(), nodeKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // create job regAddrEIP55 := evmtypes.EIP55AddressFromAddress(regAddr) job := cltest.MustInsertKeeperJob(t, db, korm, nodeAddressEIP55, regAddrEIP55) - err = app.JobSpawner().StartService(testutils.Context(t), job) + err = app.JobSpawner().StartService(ctx, job) require.NoError(t, err) // keeper job is triggered diff --git a/core/services/keeper/registry1_1_synchronizer_test.go b/core/services/keeper/registry1_1_synchronizer_test.go index 24a6a7288a7..61482208e5c 100644 --- a/core/services/keeper/registry1_1_synchronizer_test.go +++ b/core/services/keeper/registry1_1_synchronizer_test.go @@ -201,6 +201,7 @@ func Test_RegistrySynchronizer1_1_FullSync(t *testing.T) { } func Test_RegistrySynchronizer1_1_ConfigSetLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -235,11 +236,11 @@ func Test_RegistrySynchronizer1_1_ConfigSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.BlockCountPerTurn == 40 @@ -248,6 +249,7 @@ func Test_RegistrySynchronizer1_1_ConfigSetLog(t *testing.T) { } func Test_RegistrySynchronizer1_1_KeepersUpdatedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -281,11 +283,11 @@ func Test_RegistrySynchronizer1_1_KeepersUpdatedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.NumKeepers == 2 @@ -293,6 +295,7 @@ func Test_RegistrySynchronizer1_1_KeepersUpdatedLog(t *testing.T) { cltest.AssertCount(t, db, "keeper_registries", 1) } func Test_RegistrySynchronizer1_1_UpkeepCanceledLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -320,16 +323,17 @@ func Test_RegistrySynchronizer1_1_UpkeepCanceledLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_1_UpkeepRegisteredLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -360,16 +364,17 @@ func Test_RegistrySynchronizer1_1_UpkeepRegisteredLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_1_UpkeepPerformedLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) @@ -401,11 +406,11 @@ func Test_RegistrySynchronizer1_1_UpkeepPerformedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(func() int64 { var upkeep keeper.UpkeepRegistration diff --git a/core/services/keeper/registry1_2_synchronizer_test.go b/core/services/keeper/registry1_2_synchronizer_test.go index 23e6c0355ec..a62e27b8759 100644 --- a/core/services/keeper/registry1_2_synchronizer_test.go +++ b/core/services/keeper/registry1_2_synchronizer_test.go @@ -220,6 +220,7 @@ func Test_RegistrySynchronizer1_2_FullSync(t *testing.T) { } func Test_RegistrySynchronizer1_2_ConfigSetLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -258,11 +259,11 @@ func Test_RegistrySynchronizer1_2_ConfigSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.BlockCountPerTurn == 40 @@ -271,6 +272,7 @@ func Test_RegistrySynchronizer1_2_ConfigSetLog(t *testing.T) { } func Test_RegistrySynchronizer1_2_KeepersUpdatedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -308,11 +310,11 @@ func Test_RegistrySynchronizer1_2_KeepersUpdatedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.NumKeepers == 2 @@ -321,6 +323,7 @@ func Test_RegistrySynchronizer1_2_KeepersUpdatedLog(t *testing.T) { } func Test_RegistrySynchronizer1_2_UpkeepCanceledLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -349,16 +352,17 @@ func Test_RegistrySynchronizer1_2_UpkeepCanceledLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_2_UpkeepRegisteredLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -390,16 +394,17 @@ func Test_RegistrySynchronizer1_2_UpkeepRegisteredLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_2_UpkeepPerformedLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) @@ -432,11 +437,11 @@ func Test_RegistrySynchronizer1_2_UpkeepPerformedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(func() int64 { var upkeep keeper.UpkeepRegistration @@ -454,6 +459,7 @@ func Test_RegistrySynchronizer1_2_UpkeepPerformedLog(t *testing.T) { } func Test_RegistrySynchronizer1_2_UpkeepGasLimitSetLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) @@ -496,16 +502,17 @@ func Test_RegistrySynchronizer1_2_UpkeepGasLimitSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(getExecuteGas, testutils.WaitTimeout(t), cltest.DBPollingInterval).Should(gomega.Equal(uint32(4_000_000))) } func Test_RegistrySynchronizer1_2_UpkeepReceivedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -537,16 +544,17 @@ func Test_RegistrySynchronizer1_2_UpkeepReceivedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_2_UpkeepMigratedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -575,11 +583,11 @@ func Test_RegistrySynchronizer1_2_UpkeepMigratedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } diff --git a/core/services/keeper/registry1_3_synchronizer_test.go b/core/services/keeper/registry1_3_synchronizer_test.go index 2b5900ac189..7ebbbc25469 100644 --- a/core/services/keeper/registry1_3_synchronizer_test.go +++ b/core/services/keeper/registry1_3_synchronizer_test.go @@ -225,6 +225,7 @@ func Test_RegistrySynchronizer1_3_FullSync(t *testing.T) { } func Test_RegistrySynchronizer1_3_ConfigSetLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -263,11 +264,11 @@ func Test_RegistrySynchronizer1_3_ConfigSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.BlockCountPerTurn == 40 @@ -276,6 +277,7 @@ func Test_RegistrySynchronizer1_3_ConfigSetLog(t *testing.T) { } func Test_RegistrySynchronizer1_3_KeepersUpdatedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -313,11 +315,11 @@ func Test_RegistrySynchronizer1_3_KeepersUpdatedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.NumKeepers == 2 @@ -326,6 +328,7 @@ func Test_RegistrySynchronizer1_3_KeepersUpdatedLog(t *testing.T) { } func Test_RegistrySynchronizer1_3_UpkeepCanceledLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -354,16 +357,17 @@ func Test_RegistrySynchronizer1_3_UpkeepCanceledLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_3_UpkeepRegisteredLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -395,16 +399,17 @@ func Test_RegistrySynchronizer1_3_UpkeepRegisteredLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_3_UpkeepPerformedLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) @@ -437,11 +442,11 @@ func Test_RegistrySynchronizer1_3_UpkeepPerformedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(func() int64 { var upkeep keeper.UpkeepRegistration @@ -459,6 +464,7 @@ func Test_RegistrySynchronizer1_3_UpkeepPerformedLog(t *testing.T) { } func Test_RegistrySynchronizer1_3_UpkeepGasLimitSetLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) @@ -501,16 +507,17 @@ func Test_RegistrySynchronizer1_3_UpkeepGasLimitSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(getExecuteGas, testutils.WaitTimeout(t), cltest.DBPollingInterval).Should(gomega.Equal(uint32(4_000_000))) } func Test_RegistrySynchronizer1_3_UpkeepReceivedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -542,16 +549,17 @@ func Test_RegistrySynchronizer1_3_UpkeepReceivedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_3_UpkeepMigratedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -580,17 +588,18 @@ func Test_RegistrySynchronizer1_3_UpkeepMigratedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) // race condition: "wait for count" cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_3_UpkeepPausedLog_UpkeepUnpausedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -620,11 +629,11 @@ func Test_RegistrySynchronizer1_3_UpkeepPausedLog_UpkeepUnpausedLog(t *testing.T logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) @@ -635,11 +644,11 @@ func Test_RegistrySynchronizer1_3_UpkeepPausedLog_UpkeepUnpausedLog(t *testing.T logBroadcast.On("DecodedLog").Return(&unpausedlog) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 3) var upkeep keeper.UpkeepRegistration @@ -657,6 +666,7 @@ func Test_RegistrySynchronizer1_3_UpkeepPausedLog_UpkeepUnpausedLog(t *testing.T } func Test_RegistrySynchronizer1_3_UpkeepCheckDataUpdatedLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) @@ -694,11 +704,11 @@ func Test_RegistrySynchronizer1_3_UpkeepCheckDataUpdatedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&updatedLog) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(func() []byte { var upkeep keeper.UpkeepRegistration diff --git a/core/services/keeper/registry_synchronizer_log_listener.go b/core/services/keeper/registry_synchronizer_log_listener.go index 099d01d27f6..93ff2e9e950 100644 --- a/core/services/keeper/registry_synchronizer_log_listener.go +++ b/core/services/keeper/registry_synchronizer_log_listener.go @@ -1,6 +1,7 @@ package keeper import ( + "context" "reflect" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" @@ -10,7 +11,7 @@ func (rs *RegistrySynchronizer) JobID() int32 { return rs.job.ID } -func (rs *RegistrySynchronizer) HandleLog(broadcast log.Broadcast) { +func (rs *RegistrySynchronizer) HandleLog(ctx context.Context, broadcast log.Broadcast) { eventLog := broadcast.DecodedLog() if eventLog == nil || reflect.ValueOf(eventLog).IsNil() { rs.logger.Panicf("HandleLog: ignoring nil value, type: %T", broadcast) diff --git a/core/services/keeper/registry_synchronizer_process_logs.go b/core/services/keeper/registry_synchronizer_process_logs.go index 0a0e1613c95..a1bdcd8db0b 100644 --- a/core/services/keeper/registry_synchronizer_process_logs.go +++ b/core/services/keeper/registry_synchronizer_process_logs.go @@ -85,7 +85,7 @@ func (rs *RegistrySynchronizer) processLogs(ctx context.Context) { rs.logger.Error(err) } - err = rs.logBroadcaster.MarkConsumed(ctx, broadcast) + err = rs.logBroadcaster.MarkConsumed(ctx, nil, broadcast) if err != nil { rs.logger.Error(errors.Wrapf(err, "unable to mark %T log as consumed, log: %v", broadcast.RawLog(), broadcast.String())) } diff --git a/core/services/ocr/contract_tracker.go b/core/services/ocr/contract_tracker.go index e4845ee3bc2..5746f97cd38 100644 --- a/core/services/ocr/contract_tracker.go +++ b/core/services/ocr/contract_tracker.go @@ -14,13 +14,12 @@ import ( gethTypes "github.com/ethereum/go-ethereum/core/types" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/libocr/gethwrappers/offchainaggregator" "github.com/smartcontractkit/libocr/offchainreporting/confighelper" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting/types" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/common/config" @@ -31,7 +30,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/offchain_aggregator_wrapper" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) // configMailboxSanityLimit is the maximum number of configs that can be held @@ -64,7 +62,7 @@ type ( jobID int32 logger logger.Logger ocrDB OCRContractTrackerDB - q pg.Q + ds sqlutil.DataSource blockTranslator ocrcommon.BlockTranslator cfg ocrcommon.Config mailMon *mailbox.Monitor @@ -92,8 +90,8 @@ type ( } OCRContractTrackerDB interface { - SaveLatestRoundRequested(tx pg.Queryer, rr offchainaggregator.OffchainAggregatorRoundRequested) error - LoadLatestRoundRequested() (rr offchainaggregator.OffchainAggregatorRoundRequested, err error) + SaveLatestRoundRequested(ctx context.Context, tx sqlutil.DataSource, rr offchainaggregator.OffchainAggregatorRoundRequested) error + LoadLatestRoundRequested(ctx context.Context) (rr offchainaggregator.OffchainAggregatorRoundRequested, err error) } ) @@ -112,10 +110,9 @@ func NewOCRContractTracker( logBroadcaster log.Broadcaster, jobID int32, logger logger.Logger, - db *sqlx.DB, + ds sqlutil.DataSource, ocrDB OCRContractTrackerDB, cfg ocrcommon.Config, - q pg.QConfig, headBroadcaster httypes.HeadBroadcaster, mailMon *mailbox.Monitor, ) (o *OCRContractTracker) { @@ -129,7 +126,7 @@ func NewOCRContractTracker( jobID: jobID, logger: logger, ocrDB: ocrDB, - q: pg.NewQ(db, logger, q), + ds: ds, blockTranslator: ocrcommon.NewBlockTranslator(cfg, ethClient, logger), cfg: cfg, mailMon: mailMon, @@ -144,9 +141,9 @@ func NewOCRContractTracker( // Start must be called before logs can be delivered // It ought to be called before starting OCR -func (t *OCRContractTracker) Start(context.Context) error { +func (t *OCRContractTracker) Start(ctx context.Context) error { return t.StartOnce("OCRContractTracker", func() (err error) { - t.latestRoundRequested, err = t.ocrDB.LoadLatestRoundRequested() + t.latestRoundRequested, err = t.ocrDB.LoadLatestRoundRequested(ctx) if err != nil { return errors.Wrap(err, "OCRContractTracker#Start: failed to load latest round requested") } @@ -240,10 +237,7 @@ func (t *OCRContractTracker) processLogs() { // HandleLog complies with LogListener interface // It is not thread safe -func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { - ctx, cancel := t.chStop.NewCtx() - defer cancel() - +func (t *OCRContractTracker) HandleLog(ctx context.Context, lb log.Broadcast) { was, err := t.logBroadcaster.WasAlreadyConsumed(ctx, lb) if err != nil { t.logger.Errorw("could not determine if log was already consumed", "err", err) @@ -255,14 +249,14 @@ func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { raw := lb.RawLog() if raw.Address != t.contract.Address() { t.logger.Errorf("log address of 0x%x does not match configured contract address of 0x%x", raw.Address, t.contract.Address()) - if err2 := t.logBroadcaster.MarkConsumed(ctx, lb); err2 != nil { + if err2 := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err2 != nil { t.logger.Errorw("failed to mark log consumed", "err", err2) } return } topics := raw.Topics if len(topics) == 0 { - if err2 := t.logBroadcaster.MarkConsumed(ctx, lb); err2 != nil { + if err2 := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err2 != nil { t.logger.Errorw("failed to mark log consumed", "err", err2) } return @@ -275,7 +269,7 @@ func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { configSet, err = t.contractFilterer.ParseConfigSet(raw) if err != nil { t.logger.Errorw("could not parse config set", "err", err) - if err2 := t.logBroadcaster.MarkConsumed(ctx, lb); err2 != nil { + if err2 := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err2 != nil { t.logger.Errorw("failed to mark log consumed", "err", err2) } return @@ -292,17 +286,17 @@ func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { rr, err = t.contractFilterer.ParseRoundRequested(raw) if err != nil { t.logger.Errorw("could not parse round requested", "err", err) - if err2 := t.logBroadcaster.MarkConsumed(ctx, lb); err2 != nil { + if err2 := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err2 != nil { t.logger.Errorw("failed to mark log consumed", "err", err2) } return } if IsLaterThan(raw, t.latestRoundRequested.Raw) { - err = t.q.Transaction(func(tx pg.Queryer) error { - if err = t.ocrDB.SaveLatestRoundRequested(tx, *rr); err != nil { + err = sqlutil.TransactDataSource(ctx, t.ds, nil, func(tx sqlutil.DataSource) error { + if err = t.ocrDB.SaveLatestRoundRequested(ctx, tx, *rr); err != nil { return err } - return t.logBroadcaster.MarkConsumed(ctx, lb) + return t.logBroadcaster.MarkConsumed(ctx, tx, lb) }) if err != nil { t.logger.Error(err) @@ -320,7 +314,7 @@ func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { t.logger.Debugw("got unrecognised log topic", "topic", topics[0]) } if !consumed { - if err := t.logBroadcaster.MarkConsumed(ctx, lb); err != nil { + if err := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err != nil { t.logger.Errorw("failed to mark log consumed", "err", err) } } diff --git a/core/services/ocr/contract_tracker_test.go b/core/services/ocr/contract_tracker_test.go index 678af35fa04..5473a2c924c 100644 --- a/core/services/ocr/contract_tracker_test.go +++ b/core/services/ocr/contract_tracker_test.go @@ -97,7 +97,6 @@ func newContractTrackerUni(t *testing.T, opts ...interface{}) (uni contractTrack db, uni.db, cfg.EVM(), - cfg.Database(), uni.hb, mailMon, ) @@ -146,7 +145,7 @@ func Test_OCRContractTracker_LatestBlockHeight(t *testing.T) { uni := newContractTrackerUni(t) uni.hb.On("Subscribe", uni.tracker).Return(&evmtypes.Head{Number: 42}, func() {}) - uni.db.On("LoadLatestRoundRequested").Return(offchainaggregator.OffchainAggregatorRoundRequested{}, nil) + uni.db.On("LoadLatestRoundRequested", mock.Anything).Return(offchainaggregator.OffchainAggregatorRoundRequested{}, nil) uni.lb.On("Register", uni.tracker, mock.Anything).Return(func() {}) servicetest.Run(t, uni.tracker) @@ -172,7 +171,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin rawLog := cltest.LogFromFixture(t, "../../testdata/jsonrpc/round_requested_log_1_1.json") logBroadcast.On("RawLog").Return(rawLog).Maybe() logBroadcast.On("String").Return("").Maybe() - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) configDigest, epoch, round, err := uni.tracker.LatestRoundRequested(testutils.Context(t), 0) @@ -181,7 +180,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin require.Equal(t, 0, int(round)) require.Equal(t, 0, int(epoch)) - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -203,7 +202,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin require.Equal(t, 0, int(round)) require.Equal(t, 0, int(epoch)) - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -228,13 +227,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast.On("RawLog").Return(rawLog).Maybe() logBroadcast.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 1 })).Return(nil) - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -248,13 +247,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast2.On("RawLog").Return(rawLog2) logBroadcast2.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 9 })).Return(nil) - uni.tracker.HandleLog(logBroadcast2) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast2) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -263,7 +262,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin assert.Equal(t, 9, int(round)) // Same round with lower epoch is ignored - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -277,13 +276,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast3.On("RawLog").Return(rawLog3).Maybe() logBroadcast3.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { return rr.Epoch == 2 && rr.Round == 1 })).Return(nil) - uni.tracker.HandleLog(logBroadcast3) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast3) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -301,9 +300,9 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything).Return(errors.New("something exploded")) + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("something exploded")) - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err := uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -331,7 +330,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin eventuallyCloseHeadBroadcaster := cltest.NewAwaiter() uni.hb.On("Subscribe", uni.tracker).Return((*evmtypes.Head)(nil), func() { eventuallyCloseHeadBroadcaster.ItHappened() }) - uni.db.On("LoadLatestRoundRequested").Return(rr, nil) + uni.db.On("LoadLatestRoundRequested", mock.Anything).Return(rr, nil) require.NoError(t, uni.tracker.Start(testutils.Context(t))) diff --git a/core/services/ocr/database.go b/core/services/ocr/database.go index 977c371c15d..95993de9d5c 100644 --- a/core/services/ocr/database.go +++ b/core/services/ocr/database.go @@ -11,17 +11,16 @@ import ( "github.com/pkg/errors" "go.uber.org/multierr" - "github.com/jmoiron/sqlx" "github.com/smartcontractkit/libocr/gethwrappers/offchainaggregator" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting/types" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type db struct { - q pg.Q + ds sqlutil.DataSource oracleSpecID int32 lggr logger.SugaredLogger } @@ -32,11 +31,9 @@ var ( ) // NewDB returns a new DB scoped to this oracleSpecID -func NewDB(sqlxDB *sqlx.DB, oracleSpecID int32, lggr logger.Logger, cfg pg.QConfig) *db { - namedLogger := lggr.Named("OCR.DB") - +func NewDB(ds sqlutil.DataSource, oracleSpecID int32, lggr logger.Logger) *db { return &db{ - q: pg.NewQ(sqlxDB, namedLogger, cfg), + ds: ds, oracleSpecID: oracleSpecID, lggr: logger.Sugared(lggr), } @@ -54,7 +51,7 @@ func (d *db) ReadState(ctx context.Context, cd ocrtypes.ConfigDigest) (ps *ocrty var tmp []int64 var highestSentEpochTmp int64 - err = d.q.QueryRowxContext(ctx, stmt, d.oracleSpecID, cd).Scan(&ps.Epoch, &highestSentEpochTmp, pq.Array(&tmp)) + err = d.ds.QueryRowxContext(ctx, stmt, d.oracleSpecID, cd).Scan(&ps.Epoch, &highestSentEpochTmp, pq.Array(&tmp)) if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -90,7 +87,9 @@ func (d *db) WriteState(ctx context.Context, cd ocrtypes.ConfigDigest, state ocr NOW() ) ` - _, err := d.q.WithOpts(pg.WithLongQueryTimeout()).ExecContext( + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() + _, err := d.ds.ExecContext( ctx, stmt, d.oracleSpecID, cd, state.Epoch, state.HighestSentEpoch, pq.Array(&highestReceivedEpoch), ) @@ -109,7 +108,7 @@ func (d *db) ReadConfig(ctx context.Context) (c *ocrtypes.ContractConfig, err er var signers [][]byte var transmitters [][]byte - err = d.q.QueryRowContext(ctx, stmt, d.oracleSpecID).Scan( + err = d.ds.QueryRowxContext(ctx, stmt, d.oracleSpecID).Scan( &c.ConfigDigest, (*pq.ByteaArray)(&signers), (*pq.ByteaArray)(&transmitters), @@ -155,7 +154,7 @@ func (d *db) WriteConfig(ctx context.Context, c ocrtypes.ContractConfig) error { encoded = EXCLUDED.encoded, updated_at = NOW() ` - _, err := d.q.ExecContext(ctx, stmt, d.oracleSpecID, c.ConfigDigest, pq.ByteaArray(signers), pq.ByteaArray(transmitters), c.Threshold, int(c.EncodedConfigVersion), c.Encoded) + _, err := d.ds.ExecContext(ctx, stmt, d.oracleSpecID, c.ConfigDigest, pq.ByteaArray(signers), pq.ByteaArray(transmitters), c.Threshold, int(c.EncodedConfigVersion), c.Encoded) return errors.Wrap(err, "WriteConfig failed") } @@ -201,14 +200,14 @@ func (d *db) StorePendingTransmission(ctx context.Context, k ocrtypes.ReportTime updated_at = NOW() ` - _, err := d.q.ExecContext(ctx, stmt, d.oracleSpecID, k.ConfigDigest, k.Epoch, k.Round, p.Time, median, p.SerializedReport, pq.ByteaArray(rs), pq.ByteaArray(ss), p.Vs[:]) + _, err := d.ds.ExecContext(ctx, stmt, d.oracleSpecID, k.ConfigDigest, k.Epoch, k.Round, p.Time, median, p.SerializedReport, pq.ByteaArray(rs), pq.ByteaArray(ss), p.Vs[:]) return errors.Wrap(err, "StorePendingTransmission failed") } func (d *db) PendingTransmissionsWithConfigDigest(ctx context.Context, cd ocrtypes.ConfigDigest) (map[ocrtypes.ReportTimestamp]ocrtypes.PendingTransmission, error) { //nolint sqlclosecheck false positive - rows, err := d.q.QueryContext(ctx, ` + rows, err := d.ds.QueryContext(ctx, ` SELECT config_digest, epoch, @@ -269,7 +268,9 @@ WHERE ocr_oracle_spec_id = $1 AND config_digest = $2 } func (d *db) DeletePendingTransmission(ctx context.Context, k ocrtypes.ReportTimestamp) (err error) { - _, err = d.q.WithOpts(pg.WithLongQueryTimeout()).ExecContext(ctx, ` + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() + _, err = d.ds.ExecContext(ctx, ` DELETE FROM ocr_pending_transmissions WHERE ocr_oracle_spec_id = $1 AND config_digest = $2 AND epoch = $3 AND round = $4 `, d.oracleSpecID, k.ConfigDigest, k.Epoch, k.Round) @@ -280,7 +281,9 @@ WHERE ocr_oracle_spec_id = $1 AND config_digest = $2 AND epoch = $3 AND round = } func (d *db) DeletePendingTransmissionsOlderThan(ctx context.Context, t time.Time) (err error) { - _, err = d.q.WithOpts(pg.WithLongQueryTimeout()).ExecContext(ctx, ` + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() + _, err = d.ds.ExecContext(ctx, ` DELETE FROM ocr_pending_transmissions WHERE ocr_oracle_spec_id = $1 AND time < $2 `, d.oracleSpecID, t) @@ -290,12 +293,12 @@ WHERE ocr_oracle_spec_id = $1 AND time < $2 return } -func (d *db) SaveLatestRoundRequested(tx pg.Queryer, rr offchainaggregator.OffchainAggregatorRoundRequested) error { +func (d *db) SaveLatestRoundRequested(ctx context.Context, tx sqlutil.DataSource, rr offchainaggregator.OffchainAggregatorRoundRequested) error { rawLog, err := json.Marshal(rr.Raw) if err != nil { return errors.Wrap(err, "could not marshal log as JSON") } - _, err = tx.Exec(` + _, err = tx.ExecContext(ctx, ` INSERT INTO ocr_latest_round_requested (ocr_oracle_spec_id, requester, config_digest, epoch, round, raw) VALUES ($1,$2,$3,$4,$5,$6) ON CONFLICT (ocr_oracle_spec_id) DO UPDATE SET requester = EXCLUDED.requester, @@ -308,8 +311,8 @@ VALUES ($1,$2,$3,$4,$5,$6) ON CONFLICT (ocr_oracle_spec_id) DO UPDATE SET return errors.Wrap(err, "could not save latest round requested") } -func (d *db) LoadLatestRoundRequested() (rr offchainaggregator.OffchainAggregatorRoundRequested, err error) { - rows, err := d.q.Query(` +func (d *db) LoadLatestRoundRequested(ctx context.Context) (rr offchainaggregator.OffchainAggregatorRoundRequested, err error) { + rows, err := d.ds.QueryContext(ctx, ` SELECT requester, config_digest, epoch, round, raw FROM ocr_latest_round_requested WHERE ocr_oracle_spec_id = $1 diff --git a/core/services/ocr/database_test.go b/core/services/ocr/database_test.go index 5ccf257b2bb..8b8d64c49c9 100644 --- a/core/services/ocr/database_test.go +++ b/core/services/ocr/database_test.go @@ -410,7 +410,8 @@ func Test_DB_LatestRoundRequested(t *testing.T) { } t.Run("saves latest round requested", func(t *testing.T) { - err := odb.SaveLatestRoundRequested(sqlDB, rr) + ctx := testutils.Context(t) + err := odb.SaveLatestRoundRequested(ctx, sqlDB, rr) require.NoError(t, err) rawLog.Index = 42 @@ -424,17 +425,18 @@ func Test_DB_LatestRoundRequested(t *testing.T) { Raw: rawLog, } - err = odb.SaveLatestRoundRequested(sqlDB, rr) + err = odb.SaveLatestRoundRequested(ctx, sqlDB, rr) require.NoError(t, err) }) t.Run("loads latest round requested", func(t *testing.T) { + ctx := testutils.Context(t) // There is no round for db2 - lrr, err := odb2.LoadLatestRoundRequested() + lrr, err := odb2.LoadLatestRoundRequested(ctx) require.NoError(t, err) require.Equal(t, 0, int(lrr.Epoch)) - lrr, err = odb.LoadLatestRoundRequested() + lrr, err = odb.LoadLatestRoundRequested(ctx) require.NoError(t, err) assert.Equal(t, rr, lrr) diff --git a/core/services/ocr/delegate.go b/core/services/ocr/delegate.go index bcdda397e20..63055543f88 100644 --- a/core/services/ocr/delegate.go +++ b/core/services/ocr/delegate.go @@ -28,7 +28,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/synchronization" "github.com/smartcontractkit/chainlink/v2/core/services/telemetry" @@ -82,10 +81,10 @@ func (d *Delegate) JobType() job.Type { return job.OffchainReporting } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the OCR services that need to run for this job func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { @@ -121,7 +120,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] return nil, errors.Wrap(err, "could not instantiate NewOffchainAggregatorCaller") } - ocrDB := NewDB(d.db, concreteSpec.ID, lggr, d.cfg) + ocrDB := NewDB(d.db, concreteSpec.ID, lggr) tracker := NewOCRContractTracker( contract, @@ -134,7 +133,6 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] d.db, ocrDB, chain.Config().EVM(), - chain.Config().Database(), chain.HeadBroadcaster(), d.mailMon, ) diff --git a/core/services/ocr/helpers_internal_test.go b/core/services/ocr/helpers_internal_test.go index 57b669ef401..c6a3d1ac401 100644 --- a/core/services/ocr/helpers_internal_test.go +++ b/core/services/ocr/helpers_internal_test.go @@ -5,7 +5,6 @@ import ( "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" ) @@ -14,5 +13,5 @@ func (c *ConfigOverriderImpl) ExportedUpdateFlagsStatus() error { } func NewTestDB(t *testing.T, sqldb *sqlx.DB, oracleSpecID int32) *db { - return NewDB(sqldb, oracleSpecID, logger.TestLogger(t), pgtest.NewQConfig(true)) + return NewDB(sqldb, oracleSpecID, logger.TestLogger(t)) } diff --git a/core/services/ocr/mocks/ocr_contract_tracker_db.go b/core/services/ocr/mocks/ocr_contract_tracker_db.go index 6724e418014..42eebf939d7 100644 --- a/core/services/ocr/mocks/ocr_contract_tracker_db.go +++ b/core/services/ocr/mocks/ocr_contract_tracker_db.go @@ -3,11 +3,13 @@ package mocks import ( + context "context" + mock "github.com/stretchr/testify/mock" offchainaggregator "github.com/smartcontractkit/libocr/gethwrappers/offchainaggregator" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) // OCRContractTrackerDB is an autogenerated mock type for the OCRContractTrackerDB type @@ -15,9 +17,9 @@ type OCRContractTrackerDB struct { mock.Mock } -// LoadLatestRoundRequested provides a mock function with given fields: -func (_m *OCRContractTrackerDB) LoadLatestRoundRequested() (offchainaggregator.OffchainAggregatorRoundRequested, error) { - ret := _m.Called() +// LoadLatestRoundRequested provides a mock function with given fields: ctx +func (_m *OCRContractTrackerDB) LoadLatestRoundRequested(ctx context.Context) (offchainaggregator.OffchainAggregatorRoundRequested, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for LoadLatestRoundRequested") @@ -25,17 +27,17 @@ func (_m *OCRContractTrackerDB) LoadLatestRoundRequested() (offchainaggregator.O var r0 offchainaggregator.OffchainAggregatorRoundRequested var r1 error - if rf, ok := ret.Get(0).(func() (offchainaggregator.OffchainAggregatorRoundRequested, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (offchainaggregator.OffchainAggregatorRoundRequested, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() offchainaggregator.OffchainAggregatorRoundRequested); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) offchainaggregator.OffchainAggregatorRoundRequested); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(offchainaggregator.OffchainAggregatorRoundRequested) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -43,17 +45,17 @@ func (_m *OCRContractTrackerDB) LoadLatestRoundRequested() (offchainaggregator.O return r0, r1 } -// SaveLatestRoundRequested provides a mock function with given fields: tx, rr -func (_m *OCRContractTrackerDB) SaveLatestRoundRequested(tx pg.Queryer, rr offchainaggregator.OffchainAggregatorRoundRequested) error { - ret := _m.Called(tx, rr) +// SaveLatestRoundRequested provides a mock function with given fields: ctx, tx, rr +func (_m *OCRContractTrackerDB) SaveLatestRoundRequested(ctx context.Context, tx sqlutil.DataSource, rr offchainaggregator.OffchainAggregatorRoundRequested) error { + ret := _m.Called(ctx, tx, rr) if len(ret) == 0 { panic("no return value specified for SaveLatestRoundRequested") } var r0 error - if rf, ok := ret.Get(0).(func(pg.Queryer, offchainaggregator.OffchainAggregatorRoundRequested) error); ok { - r0 = rf(tx, rr) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, offchainaggregator.OffchainAggregatorRoundRequested) error); ok { + r0 = rf(ctx, tx, rr) } else { r0 = ret.Error(0) } diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index a00ed195903..da6d6a1b6e7 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -278,7 +278,7 @@ func (d *Delegate) BeforeJobCreated(spec job.Job) { } func (d *Delegate) AfterJobCreated(spec job.Job) {} func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) error { +func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job) error { // If the job spec is malformed in any way, we report the error but return nil so that // the job deletion itself isn't blocked. @@ -295,13 +295,13 @@ func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) er } // we only have clean to do for the EVM if rid.Network == relay.EVM { - return d.cleanupEVM(ctx, jb, q, rid) + return d.cleanupEVM(ctx, jb, rid) } return nil } // cleanupEVM is a helper for clean up EVM specific state when a job is deleted -func (d *Delegate) cleanupEVM(ctx context.Context, jb job.Job, q pg.Queryer, relayID relay.ID) error { +func (d *Delegate) cleanupEVM(ctx context.Context, jb job.Job, relayID relay.ID) error { // If UnregisterFilter returns an // error, that means it failed to remove a valid active filter from the db. We do abort the job deletion // in that case, since it should be easy for the user to retry and will avoid leaving the db in diff --git a/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go b/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go index 647aaf59056..d8678844d25 100644 --- a/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go +++ b/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go @@ -43,7 +43,7 @@ func TestAdapter_Integration(t *testing.T) { require.NoError(t, err) keystore := keystore.NewInMemory(db, utils.FastScryptParams, logger, cfg.Database()) - pipelineORM := pipeline.NewORM(db, logger, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger, cfg.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := job.NewORM(db, pipelineORM, bridgesORM, keystore, logger, cfg.Database()) pr := pipeline.NewRunner( diff --git a/core/services/ocrbootstrap/delegate.go b/core/services/ocrbootstrap/delegate.go index 9ed7cbea477..2d87cf80346 100644 --- a/core/services/ocrbootstrap/delegate.go +++ b/core/services/ocrbootstrap/delegate.go @@ -19,7 +19,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/validate" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/relay" ) @@ -190,6 +189,6 @@ func (d *Delegate) AfterJobCreated(spec job.Job) { func (d *Delegate) BeforeJobDeleted(spec job.Job) {} // OnDeleteJob satisfies the job.Delegate interface. -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } diff --git a/core/services/ocrcommon/run_saver.go b/core/services/ocrcommon/run_saver.go index 6d85aa857a4..52ffb31cea0 100644 --- a/core/services/ocrcommon/run_saver.go +++ b/core/services/ocrcommon/run_saver.go @@ -2,15 +2,16 @@ package ocrcommon import ( "context" + "time" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) type Runner interface { - InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error + InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *pipeline.Run, saveSuccessfulTaskRuns bool) error } type RunResultSaver struct { @@ -19,7 +20,7 @@ type RunResultSaver struct { maxSuccessfulRuns uint64 runResults chan *pipeline.Run pipelineRunner Runner - done chan struct{} + stopCh services.StopChan logger logger.Logger } @@ -36,7 +37,7 @@ func NewResultRunSaver(pipelineRunner Runner, maxSuccessfulRuns: maxSuccessfulRuns, runResults: make(chan *pipeline.Run, resultsWriteDepth), pipelineRunner: pipelineRunner, - done: make(chan struct{}), + stopCh: make(chan struct{}), logger: logger.Named("RunResultSaver"), } } @@ -55,6 +56,8 @@ func (r *RunResultSaver) Save(run *pipeline.Run) { func (r *RunResultSaver) Start(context.Context) error { return r.StartOnce("RunResultSaver", func() error { go func() { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() for { select { case run := <-r.runResults: @@ -66,10 +69,10 @@ func (r *RunResultSaver) Start(context.Context) error { r.logger.Tracew("RunSaver: saving job run", "run", run) // We do not want save successful TaskRuns as OCR runs very frequently so a lot of records // are produced and the successful TaskRuns do not provide value. - if err := r.pipelineRunner.InsertFinishedRun(run, false); err != nil { + if err := r.pipelineRunner.InsertFinishedRun(ctx, nil, run, false); err != nil { r.logger.Errorw("error inserting finished results", "err", err) } - case <-r.done: + case <-r.stopCh: return } } @@ -80,7 +83,10 @@ func (r *RunResultSaver) Start(context.Context) error { func (r *RunResultSaver) Close() error { return r.StopOnce("RunResultSaver", func() error { - r.done <- struct{}{} + close(r.stopCh) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() // In the unlikely event that there are remaining runResults to write, // drain the channel and save them. @@ -88,7 +94,7 @@ func (r *RunResultSaver) Close() error { select { case run := <-r.runResults: r.logger.Infow("RunSaver: saving job run before exiting", "run", run) - if err := r.pipelineRunner.InsertFinishedRun(run, false); err != nil { + if err := r.pipelineRunner.InsertFinishedRun(ctx, nil, run, false); err != nil { r.logger.Errorw("error inserting finished results", "err", err) } default: diff --git a/core/services/ocrcommon/run_saver_test.go b/core/services/ocrcommon/run_saver_test.go index 7bfe60f2a06..a965792ca1f 100644 --- a/core/services/ocrcommon/run_saver_test.go +++ b/core/services/ocrcommon/run_saver_test.go @@ -25,7 +25,7 @@ func TestRunSaver(t *testing.T) { pipelineRunner.On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = int64(d) + args.Get(2).(*pipeline.Run).ID = int64(d) }). Once() rs.Save(&pipeline.Run{ID: int64(i)}) diff --git a/core/services/pipeline/helpers_test.go b/core/services/pipeline/helpers_test.go index 9ee2dc693f2..0bbdef7a7f2 100644 --- a/core/services/pipeline/helpers_test.go +++ b/core/services/pipeline/helpers_test.go @@ -5,6 +5,7 @@ import ( "github.com/google/uuid" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" ) @@ -63,3 +64,5 @@ func (t *ETHTxTask) HelperSetDependencies(legacyChains legacyevm.LegacyChainCont t.specGasLimit = specGasLimit t.jobType = jobType } + +func (o *orm) Prune(ds sqlutil.DataSource, pipelineSpecID int32) { o.prune(ds, pipelineSpecID) } diff --git a/core/services/pipeline/mocks/orm.go b/core/services/pipeline/mocks/orm.go index b06041767a1..fe9aa2823a4 100644 --- a/core/services/pipeline/mocks/orm.go +++ b/core/services/pipeline/mocks/orm.go @@ -8,10 +8,10 @@ import ( models "github.com/smartcontractkit/chainlink/v2/core/store/models" mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" - pipeline "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + time "time" uuid "github.com/google/uuid" @@ -40,24 +40,17 @@ func (_m *ORM) Close() error { return r0 } -// CreateRun provides a mock function with given fields: run, qopts -func (_m *ORM) CreateRun(run *pipeline.Run, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateRun provides a mock function with given fields: ctx, run +func (_m *ORM) CreateRun(ctx context.Context, run *pipeline.Run) error { + ret := _m.Called(ctx, run) if len(ret) == 0 { panic("no return value specified for CreateRun") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, ...pg.QOpt) error); ok { - r0 = rf(run, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run) error); ok { + r0 = rf(ctx, run) } else { r0 = ret.Error(0) } @@ -65,16 +58,9 @@ func (_m *ORM) CreateRun(run *pipeline.Run, qopts ...pg.QOpt) error { return r0 } -// CreateSpec provides a mock function with given fields: _a0, maxTaskTimeout, qopts -func (_m *ORM) CreateSpec(_a0 pipeline.Pipeline, maxTaskTimeout models.Interval, qopts ...pg.QOpt) (int32, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0, maxTaskTimeout) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateSpec provides a mock function with given fields: ctx, ds, _a2, maxTaskTimeout +func (_m *ORM) CreateSpec(ctx context.Context, ds pipeline.CreateDataSource, _a2 pipeline.Pipeline, maxTaskTimeout models.Interval) (int32, error) { + ret := _m.Called(ctx, ds, _a2, maxTaskTimeout) if len(ret) == 0 { panic("no return value specified for CreateSpec") @@ -82,17 +68,17 @@ func (_m *ORM) CreateSpec(_a0 pipeline.Pipeline, maxTaskTimeout models.Interval, var r0 int32 var r1 error - if rf, ok := ret.Get(0).(func(pipeline.Pipeline, models.Interval, ...pg.QOpt) (int32, error)); ok { - return rf(_a0, maxTaskTimeout, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, pipeline.CreateDataSource, pipeline.Pipeline, models.Interval) (int32, error)); ok { + return rf(ctx, ds, _a2, maxTaskTimeout) } - if rf, ok := ret.Get(0).(func(pipeline.Pipeline, models.Interval, ...pg.QOpt) int32); ok { - r0 = rf(_a0, maxTaskTimeout, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, pipeline.CreateDataSource, pipeline.Pipeline, models.Interval) int32); ok { + r0 = rf(ctx, ds, _a2, maxTaskTimeout) } else { r0 = ret.Get(0).(int32) } - if rf, ok := ret.Get(1).(func(pipeline.Pipeline, models.Interval, ...pg.QOpt) error); ok { - r1 = rf(_a0, maxTaskTimeout, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, pipeline.CreateDataSource, pipeline.Pipeline, models.Interval) error); ok { + r1 = rf(ctx, ds, _a2, maxTaskTimeout) } else { r1 = ret.Error(1) } @@ -100,17 +86,37 @@ func (_m *ORM) CreateSpec(_a0 pipeline.Pipeline, maxTaskTimeout models.Interval, return r0, r1 } -// DeleteRun provides a mock function with given fields: id -func (_m *ORM) DeleteRun(id int64) error { - ret := _m.Called(id) +// DataSource provides a mock function with given fields: +func (_m *ORM) DataSource() sqlutil.DataSource { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for DataSource") + } + + var r0 sqlutil.DataSource + if rf, ok := ret.Get(0).(func() sqlutil.DataSource); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(sqlutil.DataSource) + } + } + + return r0 +} + +// DeleteRun provides a mock function with given fields: ctx, id +func (_m *ORM) DeleteRun(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for DeleteRun") } var r0 error - if rf, ok := ret.Get(0).(func(int64) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -136,9 +142,9 @@ func (_m *ORM) DeleteRunsOlderThan(_a0 context.Context, _a1 time.Duration) error return r0 } -// FindRun provides a mock function with given fields: id -func (_m *ORM) FindRun(id int64) (pipeline.Run, error) { - ret := _m.Called(id) +// FindRun provides a mock function with given fields: ctx, id +func (_m *ORM) FindRun(ctx context.Context, id int64) (pipeline.Run, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for FindRun") @@ -146,17 +152,17 @@ func (_m *ORM) FindRun(id int64) (pipeline.Run, error) { var r0 pipeline.Run var r1 error - if rf, ok := ret.Get(0).(func(int64) (pipeline.Run, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (pipeline.Run, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) pipeline.Run); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) pipeline.Run); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(pipeline.Run) } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -164,9 +170,9 @@ func (_m *ORM) FindRun(id int64) (pipeline.Run, error) { return r0, r1 } -// GetAllRuns provides a mock function with given fields: -func (_m *ORM) GetAllRuns() ([]pipeline.Run, error) { - ret := _m.Called() +// GetAllRuns provides a mock function with given fields: ctx +func (_m *ORM) GetAllRuns(ctx context.Context) ([]pipeline.Run, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for GetAllRuns") @@ -174,19 +180,19 @@ func (_m *ORM) GetAllRuns() ([]pipeline.Run, error) { var r0 []pipeline.Run var r1 error - if rf, ok := ret.Get(0).(func() ([]pipeline.Run, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]pipeline.Run, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []pipeline.Run); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []pipeline.Run); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]pipeline.Run) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -194,24 +200,6 @@ func (_m *ORM) GetAllRuns() ([]pipeline.Run, error) { return r0, r1 } -// GetQ provides a mock function with given fields: -func (_m *ORM) GetQ() pg.Q { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for GetQ") - } - - var r0 pg.Q - if rf, ok := ret.Get(0).(func() pg.Q); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(pg.Q) - } - - return r0 -} - // GetUnfinishedRuns provides a mock function with given fields: _a0, _a1, _a2 func (_m *ORM) GetUnfinishedRuns(_a0 context.Context, _a1 time.Time, _a2 func(pipeline.Run) error) error { ret := _m.Called(_a0, _a1, _a2) @@ -250,24 +238,17 @@ func (_m *ORM) HealthReport() map[string]error { return r0 } -// InsertFinishedRun provides a mock function with given fields: run, saveSuccessfulTaskRuns, qopts -func (_m *ORM) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRun provides a mock function with given fields: ctx, run, saveSuccessfulTaskRuns +func (_m *ORM) InsertFinishedRun(ctx context.Context, run *pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, run, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRun") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(run, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, bool) error); ok { + r0 = rf(ctx, run, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -275,24 +256,17 @@ func (_m *ORM) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, return r0 } -// InsertFinishedRunWithSpec provides a mock function with given fields: run, saveSuccessfulTaskRuns, qopts -func (_m *ORM) InsertFinishedRunWithSpec(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRunWithSpec provides a mock function with given fields: ctx, run, saveSuccessfulTaskRuns +func (_m *ORM) InsertFinishedRunWithSpec(ctx context.Context, run *pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, run, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRunWithSpec") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(run, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, bool) error); ok { + r0 = rf(ctx, run, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -300,24 +274,17 @@ func (_m *ORM) InsertFinishedRunWithSpec(run *pipeline.Run, saveSuccessfulTaskRu return r0 } -// InsertFinishedRuns provides a mock function with given fields: run, saveSuccessfulTaskRuns, qopts -func (_m *ORM) InsertFinishedRuns(run []*pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRuns provides a mock function with given fields: ctx, run, saveSuccessfulTaskRuns +func (_m *ORM) InsertFinishedRuns(ctx context.Context, run []*pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, run, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRuns") } var r0 error - if rf, ok := ret.Get(0).(func([]*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(run, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []*pipeline.Run, bool) error); ok { + r0 = rf(ctx, run, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -325,24 +292,17 @@ func (_m *ORM) InsertFinishedRuns(run []*pipeline.Run, saveSuccessfulTaskRuns bo return r0 } -// InsertRun provides a mock function with given fields: run, qopts -func (_m *ORM) InsertRun(run *pipeline.Run, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertRun provides a mock function with given fields: ctx, run +func (_m *ORM) InsertRun(ctx context.Context, run *pipeline.Run) error { + ret := _m.Called(ctx, run) if len(ret) == 0 { panic("no return value specified for InsertRun") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, ...pg.QOpt) error); ok { - r0 = rf(run, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run) error); ok { + r0 = rf(ctx, run) } else { r0 = ret.Error(0) } @@ -404,16 +364,9 @@ func (_m *ORM) Start(_a0 context.Context) error { return r0 } -// StoreRun provides a mock function with given fields: run, qopts -func (_m *ORM) StoreRun(run *pipeline.Run, qopts ...pg.QOpt) (bool, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// StoreRun provides a mock function with given fields: ctx, run +func (_m *ORM) StoreRun(ctx context.Context, run *pipeline.Run) (bool, error) { + ret := _m.Called(ctx, run) if len(ret) == 0 { panic("no return value specified for StoreRun") @@ -421,17 +374,17 @@ func (_m *ORM) StoreRun(run *pipeline.Run, qopts ...pg.QOpt) (bool, error) { var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, ...pg.QOpt) (bool, error)); ok { - return rf(run, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run) (bool, error)); ok { + return rf(ctx, run) } - if rf, ok := ret.Get(0).(func(*pipeline.Run, ...pg.QOpt) bool); ok { - r0 = rf(run, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run) bool); ok { + r0 = rf(ctx, run) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(*pipeline.Run, ...pg.QOpt) error); ok { - r1 = rf(run, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, *pipeline.Run) error); ok { + r1 = rf(ctx, run) } else { r1 = ret.Error(1) } @@ -439,9 +392,27 @@ func (_m *ORM) StoreRun(run *pipeline.Run, qopts ...pg.QOpt) (bool, error) { return r0, r1 } -// UpdateTaskRunResult provides a mock function with given fields: taskID, result -func (_m *ORM) UpdateTaskRunResult(taskID uuid.UUID, result pipeline.Result) (pipeline.Run, bool, error) { - ret := _m.Called(taskID, result) +// Transact provides a mock function with given fields: _a0, _a1 +func (_m *ORM) Transact(_a0 context.Context, _a1 func(pipeline.ORM) error) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Transact") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, func(pipeline.ORM) error) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateTaskRunResult provides a mock function with given fields: ctx, taskID, result +func (_m *ORM) UpdateTaskRunResult(ctx context.Context, taskID uuid.UUID, result pipeline.Result) (pipeline.Run, bool, error) { + ret := _m.Called(ctx, taskID, result) if len(ret) == 0 { panic("no return value specified for UpdateTaskRunResult") @@ -450,23 +421,23 @@ func (_m *ORM) UpdateTaskRunResult(taskID uuid.UUID, result pipeline.Result) (pi var r0 pipeline.Run var r1 bool var r2 error - if rf, ok := ret.Get(0).(func(uuid.UUID, pipeline.Result) (pipeline.Run, bool, error)); ok { - return rf(taskID, result) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, pipeline.Result) (pipeline.Run, bool, error)); ok { + return rf(ctx, taskID, result) } - if rf, ok := ret.Get(0).(func(uuid.UUID, pipeline.Result) pipeline.Run); ok { - r0 = rf(taskID, result) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, pipeline.Result) pipeline.Run); ok { + r0 = rf(ctx, taskID, result) } else { r0 = ret.Get(0).(pipeline.Run) } - if rf, ok := ret.Get(1).(func(uuid.UUID, pipeline.Result) bool); ok { - r1 = rf(taskID, result) + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID, pipeline.Result) bool); ok { + r1 = rf(ctx, taskID, result) } else { r1 = ret.Get(1).(bool) } - if rf, ok := ret.Get(2).(func(uuid.UUID, pipeline.Result) error); ok { - r2 = rf(taskID, result) + if rf, ok := ret.Get(2).(func(context.Context, uuid.UUID, pipeline.Result) error); ok { + r2 = rf(ctx, taskID, result) } else { r2 = ret.Error(2) } @@ -474,6 +445,26 @@ func (_m *ORM) UpdateTaskRunResult(taskID uuid.UUID, result pipeline.Result) (pi return r0, r1, r2 } +// WithDataSource provides a mock function with given fields: _a0 +func (_m *ORM) WithDataSource(_a0 sqlutil.DataSource) pipeline.ORM { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for WithDataSource") + } + + var r0 pipeline.ORM + if rf, ok := ret.Get(0).(func(sqlutil.DataSource) pipeline.ORM); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(pipeline.ORM) + } + } + + return r0 +} + // NewORM creates a new instance of ORM. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewORM(t interface { diff --git a/core/services/pipeline/mocks/runner.go b/core/services/pipeline/mocks/runner.go index 3de2703f0c7..e0378399f58 100644 --- a/core/services/pipeline/mocks/runner.go +++ b/core/services/pipeline/mocks/runner.go @@ -8,10 +8,10 @@ import ( logger "github.com/smartcontractkit/chainlink/v2/core/logger" mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" - pipeline "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + uuid "github.com/google/uuid" ) @@ -164,24 +164,17 @@ func (_m *Runner) InitializePipeline(spec pipeline.Spec) (*pipeline.Pipeline, er return r0, r1 } -// InsertFinishedRun provides a mock function with given fields: run, saveSuccessfulTaskRuns, qopts -func (_m *Runner) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRun provides a mock function with given fields: ctx, ds, run, saveSuccessfulTaskRuns +func (_m *Runner) InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, ds, run, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRun") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(run, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, *pipeline.Run, bool) error); ok { + r0 = rf(ctx, ds, run, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -189,24 +182,17 @@ func (_m *Runner) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bo return r0 } -// InsertFinishedRuns provides a mock function with given fields: runs, saveSuccessfulTaskRuns, qopts -func (_m *Runner) InsertFinishedRuns(runs []*pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, runs, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRuns provides a mock function with given fields: ctx, ds, runs, saveSuccessfulTaskRuns +func (_m *Runner) InsertFinishedRuns(ctx context.Context, ds sqlutil.DataSource, runs []*pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, ds, runs, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRuns") } var r0 error - if rf, ok := ret.Get(0).(func([]*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(runs, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, []*pipeline.Run, bool) error); ok { + r0 = rf(ctx, ds, runs, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -255,17 +241,17 @@ func (_m *Runner) Ready() error { return r0 } -// ResumeRun provides a mock function with given fields: taskID, value, err -func (_m *Runner) ResumeRun(taskID uuid.UUID, value interface{}, err error) error { - ret := _m.Called(taskID, value, err) +// ResumeRun provides a mock function with given fields: ctx, taskID, value, err +func (_m *Runner) ResumeRun(ctx context.Context, taskID uuid.UUID, value interface{}, err error) error { + ret := _m.Called(ctx, taskID, value, err) if len(ret) == 0 { panic("no return value specified for ResumeRun") } var r0 error - if rf, ok := ret.Get(0).(func(uuid.UUID, interface{}, error) error); ok { - r0 = rf(taskID, value, err) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, interface{}, error) error); ok { + r0 = rf(ctx, taskID, value, err) } else { r0 = ret.Error(0) } @@ -274,7 +260,7 @@ func (_m *Runner) ResumeRun(taskID uuid.UUID, value interface{}, err error) erro } // Run provides a mock function with given fields: ctx, run, l, saveSuccessfulTaskRuns, fn -func (_m *Runner) Run(ctx context.Context, run *pipeline.Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(pg.Queryer) error) (bool, error) { +func (_m *Runner) Run(ctx context.Context, run *pipeline.Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(sqlutil.DataSource) error) (bool, error) { ret := _m.Called(ctx, run, l, saveSuccessfulTaskRuns, fn) if len(ret) == 0 { @@ -283,16 +269,16 @@ func (_m *Runner) Run(ctx context.Context, run *pipeline.Run, l logger.Logger, s var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(pg.Queryer) error) (bool, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(sqlutil.DataSource) error) (bool, error)); ok { return rf(ctx, run, l, saveSuccessfulTaskRuns, fn) } - if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(pg.Queryer) error) bool); ok { + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(sqlutil.DataSource) error) bool); ok { r0 = rf(ctx, run, l, saveSuccessfulTaskRuns, fn) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(pg.Queryer) error) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(sqlutil.DataSource) error) error); ok { r1 = rf(ctx, run, l, saveSuccessfulTaskRuns, fn) } else { r1 = ret.Error(1) diff --git a/core/services/pipeline/orm.go b/core/services/pipeline/orm.go index c32693e4db4..3bebfb8cbad 100644 --- a/core/services/pipeline/orm.go +++ b/core/services/pipeline/orm.go @@ -14,6 +14,7 @@ import ( "github.com/jmoiron/sqlx" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/pg" @@ -71,33 +72,42 @@ const KeepersObservationSource = ` encode_check_upkeep_tx -> check_upkeep_tx -> decode_check_upkeep_tx -> calculate_perform_data_len -> perform_data_lessthan_limit -> check_perform_data_limit -> encode_perform_upkeep_tx -> simulate_perform_upkeep_tx -> decode_check_perform_tx -> check_success -> perform_upkeep_tx ` +type CreateDataSource interface { + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error +} + //go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { services.Service - CreateSpec(pipeline Pipeline, maxTaskTimeout models.Interval, qopts ...pg.QOpt) (int32, error) - CreateRun(run *Run, qopts ...pg.QOpt) (err error) - InsertRun(run *Run, qopts ...pg.QOpt) error - DeleteRun(id int64) error - StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) - UpdateTaskRunResult(taskID uuid.UUID, result Result) (run Run, start bool, err error) - InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) - InsertFinishedRunWithSpec(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) + + // ds is optional and to be removed after completing https://smartcontract-it.atlassian.net/browse/BCF-2978 + CreateSpec(ctx context.Context, ds CreateDataSource, pipeline Pipeline, maxTaskTimeout models.Interval) (int32, error) + CreateRun(ctx context.Context, run *Run) (err error) + InsertRun(ctx context.Context, run *Run) error + DeleteRun(ctx context.Context, id int64) error + StoreRun(ctx context.Context, run *Run) (restart bool, err error) + UpdateTaskRunResult(ctx context.Context, taskID uuid.UUID, result Result) (run Run, start bool, err error) + InsertFinishedRun(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) (err error) + InsertFinishedRunWithSpec(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) (err error) // InsertFinishedRuns inserts all the given runs into the database. // If saveSuccessfulTaskRuns is false, only errored runs are saved. - InsertFinishedRuns(run []*Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) + InsertFinishedRuns(ctx context.Context, run []*Run, saveSuccessfulTaskRuns bool) (err error) DeleteRunsOlderThan(context.Context, time.Duration) error - FindRun(id int64) (Run, error) - GetAllRuns() ([]Run, error) + FindRun(ctx context.Context, id int64) (Run, error) + GetAllRuns(ctx context.Context) ([]Run, error) GetUnfinishedRuns(context.Context, time.Time, func(run Run) error) error - GetQ() pg.Q + + DataSource() sqlutil.DataSource + WithDataSource(sqlutil.DataSource) ORM + Transact(context.Context, func(ORM) error) error } type orm struct { services.StateMachine - q pg.Q + ds sqlutil.DataSource lggr logger.Logger maxSuccessfulRuns uint64 // jobID => count @@ -109,17 +119,14 @@ type orm struct { var _ ORM = (*orm)(nil) -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, jobPipelineMaxSuccessfulRuns uint64) *orm { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger, jobPipelineMaxSuccessfulRuns uint64) *orm { ctx, cancel := context.WithCancel(context.Background()) return &orm{ - services.StateMachine{}, - pg.NewQ(db, lggr, cfg), - lggr.Named("PipelineORM"), - jobPipelineMaxSuccessfulRuns, - sync.Map{}, - sync.WaitGroup{}, - ctx, - cancel, + ds: ds, + lggr: lggr.Named("PipelineORM"), + maxSuccessfulRuns: jobPipelineMaxSuccessfulRuns, + ctx: ctx, + cncl: cancel, } } @@ -152,23 +159,56 @@ func (o *orm) HealthReport() map[string]error { return map[string]error{o.Name(): o.Healthy()} } -func (o *orm) CreateSpec(pipeline Pipeline, maxTaskDuration models.Interval, qopts ...pg.QOpt) (id int32, err error) { - q := o.q.WithOpts(qopts...) +func (o *orm) Transact(ctx context.Context, fn func(ORM) error) error { + return sqlutil.Transact(ctx, func(tx sqlutil.DataSource) ORM { + return o.withDataSource(tx) + }, o.ds, nil, func(tx ORM) error { + defer func() { + if err := tx.Close(); err != nil { + o.lggr.Warnw("Error closing temporary transactional ORM", "err", err) + } + }() + return fn(tx) + }) +} + +func (o *orm) DataSource() sqlutil.DataSource { return o.ds } + +func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { return o.withDataSource(ds) } + +func (o *orm) withDataSource(ds sqlutil.DataSource) *orm { + ctx, cancel := context.WithCancel(context.Background()) + return &orm{ + ds: ds, + lggr: o.lggr, + maxSuccessfulRuns: o.maxSuccessfulRuns, + ctx: ctx, + cncl: cancel, + } +} + +func (o *orm) transact(ctx context.Context, fn func(*orm) error) error { + return sqlutil.Transact(ctx, o.withDataSource, o.ds, nil, fn) +} + +func (o *orm) CreateSpec(ctx context.Context, ds CreateDataSource, pipeline Pipeline, maxTaskDuration models.Interval) (id int32, err error) { sql := `INSERT INTO pipeline_specs (dot_dag_source, max_task_duration, created_at) VALUES ($1, $2, NOW()) RETURNING id;` - err = q.Get(&id, sql, pipeline.Source, maxTaskDuration) + if ds == nil { + ds = o.ds + } + err = ds.GetContext(ctx, &id, sql, pipeline.Source, maxTaskDuration) return id, errors.WithStack(err) } -func (o *orm) CreateRun(run *Run, qopts ...pg.QOpt) (err error) { +func (o *orm) CreateRun(ctx context.Context, run *Run) (err error) { if run.CreatedAt.IsZero() { return errors.New("run.CreatedAt must be set") } - q := o.q.WithOpts(qopts...) - err = q.Transaction(func(tx pg.Queryer) error { - if e := o.InsertRun(run, pg.WithQueryer(tx)); e != nil { + err = o.transact(ctx, func(tx *orm) error { + if e := tx.InsertRun(ctx, run); e != nil { return errors.Wrap(e, "error inserting pipeline_run") } @@ -182,10 +222,9 @@ func (o *orm) CreateRun(run *Run, qopts ...pg.QOpt) (err error) { run.PipelineTaskRuns[i].PipelineRunID = run.ID } - sql := ` - INSERT INTO pipeline_task_runs (pipeline_run_id, id, type, index, output, error, dot_id, created_at) + sql := `INSERT INTO pipeline_task_runs (pipeline_run_id, id, type, index, output, error, dot_id, created_at) VALUES (:pipeline_run_id, :id, :type, :index, :output, :error, :dot_id, :created_at);` - _, err = tx.NamedExec(sql, run.PipelineTaskRuns) + _, err = tx.ds.NamedExecContext(ctx, sql, run.PipelineTaskRuns) return err }) @@ -193,33 +232,34 @@ func (o *orm) CreateRun(run *Run, qopts ...pg.QOpt) (err error) { } // InsertRun inserts a run into the database -func (o *orm) InsertRun(run *Run, qopts ...pg.QOpt) error { +func (o *orm) InsertRun(ctx context.Context, run *Run) error { if run.Status() == RunStatusCompleted { - defer o.Prune(o.q, run.PruningKey) + defer o.prune(o.ds, run.PruningKey) } - q := o.q.WithOpts(qopts...) - sql := `INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) + query, args, err := o.ds.BindNamed(`INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) VALUES (:pipeline_spec_id, :pruning_key, :meta, :all_errors, :fatal_errors, :inputs, :outputs, :created_at, :finished_at, :state) - RETURNING *;` - return q.GetNamed(sql, run, run) + RETURNING *;`, run) + if err != nil { + return fmt.Errorf("error binding arg: %w", err) + } + return o.ds.GetContext(ctx, run, query, args...) } // StoreRun will persist a partially executed run before suspending, or finish a run. // If `restart` is true, then new task run data is available and the run should be resumed immediately. -func (o *orm) StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) { - q := o.q.WithOpts(qopts...) - err = q.Transaction(func(tx pg.Queryer) error { +func (o *orm) StoreRun(ctx context.Context, run *Run) (restart bool, err error) { + err = o.transact(ctx, func(tx *orm) error { finished := run.FinishedAt.Valid if !finished { // Lock the current run. This prevents races with /v2/resume sql := `SELECT id FROM pipeline_runs WHERE id = $1 FOR UPDATE;` - if _, err = tx.Exec(sql, run.ID); err != nil { + if _, err = tx.ds.ExecContext(ctx, sql, run.ID); err != nil { return errors.Wrap(err, "StoreRun") } taskRuns := []TaskRun{} // Reload task runs, we want to check for any changes while the run was ongoing - if err = sqlx.Select(tx, &taskRuns, `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = $1`, run.ID); err != nil { + if err = tx.ds.SelectContext(ctx, &taskRuns, `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = $1`, run.ID); err != nil { return errors.Wrap(err, "StoreRun") } @@ -246,17 +286,17 @@ func (o *orm) StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) { // Suspend the run run.State = RunStatusSuspended - if _, err = sqlx.NamedExec(tx, `UPDATE pipeline_runs SET state = :state WHERE id = :id`, run); err != nil { + if _, err = tx.ds.NamedExecContext(ctx, `UPDATE pipeline_runs SET state = :state WHERE id = :id`, run); err != nil { return errors.Wrap(err, "StoreRun") } } else { - defer o.Prune(tx, run.PruningKey) + defer o.prune(tx.ds, run.PruningKey) // Simply finish the run, no need to do any sort of locking if run.Outputs.Val == nil || len(run.FatalErrors)+len(run.AllErrors) == 0 { return errors.Errorf("run must have both Outputs and Errors, got Outputs: %#v, FatalErrors: %#v, AllErrors: %#v", run.Outputs.Val, run.FatalErrors, run.AllErrors) } sql := `UPDATE pipeline_runs SET state = :state, finished_at = :finished_at, all_errors= :all_errors, fatal_errors= :fatal_errors, outputs = :outputs WHERE id = :id` - if _, err = sqlx.NamedExec(tx, sql, run); err != nil { + if _, err = tx.ds.NamedExecContext(ctx, sql, run); err != nil { return errors.Wrap(err, "StoreRun") } } @@ -272,7 +312,7 @@ func (o *orm) StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) { // NOTE: can't use Select() to auto scan because we're using NamedQuery, // sqlx.Named + Select is possible but it's about the same amount of code var rows *sqlx.Rows - rows, err = sqlx.NamedQuery(tx, sql, run.PipelineTaskRuns) + rows, err = sqlx.NamedQueryContext(ctx, tx.ds, sql, run.PipelineTaskRuns) if err != nil { return errors.Wrap(err, "StoreRun") } @@ -288,17 +328,17 @@ func (o *orm) StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) { } // DeleteRun cleans up a run that failed and is marked failEarly (should leave no trace of the run) -func (o *orm) DeleteRun(id int64) error { +func (o *orm) DeleteRun(ctx context.Context, id int64) error { // NOTE: this will cascade and wipe pipeline_task_runs too - _, err := o.q.Exec(`DELETE FROM pipeline_runs WHERE id = $1`, id) + _, err := o.ds.ExecContext(ctx, `DELETE FROM pipeline_runs WHERE id = $1`, id) return err } -func (o *orm) UpdateTaskRunResult(taskID uuid.UUID, result Result) (run Run, start bool, err error) { +func (o *orm) UpdateTaskRunResult(ctx context.Context, taskID uuid.UUID, result Result) (run Run, start bool, err error) { if result.OutputDB().Valid && result.ErrorDB().Valid { panic("run result must specify either output or error, not both") } - err = o.q.Transaction(func(tx pg.Queryer) error { + err = o.transact(ctx, func(tx *orm) error { sql := ` SELECT pipeline_runs.*, pipeline_specs.dot_dag_source "pipeline_spec.dot_dag_source", job_pipeline_specs.job_id "job_id" FROM pipeline_runs @@ -307,13 +347,13 @@ func (o *orm) UpdateTaskRunResult(taskID uuid.UUID, result Result) (run Run, sta JOIN job_pipeline_specs ON (job_pipeline_specs.pipeline_spec_id = pipeline_specs.id) WHERE pipeline_task_runs.id = $1 AND pipeline_runs.state in ('running', 'suspended') FOR UPDATE` - if err = tx.Get(&run, sql, taskID); err != nil { + if err = tx.ds.GetContext(ctx, &run, sql, taskID); err != nil { return fmt.Errorf("failed to find pipeline run for task ID %s: %w", taskID.String(), err) } // Update the task with result sql = `UPDATE pipeline_task_runs SET output = $2, error = $3, finished_at = $4 WHERE id = $1` - if _, err = tx.Exec(sql, taskID, result.OutputDB(), result.ErrorDB(), time.Now()); err != nil { + if _, err = tx.ds.ExecContext(ctx, sql, taskID, result.OutputDB(), result.ErrorDB(), time.Now()); err != nil { return fmt.Errorf("failed to update pipeline task run: %w", err) } @@ -322,21 +362,20 @@ func (o *orm) UpdateTaskRunResult(taskID uuid.UUID, result Result) (run Run, sta run.State = RunStatusRunning sql = `UPDATE pipeline_runs SET state = $2 WHERE id = $1` - if _, err = tx.Exec(sql, run.ID, run.State); err != nil { + if _, err = tx.ds.ExecContext(ctx, sql, run.ID, run.State); err != nil { return fmt.Errorf("failed to update pipeline run state: %w", err) } } - return loadAssociations(tx, []*Run{&run}) + return loadAssociations(ctx, tx.ds, []*Run{&run}) }) return run, start, err } // InsertFinishedRuns inserts all the given runs into the database. -func (o *orm) InsertFinishedRuns(runs []*Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) - err := q.Transaction(func(tx pg.Queryer) error { +func (o *orm) InsertFinishedRuns(ctx context.Context, runs []*Run, saveSuccessfulTaskRuns bool) error { + err := o.transact(ctx, func(tx *orm) error { pipelineRunsQuery := ` INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) @@ -344,7 +383,7 @@ VALUES (:pipeline_spec_id, :pruning_key, :meta, :all_errors, :fatal_errors, :inputs, :outputs, :created_at, :finished_at, :state) RETURNING id ` - rows, errQ := tx.NamedQuery(pipelineRunsQuery, runs) + rows, errQ := sqlx.NamedQueryContext(ctx, tx.ds, pipelineRunsQuery, runs) if errQ != nil { return errors.Wrap(errQ, "inserting finished pipeline runs") } @@ -369,7 +408,7 @@ RETURNING id defer func() { for pruningKey := range pruningKeysm { - o.Prune(tx, pruningKey) + o.prune(tx.ds, pruningKey) } }() @@ -385,7 +424,7 @@ VALUES (:pipeline_run_id, :id, :type, :index, :output, :error, :dot_id, :created pipelineTaskRuns = append(pipelineTaskRuns, run.PipelineTaskRuns...) } - _, errE := tx.NamedExec(pipelineTaskRunsQuery, pipelineTaskRuns) + _, errE := tx.ds.NamedExecContext(ctx, pipelineTaskRunsQuery, pipelineTaskRuns) return errors.Wrap(errE, "insert pipeline task runs") }) return errors.Wrap(err, "InsertFinishedRuns failed") @@ -411,7 +450,7 @@ func (o *orm) checkFinishedRun(run *Run, saveSuccessfulTaskRuns bool) error { // If saveSuccessfulTaskRuns = false, we only save errored runs. // That way if the job is run frequently (such as OCR) we avoid saving a large number of successful task runs // which do not provide much value. -func (o *orm) InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) { +func (o *orm) InsertFinishedRun(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) (err error) { if err = o.checkFinishedRun(run, saveSuccessfulTaskRuns); err != nil { return err } @@ -421,13 +460,12 @@ func (o *orm) InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ... return nil } - q := o.q.WithOpts(qopts...) - err = q.Transaction(o.insertFinishedRunTx(run, saveSuccessfulTaskRuns)) + err = o.insertFinishedRun(ctx, run, saveSuccessfulTaskRuns) return errors.Wrap(err, "InsertFinishedRun failed") } // InsertFinishedRunWithSpec works like InsertFinishedRun but also inserts the pipeline spec. -func (o *orm) InsertFinishedRunWithSpec(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) { +func (o *orm) InsertFinishedRunWithSpec(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) (err error) { if err = o.checkFinishedRun(run, saveSuccessfulTaskRuns); err != nil { return err } @@ -437,57 +475,55 @@ func (o *orm) InsertFinishedRunWithSpec(run *Run, saveSuccessfulTaskRuns bool, q return nil } - q := o.q.WithOpts(qopts...) - err = q.Transaction(func(tx pg.Queryer) error { + err = o.transact(ctx, func(tx *orm) error { sqlStmt1 := `INSERT INTO pipeline_specs (dot_dag_source, max_task_duration, created_at) VALUES ($1, $2, NOW()) RETURNING id;` - err = tx.Get(&run.PipelineSpecID, sqlStmt1, run.PipelineSpec.DotDagSource, run.PipelineSpec.MaxTaskDuration) + err = tx.ds.GetContext(ctx, &run.PipelineSpecID, sqlStmt1, run.PipelineSpec.DotDagSource, run.PipelineSpec.MaxTaskDuration) if err != nil { return errors.Wrap(err, "failed to insert pipeline_specs") } // This `job_pipeline_specs` record won't be primary since when this method is called, the job already exists, so it will have primary record. sqlStmt2 := `INSERT INTO job_pipeline_specs (job_id, pipeline_spec_id, is_primary) VALUES ($1, $2, false)` - _, err = tx.Exec(sqlStmt2, run.JobID, run.PipelineSpecID) + _, err = tx.ds.ExecContext(ctx, sqlStmt2, run.JobID, run.PipelineSpecID) if err != nil { return errors.Wrap(err, "failed to insert job_pipeline_specs") } - return o.insertFinishedRunTx(run, saveSuccessfulTaskRuns)(tx) + return tx.insertFinishedRun(ctx, run, saveSuccessfulTaskRuns) }) return errors.Wrap(err, "InsertFinishedRun failed") } -func (o *orm) insertFinishedRunTx(run *Run, saveSuccessfulTaskRuns bool) func(tx pg.Queryer) error { - return func(tx pg.Queryer) error { - sql := `INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) +func (o *orm) insertFinishedRun(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) error { + sql := `INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) VALUES (:pipeline_spec_id, :pruning_key, :meta, :all_errors, :fatal_errors, :inputs, :outputs, :created_at, :finished_at, :state) RETURNING id;` - query, args, e := tx.BindNamed(sql, run) - if e != nil { - return errors.Wrap(e, "failed to bind") - } + query, args, err := o.ds.BindNamed(sql, run) + if err != nil { + return errors.Wrap(err, "failed to bind") + } - if err := tx.QueryRowx(query, args...).Scan(&run.ID); err != nil { - return errors.Wrap(err, "error inserting finished pipeline_run") - } + if err = o.ds.QueryRowxContext(ctx, query, args...).Scan(&run.ID); err != nil { + return errors.Wrap(err, "error inserting finished pipeline_run") + } - // update the ID key everywhere - for i := range run.PipelineTaskRuns { - run.PipelineTaskRuns[i].PipelineRunID = run.ID - } + // update the ID key everywhere + for i := range run.PipelineTaskRuns { + run.PipelineTaskRuns[i].PipelineRunID = run.ID + } - if !saveSuccessfulTaskRuns && !run.HasErrors() { - return nil - } + if !saveSuccessfulTaskRuns && !run.HasErrors() { + return nil + } - defer o.Prune(tx, run.PruningKey) - sql = ` + defer o.prune(o.ds, run.PruningKey) + sql = ` INSERT INTO pipeline_task_runs (pipeline_run_id, id, type, index, output, error, dot_id, created_at, finished_at) VALUES (:pipeline_run_id, :id, :type, :index, :output, :error, :dot_id, :created_at, :finished_at);` - _, err := tx.NamedExec(sql, run.PipelineTaskRuns) - return errors.Wrap(err, "failed to insert pipeline_task_runs") - } + _, err = o.ds.NamedExecContext(ctx, sql, run.PipelineTaskRuns) + return errors.Wrap(err, "failed to insert pipeline_task_runs") + } // DeleteRunsOlderThan deletes all pipeline_runs that have been finished for a certain threshold to free DB space @@ -495,14 +531,12 @@ func (o *orm) insertFinishedRunTx(run *Run, saveSuccessfulTaskRuns bool) func(tx func (o *orm) DeleteRunsOlderThan(ctx context.Context, threshold time.Duration) error { start := time.Now() - q := o.q.WithOpts(pg.WithParentCtxInheritTimeout(ctx)) - queryThreshold := start.Add(-threshold) rowsDeleted := int64(0) err := pg.Batch(func(_, limit uint) (count uint, err error) { - result, cancel, err := q.ExecQIter(` + result, err := o.ds.ExecContext(ctx, ` WITH batched_pipeline_runs AS ( SELECT * FROM pipeline_runs WHERE finished_at < ($1) @@ -515,7 +549,6 @@ WHERE pipeline_runs.id = batched_pipeline_runs.id`, queryThreshold, limit, ) - defer cancel() if err != nil { return count, errors.Wrap(err, "DeleteRunsOlderThan failed to delete old pipeline_runs") } @@ -539,7 +572,7 @@ WHERE pipeline_runs.id = batched_pipeline_runs.id`, o.lggr.Debugw("pipeline_runs reaper VACUUM ANALYZE query completed", "duration", time.Since(start)) }(deleteTS) - err = q.ExecQ("VACUUM ANALYZE pipeline_runs") + _, err = o.ds.ExecContext(ctx, "VACUUM ANALYZE pipeline_runs") if err != nil { o.lggr.Warnw("DeleteRunsOlderThan successfully deleted old pipeline_runs rows, but failed to run VACUUM ANALYZE", "err", err) return nil @@ -548,13 +581,13 @@ WHERE pipeline_runs.id = batched_pipeline_runs.id`, return nil } -func (o *orm) FindRun(id int64) (r Run, err error) { +func (o *orm) FindRun(ctx context.Context, id int64) (r Run, err error) { var runs []*Run - err = o.q.Transaction(func(tx pg.Queryer) error { - if err = tx.Select(&runs, `SELECT * from pipeline_runs WHERE id = $1 LIMIT 1`, id); err != nil { + err = o.transact(ctx, func(tx *orm) error { + if err = tx.ds.SelectContext(ctx, &runs, `SELECT * from pipeline_runs WHERE id = $1 LIMIT 1`, id); err != nil { return errors.Wrap(err, "failed to load runs") } - return loadAssociations(tx, runs) + return loadAssociations(ctx, tx.ds, runs) }) if len(runs) == 0 { return r, sql.ErrNoRows @@ -562,15 +595,15 @@ func (o *orm) FindRun(id int64) (r Run, err error) { return *runs[0], err } -func (o *orm) GetAllRuns() (runs []Run, err error) { +func (o *orm) GetAllRuns(ctx context.Context) (runs []Run, err error) { var runsPtrs []*Run - err = o.q.Transaction(func(tx pg.Queryer) error { - err = tx.Select(&runsPtrs, `SELECT * from pipeline_runs ORDER BY created_at ASC, id ASC`) + err = o.transact(ctx, func(tx *orm) error { + err = tx.ds.SelectContext(ctx, &runsPtrs, `SELECT * from pipeline_runs ORDER BY created_at ASC, id ASC`) if err != nil { return errors.Wrap(err, "failed to load runs") } - return loadAssociations(tx, runsPtrs) + return loadAssociations(ctx, tx.ds, runsPtrs) }) runs = make([]Run, len(runsPtrs)) for i, runPtr := range runsPtrs { @@ -580,17 +613,16 @@ func (o *orm) GetAllRuns() (runs []Run, err error) { } func (o *orm) GetUnfinishedRuns(ctx context.Context, now time.Time, fn func(run Run) error) error { - q := o.q.WithOpts(pg.WithParentCtx(ctx)) return pg.Batch(func(offset, limit uint) (count uint, err error) { var runs []*Run - err = q.Transaction(func(tx pg.Queryer) error { - err = tx.Select(&runs, `SELECT * from pipeline_runs WHERE state = $1 AND created_at < $2 ORDER BY created_at ASC, id ASC OFFSET $3 LIMIT $4`, RunStatusRunning, now, offset, limit) + err = o.transact(ctx, func(tx *orm) error { + err = tx.ds.SelectContext(ctx, &runs, `SELECT * from pipeline_runs WHERE state = $1 AND created_at < $2 ORDER BY created_at ASC, id ASC OFFSET $3 LIMIT $4`, RunStatusRunning, now, offset, limit) if err != nil { return errors.Wrap(err, "failed to load runs") } - err = loadAssociations(tx, runs) + err = loadAssociations(ctx, tx.ds, runs) if err != nil { return err } @@ -608,7 +640,7 @@ func (o *orm) GetUnfinishedRuns(ctx context.Context, now time.Time, fn func(run } // loads PipelineSpec and PipelineTaskRuns for Runs in exactly 2 queries -func loadAssociations(q pg.Queryer, runs []*Run) error { +func loadAssociations(ctx context.Context, ds sqlutil.DataSource, runs []*Run) error { if len(runs) == 0 { return nil } @@ -635,7 +667,7 @@ func loadAssociations(q pg.Queryer, runs []*Run) error { LEFT JOIN job_pipeline_specs jps ON jps.pipeline_spec_id=ps.id LEFT JOIN jobs ON jobs.id=jps.job_id WHERE ps.id = ANY($1)` - if err := q.Select(&specs, sqlQuery, pipelineSpecIDs); err != nil { + if err := ds.SelectContext(ctx, &specs, sqlQuery, pipelineSpecIDs); err != nil { return errors.Wrap(err, "failed to postload pipeline_specs for runs") } for _, spec := range specs { @@ -647,7 +679,7 @@ func loadAssociations(q pg.Queryer, runs []*Run) error { var taskRuns []TaskRun taskRunPRIDM := make(map[int64][]TaskRun, len(runs)) // keyed by pipelineRunID - if err := q.Select(&taskRuns, `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = ANY($1) ORDER BY created_at ASC, id ASC`, pipelineRunIDs); err != nil { + if err := ds.SelectContext(ctx, &taskRuns, `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = ANY($1) ORDER BY created_at ASC, id ASC`, pipelineRunIDs); err != nil { return errors.Wrap(err, "failed to postload pipeline_task_runs for runs") } for _, taskRun := range taskRuns { @@ -662,10 +694,6 @@ func loadAssociations(q pg.Queryer, runs []*Run) error { return nil } -func (o *orm) GetQ() pg.Q { - return o.q -} - func (o *orm) loadCount(jobID int32) *atomic.Uint64 { // fast path; avoids allocation actual, exists := o.pm.Load(jobID) @@ -681,7 +709,7 @@ func (o *orm) loadCount(jobID int32) *atomic.Uint64 { // this value or higher const syncLimit = 1000 -// Prune attempts to keep the pipeline_runs table capped close to the +// prune attempts to keep the pipeline_runs table capped close to the // maxSuccessfulRuns length for each job_id. // // It does this synchronously for small values and async/sampled for large @@ -689,13 +717,13 @@ const syncLimit = 1000 // // Note this does not guarantee the pipeline_runs table is kept to exactly the // max length, rather that it doesn't excessively larger than it. -func (o *orm) Prune(tx pg.Queryer, jobID int32) { +func (o *orm) prune(ds sqlutil.DataSource, jobID int32) { if jobID == 0 { o.lggr.Panic("expected a non-zero job ID") } // For small maxSuccessfulRuns its fast enough to prune every time if o.maxSuccessfulRuns < syncLimit { - o.execPrune(tx, jobID) + o.execPrune(o.ctx, ds, jobID) return } // for large maxSuccessfulRuns we do it async on a sampled basis @@ -708,9 +736,11 @@ func (o *orm) Prune(tx pg.Queryer, jobID int32) { go func() { o.lggr.Debugw("Pruning runs", "jobID", jobID, "count", val, "every", every, "maxSuccessfulRuns", o.maxSuccessfulRuns) defer o.wg.Done() - // Must not use tx here since it's async and the transaction + // Must not use ds here since it's async and the transaction // could be stale - o.execPrune(o.q.WithOpts(pg.WithLongQueryTimeout()), jobID) + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(o.ctx), time.Minute) + defer cancel() + o.execPrune(ctx, o.ds, jobID) }() }) if !ok { @@ -720,8 +750,8 @@ func (o *orm) Prune(tx pg.Queryer, jobID int32) { } } -func (o *orm) execPrune(q pg.Queryer, jobID int32) { - res, err := q.ExecContext(o.ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( +func (o *orm) execPrune(ctx context.Context, ds sqlutil.DataSource, jobID int32) { + res, err := ds.ExecContext(o.ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( SELECT id FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 ORDER BY id DESC @@ -739,7 +769,7 @@ LIMIT $3 if rowsAffected == 0 { // check the spec still exists and garbage collect if necessary var exists bool - if err := q.GetContext(o.ctx, &exists, `SELECT EXISTS(SELECT ps.* FROM pipeline_specs ps JOIN job_pipeline_specs jps ON (ps.id=jps.pipeline_spec_id) WHERE jps.job_id = $1)`, jobID); err != nil { + if err := ds.GetContext(ctx, &exists, `SELECT EXISTS(SELECT ps.* FROM pipeline_specs ps JOIN job_pipeline_specs jps ON (ps.id=jps.pipeline_spec_id) WHERE jps.job_id = $1)`, jobID); err != nil { o.lggr.Errorw("Failed check existence of pipeline_spec while pruning runs", "err", err, "jobID", jobID) return } diff --git a/core/services/pipeline/orm_test.go b/core/services/pipeline/orm_test.go index e5bf319f056..88155bc04ba 100644 --- a/core/services/pipeline/orm_test.go +++ b/core/services/pipeline/orm_test.go @@ -71,7 +71,7 @@ func setupORM(t *testing.T, heavy bool) (db *sqlx.DB, orm pipeline.ORM, jorm job db = pgtest.NewSqlxDB(t) } cfg := ormconfig{pgtest.NewQConfig(true)} - orm = pipeline.NewORM(db, logger.TestLogger(t), cfg, cfg.JobPipelineMaxSuccessfulRuns()) + orm = pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipelineMaxSuccessfulRuns()) config := configtest.NewTestGeneralConfig(t) lggr := logger.TestLogger(t) keyStore := cltest.NewKeyStore(t, db, config.Database()) @@ -91,6 +91,7 @@ func setupLiteORM(t *testing.T) (db *sqlx.DB, orm pipeline.ORM, jorm job.ORM) { } func Test_PipelineORM_CreateSpec(t *testing.T) { + ctx := testutils.Context(t) db, orm, _ := setupLiteORM(t) var ( @@ -102,7 +103,7 @@ func Test_PipelineORM_CreateSpec(t *testing.T) { Source: source, } - id, err := orm.CreateSpec(p, maxTaskDuration) + id, err := orm.CreateSpec(ctx, nil, p, maxTaskDuration) require.NoError(t, err) actual := pipeline.Spec{} @@ -121,7 +122,8 @@ func Test_PipelineORM_FindRun(t *testing.T) { require.NoError(t, err) expected := mustInsertPipelineRun(t, orm) - run, err := orm.FindRun(expected.ID) + ctx := testutils.Context(t) + run, err := orm.FindRun(ctx, expected.ID) require.NoError(t, err) require.Equal(t, expected.ID, run.ID) @@ -138,12 +140,14 @@ func mustInsertPipelineRun(t *testing.T, orm pipeline.ORM) pipeline.Run { FinishedAt: null.Time{}, } - require.NoError(t, orm.InsertRun(&run)) + ctx := testutils.Context(t) + require.NoError(t, orm.InsertRun(ctx, &run)) return run } func mustInsertAsyncRun(t *testing.T, orm pipeline.ORM, jobORM job.ORM) *pipeline.Run { t.Helper() + ctx := testutils.Context(t) s := ` ds1 [type=bridge async=true name="example-bridge" timeout=0 requestData=<{"data": {"coin": "BTC", "market": "USD"}}>] @@ -178,12 +182,13 @@ answer2 [type=bridge name=election_winner index=1]; CreatedAt: time.Now(), } - err = orm.CreateRun(run) + err = orm.CreateRun(ctx, run) require.NoError(t, err) return run } func TestInsertFinishedRuns(t *testing.T) { + ctx := testutils.Context(t) db, orm, _ := setupLiteORM(t) _, err := db.Exec(`SET CONSTRAINTS fk_pipeline_runs_pruning_key DEFERRED`) @@ -207,7 +212,7 @@ func TestInsertFinishedRuns(t *testing.T) { Outputs: jsonserializable.JSONSerializable{}, } - require.NoError(t, orm.InsertRun(&r)) + require.NoError(t, orm.InsertRun(ctx, &r)) r.PipelineTaskRuns = []pipeline.TaskRun{ { @@ -238,12 +243,13 @@ func TestInsertFinishedRuns(t *testing.T) { runs = append(runs, &r) } - err = orm.InsertFinishedRuns(runs, true) + err = orm.InsertFinishedRuns(ctx, runs, true) require.NoError(t, err) } func Test_PipelineORM_InsertFinishedRunWithSpec(t *testing.T) { + ctx := testutils.Context(t) db, orm, jorm := setupLiteORM(t) s := ` @@ -314,7 +320,7 @@ answer2 [type=bridge name=election_winner index=1]; run.AllErrors = append(run.AllErrors, null.NewString("", false)) run.State = pipeline.RunStatusCompleted - err = orm.InsertFinishedRunWithSpec(run, true) + err = orm.InsertFinishedRunWithSpec(ctx, run, true) require.NoError(t, err) var pipelineSpec pipeline.Spec @@ -330,6 +336,7 @@ answer2 [type=bridge name=election_winner index=1]; // Tests that inserting run results, then later updating the run results via upsert will work correctly. func Test_PipelineORM_StoreRun_ShouldUpsert(t *testing.T) { + ctx := testutils.Context(t) _, orm, jorm := setupLiteORM(t) run := mustInsertAsyncRun(t, orm, jorm) @@ -357,14 +364,14 @@ func Test_PipelineORM_StoreRun_ShouldUpsert(t *testing.T) { FinishedAt: null.TimeFrom(now), }, } - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) require.NoError(t, err) // no new data, so we don't need a restart require.Equal(t, false, restart) // the run is paused require.Equal(t, pipeline.RunStatusSuspended, run.State) - r, err := orm.FindRun(run.ID) + r, err := orm.FindRun(ctx, run.ID) require.NoError(t, err) run = &r // this is an incomplete run, so partial results should be present (regardless of saveSuccessfulTaskRuns) @@ -388,14 +395,14 @@ func Test_PipelineORM_StoreRun_ShouldUpsert(t *testing.T) { FinishedAt: null.TimeFrom(now), }, } - restart, err = orm.StoreRun(run) + restart, err = orm.StoreRun(ctx, run) require.NoError(t, err) // no new data, so we don't need a restart require.Equal(t, false, restart) // the run is paused require.Equal(t, pipeline.RunStatusSuspended, run.State) - r, err = orm.FindRun(run.ID) + r, err = orm.FindRun(ctx, run.ID) require.NoError(t, err) run = &r // this is an incomplete run, so partial results should be present (regardless of saveSuccessfulTaskRuns) @@ -409,11 +416,12 @@ func Test_PipelineORM_StoreRun_ShouldUpsert(t *testing.T) { // Tests that trying to persist a partial run while new data became available (i.e. via /v2/restart) // will detect a restart and update the result data on the Run. func Test_PipelineORM_StoreRun_DetectsRestarts(t *testing.T) { + ctx := testutils.Context(t) db, orm, jorm := setupLiteORM(t) run := mustInsertAsyncRun(t, orm, jorm) - r, err := orm.FindRun(run.ID) + r, err := orm.FindRun(ctx, run.ID) require.NoError(t, err) require.Equal(t, run.Inputs, r.Inputs) @@ -459,7 +467,7 @@ func Test_PipelineORM_StoreRun_DetectsRestarts(t *testing.T) { }, } - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) require.NoError(t, err) // new data available! immediately restart the run require.Equal(t, true, restart) @@ -474,6 +482,7 @@ func Test_PipelineORM_StoreRun_DetectsRestarts(t *testing.T) { } func Test_PipelineORM_StoreRun_UpdateTaskRunResult(t *testing.T) { + ctx := testutils.Context(t) _, orm, jorm := setupLiteORM(t) run := mustInsertAsyncRun(t, orm, jorm) @@ -525,13 +534,13 @@ func Test_PipelineORM_StoreRun_UpdateTaskRunResult(t *testing.T) { require.Equal(t, pipeline.RunStatusRunning, run.State) // Now store a partial run - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) require.NoError(t, err) require.False(t, restart) // assert that run should be in "paused" state require.Equal(t, pipeline.RunStatusSuspended, run.State) - r, start, err := orm.UpdateTaskRunResult(ds1_id, pipeline.Result{Value: "foo"}) + r, start, err := orm.UpdateTaskRunResult(ctx, ds1_id, pipeline.Result{Value: "foo"}) run = &r require.NoError(t, err) assert.Greater(t, run.ID, int64(0)) @@ -555,6 +564,7 @@ func Test_PipelineORM_StoreRun_UpdateTaskRunResult(t *testing.T) { } func Test_PipelineORM_DeleteRun(t *testing.T) { + ctx := testutils.Context(t) _, orm, jorm := setupLiteORM(t) run := mustInsertAsyncRun(t, orm, jorm) @@ -582,21 +592,22 @@ func Test_PipelineORM_DeleteRun(t *testing.T) { FinishedAt: null.TimeFrom(now), }, } - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) require.NoError(t, err) // no new data, so we don't need a restart require.Equal(t, false, restart) // the run is paused require.Equal(t, pipeline.RunStatusSuspended, run.State) - err = orm.DeleteRun(run.ID) + err = orm.DeleteRun(ctx, run.ID) require.NoError(t, err) - _, err = orm.FindRun(run.ID) + _, err = orm.FindRun(ctx, run.ID) require.Error(t, err, "not found") } func Test_PipelineORM_DeleteRunsOlderThan(t *testing.T) { + ctx := testutils.Context(t) _, orm, jorm := setupHeavyORM(t) var runsIds []int64 @@ -623,7 +634,7 @@ func Test_PipelineORM_DeleteRunsOlderThan(t *testing.T) { run.Outputs = jsonserializable.JSONSerializable{Val: 1, Valid: true} run.AllErrors = pipeline.RunErrors{null.StringFrom("SOMETHING")} - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) assert.NoError(t, err) // no new data, so we don't need a restart assert.Equal(t, false, restart) @@ -635,13 +646,14 @@ func Test_PipelineORM_DeleteRunsOlderThan(t *testing.T) { assert.NoError(t, err) for _, runId := range runsIds { - _, err := orm.FindRun(runId) + _, err := orm.FindRun(ctx, runId) require.Error(t, err, "not found") } } func Test_GetUnfinishedRuns_Keepers(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) // The test configures single Keeper job with two running tasks. // GetUnfinishedRuns() expects to catch both running tasks. @@ -650,7 +662,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { lggr := logger.TestLogger(t) db := pgtest.NewSqlxDB(t) keyStore := cltest.NewKeyStore(t, db, config.Database()) - porm := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + porm := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) jorm := job.NewORM(db, porm, bridgeORM, keyStore, lggr, config.Database()) @@ -684,7 +696,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { runID1 := uuid.New() runID2 := uuid.New() - err = porm.CreateRun(&pipeline.Run{ + err = porm.CreateRun(ctx, &pipeline.Run{ PipelineSpecID: keeperJob.PipelineSpecID, PruningKey: keeperJob.ID, State: pipeline.RunStatusRunning, @@ -701,7 +713,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { }) require.NoError(t, err) - err = porm.CreateRun(&pipeline.Run{ + err = porm.CreateRun(ctx, &pipeline.Run{ PipelineSpecID: keeperJob.PipelineSpecID, PruningKey: keeperJob.ID, State: pipeline.RunStatusRunning, @@ -744,6 +756,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) // The test configures single DR job with two task runs: one is running and one is suspended. // GetUnfinishedRuns() expects to catch the one that is running. @@ -752,7 +765,7 @@ func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { lggr := logger.TestLogger(t) db := pgtest.NewSqlxDB(t) keyStore := cltest.NewKeyStore(t, db, config.Database()) - porm := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + porm := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) jorm := job.NewORM(db, porm, bridgeORM, keyStore, lggr, config.Database()) @@ -784,7 +797,7 @@ func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { runningID := uuid.New() - err = porm.CreateRun(&pipeline.Run{ + err = porm.CreateRun(ctx, &pipeline.Run{ PipelineSpecID: drJob.PipelineSpecID, PruningKey: drJob.ID, State: pipeline.RunStatusRunning, @@ -801,7 +814,7 @@ func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { }) require.NoError(t, err) - err = porm.CreateRun(&pipeline.Run{ + err = porm.CreateRun(ctx, &pipeline.Run{ PipelineSpecID: drJob.PipelineSpecID, PruningKey: drJob.ID, State: pipeline.RunStatusSuspended, @@ -846,7 +859,7 @@ func Test_Prune(t *testing.T) { }) lggr, observed := logger.TestLoggerObserved(t, zapcore.DebugLevel) db := pgtest.NewSqlxDB(t) - porm := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + porm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) torm := newTestORM(porm, db) ps1 := cltest.MustInsertPipelineSpec(t, db) diff --git a/core/services/pipeline/runner.go b/core/services/pipeline/runner.go index 08d371716fc..862d2f49178 100644 --- a/core/services/pipeline/runner.go +++ b/core/services/pipeline/runner.go @@ -15,6 +15,7 @@ import ( "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" commonutils "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" @@ -23,7 +24,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/config/env" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/recovery" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/store/models" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -36,15 +36,16 @@ type Runner interface { // Run is a blocking call that will execute the run until no further progress can be made. // If `incomplete` is true, the run is only partially complete and is suspended, awaiting to be resumed when more data comes in. // Note that `saveSuccessfulTaskRuns` value is ignored if the run contains async tasks. - Run(ctx context.Context, run *Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(tx pg.Queryer) error) (incomplete bool, err error) - ResumeRun(taskID uuid.UUID, value interface{}, err error) error + Run(ctx context.Context, run *Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(tx sqlutil.DataSource) error) (incomplete bool, err error) + ResumeRun(ctx context.Context, taskID uuid.UUID, value interface{}, err error) error // ExecuteRun executes a new run in-memory according to a spec and returns the results. // We expect spec.JobID and spec.JobName to be set for logging/prometheus. ExecuteRun(ctx context.Context, spec Spec, vars Vars, l logger.Logger) (run *Run, trrs TaskRunResults, err error) // InsertFinishedRun saves the run results in the database. - InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error - InsertFinishedRuns(runs []*Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error + // ds is an optional override, for example when executing a transaction. + InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *Run, saveSuccessfulTaskRuns bool) error + InsertFinishedRuns(ctx context.Context, ds sqlutil.DataSource, runs []*Run, saveSuccessfulTaskRuns bool) error // ExecuteAndInsertFinishedRun executes a new run in-memory according to a spec, persists and saves the results. // It is a combination of ExecuteRun and InsertFinishedRun. @@ -566,9 +567,9 @@ func (r *runner) ExecuteAndInsertFinishedRun(ctx context.Context, spec Spec, var } if spec.ID == 0 { - err = r.orm.InsertFinishedRunWithSpec(run, saveSuccessfulTaskRuns) + err = r.orm.InsertFinishedRunWithSpec(ctx, run, saveSuccessfulTaskRuns) } else { - err = r.orm.InsertFinishedRun(run, saveSuccessfulTaskRuns) + err = r.orm.InsertFinishedRun(ctx, run, saveSuccessfulTaskRuns) } if err != nil { return 0, trrs, pkgerrors.Wrapf(err, "error inserting finished results for spec ID %v", run.PipelineSpecID) @@ -577,7 +578,7 @@ func (r *runner) ExecuteAndInsertFinishedRun(ctx context.Context, spec Spec, var } -func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(tx pg.Queryer) error) (incomplete bool, err error) { +func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(tx sqlutil.DataSource) error) (incomplete bool, err error) { pipeline, err := r.InitializePipeline(run.PipelineSpec) if err != nil { return false, err @@ -594,8 +595,7 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess preinsert := pipeline.RequiresPreInsert() - q := r.orm.GetQ().WithOpts(pg.WithParentCtx(ctx)) - err = q.Transaction(func(tx pg.Queryer) error { + err = r.orm.Transact(ctx, func(tx ORM) error { // OPTIMISATION: avoid an extra db write if there is no async tasks present or if this is a resumed run if preinsert && run.ID == 0 { now := time.Now() @@ -614,13 +614,13 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess default: } } - if err = r.orm.CreateRun(run, pg.WithQueryer(tx)); err != nil { + if err = tx.CreateRun(ctx, run); err != nil { return err } } if fn != nil { - return fn(tx) + return fn(tx.DataSource()) } return nil }) @@ -634,14 +634,14 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess if preinsert { // FailSilently = run failed and task was marked failEarly. skip StoreRun and instead delete all trace of it if run.FailSilently { - if err = r.orm.DeleteRun(run.ID); err != nil { + if err = r.orm.DeleteRun(ctx, run.ID); err != nil { return false, pkgerrors.Wrap(err, "Run") } return false, nil } var restart bool - restart, err = r.orm.StoreRun(run) + restart, err = r.orm.StoreRun(ctx, run) if err != nil { return false, pkgerrors.Wrapf(err, "error storing run for spec ID %v state %v outputs %v errors %v finished_at %v", run.PipelineSpec.ID, run.State, run.Outputs, run.FatalErrors, run.FinishedAt) @@ -660,7 +660,7 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess return false, nil } - if err = r.orm.InsertFinishedRun(run, saveSuccessfulTaskRuns, pg.WithParentCtx(ctx)); err != nil { + if err = r.orm.InsertFinishedRun(ctx, run, saveSuccessfulTaskRuns); err != nil { return false, pkgerrors.Wrapf(err, "error storing run for spec ID %v", run.PipelineSpec.ID) } } @@ -671,8 +671,8 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess } } -func (r *runner) ResumeRun(taskID uuid.UUID, value interface{}, err error) error { - run, start, err := r.orm.UpdateTaskRunResult(taskID, Result{ +func (r *runner) ResumeRun(ctx context.Context, taskID uuid.UUID, value interface{}, err error) error { + run, start, err := r.orm.UpdateTaskRunResult(ctx, taskID, Result{ Value: value, Error: err, }) @@ -694,12 +694,20 @@ func (r *runner) ResumeRun(taskID uuid.UUID, value interface{}, err error) error return nil } -func (r *runner) InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - return r.orm.InsertFinishedRun(run, saveSuccessfulTaskRuns, qopts...) +func (r *runner) InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *Run, saveSuccessfulTaskRuns bool) error { + orm := r.orm + if ds != nil { + orm = orm.WithDataSource(ds) + } + return orm.InsertFinishedRun(ctx, run, saveSuccessfulTaskRuns) } -func (r *runner) InsertFinishedRuns(runs []*Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - return r.orm.InsertFinishedRuns(runs, saveSuccessfulTaskRuns, qopts...) +func (r *runner) InsertFinishedRuns(ctx context.Context, ds sqlutil.DataSource, runs []*Run, saveSuccessfulTaskRuns bool) error { + orm := r.orm + if ds != nil { + orm = orm.WithDataSource(ds) + } + return orm.InsertFinishedRuns(ctx, runs, saveSuccessfulTaskRuns) } func (r *runner) runReaper() { diff --git a/core/services/pipeline/runner_test.go b/core/services/pipeline/runner_test.go index 52e668339ec..f27a6b35348 100644 --- a/core/services/pipeline/runner_test.go +++ b/core/services/pipeline/runner_test.go @@ -476,7 +476,7 @@ func Test_PipelineRunner_HandleFaultsPersistRun(t *testing.T) { orm.On("GetQ").Return(q).Maybe() orm.On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(1).(*pipeline.Run).ID = 1 }). Return(nil) cfg := configtest.NewTestGeneralConfig(t) @@ -517,7 +517,7 @@ func Test_PipelineRunner_ExecuteAndInsertFinishedRun_SavingTheSpec(t *testing.T) orm.On("GetQ").Return(q).Maybe() orm.On("InsertFinishedRunWithSpec", mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(1).(*pipeline.Run).ID = 1 }). Return(nil) cfg := configtest.NewTestGeneralConfig(t) @@ -642,7 +642,13 @@ func Test_PipelineRunner_AsyncJob_Basic(t *testing.T) { btORM := bridgesMocks.NewORM(t) btORM.On("FindBridge", mock.Anything, bt.Name).Return(*bt, nil) + r, orm := newRunner(t, db, btORM, cfg) + transactCall := orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm pipeline.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(orm)} + }) s := fmt.Sprintf(` ds1 [type=bridge async=true name="%s" timeout=0 requestData=<{"data": {"coin": "BTC", "market": "USD"}}>] @@ -673,11 +679,11 @@ ds5 [type=http method="GET" url="%s" index=2] // Start a new run run := pipeline.NewRun(spec, pipeline.NewVarsFrom(nil)) // we should receive a call to CreateRun because it's contains an async task - orm.On("CreateRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(nil).Run(func(args mock.Arguments) { - run := args.Get(0).(*pipeline.Run) + orm.On("CreateRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(nil).Run(func(args mock.Arguments) { + run := args.Get(1).(*pipeline.Run) run.ID = 1 // give it a valid "id" }).Once() - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(false, nil).Once() + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() lggr := logger.TestLogger(t) incomplete, err := r.Run(testutils.Context(t), run, lggr, false, nil) require.NoError(t, err) @@ -687,7 +693,7 @@ ds5 [type=http method="GET" url="%s" index=2] // TODO: test a pending run that's not marked async=true, that is not allowed // Trigger run resumption with no new data - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() incomplete, err = r.Run(testutils.Context(t), run, lggr, false, nil) require.NoError(t, err) require.Equal(t, true, incomplete) // still incomplete @@ -700,7 +706,7 @@ ds5 [type=http method="GET" url="%s" index=2] Valid: true, } // Trigger run resumption - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(false, nil).Once() + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() incomplete, err = r.Run(testutils.Context(t), run, lggr, false, nil) require.NoError(t, err) require.Equal(t, false, incomplete) // done @@ -770,6 +776,11 @@ func Test_PipelineRunner_AsyncJob_InstantRestart(t *testing.T) { btORM.On("FindBridge", mock.Anything, bt.Name).Return(*bt, nil) r, orm := newRunner(t, db, btORM, cfg) + transactCall := orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm pipeline.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(orm)} + }) s := fmt.Sprintf(` ds1 [type=bridge async=true name="%s" timeout=0 requestData=<{"data": {"coin": "BTC", "market": "USD"}}>] @@ -800,13 +811,13 @@ ds5 [type=http method="GET" url="%s" index=2] // Start a new run run := pipeline.NewRun(spec, pipeline.NewVarsFrom(nil)) // we should receive a call to CreateRun because it's contains an async task - orm.On("CreateRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(nil).Run(func(args mock.Arguments) { - run := args.Get(0).(*pipeline.Run) + orm.On("CreateRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(nil).Run(func(args mock.Arguments) { + run := args.Get(1).(*pipeline.Run) run.ID = 1 // give it a valid "id" }).Once() // Simulate updated task run data - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(true, nil).Run(func(args mock.Arguments) { - run := args.Get(0).(*pipeline.Run) + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(true, nil).Run(func(args mock.Arguments) { + run := args.Get(1).(*pipeline.Run) // Now simulate a new result coming in while we were running task := run.ByDotID("ds1") task.Error = null.NewString("", false) @@ -816,7 +827,7 @@ ds5 [type=http method="GET" url="%s" index=2] } }).Once() // StoreRun is called again to store the final result - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(false, nil).Once() + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() incomplete, err := r.Run(testutils.Context(t), run, logger.TestLogger(t), false, nil) require.NoError(t, err) require.Len(t, run.PipelineTaskRuns, 12) diff --git a/core/services/pipeline/task.bridge_test.go b/core/services/pipeline/task.bridge_test.go index 922f82a533b..029c6c78ca8 100644 --- a/core/services/pipeline/task.bridge_test.go +++ b/core/services/pipeline/task.bridge_test.go @@ -216,8 +216,8 @@ func TestBridgeTask_Happy(t *testing.T) { RequestData: btcUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -258,8 +258,8 @@ func TestBridgeTask_HandlesIntermittentFailure(t *testing.T) { CacheTTL: "30s", // standard duration string format } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) result, runInfo := task.Run(testutils.Context(t), logger.TestLogger(t), @@ -321,8 +321,8 @@ func TestBridgeTask_DoesNotReturnStaleResults(t *testing.T) { RequestData: btcUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -481,8 +481,8 @@ func TestBridgeTask_AsyncJobPendingState(t *testing.T) { Async: "true", } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, id, c) @@ -659,8 +659,8 @@ func TestBridgeTask_Variables(t *testing.T) { IncludeInputAtKey: test.includeInputAtKey, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -728,8 +728,8 @@ func TestBridgeTask_Meta(t *testing.T) { Name: bridge.Name.String(), } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -782,8 +782,8 @@ func TestBridgeTask_IncludeInputAtKey(t *testing.T) { IncludeInputAtKey: test.includeInputAtKey, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -838,8 +838,8 @@ func TestBridgeTask_ErrorMessage(t *testing.T) { RequestData: ethUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -877,8 +877,8 @@ func TestBridgeTask_OnlyErrorMessage(t *testing.T) { RequestData: ethUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -902,8 +902,8 @@ func TestBridgeTask_ErrorIfBridgeMissing(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() orm := bridges.NewORM(db) - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -992,8 +992,8 @@ func TestBridgeTask_Headers(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -1014,8 +1014,8 @@ func TestBridgeTask_Headers(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -1036,8 +1036,8 @@ func TestBridgeTask_Headers(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -1082,8 +1082,8 @@ func TestBridgeTask_AdapterResponseStatusFailure(t *testing.T) { RequestData: btcUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) diff --git a/core/services/pipeline/task.http_test.go b/core/services/pipeline/task.http_test.go index ce28fac478c..6264d1e591b 100644 --- a/core/services/pipeline/task.http_test.go +++ b/core/services/pipeline/task.http_test.go @@ -24,7 +24,6 @@ import ( clhttptest "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/httptest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -177,8 +176,8 @@ func TestHTTPTask_Variables(t *testing.T) { RequestData: test.requestData, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) diff --git a/core/services/relay/evm/mocks/request_round_db.go b/core/services/relay/evm/mocks/request_round_db.go index 725fc6e6b37..4168ba4a1b0 100644 --- a/core/services/relay/evm/mocks/request_round_db.go +++ b/core/services/relay/evm/mocks/request_round_db.go @@ -9,6 +9,8 @@ import ( mock "github.com/stretchr/testify/mock" ocr2aggregator "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" + + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) // RequestRoundDB is an autogenerated mock type for the RequestRoundDB type @@ -62,19 +64,21 @@ func (_m *RequestRoundDB) SaveLatestRoundRequested(ctx context.Context, rr ocr2a return r0 } -// Transact provides a mock function with given fields: _a0, _a1 -func (_m *RequestRoundDB) Transact(_a0 context.Context, _a1 func(evm.RequestRoundDB) error) error { - ret := _m.Called(_a0, _a1) +// WithDataSource provides a mock function with given fields: _a0 +func (_m *RequestRoundDB) WithDataSource(_a0 sqlutil.DataSource) evm.RequestRoundDB { + ret := _m.Called(_a0) if len(ret) == 0 { - panic("no return value specified for Transact") + panic("no return value specified for WithDataSource") } - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, func(evm.RequestRoundDB) error) error); ok { - r0 = rf(_a0, _a1) + var r0 evm.RequestRoundDB + if rf, ok := ret.Get(0).(func(sqlutil.DataSource) evm.RequestRoundDB); ok { + r0 = rf(_a0) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).(evm.RequestRoundDB) + } } return r0 diff --git a/core/services/relay/evm/request_round_db.go b/core/services/relay/evm/request_round_db.go index 2b6ae10782d..96c5a05d1c7 100644 --- a/core/services/relay/evm/request_round_db.go +++ b/core/services/relay/evm/request_round_db.go @@ -12,16 +12,17 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" ) +//go:generate mockery --quiet --name RequestRoundDB --output ./mocks/ --case=underscore + // RequestRoundDB stores requested rounds for querying by the median plugin. type RequestRoundDB interface { SaveLatestRoundRequested(ctx context.Context, rr ocr2aggregator.OCR2AggregatorRoundRequested) error LoadLatestRoundRequested(context.Context) (rr ocr2aggregator.OCR2AggregatorRoundRequested, err error) - Transact(context.Context, func(db RequestRoundDB) error) error + WithDataSource(sqlutil.DataSource) RequestRoundDB } var _ RequestRoundDB = &requestRoundDB{} -//go:generate mockery --quiet --name RequestRoundDB --output ./mocks/ --case=underscore type requestRoundDB struct { ds sqlutil.DataSource oracleSpecID int32 @@ -33,10 +34,8 @@ func NewRoundRequestedDB(ds sqlutil.DataSource, oracleSpecID int32, lggr logger. return &requestRoundDB{ds, oracleSpecID, lggr} } -func (d *requestRoundDB) Transact(ctx context.Context, fn func(db RequestRoundDB) error) error { - return sqlutil.Transact(ctx, func(ds sqlutil.DataSource) RequestRoundDB { - return NewRoundRequestedDB(ds, d.oracleSpecID, d.lggr) - }, d.ds, nil, fn) +func (d *requestRoundDB) WithDataSource(ds sqlutil.DataSource) RequestRoundDB { + return NewRoundRequestedDB(ds, d.oracleSpecID, d.lggr) } func (d *requestRoundDB) SaveLatestRoundRequested(ctx context.Context, rr ocr2aggregator.OCR2AggregatorRoundRequested) error { diff --git a/core/services/relay/evm/request_round_db_test.go b/core/services/relay/evm/request_round_db_test.go index 10932c4e229..26f8e2ac1a6 100644 --- a/core/services/relay/evm/request_round_db_test.go +++ b/core/services/relay/evm/request_round_db_test.go @@ -37,9 +37,7 @@ func Test_DB_LatestRoundRequested(t *testing.T) { t.Run("saves latest round requested", func(t *testing.T) { ctx := testutils.Context(t) - err := db.Transact(ctx, func(tx evm.RequestRoundDB) error { - return tx.SaveLatestRoundRequested(ctx, rr) - }) + err := db.SaveLatestRoundRequested(ctx, rr) require.NoError(t, err) rawLog.Index = 42 @@ -53,9 +51,7 @@ func Test_DB_LatestRoundRequested(t *testing.T) { Raw: rawLog, } - err = db.Transact(ctx, func(tx evm.RequestRoundDB) error { - return tx.SaveLatestRoundRequested(ctx, rr) - }) + err = db.SaveLatestRoundRequested(ctx, rr) require.NoError(t, err) }) diff --git a/core/services/relay/evm/request_round_tracker.go b/core/services/relay/evm/request_round_tracker.go index bb39271f278..fe6b6826eb2 100644 --- a/core/services/relay/evm/request_round_tracker.go +++ b/core/services/relay/evm/request_round_tracker.go @@ -106,8 +106,8 @@ func (t *RequestRoundTracker) Close() error { // HandleLog complies with LogListener interface // It is not thread safe -func (t *RequestRoundTracker) HandleLog(lb log.Broadcast) { - was, err := t.logBroadcaster.WasAlreadyConsumed(t.ctx, lb) +func (t *RequestRoundTracker) HandleLog(ctx context.Context, lb log.Broadcast) { + was, err := t.logBroadcaster.WasAlreadyConsumed(ctx, lb) if err != nil { t.lggr.Errorw("OCRContract: could not determine if log was already consumed", "err", err) return @@ -118,12 +118,12 @@ func (t *RequestRoundTracker) HandleLog(lb log.Broadcast) { raw := lb.RawLog() if raw.Address != t.contract.Address() { t.lggr.Errorf("log address of 0x%x does not match configured contract address of 0x%x", raw.Address, t.contract.Address()) - t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(t.ctx, lb), "unable to mark consumed") + t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(ctx, nil, lb), "unable to mark consumed") return } topics := raw.Topics if len(topics) == 0 { - t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(t.ctx, lb), "unable to mark consumed") + t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(ctx, nil, lb), "unable to mark consumed") return } @@ -134,16 +134,15 @@ func (t *RequestRoundTracker) HandleLog(lb log.Broadcast) { rr, err = t.contractFilterer.ParseRoundRequested(raw) if err != nil { t.lggr.Errorw("could not parse round requested", "err", err) - t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(t.ctx, lb), "unable to mark consumed") + t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(ctx, nil, lb), "unable to mark consumed") return } if IsLaterThan(raw, t.latestRoundRequested.Raw) { - ctx := context.TODO() //TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 - err = t.odb.Transact(ctx, func(tx RequestRoundDB) error { - if err = tx.SaveLatestRoundRequested(ctx, *rr); err != nil { + err = sqlutil.TransactDataSource(ctx, t.ds, nil, func(tx sqlutil.DataSource) error { + if err = t.odb.WithDataSource(tx).SaveLatestRoundRequested(ctx, *rr); err != nil { return err } - return t.logBroadcaster.MarkConsumed(t.ctx, lb) + return t.logBroadcaster.MarkConsumed(ctx, tx, lb) }) if err != nil { t.lggr.Error(err) @@ -161,7 +160,7 @@ func (t *RequestRoundTracker) HandleLog(lb log.Broadcast) { t.lggr.Debugw("RequestRoundTracker: got unrecognised log topic", "topic", topics[0]) } if !consumed { - t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(t.ctx, lb), "unable to mark consumed") + t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(ctx, nil, lb), "unable to mark consumed") } } diff --git a/core/services/relay/evm/request_round_tracker_test.go b/core/services/relay/evm/request_round_tracker_test.go index 9feb4b77348..3421004ccf5 100644 --- a/core/services/relay/evm/request_round_tracker_test.go +++ b/core/services/relay/evm/request_round_tracker_test.go @@ -14,6 +14,7 @@ import ( "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" htmocks "github.com/smartcontractkit/chainlink/v2/common/headtracker/mocks" evmclimocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client/mocks" evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" @@ -112,7 +113,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin rawLog := cltest.LogFromFixture(t, "../../../testdata/jsonrpc/ocr2_round_requested_log_1_1.json") logBroadcast.On("RawLog").Return(rawLog).Maybe() logBroadcast.On("String").Return("").Maybe() - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) configDigest, epoch, round, err := uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) @@ -121,7 +122,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin require.Equal(t, 0, int(round)) require.Equal(t, 0, int(epoch)) - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -143,7 +144,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin require.Equal(t, 0, int(round)) require.Equal(t, 0, int(epoch)) - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -168,19 +169,14 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast.On("RawLog").Return(rawLog).Maybe() logBroadcast.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr ocr2aggregator.OCR2AggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 1 })).Return(nil) - transact := uni.db.On("Transact", mock.Anything, mock.Anything) - transact.Run(func(args mock.Arguments) { - fn := args[1].(func(evm.RequestRoundDB) error) - err2 := fn(uni.db) - transact.ReturnArguments = []any{err2} - }) + uni.db.On("WithDataSource", mock.Anything).Return(uni.db) - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -194,13 +190,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast2.On("RawLog").Return(rawLog2).Maybe() logBroadcast2.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr ocr2aggregator.OCR2AggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 9 })).Return(nil) - uni.requestRoundTracker.HandleLog(logBroadcast2) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast2) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -209,7 +205,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin assert.Equal(t, 9, int(round)) // Same round with lower epoch is ignored - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -224,13 +220,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast3.On("RawLog").Return(rawLog3).Maybe() logBroadcast3.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr ocr2aggregator.OCR2AggregatorRoundRequested) bool { return rr.Epoch == 2 && rr.Round == 1 })).Return(nil) - uni.requestRoundTracker.HandleLog(logBroadcast3) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast3) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -250,14 +246,9 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything).Return(errors.New("something exploded")) - transact := uni.db.On("Transact", mock.Anything, mock.Anything) - transact.Run(func(args mock.Arguments) { - fn := args[1].(func(evm.RequestRoundDB) error) - err := fn(uni.db) - transact.ReturnArguments = []any{err} - }) + uni.db.On("WithDataSource", mock.Anything).Return(uni.db) - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err := uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) diff --git a/core/services/streams/delegate.go b/core/services/streams/delegate.go index f9e2a64c4a3..bf492d4bd15 100644 --- a/core/services/streams/delegate.go +++ b/core/services/streams/delegate.go @@ -12,7 +12,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -38,10 +37,10 @@ func (d *Delegate) JobType() job.Type { return job.Stream } -func (d *Delegate) BeforeJobCreated(jb job.Job) {} -func (d *Delegate) AfterJobCreated(jb job.Job) {} -func (d *Delegate) BeforeJobDeleted(jb job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(jb job.Job) {} +func (d *Delegate) AfterJobCreated(jb job.Job) {} +func (d *Delegate) BeforeJobDeleted(jb job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { if jb.StreamID == nil { diff --git a/core/services/streams/stream_test.go b/core/services/streams/stream_test.go index 3c0b4d0721f..3e8f58cd58b 100644 --- a/core/services/streams/stream_test.go +++ b/core/services/streams/stream_test.go @@ -11,9 +11,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -32,7 +32,7 @@ func (m *mockRunner) ExecuteRun(ctx context.Context, spec pipeline.Spec, vars pi func (m *mockRunner) InitializePipeline(spec pipeline.Spec) (p *pipeline.Pipeline, err error) { return m.p, m.err } -func (m *mockRunner) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { +func (m *mockRunner) InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *pipeline.Run, saveSuccessfulTaskRuns bool) error { return m.err } diff --git a/core/services/vrf/delegate.go b/core/services/vrf/delegate.go index 617a28ac4d5..84c5126afef 100644 --- a/core/services/vrf/delegate.go +++ b/core/services/vrf/delegate.go @@ -11,8 +11,7 @@ import ( "github.com/theodesp/go-heaps/pairing" "go.uber.org/multierr" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" @@ -26,7 +25,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" v1 "github.com/smartcontractkit/chainlink/v2/core/services/vrf/v1" v2 "github.com/smartcontractkit/chainlink/v2/core/services/vrf/v2" @@ -34,7 +32,7 @@ import ( ) type Delegate struct { - q pg.Q + ds sqlutil.DataSource pr pipeline.Runner porm pipeline.ORM ks keystore.Master @@ -44,16 +42,15 @@ type Delegate struct { } func NewDelegate( - db *sqlx.DB, + ds sqlutil.DataSource, ks keystore.Master, pr pipeline.Runner, porm pipeline.ORM, legacyChains legacyevm.LegacyChainContainer, lggr logger.Logger, - cfg pg.QConfig, mailMon *mailbox.Monitor) *Delegate { return &Delegate{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, ks: ks, pr: pr, porm: porm, @@ -67,10 +64,10 @@ func (d *Delegate) JobType() job.Type { return job.VRF } -func (d *Delegate) BeforeJobCreated(job.Job) {} -func (d *Delegate) AfterJobCreated(job.Job) {} -func (d *Delegate) BeforeJobDeleted(job.Job) {} -func (d *Delegate) OnDeleteJob(context.Context, job.Job, pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(job.Job) {} +func (d *Delegate) AfterJobCreated(job.Job) {} +func (d *Delegate) BeforeJobDeleted(job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.ServiceCtx, error) { @@ -171,7 +168,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi lV2Plus, chain, chain.ID(), - d.q, + d.ds, v2.NewCoordinatorV2_5(coordinatorV2Plus), batchCoordinatorV2, vrfOwner, @@ -225,7 +222,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi lV2, chain, chain.ID(), - d.q, + d.ds, v2.NewCoordinatorV2(coordinatorV2), batchCoordinatorV2, vrfOwner, @@ -246,7 +243,6 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi Cfg: chain.Config().EVM(), FeeCfg: chain.Config().EVM().GasEstimator(), L: logger.Sugared(lV1), - Q: d.q, Coordinator: coordinator, PipelineRunner: d.pr, GethKs: d.ks.Eth(), diff --git a/core/services/vrf/delegate_test.go b/core/services/vrf/delegate_test.go index d009641e65f..db9724179e7 100644 --- a/core/services/vrf/delegate_test.go +++ b/core/services/vrf/delegate_test.go @@ -78,7 +78,7 @@ func buildVrfUni(t *testing.T, db *sqlx.DB, cfg chainlink.GeneralConfig) vrfUniv hb := headtracker.NewHeadBroadcaster(lggr) // Don't mock db interactions - prm := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + prm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) ks := keystore.NewInMemory(db, utils.FastScryptParams, lggr, cfg.Database()) _, dbConfig, evmConfig := txmgr.MakeTestConfigs(t) @@ -160,7 +160,6 @@ func setup(t *testing.T) (vrfUniverse, *v1.Listener, job.Job) { vuni.prm, vuni.legacyChains, logger.TestLogger(t), - cfg.Database(), mailMon) vs := testspecs.GenerateVRFSpec(testspecs.VRFSpecParams{PublicKey: vuni.vrfkey.PublicKey.String(), EVMChainID: testutils.FixtureChainID.String()}) jb, err := vrfcommon.ValidatedVRFSpec(vs.Toml()) @@ -201,9 +200,10 @@ func TestDelegate_ReorgAttackProtection(t *testing.T) { preSeed := common.BigToHash(big.NewInt(42)).Bytes() txHash := evmutils.NewHash() vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil).Maybe() - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Maybe() + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() vuni.ec.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(generateCallbackReturnValues(t, false), nil).Maybe() - listener.HandleLog(log.NewLogBroadcast(types.Log{ + ctx := testutils.Context(t) + listener.HandleLog(ctx, log.NewLogBroadcast(types.Log{ // Data has all the NON-indexed parameters Data: bytes.Join([][]byte{pk.MustHash().Bytes(), // key hash preSeed, // preSeed @@ -302,14 +302,15 @@ func TestDelegate_ValidLog(t *testing.T) { consumed := make(chan struct{}) for i, tc := range tt { tc := tc + ctx := testutils.Context(t) vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { consumed <- struct{}{} }).Return(nil).Once() // Expect a call to check if the req is already fulfilled. vuni.ec.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(generateCallbackReturnValues(t, false), nil) - listener.HandleLog(log.NewLogBroadcast(tc.log, vuni.cid, nil)) + listener.HandleLog(ctx, log.NewLogBroadcast(tc.log, vuni.cid, nil)) // Wait until the log is present waitForChannel(t, added, time.Second, "request not added to the queue") // Feed it a head which confirms it. @@ -318,7 +319,7 @@ func TestDelegate_ValidLog(t *testing.T) { // Ensure we created a successful run. waitForChannel(t, runComplete, 2*time.Second, "pipeline not complete") - runs, err := vuni.prm.GetAllRuns() + runs, err := vuni.prm.GetAllRuns(ctx) require.NoError(t, err) require.Equal(t, i+1, len(runs)) assert.False(t, runs[0].FatalErrors.HasError()) @@ -328,13 +329,13 @@ func TestDelegate_ValidLog(t *testing.T) { p, err := vuni.ks.VRF().GenerateProof(keyID, evmutils.MustHash(string(bytes.Join([][]byte{preSeed, bh.Bytes()}, []byte{}))).Big()) require.NoError(t, err) vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { consumed <- struct{}{} }).Return(nil).Once() // If we send a completed log we should the respCount increase var reqIDBytes []byte copy(reqIDBytes[:], tc.reqID[:]) - listener.HandleLog(log.NewLogBroadcast(types.Log{ + listener.HandleLog(ctx, log.NewLogBroadcast(types.Log{ // Data has all the NON-indexed parameters Data: bytes.Join([][]byte{reqIDBytes, // output p.Output.Bytes(), @@ -354,7 +355,7 @@ func TestDelegate_InvalidLog(t *testing.T) { vuni, listener, jb := setup(t) vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) done := make(chan struct{}) - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { done <- struct{}{} }).Return(nil).Once() // Expect a call to check if the req is already fulfilled. @@ -365,7 +366,8 @@ func TestDelegate_InvalidLog(t *testing.T) { added <- struct{}{} }) // Send an invalid log (keyhash doesnt match) - listener.HandleLog(log.NewLogBroadcast(types.Log{ + ctx := testutils.Context(t) + listener.HandleLog(ctx, log.NewLogBroadcast(types.Log{ // Data has all the NON-indexed parameters Data: append(append(append(append( evmutils.NewHash().Bytes(), // key hash @@ -392,7 +394,7 @@ func TestDelegate_InvalidLog(t *testing.T) { waitForChannel(t, done, time.Second, "log not consumed") // Should create a run that errors in the vrf task - runs, err := vuni.prm.GetAllRuns() + runs, err := vuni.prm.GetAllRuns(ctx) require.NoError(t, err) require.Equal(t, len(runs), 1) for _, tr := range runs[0].PipelineTaskRuns { @@ -417,7 +419,7 @@ func TestFulfilledCheck(t *testing.T) { vuni, listener, jb := setup(t) vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) done := make(chan struct{}) - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { done <- struct{}{} }).Return(nil).Once() // Expect a call to check if the req is already fulfilled. @@ -429,7 +431,8 @@ func TestFulfilledCheck(t *testing.T) { added <- struct{}{} }) // Send an invalid log (keyhash doesn't match) - listener.HandleLog(log.NewLogBroadcast( + ctx := testutils.Context(t) + listener.HandleLog(ctx, log.NewLogBroadcast( types.Log{ // Data has all the NON-indexed parameters Data: bytes.Join([][]byte{ @@ -455,7 +458,7 @@ func TestFulfilledCheck(t *testing.T) { waitForChannel(t, done, time.Second, "log not consumed") // Should consume the log with no run - runs, err := vuni.prm.GetAllRuns() + runs, err := vuni.prm.GetAllRuns(ctx) require.NoError(t, err) require.Equal(t, len(runs), 0) } @@ -685,7 +688,6 @@ func Test_VRFV2PlusServiceFailsWhenVRFOwnerProvided(t *testing.T) { vuni.prm, vuni.legacyChains, logger.TestLogger(t), - cfg.Database(), mailMon) chain, err := vuni.legacyChains.Get(testutils.FixtureChainID.String()) require.NoError(t, err) diff --git a/core/services/vrf/v1/integration_test.go b/core/services/vrf/v1/integration_test.go index f68700a8af7..1d11615950b 100644 --- a/core/services/vrf/v1/integration_test.go +++ b/core/services/vrf/v1/integration_test.go @@ -45,6 +45,7 @@ func TestIntegration_VRF_JPV2(t *testing.T) { for _, tt := range tests { test := tt t.Run(test.name, func(t *testing.T) { + ctx := testutils.Context(t) config, _ := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].GasEstimator.EIP1559DynamicFees = &test.eip1559 c.EVM[0].ChainID = (*ubig.Big)(testutils.SimulatedChainID) @@ -75,7 +76,7 @@ func TestIntegration_VRF_JPV2(t *testing.T) { } var runs []pipeline.Run gomega.NewWithT(t).Eventually(func() bool { - runs, err = app.PipelineORM().GetAllRuns() + runs, err = app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) // It possible that we send the test request // before the Job spawner has started the vrf services, which is fine @@ -128,6 +129,7 @@ func TestIntegration_VRF_JPV2(t *testing.T) { func TestIntegration_VRF_WithBHS(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) config, _ := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].GasEstimator.EIP1559DynamicFees = ptr(true) c.EVM[0].BlockBackfillDepth = ptr[uint32](500) @@ -196,7 +198,7 @@ func TestIntegration_VRF_WithBHS(t *testing.T) { var runs []pipeline.Run gomega.NewWithT(t).Eventually(func() bool { - runs, err = app.PipelineORM().GetAllRuns() + runs, err = app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) cu.Backend.Commit() return len(runs) == 1 && runs[0].State == pipeline.RunStatusCompleted diff --git a/core/services/vrf/v1/listener_v1.go b/core/services/vrf/v1/listener_v1.go index c57265634e5..ddf5779deb0 100644 --- a/core/services/vrf/v1/listener_v1.go +++ b/core/services/vrf/v1/listener_v1.go @@ -17,6 +17,7 @@ import ( "github.com/theodesp/go-heaps/pairing" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink-common/pkg/utils/mathutil" @@ -303,17 +304,16 @@ func (lsn *Listener) RunLogListener(unsubscribes []func(), minConfs uint32) { break } recovery.WrapRecover(lsn.L, func() { - lsn.handleLog(lb, minConfs) + ctx, cancel := lsn.ChStop.NewCtx() + defer cancel() + lsn.handleLog(ctx, lb, minConfs) }) } } } } -func (lsn *Listener) handleLog(lb log.Broadcast, minConfs uint32) { - ctx, cancel := lsn.ChStop.NewCtx() - defer cancel() - +func (lsn *Listener) handleLog(ctx context.Context, lb log.Broadcast, minConfs uint32) { lggr := lsn.L.With( "log", lb.String(), "decodedLog", lb.DecodedLog(), @@ -380,7 +380,7 @@ func (lsn *Listener) shouldProcessLog(ctx context.Context, lb log.Broadcast) boo } func (lsn *Listener) markLogAsConsumed(ctx context.Context, lb log.Broadcast) { - err := lsn.Chain.LogBroadcaster().MarkConsumed(ctx, lb) + err := lsn.Chain.LogBroadcaster().MarkConsumed(ctx, nil, lb) lsn.L.ErrorIf(err, fmt.Sprintf("Unable to mark log %v as consumed", lb.String())) } @@ -486,9 +486,10 @@ func (lsn *Listener) ProcessRequest(ctx context.Context, req request) bool { run := pipeline.NewRun(*lsn.Job.PipelineSpec, vars) // The VRF pipeline has no async tasks, so we don't need to check for `incomplete` - if _, err = lsn.PipelineRunner.Run(ctx, run, lggr, true, func(tx pg.Queryer) error { + if _, err = lsn.PipelineRunner.Run(ctx, run, lggr, true, func(tx sqlutil.DataSource) error { // Always mark consumed regardless of whether the proof failed or not. - if err = lsn.Chain.LogBroadcaster().MarkConsumed(ctx, req.lb); err != nil { + //TODO restore tx https://smartcontract-it.atlassian.net/browse/BCF-2978 + if err = lsn.Chain.LogBroadcaster().MarkConsumed(ctx, nil, req.lb); err != nil { lggr.Errorw("Failed mark consumed", "err", err) } return nil @@ -525,7 +526,7 @@ func (lsn *Listener) Close() error { }) } -func (lsn *Listener) HandleLog(lb log.Broadcast) { +func (lsn *Listener) HandleLog(ctx context.Context, lb log.Broadcast) { if !lsn.Deduper.ShouldDeliver(lb.RawLog()) { lsn.L.Tracew("skipping duplicate log broadcast", "log", lb.RawLog()) return diff --git a/core/services/vrf/v2/integration_helpers_test.go b/core/services/vrf/v2/integration_helpers_test.go index f19f39f03f2..3d7a94ae833 100644 --- a/core/services/vrf/v2/integration_helpers_test.go +++ b/core/services/vrf/v2/integration_helpers_test.go @@ -62,6 +62,7 @@ func testSingleConsumerHappyPath( rwfe v22.RandomWordsFulfilled, subID *big.Int), ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) key2 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) @@ -87,7 +88,7 @@ func testSingleConsumerHappyPath( // Fund gas lanes. sendEth(t, ownerKey, uni.backend, key1.Address, 10) sendEth(t, ownerKey, uni.backend, key2.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job using key1 and key2 on the same gas lane. jbs := createVRFJobs( @@ -111,7 +112,7 @@ func testSingleConsumerHappyPath( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -133,7 +134,7 @@ func testSingleConsumerHappyPath( requestID2, _ := requestRandomnessAndAssertRandomWordsRequestedEvent(t, consumerContract, consumer, keyHash, subID, numWords, 500_000, coordinator, uni.backend, nativePayment) gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 2 @@ -153,11 +154,11 @@ func testSingleConsumerHappyPath( assertNumRandomWords(t, consumerContract, numWords) // Assert that both send addresses were used to fulfill the requests - n, err := uni.backend.PendingNonceAt(testutils.Context(t), key1.Address) + n, err := uni.backend.PendingNonceAt(ctx, key1.Address) require.NoError(t, err) require.EqualValues(t, 1, n) - n, err = uni.backend.PendingNonceAt(testutils.Context(t), key2.Address) + n, err = uni.backend.PendingNonceAt(ctx, key2.Address) require.NoError(t, err) require.EqualValues(t, 1, n) @@ -182,6 +183,7 @@ func testMultipleConsumersNeedBHS( coordinator v22.CoordinatorV2_X, rwfe v22.RandomWordsFulfilled), ) { + ctx := testutils.Context(t) nConsumers := len(consumers) vrfKey := cltest.MustGenerateRandomKey(t) sendEth(t, ownerKey, uni.backend, vrfKey.Address, 10) @@ -216,7 +218,7 @@ func testMultipleConsumersNeedBHS( }) keys = append(keys, ownerKey, vrfKey) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, keys...) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. vrfJobs := createVRFJobs( @@ -250,7 +252,7 @@ func testMultipleConsumersNeedBHS( // Ensure log poller is ready and has all logs. require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Ready()) - require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(testutils.Context(t), 1)) + require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(ctx, 1)) for i := 0; i < nConsumers; i++ { consumer := consumers[i] @@ -284,7 +286,7 @@ func testMultipleConsumersNeedBHS( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -320,6 +322,7 @@ func testMultipleConsumersNeedTrustedBHS( coordinator v22.CoordinatorV2_X, rwfe v22.RandomWordsFulfilled), ) { + ctx := testutils.Context(t) nConsumers := len(consumers) vrfKey := cltest.MustGenerateRandomKey(t) sendEth(t, ownerKey, uni.backend, vrfKey.Address, 10) @@ -364,7 +367,7 @@ func testMultipleConsumersNeedTrustedBHS( }) keys = append(keys, ownerKey, vrfKey) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, keys...) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. vrfJobs := createVRFJobs( @@ -403,7 +406,7 @@ func testMultipleConsumersNeedTrustedBHS( // Ensure log poller is ready and has all logs. chain := app.GetRelayers().LegacyEVMChains().Slice()[0] require.NoError(t, chain.LogPoller().Ready()) - require.NoError(t, chain.LogPoller().Replay(testutils.Context(t), 1)) + require.NoError(t, chain.LogPoller().Replay(ctx, 1)) for i := 0; i < nConsumers; i++ { consumer := consumers[i] @@ -445,7 +448,7 @@ func testMultipleConsumersNeedTrustedBHS( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -534,6 +537,7 @@ func testSingleConsumerHappyPathBatchFulfillment( rwfe v22.RandomWordsFulfilled, subID *big.Int), ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) config, db := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -555,7 +559,7 @@ func testSingleConsumerHappyPathBatchFulfillment( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job using key1 and key2 on the same gas lane. jbs := createVRFJobs( @@ -590,7 +594,7 @@ func testSingleConsumerHappyPathBatchFulfillment( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) if bigGasCallback { @@ -640,6 +644,7 @@ func testSingleConsumerNeedsTopUp( coordinator v22.CoordinatorV2_X, rwfe v22.RandomWordsFulfilled), ) { + ctx := testutils.Context(t) key := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(1000) config, db := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -659,7 +664,7 @@ func testSingleConsumerNeedsTopUp( // Fund expensive gas lane. sendEth(t, ownerKey, uni.backend, key.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -682,7 +687,7 @@ func testSingleConsumerNeedsTopUp( // Fulfillment will not be enqueued because subscriber doesn't have enough LINK. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 0 @@ -695,7 +700,7 @@ func testSingleConsumerNeedsTopUp( // Wait for fulfillment to go through. gomega.NewWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 2", "runs", len(runs)) return len(runs) == 1 @@ -737,6 +742,7 @@ func testBlockHeaderFeeder( coordinator v22.CoordinatorV2_X, rwfe v22.RandomWordsFulfilled), ) { + ctx := testutils.Context(t) nConsumers := len(consumers) vrfKey := cltest.MustGenerateRandomKey(t) @@ -760,7 +766,7 @@ func testBlockHeaderFeeder( c.EVM[0].FinalityDepth = ptr[uint32](2) }) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, ownerKey, vrfKey, bhfKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. vrfJobs := createVRFJobs( @@ -792,7 +798,7 @@ func testBlockHeaderFeeder( // Ensure log poller is ready and has all logs. require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Ready()) - require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(testutils.Context(t), 1)) + require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(ctx, 1)) for i := 0; i < nConsumers; i++ { consumer := consumers[i] @@ -821,7 +827,7 @@ func testBlockHeaderFeeder( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -900,6 +906,7 @@ func testSingleConsumerForcedFulfillment( batchEnabled bool, vrfVersion vrfcommon.Version, ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) key2 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) @@ -951,7 +958,7 @@ func testSingleConsumerForcedFulfillment( // Fund gas lanes. sendEth(t, ownerKey, uni.backend, key1.Address, 10) sendEth(t, ownerKey, uni.backend, key2.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job using key1 and key2 on the same gas lane. jbs := createVRFJobs( @@ -1065,6 +1072,7 @@ func testSingleConsumerEIP150( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) callBackGasLimit := int64(2_500_000) // base callback gas. key1 := cltest.MustGenerateRandomKey(t) @@ -1090,7 +1098,7 @@ func testSingleConsumerEIP150( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1114,7 +1122,7 @@ func testSingleConsumerEIP150( // Wait for simulation to pass. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1132,6 +1140,7 @@ func testSingleConsumerEIP150Revert( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) callBackGasLimit := int64(2_500_000) // base callback gas. eip150Fee := int64(0) // no premium given for callWithExactGas coordinatorFulfillmentOverhead := int64(90_000) // fixed gas used in coordinator fulfillment @@ -1160,7 +1169,7 @@ func testSingleConsumerEIP150Revert( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1184,7 +1193,7 @@ func testSingleConsumerEIP150Revert( // Simulation should not pass. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 0 @@ -1202,6 +1211,7 @@ func testSingleConsumerBigGasCallbackSandwich( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(100) config, db := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -1224,7 +1234,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1253,7 +1263,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Assert that we've completed 0 runs before adding 3 new requests. { - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) assert.Equal(t, 0, len(runs)) assert.Equal(t, 3, len(reqIDs)) @@ -1262,7 +1272,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Wait for the 50_000 gas randomness request to be enqueued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1271,7 +1281,7 @@ func testSingleConsumerBigGasCallbackSandwich( // After the first successful request, no more will be enqueued. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 1 @@ -1285,7 +1295,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Assert that we've still only completed 1 run before adding new requests. { - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) assert.Equal(t, 1, len(runs)) } @@ -1300,7 +1310,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Fulfillment will not be enqueued because subscriber doesn't have enough LINK for any of the requests. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 1 @@ -1318,6 +1328,7 @@ func testSingleConsumerMultipleGasLanes( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) cheapKey := cltest.MustGenerateRandomKey(t) expensiveKey := cltest.MustGenerateRandomKey(t) cheapGasLane := assets.GWei(10) @@ -1349,7 +1360,7 @@ func testSingleConsumerMultipleGasLanes( // Fund gas lanes. sendEth(t, ownerKey, uni.backend, cheapKey.Address, 10) sendEth(t, ownerKey, uni.backend, expensiveKey.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF jobs. jbs := createVRFJobs( @@ -1374,7 +1385,7 @@ func testSingleConsumerMultipleGasLanes( // Wait for fulfillment to be queued for cheap key hash. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 1 @@ -1394,7 +1405,7 @@ func testSingleConsumerMultipleGasLanes( // We should not have any new fulfillments until a top up. gomega.NewWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 2", "runs", len(runs)) return len(runs) == 1 @@ -1406,7 +1417,7 @@ func testSingleConsumerMultipleGasLanes( // Wait for fulfillment to be queued for expensive key hash. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 2 @@ -1442,6 +1453,7 @@ func testSingleConsumerAlwaysRevertingCallbackStillFulfilled( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) key := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) config, db := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -1464,7 +1476,7 @@ func testSingleConsumerAlwaysRevertingCallbackStillFulfilled( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1488,7 +1500,7 @@ func testSingleConsumerAlwaysRevertingCallbackStillFulfilled( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1511,6 +1523,7 @@ func testConsumerProxyHappyPath( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) key2 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) @@ -1540,7 +1553,7 @@ func testConsumerProxyHappyPath( // Create gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) sendEth(t, ownerKey, uni.backend, key2.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job using key1 and key2 on the same gas lane. jbs := createVRFJobs( @@ -1565,7 +1578,7 @@ func testConsumerProxyHappyPath( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1591,7 +1604,7 @@ func testConsumerProxyHappyPath( t, consumerContract, consumerOwner, keyHash, subID, numWords, 750_000, uni.rootContract, uni.backend, nativePayment) gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 2 @@ -1603,11 +1616,11 @@ func testConsumerProxyHappyPath( assertNumRandomWords(t, consumerContract, numWords) // Assert that both send addresses were used to fulfill the requests - n, err := uni.backend.PendingNonceAt(testutils.Context(t), key1.Address) + n, err := uni.backend.PendingNonceAt(ctx, key1.Address) require.NoError(t, err) require.EqualValues(t, 1, n) - n, err = uni.backend.PendingNonceAt(testutils.Context(t), key2.Address) + n, err = uni.backend.PendingNonceAt(ctx, key2.Address) require.NoError(t, err) require.EqualValues(t, 1, n) @@ -1644,6 +1657,7 @@ func testMaliciousConsumer( batchEnabled bool, vrfVersion vrfcommon.Version, ) { + ctx := testutils.Context(t) config, _ := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].GasEstimator.LimitDefault = ptr[uint64](2_000_000) c.EVM[0].GasEstimator.PriceMax = assets.GWei(1) @@ -1656,7 +1670,7 @@ func testMaliciousConsumer( carol := uni.vrfConsumers[0] app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, ownerKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) err := app.GetKeyStore().Unlock(cltest.Password) require.NoError(t, err) @@ -1702,7 +1716,7 @@ func testMaliciousConsumer( // by the node. var attempts []txmgr.TxAttempt gomega.NewWithT(t).Eventually(func() bool { - attempts, _, err = app.TxmStorageService().TxAttempts(testutils.Context(t), 0, 1000) + attempts, _, err = app.TxmStorageService().TxAttempts(ctx, 0, 1000) require.NoError(t, err) // It possible that we send the test request // before the job spawner has started the vrf services, which is fine @@ -1716,7 +1730,7 @@ func testMaliciousConsumer( // The fulfillment tx should succeed ch, err := app.GetRelayers().LegacyEVMChains().Get(evmtest.MustGetDefaultChainID(t, config.EVMConfigs()).String()) require.NoError(t, err) - r, err := ch.Client().TransactionReceipt(testutils.Context(t), attempts[0].Hash) + r, err := ch.Client().TransactionReceipt(ctx, attempts[0].Hash) require.NoError(t, err) require.Equal(t, uint64(1), r.Status) @@ -1759,6 +1773,7 @@ func testReplayOldRequestsOnStartUp( rwfe v22.RandomWordsFulfilled, subID *big.Int), ) { + ctx := testutils.Context(t) sendingKey := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) config, _ := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -1778,7 +1793,7 @@ func testReplayOldRequestsOnStartUp( // Fund gas lanes. sendEth(t, ownerKey, uni.backend, sendingKey.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF Key, register it to coordinator and export vrfkey, err := app.GetKeyStore().VRF().Create() @@ -1816,7 +1831,7 @@ func testReplayOldRequestsOnStartUp( // Start a new app and create VRF job using the same VRF key created above app = cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, ownerKey, sendingKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) vrfKey, err := app.GetKeyStore().VRF().Import(encodedVrfKey, testutils.Password) require.NoError(t, err) @@ -1863,7 +1878,7 @@ func testReplayOldRequestsOnStartUp( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 diff --git a/core/services/vrf/v2/integration_v2_plus_test.go b/core/services/vrf/v2/integration_v2_plus_test.go index bfec76afec3..742ff99071c 100644 --- a/core/services/vrf/v2/integration_v2_plus_test.go +++ b/core/services/vrf/v2/integration_v2_plus_test.go @@ -1141,6 +1141,7 @@ func setupSubscriptionAndFund( func TestVRFV2PlusIntegration_Migration(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) ownerKey := cltest.MustGenerateRandomKey(t) uni := newVRFCoordinatorV2PlusUniverse(t, ownerKey, 1, false) key1 := cltest.MustGenerateRandomKey(t) @@ -1200,7 +1201,7 @@ func TestVRFV2PlusIntegration_Migration(t *testing.T) { // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 diff --git a/core/services/vrf/v2/integration_v2_test.go b/core/services/vrf/v2/integration_v2_test.go index 1a7c15a2508..0c81c3faca5 100644 --- a/core/services/vrf/v2/integration_v2_test.go +++ b/core/services/vrf/v2/integration_v2_test.go @@ -73,7 +73,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/keystore" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/vrfkey" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" "github.com/smartcontractkit/chainlink/v2/core/services/signatures/secp256k1" @@ -466,7 +465,8 @@ func deployOldCoordinator( // Send eth from prefunded account. // Amount is number of ETH not wei. func sendEth(t *testing.T, key ethkey.KeyV2, ec *backends.SimulatedBackend, to common.Address, eth int) { - nonce, err := ec.PendingNonceAt(testutils.Context(t), key.Address) + ctx := testutils.Context(t) + nonce, err := ec.PendingNonceAt(ctx, key.Address) require.NoError(t, err) tx := gethtypes.NewTx(&gethtypes.DynamicFeeTx{ ChainID: testutils.SimulatedChainID, @@ -480,7 +480,7 @@ func sendEth(t *testing.T, key ethkey.KeyV2, ec *backends.SimulatedBackend, to c }) signedTx, err := gethtypes.SignTx(tx, gethtypes.NewLondonSigner(testutils.SimulatedChainID), key.ToEcdsaPrivKey()) require.NoError(t, err) - err = ec.SendTransaction(testutils.Context(t), signedTx) + err = ec.SendTransaction(ctx, signedTx) require.NoError(t, err) ec.Commit() } @@ -996,7 +996,9 @@ func testEoa( batchingEnabled bool, batchCoordinatorAddress common.Address, vrfOwnerAddress *common.Address, - vrfVersion vrfcommon.Version) { + vrfVersion vrfcommon.Version, +) { + ctx := testutils.Context(t) gasLimit := int64(2_500_000) finalityDepth := uint32(50) @@ -1030,7 +1032,7 @@ func testEoa( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1059,7 +1061,7 @@ func testEoa( // Ensure request is not fulfilled. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 0 @@ -1069,10 +1071,9 @@ func testEoa( var broadcastsBeforeFinality []evmlogger.LogBroadcast var broadcastsAfterFinality []evmlogger.LogBroadcast query := `SELECT block_hash, consumed, log_index, job_id FROM log_broadcasts` - q := pg.NewQ(app.GetSqlxDB(), app.Logger, app.Config.Database()) // Execute the query. - require.NoError(t, q.Select(&broadcastsBeforeFinality, query)) + require.NoError(t, app.GetDB().SelectContext(ctx, &broadcastsBeforeFinality, query)) // Ensure there is only one log broadcast (our EOA request), and that // it hasn't been marked as consumed yet. @@ -1087,14 +1088,14 @@ func testEoa( // Ensure the request is still not fulfilled. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 0 }, 5*time.Second, time.Second).Should(gomega.BeTrue()) // Execute the query for log broadcasts again after finality depth has elapsed. - require.NoError(t, q.Select(&broadcastsAfterFinality, query)) + require.NoError(t, app.GetDB().SelectContext(ctx, &broadcastsAfterFinality, query)) // Ensure that there is still only one log broadcast (our EOA request), but that // it has been marked as "consumed," such that it won't be retried. @@ -1158,6 +1159,7 @@ func deployWrapper(t *testing.T, uni coordinatorV2UniverseCommon, wrapperOverhea func TestVRFV2Integration_SingleConsumer_Wrapper(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) wrapperOverhead := uint32(30_000) coordinatorOverhead := uint32(90_000) @@ -1179,7 +1181,7 @@ func TestVRFV2Integration_SingleConsumer_Wrapper(t *testing.T) { // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1221,7 +1223,7 @@ func TestVRFV2Integration_SingleConsumer_Wrapper(t *testing.T) { // Wait for simulation to pass. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err2 := app.PipelineORM().GetAllRuns() + runs, err2 := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err2) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1238,6 +1240,7 @@ func TestVRFV2Integration_SingleConsumer_Wrapper(t *testing.T) { func TestVRFV2Integration_Wrapper_High_Gas(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) wrapperOverhead := uint32(30_000) coordinatorOverhead := uint32(90_000) @@ -1261,7 +1264,7 @@ func TestVRFV2Integration_Wrapper_High_Gas(t *testing.T) { // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1303,7 +1306,7 @@ func TestVRFV2Integration_Wrapper_High_Gas(t *testing.T) { // Wait for simulation to pass. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err2 := app.PipelineORM().GetAllRuns() + runs, err2 := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err2) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1631,6 +1634,7 @@ func TestSimpleConsumerExample(t *testing.T) { func TestIntegrationVRFV2(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) // Reconfigure the sim chain with a default gas price of 1 gwei, // max gas limit of 2M and a key specific max 10 gwei price. // Keep the prices low so we can operate with small link balance subscriptions. @@ -1650,11 +1654,11 @@ func TestIntegrationVRFV2(t *testing.T) { carolContractAddress := uni.consumerContractAddresses[0] app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, key) - keys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.Context(t), testutils.SimulatedChainID) + keys, err := app.KeyStore.Eth().EnabledKeysForChain(ctx, testutils.SimulatedChainID) require.NoError(t, err) require.Zero(t, key.Cmp(keys[0])) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) var chain legacyevm.Chain chain, err = app.GetRelayers().LegacyEVMChains().Get(testutils.SimulatedChainID.String()) require.NoError(t, err) @@ -1723,7 +1727,7 @@ func TestIntegrationVRFV2(t *testing.T) { // by the node. var runs []pipeline.Run gomega.NewWithT(t).Eventually(func() bool { - runs, err = app.PipelineORM().GetAllRuns() + runs, err = app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) // It is possible that we send the test request // before the job spawner has started the vrf services, which is fine @@ -1745,7 +1749,7 @@ func TestIntegrationVRFV2(t *testing.T) { return len(rf) == 1 }, testutils.WaitTimeout(t), 500*time.Millisecond).Should(gomega.BeTrue()) assert.True(t, rf[0].Success(), "expected callback to succeed") - fulfillReceipt, err := uni.backend.TransactionReceipt(testutils.Context(t), rf[0].Raw().TxHash) + fulfillReceipt, err := uni.backend.TransactionReceipt(ctx, rf[0].Raw().TxHash) require.NoError(t, err) // Assert all the random words received by the consumer are different and non-zero. @@ -1813,7 +1817,7 @@ func TestIntegrationVRFV2(t *testing.T) { // We should see the response count present require.NoError(t, err) var counts map[string]uint64 - counts, err = listenerV2.GetStartingResponseCountsV2(testutils.Context(t)) + counts, err = listenerV2.GetStartingResponseCountsV2(ctx) require.NoError(t, err) t.Log(counts, rf[0].RequestID().String()) assert.Equal(t, uint64(1), counts[rf[0].RequestID().String()]) diff --git a/core/services/vrf/v2/listener_v2.go b/core/services/vrf/v2/listener_v2.go index 71c6e72a06f..e820cff63b7 100644 --- a/core/services/vrf/v2/listener_v2.go +++ b/core/services/vrf/v2/listener_v2.go @@ -14,6 +14,7 @@ import ( "github.com/theodesp/go-heaps/pairing" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" txmgrtypes "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" @@ -29,7 +30,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/vrf/vrfcommon" ) @@ -70,7 +70,7 @@ func New( l logger.Logger, chain legacyevm.Chain, chainID *big.Int, - q pg.Q, + ds sqlutil.DataSource, coordinator CoordinatorV2_X, batchCoordinator batch_vrf_coordinator_v2.BatchVRFCoordinatorV2Interface, vrfOwner vrf_owner.VRFOwnerInterface, @@ -93,7 +93,7 @@ func New( vrfOwner: vrfOwner, pipelineRunner: pipelineRunner, job: job, - q: q, + ds: ds, gethks: gethks, chStop: make(chan struct{}), reqAdded: reqAdded, @@ -120,7 +120,7 @@ type listenerV2 struct { pipelineRunner pipeline.Runner job job.Job - q pg.Q + ds sqlutil.DataSource gethks keystore.Eth chStop services.StopChan diff --git a/core/services/vrf/v2/listener_v2_log_processor.go b/core/services/vrf/v2/listener_v2_log_processor.go index db84fb47e3e..673f8618c0b 100644 --- a/core/services/vrf/v2/listener_v2_log_processor.go +++ b/core/services/vrf/v2/listener_v2_log_processor.go @@ -20,6 +20,7 @@ import ( "github.com/pkg/errors" "go.uber.org/multierr" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/hex" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" txmgrtypes "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" @@ -28,7 +29,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/vrf_coordinator_v2" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/vrf_coordinator_v2plus_interface" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/vrf/vrfcommon" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -565,55 +565,53 @@ func (lsn *listenerV2) enqueueForceFulfillment( } // fulfill the request through the VRF owner - err = lsn.q.Transaction(func(tx pg.Queryer) error { - lsn.l.Infow("VRFOwner.fulfillRandomWords vs. VRFCoordinatorV2.fulfillRandomWords", - "vrf_owner.fulfillRandomWords", hexutil.Encode(vrfOwnerABI.Methods["fulfillRandomWords"].ID), - "vrf_coordinator_v2.fulfillRandomWords", hexutil.Encode(coordinatorV2ABI.Methods["fulfillRandomWords"].ID), - ) + lsn.l.Infow("VRFOwner.fulfillRandomWords vs. VRFCoordinatorV2.fulfillRandomWords", + "vrf_owner.fulfillRandomWords", hexutil.Encode(vrfOwnerABI.Methods["fulfillRandomWords"].ID), + "vrf_coordinator_v2.fulfillRandomWords", hexutil.Encode(coordinatorV2ABI.Methods["fulfillRandomWords"].ID), + ) - vrfOwnerAddress1 := lsn.vrfOwner.Address() - vrfOwnerAddressSpec := lsn.job.VRFSpec.VRFOwnerAddress.Address() - lsn.l.Infow("addresses diff", "wrapper_address", vrfOwnerAddress1, "spec_address", vrfOwnerAddressSpec) + vrfOwnerAddress1 := lsn.vrfOwner.Address() + vrfOwnerAddressSpec := lsn.job.VRFSpec.VRFOwnerAddress.Address() + lsn.l.Infow("addresses diff", "wrapper_address", vrfOwnerAddress1, "spec_address", vrfOwnerAddressSpec) - lsn.l.Infow("fulfillRandomWords payload", "proof", p.proof, "commitment", p.reqCommitment.Get(), "payload", p.payload) - txData := hexutil.MustDecode(p.payload) - if err != nil { - return fmt.Errorf("abi pack VRFOwner.fulfillRandomWords: %w", err) - } - estimateGasLimit, err := lsn.chain.Client().EstimateGas(ctx, ethereum.CallMsg{ - From: fromAddress, - To: &vrfOwnerAddressSpec, - Data: txData, - }) - if err != nil { - return fmt.Errorf("failed to estimate gas on VRFOwner.fulfillRandomWords: %w", err) - } + lsn.l.Infow("fulfillRandomWords payload", "proof", p.proof, "commitment", p.reqCommitment.Get(), "payload", p.payload) + txData := hexutil.MustDecode(p.payload) + if err != nil { + err = fmt.Errorf("abi pack VRFOwner.fulfillRandomWords: %w", err) + return + } + estimateGasLimit, err := lsn.chain.Client().EstimateGas(ctx, ethereum.CallMsg{ + From: fromAddress, + To: &vrfOwnerAddressSpec, + Data: txData, + }) + if err != nil { + err = fmt.Errorf("failed to estimate gas on VRFOwner.fulfillRandomWords: %w", err) + return + } - lsn.l.Infow("Estimated gas limit on force fulfillment", - "estimateGasLimit", estimateGasLimit, "pipelineGasLimit", p.gasLimit) - if estimateGasLimit < p.gasLimit { - estimateGasLimit = p.gasLimit - } + lsn.l.Infow("Estimated gas limit on force fulfillment", + "estimateGasLimit", estimateGasLimit, "pipelineGasLimit", p.gasLimit) + if estimateGasLimit < p.gasLimit { + estimateGasLimit = p.gasLimit + } - requestID := common.BytesToHash(p.req.req.RequestID().Bytes()) - subID := p.req.req.SubID() - requestTxHash := p.req.req.Raw().TxHash - etx, err = lsn.chain.TxManager().CreateTransaction(ctx, txmgr.TxRequest{ - FromAddress: fromAddress, - ToAddress: lsn.vrfOwner.Address(), - EncodedPayload: txData, - FeeLimit: estimateGasLimit, - Strategy: txmgrcommon.NewSendEveryStrategy(), - Meta: &txmgr.TxMeta{ - RequestID: &requestID, - SubID: ptr(subID.Uint64()), - RequestTxHash: &requestTxHash, - // No max link since simulation failed - }, - }) - return err + requestID := common.BytesToHash(p.req.req.RequestID().Bytes()) + subID := p.req.req.SubID() + requestTxHash := p.req.req.Raw().TxHash + return lsn.chain.TxManager().CreateTransaction(ctx, txmgr.TxRequest{ + FromAddress: fromAddress, + ToAddress: lsn.vrfOwner.Address(), + EncodedPayload: txData, + FeeLimit: estimateGasLimit, + Strategy: txmgrcommon.NewSendEveryStrategy(), + Meta: &txmgr.TxMeta{ + RequestID: &requestID, + SubID: ptr(subID.Uint64()), + RequestTxHash: &requestTxHash, + // No max link since simulation failed + }, }) - return } // For an errored pipeline run, wait until the finality depth of the chain to have elapsed, @@ -786,8 +784,8 @@ func (lsn *listenerV2) processRequestsPerSubHelper( ll.Infow("Enqueuing fulfillment") var transaction txmgr.Tx - err = lsn.q.Transaction(func(tx pg.Queryer) error { - if err = lsn.pipelineRunner.InsertFinishedRun(p.run, true, pg.WithQueryer(tx)); err != nil { + err = sqlutil.TransactDataSource(ctx, lsn.ds, nil, func(tx sqlutil.DataSource) error { + if err = lsn.pipelineRunner.InsertFinishedRun(ctx, tx, p.run, true); err != nil { return err } diff --git a/core/services/vrf/v2/listener_v2_types.go b/core/services/vrf/v2/listener_v2_types.go index f10297f31a9..c7dc45bb3bd 100644 --- a/core/services/vrf/v2/listener_v2_types.go +++ b/core/services/vrf/v2/listener_v2_types.go @@ -8,10 +8,10 @@ import ( "github.com/ethereum/go-ethereum/common" heaps "github.com/theodesp/go-heaps" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/vrf/vrfcommon" ) @@ -222,8 +222,8 @@ func (lsn *listenerV2) processBatch( ) ll.Info("Enqueuing batch fulfillment") var ethTX txmgr.Tx - err = lsn.q.Transaction(func(tx pg.Queryer) error { - if err = lsn.pipelineRunner.InsertFinishedRuns(batch.runs, true, pg.WithQueryer(tx)); err != nil { + err = sqlutil.TransactDataSource(ctx, lsn.ds, nil, func(tx sqlutil.DataSource) error { + if err = lsn.pipelineRunner.InsertFinishedRuns(ctx, tx, batch.runs, true); err != nil { return fmt.Errorf("inserting finished pipeline runs: %w", err) } diff --git a/core/services/vrf/v2/reverted_txns.go b/core/services/vrf/v2/reverted_txns.go index d2f62fbf271..cfd9954a208 100644 --- a/core/services/vrf/v2/reverted_txns.go +++ b/core/services/vrf/v2/reverted_txns.go @@ -17,13 +17,13 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/pkg/errors" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" evmutils "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/vrf_coordinator_v2" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -71,15 +71,15 @@ func (lsn *listenerV2) handleRevertedTxns(ctx context.Context, pollPeriod time.D lsn.l.Infow("Handling reverted txns") // Fetch recent single and batch txns, that have not been force-fulfilled - recentSingleTxns, err := lsn.fetchRecentSingleTxns(ctx, lsn.q, lsn.chainID.Uint64(), pollPeriod) + recentSingleTxns, err := lsn.fetchRecentSingleTxns(ctx, lsn.ds, lsn.chainID.Uint64(), pollPeriod) if err != nil { lsn.l.Fatalw("Fetch recent txns", "err", err) } - recentBatchTxns, err := lsn.fetchRecentBatchTxns(ctx, lsn.q, lsn.chainID.Uint64(), pollPeriod) + recentBatchTxns, err := lsn.fetchRecentBatchTxns(ctx, lsn.ds, lsn.chainID.Uint64(), pollPeriod) if err != nil { lsn.l.Fatalw("Fetch recent batch txns", "err", err) } - recentForceFulfillmentTxns, err := lsn.fetchRevertedForceFulfilmentTxns(ctx, lsn.q, lsn.chainID.Uint64(), pollPeriod) + recentForceFulfillmentTxns, err := lsn.fetchRevertedForceFulfilmentTxns(ctx, lsn.ds, lsn.chainID.Uint64(), pollPeriod) if err != nil { lsn.l.Fatalw("Fetch recent reverted force-fulfillment txns", "err", err) } @@ -108,7 +108,7 @@ func (lsn *listenerV2) handleRevertedTxns(ctx context.Context, pollPeriod time.D } func (lsn *listenerV2) fetchRecentSingleTxns(ctx context.Context, - q pg.Q, + ds sqlutil.DataSource, chainID uint64, pollPeriod time.Duration) ([]TxnReceiptDB, error) { @@ -155,7 +155,7 @@ func (lsn *listenerV2) fetchRecentSingleTxns(ctx context.Context, var recentReceipts []TxnReceiptDB before := time.Now() - err := q.Select(&recentReceipts, sqlQuery, chainID) + err := ds.SelectContext(ctx, &recentReceipts, sqlQuery, chainID) lsn.postSqlLog(ctx, before, pollPeriod, "FetchRecentSingleTxns") if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "Error fetching recent non-force-fulfilled txns") @@ -172,7 +172,7 @@ func (lsn *listenerV2) fetchRecentSingleTxns(ctx context.Context, } func (lsn *listenerV2) fetchRecentBatchTxns(ctx context.Context, - q pg.Q, + ds sqlutil.DataSource, chainID uint64, pollPeriod time.Duration) ([]TxnReceiptDB, error) { sqlQuery := fmt.Sprintf(` @@ -217,7 +217,7 @@ func (lsn *listenerV2) fetchRecentBatchTxns(ctx context.Context, var recentReceipts []TxnReceiptDB before := time.Now() - err := q.Select(&recentReceipts, sqlQuery, chainID) + err := ds.SelectContext(ctx, &recentReceipts, sqlQuery, chainID) lsn.postSqlLog(ctx, before, pollPeriod, "FetchRecentBatchTxns") if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "Error fetching recent non-force-fulfilled txns") @@ -231,7 +231,7 @@ func (lsn *listenerV2) fetchRecentBatchTxns(ctx context.Context, } func (lsn *listenerV2) fetchRevertedForceFulfilmentTxns(ctx context.Context, - q pg.Q, + ds sqlutil.DataSource, chainID uint64, pollPeriod time.Duration) ([]TxnReceiptDB, error) { @@ -271,7 +271,7 @@ func (lsn *listenerV2) fetchRevertedForceFulfilmentTxns(ctx context.Context, var recentReceipts []TxnReceiptDB before := time.Now() - err := q.Select(&recentReceipts, sqlQuery, chainID) + err := ds.SelectContext(ctx, &recentReceipts, sqlQuery, chainID) lsn.postSqlLog(ctx, before, pollPeriod, "FetchRevertedForceFulfilmentTxns") if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "Error fetching recent reverted force-fulfilled txns") @@ -300,7 +300,7 @@ func (lsn *listenerV2) fetchRevertedForceFulfilmentTxns(ctx context.Context, `, ReqScanTimeRangeInDB) var allReceipts []TxnReceiptDB before = time.Now() - err = q.Select(&allReceipts, sqlQueryAll, chainID) + err = ds.SelectContext(ctx, &allReceipts, sqlQueryAll, chainID) lsn.postSqlLog(ctx, before, pollPeriod, "Fetch all ForceFulfilment Txns") if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "Error fetching all recent force-fulfilled txns") @@ -389,9 +389,10 @@ func (lsn *listenerV2) postSqlLog(ctx context.Context, begin time.Time, pollPeri lsn.l.Debugw("SQL context canceled", "ms", elapsed.Milliseconds(), "err", ctx.Err(), "sql", queryName) } - timeout := lsn.q.QueryTimeout - if timeout <= 0 { - timeout = pollPeriod + timeout := pollPeriod + deadline, ok := ctx.Deadline() + if ok { + timeout = deadline.Sub(begin) } pct := float64(elapsed) / float64(timeout) diff --git a/core/services/webhook/delegate.go b/core/services/webhook/delegate.go index 0c08e992f32..690ae38d088 100644 --- a/core/services/webhook/delegate.go +++ b/core/services/webhook/delegate.go @@ -13,7 +13,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -74,7 +73,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) { ) } } -func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.ServiceCtx, error) { diff --git a/core/services/workflows/delegate.go b/core/services/workflows/delegate.go index 6db39d52dd6..dedf53e369b 100644 --- a/core/services/workflows/delegate.go +++ b/core/services/workflows/delegate.go @@ -15,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type Delegate struct { @@ -36,7 +35,7 @@ func (d *Delegate) AfterJobCreated(jb job.Job) {} func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.ServiceCtx, error) { diff --git a/core/store/migrate/migrate_test.go b/core/store/migrate/migrate_test.go index 286e1b3a295..b3a15123efa 100644 --- a/core/store/migrate/migrate_test.go +++ b/core/store/migrate/migrate_test.go @@ -78,14 +78,15 @@ func TestMigrate_0100_BootstrapConfigs(t *testing.T) { err := goose.UpTo(db.DB, migrationDir, 99) require.NoError(t, err) - pipelineORM := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - pipelineID, err := pipelineORM.CreateSpec(pipeline.Pipeline{}, 0) + pipelineORM := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) + ctx := testutils.Context(t) + pipelineID, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) require.NoError(t, err) - pipelineID2, err := pipelineORM.CreateSpec(pipeline.Pipeline{}, 0) + pipelineID2, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) require.NoError(t, err) - nonBootstrapPipelineID, err := pipelineORM.CreateSpec(pipeline.Pipeline{}, 0) + nonBootstrapPipelineID, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) require.NoError(t, err) - newFormatBoostrapPipelineID2, err := pipelineORM.CreateSpec(pipeline.Pipeline{}, 0) + newFormatBoostrapPipelineID2, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) require.NoError(t, err) // OCR2 struct at migration v0099 diff --git a/core/web/pipeline_runs_controller.go b/core/web/pipeline_runs_controller.go index 2c6caa648fc..1bd52b021c3 100644 --- a/core/web/pipeline_runs_controller.go +++ b/core/web/pipeline_runs_controller.go @@ -66,6 +66,7 @@ func (prc *PipelineRunsController) Index(c *gin.Context, size, page, offset int) // Example: // "GET /jobs/:ID/runs/:runID" func (prc *PipelineRunsController) Show(c *gin.Context) { + ctx := c.Request.Context() pipelineRun := pipeline.Run{} err := pipelineRun.SetID(c.Param("runID")) if err != nil { @@ -73,7 +74,7 @@ func (prc *PipelineRunsController) Show(c *gin.Context) { return } - pipelineRun, err = prc.App.PipelineORM().FindRun(pipelineRun.ID) + pipelineRun, err = prc.App.PipelineORM().FindRun(ctx, pipelineRun.ID) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return @@ -87,8 +88,9 @@ func (prc *PipelineRunsController) Show(c *gin.Context) { // Example: // "POST /jobs/:ID/runs" func (prc *PipelineRunsController) Create(c *gin.Context) { + ctx := c.Request.Context() respondWithPipelineRun := func(jobRunID int64) { - pipelineRun, err := prc.App.PipelineORM().FindRun(jobRunID) + pipelineRun, err := prc.App.PipelineORM().FindRun(ctx, jobRunID) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return diff --git a/core/web/resolver/job_run_test.go b/core/web/resolver/job_run_test.go index 18036311155..a35a2f66ac5 100644 --- a/core/web/resolver/job_run_test.go +++ b/core/web/resolver/job_run_test.go @@ -286,7 +286,7 @@ func TestResolver_RunJob(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), nil) - f.Mocks.pipelineORM.On("FindRun", int64(25)).Return(pipeline.Run{ + f.Mocks.pipelineORM.On("FindRun", mock.Anything, int64(25)).Return(pipeline.Run{ ID: 2, PipelineSpecID: 5, CreatedAt: f.Timestamp(), @@ -377,7 +377,7 @@ func TestResolver_RunJob(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), nil) - f.Mocks.pipelineORM.On("FindRun", int64(25)).Return(pipeline.Run{}, gError) + f.Mocks.pipelineORM.On("FindRun", mock.Anything, int64(25)).Return(pipeline.Run{}, gError) f.App.On("PipelineORM").Return(f.Mocks.pipelineORM) }, query: mutation, diff --git a/core/web/resolver/mutation.go b/core/web/resolver/mutation.go index 85f3407169e..551b8d8e89a 100644 --- a/core/web/resolver/mutation.go +++ b/core/web/resolver/mutation.go @@ -1162,7 +1162,7 @@ func (r *Resolver) RunJob(ctx context.Context, args struct { return nil, err } - plnRun, err := r.App.PipelineORM().FindRun(jobRunID) + plnRun, err := r.App.PipelineORM().FindRun(ctx, jobRunID) if err != nil { return nil, err }