From 942da4a72845d2175dfe6e2b760aa2cfa2d44273 Mon Sep 17 00:00:00 2001 From: Edward Dowling Date: Wed, 16 Oct 2024 18:30:57 +0100 Subject: [PATCH] Add access monitoring rules to msteams plugins --- integrations/access/msteams/app.go | 96 ++++++++++++++++---- integrations/access/msteams/testlib/suite.go | 80 ++++++++++++++-- 2 files changed, 149 insertions(+), 27 deletions(-) diff --git a/integrations/access/msteams/app.go b/integrations/access/msteams/app.go index 885d9ca8de65..3cbbe5004d88 100644 --- a/integrations/access/msteams/app.go +++ b/integrations/access/msteams/app.go @@ -17,12 +17,14 @@ package msteams import ( "context" "log/slog" + "slices" "time" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/access/accessmonitoring" "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/common/teleport" "github.com/gravitational/teleport/integrations/lib" @@ -53,7 +55,8 @@ type App struct { watcherJob lib.ServiceJob pd *pd.CompareAndSwap[PluginData] - log *slog.Logger + log *slog.Logger + accessMonitoringRules *accessmonitoring.RuleHandler *lib.Process } @@ -85,13 +88,11 @@ func (a *App) Run(ctx context.Context) error { } a.Process = lib.NewProcess(ctx) - a.watcherJob, err = a.newWatcherJob() if err != nil { return trace.Wrap(err) } a.SpawnCriticalJob(a.mainJob) - a.SpawnCriticalJob(a.watcherJob) select { case <-ctx.Done(): @@ -116,10 +117,14 @@ func (a *App) init(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, initTimeout) defer cancel() - var err error - a.apiClient, err = common.GetTeleportClient(ctx, a.conf.Teleport) - if err != nil { - return trace.Wrap(err) + if a.conf.Client != nil { + a.apiClient = a.conf.Client + } else { + var err error + a.apiClient, err = common.GetTeleportClient(ctx, a.conf.Teleport) + if err != nil { + return trace.Wrap(err) + } } a.pd = pd.NewCAS( @@ -145,6 +150,24 @@ func (a *App) init(ctx context.Context) error { return trace.Wrap(err) } + a.accessMonitoringRules = accessmonitoring.NewRuleHandler(accessmonitoring.RuleHandlerConfig{ + Client: a.apiClient, + PluginName: pluginName, + // Map msteams.RecipientData onto the common recipient type used + // by the access monitoring rules watcher. + FetchRecipientCallback: func(ctx context.Context, name string) (*common.Recipient, error) { + msTeamsRecipient, err := a.bot.FetchRecipient(ctx, name) + if err != nil { + return nil, trace.Wrap(err) + } + return &common.Recipient{ + Name: name, + ID: msTeamsRecipient.ID, + Kind: string(msTeamsRecipient.Kind), + }, nil + }, + }) + return a.initBot(ctx) } @@ -187,27 +210,52 @@ func (a *App) initBot(ctx context.Context) error { return nil } -// newWatcherJob creates WatcherJob -func (a *App) newWatcherJob() (lib.ServiceJob, error) { - return watcherjob.NewJob( +// run starts the main process +func (a *App) run(ctx context.Context) error { + + process := lib.MustGetProcess(ctx) + + watchKinds := []types.WatchKind{ + {Kind: types.KindAccessRequest}, + {Kind: types.KindAccessMonitoringRule}, + } + acceptedWatchKinds := make([]string, 0, len(watchKinds)) + watcherJob, err := watcherjob.NewJobWithConfirmedWatchKinds( a.apiClient, watcherjob.Config{ - Watch: types.Watch{ - Kinds: []types.WatchKind{{Kind: types.KindAccessRequest}}, - }, + Watch: types.Watch{Kinds: watchKinds, AllowPartialSuccess: true}, EventFuncTimeout: handlerTimeout, }, a.onWatcherEvent, + func(ws types.WatchStatus) { + for _, watchKind := range ws.GetKinds() { + acceptedWatchKinds = append(acceptedWatchKinds, watchKind.Kind) + } + }, ) -} - -// run starts the main process -func (a *App) run(ctx context.Context) error { - ok, err := a.watcherJob.WaitReady(ctx) if err != nil { return trace.Wrap(err) } + process.SpawnCriticalJob(watcherJob) + + ok, err := watcherJob.WaitReady(ctx) + if err != nil { + return trace.Wrap(err) + } + if len(acceptedWatchKinds) == 0 { + return trace.BadParameter("failed to initialize watcher for all the required resources: %+v", + watchKinds) + } + // Check if KindAccessMonitoringRule resources are being watched, + // the role the plugin is running as may not have access. + if slices.Contains(acceptedWatchKinds, types.KindAccessMonitoringRule) { + if err := a.accessMonitoringRules.InitAccessMonitoringRulesCache(ctx); err != nil { + return trace.Wrap(err, "initializing Access Monitoring Rule cache") + } + } + a.watcherJob = watcherJob + a.watcherJob.SetReady(ok) if ok { a.log.InfoContext(ctx, "Plugin is ready") } else { @@ -243,6 +291,10 @@ func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, err // onWatcherEvent called when an access request event is received func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { kind := event.Resource.GetKind() + if kind == types.KindAccessMonitoringRule { + return trace.Wrap(a.accessMonitoringRules.HandleAccessMonitoringRule(ctx, event)) + } + if kind != types.KindAccessRequest { return trace.Errorf("unexpected kind %s", kind) } @@ -480,6 +532,14 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) recipientSet := stringset.New() a.log.DebugContext(ctx, "Getting suggested reviewer recipients") + accessRuleRecipients := a.accessMonitoringRules.RecipientsFromAccessMonitoringRules(ctx, req) + accessRuleRecipients.ForEach(func(r common.Recipient) { + recipientSet.Add(r.Name) + }) + if recipientSet.Len() != 0 { + return recipientSet.ToSlice() + } + var validEmailsSuggReviewers []string for _, reviewer := range req.GetSuggestedReviewers() { if !lib.IsEmail(reviewer) { diff --git a/integrations/access/msteams/testlib/suite.go b/integrations/access/msteams/testlib/suite.go index eaa9e138a9a1..a9a209c04cd1 100644 --- a/integrations/access/msteams/testlib/suite.go +++ b/integrations/access/msteams/testlib/suite.go @@ -28,6 +28,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" + v1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/msteams" @@ -37,7 +39,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/testing/integration" ) -// MsTeamsBaseSuite is the Slack access plugin test suite. +// MsTeamsBaseSuite is the MsTeams access plugin test suite. // It implements the testify.TestingSuite interface. type MsTeamsBaseSuite struct { *integration.AccessRequestSuite @@ -51,16 +53,19 @@ type MsTeamsBaseSuite struct { reviewer2TeamsUser msapi.User } -// SetupTest starts a fake Slack, generates the plugin configuration, and loads -// the fixtures in Slack. It runs for each test. +// SetupTest starts a fake MsTeams, generates the plugin configuration, and loads +// the fixtures in MsTeams. It runs for each test. func (s *MsTeamsBaseSuite) SetupTest() { t := s.T() + + err := logger.Setup(logger.Config{Severity: "debug"}) + require.NoError(t, err) s.raceNumber = runtime.GOMAXPROCS(0) s.fakeTeams = NewFakeTeams(s.raceNumber) t.Cleanup(s.fakeTeams.Close) - // We need requester users as well, the slack plugin sends messages to users + // We need requester users as well, the MsTeams plugin sends messages to users // when their access request got approved. s.requesterOSSTeamsUser = s.fakeTeams.StoreUser(msapi.User{Name: "Requester OSS", Mail: integration.RequesterOSSUserName}) s.requester1TeamsUser = s.fakeTeams.StoreUser(msapi.User{Name: "Requester Ent", Mail: integration.Requester1UserName}) @@ -71,16 +76,17 @@ func (s *MsTeamsBaseSuite) SetupTest() { var conf msteams.Config conf.Teleport = s.TeleportConfig() + apiClient, err := common.GetTeleportClient(context.Background(), s.TeleportConfig()) + require.NoError(t, err) + conf.Client = apiClient + conf.StatusSink = s.fakeStatusSink conf.MSAPI = s.fakeTeams.Config conf.MSAPI.SetBaseURLs(s.fakeTeams.URL(), s.fakeTeams.URL(), s.fakeTeams.URL()) - conf.Log = logger.Config{ - Severity: "debug", - } s.appConfig = &conf } -// startApp starts the Slack plugin, waits for it to become ready and returns. +// startApp starts the MsTeams plugin, waits for it to become ready and returns. func (s *MsTeamsBaseSuite) startApp() { s.T().Helper() t := s.T() @@ -414,7 +420,9 @@ func (s *MsTeamsSuiteEnterprise) TestRace() { ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) t.Cleanup(cancel) - s.appConfig.Log.Severity = "debug" // Turn off noisy debug logging + err := logger.Setup(logger.Config{Severity: "info"}) // Turn off noisy debug logging + require.NoError(t, err) + s.startApp() var ( @@ -527,3 +535,57 @@ func (s *MsTeamsSuiteEnterprise) TestRace() { return next }) } + +func (s *MsTeamsSuiteOSS) TestRecipientsFromAccessMonitoringRule() { + t := s.T() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + t.Cleanup(cancel) + + s.startApp() + + _, err := s.ClientByName(integration.RulerUserName). + AccessMonitoringRulesClient(). + CreateAccessMonitoringRule(ctx, &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &v1.Metadata{ + Name: "test-msteams-amr", + }, + Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "!is_empty(access_request.spec.roles)", + Notification: &accessmonitoringrulesv1.Notification{ + Name: "msteams", + Recipients: []string{ + s.reviewer1TeamsUser.ID, + s.reviewer2TeamsUser.Mail, + }, + }, + }, + }) + assert.NoError(t, err) + + // Test execution: create an access request + req := s.CreateAccessRequest(ctx, integration.RequesterOSSUserName, nil) + + s.checkPluginData(ctx, req.GetName(), func(data msteams.PluginData) bool { + return len(data.TeamsData) > 0 + }) + + title := "Access Request " + req.GetName() + msgs, err := s.getNewMessages(ctx, 2) + require.NoError(t, err) + + var body1 testTeamsMessage + require.NoError(t, json.Unmarshal([]byte(msgs[0].Body), &body1)) + body1.checkTitle(t, title) + require.Equal(t, msgs[0].RecipientID, s.reviewer1TeamsUser.ID) + + var body2 testTeamsMessage + require.NoError(t, json.Unmarshal([]byte(msgs[1].Body), &body2)) + body1.checkTitle(t, title) + require.Equal(t, msgs[1].RecipientID, s.reviewer2TeamsUser.ID) + + assert.NoError(t, s.ClientByName(integration.RulerUserName). + AccessMonitoringRulesClient().DeleteAccessMonitoringRule(ctx, "test-msteams-amr")) +}