Skip to content

Commit

Permalink
feat: SSO MFA - Add ephemeral SSO MFA device (#46704)
Browse files Browse the repository at this point in the history
* Add MFA device proto.

* Add MFA device helper functions.

* Include SSO MFA device in backend GetMFADevices method.

* Handle SSO device in oneof switches.

* Display sso mfa devices as 'SSO' in tsh and 'SSO Provider' in the WebUI.

* Handle SSO mfa devices in deletion flow.

* Fix lint and unit test.

* Resolve comments; separate SSO MFA device tests; Prevent SSO MFA devices from being stored in the backend.

* Get SSO mfa device concurrently.

* Update test.

* Resolve comments.

* Fix lint.

* Minor change to use cancel context.

* Use errgroup.
  • Loading branch information
Joerger authored Oct 16, 2024
1 parent 98c11b6 commit 5d27bf8
Show file tree
Hide file tree
Showing 13 changed files with 3,075 additions and 2,298 deletions.
9 changes: 9 additions & 0 deletions api/proto/teleport/legacy/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3736,6 +3736,7 @@ message MFADevice {
TOTPDevice totp = 8;
U2FDevice u2f = 9;
WebauthnDevice webauthn = 10;
SSOMFADevice sso = 11;
}
}

Expand Down Expand Up @@ -3800,6 +3801,14 @@ message WebauthnDevice {
google.protobuf.BoolValue credential_backed_up = 10;
}

// SSOMFADevice contains details of an SSO MFA method.
message SSOMFADevice {
// connector_id is the ID of the SSO connector.
string connector_id = 1;
// connector_type is the type of the SSO connector.
string connector_type = 2;
}

