From b03c5dcf3273ce33efa654b52628b14768aca46c Mon Sep 17 00:00:00 2001 From: fiftin Date: Sun, 24 Mar 2024 21:45:54 +0100 Subject: [PATCH 1/3] feat: add format for oidc claims --- api/login.go | 53 ++++++++++++++++++++++++++++---------------------- util/config.go | 1 - 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/api/login.go b/api/login.go index 6c159c517..d75f9b6b4 100644 --- a/api/login.go +++ b/api/login.go @@ -1,6 +1,7 @@ package api import ( + "bytes" "context" "crypto/tls" "encoding/base64" @@ -12,8 +13,8 @@ import ( "net/url" "os" "sort" - "strconv" "strings" + "text/template" "time" "golang.org/x/crypto/bcrypt" @@ -425,37 +426,43 @@ type oidcClaimResult struct { email string } -func parseClaims(claims map[string]interface{}, provider util.OidcProvider) (res oidcClaimResult, err error) { - var ok bool +func parseClaim(str string, claims map[string]interface{}) (string, bool) { + if strings.Contains(str, "{{") { + tpl, err := template.New("").Parse(str) - res.email, ok = claims[provider.EmailClaim].(string) - if !ok { + if err != nil { + return "", false + } - var username string + email := bytes.NewBufferString("") - if provider.EmailSuffix == "" { - err = fmt.Errorf("claim '%s' missing from id_token or not a string", provider.EmailClaim) - return + if err = tpl.Execute(email, claims); err != nil { + return "", false } - switch claims[provider.UsernameClaim].(type) { - case float64: - username = strconv.FormatFloat(claims[provider.UsernameClaim].(float64), 'f', -1, 64) - case string: - username = claims[provider.UsernameClaim].(string) - default: - err = fmt.Errorf("claim '%s' missing from id_token or not a string or an number", provider.UsernameClaim) - b, _ := json.MarshalIndent(claims, "", " ") - fmt.Print(string(b)) - return - } + return email.String(), true + } - res.email = username + "@" + provider.EmailSuffix + res, ok := claims[str].(string) + return res, ok +} + +func parseClaims(claims map[string]interface{}, provider util.OidcProvider) (res oidcClaimResult, err error) { + + var ok bool + res.email, ok = parseClaim(provider.EmailClaim, claims) + + if !ok { + err = fmt.Errorf("claim '%s' missing or has bad format", provider.EmailClaim) + return } - res.username = getRandomUsername() + res.username, ok = parseClaim(provider.UsernameClaim, claims) + if !ok || res.username == "" { + res.username = getRandomUsername() + } - res.name, ok = claims[provider.NameClaim].(string) + res.name, ok = parseClaim(provider.NameClaim, claims) if !ok || res.name == "" { res.name = getRandomProfileName() } diff --git a/util/config.go b/util/config.go index 7217ab8b0..02435d698 100644 --- a/util/config.go +++ b/util/config.go @@ -74,7 +74,6 @@ type OidcProvider struct { UsernameClaim string `json:"username_claim" default:"preferred_username"` NameClaim string `json:"name_claim" default:"preferred_username"` EmailClaim string `json:"email_claim" default:"email"` - EmailSuffix string `json:"email_suffix"` Order int `json:"order"` } From f31a3500d1b84a5c579d93ff97848a77c771a404 Mon Sep 17 00:00:00 2001 From: fiftin Date: Sun, 24 Mar 2024 21:54:40 +0100 Subject: [PATCH 2/3] feat: support claim pipes --- api/login.go | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/api/login.go b/api/login.go index d75f9b6b4..5b3b6a5ac 100644 --- a/api/login.go +++ b/api/login.go @@ -427,24 +427,31 @@ type oidcClaimResult struct { } func parseClaim(str string, claims map[string]interface{}) (string, bool) { - if strings.Contains(str, "{{") { - tpl, err := template.New("").Parse(str) - if err != nil { - return "", false - } + for _, s := range strings.Split(str, "|") { + if strings.Contains(s, "{{") { + tpl, err := template.New("").Parse(s) + + if err != nil { + return "", false + } - email := bytes.NewBufferString("") + email := bytes.NewBufferString("") - if err = tpl.Execute(email, claims); err != nil { - return "", false + if err = tpl.Execute(email, claims); err != nil { + return "", false + } + + return email.String(), true } - return email.String(), true + res, ok := claims[s].(string) + if ok { + return res, ok + } } - res, ok := claims[str].(string) - return res, ok + return "", false } func parseClaims(claims map[string]interface{}, provider util.OidcProvider) (res oidcClaimResult, err error) { From 54587b0e074013c7814e9455c78b00d8ada5da1e Mon Sep 17 00:00:00 2001 From: fiftin Date: Sun, 24 Mar 2024 22:08:49 +0100 Subject: [PATCH 3/3] test: add tests for parseClaim --- api/login.go | 12 +++++++---- api/login_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 api/login_test.go diff --git a/api/login.go b/api/login.go index 5b3b6a5ac..839b01210 100644 --- a/api/login.go +++ b/api/login.go @@ -429,6 +429,8 @@ type oidcClaimResult struct { func parseClaim(str string, claims map[string]interface{}) (string, bool) { for _, s := range strings.Split(str, "|") { + s = strings.TrimSpace(s) + if strings.Contains(s, "{{") { tpl, err := template.New("").Parse(s) @@ -442,11 +444,13 @@ func parseClaim(str string, claims map[string]interface{}) (string, bool) { return "", false } - return email.String(), true + res := email.String() + + return res, res != "" } res, ok := claims[s].(string) - if ok { + if res != "" && ok { return res, ok } } @@ -465,12 +469,12 @@ func parseClaims(claims map[string]interface{}, provider util.OidcProvider) (res } res.username, ok = parseClaim(provider.UsernameClaim, claims) - if !ok || res.username == "" { + if !ok { res.username = getRandomUsername() } res.name, ok = parseClaim(provider.NameClaim, claims) - if !ok || res.name == "" { + if !ok { res.name = getRandomProfileName() } diff --git a/api/login_test.go b/api/login_test.go new file mode 100644 index 000000000..ba501e473 --- /dev/null +++ b/api/login_test.go @@ -0,0 +1,55 @@ +package api + +import ( + "testing" +) + +func TestParseClaim(t *testing.T) { + claims := map[string]interface{}{ + "username": "fiftin", + "email": "", + "id": 1234567, + } + + res, ok := parseClaim("email | {{ .id }}@test.com", claims) + + if !ok { + t.Fail() + } + + if res != "1234567@test.com" { + t.Fatalf("%s must be %d@test.com", res, claims["id"]) + } +} + +func TestParseClaim2(t *testing.T) { + claims := map[string]interface{}{ + "username": "fiftin", + "email": "", + "id": 1234567, + } + + res, ok := parseClaim("username", claims) + + if !ok { + t.Fail() + } + + if res != claims["username"] { + t.Fail() + } +} + +func TestParseClaim3(t *testing.T) { + claims := map[string]interface{}{ + "username": "fiftin", + "email": "", + "id": 1234567, + } + + _, ok := parseClaim("email", claims) + + if ok { + t.Fail() + } +}