Skip to content

Commit

Permalink
Add support for IMDSv2 token enforcement (#80)
Browse files Browse the repository at this point in the history
* add imdsv2 token support

* update example config with imds enforcement flag

* fix weird find & replace issue

* clean up exports, add xff header test cases

* fix imports in server package
  • Loading branch information
patricksanders committed Aug 3, 2021
1 parent 8069e54 commit d610bc4
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 68 deletions.
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func init() {
viper.SetDefault("feature_flags.consoleme_metadata", false)
viper.SetDefault("log_file", getDefaultLogFile())
viper.SetDefault("mtls_settings.old_cert_message", "mTLS certificate is too old, please refresh mtls certificate")
viper.SetDefault("server.enforce_imdsv2", false)
viper.SetDefault("server.http_timeout", 20)
viper.SetDefault("server.address", "127.0.0.1")
viper.SetDefault("server.port", 9091)
Expand Down
1 change: 1 addition & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func (e Error) Error() string { return string(e) }

const (
NoCredentialsFoundInCache = Error("no credentials found in cache")
NoTokenFoundInCache = Error("no token found in cache")
NoDefaultRoleSet = Error("no default role set")
BrowserOpenError = Error("could not launch browser, open link manually")
CredentialRetrievalError = Error("failed to retrieve credentials from broker")
Expand Down
1 change: 1 addition & 0 deletions example-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ server:
http_timeout: 20
address: 127.0.0.1
port: 9091
enforce_imdsv2: false # Enforce use of a token in IMDS emulation mode (weep serve <role>)
service:
command: serve
flags: # Flags are CLI options
Expand Down
7 changes: 7 additions & 0 deletions server/baseHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"fmt"
"net/http"

"github.com/netflix/weep/util"

"github.com/netflix/weep/logging"
)

Expand Down Expand Up @@ -60,3 +62,8 @@ user-data`

fmt.Fprintln(w, baseVersionPath)
}

func NotFoundHandler(w http.ResponseWriter, r *http.Request) {
util.WriteError(w, "not found", http.StatusNotFound)
return
}
37 changes: 0 additions & 37 deletions server/customHandler.go

This file was deleted.

57 changes: 44 additions & 13 deletions server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,50 @@ import (
"strings"
"time"

"github.com/netflix/weep/session"
"github.com/spf13/viper"

"github.com/sirupsen/logrus"

"github.com/netflix/weep/util"
)

// CredentialServiceMiddleware is a convenience wrapper that chains BrowserFilterMiddleware and AWSHeaderMiddleware
func CredentialServiceMiddleware(next http.HandlerFunc) http.HandlerFunc {
// InstanceMetadataMiddleware is a convenience wrapper that chains TokenMiddleware, BrowserFilterMiddleware, and AWSHeaderMiddleware
func InstanceMetadataMiddleware(next http.HandlerFunc) http.HandlerFunc {
return TokenMiddleware(TaskMetadataMiddleware(next))
}

// TaskMetadataMiddleware is a convenience wrapper that chains BrowserFilterMiddleware and AWSHeaderMiddleware
func TaskMetadataMiddleware(next http.HandlerFunc) http.HandlerFunc {
return BrowserFilterMiddleware(AWSHeaderMiddleware(next))
}

func TokenMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var remainingTtl int
var ok bool

token := r.Header.Get("x-aws-ec2-metadata-token")
if token != "" {
if ok, remainingTtl = session.CheckToken(token); !ok {
log.Debug("token invalid")
util.WriteError(w, "invalid session token", http.StatusForbidden)
return
}
} else if token == "" && viper.GetBool("server.enforce_imdsv2") {
log.Info("request forbidden, imdsv2 required")
util.WriteError(w, "IMDSv2 required, please upgrade your SDK or CLI", http.StatusForbidden)
return
}

// Return the token's remaining TTL in a header
if remainingTtl > 0 {
w.Header().Set("X-Aws-Ec2-Metadata-Token-Ttl-Seconds", strconv.Itoa(remainingTtl))
}
next.ServeHTTP(w, r)
}
}

func AWSHeaderMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {

Expand All @@ -48,7 +82,6 @@ func AWSHeaderMiddleware(next http.HandlerFunc) http.HandlerFunc {
// If either of these request headers exist, we can be reasonably confident that the request is for IMDSv2.
// `X-Aws-Ec2-Metadata-Token-Ttl-Seconds` is used when requesting a token
// `X-aws-ec2-metadata-token` is used to pass the token to the metadata service
// Weep uses a static token, and does not perform any token validation.
if token != "" || tokenTtl != "" {
metadataVersion = 2
}
Expand All @@ -71,11 +104,10 @@ var allowedHosts = map[string]bool{
}

// deniedHeaders is a list of headers that will cause a 403 if present at all
var deniedHeaders = []string{
"Referrer",
"referrer",
"Origin",
"origin",
var deniedHeaders = map[string]bool{
"referrer": true,
"origin": true,
"x-forwarded-for": true,
}

// BrowserFilterMiddleware is a middleware designed mitigate risks related to DNS rebinding,
Expand All @@ -94,18 +126,17 @@ func BrowserFilterMiddleware(next http.HandlerFunc) http.HandlerFunc {

// Check for presence of deniedHeaders
// These also indicate a likely browser request
headers := r.Header
for _, h := range deniedHeaders {
if _, ok := headers[h]; ok {
log.Warnf("%s detected", h)
for h, _ := range r.Header {
if deniedHeaders[strings.ToLower(h)] {
log.Warnf("%s header detected", h)
util.WriteError(w, "forbidden", http.StatusForbidden)
return
}
}

// Check host header
// This should only be 127.0.0.1 or 169.254.169.254
if host := r.Header.Get("Host"); host != "" && !allowedHosts[host] {
if host := r.Header.Get("Host"); host != "" && !allowedHosts[strings.ToLower(host)] {
log.Warn("bad host detected")
util.WriteError(w, "forbidden", http.StatusForbidden)
return
Expand Down
18 changes: 15 additions & 3 deletions server/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ var browserHeaderTestCases = []struct {
HeaderValue: "",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "xff header set",
HeaderName: "X-Forwarded-For",
HeaderValue: "anything",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "empty xff header set",
HeaderName: "X-Forwarded-For",
HeaderValue: "",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "host header not in allowlist",
HeaderName: "Host",
Expand Down Expand Up @@ -107,7 +119,7 @@ func TestAWSHeaderMiddleware(t *testing.T) {
w.WriteHeader(http.StatusOK)
})
t.Logf("test case: %s", description)
bfmHandler := CredentialServiceMiddleware(nextHandler)
bfmHandler := InstanceMetadataMiddleware(nextHandler)
req := httptest.NewRequest("GET", "http://localhost", nil)
rec := httptest.NewRecorder()
bfmHandler.ServeHTTP(rec, req)
Expand All @@ -129,14 +141,14 @@ func TestAWSHeaderMiddleware(t *testing.T) {
}

// TestCredentialServiceMiddleware is a superset of TestBrowserFilterMiddleware and TestAWSHeaderMiddleware
// since CredentialServiceMiddleware is a chain of BrowserFilterMiddleware and AWSHeaderMiddleware
// since InstanceMetadataMiddleware is a chain of BrowserFilterMiddleware and AWSHeaderMiddleware
func TestCredentialServiceMiddleware(t *testing.T) {
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
for i, tc := range browserHeaderTestCases {
t.Logf("test case %d: %s", i, tc.Description)
bfmHandler := CredentialServiceMiddleware(nextHandler)
bfmHandler := InstanceMetadataMiddleware(nextHandler)
req := httptest.NewRequest("GET", "http://localhost", nil)
req.Header.Add(tc.HeaderName, tc.HeaderValue)
rec := httptest.NewRecorder()
Expand Down
26 changes: 15 additions & 11 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,24 @@ func Run(host string, port int, role, region string, shutdown chan os.Signal) er
if err != nil {
return err
}
router.HandleFunc("/{version}/", CredentialServiceMiddleware(BaseVersionHandler))
router.HandleFunc("/{version}/api/token", CredentialServiceMiddleware(TokenHandler)).Methods("PUT")
router.HandleFunc("/{version}/meta-data", CredentialServiceMiddleware(BaseHandler))
router.HandleFunc("/{version}/meta-data/", CredentialServiceMiddleware(BaseHandler))
router.HandleFunc("/{version}/meta-data/iam/info", CredentialServiceMiddleware(IamInfoHandler))

// Unauthenticated endpoints
router.HandleFunc("/{version}/api/token", TaskMetadataMiddleware(TokenHandler)).Methods("PUT")

// Authenticated endpoints
router.HandleFunc("/{version}/", InstanceMetadataMiddleware(BaseVersionHandler))
router.HandleFunc("/{version}/meta-data", InstanceMetadataMiddleware(BaseHandler))
router.HandleFunc("/{version}/meta-data/", InstanceMetadataMiddleware(BaseHandler))
router.HandleFunc("/{version}/meta-data/iam/info", InstanceMetadataMiddleware(IamInfoHandler))
// There's an extra route here to support the lack of trailing slash without the redirect that StrictSlash(true) does
router.HandleFunc("/{version}/meta-data/iam/security-credentials", CredentialServiceMiddleware(RoleHandler))
router.HandleFunc("/{version}/meta-data/iam/security-credentials/", CredentialServiceMiddleware(RoleHandler))
router.HandleFunc("/{version}/meta-data/iam/security-credentials/{role}", CredentialServiceMiddleware(IMDSHandler))
router.HandleFunc("/{version}/dynamic/instance-identity/document", CredentialServiceMiddleware(InstanceIdentityDocumentHandler))
router.HandleFunc("/{version}/meta-data/iam/security-credentials", InstanceMetadataMiddleware(RoleHandler))
router.HandleFunc("/{version}/meta-data/iam/security-credentials/", InstanceMetadataMiddleware(RoleHandler))
router.HandleFunc("/{version}/meta-data/iam/security-credentials/{role}", InstanceMetadataMiddleware(IMDSHandler))
router.HandleFunc("/{version}/dynamic/instance-identity/document", InstanceMetadataMiddleware(InstanceIdentityDocumentHandler))
}

router.HandleFunc("/ecs/{role:.*}", CredentialServiceMiddleware(getCredentialHandler(region)))
router.HandleFunc("/{path:.*}", CredentialServiceMiddleware(CustomHandler))
router.HandleFunc("/ecs/{role:.*}", TaskMetadataMiddleware(getCredentialHandler(region)))
router.HandleFunc("/{path:.*}", TaskMetadataMiddleware(NotFoundHandler))

go func() {
log.Info("starting weep on ", listenAddr)
Expand Down
20 changes: 16 additions & 4 deletions server/tokenHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,23 @@ package server
import (
"fmt"
"net/http"
)
"strconv"

const staticToken = "AQAEANQlVdnIoNfmJQHofbSTjkIm8eoMIBZZZX05Xk9jLiFuJuL2_A=="
"github.com/netflix/weep/session"
"github.com/netflix/weep/util"
"github.com/sirupsen/logrus"
)

func TokenHandler(w http.ResponseWriter, r *http.Request) {
// Returning a static token allows us to support IMDSv2 with minimal effort.
fmt.Fprint(w, staticToken)
ttlString := r.Header.Get("X-aws-ec2-metadata-token-ttl-seconds")
ttlSeconds, err := strconv.Atoi(ttlString)
log.WithFields(logrus.Fields{
"ttlSeconds": ttlSeconds,
}).Debug("generating IMDSv2 token")
if err != nil {
util.WriteError(w, "bad request", http.StatusBadRequest)
return
}
token := session.GenerateToken("", ttlSeconds)
fmt.Fprint(w, token)
}
Loading

0 comments on commit d610bc4

Please sign in to comment.