diff --git a/dht_net.go b/dht_net.go index e4dcf004d..03a7cfc43 100644 --- a/dht_net.go +++ b/dht_net.go @@ -1,6 +1,7 @@ package dht import ( + "bufio" "context" "fmt" "io" @@ -17,6 +18,31 @@ import ( var dhtReadMessageTimeout = time.Minute var ErrReadTimeout = fmt.Errorf("timed out reading response") +type bufferedWriteCloser interface { + ggio.WriteCloser + Flush() error +} + +// The Protobuf writer performs multiple small writes when writing a message. +// We need to buffer those writes, to make sure that we're not sending a new +// packet for every single write. +type bufferedDelimitedWriter struct { + *bufio.Writer + ggio.WriteCloser +} + +func newBufferedDelimitedWriter(str io.Writer) bufferedWriteCloser { + w := bufio.NewWriter(str) + return &bufferedDelimitedWriter{ + Writer: w, + WriteCloser: ggio.NewDelimitedWriter(w), + } +} + +func (w *bufferedDelimitedWriter) Flush() error { + return w.Writer.Flush() +} + // handleNewStream implements the inet.StreamHandler func (dht *IpfsDHT) handleNewStream(s inet.Stream) { go dht.handleNewMessage(s) @@ -27,7 +53,7 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) { cr := ctxio.NewReader(ctx, s) // ok to use. we defer close stream in this func cw := ctxio.NewWriter(ctx, s) // ok to use. we defer close stream in this func r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax) - w := ggio.NewDelimitedWriter(cw) + w := newBufferedDelimitedWriter(cw) mPeer := s.Conn().RemotePeer() for { @@ -70,7 +96,11 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) { } // send out response msg - if err := w.WriteMsg(rpmes); err != nil { + err = w.WriteMsg(rpmes) + if err == nil { + err = w.Flush() + } + if err != nil { s.Reset() log.Debugf("send response error: %s", err) return @@ -160,7 +190,7 @@ func (dht *IpfsDHT) messageSenderForPeer(p peer.ID) (*messageSender, error) { type messageSender struct { s inet.Stream r ggio.ReadCloser - w ggio.WriteCloser + w bufferedWriteCloser lk sync.Mutex p peer.ID dht *IpfsDHT @@ -204,7 +234,7 @@ func (ms *messageSender) prep() error { } ms.r = ggio.NewDelimitedReader(nstr, inet.MessageSizeMax) - ms.w = ggio.NewDelimitedWriter(nstr) + ms.w = newBufferedDelimitedWriter(nstr) ms.s = nstr return nil @@ -224,7 +254,7 @@ func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) erro return err } - if err := ms.w.WriteMsg(pmes); err != nil { + if err := ms.writeMsg(pmes); err != nil { ms.s.Reset() ms.s = nil @@ -260,7 +290,7 @@ func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb return nil, err } - if err := ms.w.WriteMsg(pmes); err != nil { + if err := ms.writeMsg(pmes); err != nil { ms.s.Reset() ms.s = nil @@ -302,6 +332,13 @@ func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb } } +func (ms *messageSender) writeMsg(pmes *pb.Message) error { + if err := ms.w.WriteMsg(pmes); err != nil { + return err + } + return ms.w.Flush() +} + func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { errc := make(chan error, 1) go func(r ggio.ReadCloser) {