Skip to content

Commit

Permalink
Merge pull request ipfs/go-blockservice#1 from Jorropo/fix-sessions
Browse files Browse the repository at this point in the history
Fix sessions

This commit was moved from ipfs/go-blockservice@06e2b5d
  • Loading branch information
MichaelMure authored Jul 27, 2022
2 parents 8ee16bb + 620a34f commit 41de1a5
Showing 1 changed file with 53 additions and 30 deletions.
83 changes: 53 additions & 30 deletions blockservice/blockservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,18 @@ func NewSession(ctx context.Context, bs BlockService) *Session {
exch := bs.Exchange()
if sessEx, ok := exch.(exchange.SessionExchange); ok {
return &Session{
sessCtx: ctx,
ses: nil,
sessEx: sessEx,
bs: bs.Blockstore(),
sessCtx: ctx,
ses: nil,
sessEx: sessEx,
bs: bs.Blockstore(),
notifier: exch,
}
}
return &Session{
ses: exch,
sessCtx: ctx,
bs: bs.Blockstore(),
ses: exch,
sessCtx: ctx,
bs: bs.Blockstore(),
notifier: exch,
}
}

Expand Down Expand Up @@ -214,19 +216,19 @@ func (s *blockService) GetBlock(ctx context.Context, c cid.Cid) (blocks.Block, e
ctx, span := internal.StartSpan(ctx, "blockService.GetBlock", trace.WithAttributes(attribute.Stringer("CID", c)))
defer span.End()

var f func() exchange.Interface
var f func() notifiableFetcher
if s.exchange != nil {
f = s.getExchange
}

return getBlock(ctx, c, s.blockstore, f) // hash security
}

func (s *blockService) getExchange() exchange.Interface {
func (s *blockService) getExchange() notifiableFetcher {
return s.exchange
}

func getBlock(ctx context.Context, c cid.Cid, bs blockstore.Blockstore, fget func() exchange.Interface) (blocks.Block, error) {
func getBlock(ctx context.Context, c cid.Cid, bs blockstore.Blockstore, fget func() notifiableFetcher) (blocks.Block, error) {
err := verifcid.ValidateCid(c) // hash security
if err != nil {
return nil, err
Expand Down Expand Up @@ -271,15 +273,15 @@ func (s *blockService) GetBlocks(ctx context.Context, ks []cid.Cid) <-chan block
ctx, span := internal.StartSpan(ctx, "blockService.GetBlocks")
defer span.End()

var f func() exchange.Interface
var f func() notifiableFetcher
if s.exchange != nil {
f = s.getExchange
}

return getBlocks(ctx, ks, s.blockstore, f) // hash security
}

func getBlocks(ctx context.Context, ks []cid.Cid, bs blockstore.Blockstore, fget func() exchange.Interface) <-chan blocks.Block {
func getBlocks(ctx context.Context, ks []cid.Cid, bs blockstore.Blockstore, fget func() notifiableFetcher) <-chan blocks.Block {
out := make(chan blocks.Block)

go func() {
Expand Down Expand Up @@ -397,48 +399,69 @@ func (s *blockService) Close() error {
return s.exchange.Close()
}

type notifier interface {
NotifyNewBlocks(context.Context, ...blocks.Block) error
}

// Session is a helper type to provide higher level access to bitswap sessions
type Session struct {
bs blockstore.Blockstore
ses exchange.Fetcher
sessEx exchange.SessionExchange
sessCtx context.Context
lk sync.Mutex
bs blockstore.Blockstore
ses exchange.Fetcher
sessEx exchange.SessionExchange
sessCtx context.Context
notifier notifier
lk sync.Mutex
}

func (s *Session) getSession() exchange.Interface {
type notifiableFetcher interface {
exchange.Fetcher
notifier
}

type notifiableFetcherWrapper struct {
exchange.Fetcher
notifier
}

func (s *Session) getSession() notifiableFetcher {
s.lk.Lock()
defer s.lk.Unlock()
if s.ses == nil {
s.ses = s.sessEx.NewSession(s.sessCtx)
}

// TODO: don't do that
return s.ses.(exchange.Interface)
return notifiableFetcherWrapper{s.ses, s.notifier}
}

func (s *Session) getExchange() notifiableFetcher {
return notifiableFetcherWrapper{s.ses, s.notifier}
}

func (s *Session) getFetcherFactory() func() notifiableFetcher {
if s.sessEx != nil {
return s.getSession
}
if s.ses != nil {
// Our exchange isn't session compatible, let's fallback to non sessions fetches
return s.getExchange
}
return nil
}

// GetBlock gets a block in the context of a request session
func (s *Session) GetBlock(ctx context.Context, c cid.Cid) (blocks.Block, error) {
ctx, span := internal.StartSpan(ctx, "Session.GetBlock", trace.WithAttributes(attribute.Stringer("CID", c)))
defer span.End()

var f func() exchange.Interface
if s.sessEx != nil {
f = s.getSession
}
return getBlock(ctx, c, s.bs, f) // hash security
return getBlock(ctx, c, s.bs, s.getFetcherFactory()) // hash security
}

// GetBlocks gets blocks in the context of a request session
func (s *Session) GetBlocks(ctx context.Context, ks []cid.Cid) <-chan blocks.Block {
ctx, span := internal.StartSpan(ctx, "Session.GetBlocks")
defer span.End()

var f func() exchange.Interface
if s.sessEx != nil {
f = s.getSession
}
return getBlocks(ctx, ks, s.bs, f) // hash security
return getBlocks(ctx, ks, s.bs, s.getFetcherFactory()) // hash security
}

var _ BlockGetter = (*Session)(nil)

0 comments on commit 41de1a5

Please sign in to comment.