// WebauthnLocalAuth holds settings necessary for local webauthn use.
message WebauthnLocalAuth {
// UserID is the random user handle generated for the user.
Expand Down
104 changes: 0 additions & 104 deletions api/types/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
package types

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand All @@ -27,7 +26,6 @@ import (
"strings"
"time"

"github.com/gogo/protobuf/jsonpb"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/constants"
Expand Down Expand Up @@ -939,108 +937,6 @@ func (wal *WebauthnLocalAuth) Check() error {
return nil
}

// NewMFADevice creates a new MFADevice with the given name. Caller must set
// the Device field in the returned MFADevice.
func NewMFADevice(name, id string, addedAt time.Time) *MFADevice {
return &MFADevice{
Metadata: Metadata{
Name: name,
},
Id: id,
AddedAt: addedAt,
LastUsed: addedAt,
}
}

// setStaticFields sets static resource header and metadata fields.
func (d *MFADevice) setStaticFields() {
d.Kind = KindMFADevice
d.Version = V1
}

// CheckAndSetDefaults validates MFADevice fields and populates empty fields
// with default values.
func (d *MFADevice) CheckAndSetDefaults() error {
d.setStaticFields()
if err := d.Metadata.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
if d.Id == "" {
return trace.BadParameter("MFADevice missing ID field")
}
if d.AddedAt.IsZero() {
return trace.BadParameter("MFADevice missing AddedAt field")
}
if d.LastUsed.IsZero() {
return trace.BadParameter("MFADevice missing LastUsed field")
}
if d.LastUsed.Before(d.AddedAt) {
return trace.BadParameter("MFADevice LastUsed field must be earlier than AddedAt")
}
if d.Device == nil {
return trace.BadParameter("MFADevice missing Device field")
}
if err := checkWebauthnDevice(d); err != nil {
return trace.Wrap(err)
}
return nil
}

func checkWebauthnDevice(d *MFADevice) error {
wrapper, ok := d.Device.(*MFADevice_Webauthn)
if !ok {
return nil
}
switch webDev := wrapper.Webauthn; {
case webDev == nil:
return trace.BadParameter("MFADevice has malformed WebauthnDevice")
case len(webDev.CredentialId) == 0:
return trace.BadParameter("WebauthnDevice missing CredentialId field")
case len(webDev.PublicKeyCbor) == 0:
return trace.BadParameter("WebauthnDevice missing PublicKeyCbor field")
default:
return nil
}
}

func (d *MFADevice) GetKind() string { return d.Kind }
func (d *MFADevice) GetSubKind() string { return d.SubKind }
func (d *MFADevice) SetSubKind(sk string) { d.SubKind = sk }
func (d *MFADevice) GetVersion() string { return d.Version }
func (d *MFADevice) GetMetadata() Metadata { return d.Metadata }
func (d *MFADevice) GetName() string { return d.Metadata.GetName() }
func (d *MFADevice) SetName(n string) { d.Metadata.SetName(n) }
func (d *MFADevice) GetRevision() string { return d.Metadata.GetRevision() }
func (d *MFADevice) SetRevision(rev string) { d.Metadata.SetRevision(rev) }
func (d *MFADevice) Expiry() time.Time { return d.Metadata.Expiry() }
func (d *MFADevice) SetExpiry(exp time.Time) { d.Metadata.SetExpiry(exp) }

// MFAType returns the human-readable name of the MFA protocol of this device.
func (d *MFADevice) MFAType() string {
switch d.Device.(type) {
case *MFADevice_Totp:
return "TOTP"
case *MFADevice_U2F:
return "U2F"
case *MFADevice_Webauthn:
return "WebAuthn"
default:
return "unknown"
}
}

func (d *MFADevice) MarshalJSON() ([]byte, error) {
buf := new(bytes.Buffer)
err := (&jsonpb.Marshaler{}).Marshal(buf, d)
return buf.Bytes(), trace.Wrap(err)
}

func (d *MFADevice) UnmarshalJSON(buf []byte) error {
unmarshaler := jsonpb.Unmarshaler{AllowUnknownFields: true}
err := unmarshaler.Unmarshal(bytes.NewReader(buf), d)
return trace.Wrap(err)
}

// IsSessionMFARequired returns whether this RequireMFAType requires per-session MFA.
func (r RequireMFAType) IsSessionMFARequired() bool {
return r != RequireMFAType_OFF
Expand Down
6 changes: 5 additions & 1 deletion api/types/authentication_mfadevice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ func TestMFADevice_CheckAndSetDefaults(t *testing.T) {
Id: "otp-0001",
AddedAt: now,
LastUsed: now,
Device: &types.MFADevice_Totp{}, // validated elsewhere
Device: &types.MFADevice_Totp{
Totp: &types.TOTPDevice{
Key: "key",
},
},
},
},
{
Expand Down
134 changes: 134 additions & 0 deletions api/types/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,103 @@
package types

import (
"bytes"
"time"

"github.com/gogo/protobuf/jsonpb"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/utils"
)

// NewMFADevice creates a new MFADevice with the given name. Caller must set
// the Device field in the returned MFADevice.
func NewMFADevice(name, id string, addedAt time.Time, device isMFADevice_Device) (*MFADevice, error) {
dev := &MFADevice{
Metadata: Metadata{
Name: name,
},
Id: id,
AddedAt: addedAt,
LastUsed: addedAt,
Device: device,
}
return dev, dev.CheckAndSetDefaults()
}

// setStaticFields sets static resource header and metadata fields.
func (d *MFADevice) setStaticFields() {
d.Kind = KindMFADevice
d.Version = V1
}

// CheckAndSetDefaults validates MFADevice fields and populates empty fields
// with default values.
func (d *MFADevice) CheckAndSetDefaults() error {
d.setStaticFields()
if err := d.Metadata.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
if d.Id == "" {
return trace.BadParameter("MFADevice missing ID field")
}
if d.AddedAt.IsZero() {
return trace.BadParameter("MFADevice missing AddedAt field")
}
if d.LastUsed.IsZero() {
return trace.BadParameter("MFADevice missing LastUsed field")
}
if d.LastUsed.Before(d.AddedAt) {
return trace.BadParameter("MFADevice LastUsed field must be earlier than AddedAt")
}
if d.Device == nil {
return trace.BadParameter("MFADevice missing Device field")
}
if err := d.validateDevice(); err != nil {
return trace.Wrap(err)
}
return nil
}

// validateDevice runs additional validations for OTP devices.
// Prefer adding new validation logic to types.MFADevice.CheckAndSetDefaults
// instead.
func (d *MFADevice) validateDevice() error {
switch dev := d.Device.(type) {
case *MFADevice_Totp:
if dev.Totp == nil {
return trace.BadParameter("MFADevice has malformed TOTPDevice")
}
if dev.Totp.Key == "" {
return trace.BadParameter("TOTPDevice missing Key field")
}
case *MFADevice_Webauthn:
if dev.Webauthn == nil {
return trace.BadParameter("MFADevice has malformed WebauthnDevice")
}
if len(dev.Webauthn.CredentialId) == 0 {
return trace.BadParameter("WebauthnDevice missing CredentialId field")
}
if len(dev.Webauthn.PublicKeyCbor) == 0 {
return trace.BadParameter("WebauthnDevice missing PublicKeyCbor field")
}
case *MFADevice_Sso:
if dev.Sso == nil {
return trace.BadParameter("MFADevice has malformed SSODevice")
}
if dev.Sso.ConnectorId == "" {
return trace.BadParameter("SSODevice missing ConnectorId field")
}
if dev.Sso.ConnectorType == "" {
return trace.BadParameter("SSODevice missing ConnectorType field")
}
case *MFADevice_U2F:
default:
return trace.BadParameter("MFADevice has Device field of unknown type %T", dev)
}
return nil
}

func (d *MFADevice) WithoutSensitiveData() (*MFADevice, error) {
if d == nil {
return nil, trace.BadParameter("cannot hide sensitive data on empty object")
Expand All @@ -33,9 +125,51 @@ func (d *MFADevice) WithoutSensitiveData() (*MFADevice, error) {
// OK, no sensitive secrets.
case *MFADevice_Webauthn:
// OK, no sensitive secrets.
case *MFADevice_Sso:
// OK, no sensitive secrets.
default:
return nil, trace.BadParameter("unsupported MFADevice type %T", d.Device)
}

return out, nil
}

func (d *MFADevice) GetKind() string { return d.Kind }
func (d *MFADevice) GetSubKind() string { return d.SubKind }
func (d *MFADevice) SetSubKind(sk string) { d.SubKind = sk }
func (d *MFADevice) GetVersion() string { return d.Version }
func (d *MFADevice) GetMetadata() Metadata { return d.Metadata }
func (d *MFADevice) GetName() string { return d.Metadata.GetName() }
func (d *MFADevice) SetName(n string) { d.Metadata.SetName(n) }
func (d *MFADevice) GetRevision() string { return d.Metadata.GetRevision() }
func (d *MFADevice) SetRevision(rev string) { d.Metadata.SetRevision(rev) }
func (d *MFADevice) Expiry() time.Time { return d.Metadata.Expiry() }
func (d *MFADevice) SetExpiry(exp time.Time) { d.Metadata.SetExpiry(exp) }

// MFAType returns the human-readable name of the MFA protocol of this device.
func (d *MFADevice) MFAType() string {
switch d.Device.(type) {
case *MFADevice_Totp:
return "TOTP"
case *MFADevice_U2F:
return "U2F"
case *MFADevice_Webauthn:
return "WebAuthn"
case *MFADevice_Sso:
return "SSO"
default:
return "unknown"
}
}

func (d *MFADevice) MarshalJSON() ([]byte, error) {
buf := new(bytes.Buffer)
err := (&jsonpb.Marshaler{}).Marshal(buf, d)
return buf.Bytes(), trace.Wrap(err)
}

func (d *MFADevice) UnmarshalJSON(buf []byte) error {
unmarshaler := jsonpb.Unmarshaler{AllowUnknownFields: true}
err := unmarshaler.Unmarshal(bytes.NewReader(buf), d)
return trace.Wrap(err)
}
Loading

0 comments on commit 5d27bf8

Please sign in to comment.