Skip to content

Commit

Permalink
Fixes in adal.Token for ADFS
Browse files Browse the repository at this point in the history
Changed a few fields in adal.Token from string to json.Number to handle
differences between AAD and ADFS in how they send data over the wire.
  • Loading branch information
jhendrixMSFT committed Sep 28, 2018
1 parent 9bc4033 commit cc7d4d2
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 26 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# CHANGELOG

## v11.0.0

### Breaking Changes

- To handle differences between ADFS and AAD the following fields have had their types changed from `string` to `json.Number`
- ExpiresIn
- ExpiresOn
- NotBefore

## v10.15.5

### Bug Fixes
Expand Down
21 changes: 15 additions & 6 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -101,27 +100,35 @@ type Token struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`

ExpiresIn string `json:"expires_in"`
ExpiresOn string `json:"expires_on"`
NotBefore string `json:"not_before"`
ExpiresIn json.Number `json:"expires_in"`
ExpiresOn json.Number `json:"expires_on"`
NotBefore json.Number `json:"not_before"`

Resource string `json:"resource"`
Type string `json:"token_type"`
}

func newToken() Token {
return Token{
ExpiresIn: "0",
ExpiresOn: "0",
NotBefore: "0",
}
}

// IsZero returns true if the token object is zero-initialized.
func (t Token) IsZero() bool {
return t == Token{}
}

// Expires returns the time.Time when the Token expires.
func (t Token) Expires() time.Time {
s, err := strconv.Atoi(t.ExpiresOn)
s, err := t.ExpiresOn.Float64()
if err != nil {
s = -3600
}

expiration := date.NewUnixTimeFromSeconds(float64(s))
expiration := date.NewUnixTimeFromSeconds(s)

return time.Time(expiration).UTC()
}
Expand Down Expand Up @@ -414,6 +421,7 @@ func NewServicePrincipalTokenWithSecret(oauthConfig OAuthConfig, id string, reso
}
spt := &ServicePrincipalToken{
inner: servicePrincipalToken{
Token: newToken(),
OauthConfig: oauthConfig,
Secret: secret,
ClientID: id,
Expand Down Expand Up @@ -653,6 +661,7 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI

spt := &ServicePrincipalToken{
inner: servicePrincipalToken{
Token: newToken(),
OauthConfig: OAuthConfig{
TokenEndpoint: *msiEndpointURL,
},
Expand Down
30 changes: 12 additions & 18 deletions autorest/adal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ func TestServicePrincipalTokenRefreshUnmarshals(t *testing.T) {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
} else if spt.inner.Token.AccessToken != "accessToken" ||
spt.inner.Token.ExpiresIn != "3600" ||
spt.inner.Token.ExpiresOn != expiresOn ||
spt.inner.Token.NotBefore != expiresOn ||
spt.inner.Token.ExpiresOn != json.Number(expiresOn) ||
spt.inner.Token.NotBefore != json.Number(expiresOn) ||
spt.inner.Token.Resource != "resource" ||
spt.inner.Token.Type != "Bearer" {
t.Fatalf("adal: ServicePrincipalToken#Refresh failed correctly unmarshal the JSON -- expected %v, received %v",
Expand Down Expand Up @@ -684,13 +684,13 @@ func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) {
RedirectURI: "redirect",
}

spt, err := NewServicePrincipalTokenFromManualTokenSecret(TestOAuthConfig, "id", "resource", *token, secret, nil)
spt, err := NewServicePrincipalTokenFromManualTokenSecret(TestOAuthConfig, "id", "resource", token, secret, nil)
if err != nil {
t.Fatalf("Failed creating new SPT: %s", err)
}

if !reflect.DeepEqual(*token, spt.inner.Token) {
t.Fatalf("Tokens do not match: %s, %s", *token, spt.inner.Token)
if !reflect.DeepEqual(token, spt.inner.Token) {
t.Fatalf("Tokens do not match: %s, %s", token, spt.inner.Token)
}

if !reflect.DeepEqual(secret, spt.inner.Secret) {
Expand Down Expand Up @@ -822,7 +822,7 @@ func TestMarshalInnerToken(t *testing.T) {
t.Fatalf("tokens don't match: %s, %s", tokenJSON, testTokenJSON)
}

var t1 *Token
var t1 Token
err = json.Unmarshal(tokenJSON, &t1)
if err != nil {
t.Fatalf("failed to unmarshal token: %+v", err)
Expand All @@ -833,14 +833,6 @@ func TestMarshalInnerToken(t *testing.T) {
}
}

func newToken() *Token {
return &Token{
AccessToken: "ASECRETVALUE",
Resource: "https://azure.microsoft.com/",
Type: "Bearer",
}
}

func newTokenJSON(expiresOn string, resource string) string {
return fmt.Sprintf(`{
"access_token" : "accessToken",
Expand All @@ -855,11 +847,13 @@ func newTokenJSON(expiresOn string, resource string) string {
}

func newTokenExpiresIn(expireIn time.Duration) *Token {
return setTokenToExpireIn(newToken(), expireIn)
t := newToken()
return setTokenToExpireIn(&t, expireIn)
}

func newTokenExpiresAt(expireAt time.Time) *Token {
return setTokenToExpireAt(newToken(), expireAt)
t := newToken()
return setTokenToExpireAt(&t, expireAt)
}

func expireToken(t *Token) *Token {
Expand All @@ -868,7 +862,7 @@ func expireToken(t *Token) *Token {

func setTokenToExpireAt(t *Token, expireAt time.Time) *Token {
t.ExpiresIn = "3600"
t.ExpiresOn = strconv.Itoa(int(expireAt.Sub(date.UnixEpoch()).Seconds()))
t.ExpiresOn = json.Number(strconv.Itoa(int(expireAt.Sub(date.UnixEpoch()).Seconds())))
t.NotBefore = t.ExpiresOn
return t
}
Expand All @@ -885,7 +879,7 @@ func newServicePrincipalToken(callbacks ...TokenRefreshCallback) *ServicePrincip
func newServicePrincipalTokenManual() *ServicePrincipalToken {
token := newToken()
token.RefreshToken = "refreshtoken"
spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", *token)
spt, _ := NewServicePrincipalTokenFromManualToken(TestOAuthConfig, "id", "resource", token)
return spt
}

Expand Down
2 changes: 1 addition & 1 deletion autorest/azure/cli/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (t Token) ToADALToken() (converted adal.Token, err error) {
AccessToken: t.AccessToken,
Type: t.TokenType,
ExpiresIn: "3600",
ExpiresOn: strconv.Itoa(int(difference.Seconds())),
ExpiresOn: json.Number(strconv.Itoa(int(difference.Seconds()))),
RefreshToken: t.RefreshToken,
Resource: t.Resource,
}
Expand Down
2 changes: 1 addition & 1 deletion version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
)

// Number contains the semantic version of this SDK.
const Number = "v10.15.5"
const Number = "v11.0.0"

var (
userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",
Expand Down

0 comments on commit cc7d4d2

Please sign in to comment.