diff --git a/tapgarden/mock.go b/tapgarden/mock.go index 9fa65a0c6..bd363c077 100644 --- a/tapgarden/mock.go +++ b/tapgarden/mock.go @@ -310,12 +310,12 @@ type MockChainBridge struct { NewBlocks chan int32 - ReqCount int + ReqCount atomic.Int32 ConfReqs map[int]*chainntnfs.ConfirmationEvent failFeeEstimates atomic.Bool - emptyConf bool - errConf bool + errConf atomic.Int32 + emptyConf atomic.Int32 confErr chan error } @@ -334,19 +334,30 @@ func (m *MockChainBridge) FailFeeEstimatesOnce() { m.failFeeEstimates.Store(true) } -func (m *MockChainBridge) FailConf(enable bool) { - m.errConf = enable +// FailConfOnce updates the ChainBridge such that the next call to +// RegisterConfirmationNtfn will fail by returning an error on the error channel +// returned from RegisterConfirmationNtfn. +func (m *MockChainBridge) FailConfOnce() { + // Store the incremented request count so we never store 0 as a value. + m.errConf.Store(m.ReqCount.Load() + 1) } -func (m *MockChainBridge) EmptyConf(enable bool) { - m.emptyConf = enable + +// EmptyConfOnce updates the ChainBridge such that the next confirmation event +// sent via SendConfNtfn will have an empty confirmation. +func (m *MockChainBridge) EmptyConfOnce() { + // Store the incremented request count so we never store 0 as a value. + m.emptyConf.Store(m.ReqCount.Load() + 1) } func (m *MockChainBridge) SendConfNtfn(reqNo int, blockHash *chainhash.Hash, blockHeight, blockIndex int, block *wire.MsgBlock, tx *wire.MsgTx) { + // Compare to the incremented request count since we incremented it + // when storing the request number. req := m.ConfReqs[reqNo] - if m.emptyConf { + if m.emptyConf.Load() == int32(reqNo)+1 { + m.emptyConf.Store(0) req.Confirmed <- nil return } @@ -371,7 +382,7 @@ func (m *MockChainBridge) RegisterConfirmationsNtfn(ctx context.Context, } defer func() { - m.ReqCount++ + m.ReqCount.Add(1) }() req := &chainntnfs.ConfirmationEvent{ @@ -380,15 +391,18 @@ func (m *MockChainBridge) RegisterConfirmationsNtfn(ctx context.Context, } m.confErr = make(chan error, 1) - m.ConfReqs[m.ReqCount] = req + currentReqCount := m.ReqCount.Load() + m.ConfReqs[int(currentReqCount)] = req select { - case m.ConfReqSignal <- m.ReqCount: + case m.ConfReqSignal <- int(currentReqCount): case <-ctx.Done(): } - if m.errConf { - m.confErr <- fmt.Errorf("confirmation error") + // Compare to the incremented request count since we incremented it + // when storing the request number. + if m.errConf.CompareAndSwap(currentReqCount+1, 0) { + m.confErr <- fmt.Errorf("confirmation registration error") } return req, m.confErr, nil @@ -661,7 +675,7 @@ func (m *MockKeyRing) IsLocalKey(context.Context, keychain.KeyDescriptor) bool { type MockGenSigner struct { KeyRing *MockKeyRing - FailSigning bool + failSigning atomic.Bool } func NewMockGenSigner(keyRing *MockKeyRing) *MockGenSigner { @@ -670,11 +684,17 @@ func NewMockGenSigner(keyRing *MockKeyRing) *MockGenSigner { } } +// FailSigningOnce updates the GenSigner such that the next call to +// SignVirtualTx will fail by returning an error. +func (m *MockGenSigner) FailSigningOnce() { + m.failSigning.Store(true) +} + func (m *MockGenSigner) SignVirtualTx(signDesc *lndclient.SignDescriptor, virtualTx *wire.MsgTx, prevOut *wire.TxOut) (*schnorr.Signature, error) { - if m.FailSigning { + if m.failSigning.CompareAndSwap(true, false) { return nil, fmt.Errorf("failed to sign virtual tx") } diff --git a/tapgarden/planter_test.go b/tapgarden/planter_test.go index a1553c1aa..1ae6262a8 100644 --- a/tapgarden/planter_test.go +++ b/tapgarden/planter_test.go @@ -930,7 +930,7 @@ func (t *mintingTestHarness) assertSeedlingsMatchSprouts( ) require.NoError(t, err) - // Filter out any cancelled batches. + // Filter out any cancelled or frozen batches. isCommittedBatch := func(batch *tapgarden.MintingBatch) bool { return batchCommittedStates.Contains(batch.State()) } @@ -947,7 +947,7 @@ func (t *mintingTestHarness) assertSeedlingsMatchSprouts( ) // The amount of assets committed to in the Taproot Asset commitment - // should match up + // should match up. dbAssets := pendingBatch.RootAssetCommitment.CommittedAssets() require.Len(t, dbAssets, len(seedlings)) @@ -1147,7 +1147,7 @@ func testBasicAssetCreation(t *mintingTestHarness) { t.assertNumCaretakersActive(1) // We'll now force yet another restart to ensure correctness of the - // state machine, we expect the PSBT packet to still be funded. + // state machine. We expect the PSBT packet to still be funded. t.refreshChainPlanter() batch = t.fetchSingleBatch(nil) t.assertBatchGenesisTx(batch.GenesisPacket) @@ -1457,7 +1457,7 @@ func testFinalizeBatch(t *mintingTestHarness) { // Queue another batch, reset fee estimation behavior, and set TX // confirmation registration to fail. t.queueInitialBatch(numSeedlings) - t.chain.FailConf(true) + t.chain.FailConfOnce() // Finalize the pending batch to start a caretaker, and progress the // caretaker to TX confirmation. The finalize call should report no @@ -1483,8 +1483,7 @@ func testFinalizeBatch(t *mintingTestHarness) { // Queue another batch, set TX confirmation to succeed, and set the // confirmation event to be empty. t.queueInitialBatch(numSeedlings) - t.chain.FailConf(false) - t.chain.EmptyConf(true) + t.chain.EmptyConfOnce() // Start a new caretaker that should reach TX broadcast. t.finalizeBatch(&wg, respChan, nil) @@ -1515,7 +1514,6 @@ func testFinalizeBatch(t *mintingTestHarness) { // Queue another batch and drive the caretaker to a successful minting. t.queueInitialBatch(numSeedlings) - t.chain.EmptyConf(false) // Use a custom feerate and verify that the TX uses that feerate. manualFeeRate := chainfee.FeePerKwFloor * 2 @@ -1915,7 +1913,7 @@ func testFundSealOnRestart(t *mintingTestHarness) { // Allow batch funding to succeed, but set group key signing to fail so // that batch sealing fails. - t.genSigner.FailSigning = true + t.genSigner.FailSigningOnce() failedBatchCount++ // Create a seedling with emission enabled, to ensure that batch sealing @@ -1942,7 +1940,6 @@ func testFundSealOnRestart(t *mintingTestHarness) { // Allow batch sealing to succeed. The planter should now be able to // start a caretaker for the batch on restart. - t.genSigner.FailSigning = false t.queueSeedlingsInBatch(false, seedlings...) batchCount++