Skip to content

Commit

Permalink
feat: content-length for tar checkpoint downloads (#8684)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikwilson authored Jan 19, 2024
1 parent 11e3ba9 commit 6d744f7
Show file tree
Hide file tree
Showing 11 changed files with 649 additions and 269 deletions.
7 changes: 7 additions & 0 deletions docs/release-notes/checkpoint-tar.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:orphan:

**API Changes**

- Checkpoints: The checkpoint download endpoint will now allow the use of `application/x-tar`` as
an accepted content type in the request. It will provide a response in the form of an
uncompressed tar file, complete with content-length information included in the headers.
47 changes: 41 additions & 6 deletions master/internal/core_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"fmt"
"io"
"net/http"
"strconv"

"github.com/pkg/errors"
"github.com/sirupsen/logrus"

"github.com/google/uuid"
"github.com/labstack/echo/v4"
Expand All @@ -26,14 +28,20 @@ import (
)

const (
// MIMEApplicationXTar is Tar's MIME type.
MIMEApplicationXTar = "application/x-tar"
// MIMEApplicationGZip is GZip's MIME type.
MIMEApplicationGZip = "application/gzip"
// MIMEApplicationZip is Zip's MIME type.
MIMEApplicationZip = "application/zip"
)

var checkpointLogger = logrus.WithField("component", "core-checkpoint")

func mimeToArchiveType(mimeType string) archive.ArchiveType {
switch mimeType {
case MIMEApplicationXTar:
return archive.ArchiveTar
case MIMEApplicationGZip:
return archive.ArchiveTgz
case MIMEApplicationZip:
Expand Down Expand Up @@ -88,7 +96,7 @@ func (m *Master) getCheckpointStorageConfig(id uuid.UUID) (
}

func (m *Master) getCheckpointImpl(
ctx context.Context, id uuid.UUID, mimeType string, content io.Writer,
ctx context.Context, id uuid.UUID, mimeType string, content *echo.Response,
) error {
// Assume a checkpoint always has experiment configs
storageConfig, err := m.getCheckpointStorageConfig(id)
Expand All @@ -104,12 +112,34 @@ func (m *Master) getCheckpointImpl(
// DelayWriter delays the first write until we have successfully downloaded
// some bytes and are more confident that the download will succeed.
dw := newDelayWriter(content, 16*1024)
downloader, err := checkpoints.NewDownloader(
dw, id.String(), storageConfig, mimeToArchiveType(mimeType))
aw, err := archive.NewArchiveWriter(dw, mimeToArchiveType(mimeType))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}

downloader, err := checkpoints.NewDownloader(ctx, dw, id.String(), storageConfig, aw)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}

if aw.DryRunEnabled() {
files, err := downloader.ListFiles(ctx)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError,
fmt.Sprintf("unable to list checkpoint %s files: %s", id.String(), err.Error()))
}

log := checkpointLogger.WithField("checkpoint", id.String())
contentLength, err := archive.DryRunLength(aw, files)
if err != nil {
log.Warnf("failed to get dry-run content-length: %s", err.Error())
}
if contentLength > 0 {
log.Debugf("dry-run content-length: %d", contentLength)
content.Header().Set(echo.HeaderContentLength, strconv.FormatInt(contentLength, 10))
}
}

err = downloader.Download(ctx)
switch {
case err != nil && errors.Is(err, context.Canceled):
Expand All @@ -131,11 +161,11 @@ func (m *Master) getCheckpointImpl(
return nil
}

// @Summary Get a checkpoint's contents in a tgz or zip file.
// @Summary Get a checkpoint's contents in a tar, tgz, or zip file.
// @Tags Checkpoints
// @ID get-checkpoint
// @Accept json
// @Produce application/gzip,application/zip
// @Produce application/x-tar,application/gzip,application/zip
// @Param checkpoint_uuid path string true "Checkpoint UUID"
// @Success 200 {} string ""
// @Router /checkpoints/{checkpoint_uuid} [get]
Expand All @@ -144,7 +174,12 @@ func (m *Master) getCheckpointImpl(
func (m *Master) getCheckpoint(c echo.Context) error {
// Get the MIME type. Only a single type is accepted.
mimeType := c.Request().Header.Get("Accept")
if mimeType != MIMEApplicationGZip &&
// Default to tar if no MIME type is specified.
if mimeType == "" || mimeType == "*/*" || mimeType == "application/*" {
mimeType = MIMEApplicationXTar
}
if mimeType != MIMEApplicationXTar &&
mimeType != MIMEApplicationGZip &&
mimeType != MIMEApplicationZip {
return echo.NewHTTPError(http.StatusUnsupportedMediaType,
fmt.Sprintf("unsupported media type to download a checkpoint: '%s'", mimeType))
Expand Down
47 changes: 43 additions & 4 deletions master/internal/core_checkpoint_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -93,7 +94,29 @@ func createMockCheckpointS3(bucket string, prefix string) error {
return nil
}

func checkTgz(t *testing.T, content io.Reader, id string) {
func checkTar(t *testing.T, rec *httptest.ResponseRecorder, id string) {
require.Equal(t, strconv.Itoa(rec.Body.Len()), rec.Header().Get("Content-Length"))
content := rec.Body
tr := tar.NewReader(content)
gotMap := make(map[string]string)
for {
hdr, err := tr.Next()
if err == io.EOF {
break // End of archive
}
require.NoError(t, err, "failed to read record header")
buf := &strings.Builder{}
if hdr.Size > 0 {
_, err := io.Copy(buf, tr) //nolint: gosec
require.NoError(t, err, "failed to read content of file", hdr.Name)
}
gotMap[hdr.Name] = buf.String()
}
require.Equal(t, mockCheckpointContent, gotMap)
}

func checkTgz(t *testing.T, rec *httptest.ResponseRecorder, id string) {
content := rec.Body
zr, err := gzip.NewReader(content)
require.NoError(t, err, "failed to create a gzip reader")
tr := tar.NewReader(zr)
Expand All @@ -114,7 +137,8 @@ func checkTgz(t *testing.T, content io.Reader, id string) {
require.Equal(t, mockCheckpointContent, gotMap)
}

func checkZip(t *testing.T, content string, id string) {
func checkZip(t *testing.T, rec *httptest.ResponseRecorder, id string) {
content := rec.Body.String()
zr, err := zip.NewReader(strings.NewReader(content), int64(len(content)))
require.NoError(t, err, "failed to create a zip reader")
gotMap := make(map[string]string)
Expand Down Expand Up @@ -184,6 +208,21 @@ func testGetCheckpointEcho(t *testing.T, bucket string) {
IDToReqCall func() error
Params []any
}{
{"CanGetCheckpointTar", func() error {
api, ctx, rec := setupCheckpointTestEcho(t)
id, err := createCheckpoint(t, api.m.db, bucket)
if err != nil {
return err
}
ctx.SetParamNames("checkpoint_uuid")
ctx.SetParamValues(id)
ctx.SetRequest(httptest.NewRequest(http.MethodGet, "/", nil))
ctx.Request().Header.Set("Accept", MIMEApplicationXTar)
err = api.m.getCheckpoint(ctx)
require.NoError(t, err, "API call returns error")
checkTar(t, rec, id)
return err
}, []any{mock.Anything, mock.Anything, mock.Anything}},
{"CanGetCheckpointTgz", func() error {
api, ctx, rec := setupCheckpointTestEcho(t)
id, err := createCheckpoint(t, api.m.db, bucket)
Expand All @@ -196,7 +235,7 @@ func testGetCheckpointEcho(t *testing.T, bucket string) {
ctx.Request().Header.Set("Accept", MIMEApplicationGZip)
err = api.m.getCheckpoint(ctx)
require.NoError(t, err, "API call returns error")
checkTgz(t, rec.Body, id)
checkTgz(t, rec, id)
return err
}, []any{mock.Anything, mock.Anything, mock.Anything}},
{"CanGetCheckpointZip", func() error {
Expand All @@ -211,7 +250,7 @@ func testGetCheckpointEcho(t *testing.T, bucket string) {
ctx.Request().Header.Set("Accept", MIMEApplicationZip)
err = api.m.getCheckpoint(ctx)
require.NoError(t, err, "API call returns error")
checkZip(t, rec.Body.String(), id)
checkZip(t, rec, id)
return err
}, []any{mock.Anything, mock.Anything, mock.Anything}},
}
Expand Down
94 changes: 39 additions & 55 deletions master/pkg/checkpoints/archive/archive_writer.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
package archive

import (
"archive/tar"
"archive/zip"
"compress/gzip"
"fmt"
"io"
"strings"

"github.com/pkg/errors"
)

// ArchiveType currently includes tgz and zip.
type ArchiveType string

const (
// ArchiveTar is a tar ball.
ArchiveTar = "tar"
// ArchiveTgz is a gzipped tar ball.
ArchiveTgz = "tgz"
// ArchiveZip is a zip file.
Expand All @@ -21,35 +22,42 @@ const (
ArchiveUnknown = "unknown"
)

// FileEntry represents a file in an archive.
type FileEntry struct {
// Path is the path of the file in the archive.
Path string
// Size is the size of the file in bytes.
Size int64
}

// ArchiveWriter defines an interface to create an archive file.
type ArchiveWriter interface {
WriteHeader(path string, size int64) error
Write(b []byte) (int, error)
Close() error
DryRunEnabled() bool
DryRunLength(path string, size int64) (int64, error)
DryRunClose() (int64, error)
}

// NewArchiveWriter returns a new ArchiveWriter for archiveType that writes to w.
func NewArchiveWriter(w io.Writer, archiveType ArchiveType) (ArchiveWriter, error) {
closers := []io.Closer{}
switch archiveType {
case ArchiveTar:
return newTarArchiveWriter(w, closers).enableDryRun(), nil

case ArchiveTgz:
gz := gzip.NewWriter(w)
closers = append(closers, gz)

tw := tar.NewWriter(gz)
closers = append(closers, tw)

return &tarArchiveWriter{archiveClosers{closers}, tw}, nil
return newTarArchiveWriter(gz, closers), nil

case ArchiveZip:
zw := zip.NewWriter(w)
closers = append(closers, zw)

return &zipArchiveWriter{archiveClosers{closers}, zw, nil}, nil
return newZipArchiveWriter(w, closers), nil

default:
return nil, fmt.Errorf(
"archive type must be %s or %s but got %s", ArchiveTgz, ArchiveZip, archiveType)
"archive type must be %s, %s, or %s. received %s", ArchiveTar, ArchiveTgz, ArchiveZip, archiveType)
}
}

Expand All @@ -68,50 +76,26 @@ func (ac *archiveClosers) Close() error {
return nil
}

type tarArchiveWriter struct {
archiveClosers
tw *tar.Writer
}

func (aw *tarArchiveWriter) WriteHeader(path string, size int64) error {
hdr := tar.Header{
Name: path,
Mode: 0o666,
Size: size,
// DryRunLength returns the length of the archive file that would be created if the files were
// written to the archive.
func DryRunLength(
aw ArchiveWriter,
files []FileEntry,
) (int64, error) {
if !aw.DryRunEnabled() {
return 0, errors.New("dry run not enabled")
}
if strings.HasSuffix(path, "/") {
// This a directory
hdr.Mode = 0o777
contentLength := int64(0)
for _, file := range files {
size, err := aw.DryRunLength(file.Path, file.Size)
if err != nil {
return 0, err
}
contentLength += size
}
return aw.tw.WriteHeader(&hdr)
}

func (aw *tarArchiveWriter) Write(p []byte) (int, error) {
return aw.tw.Write(p)
}

type zipArchiveWriter struct {
archiveClosers
zw *zip.Writer
zwContent io.Writer
}

func (aw *zipArchiveWriter) WriteHeader(path string, size int64) error {
// Zip by default sets mode 0666 and 0777 for files and folders respectively.
zwc, err := aw.zw.Create(path)
closeSize, err := aw.DryRunClose()
if err != nil {
return err
}
aw.zwContent = zwc
return nil
}

func (aw *zipArchiveWriter) Write(p []byte) (int, error) {
// Guard against the mistake where WriteHeader() is not called before
// calling Write(). The AWS SDK likely will not make this mistake but
// zipArchiveWriter is not just limited to being used with AWS.
if aw.zwContent == nil {
return 0, nil
return 0, err
}
return aw.zwContent.Write(p)
return contentLength + closeSize, nil
}
Loading

0 comments on commit 6d744f7

Please sign in to comment.