Skip to content

Commit

Permalink
[v0.7.5] kgw cookie multi providers (#715)
Browse files Browse the repository at this point in the history
* core: kgw cookie on multi providers (#691)
* fix: use conf.GrpcURL
  • Loading branch information
Yaiba committed May 10, 2024
1 parent a325df6 commit 5bdd5c3
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 13 deletions.
41 changes: 33 additions & 8 deletions cmd/kwil-cli/cmds/common/authinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"

"github.com/kwilteam/kwil-db/cmd/kwil-cli/config"
Expand Down Expand Up @@ -77,6 +79,30 @@ func convertToHttpCookie(c cookie) *http.Cookie {
}
}

// getDomain returns the domain of the URL.
func getDomain(target string) (string, error) {
if target == "" {
return "", fmt.Errorf("target is empty")
}

if !(strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://")) {
return "", fmt.Errorf("target missing scheme")
}

parsedTarget, err := url.Parse(target)
if err != nil {
return "", fmt.Errorf("parse target: %w", err)
}

return parsedTarget.Scheme + "://" + parsedTarget.Host, nil
}

// getCookieIdentifier returns a unique identifier for a cookie, base64 encoded.
func getCookieIdentifier(domain string, userIdentifier []byte) string {
return base64.StdEncoding.EncodeToString(
append([]byte(domain+"_"), userIdentifier...))
}

// PersistedCookies is a set of Gateway Auth cookies that can be saved to a file.
// It maps a base64 user identifier to a cookie, ensuring only one cookie per wallet.
// It uses a custom cookie type that is json serializable.
Expand All @@ -85,7 +111,7 @@ type PersistedCookies map[string]cookie
// LoadPersistedCookie loads a persisted cookie from the auth file.
// It will look up the cookie for the given user identifier.
// If nothing is found, it returns nil, nil.
func LoadPersistedCookie(authFile string, userIdentifier []byte) (*http.Cookie, error) {
func LoadPersistedCookie(authFile string, domain string, userIdentifier []byte) (*http.Cookie, error) {
if _, err := os.Stat(authFile); os.IsNotExist(err) {
return nil, nil
}
Expand All @@ -101,24 +127,23 @@ func LoadPersistedCookie(authFile string, userIdentifier []byte) (*http.Cookie,
return nil, fmt.Errorf("unmarshal kgw auth file: %w", err)
}

b64Identifier := base64.StdEncoding.EncodeToString(userIdentifier)
b64Identifier := getCookieIdentifier(domain, userIdentifier)
cookie := aInfo[b64Identifier]

return convertToHttpCookie(cookie), nil
}

// SaveCookie saves the cookie to auth file.
// It will overwrite the cookie if the address already exists.
func SaveCookie(authFile string, userIdentifier []byte, originCookie *http.Cookie) error {
func SaveCookie(authFile string, domain string, userIdentifier []byte, originCookie *http.Cookie) error {
b64Identifier := getCookieIdentifier(domain, userIdentifier)
cookie := convertToCookie(originCookie)

authInfoBytes, err := utils.ReadOrCreateFile(authFile)
if err != nil {
return fmt.Errorf("read kgw auth file: %w", err)
}

b64Identifier := base64.StdEncoding.EncodeToString(userIdentifier)

var aInfo PersistedCookies
if len(authInfoBytes) == 0 {
aInfo = make(PersistedCookies)
Expand All @@ -144,14 +169,12 @@ func SaveCookie(authFile string, userIdentifier []byte, originCookie *http.Cooki

// DeleteCookie will delete a cookie that exists for a given user identifier.
// If no cookie exists for the user identifier, it will do nothing.
func DeleteCookie(authFile string, userIdentifier []byte) error {
func DeleteCookie(authFile string, domain string, userIdentifier []byte) error {
authInfoBytes, err := utils.ReadOrCreateFile(authFile)
if err != nil {
return fmt.Errorf("read kgw auth file: %w", err)
}

b64Identifier := base64.StdEncoding.EncodeToString(userIdentifier)

var aInfo PersistedCookies
if len(authInfoBytes) == 0 {
aInfo = make(PersistedCookies)
Expand All @@ -161,6 +184,8 @@ func DeleteCookie(authFile string, userIdentifier []byte) error {
return fmt.Errorf("unmarshal kgw auth file: %w", err)
}
}

b64Identifier := getCookieIdentifier(domain, userIdentifier)
delete(aInfo, b64Identifier)

jsonBytes, err := json.MarshalIndent(&aInfo, "", " ")
Expand Down
103 changes: 101 additions & 2 deletions cmd/kwil-cli/cmds/common/authinfo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,45 @@ import (
"github.com/stretchr/testify/assert"
)

func TestLoadKGWAuthInfo_without_domain(t *testing.T) {
// this test just to show what the old behavior was

ckA := http.Cookie{
Name: "AAA",
Value: "AAA",
Path: "AAA",
Domain: "AAA",
Expires: time.Date(2023, 10, 27, 15, 46, 58, 651387237, time.UTC),
}

ckB := http.Cookie{
Name: "BBB",
Value: "BBB",
Path: "BBB",
Domain: "BBB",
Expires: time.Date(2023, 10, 27, 15, 46, 58, 651387237, time.UTC),
}

var err error
authFile := filepath.Join(t.TempDir(), "auth.json")
domain := ""

// authn on site A
err = SaveCookie(authFile, domain, []byte("0x123"), &ckA)
assert.NoError(t, err)

// authn on site B
err = SaveCookie(authFile, domain, []byte("0x123"), &ckB)
assert.NoError(t, err)

got, err := LoadPersistedCookie(authFile, domain, []byte("0x123"))
assert.NoError(t, err)

// ckA has been overwritten by ckB
assert.NotEqualValues(t, &ckA, got)
assert.EqualValues(t, &ckB, got)
}

func TestLoadKGWAuthInfo(t *testing.T) {
ck := http.Cookie{
Name: "test",
Expand All @@ -27,12 +66,72 @@ func TestLoadKGWAuthInfo(t *testing.T) {

var err error
authFile := filepath.Join(t.TempDir(), "auth.json")
domain := "https://kgw.kwil.com"

err = SaveCookie(authFile, []byte("0x123"), &ck)
err = SaveCookie(authFile, domain, []byte("0x123"), &ck)
assert.NoError(t, err)

got, err := LoadPersistedCookie(authFile, []byte("0x123"))
got, err := LoadPersistedCookie(authFile, domain, []byte("0x123"))
assert.NoError(t, err)

assert.EqualValues(t, &ck, got)
}

func Test_getDomain(t *testing.T) {
type args struct {
target string
}
tests := []struct {
name string
args args
wantErr bool
wantDoamin string
}{
// TODO: Add test cases.
{
name: "empty string",
args: args{
target: "",
},
wantErr: true,
wantDoamin: "",
},
{
name: "localhost with port",
args: args{
target: "http://localhost:8080/api",
},
wantDoamin: "http://localhost:8080",
},
{
name: "https localhost with port",
args: args{
target: "https://localhost:8080/api/",
},
wantDoamin: "https://localhost:8080",
},
{
name: "http example.com",
args: args{
target: "http://example.com/a/b/c",
},
wantDoamin: "http://example.com",
},
{
name: "https example.com",
args: args{
target: "https://example.com/a/b/c",
},
wantDoamin: "https://example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
domain, err := getDomain(tt.args.target)
if tt.wantErr {
assert.Errorf(t, err, "getDomain(%v)", tt.args.target)
}
assert.Equalf(t, tt.wantDoamin, domain, "getDomain(%v)", tt.args.target)
})
}
}
11 changes: 8 additions & 3 deletions cmd/kwil-cli/cmds/common/roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,19 @@ func DialClient(ctx context.Context, cmd *cobra.Command, flags uint8, fn RoundTr
return fn(ctx, client, conf)
}

cookie, err := LoadPersistedCookie(KGWAuthTokenFilePath(), clientConfig.Signer.Identity())
providerDomain, err := getDomain(conf.GrpcURL)
if err != nil {
return err
}

cookie, err := LoadPersistedCookie(KGWAuthTokenFilePath(), providerDomain, clientConfig.Signer.Identity())
if err == nil && cookie != nil {
// if setting fails, then don't do fail usage- failure likely means that the client has
// switched providers, and the cookie is no longer valid. The gatewayclient will re-authenticate.
// delete the cookie if it is invalid
err = client.SetAuthCookie(cookie)
if err != nil {
err2 := DeleteCookie(KGWAuthTokenFilePath(), clientConfig.Signer.Identity())
err2 := DeleteCookie(KGWAuthTokenFilePath(), providerDomain, clientConfig.Signer.Identity())
if err2 != nil {
return fmt.Errorf("failed to delete cookie: %w", err2)
}
Expand All @@ -123,7 +128,7 @@ func DialClient(ctx context.Context, cmd *cobra.Command, flags uint8, fn RoundTr
return nil
}

err = SaveCookie(KGWAuthTokenFilePath(), clientConfig.Signer.Identity(), cookie)
err = SaveCookie(KGWAuthTokenFilePath(), providerDomain, clientConfig.Signer.Identity(), cookie)
if err != nil {
return fmt.Errorf("save cookie: %w", err)
}
Expand Down

0 comments on commit 5bdd5c3

Please sign in to comment.