Skip to content

Commit

Permalink
API: Add context to each raw request call (#4987)
Browse files Browse the repository at this point in the history
  • Loading branch information
briankassouf authored Jul 24, 2018
1 parent 754fd0b commit 4ca3b84
Show file tree
Hide file tree
Showing 22 changed files with 351 additions and 97 deletions.
57 changes: 44 additions & 13 deletions api/auth_token.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package api

import "context"

// TokenAuth is used to perform token backend operations on Vault
type TokenAuth struct {
c *Client
Expand All @@ -16,7 +18,9 @@ func (c *TokenAuth) Create(opts *TokenCreateRequest) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -31,7 +35,9 @@ func (c *TokenAuth) CreateOrphan(opts *TokenCreateRequest) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -46,7 +52,9 @@ func (c *TokenAuth) CreateWithRole(opts *TokenCreateRequest, roleName string) (*
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -63,7 +71,9 @@ func (c *TokenAuth) Lookup(token string) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -79,7 +89,10 @@ func (c *TokenAuth) LookupAccessor(accessor string) (*Secret, error) {
}); err != nil {
return nil, err
}
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -91,7 +104,9 @@ func (c *TokenAuth) LookupAccessor(accessor string) (*Secret, error) {
func (c *TokenAuth) LookupSelf() (*Secret, error) {
r := c.c.NewRequest("GET", "/v1/auth/token/lookup-self")

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -109,7 +124,9 @@ func (c *TokenAuth) Renew(token string, increment int) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -126,7 +143,9 @@ func (c *TokenAuth) RenewSelf(increment int) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -146,7 +165,9 @@ func (c *TokenAuth) RenewTokenAsSelf(token string, increment int) (*Secret, erro
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -164,7 +185,10 @@ func (c *TokenAuth) RevokeAccessor(accessor string) error {
}); err != nil {
return err
}
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
Expand All @@ -183,7 +207,9 @@ func (c *TokenAuth) RevokeOrphan(token string) error {
return err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
Expand All @@ -197,7 +223,10 @@ func (c *TokenAuth) RevokeOrphan(token string) error {
// an effect.
func (c *TokenAuth) RevokeSelf(token string) error {
r := c.c.NewRequest("PUT", "/v1/auth/token/revoke-self")
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
Expand All @@ -217,7 +246,9 @@ func (c *TokenAuth) RevokeTree(token string) error {
return err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return err
}
Expand Down
19 changes: 10 additions & 9 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,13 @@ func (c *Client) NewRequest(method, requestPath string) *Request {
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequest(r *Request) (*Response, error) {
return c.RawRequestWithContext(context.Background(), r)
}

// RawRequestWithContext performs the raw request given. This request may be against
// a Vault server not configured with this client. This is an advanced operation
// that generally won't need to be called externally.
func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Response, error) {
c.modifyLock.RLock()
token := c.token

Expand All @@ -622,7 +629,7 @@ func (c *Client) RawRequest(r *Request) (*Response, error) {
c.modifyLock.RUnlock()

if limiter != nil {
limiter.Wait(context.Background())
limiter.Wait(ctx)
}

// Sanity check the token before potentially erroring from the API
Expand All @@ -643,13 +650,10 @@ START:
return nil, fmt.Errorf("nil request created")
}

// Set the timeout, if any
var cancelFunc context.CancelFunc
if timeout != 0 {
var ctx context.Context
ctx, cancelFunc = context.WithTimeout(context.Background(), timeout)
req.Request = req.Request.WithContext(ctx)
ctx, _ = context.WithTimeout(ctx, timeout)
}
req.Request = req.Request.WithContext(ctx)

if backoff == nil {
backoff = retryablehttp.LinearJitterBackoff
Expand All @@ -667,9 +671,6 @@ START:

var result *Response
resp, err := client.Do(req)
if cancelFunc != nil {
cancelFunc()
}
if resp != nil {
result = &Response{Response: resp}
}
Expand Down
6 changes: 5 additions & 1 deletion api/help.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package api

import (
"context"
"fmt"
)

// Help reads the help information for the given path.
func (c *Client) Help(path string) (*Help, error) {
r := c.NewRequest("GET", fmt.Sprintf("/v1/%s", path))
r.Params.Add("help", "1")
resp, err := c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand Down
24 changes: 19 additions & 5 deletions api/logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"bytes"
"context"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -46,7 +47,10 @@ func (c *Client) Logical() *Logical {

func (c *Logical) Read(path string) (*Secret, error) {
r := c.c.NewRequest("GET", "/v1/"+path)
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand Down Expand Up @@ -77,7 +81,10 @@ func (c *Logical) List(path string) (*Secret, error) {
// handle the wrapping lookup function
r.Method = "GET"
r.Params.Set("list", "true")
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand Down Expand Up @@ -108,7 +115,9 @@ func (c *Logical) Write(path string, data map[string]interface{}) (*Secret, erro
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand All @@ -134,7 +143,10 @@ func (c *Logical) Write(path string, data map[string]interface{}) (*Secret, erro

func (c *Logical) Delete(path string) (*Secret, error) {
r := c.c.NewRequest("DELETE", "/v1/"+path)
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand Down Expand Up @@ -175,7 +187,9 @@ func (c *Logical) Unwrap(wrappingToken string) (*Secret, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if resp != nil {
defer resp.Body.Close()
}
Expand Down
13 changes: 10 additions & 3 deletions api/ssh.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package api

import "fmt"
import (
"context"
"fmt"
)

// SSH is used to return a client to invoke operations on SSH backend.
type SSH struct {
Expand Down Expand Up @@ -28,7 +31,9 @@ func (c *SSH) Credential(role string, data map[string]interface{}) (*Secret, err
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand All @@ -45,7 +50,9 @@ func (c *SSH) SignKey(role string, data map[string]interface{}) (*Secret, error)
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion api/ssh_agent.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -207,7 +208,9 @@ func (c *SSHHelper) Verify(otp string) (*SSHVerifyResponse, error) {
return nil, err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return nil, err
}
Expand Down
22 changes: 18 additions & 4 deletions api/sys_audit.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"fmt"

"github.com/mitchellh/mapstructure"
Expand All @@ -16,7 +17,9 @@ func (c *Sys) AuditHash(path string, input string) (string, error) {
return "", err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)
if err != nil {
return "", err
}
Expand All @@ -37,7 +40,11 @@ func (c *Sys) AuditHash(path string, input string) (string, error) {

func (c *Sys) ListAudit() (map[string]*Audit, error) {
r := c.c.NewRequest("GET", "/v1/sys/audit")
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -87,7 +94,10 @@ func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) e
return err
}

resp, err := c.c.RawRequest(r)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)

if err != nil {
return err
}
Expand All @@ -98,7 +108,11 @@ func (c *Sys) EnableAuditWithOptions(path string, options *EnableAuditOptions) e

func (c *Sys) DisableAudit(path string) error {
r := c.c.NewRequest("DELETE", fmt.Sprintf("/v1/sys/audit/%s", path))
resp, err := c.c.RawRequest(r)

ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
resp, err := c.c.RawRequestWithContext(ctx, r)

if err == nil {
defer resp.Body.Close()
}
Expand Down
Loading

0 comments on commit 4ca3b84

Please sign in to comment.