Skip to content

Commit

Permalink
fix(api): always use ASCII forms of domains (#61)
Browse files Browse the repository at this point in the history
* test(api): CloudflareAuth.New

* fix(api): always use ASCII forms to access Cloudflare API

* test(api): update the test of api

* docs(api): remove the bogus comment

* test(api): rename test
  • Loading branch information
favonia committed Aug 6, 2021
1 parent 37a9f03 commit befb0a9
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 67 deletions.
4 changes: 2 additions & 2 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ A small and fast DDNS updater for Cloudflare.
🔸 Effective GID: 1000
🔸 Supplementary GIDs: (empty)
🔇 Quiet mode enabled.
🐣 Added a new A record of …… (ID: ……).
🐣 Added a new AAAA record of …… (ID: ……).
🐣 Added a new A record of "……" (ID: ……).
🐣 Added a new AAAA record of "……" (ID: ……).
```

## 📜 Highlights
Expand Down
47 changes: 24 additions & 23 deletions internal/api/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const (
type CloudflareAuth struct {
Token string
AccountID string
URL string
BaseURL string
}

func (t *CloudflareAuth) New(ctx context.Context, indent pp.Indent, cacheExpiration time.Duration) (Handle, bool) {
Expand All @@ -41,8 +41,8 @@ func (t *CloudflareAuth) New(ctx context.Context, indent pp.Indent, cacheExpirat
}

// set the base URL (mostly for testing)
if t.URL != "" {
handle.BaseURL = t.URL
if t.BaseURL != "" {
handle.BaseURL = t.BaseURL
}

// this is not needed, but is helpful for diagnosing the problem
Expand Down Expand Up @@ -82,7 +82,7 @@ func (h *CloudflareHandle) ActiveZones(ctx context.Context, indent pp.Indent, na

res, err := h.cf.ListZonesContext(ctx, cloudflare.WithZoneFilters(name, h.cf.AccountID, "active"))
if err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to check the existence of a zone named %s: %v", name, err)
pp.Printf(indent, pp.EmojiError, "Failed to check the existence of a zone named %q: %v", name, err)
return nil, false
}

Expand All @@ -97,7 +97,7 @@ func (h *CloudflareHandle) ActiveZones(ctx context.Context, indent pp.Indent, na
}

func (h *CloudflareHandle) ZoneOfDomain(ctx context.Context, indent pp.Indent, domain FQDN) (string, bool) {
if id, found := h.cache.zoneOfDomain.Get(domain.String()); found {
if id, found := h.cache.zoneOfDomain.Get(domain.ToASCII()); found {
return id.(string), true
}

Expand All @@ -113,7 +113,7 @@ zoneSearch:
case 0: // len(zones) == 0
continue zoneSearch
case 1: // len(zones) == 1
h.cache.zoneOfDomain.SetDefault(domain.String(), zones[0])
h.cache.zoneOfDomain.SetDefault(domain.ToASCII(), zones[0])

return zones[0], true

Expand All @@ -124,13 +124,13 @@ zoneSearch:
}
}

pp.Printf(indent, pp.EmojiError, "Failed to find the zone of %s.", domain)
pp.Printf(indent, pp.EmojiError, "Failed to find the zone of %q.", domain.Describe())
return "", false
}

func (h *CloudflareHandle) ListRecords(ctx context.Context, indent pp.Indent,
domain FQDN, ipNet ipnet.Type) (map[string]net.IP, bool) {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.String()); found {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.ToASCII()); found {
return rmap.(map[string]net.IP), true
}

Expand All @@ -141,11 +141,11 @@ func (h *CloudflareHandle) ListRecords(ctx context.Context, indent pp.Indent,

//nolint:exhaustivestruct // Other fields are intentionally unspecified
rs, err := h.cf.DNSRecords(ctx, zone, cloudflare.DNSRecord{
Name: domain.String(),
Name: domain.ToASCII(),
Type: ipNet.RecordType(),
})
if err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to retrieve records of %s: %v", domain, err)
pp.Printf(indent, pp.EmojiError, "Failed to retrieve records of %q: %v", domain.Describe(), err)
return nil, false
}

Expand All @@ -165,15 +165,15 @@ func (h *CloudflareHandle) DeleteRecord(ctx context.Context, indent pp.Indent,
}

if err := h.cf.DeleteDNSRecord(ctx, zone, id); err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to delete a stale %s record of %s (ID: %s): %v",
ipNet.RecordType(), domain, id, err)
pp.Printf(indent, pp.EmojiError, "Failed to delete a stale %s record of %q (ID: %s): %v",
ipNet.RecordType(), domain.Describe(), id, err)

h.cache.listRecords[ipNet].Delete(domain.String())
h.cache.listRecords[ipNet].Delete(domain.ToASCII())

return false
}

if rmap, found := h.cache.listRecords[ipNet].Get(domain.String()); found {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.ToASCII()); found {
delete(rmap.(map[string]net.IP), id)
}

Expand All @@ -189,21 +189,21 @@ func (h *CloudflareHandle) UpdateRecord(ctx context.Context, indent pp.Indent,

//nolint:exhaustivestruct // Other fields are intentionally omitted
payload := cloudflare.DNSRecord{
Name: domain.String(),
Name: domain.ToASCII(),
Type: ipNet.RecordType(),
Content: ip.String(),
}

if err := h.cf.UpdateDNSRecord(ctx, zone, id, payload); err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to update a stale %s record of %s (ID: %s): %v",
ipNet.RecordType(), domain, id, err)
pp.Printf(indent, pp.EmojiError, "Failed to update a stale %s record of %q (ID: %s): %v",
ipNet.RecordType(), domain.Describe(), id, err)

h.cache.listRecords[ipNet].Delete(domain.String())
h.cache.listRecords[ipNet].Delete(domain.ToASCII())

return false
}

if rmap, found := h.cache.listRecords[ipNet].Get(domain.String()); found {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.ToASCII()); found {
rmap.(map[string]net.IP)[id] = ip
}

Expand All @@ -219,7 +219,7 @@ func (h *CloudflareHandle) CreateRecord(ctx context.Context, indent pp.Indent,

//nolint:exhaustivestruct // Other fields are intentionally omitted
payload := cloudflare.DNSRecord{
Name: domain.String(),
Name: domain.ToASCII(),
Type: ipNet.RecordType(),
Content: ip.String(),
TTL: ttl,
Expand All @@ -228,14 +228,15 @@ func (h *CloudflareHandle) CreateRecord(ctx context.Context, indent pp.Indent,

res, err := h.cf.CreateDNSRecord(ctx, zone, payload)
if err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to add a new %s record of %s: %v", ipNet.RecordType(), domain, err)
pp.Printf(indent, pp.EmojiError, "Failed to add a new %s record of %q: %v",
ipNet.RecordType(), domain.Describe(), err)

h.cache.listRecords[ipNet].Delete(domain.String())
h.cache.listRecords[ipNet].Delete(domain.ToASCII())

return "", false
}

if rmap, found := h.cache.listRecords[ipNet].Get(domain.String()); found {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.ToASCII()); found {
rmap.(map[string]net.IP)[res.Result.ID] = ip
}

Expand Down
118 changes: 118 additions & 0 deletions internal/api/cloudflare_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package api_test

import (
"context"
"crypto/sha512"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/favonia/cloudflare-ddns/internal/api"
)

// mockID returns a hex string of length 32, suitable for all kinds of IDs
// used in the Cloudflare API.
func mockID(seed string) string {
arr := sha512.Sum512([]byte(seed))
return hex.EncodeToString(arr[:16])
}

const (
mockToken = "token123"
mockAccount = "account456"
)

func TestCloudflareAuthNewValid(t *testing.T) {
t.Parallel()

mux := http.NewServeMux()
ts := httptest.NewServer(mux)
defer ts.Close()

mux.HandleFunc("/user/tokens/verify", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, []string{fmt.Sprintf("Bearer %s", mockToken)}, r.Header["Authorization"])

w.Header().Set("content-type", "application/json")
fmt.Fprintf(w,
`{
"result": { "id": "%s", "status": "active" },
"success": true,
"errors": [],
"messages": [
{
"code": 10000,
"message": "This API Token is valid and active",
"type": null
}
]
}`, mockID("result"))
})

auth := api.CloudflareAuth{
Token: mockToken,
AccountID: mockAccount,
BaseURL: ts.URL,
}

h, ok := auth.New(context.Background(), 3, time.Second)
require.NotNil(t, h)
require.True(t, ok)
}

func TestCloudflareAuthNewEmpty(t *testing.T) {
t.Parallel()

mux := http.NewServeMux()
ts := httptest.NewServer(mux)
defer ts.Close()

auth := api.CloudflareAuth{
Token: "",
AccountID: mockAccount,
BaseURL: ts.URL,
}

h, ok := auth.New(context.Background(), 3, time.Second)
require.Nil(t, h)
require.False(t, ok)
}

func TestCloudflareAuthNewInvalid(t *testing.T) {
t.Parallel()

mux := http.NewServeMux()
ts := httptest.NewServer(mux)
defer ts.Close()

mux.HandleFunc("/user/tokens/verify", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, []string{fmt.Sprintf("Bearer %s", mockToken)}, r.Header["Authorization"])

w.WriteHeader(http.StatusUnauthorized)
w.Header().Set("content-type", "application/json")
fmt.Fprintf(w,
`{
"success": false,
"errors": [{ "code": 1000, "message": "Invalid API Token" }],
"messages": [],
"result": null
}`)
})

auth := api.CloudflareAuth{
Token: mockToken,
AccountID: mockAccount,
BaseURL: ts.URL,
}

h, ok := auth.New(context.Background(), 3, time.Second)
require.Nil(t, h)
require.False(t, ok)
}
21 changes: 15 additions & 6 deletions internal/api/fqdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@ type FQDN string
// safelyToUnicode takes an ASCII form and returns the Unicode form
// when the round trip gives the same ASCII form back without errors.
// Otherwise, the input ASCII form is returned.
func safelyToUnicode(ascii string) string {
func safelyToUnicode(ascii string) (string, bool) {
unicode, errToA := profile.ToUnicode(ascii)
roundTrip, errToU := profile.ToASCII(unicode)
if errToA != nil || errToU != nil || roundTrip != ascii {
return ascii
return ascii, false
}

return unicode
return unicode, true
}

func (f FQDN) String() string { return string(f) }
func (f FQDN) ToASCII() string { return string(f) }

func (f FQDN) Describe() string {
best, ok := safelyToUnicode(string(f))
if !ok {
return string(f)
}

return best
}

// NewFQDN normalizes a domain to its ASCII form and then stores
// the normalized domain in its Unicode form when the round trip
Expand All @@ -45,7 +54,7 @@ func NewFQDN(domain string) (FQDN, error) {
// Remove the final dot for consistency
normalized = strings.TrimSuffix(normalized, ".")

return FQDN(safelyToUnicode(normalized)), err
return FQDN(normalized), err
}

func SortFQDNs(s []FQDN) { sort.Slice(s, func(i, j int) bool { return s[i] < s[j] }) }
Expand All @@ -58,7 +67,7 @@ type FQDNSplitter struct {

func NewFQDNSplitter(domain FQDN) *FQDNSplitter {
return &FQDNSplitter{
domain: domain.String(),
domain: domain.ToASCII(),
cursor: 0,
exhausted: false,
}
Expand Down
Loading

0 comments on commit befb0a9

Please sign in to comment.