diff --git a/pkg/cloudprovider/aws/volume_snapshotter.go b/pkg/cloudprovider/aws/volume_snapshotter.go index 0be495e52e..9014ab3f67 100644 --- a/pkg/cloudprovider/aws/volume_snapshotter.go +++ b/pkg/cloudprovider/aws/volume_snapshotter.go @@ -21,6 +21,7 @@ import ( "os" "regexp" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" @@ -32,11 +33,16 @@ import ( "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" "github.com/heptio/velero/pkg/cloudprovider" ) -const regionKey = "region" +const ( + regionKey = "region" + snapshotCreationTimeoutKey = "snapshotCreationTimeout" + snapshotCreationTimeoutDefault = 20 * time.Minute +) // iopsVolumeTypes is a set of AWS EBS volume types for which IOPS should // be captured during snapshot and provided when creating a new volume @@ -44,8 +50,9 @@ const regionKey = "region" var iopsVolumeTypes = sets.NewString("io1") type VolumeSnapshotter struct { - log logrus.FieldLogger - ec2 *ec2.EC2 + log logrus.FieldLogger + ec2 *ec2.EC2 + snapshotCreationTimeout time.Duration } // takes AWS credential config & a profile to create a new session @@ -68,12 +75,27 @@ func NewVolumeSnapshotter(logger logrus.FieldLogger) *VolumeSnapshotter { } func (b *VolumeSnapshotter) Init(config map[string]string) error { - if err := cloudprovider.ValidateVolumeSnapshotterConfigKeys(config, regionKey, credentialProfileKey); err != nil { + if err := cloudprovider.ValidateVolumeSnapshotterConfigKeys( + config, + regionKey, + credentialProfileKey, + snapshotCreationTimeoutKey, + ); err != nil { return err } region := config[regionKey] credentialProfile := config[credentialProfileKey] + // if config["snapshotCreationTimeout"] is empty, default to 20m; otherwise, parse it + var err error + if val := config[snapshotCreationTimeoutKey]; val == "" { + b.snapshotCreationTimeout = snapshotCreationTimeoutDefault + } else { + b.snapshotCreationTimeout, err = time.ParseDuration(val) + if err != nil { + return errors.Wrapf(err, "unable to parse value %q for config key %q (expected a duration string)", val, snapshotCreationTimeoutKey) + } + } if region == "" { return errors.Errorf("missing %s in aws configuration", regionKey) } @@ -91,18 +113,12 @@ func (b *VolumeSnapshotter) Init(config map[string]string) error { } func (b *VolumeSnapshotter) CreateVolumeFromSnapshot(snapshotID, volumeType, volumeAZ string, iops *int64) (volumeID string, err error) { - // describe the snapshot so we can apply its tags to the volume - snapReq := &ec2.DescribeSnapshotsInput{ - SnapshotIds: []*string{&snapshotID}, - } - - snapRes, err := b.ec2.DescribeSnapshots(snapReq) + snapshot, err := b.snapshotWhenAvailable(snapshotID) if err != nil { return "", errors.WithStack(err) } - - if count := len(snapRes.Snapshots); count != 1 { - return "", errors.Errorf("expected 1 snapshot from DescribeSnapshots for %s, got %v", snapshotID, count) + if snapshot == nil { + return "", errors.Errorf("Snapshot %s is not available", snapshotID) } // filter tags through getTagsForCluster() function in order to apply @@ -111,11 +127,11 @@ func (b *VolumeSnapshotter) CreateVolumeFromSnapshot(snapshotID, volumeType, vol SnapshotId: &snapshotID, AvailabilityZone: &volumeAZ, VolumeType: &volumeType, - Encrypted: snapRes.Snapshots[0].Encrypted, + Encrypted: snapshot.Encrypted, TagSpecifications: []*ec2.TagSpecification{ { ResourceType: aws.String(ec2.ResourceTypeVolume), - Tags: getTagsForCluster(snapRes.Snapshots[0].Tags), + Tags: getTagsForCluster(snapshot.Tags), }, }, } @@ -132,6 +148,63 @@ func (b *VolumeSnapshotter) CreateVolumeFromSnapshot(snapshotID, volumeType, vol return *res.VolumeId, nil } +func (b *VolumeSnapshotter) snapshotWhenAvailable(snapshotID string) (*ec2.Snapshot, error) { + logger := b.log.WithField("snapshotID", snapshotID) + + var snapshot *ec2.Snapshot + err := wait.PollImmediate(time.Second, b.snapshotCreationTimeout, func() (bool, error) { + var err error + snapshot, err = b.getSnapshot(snapshotID) + if err != nil { + return true, err + } + if snapshot.State == nil { + snapshot = nil + logger.Debug("snapshot has nil state") + return true, errors.Errorf("Snapshot has nil state") + } + if *snapshot.State == ec2.SnapshotStatePending { + snapshot = nil + logger.Debug("snapshot not yet ready for use") + return false, nil + } + if *snapshot.State == ec2.SnapshotStateCompleted { + return true, nil + } + if *snapshot.State == ec2.SnapshotStateError { + snapshot = nil + logger.Debug("snapshot is in 'error' state") + return true, errors.Errorf("Snapshot is in 'error' state") + } + unknownState := *snapshot.State + snapshot = nil + return true, errors.Errorf("Snapshot is in unknown state '%s'", unknownState) + }) + + if err == wait.ErrWaitTimeout { + logger.Debug("timeout reached waiting for snapshot to be ready") + } + + return snapshot, err +} + +func (b *VolumeSnapshotter) getSnapshot(snapshotID string) (*ec2.Snapshot, error) { + // describe the snapshot so we can apply its tags to the volume + snapReq := &ec2.DescribeSnapshotsInput{ + SnapshotIds: []*string{&snapshotID}, + } + + snapRes, err := b.ec2.DescribeSnapshots(snapReq) + if err != nil { + return nil, errors.WithStack(err) + } + + if count := len(snapRes.Snapshots); count != 1 { + return nil, errors.Errorf("expected 1 snapshot from DescribeSnapshots for %s, got %v", snapshotID, count) + } + return snapRes.Snapshots[0], nil +} + func (b *VolumeSnapshotter) GetVolumeInfo(volumeID, volumeAZ string) (string, *int64, error) { volumeInfo, err := b.describeVolume(volumeID) if err != nil {