diff --git a/registry/remote/auth/cache.go b/registry/remote/auth/cache.go index 91a52e1c..7b9816fc 100644 --- a/registry/remote/auth/cache.go +++ b/registry/remote/auth/cache.go @@ -157,3 +157,76 @@ func (noCache) GetToken(ctx context.Context, registry string, scheme Scheme, key func (noCache) Set(ctx context.Context, registry string, scheme Scheme, key string, fetch func(context.Context) (string, error)) (string, error) { return fetch(ctx) } + +// hostCache is an auth cache that ignores scopes. Uses only the registry's hostname to find a token. +type hostCache struct { + Cache +} + +// GetToken implements Cache. +func (c *hostCache) GetToken(ctx context.Context, registry string, scheme Scheme, key string) (string, error) { + return c.Cache.GetToken(ctx, registry, scheme, "") +} + +// Set implements Cache. +func (c *hostCache) Set(ctx context.Context, registry string, scheme Scheme, key string, fetch func(context.Context) (string, error)) (string, error) { + return c.Cache.Set(ctx, registry, scheme, "", fetch) +} + +// fallbackCache tries the primary cache then falls back to the secondary cache. +type fallbackCache struct { + primary Cache + secondary Cache +} + +// GetScheme implements Cache. +func (fc *fallbackCache) GetScheme(ctx context.Context, registry string) (Scheme, error) { + scheme, err := fc.primary.GetScheme(ctx, registry) + if err == nil { + return scheme, nil + } + + // fallback + return fc.secondary.GetScheme(ctx, registry) +} + +// GetToken implements Cache. +func (fc *fallbackCache) GetToken(ctx context.Context, registry string, scheme Scheme, key string) (string, error) { + token, err := fc.primary.GetToken(ctx, registry, scheme, key) + if err == nil { + return token, nil + } + + // fallback + return fc.secondary.GetToken(ctx, registry, scheme, key) +} + +// Set implements Cache. +func (fc *fallbackCache) Set(ctx context.Context, registry string, scheme Scheme, key string, fetch func(context.Context) (string, error)) (string, error) { + token, err := fc.primary.Set(ctx, registry, scheme, key, fetch) + if err != nil { + return token, err + } + + return fc.secondary.Set(ctx, registry, scheme, key, func(ctx context.Context) (string, error) { + return token, nil + }) +} + +// NewSingleContextCache creates a host-based cache for optimizing the auth flow for non-compliant registries. +// It is intended to be used in a single context, such as pulling from a single repository. +// This cache should not be shared. +// +// Note: [NewCache] should be used for compliant registries as it can be shared +// across context and will generally make less re-authentication requests. +func NewSingleContextCache() Cache { + cache := NewCache() + return &fallbackCache{ + primary: cache, + // We can re-use the came concurrentCache here because the key space is different + // (keys are always empty for the hostCache) so there is no collision. + // Even if there is a collision it is not an issue. + // Re-using saves a little memory. + secondary: &hostCache{cache}, + } +} diff --git a/registry/remote/auth/cache_test.go b/registry/remote/auth/cache_test.go index d4edfb6a..570feede 100644 --- a/registry/remote/auth/cache_test.go +++ b/registry/remote/auth/cache_test.go @@ -540,3 +540,136 @@ func Test_concurrentCache_Set_Fetch_Failure(t *testing.T) { } } } + +func Test_hostCache(t *testing.T) { + base := NewCache() + + // no entry in the cache + ctx := context.Background() + + hc := hostCache{base} + + fetch := func(i int) func(context.Context) (string, error) { + return func(context.Context) (string, error) { + return strconv.Itoa(i), nil + } + } + + // The key is ignored in the hostCache implementation. + + { // Set the token to 100 + gotToken, err := hc.Set(ctx, "reg.example.com", SchemeBearer, "key1", fetch(100)) + if err != nil { + t.Fatalf("hostCache.Set() error = %v", err) + } + if want := strconv.Itoa(100); gotToken != want { + t.Errorf("hostCache.Set() = %v, want %v", gotToken, want) + } + } + + { // Overwrite the token entry to 101 + gotToken, err := hc.Set(ctx, "reg.example.com", SchemeBearer, "key2", fetch(101)) + if err != nil { + t.Fatalf("hostCache.Set() error = %v", err) + } + if want := strconv.Itoa(101); gotToken != want { + t.Errorf("hostCache.Set() = %v, want %v", gotToken, want) + } + } + + { // Add entry for another host + gotToken, err := hc.Set(ctx, "reg2.example.com", SchemeBearer, "key3", fetch(102)) + if err != nil { + t.Fatalf("hostCache.Set() error = %v", err) + } + if want := strconv.Itoa(102); gotToken != want { + t.Errorf("hostCache.Set() = %v, want %v", gotToken, want) + } + } + + { // Ensure the token for key1 is 101 now + gotToken, err := hc.GetToken(ctx, "reg.example.com", SchemeBearer, "key1") + if err != nil { + t.Fatalf("hostCache.GetToken() error = %v", err) + } + if want := strconv.Itoa(101); gotToken != want { + t.Errorf("hostCache.GetToken() = %v, want %v", gotToken, want) + } + } + + { // Make sure GetScheme still works + gotScheme, err := hc.GetScheme(ctx, "reg.example.com") + if err != nil { + t.Fatalf("hostCache.GetScheme() error = %v", err) + } + if want := SchemeBearer; gotScheme != want { + t.Errorf("hostCache.GetScheme() = %v, want %v", gotScheme, want) + } + } +} + +func Test_fallbackCache(t *testing.T) { + // no entry in the cache + ctx := context.Background() + + scc := NewSingleContextCache() + + fetch := func(i int) func(context.Context) (string, error) { + return func(context.Context) (string, error) { + return strconv.Itoa(i), nil + } + } + + // Test that fallback works + + { // Set the token to 100 + gotToken, err := scc.Set(ctx, "reg.example.com", SchemeBearer, "key1", fetch(100)) + if err != nil { + t.Fatalf("hostCache.Set() error = %v", err) + } + if want := strconv.Itoa(100); gotToken != want { + t.Errorf("hostCache.Set() = %v, want %v", gotToken, want) + } + } + + { // Ensure the token for key2 falls back to 100 + gotToken, err := scc.GetToken(ctx, "reg.example.com", SchemeBearer, "key2") + if err != nil { + t.Fatalf("hostCache.GetToken() error = %v", err) + } + if want := strconv.Itoa(100); gotToken != want { + t.Errorf("hostCache.GetToken() = %v, want %v", gotToken, want) + } + } + + { // Make sure GetScheme works as expected + gotScheme, err := scc.GetScheme(ctx, "reg.example.com") + if err != nil { + t.Fatalf("hostCache.GetScheme() error = %v", err) + } + if want := SchemeBearer; gotScheme != want { + t.Errorf("hostCache.GetScheme() = %v, want %v", gotScheme, want) + } + } + + { // Make sure GetScheme falls back + gotScheme, err := scc.GetScheme(ctx, "reg.example.com") + if err != nil { + t.Fatalf("hostCache.GetScheme() error = %v", err) + } + if want := SchemeBearer; gotScheme != want { + t.Errorf("hostCache.GetScheme() = %v, want %v", gotScheme, want) + } + } + + { // Check GetScheme fallback + // scc.(*fallbackCache).primary = NewCache() + gotScheme, err := scc.GetScheme(ctx, "reg2.example.com") + if !errors.Is(err, errdef.ErrNotFound) { + t.Fatalf("hostCache.GetScheme() error = %v", err) + } + if want := SchemeUnknown; gotScheme != want { + t.Errorf("hostCache.GetScheme() = %v, want %v", gotScheme, want) + } + } +}