Skip to content

Commit

Permalink
Updates the S3 sink to use the AWS Plugin for loading AWS credentials.
Browse files Browse the repository at this point in the history
…Resolves opensearch-project#2767

Signed-off-by: David Venable <dlv@amazon.com>
  • Loading branch information
dlvenable committed May 31, 2023
1 parent 29225c2 commit fcf0969
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 281 deletions.
1 change: 1 addition & 0 deletions data-prepper-plugins/s3-sink/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
dependencies {
implementation project(':data-prepper-api')
implementation project(path: ':data-prepper-plugins:common')
implementation project(':data-prepper-plugins:aws-plugin-api')
implementation 'io.micrometer:micrometer-core'
implementation 'com.fasterxml.jackson.core:jackson-core'
implementation 'com.fasterxml.jackson.core:jackson-databind'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ public void setUp() {
when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb"));
when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT3M"));
when(s3SinkConfig.getThresholdOptions()).thenReturn(thresholdOptions);
when(s3SinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(s3region));

lenient().when(pluginMetrics.counter(S3SinkService.OBJECTS_SUCCEEDED)).thenReturn(snapshotSuccessCounter);
lenient().when(pluginMetrics.counter(S3SinkService.OBJECTS_FAILED)).thenReturn(snapshotFailedCounter);
Expand Down Expand Up @@ -136,7 +134,7 @@ void verify_flushed_records_into_s3_bucket() {
}

private S3SinkService createObjectUnderTest() {
return new S3SinkService(s3SinkConfig, bufferFactory, codec, pluginMetrics);
return new S3SinkService(s3SinkConfig, bufferFactory, codec, s3Client, pluginMetrics);
}

private int gets3ObjectCount() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink;

import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.configuration.AwsAuthenticationOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.services.s3.S3Client;

public final class ClientFactory {
private ClientFactory() { }

static S3Client createS3Client(final S3SinkConfig s3SinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(s3SinkConfig.getAwsAuthenticationOptions());
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return S3Client.builder()
.region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build();
}

private static ClientOverrideConfiguration createOverrideConfiguration(final S3SinkConfig s3SinkConfig) {
final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries()).build();
return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build();
}

private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) {
return AwsCredentialsOptions.builder()
.withRegion(awsAuthenticationOptions.getAwsRegion())
.withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn())
.withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.dataprepper.plugins.sink;

import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.configuration.PluginModel;
Expand All @@ -22,6 +23,8 @@
import org.opensearch.dataprepper.plugins.sink.codec.Codec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.s3.S3Client;

import java.util.Collection;

/**
Expand All @@ -35,7 +38,7 @@ public class S3Sink extends AbstractSink<Record<Event>> {
private final S3SinkConfig s3SinkConfig;
private final Codec codec;
private volatile boolean sinkInitialized;
private S3SinkService s3SinkService;
private final S3SinkService s3SinkService;
private final BufferFactory bufferFactory;

/**
Expand All @@ -44,8 +47,10 @@ public class S3Sink extends AbstractSink<Record<Event>> {
* @param pluginFactory dp plugin factory.
*/
@DataPrepperPluginConstructor
public S3Sink(final PluginSetting pluginSetting, final S3SinkConfig s3SinkConfig,
final PluginFactory pluginFactory) {
public S3Sink(final PluginSetting pluginSetting,
final S3SinkConfig s3SinkConfig,
final PluginFactory pluginFactory,
final AwsCredentialsSupplier awsCredentialsSupplier) {
super(pluginSetting);
this.s3SinkConfig = s3SinkConfig;
final PluginModel codecConfiguration = s3SinkConfig.getCodec();
Expand All @@ -59,6 +64,8 @@ public S3Sink(final PluginSetting pluginSetting, final S3SinkConfig s3SinkConfig
} else {
bufferFactory = new InMemoryBufferFactory();
}
final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);
s3SinkService = new S3SinkService(s3SinkConfig, bufferFactory, codec, s3Client, pluginMetrics);
}

@Override
Expand All @@ -85,7 +92,6 @@ public void doInitialize() {
* Initialize {@link S3SinkService}
*/
private void doInitializeInternal() {
s3SinkService = new S3SinkService(s3SinkConfig, bufferFactory, codec, pluginMetrics);
sinkInitialized = Boolean.TRUE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.services.s3.S3Client;

import java.io.IOException;
Expand All @@ -47,6 +45,7 @@ public class S3SinkService {
private final BufferFactory bufferFactory;
private final Collection<EventHandle> bufferedEventHandles;
private final Codec codec;
private final S3Client s3Client;
private Buffer currentBuffer;
private final int maxEvents;
private final ByteCount maxBytes;
Expand All @@ -61,15 +60,17 @@ public class S3SinkService {

/**
* @param s3SinkConfig s3 sink related configuration.
* @param bufferFactory factory of buffer.
* @param bufferFactory factory of buffer.
* @param codec parser.
* @param s3Client
* @param pluginMetrics metrics.
*/
public S3SinkService(final S3SinkConfig s3SinkConfig, final BufferFactory bufferFactory,
final Codec codec, final PluginMetrics pluginMetrics) {
final Codec codec, final S3Client s3Client, final PluginMetrics pluginMetrics) {
this.s3SinkConfig = s3SinkConfig;
this.bufferFactory = bufferFactory;
this.codec = codec;
this.s3Client = s3Client;
reentrantLock = new ReentrantLock();

bufferedEventHandles = new LinkedList<>();
Expand Down Expand Up @@ -154,7 +155,7 @@ protected boolean retryFlushToS3(final Buffer currentBuffer, final String s3Key)
int retryCount = maxRetries;
do {
try {
currentBuffer.flushToS3(createS3Client(), bucket, s3Key);
currentBuffer.flushToS3(s3Client, bucket, s3Key);
isUploadedToS3 = Boolean.TRUE;
} catch (AwsServiceException | SdkClientException e) {
LOG.error("Exception occurred while uploading records to s3 bucket. Retry countdown : {} | exception:",
Expand All @@ -179,15 +180,4 @@ protected String generateKey() {
final String namePattern = ObjectKey.objectFileName(s3SinkConfig);
return (!pathPrefix.isEmpty()) ? pathPrefix + namePattern : namePattern;
}

/**
* create s3 client instance.
* @return {@link S3Client}
*/
public S3Client createS3Client() {
return S3Client.builder().region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion())
.credentialsProvider(s3SinkConfig.getAwsAuthenticationOptions().authenticateAwsConfiguration())
.overrideConfiguration(ClientOverrideConfiguration.builder().retryPolicy(RetryPolicy.builder()
.numRetries(s3SinkConfig.getMaxConnectionRetries()).build()).build()).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,11 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.Size;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;

import java.util.Map;
import java.util.Optional;
import java.util.UUID;

public class AwsAuthenticationOptions {
private static final String AWS_IAM_ROLE = "role";
private static final String AWS_IAM = "iam";

@JsonProperty("region")
@Size(min = 1, message = "Region cannot be empty string")
private String awsRegion;
Expand All @@ -35,58 +24,15 @@ public class AwsAuthenticationOptions {
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> awsStsHeaderOverrides;

private void validateStsRoleArn() {
final Arn arn = getArn();
if (!AWS_IAM.equals(arn.service())) {
throw new IllegalArgumentException("sts_role_arn must be an IAM Role");
}
final Optional<String> resourceType = arn.resource().resourceType();
if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) {
throw new IllegalArgumentException("sts_role_arn must be an IAM Role");
}
}

private Arn getArn() {
try {
return Arn.fromString(awsStsRoleArn);
} catch (final Exception e) {
throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn));
}
}

public Region getAwsRegion() {
return awsRegion != null ? Region.of(awsRegion) : null;
}

public AwsCredentialsProvider authenticateAwsConfiguration() {

final AwsCredentialsProvider awsCredentialsProvider;
if (awsStsRoleArn != null && !awsStsRoleArn.isEmpty()) {

validateStsRoleArn();

final StsClient stsClient = StsClient.builder()
.region(getAwsRegion())
.build();

AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder()
.roleSessionName("S3-Sink-" + UUID.randomUUID())
.roleArn(awsStsRoleArn);
if(awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) {
assumeRoleRequestBuilder = assumeRoleRequestBuilder
.overrideConfiguration(configuration -> awsStsHeaderOverrides.forEach(configuration::putHeader));
}

awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(assumeRoleRequestBuilder.build())
.build();

} else {
// use default credential provider
awsCredentialsProvider = DefaultCredentialsProvider.create();
}
public String getAwsStsRoleArn() {
return awsStsRoleArn;
}

return awsCredentialsProvider;
public Map<String, String> getAwsStsHeaderOverrides() {
return awsStsHeaderOverrides;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.sink;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.sink.configuration.AwsAuthenticationOptions;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;

import java.util.Map;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class ClientFactoryTest {
@Mock
private S3SinkConfig s3SinkConfig;
@Mock
private AwsCredentialsSupplier awsCredentialsSupplier;

@Mock
private AwsAuthenticationOptions awsAuthenticationOptions;

@BeforeEach
void setUp() {
when(s3SinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
}

@Test
void createS3Client_with_real_S3Client() {
final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);

assertThat(s3Client, notNullValue());
}

@Test
void createS3Client_provides_correct_inputs() {
Region region = Region.US_WEST_2;
String stsRoleArn = UUID.randomUUID().toString();
final Map<String, String> stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString());
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region);
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides);

AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class);
when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider);

S3ClientBuilder s3ClientBuilder = mock(S3ClientBuilder.class);
when(s3ClientBuilder.region(region)).thenReturn(s3ClientBuilder);
when(s3ClientBuilder.credentialsProvider(any())).thenReturn(s3ClientBuilder);
when(s3ClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(s3ClientBuilder);
try(MockedStatic<S3Client> s3ClientMockedStatic = mockStatic(S3Client.class)) {
s3ClientMockedStatic.when(S3Client::builder)
.thenReturn(s3ClientBuilder);
ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier);
}

ArgumentCaptor<AwsCredentialsProvider> credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class);
verify(s3ClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture());

final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue();

assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider));

ArgumentCaptor<AwsCredentialsOptions> optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class);
verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture());

final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue();
assertThat(actualCredentialsOptions.getRegion(), equalTo(region));
assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn));
assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides));
}
}
Loading

0 comments on commit fcf0969

Please sign in to comment.