From 15d1a5af7f5a95ee90d8c8eb9589cc23e9ba1c4b Mon Sep 17 00:00:00 2001 From: favonia Date: Thu, 11 Jul 2024 12:04:06 +0300 Subject: [PATCH] feat(api): recheck tokens if the network was temporarily down (#790) --- cmd/ddns/ddns.go | 9 + internal/api/base.go | 6 + internal/api/cloudflare.go | 94 ++++++++-- internal/api/cloudflare_test.go | 318 ++++++++++++++++++++++++-------- internal/mocks/mock_api.go | 38 ++++ internal/mocks/mock_setter.go | 38 ++++ internal/setter/base.go | 6 + internal/setter/setter.go | 4 + internal/setter/setter_test.go | 29 +++ 9 files changed, 445 insertions(+), 97 deletions(-) diff --git a/cmd/ddns/ddns.go b/cmd/ddns/ddns.go index b898fb46..f40dafc3 100644 --- a/cmd/ddns/ddns.go +++ b/cmd/ddns/ddns.go @@ -129,6 +129,15 @@ func realMain() int { //nolint:funlen if first && !c.UpdateOnStart { monitor.SuccessAll(ctx, ppfmt, c.Monitors, "Started (no action)") } else { + if c.UpdateCron != nil && !s.SanityCheck(ctx, ppfmt) { + monitor.SuccessAll(ctx, ppfmt, c.Monitors, "Invalid Cloudflare API token") + notifier.SendAll(ctx, ppfmt, c.Notifiers, + "The Cloudflare API token is invalid. "+ + "Please check the value of CF_API_TOKEN or CF_API_TOKEN_FILE.", + ) + return 1 + } + msg := updater.UpdateIPs(ctxWithSignals, ppfmt, c, s) monitor.PingMessageAll(ctx, ppfmt, c.Monitors, msg) notifier.SendMessageAll(ctx, ppfmt, c.Notifiers, msg) diff --git a/internal/api/base.go b/internal/api/base.go index 0d93f9e6..61b5ee45 100644 --- a/internal/api/base.go +++ b/internal/api/base.go @@ -15,7 +15,13 @@ import ( // A Handle represents a generic API to update DNS records. Currently, the only implementation is Cloudflare. type Handle interface { + // Perform basic checking. It returns false when we should give up + // all future operations. + SanityCheck(ctx context.Context, ppfmt pp.PP) bool + // ListRecords lists all matching DNS records. + // + // The second return value means whether the list is cached. ListRecords(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type) (map[string]netip.Addr, bool, bool) //nolint:lll // DeleteRecord deletes one DNS record. diff --git a/internal/api/cloudflare.go b/internal/api/cloudflare.go index 871f0941..50391db8 100644 --- a/internal/api/cloudflare.go +++ b/internal/api/cloudflare.go @@ -2,6 +2,7 @@ package api import ( "context" + "errors" "net/netip" "time" @@ -33,9 +34,11 @@ func newCache[K comparable, V any](cacheExpiration time.Duration) *ttlcache.Cach // A CloudflareHandle implements the [Handle] interface with the Cloudflare API. type CloudflareHandle struct { - cf *cloudflare.API - accountID string - cache CloudflareCache + cf *cloudflare.API + sanityPermChecked bool + sanityPermPassed bool + accountID string + cache CloudflareCache } // A CloudflareAuth implements the [Auth] interface, holding the authentication data to create a [CloudflareHandle]. @@ -46,7 +49,7 @@ type CloudflareAuth struct { } // New creates a [CloudflareHandle] from the authentication data. -func (t *CloudflareAuth) New(ctx context.Context, ppfmt pp.PP, cacheExpiration time.Duration) (Handle, bool) { +func (t *CloudflareAuth) New(_ context.Context, ppfmt pp.PP, cacheExpiration time.Duration) (Handle, bool) { handle, err := cloudflare.NewWithAPIToken(t.Token) if err != nil { ppfmt.Errorf(pp.EmojiUserError, "Failed to prepare the Cloudflare authentication: %v", err) @@ -58,19 +61,11 @@ func (t *CloudflareAuth) New(ctx context.Context, ppfmt pp.PP, cacheExpiration t handle.BaseURL = t.BaseURL } - // verify Cloudflare token - // - // ideally, we should also verify accountID here, but that is impossible without - // more permissions included in the API token. - if _, err := handle.VerifyAPIToken(ctx); err != nil { - ppfmt.Errorf(pp.EmojiUserError, "The Cloudflare API token could not be verified: %v", err) - ppfmt.Errorf(pp.EmojiUserError, "Please double-check the value of CF_API_TOKEN or CF_API_TOKEN_FILE") - return nil, false - } - - return &CloudflareHandle{ - cf: handle, - accountID: t.AccountID, + h := &CloudflareHandle{ + cf: handle, + sanityPermChecked: false, + sanityPermPassed: false, + accountID: t.AccountID, cache: CloudflareCache{ listRecords: map[ipnet.Type]*ttlcache.Cache[string, map[string]netip.Addr]{ ipnet.IP4: newCache[string, map[string]netip.Addr](cacheExpiration), @@ -79,7 +74,9 @@ func (t *CloudflareAuth) New(ctx context.Context, ppfmt pp.PP, cacheExpiration t activeZones: newCache[string, []string](cacheExpiration), zoneOfDomain: newCache[string, string](cacheExpiration), }, - }, true + } + + return h, true } // FlushCache flushes the API cache. @@ -91,6 +88,59 @@ func (h *CloudflareHandle) FlushCache() { h.cache.zoneOfDomain.DeleteAll() } +// SanityCheck verifies Cloudflare tokens. +// +// Ideally, we should also verify accountID here, but that is impossible without +// more permissions included in the API token. +func (h *CloudflareHandle) SanityCheck(ctx context.Context, ppfmt pp.PP) bool { + if h.sanityPermChecked { + return h.sanityPermPassed + } + + quickCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + ok := true + res, err := h.cf.VerifyAPIToken(quickCtx) + if err != nil { + // Check if the token is permanently invalid... + var aerr *cloudflare.AuthorizationError + var rerr *cloudflare.RequestError + if errors.As(err, &aerr) || errors.As(err, &rerr) { + ppfmt.Errorf(pp.EmojiUserError, "The Cloudflare API token is invalid: %v", err) + ok = false + goto permanently + } + ppfmt.Warningf(pp.EmojiWarning, "Could not verify the Cloudflare API token: %v", err) + return true // It could be that the network times out. + } + switch res.Status { + case "active": + case "disabled", "expired": + ppfmt.Errorf(pp.EmojiUserError, "The Cloudflare API token is %s", res.Status) + ok = false + goto permanently + default: + ppfmt.Errorf(pp.EmojiImpossible, "The Cloudflare API token is in an undocumented state: %s", res.Status) + ppfmt.Errorf(pp.EmojiImpossible, "Please report the bug at https://github.com/favonia/cloudflare-ddns/issues/new") //nolint:lll + ok = false + goto permanently + } + + if !res.ExpiresOn.IsZero() { + ppfmt.Warningf(pp.EmojiAlarm, "The token will expire at %s", + res.ExpiresOn.In(time.Local).Format(time.RFC1123Z)) + } + +permanently: + if !ok { + ppfmt.Errorf(pp.EmojiUserError, "Please double-check the value of CF_API_TOKEN or CF_API_TOKEN_FILE") + } + h.sanityPermChecked = true + h.sanityPermPassed = ok + return ok +} + // ActiveZones returns a list of zone IDs with the zone name. func (h *CloudflareHandle) ActiveZones(ctx context.Context, ppfmt pp.PP, name string) ([]string, bool) { // WithZoneFilters does not work with the empty zone name, @@ -109,6 +159,14 @@ func (h *CloudflareHandle) ActiveZones(ctx context.Context, ppfmt pp.PP, name st return nil, false } + // No need to perform any sanity checking in future. ;-) + // + // This is the best place to force pass the sanity check + // because ListZonesContext will be the first real + // API call. + h.sanityPermChecked = true + h.sanityPermPassed = true + ids := make([]string, 0, len(res.Result)) for _, zone := range res.Result { // The list of possible statuses was at https://api.cloudflare.com/#zone-list-zones diff --git a/internal/api/cloudflare_test.go b/internal/api/cloudflare_test.go index 9472e228..9802e603 100644 --- a/internal/api/cloudflare_test.go +++ b/internal/api/cloudflare_test.go @@ -72,7 +72,7 @@ func newServerAuth(t *testing.T, emptyAccountID bool) (*http.ServeMux, *api.Clou return mux, &auth } -func handleTokensVerify(t *testing.T, w http.ResponseWriter, r *http.Request) { +func handleSanityCheck(t *testing.T, w http.ResponseWriter, r *http.Request) { t.Helper() require.Equal(t, http.MethodGet, r.Method) @@ -82,31 +82,29 @@ func handleTokensVerify(t *testing.T, w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{ - "result": { "id": "%s", "status": "active" }, - "success": true, - "errors": [], - "messages": [ - { - "code": 10000, - "message": "This API Token is valid and active", - "type": null - } - ] - }`, + "result": { "id": "%s", "status": "active" }, + "success": true, + "errors": [], + "messages": [ + { + "code": 10000, + "message": "This API Token is valid and active", + "type": null + } + ] + }`, mockID("result", 0)) } -func newHandle(t *testing.T, emptyAccountID bool) (*http.ServeMux, api.Handle) { +func newHandle(t *testing.T, emptyAccountID bool, mockPP *mocks.MockPP) (*http.ServeMux, api.Handle) { t.Helper() - mockCtrl := gomock.NewController(t) mux, auth := newServerAuth(t, emptyAccountID) mux.HandleFunc("/user/tokens/verify", func(w http.ResponseWriter, r *http.Request) { - handleTokensVerify(t, w, r) + handleSanityCheck(t, w, r) }) - mockPP := mocks.NewMockPP(mockCtrl) h, ok := auth.New(context.Background(), mockPP, time.Second) require.True(t, ok) require.NotNil(t, h) @@ -116,56 +114,217 @@ func newHandle(t *testing.T, emptyAccountID bool) (*http.ServeMux, api.Handle) { func TestNewValid(t *testing.T) { t.Parallel() + mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) + + _, h := newHandle(t, false, mockPP) - newHandle(t, false) + require.True(t, h.SanityCheck(context.Background(), mockPP)) + + // Test again to test the caching + require.True(t, h.SanityCheck(context.Background(), mockPP)) } func TestNewEmpty(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) _, auth := newServerAuth(t, false) auth.Token = "" - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Errorf(pp.EmojiUserError, "Failed to prepare the Cloudflare authentication: %v", gomock.Any()) h, ok := auth.New(context.Background(), mockPP, time.Second) require.False(t, ok) require.Nil(t, h) } +func TestSanityCheckExpiring(t *testing.T) { + t.Parallel() + + for name, tc := range map[string]struct { + resp string + ok bool + prepareMockPP func(*mocks.MockPP) + }{ + "expiring": { + `{ + "success": true, + "errors": [], + "messages": [], + "result": { + "id": "11111111111111111111111111111111", + "status": "active", + "expires_on": "3000-01-01T00:00:00Z" + } +}`, + true, + func(p *mocks.MockPP) { + deadline, err := time.Parse(time.RFC3339, "3000-01-01T00:00:00Z") + require.NoError(t, err) + p.EXPECT().Warningf(pp.EmojiAlarm, "The token will expire at %s", + deadline.In(time.Local).Format(time.RFC1123Z)) + }, + }, + "expired": { + `{ + "success": true, + "errors": [], + "messages": [], + "result": { + "id": "11111111111111111111111111111111", + "status": "expired" + } +}`, + false, + func(p *mocks.MockPP) { + gomock.InOrder( + p.EXPECT().Errorf(pp.EmojiUserError, "The Cloudflare API token is %s", "expired"), + p.EXPECT().Errorf(pp.EmojiUserError, "Please double-check the value of CF_API_TOKEN or CF_API_TOKEN_FILE"), + ) + }, + }, + "funny": { + `{ + "success": true, + "errors": [], + "messages": [], + "result": { + "id": "11111111111111111111111111111111", + "status": "funny" + } +}`, + false, + func(p *mocks.MockPP) { + gomock.InOrder( + p.EXPECT().Errorf(pp.EmojiImpossible, "The Cloudflare API token is in an undocumented state: %s", "funny"), + p.EXPECT().Errorf(pp.EmojiImpossible, "Please report the bug at https://github.com/favonia/cloudflare-ddns/issues/new"), //nolint:lll + p.EXPECT().Errorf(pp.EmojiUserError, "Please double-check the value of CF_API_TOKEN or CF_API_TOKEN_FILE"), + ) + }, + }, + "disabled": { + `{ + "success": true, + "errors": [], + "messages": [], + "result": { + "id": "11111111111111111111111111111111", + "status": "disabled" + } +}`, + false, + func(p *mocks.MockPP) { + gomock.InOrder( + p.EXPECT().Errorf(pp.EmojiUserError, "The Cloudflare API token is %s", "disabled"), + p.EXPECT().Errorf(pp.EmojiUserError, "Please double-check the value of CF_API_TOKEN or CF_API_TOKEN_FILE"), + ) + }, + }, + } { + t.Run(name, func(t *testing.T) { + t.Parallel() + + mockCtrl := gomock.NewController(t) + mux, auth := newServerAuth(t, false) + mux.HandleFunc("/user/tokens/verify", func(w http.ResponseWriter, r *http.Request) { + if !assert.Equal(t, http.MethodGet, r.Method) || + !assert.Equal(t, []string{mockAuthString}, r.Header["Authorization"]) || + !assert.Empty(t, r.URL.Query()) { + panic(http.ErrAbortHandler) + } + + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, tc.resp) + }) + + mockPP := mocks.NewMockPP(mockCtrl) + if tc.prepareMockPP != nil { + tc.prepareMockPP(mockPP) + } + h, ok := auth.New(context.Background(), mockPP, time.Second) + require.True(t, ok) + require.NotNil(t, h) + require.Equal(t, tc.ok, h.SanityCheck(context.Background(), mockPP)) + }) + } +} + func TestNewInvalid(t *testing.T) { t.Parallel() + + for name, resp := range map[string]string{ + "invalid-token": `{ + "success": false, + "errors": [{ "code": 1000, "message": "Invalid API Token" }], + "messages": [], + "result": null +}`, + "invalid-format": `{ + "success": false, + "errors": [ + { + "code": 6003, + "message": "Invalid request headers", + "error_chain": [ + { "code": 6111, "message": "Invalid format for Authorization header" } + ] + } + ], + "messages": [], + "result": null +}`, + } { + t.Run(name, func(t *testing.T) { + t.Parallel() + + mockCtrl := gomock.NewController(t) + mux, auth := newServerAuth(t, false) + mux.HandleFunc("/user/tokens/verify", func(w http.ResponseWriter, r *http.Request) { + if !assert.Equal(t, http.MethodGet, r.Method) || + !assert.Equal(t, []string{mockAuthString}, r.Header["Authorization"]) || + !assert.Empty(t, r.URL.Query()) { + panic(http.ErrAbortHandler) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, resp) + }) + + mockPP := mocks.NewMockPP(mockCtrl) + gomock.InOrder( + mockPP.EXPECT().Errorf(pp.EmojiUserError, "The Cloudflare API token is invalid: %v", gomock.Any()), + mockPP.EXPECT().Errorf(pp.EmojiUserError, "Please double-check the value of CF_API_TOKEN or CF_API_TOKEN_FILE"), + ) + h, ok := auth.New(context.Background(), mockPP, time.Second) + require.True(t, ok) + require.NotNil(t, h) + require.False(t, h.SanityCheck(context.Background(), mockPP)) + }) + } +} + +func TestSanityCheckTimeout(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) mux, auth := newServerAuth(t, false) + mux.HandleFunc("/user/tokens/verify", func(_ http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, []string{mockAuthString}, r.Header["Authorization"]) + assert.Empty(t, r.URL.Query()) - mux.HandleFunc("/user/tokens/verify", func(w http.ResponseWriter, r *http.Request) { - if !assert.Equal(t, http.MethodGet, r.Method) || - !assert.Equal(t, []string{mockAuthString}, r.Header["Authorization"]) || - !assert.Empty(t, r.URL.Query()) { - panic(http.ErrAbortHandler) - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - fmt.Fprintf(w, - `{ - "success": false, - "errors": [{ "code": 1000, "message": "Invalid API Token" }], - "messages": [], - "result": null - }`) + panic(http.ErrAbortHandler) }) - mockPP := mocks.NewMockPP(mockCtrl) - gomock.InOrder( - mockPP.EXPECT().Errorf(pp.EmojiUserError, "The Cloudflare API token could not be verified: %v", gomock.Any()), - mockPP.EXPECT().Errorf(pp.EmojiUserError, "Please double-check the value of CF_API_TOKEN or CF_API_TOKEN_FILE"), - ) + mockPP.EXPECT().Warningf(pp.EmojiWarning, "Could not verify the Cloudflare API token: %v", gomock.Any()) h, ok := auth.New(context.Background(), mockPP, time.Second) - require.False(t, ok) - require.Nil(t, h) + require.True(t, ok) + require.NotNil(t, h) + require.True(t, h.SanityCheck(context.Background(), mockPP)) } func mockZone(name string, i int, status string) *cloudflare.Zone { @@ -279,10 +438,10 @@ func (h *zonesHandler) isExhausted() bool { func TestActiveZonesRoot(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - _, h := newHandle(t, false) + _, h := newHandle(t, false, mockPP) - mockPP := mocks.NewMockPP(mockCtrl) zones, ok := h.(*api.CloudflareHandle).ActiveZones(context.Background(), mockPP, "") require.True(t, ok) require.Empty(t, zones) @@ -291,13 +450,13 @@ func TestActiveZonesRoot(t *testing.T) { func TestActiveZonesTwo(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active", "active"}}, 1) - mockPP := mocks.NewMockPP(mockCtrl) zones, ok := h.(*api.CloudflareHandle).ActiveZones(context.Background(), mockPP, "test.org") require.True(t, ok) require.Equal(t, mockIDs("test.org", 0, 1), zones) @@ -328,13 +487,13 @@ func TestActiveZonesTwo(t *testing.T) { func TestActiveZonesEmpty(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{}, 1) - mockPP := mocks.NewMockPP(mockCtrl) zones, ok := h.(*api.CloudflareHandle).ActiveZones(context.Background(), mockPP, "test.org") require.True(t, ok) require.Empty(t, zones) @@ -495,12 +654,13 @@ func TestZoneOfDomain(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) - mux, h := newHandle(t, tc.emptyAccountID) + mockPP := mocks.NewMockPP(mockCtrl) + + mux, h := newHandle(t, tc.emptyAccountID, mockPP) zh := newZonesHandler(t, mux, tc.emptyAccountID) zh.set(tc.zoneStatuses, tc.accessCount) - mockPP := mocks.NewMockPP(mockCtrl) if tc.prepareMockPP != nil { tc.prepareMockPP(mockPP) } @@ -524,10 +684,10 @@ func TestZoneOfDomain(t *testing.T) { func TestZoneOfDomainInvalid(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - _, h := newHandle(t, false) + _, h := newHandle(t, false, mockPP) - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf( pp.EmojiError, "Failed to check the existence of a zone named %q: %v", @@ -595,8 +755,9 @@ func mockDNSListResponseFromAddr(ipNet ipnet.Type, name string, ips map[string]n func TestListRecords(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 2) @@ -635,7 +796,6 @@ func TestListRecords(t *testing.T) { expected := map[string]netip.Addr{"record1": mustIP("::1"), "record2": mustIP("::2")} ipNet, ips, accessCount = ipnet.IP6, expected, 1 - mockPP := mocks.NewMockPP(mockCtrl) ips, cached, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6) require.True(t, ok) require.False(t, cached) @@ -654,8 +814,9 @@ func TestListRecords(t *testing.T) { func TestListRecordsInvalidIPAddress(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 2) @@ -693,7 +854,6 @@ func TestListRecordsInvalidIPAddress(t *testing.T) { }) ipNet, accessCount = ipnet.IP6, 1 - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf( pp.EmojiImpossible, "Failed to parse the IP address in records of %q: %v", @@ -725,8 +885,9 @@ func TestListRecordsInvalidIPAddress(t *testing.T) { func TestListRecordsWildcard(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 1) @@ -765,7 +926,6 @@ func TestListRecordsWildcard(t *testing.T) { expected := map[string]netip.Addr{"record1": mustIP("::1"), "record2": mustIP("::2")} ipNet, ips, accessCount = ipnet.IP6, expected, 1 - mockPP := mocks.NewMockPP(mockCtrl) ips, cached, ok := h.ListRecords(context.Background(), mockPP, domain.Wildcard("test.org"), ipnet.IP6) require.True(t, ok) require.False(t, cached) @@ -783,13 +943,13 @@ func TestListRecordsWildcard(t *testing.T) { func TestListRecordsInvalidDomain(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 2) - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf(pp.EmojiError, "Failed to retrieve records of %q: %v", "sub.test.org", gomock.Any()) ips, cached, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP4) require.False(t, ok) @@ -807,10 +967,10 @@ func TestListRecordsInvalidDomain(t *testing.T) { func TestListRecordsInvalidZone(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - _, h := newHandle(t, false) + _, h := newHandle(t, false, mockPP) - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf( pp.EmojiError, "Failed to check the existence of a zone named %q: %v", @@ -862,8 +1022,9 @@ func mockDNSRecordResponse(id string, ipNet ipnet.Type, name string, ip string) func TestDeleteRecordValid(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 2) @@ -911,7 +1072,6 @@ func TestDeleteRecordValid(t *testing.T) { }) deleteAccessCount = 1 - mockPP := mocks.NewMockPP(mockCtrl) ok := h.DeleteRecord(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6, "record1") require.True(t, ok) @@ -928,13 +1088,13 @@ func TestDeleteRecordValid(t *testing.T) { func TestDeleteRecordInvalid(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 2) - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf(pp.EmojiError, "Failed to delete a stale %s record of %q (ID: %s): %v", "AAAA", "sub.test.org", @@ -948,10 +1108,10 @@ func TestDeleteRecordInvalid(t *testing.T) { func TestDeleteRecordZoneInvalid(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - _, h := newHandle(t, false) + _, h := newHandle(t, false, mockPP) - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf(pp.EmojiError, "Failed to check the existence of a zone named %q: %v", "sub.test.org", gomock.Any(), @@ -964,8 +1124,9 @@ func TestDeleteRecordZoneInvalid(t *testing.T) { func TestUpdateRecordValid(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 2) @@ -1027,7 +1188,6 @@ func TestUpdateRecordValid(t *testing.T) { }) updateAccessCount = 1 - mockPP := mocks.NewMockPP(mockCtrl) ok := h.UpdateRecord(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6, "record1", mustIP("::2")) require.True(t, ok) @@ -1044,13 +1204,13 @@ func TestUpdateRecordValid(t *testing.T) { func TestUpdateRecordInvalid(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 2) - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf(pp.EmojiError, "Failed to update a stale %s record of %q (ID: %s): %v", "AAAA", "sub.test.org", @@ -1064,10 +1224,10 @@ func TestUpdateRecordInvalid(t *testing.T) { func TestUpdateRecordInvalidZone(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - _, h := newHandle(t, false) + _, h := newHandle(t, false, mockPP) - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf(pp.EmojiError, "Failed to check the existence of a zone named %q: %v", "sub.test.org", gomock.Any(), @@ -1080,8 +1240,9 @@ func TestUpdateRecordInvalidZone(t *testing.T) { func TestCreateRecordValid(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 2) @@ -1143,7 +1304,6 @@ func TestCreateRecordValid(t *testing.T) { }) createAccessCount = 1 - mockPP := mocks.NewMockPP(mockCtrl) actualID, ok := h.CreateRecord(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6, mustIP("::1"), 100, false, "hello") //nolint:lll require.True(t, ok) require.Equal(t, "record1", actualID) @@ -1161,13 +1321,13 @@ func TestCreateRecordValid(t *testing.T) { func TestCreateRecordInvalid(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - mux, h := newHandle(t, false) + mux, h := newHandle(t, false, mockPP) zh := newZonesHandler(t, mux, false) zh.set(map[string][]string{"test.org": {"active"}}, 2) - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf(pp.EmojiError, "Failed to add a new %s record of %q: %v", "AAAA", "sub.test.org", @@ -1181,10 +1341,10 @@ func TestCreateRecordInvalid(t *testing.T) { func TestCreateRecordInvalidZone(t *testing.T) { t.Parallel() mockCtrl := gomock.NewController(t) + mockPP := mocks.NewMockPP(mockCtrl) - _, h := newHandle(t, false) + _, h := newHandle(t, false, mockPP) - mockPP := mocks.NewMockPP(mockCtrl) mockPP.EXPECT().Warningf(pp.EmojiError, "Failed to check the existence of a zone named %q: %v", "sub.test.org", gomock.Any(), diff --git a/internal/mocks/mock_api.go b/internal/mocks/mock_api.go index ccb26f4d..2ab77f63 100644 --- a/internal/mocks/mock_api.go +++ b/internal/mocks/mock_api.go @@ -196,6 +196,44 @@ func (c *HandleListRecordsCall) DoAndReturn(f func(context.Context, pp.PP, domai return c } +// SanityCheck mocks base method. +func (m *MockHandle) SanityCheck(arg0 context.Context, arg1 pp.PP) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SanityCheck", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// SanityCheck indicates an expected call of SanityCheck. +func (mr *MockHandleMockRecorder) SanityCheck(arg0, arg1 any) *HandleSanityCheckCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SanityCheck", reflect.TypeOf((*MockHandle)(nil).SanityCheck), arg0, arg1) + return &HandleSanityCheckCall{Call: call} +} + +// HandleSanityCheckCall wrap *gomock.Call +type HandleSanityCheckCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *HandleSanityCheckCall) Return(arg0 bool) *HandleSanityCheckCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *HandleSanityCheckCall) Do(f func(context.Context, pp.PP) bool) *HandleSanityCheckCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *HandleSanityCheckCall) DoAndReturn(f func(context.Context, pp.PP) bool) *HandleSanityCheckCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // UpdateRecord mocks base method. func (m *MockHandle) UpdateRecord(arg0 context.Context, arg1 pp.PP, arg2 domain.Domain, arg3 ipnet.Type, arg4 string, arg5 netip.Addr) bool { m.ctrl.T.Helper() diff --git a/internal/mocks/mock_setter.go b/internal/mocks/mock_setter.go index 5b1995c6..6f1fc56b 100644 --- a/internal/mocks/mock_setter.go +++ b/internal/mocks/mock_setter.go @@ -82,6 +82,44 @@ func (c *SetterDeleteCall) DoAndReturn(f func(context.Context, pp.PP, domain.Dom return c } +// SanityCheck mocks base method. +func (m *MockSetter) SanityCheck(arg0 context.Context, arg1 pp.PP) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SanityCheck", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// SanityCheck indicates an expected call of SanityCheck. +func (mr *MockSetterMockRecorder) SanityCheck(arg0, arg1 any) *SetterSanityCheckCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SanityCheck", reflect.TypeOf((*MockSetter)(nil).SanityCheck), arg0, arg1) + return &SetterSanityCheckCall{Call: call} +} + +// SetterSanityCheckCall wrap *gomock.Call +type SetterSanityCheckCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *SetterSanityCheckCall) Return(arg0 bool) *SetterSanityCheckCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *SetterSanityCheckCall) Do(f func(context.Context, pp.PP) bool) *SetterSanityCheckCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *SetterSanityCheckCall) DoAndReturn(f func(context.Context, pp.PP) bool) *SetterSanityCheckCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Set mocks base method. func (m *MockSetter) Set(arg0 context.Context, arg1 pp.PP, arg2 domain.Domain, arg3 ipnet.Type, arg4 netip.Addr, arg5 api.TTL, arg6 bool, arg7 string) setter.ResponseCode { m.ctrl.T.Helper() diff --git a/internal/setter/base.go b/internal/setter/base.go index fd8c964b..e67b9d79 100644 --- a/internal/setter/base.go +++ b/internal/setter/base.go @@ -19,6 +19,12 @@ import ( // Setter uses [api.Handle] to update DNS records. type Setter interface { + // SanityCheck determines whether one should continue trying + SanityCheck( + ctx context.Context, + ppfmt pp.PP, + ) bool + // Set sets a particular domain to the given IP address. Set( ctx context.Context, diff --git a/internal/setter/setter.go b/internal/setter/setter.go index 82fa1e56..e0614c55 100644 --- a/internal/setter/setter.go +++ b/internal/setter/setter.go @@ -43,6 +43,10 @@ func New(_ppfmt pp.PP, handle api.Handle) (Setter, bool) { }, true } +func (s *setter) SanityCheck(ctx context.Context, ppfmt pp.PP) bool { + return s.Handle.SanityCheck(ctx, ppfmt) +} + // Set updates the IP address of one domain to the given ip. The ip must be non-zero. // //nolint:funlen diff --git a/internal/setter/setter_test.go b/internal/setter/setter_test.go index 4f91cbe7..278b4614 100644 --- a/internal/setter/setter_test.go +++ b/internal/setter/setter_test.go @@ -37,6 +37,35 @@ func wrapCancelAsDelete(cancel func()) func(context.Context, pp.PP, domain.Domai } } +//nolint:funlen +func TestSanityCheck(t *testing.T) { + t.Parallel() + + for name, tc := range map[string]struct { + answer bool + }{ + "true": {true}, + "false": {false}, + } { + t.Run(name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockPP := mocks.NewMockPP(mockCtrl) + mockHandle := mocks.NewMockHandle(mockCtrl) + + s, ok := setter.New(mockPP, mockHandle) + require.True(t, ok) + + mockHandle.EXPECT().SanityCheck(ctx, mockPP).Return(tc.answer) + require.Equal(t, tc.answer, s.SanityCheck(ctx, mockPP)) + }) + } +} + //nolint:funlen func TestSet(t *testing.T) { t.Parallel()