From fcf096987f88a6403a5a53feacf3e91ac9f27d64 Mon Sep 17 00:00:00 2001 From: David Venable Date: Wed, 31 May 2023 13:56:11 -0500 Subject: [PATCH] Updates the S3 sink to use the AWS Plugin for loading AWS credentials. Resolves #2767 Signed-off-by: David Venable --- data-prepper-plugins/s3-sink/build.gradle | 1 + .../plugins/sink/S3SinkServiceIT.java | 4 +- .../plugins/sink/ClientFactory.java | 43 +++++ .../dataprepper/plugins/sink/S3Sink.java | 14 +- .../plugins/sink/S3SinkService.java | 22 +-- .../AwsAuthenticationOptions.java | 64 +------ .../plugins/sink/ClientFactoryTest.java | 95 ++++++++++ .../plugins/sink/S3SinkServiceTest.java | 21 +-- .../dataprepper/plugins/sink/S3SinkTest.java | 23 +-- .../AwsAuthenticationOptionsTest.java | 176 +----------------- 10 files changed, 182 insertions(+), 281 deletions(-) create mode 100644 data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/ClientFactory.java create mode 100644 data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/ClientFactoryTest.java diff --git a/data-prepper-plugins/s3-sink/build.gradle b/data-prepper-plugins/s3-sink/build.gradle index b4fbc5b841..6d8b44cdb8 100644 --- a/data-prepper-plugins/s3-sink/build.gradle +++ b/data-prepper-plugins/s3-sink/build.gradle @@ -6,6 +6,7 @@ dependencies { implementation project(':data-prepper-api') implementation project(path: ':data-prepper-plugins:common') + implementation project(':data-prepper-plugins:aws-plugin-api') implementation 'io.micrometer:micrometer-core' implementation 'com.fasterxml.jackson.core:jackson-core' implementation 'com.fasterxml.jackson.core:jackson-databind' diff --git a/data-prepper-plugins/s3-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/S3SinkServiceIT.java b/data-prepper-plugins/s3-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/S3SinkServiceIT.java index 897c1eae87..c635650546 100644 --- a/data-prepper-plugins/s3-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/S3SinkServiceIT.java +++ b/data-prepper-plugins/s3-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/S3SinkServiceIT.java @@ -99,8 +99,6 @@ public void setUp() { when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("2mb")); when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.parse("PT3M")); when(s3SinkConfig.getThresholdOptions()).thenReturn(thresholdOptions); - when(s3SinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); - when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(s3region)); lenient().when(pluginMetrics.counter(S3SinkService.OBJECTS_SUCCEEDED)).thenReturn(snapshotSuccessCounter); lenient().when(pluginMetrics.counter(S3SinkService.OBJECTS_FAILED)).thenReturn(snapshotFailedCounter); @@ -136,7 +134,7 @@ void verify_flushed_records_into_s3_bucket() { } private S3SinkService createObjectUnderTest() { - return new S3SinkService(s3SinkConfig, bufferFactory, codec, pluginMetrics); + return new S3SinkService(s3SinkConfig, bufferFactory, codec, s3Client, pluginMetrics); } private int gets3ObjectCount() { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/ClientFactory.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/ClientFactory.java new file mode 100644 index 0000000000..7d71faa41f --- /dev/null +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/ClientFactory.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.sink.configuration.AwsAuthenticationOptions; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.services.s3.S3Client; + +public final class ClientFactory { + private ClientFactory() { } + + static S3Client createS3Client(final S3SinkConfig s3SinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) { + final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(s3SinkConfig.getAwsAuthenticationOptions()); + final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions); + + return S3Client.builder() + .region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion()) + .credentialsProvider(awsCredentialsProvider) + .overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build(); + } + + private static ClientOverrideConfiguration createOverrideConfiguration(final S3SinkConfig s3SinkConfig) { + final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries()).build(); + return ClientOverrideConfiguration.builder() + .retryPolicy(retryPolicy) + .build(); + } + + private static AwsCredentialsOptions convertToCredentialsOptions(final AwsAuthenticationOptions awsAuthenticationOptions) { + return AwsCredentialsOptions.builder() + .withRegion(awsAuthenticationOptions.getAwsRegion()) + .withStsRoleArn(awsAuthenticationOptions.getAwsStsRoleArn()) + .withStsHeaderOverrides(awsAuthenticationOptions.getAwsStsHeaderOverrides()) + .build(); + } +} diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/S3Sink.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/S3Sink.java index 2e312631f4..1dc6963c23 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/S3Sink.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/S3Sink.java @@ -5,6 +5,7 @@ package org.opensearch.dataprepper.plugins.sink; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; import org.opensearch.dataprepper.model.configuration.PluginModel; @@ -22,6 +23,8 @@ import org.opensearch.dataprepper.plugins.sink.codec.Codec; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.s3.S3Client; + import java.util.Collection; /** @@ -35,7 +38,7 @@ public class S3Sink extends AbstractSink> { private final S3SinkConfig s3SinkConfig; private final Codec codec; private volatile boolean sinkInitialized; - private S3SinkService s3SinkService; + private final S3SinkService s3SinkService; private final BufferFactory bufferFactory; /** @@ -44,8 +47,10 @@ public class S3Sink extends AbstractSink> { * @param pluginFactory dp plugin factory. */ @DataPrepperPluginConstructor - public S3Sink(final PluginSetting pluginSetting, final S3SinkConfig s3SinkConfig, - final PluginFactory pluginFactory) { + public S3Sink(final PluginSetting pluginSetting, + final S3SinkConfig s3SinkConfig, + final PluginFactory pluginFactory, + final AwsCredentialsSupplier awsCredentialsSupplier) { super(pluginSetting); this.s3SinkConfig = s3SinkConfig; final PluginModel codecConfiguration = s3SinkConfig.getCodec(); @@ -59,6 +64,8 @@ public S3Sink(final PluginSetting pluginSetting, final S3SinkConfig s3SinkConfig } else { bufferFactory = new InMemoryBufferFactory(); } + final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier); + s3SinkService = new S3SinkService(s3SinkConfig, bufferFactory, codec, s3Client, pluginMetrics); } @Override @@ -85,7 +92,6 @@ public void doInitialize() { * Initialize {@link S3SinkService} */ private void doInitializeInternal() { - s3SinkService = new S3SinkService(s3SinkConfig, bufferFactory, codec, pluginMetrics); sinkInitialized = Boolean.TRUE; } diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/S3SinkService.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/S3SinkService.java index 9c8e402060..6796ed933b 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/S3SinkService.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/S3SinkService.java @@ -19,9 +19,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.awscore.exception.AwsServiceException; -import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.core.retry.RetryPolicy; import software.amazon.awssdk.services.s3.S3Client; import java.io.IOException; @@ -47,6 +45,7 @@ public class S3SinkService { private final BufferFactory bufferFactory; private final Collection bufferedEventHandles; private final Codec codec; + private final S3Client s3Client; private Buffer currentBuffer; private final int maxEvents; private final ByteCount maxBytes; @@ -61,15 +60,17 @@ public class S3SinkService { /** * @param s3SinkConfig s3 sink related configuration. - * @param bufferFactory factory of buffer. + * @param bufferFactory factory of buffer. * @param codec parser. + * @param s3Client * @param pluginMetrics metrics. */ public S3SinkService(final S3SinkConfig s3SinkConfig, final BufferFactory bufferFactory, - final Codec codec, final PluginMetrics pluginMetrics) { + final Codec codec, final S3Client s3Client, final PluginMetrics pluginMetrics) { this.s3SinkConfig = s3SinkConfig; this.bufferFactory = bufferFactory; this.codec = codec; + this.s3Client = s3Client; reentrantLock = new ReentrantLock(); bufferedEventHandles = new LinkedList<>(); @@ -154,7 +155,7 @@ protected boolean retryFlushToS3(final Buffer currentBuffer, final String s3Key) int retryCount = maxRetries; do { try { - currentBuffer.flushToS3(createS3Client(), bucket, s3Key); + currentBuffer.flushToS3(s3Client, bucket, s3Key); isUploadedToS3 = Boolean.TRUE; } catch (AwsServiceException | SdkClientException e) { LOG.error("Exception occurred while uploading records to s3 bucket. Retry countdown : {} | exception:", @@ -179,15 +180,4 @@ protected String generateKey() { final String namePattern = ObjectKey.objectFileName(s3SinkConfig); return (!pathPrefix.isEmpty()) ? pathPrefix + namePattern : namePattern; } - - /** - * create s3 client instance. - * @return {@link S3Client} - */ - public S3Client createS3Client() { - return S3Client.builder().region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion()) - .credentialsProvider(s3SinkConfig.getAwsAuthenticationOptions().authenticateAwsConfiguration()) - .overrideConfiguration(ClientOverrideConfiguration.builder().retryPolicy(RetryPolicy.builder() - .numRetries(s3SinkConfig.getMaxConnectionRetries()).build()).build()).build(); - } } \ No newline at end of file diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/configuration/AwsAuthenticationOptions.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/configuration/AwsAuthenticationOptions.java index 913f978ad1..783f98fa55 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/configuration/AwsAuthenticationOptions.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/configuration/AwsAuthenticationOptions.java @@ -7,22 +7,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; import jakarta.validation.constraints.Size; -import software.amazon.awssdk.arns.Arn; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.sts.StsClient; -import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; -import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; import java.util.Map; -import java.util.Optional; -import java.util.UUID; public class AwsAuthenticationOptions { - private static final String AWS_IAM_ROLE = "role"; - private static final String AWS_IAM = "iam"; - @JsonProperty("region") @Size(min = 1, message = "Region cannot be empty string") private String awsRegion; @@ -35,58 +24,15 @@ public class AwsAuthenticationOptions { @Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override") private Map awsStsHeaderOverrides; - private void validateStsRoleArn() { - final Arn arn = getArn(); - if (!AWS_IAM.equals(arn.service())) { - throw new IllegalArgumentException("sts_role_arn must be an IAM Role"); - } - final Optional resourceType = arn.resource().resourceType(); - if (resourceType.isEmpty() || !resourceType.get().equals(AWS_IAM_ROLE)) { - throw new IllegalArgumentException("sts_role_arn must be an IAM Role"); - } - } - - private Arn getArn() { - try { - return Arn.fromString(awsStsRoleArn); - } catch (final Exception e) { - throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn)); - } - } - public Region getAwsRegion() { return awsRegion != null ? Region.of(awsRegion) : null; } - public AwsCredentialsProvider authenticateAwsConfiguration() { - - final AwsCredentialsProvider awsCredentialsProvider; - if (awsStsRoleArn != null && !awsStsRoleArn.isEmpty()) { - - validateStsRoleArn(); - - final StsClient stsClient = StsClient.builder() - .region(getAwsRegion()) - .build(); - - AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder() - .roleSessionName("S3-Sink-" + UUID.randomUUID()) - .roleArn(awsStsRoleArn); - if(awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) { - assumeRoleRequestBuilder = assumeRoleRequestBuilder - .overrideConfiguration(configuration -> awsStsHeaderOverrides.forEach(configuration::putHeader)); - } - - awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder() - .stsClient(stsClient) - .refreshRequest(assumeRoleRequestBuilder.build()) - .build(); - - } else { - // use default credential provider - awsCredentialsProvider = DefaultCredentialsProvider.create(); - } + public String getAwsStsRoleArn() { + return awsStsRoleArn; + } - return awsCredentialsProvider; + public Map getAwsStsHeaderOverrides() { + return awsStsHeaderOverrides; } } \ No newline at end of file diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/ClientFactoryTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/ClientFactoryTest.java new file mode 100644 index 0000000000..38aaaddab6 --- /dev/null +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/ClientFactoryTest.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.sink.configuration.AwsAuthenticationOptions; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; + +import java.util.Map; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class ClientFactoryTest { + @Mock + private S3SinkConfig s3SinkConfig; + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private AwsAuthenticationOptions awsAuthenticationOptions; + + @BeforeEach + void setUp() { + when(s3SinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); + } + + @Test + void createS3Client_with_real_S3Client() { + final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier); + + assertThat(s3Client, notNullValue()); + } + + @Test + void createS3Client_provides_correct_inputs() { + Region region = Region.US_WEST_2; + String stsRoleArn = UUID.randomUUID().toString(); + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + when(awsAuthenticationOptions.getAwsRegion()).thenReturn(region); + when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(stsRoleArn); + when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); + + AwsCredentialsProvider expectedCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(any())).thenReturn(expectedCredentialsProvider); + + S3ClientBuilder s3ClientBuilder = mock(S3ClientBuilder.class); + when(s3ClientBuilder.region(region)).thenReturn(s3ClientBuilder); + when(s3ClientBuilder.credentialsProvider(any())).thenReturn(s3ClientBuilder); + when(s3ClientBuilder.overrideConfiguration(any(ClientOverrideConfiguration.class))).thenReturn(s3ClientBuilder); + try(MockedStatic s3ClientMockedStatic = mockStatic(S3Client.class)) { + s3ClientMockedStatic.when(S3Client::builder) + .thenReturn(s3ClientBuilder); + ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier); + } + + ArgumentCaptor credentialsProviderArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsProvider.class); + verify(s3ClientBuilder).credentialsProvider(credentialsProviderArgumentCaptor.capture()); + + final AwsCredentialsProvider actualCredentialsProvider = credentialsProviderArgumentCaptor.getValue(); + + assertThat(actualCredentialsProvider, equalTo(expectedCredentialsProvider)); + + ArgumentCaptor optionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); + verify(awsCredentialsSupplier).getProvider(optionsArgumentCaptor.capture()); + + final AwsCredentialsOptions actualCredentialsOptions = optionsArgumentCaptor.getValue(); + assertThat(actualCredentialsOptions.getRegion(), equalTo(region)); + assertThat(actualCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); + assertThat(actualCredentialsOptions.getStsHeaderOverrides(), equalTo(stsHeaderOverrides)); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/S3SinkServiceTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/S3SinkServiceTest.java index ee7c4275a1..009f65c0b1 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/S3SinkServiceTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/S3SinkServiceTest.java @@ -55,6 +55,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.lenient; @@ -75,6 +76,7 @@ class S3SinkServiceTest { public static final String CODEC_PLUGIN_NAME = "json"; public static final String PATH_PREFIX = "logdata/"; private S3SinkConfig s3SinkConfig; + private S3Client s3Client; private JsonCodec codec; private PluginMetrics pluginMetrics; private BufferFactory bufferFactory; @@ -83,10 +85,11 @@ class S3SinkServiceTest { private Random random; @BeforeEach - void setUp() throws Exception { + void setUp() { random = new Random(); s3SinkConfig = mock(S3SinkConfig.class); + s3Client = mock(S3Client.class); ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); ObjectKeyOptions objectKeyOptions = mock(ObjectKeyOptions.class); AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); @@ -129,7 +132,7 @@ void setUp() throws Exception { } private S3SinkService createObjectUnderTest() { - return new S3SinkService(s3SinkConfig, bufferFactory, codec, pluginMetrics); + return new S3SinkService(s3SinkConfig, bufferFactory, codec, s3Client, pluginMetrics); } @Test @@ -139,14 +142,6 @@ void test_s3SinkService_notNull() { assertThat(s3SinkService, instanceOf(S3SinkService.class)); } - @Test - void test_s3Client_notNull() { - S3SinkService s3SinkService = createObjectUnderTest(); - S3Client s3Client = s3SinkService.createS3Client(); - assertNotNull(s3Client); - assertThat(s3Client, instanceOf(S3Client.class)); - } - @Test void test_generateKey_with_general_prefix() { String pathPrefix = "events/"; @@ -310,13 +305,15 @@ void test_retryFlushToS3_positive() throws InterruptedException, IOException { @Test void test_retryFlushToS3_negative() throws InterruptedException, IOException { + bufferFactory = mock(BufferFactory.class); + InMemoryBuffer buffer = mock(InMemoryBuffer.class); + when(bufferFactory.getBuffer()).thenReturn(buffer); when(s3SinkConfig.getBucketName()).thenReturn(""); S3SinkService s3SinkService = createObjectUnderTest(); assertNotNull(s3SinkService); - Buffer buffer = bufferFactory.getBuffer(); - assertNotNull(buffer); buffer.writeEvent(generateByteArray()); final String s3Key = UUID.randomUUID().toString(); + doThrow(AwsServiceException.class).when(buffer).flushToS3(eq(s3Client), anyString(), anyString()); boolean isUploadedToS3 = s3SinkService.retryFlushToS3(buffer, s3Key); assertFalse(isUploadedToS3); } diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/S3SinkTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/S3SinkTest.java index b7fc33d196..25941e718f 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/S3SinkTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/S3SinkTest.java @@ -8,6 +8,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.model.configuration.PluginModel; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; @@ -27,7 +28,6 @@ import java.util.Collection; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -48,6 +48,7 @@ class S3SinkTest { private S3Sink s3Sink; private PluginSetting pluginSetting; private PluginFactory pluginFactory; + private AwsCredentialsSupplier awsCredentialsSupplier; @BeforeEach void setUp() { @@ -60,6 +61,7 @@ void setUp() { pluginSetting = mock(PluginSetting.class); PluginModel pluginModel = mock(PluginModel.class); pluginFactory = mock(PluginFactory.class); + awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); when(s3SinkConfig.getBufferType()).thenReturn(BufferTypeOptions.INMEMORY); when(s3SinkConfig.getThresholdOptions()).thenReturn(thresholdOptions); @@ -77,9 +79,13 @@ void setUp() { when(s3SinkConfig.getBucketName()).thenReturn(BUCKET_NAME); } + private S3Sink createObjectUnderTest() { + return new S3Sink(pluginSetting, s3SinkConfig, pluginFactory, awsCredentialsSupplier); + } + @Test void test_s3_sink_plugin_isReady_positive() { - s3Sink = new S3Sink(pluginSetting, s3SinkConfig, pluginFactory); + s3Sink = createObjectUnderTest(); Assertions.assertNotNull(s3Sink); s3Sink.doInitialize(); assertTrue(s3Sink.isReady(), "s3 sink is not initialized and not ready to work"); @@ -87,24 +93,15 @@ void test_s3_sink_plugin_isReady_positive() { @Test void test_s3_Sink_plugin_isReady_negative() { - s3Sink = new S3Sink(pluginSetting, s3SinkConfig, pluginFactory); + s3Sink = createObjectUnderTest(); Assertions.assertNotNull(s3Sink); assertFalse(s3Sink.isReady(), "s3 sink is initialized and ready to work"); } - @Test - void test_doInitialize_with_exception() { - when(s3SinkConfig.getBufferType()).thenReturn(BufferTypeOptions.INMEMORY); - s3Sink = new S3Sink(pluginSetting, s3SinkConfig, pluginFactory); - Assertions.assertNotNull(s3Sink); - when(s3SinkConfig.getThresholdOptions()).thenReturn(null); - assertThrows(NullPointerException.class, s3Sink::doInitialize); - } - @Test void test_doOutput_with_empty_records() { when(s3SinkConfig.getBucketName()).thenReturn(BUCKET_NAME); - s3Sink = new S3Sink(pluginSetting, s3SinkConfig, pluginFactory); + s3Sink = createObjectUnderTest(); Assertions.assertNotNull(s3Sink); s3Sink.doInitialize(); Collection> records = new ArrayList<>(); diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/configuration/AwsAuthenticationOptionsTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/configuration/AwsAuthenticationOptionsTest.java index cc52739dbc..1d1f5af40f 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/configuration/AwsAuthenticationOptionsTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/configuration/AwsAuthenticationOptionsTest.java @@ -6,41 +6,23 @@ package org.opensearch.dataprepper.plugins.sink.configuration; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.opensearch.dataprepper.test.helper.ReflectivelySetField; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.sts.StsClient; -import software.amazon.awssdk.services.sts.StsClientBuilder; -import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; -import java.util.Collections; -import java.util.Map; + import java.util.UUID; -import java.util.function.Consumer; + import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.nullValue; -import static org.hamcrest.CoreMatchers.sameInstance; import static org.hamcrest.MatcherAssert.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; class AwsAuthenticationOptionsTest { private AwsAuthenticationOptions awsAuthenticationOptions; - private final String TEST_ROLE = "arn:aws:iam::123456789012:role/test-role"; - @BeforeEach void setUp() { awsAuthenticationOptions = new AwsAuthenticationOptions(); @@ -64,158 +46,4 @@ void getAwsRegion_returns_null_when_region_is_null() throws NoSuchFieldException ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsRegion", null); assertThat(awsAuthenticationOptions.getAwsRegion(), nullValue()); } - - @Test - void authenticateAWSConfiguration_should_return_s3Client_without_sts_role_arn() throws NoSuchFieldException, IllegalAccessException { - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsRegion", "us-east-1"); - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsStsRoleArn", null); - - final DefaultCredentialsProvider mockedCredentialsProvider = mock(DefaultCredentialsProvider.class); - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic defaultCredentialsProviderMockedStatic = mockStatic(DefaultCredentialsProvider.class)) { - defaultCredentialsProviderMockedStatic.when(DefaultCredentialsProvider::create) - .thenReturn(mockedCredentialsProvider); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, sameInstance(mockedCredentialsProvider)); - } - - @Nested - class WithSts { - private StsClient stsClient; - private StsClientBuilder stsClientBuilder; - - @BeforeEach - void setUp() { - stsClient = mock(StsClient.class); - stsClientBuilder = mock(StsClientBuilder.class); - - when(stsClientBuilder.build()).thenReturn(stsClient); - } - - @Test - void authenticateAWSConfiguration_should_return_s3Client_with_sts_role_arn() throws NoSuchFieldException, IllegalAccessException { - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsRegion", "us-east-1"); - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsStsRoleArn", TEST_ROLE); - - when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); - final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); - when(assumeRoleRequestBuilder.roleSessionName(anyString())) - .thenReturn(assumeRoleRequestBuilder); - when(assumeRoleRequestBuilder.roleArn(anyString())) - .thenReturn(assumeRoleRequestBuilder); - - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); - final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { - stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); - assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); - - verify(assumeRoleRequestBuilder).roleArn(TEST_ROLE); - verify(assumeRoleRequestBuilder).roleSessionName(anyString()); - verify(assumeRoleRequestBuilder).build(); - verifyNoMoreInteractions(assumeRoleRequestBuilder); - } - - @Test - void authenticateAWSConfiguration_should_return_s3Client_with_sts_role_arn_when_no_region() throws NoSuchFieldException, IllegalAccessException { - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsRegion", null); - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsStsRoleArn", TEST_ROLE); - assertThat(awsAuthenticationOptions.getAwsRegion(), equalTo(null)); - - when(stsClientBuilder.region(null)).thenReturn(stsClientBuilder); - - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class)) { - stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); - } - - @Test - void authenticateAWSConfiguration_should_override_STS_Headers_when_HeaderOverrides_when_set() throws NoSuchFieldException, IllegalAccessException { - final String headerName1 = UUID.randomUUID().toString(); - final String headerValue1 = UUID.randomUUID().toString(); - final String headerName2 = UUID.randomUUID().toString(); - final String headerValue2 = UUID.randomUUID().toString(); - final Map overrideHeaders = Map.of(headerName1, headerValue1, headerName2, headerValue2); - - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsRegion", "us-east-1"); - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsStsRoleArn", TEST_ROLE); - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsStsHeaderOverrides", overrideHeaders); - - when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); - - final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); - when(assumeRoleRequestBuilder.roleSessionName(anyString())) - .thenReturn(assumeRoleRequestBuilder); - when(assumeRoleRequestBuilder.roleArn(anyString())) - .thenReturn(assumeRoleRequestBuilder); - when(assumeRoleRequestBuilder.overrideConfiguration(any(Consumer.class))) - .thenReturn(assumeRoleRequestBuilder); - - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); - final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { - stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); - assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); - - final ArgumentCaptor> configurationCaptor = ArgumentCaptor.forClass(Consumer.class); - - verify(assumeRoleRequestBuilder).roleArn(TEST_ROLE); - verify(assumeRoleRequestBuilder).roleSessionName(anyString()); - verify(assumeRoleRequestBuilder).overrideConfiguration(configurationCaptor.capture()); - verify(assumeRoleRequestBuilder).build(); - verifyNoMoreInteractions(assumeRoleRequestBuilder); - - final Consumer actualOverride = configurationCaptor.getValue(); - - final AwsRequestOverrideConfiguration.Builder configurationBuilder = mock(AwsRequestOverrideConfiguration.Builder.class); - actualOverride.accept(configurationBuilder); - verify(configurationBuilder).putHeader(headerName1, headerValue1); - verify(configurationBuilder).putHeader(headerName2, headerValue2); - verifyNoMoreInteractions(configurationBuilder); - } - - @Test - void authenticateAWSConfiguration_should_not_override_STS_Headers_when_HeaderOverrides_are_empty() throws NoSuchFieldException, IllegalAccessException { - - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsRegion", "us-east-1"); - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsStsRoleArn", TEST_ROLE); - ReflectivelySetField.setField(AwsAuthenticationOptions.class, awsAuthenticationOptions, "awsStsHeaderOverrides", Collections.emptyMap()); - - when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); - final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); - when(assumeRoleRequestBuilder.roleSessionName(anyString())) - .thenReturn(assumeRoleRequestBuilder); - when(assumeRoleRequestBuilder.roleArn(anyString())) - .thenReturn(assumeRoleRequestBuilder); - - final AwsCredentialsProvider actualCredentialsProvider; - try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); - final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { - stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); - assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); - actualCredentialsProvider = awsAuthenticationOptions.authenticateAwsConfiguration(); - } - - assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); - - verify(assumeRoleRequestBuilder).roleArn(TEST_ROLE); - verify(assumeRoleRequestBuilder).roleSessionName(anyString()); - verify(assumeRoleRequestBuilder).build(); - verifyNoMoreInteractions(assumeRoleRequestBuilder); - } - } } \ No newline at end of file