Skip to content

Commit

Permalink
feat: DialContext() better error handling
Browse files Browse the repository at this point in the history
Now DialContext() returns some errors immediately
instead of deffering it Read/Write operation on the returned connection.

Signed-off-by: Matej Vasek <mvasek@redhat.com>
  • Loading branch information
matejvasek committed Jun 8, 2023
1 parent a02623f commit 19af793
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
49 changes: 41 additions & 8 deletions pkg/k8s/dialer.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package k8s

import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"regexp"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -67,18 +69,45 @@ func (c *contextDialer) DialContext(ctx context.Context, network string, addr st
return nil, fmt.Errorf("unsupported network: %q", network)
}

pr, pw, conn := newConn()

ctrStdin, ctrStdout, conn := newConn()
connectSuccess := make(chan struct{})
connectFailure := make(chan error, 1)
go func() {
errOut := bytes.NewBuffer(nil)
err := c.exec(addr, pr, pw, errOut)
stderrBuff := bytes.NewBuffer(nil)
ctrStderr := io.MultiWriter(stderrBuff, detectConnSuccess(connectSuccess))

err := c.exec(addr, ctrStdin, ctrStdout, ctrStderr)
if err != nil {
err = fmt.Errorf("failed to exec in pod: %w (stderr: %q)", err, errOut.String())
err = fmt.Errorf("failed to exec in pod: %w (stderr: %q)", err, stderrBuff.String())
}
_ = conn.closeWithError(err)
connectFailure <- err
}()

return conn, nil
select {
case <-connectSuccess:
return conn, nil
case err := <-connectFailure:
return nil, err
case <-ctx.Done():
_ = conn.closeWithError(ctx.Err())
return nil, ctx.Err()
}
}

var connSuccessfulRE = regexp.MustCompile("successfully connected")

// Creates io.Writer which closes connectSuccess channel when string "successfully connected" is written to it.
func detectConnSuccess(connectSuccess chan struct{}) io.Writer {
pr, pw := io.Pipe()
go func() {
ok := connSuccessfulRE.MatchReader(bufio.NewReader(pr))
if ok {
close(connectSuccess)
}
_, _ = io.Copy(io.Discard, pr)
}()
return pw
}

func (c *contextDialer) Close() error {
Expand Down Expand Up @@ -195,7 +224,7 @@ func (c *contextDialer) exec(hostPort string, in io.Reader, out, errOut io.Write
Namespace(c.namespace).
SubResource("exec")
req.VersionedParams(&coreV1.PodExecOptions{
Command: []string{"socat", "-", fmt.Sprintf("TCP:%s", hostPort)},
Command: []string{"socat", "-dd", "-", fmt.Sprintf("TCP:%s", hostPort)},
Container: c.podName,
Stdin: true,
Stdout: true,
Expand Down Expand Up @@ -348,6 +377,10 @@ func (c *conn) Write(b []byte) (n int, err error) {
}

func (c *conn) closeWithError(err error) error {
if err == nil {
err = net.ErrClosed
}

{
e := err
c.err.CompareAndSwap(nil, &e)
Expand All @@ -364,7 +397,7 @@ func (c *conn) closeWithError(err error) error {
}

func (c *conn) Close() error {
return c.closeWithError(net.ErrClosed)
return c.closeWithError(nil)
}

func (c *conn) LocalAddr() net.Addr {
Expand Down
12 changes: 2 additions & 10 deletions pkg/k8s/dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,8 @@ func TestDialUnreachable(t *testing.T) {
t.Cleanup(func() {
dialer.Close()
})

transport := &http.Transport{
DialContext: dialer.DialContext,
}

var client = http.Client{
Transport: transport,
}

_, err = client.Get("http://does-not.exists.svc")

_, err = dialer.DialContext(ctx, "tcp", "does-not.exists.svc:80")
if err == nil {
t.Error("error was expected but got nil")
return
Expand Down

0 comments on commit 19af793

Please sign in to comment.