diff --git a/v8/spnego/http.go b/v8/spnego/http.go index 4e1d2bda..3971dbc1 100644 --- a/v8/spnego/http.go +++ b/v8/spnego/http.go @@ -18,6 +18,7 @@ import ( "github.com/jcmturner/gokrb5/v8/client" "github.com/jcmturner/gokrb5/v8/credentials" "github.com/jcmturner/gokrb5/v8/gssapi" + "github.com/jcmturner/gokrb5/v8/iana/nametype" "github.com/jcmturner/gokrb5/v8/keytab" "github.com/jcmturner/gokrb5/v8/krberror" "github.com/jcmturner/gokrb5/v8/service" @@ -162,18 +163,43 @@ func respUnauthorizedNegotiate(resp *http.Response) bool { return false } +func setRequestSPN(r *http.Request) (types.PrincipalName, error) { + h := strings.TrimSuffix(r.URL.Host, ".") + // This if statement checks if the host includes a port number + if strings.LastIndex(r.URL.Host, ":") > strings.LastIndex(r.URL.Host, "]") { + // There is a port number in the URL + h, p, err := net.SplitHostPort(h) + if err != nil { + return types.PrincipalName{}, err + } + name, err := net.LookupCNAME(h) + if err == nil { + // Underlyng canonical name should be used for SPN + h = name + } + h = strings.TrimSuffix(h, ".") + r.Host = fmt.Sprintf("%s:%s", h, p) + return types.NewPrincipalName(nametype.KRB_NT_PRINCIPAL, "HTTP/"+h), nil + } + name, err := net.LookupCNAME(h) + if err == nil { + // Underlyng canonical name should be used for SPN + h = name + } + h = strings.TrimSuffix(h, ".") + r.Host = h + return types.NewPrincipalName(nametype.KRB_NT_PRINCIPAL, "HTTP/"+h), nil +} + // SetSPNEGOHeader gets the service ticket and sets it as the SPNEGO authorization header on HTTP request object. // To auto generate the SPN from the request object pass a null string "". func SetSPNEGOHeader(cl *client.Client, r *http.Request, spn string) error { if spn == "" { - h := strings.TrimSuffix(strings.SplitN(r.URL.Host, ":", 2)[0], ".") - name, err := net.LookupCNAME(h) - if err == nil { - // Underlyng canonical name should be used for SPN - h = strings.TrimSuffix(name, ".") + pn, err := setRequestSPN(r) + if err != nil { + return err } - spn = "HTTP/" + h - r.Host = h + spn = pn.PrincipalNameString() } cl.Log("using SPN %s", spn) s := SPNEGOClient(cl, spn)