From cd72f45c3358139d717392924260657ddf835055 Mon Sep 17 00:00:00 2001 From: Albert Zaharovits Date: Wed, 23 Dec 2020 23:46:59 +0200 Subject: [PATCH] Client-side encrypted snapshot repository (feature flag) (#66773) The client-side encrypted repository is a new type of snapshot repository that internally delegates to the regular variants of snapshot repositories (of types Azure, S3, GCS, FS, and maybe others but not yet tested). After the encrypted repository is set up, it is transparent to the snapshot and restore APIs (i.e. all snapshots stored in the encrypted repository are encrypted, no other parameters required). The encrypted repository is protected by a password stored on every node's keystore (which must be the same across the nodes). The password is used to generate a key encrytion key (KEK), using the PBKDF2 function, which is used to encrypt (using the AES Wrap algorithm) other symmetric keys (referred to as DEK - data encryption keys), which themselves are generated randomly, and which are ultimately used to encrypt the snapshot blobs. For example, here is how to set up an encrypted FS repository: ------ 1) make sure that the cluster runs under at least a "platinum" license (simplest test configuration is to put `xpack.license.self_generated.type: "trial"` in the elasticsearch.yml file) 2) identical to the un-encrypted FS repository, specify the mount point of the shared FS in the elasticsearch.yml conf file (on all the cluster nodes), e.g. `path.repo: ["/tmp/repo"]` 3) store the repository password inside the elasticsearch.keystore, *on every cluster node*. In order to support changing password on existing repository (implemented in a follow-up), the password itself must be names, e.g. for the "test_enc_key" repository password name: `./bin/elasticsearch-keystore add repository.encrypted.test_enc_pass.password` *type in the password* 4) start up the cluster and create the new encrypted FS repository, named "test_enc", by calling: ` curl -X PUT "localhost:9200/_snapshot/test_enc?pretty" -H 'Content-Type: application/json' -d' { "type": "encrypted", "settings": { "location": "/tmp/repo/enc", "delegate_type": "fs", "password_name": "test_enc_pass" } } ' ` 5) the snapshot and restore APIs work unmodified when they refer to this new repository, e.g. ` curl -X PUT "localhost:9200/_snapshot/test_enc/snapshot_1?wait_for_completion=true"` Related: #49896 #41910 #50846 #48221 #65768 --- plugins/repository-azure/build.gradle | 15 + .../azure/AzureBlobStoreRepositoryTests.java | 15 +- plugins/repository-gcs/build.gradle | 17 + ...eCloudStorageBlobStoreRepositoryTests.java | 18 +- .../hdfs/HdfsBlobStoreRepositoryTests.java | 8 +- plugins/repository-s3/build.gradle | 17 + .../s3/S3BlobStoreRepositoryTests.java | 18 +- .../fs/FsBlobStoreRepositoryIntegTests.java | 40 + .../common/blobstore/BlobPath.java | 14 + .../ESBlobStoreRepositoryIntegTestCase.java | 57 +- .../ESFsBasedRepositoryIntegTestCase.java | 67 +- ...ESMockAPIBasedRepositoryIntegTestCase.java | 39 +- .../license/XPackLicenseState.java | 2 + .../license/XPackLicenseStateTests.java | 19 + .../plugin/repository-encrypted/build.gradle | 30 + ...tedAzureBlobStoreRepositoryIntegTests.java | 101 ++ ...ryptedFSBlobStoreRepositoryIntegTests.java | 155 +++ ...yptedGCSBlobStoreRepositoryIntegTests.java | 102 ++ .../EncryptedRepositorySecretIntegTests.java | 806 ++++++++++++ ...ryptedS3BlobStoreRepositoryIntegTests.java | 106 ++ .../repositories/encrypted/AESKeyUtils.java | 76 ++ .../encrypted/BufferOnMarkInputStream.java | 547 ++++++++ .../encrypted/ChainingInputStream.java | 426 ++++++ .../encrypted/CountingInputStream.java | 115 ++ .../DecryptionPacketsInputStream.java | 176 +++ .../encrypted/EncryptedRepository.java | 697 ++++++++++ .../encrypted/EncryptedRepositoryPlugin.java | 199 +++ .../EncryptionPacketsInputStream.java | 198 +++ .../encrypted/PrefixInputStream.java | 150 +++ .../repositories/encrypted/SingleUseKey.java | 103 ++ .../plugin-metadata/plugin-security.policy | 8 + .../encrypted/AESKeyUtilsTests.java | 54 + .../BufferOnMarkInputStreamTests.java | 853 ++++++++++++ .../encrypted/ChainingInputStreamTests.java | 1170 +++++++++++++++++ .../encrypted/CountingInputStreamTests.java | 162 +++ .../DecryptionPacketsInputStreamTests.java | 198 +++ .../encrypted/EncryptedRepositoryTests.java | 175 +++ .../EncryptionPacketsInputStreamTests.java | 536 ++++++++ .../LocalStateEncryptedRepositoryPlugin.java | 153 +++ .../encrypted/PrefixInputStreamTests.java | 222 ++++ .../encrypted/SingleUseKeyTests.java | 156 +++ .../test/repository_encrypted/10_basic.yml | 16 + 42 files changed, 7956 insertions(+), 80 deletions(-) create mode 100644 server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIntegTests.java rename server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIT.java => test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESFsBasedRepositoryIntegTestCase.java (69%) create mode 100644 x-pack/plugin/repository-encrypted/build.gradle create mode 100644 x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedAzureBlobStoreRepositoryIntegTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedFSBlobStoreRepositoryIntegTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedGCSBlobStoreRepositoryIntegTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedRepositorySecretIntegTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedS3BlobStoreRepositoryIntegTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/AESKeyUtils.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStream.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/ChainingInputStream.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/CountingInputStream.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStream.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepository.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryPlugin.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStream.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/PrefixInputStream.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/SingleUseKey.java create mode 100644 x-pack/plugin/repository-encrypted/src/main/plugin-metadata/plugin-security.policy create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/AESKeyUtilsTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStreamTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/ChainingInputStreamTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/CountingInputStreamTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStreamTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStreamTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/LocalStateEncryptedRepositoryPlugin.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/PrefixInputStreamTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/SingleUseKeyTests.java create mode 100644 x-pack/plugin/repository-encrypted/src/test/resources/rest-api-spec/test/repository_encrypted/10_basic.yml diff --git a/plugins/repository-azure/build.gradle b/plugins/repository-azure/build.gradle index 20752e045e489..e06b71feed6e6 100644 --- a/plugins/repository-azure/build.gradle +++ b/plugins/repository-azure/build.gradle @@ -374,3 +374,18 @@ task azureThirdPartyTest(type: Test) { } } tasks.named("check").configure { dependsOn("azureThirdPartyTest") } + +// test jar is exported by the integTestArtifacts configuration to be used in the encrypted Azure repository test +configurations { + internalClusterTestArtifacts.extendsFrom internalClusterTestImplementation + internalClusterTestArtifacts.extendsFrom internalClusterTestRuntime +} + +def internalClusterTestJar = tasks.register("internalClusterTestJar", Jar) { + appendix 'internalClusterTest' + from sourceSets.internalClusterTest.output +} + +artifacts { + internalClusterTestArtifacts internalClusterTestJar +} diff --git a/plugins/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java b/plugins/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java index 69d0220f6f57e..b00936d3efa81 100644 --- a/plugins/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java +++ b/plugins/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java @@ -59,12 +59,15 @@ protected String repositoryType() { } @Override - protected Settings repositorySettings() { - return Settings.builder() - .put(super.repositorySettings()) - .put(AzureRepository.Repository.CONTAINER_SETTING.getKey(), "container") - .put(AzureStorageSettings.ACCOUNT_SETTING.getKey(), "test") - .build(); + protected Settings repositorySettings(String repoName) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.repositorySettings(repoName)) + .put(AzureRepository.Repository.CONTAINER_SETTING.getKey(), "container") + .put(AzureStorageSettings.ACCOUNT_SETTING.getKey(), "test"); + if (randomBoolean()) { + settingsBuilder.put(AzureRepository.Repository.BASE_PATH_SETTING.getKey(), randomFrom("test", "test/1")); + } + return settingsBuilder.build(); } @Override diff --git a/plugins/repository-gcs/build.gradle b/plugins/repository-gcs/build.gradle index 3436f5af7aa32..8cdf7c701a1be 100644 --- a/plugins/repository-gcs/build.gradle +++ b/plugins/repository-gcs/build.gradle @@ -334,3 +334,20 @@ def gcsThirdPartyTest = tasks.register("gcsThirdPartyTest", Test) { tasks.named("check").configure { dependsOn(largeBlobYamlRestTest, gcsThirdPartyTest) } + +// test jar is exported by the integTestArtifacts configuration to be used in the encrypted GCS repository test +configurations { + internalClusterTestArtifacts.extendsFrom internalClusterTestImplementation + internalClusterTestArtifacts.extendsFrom internalClusterTestRuntime +} + +def internalClusterTestJar = tasks.register("internalClusterTestJar", Jar) { + appendix 'internalClusterTest' + from sourceSets.internalClusterTest.output + // for the repositories.gcs.TestUtils class + from sourceSets.test.output +} + +artifacts { + internalClusterTestArtifacts internalClusterTestJar +} diff --git a/plugins/repository-gcs/src/internalClusterTest/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStoreRepositoryTests.java b/plugins/repository-gcs/src/internalClusterTest/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStoreRepositoryTests.java index 47918619c18e1..69b8fb669d441 100644 --- a/plugins/repository-gcs/src/internalClusterTest/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStoreRepositoryTests.java +++ b/plugins/repository-gcs/src/internalClusterTest/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStoreRepositoryTests.java @@ -67,6 +67,7 @@ import static org.elasticsearch.repositories.gcs.GoogleCloudStorageClientSettings.CREDENTIALS_FILE_SETTING; import static org.elasticsearch.repositories.gcs.GoogleCloudStorageClientSettings.ENDPOINT_SETTING; import static org.elasticsearch.repositories.gcs.GoogleCloudStorageClientSettings.TOKEN_URI_SETTING; +import static org.elasticsearch.repositories.gcs.GoogleCloudStorageRepository.BASE_PATH; import static org.elasticsearch.repositories.gcs.GoogleCloudStorageRepository.BUCKET; import static org.elasticsearch.repositories.gcs.GoogleCloudStorageRepository.CLIENT_NAME; @@ -79,12 +80,15 @@ protected String repositoryType() { } @Override - protected Settings repositorySettings() { - return Settings.builder() - .put(super.repositorySettings()) - .put(BUCKET.getKey(), "bucket") - .put(CLIENT_NAME.getKey(), "test") - .build(); + protected Settings repositorySettings(String repoName) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.repositorySettings(repoName)) + .put(BUCKET.getKey(), "bucket") + .put(CLIENT_NAME.getKey(), "test"); + if (randomBoolean()) { + settingsBuilder.put(BASE_PATH.getKey(), randomFrom("test", "test/1")); + } + return settingsBuilder.build(); } @Override @@ -120,7 +124,7 @@ protected Settings nodeSettings(int nodeOrdinal) { } public void testDeleteSingleItem() { - final String repoName = createRepository(randomName()); + final String repoName = createRepository(randomRepositoryName()); final RepositoriesService repositoriesService = internalCluster().getMasterNodeInstance(RepositoriesService.class); final BlobStoreRepository repository = (BlobStoreRepository) repositoriesService.repository(repoName); PlainActionFuture.get(f -> repository.threadPool().generic().execute(ActionRunnable.run(f, () -> diff --git a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java index a427cee824209..e87c2ac5c1b29 100644 --- a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java +++ b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java @@ -38,7 +38,7 @@ protected String repositoryType() { } @Override - protected Settings repositorySettings() { + protected Settings repositorySettings(String repoName) { return Settings.builder() .put("uri", "hdfs:///") .put("conf.fs.AbstractFileSystem.hdfs.impl", TestingFs.class.getName()) @@ -47,6 +47,12 @@ protected Settings repositorySettings() { .put("compress", randomBoolean()).build(); } + @Override + public void testSnapshotAndRestore() throws Exception { + // the HDFS mockup doesn't preserve the repository contents after removing the repository + testSnapshotAndRestore(false); + } + @Override protected Collection> nodePlugins() { return Collections.singletonList(HdfsPlugin.class); diff --git a/plugins/repository-s3/build.gradle b/plugins/repository-s3/build.gradle index e583832787ea1..36dcd59a6d885 100644 --- a/plugins/repository-s3/build.gradle +++ b/plugins/repository-s3/build.gradle @@ -312,3 +312,20 @@ tasks.named("thirdPartyAudit").configure { 'javax.activation.DataHandler' ) } + +// test jar is exported by the integTestArtifacts configuration to be used in the encrypted S3 repository test +configurations { + internalClusterTestArtifacts.extendsFrom internalClusterTestImplementation + internalClusterTestArtifacts.extendsFrom internalClusterTestRuntime +} + +def internalClusterTestJar = tasks.register("internalClusterTestJar", Jar) { + appendix 'internalClusterTest' + from sourceSets.internalClusterTest.output + // for the plugin-security.policy resource + from sourceSets.test.output +} + +artifacts { + internalClusterTestArtifacts internalClusterTestJar +} diff --git a/plugins/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java b/plugins/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java index 6dcef8ddf554b..4a8f308ce1bba 100644 --- a/plugins/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java +++ b/plugins/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java @@ -94,14 +94,17 @@ protected String repositoryType() { } @Override - protected Settings repositorySettings() { - return Settings.builder() - .put(super.repositorySettings()) + protected Settings repositorySettings(String repoName) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.repositorySettings(repoName)) .put(S3Repository.BUCKET_SETTING.getKey(), "bucket") .put(S3Repository.CLIENT_NAME.getKey(), "test") // Don't cache repository data because some tests manually modify the repository data - .put(BlobStoreRepository.CACHE_REPOSITORY_DATA.getKey(), false) - .build(); + .put(BlobStoreRepository.CACHE_REPOSITORY_DATA.getKey(), false); + if (randomBoolean()) { + settingsBuilder.put(S3Repository.BASE_PATH_SETTING.getKey(), randomFrom("test", "test/1")); + } + return settingsBuilder.build(); } @Override @@ -145,8 +148,9 @@ protected Settings nodeSettings(int nodeOrdinal) { } public void testEnforcedCooldownPeriod() throws IOException { - final String repoName = createRepository(randomName(), Settings.builder().put(repositorySettings()) - .put(S3Repository.COOLDOWN_PERIOD.getKey(), TEST_COOLDOWN_PERIOD).build()); + final String repoName = randomRepositoryName(); + createRepository(repoName, Settings.builder().put(repositorySettings(repoName)) + .put(S3Repository.COOLDOWN_PERIOD.getKey(), TEST_COOLDOWN_PERIOD).build(), true); final SnapshotId fakeOldSnapshot = client().admin().cluster().prepareCreateSnapshot(repoName, "snapshot-old") .setWaitForCompletion(true).setIndices().get().getSnapshotInfo().snapshotId(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIntegTests.java b/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..268f47130601d --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIntegTests.java @@ -0,0 +1,40 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.repositories.fs; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.repositories.blobstore.ESFsBasedRepositoryIntegTestCase; + +public class FsBlobStoreRepositoryIntegTests extends ESFsBasedRepositoryIntegTestCase { + + @Override + protected Settings repositorySettings(String repositoryName) { + final Settings.Builder settings = Settings.builder() + .put("compress", randomBoolean()) + .put("location", randomRepoPath()); + if (randomBoolean()) { + long size = 1 << randomInt(10); + settings.put("chunk_size", new ByteSizeValue(size, ByteSizeUnit.KB)); + } + return settings.build(); + } +} diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/BlobPath.java b/server/src/main/java/org/elasticsearch/common/blobstore/BlobPath.java index 6c6df937584d8..9c635fbe9201c 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/BlobPath.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/BlobPath.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Objects; /** * The list of paths where a blob can reside. The contents of the paths are dependent upon the implementation of {@link BlobContainer}. @@ -90,4 +91,17 @@ public String toString() { } return sb.toString(); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BlobPath other = (BlobPath) o; + return paths.equals(other.paths); + } + + @Override + public int hashCode() { + return Objects.hash(paths); + } } diff --git a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java index 1c529822504b5..e3f58ff38cbf7 100644 --- a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java @@ -41,6 +41,7 @@ import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.repositories.Repository; import org.elasticsearch.repositories.RepositoryData; +import org.elasticsearch.repositories.RepositoryMissingException; import org.elasticsearch.snapshots.SnapshotMissingException; import org.elasticsearch.snapshots.SnapshotRestoreException; import org.elasticsearch.test.ESIntegTestCase; @@ -78,17 +79,19 @@ public static RepositoryData getRepositoryData(Repository repository) { protected abstract String repositoryType(); - protected Settings repositorySettings() { + protected Settings repositorySettings(String repoName) { return Settings.builder().put("compress", randomBoolean()).build(); } protected final String createRepository(final String name) { - return createRepository(name, repositorySettings()); + return createRepository(name, true); } - protected final String createRepository(final String name, final Settings settings) { - final boolean verify = randomBoolean(); + protected final String createRepository(final String name, final boolean verify) { + return createRepository(name, repositorySettings(name), verify); + } + protected final String createRepository(final String name, final Settings settings, final boolean verify) { logger.info("--> creating repository [name: {}, verify: {}, settings: {}]", name, verify, settings); assertAcked(client().admin().cluster().preparePutRepository(name) .setType(repositoryType()) @@ -98,7 +101,7 @@ protected final String createRepository(final String name, final Settings settin internalCluster().getDataOrMasterNodeInstances(RepositoriesService.class).forEach(repositories -> { assertThat(repositories.repository(name), notNullValue()); assertThat(repositories.repository(name), instanceOf(BlobStoreRepository.class)); - assertThat(repositories.repository(name).isReadOnly(), is(false)); + assertThat(repositories.repository(name).isReadOnly(), is(settings.getAsBoolean("readonly", false))); BlobStore blobStore = ((BlobStoreRepository) repositories.repository(name)).getBlobStore(); assertThat("blob store has to be lazy initialized", blobStore, verify ? is(notNullValue()) : is(nullValue())); }); @@ -106,6 +109,15 @@ protected final String createRepository(final String name, final Settings settin return name; } + protected final void deleteRepository(final String name) { + logger.debug("--> deleting repository [name: {}]", name); + assertAcked(client().admin().cluster().prepareDeleteRepository(name)); + internalCluster().getDataOrMasterNodeInstances(RepositoriesService.class).forEach(repositories -> { + RepositoryMissingException e = expectThrows(RepositoryMissingException.class, () -> repositories.repository(name)); + assertThat(e.repository(), equalTo(name)); + }); + } + public void testReadNonExistingPath() throws IOException { try (BlobStore store = newBlobStore()) { final BlobContainer container = store.blobContainer(new BlobPath()); @@ -176,7 +188,7 @@ public void testList() throws IOException { BlobMetadata blobMetadata = blobs.get(generated.getKey()); assertThat(generated.getKey(), blobMetadata, CoreMatchers.notNullValue()); assertThat(blobMetadata.name(), CoreMatchers.equalTo(generated.getKey())); - assertThat(blobMetadata.length(), CoreMatchers.equalTo(generated.getValue())); + assertThat(blobMetadata.length(), CoreMatchers.equalTo(blobLengthFromContentLength(generated.getValue()))); } assertThat(container.listBlobsByPrefix("foo-").size(), CoreMatchers.equalTo(numberOfFooBlobs)); @@ -259,7 +271,11 @@ protected static void writeBlob(BlobContainer container, String blobName, BytesA } protected BlobStore newBlobStore() { - final String repository = createRepository(randomName()); + final String repository = createRepository(randomRepositoryName()); + return newBlobStore(repository); + } + + protected BlobStore newBlobStore(String repository) { final BlobStoreRepository blobStoreRepository = (BlobStoreRepository) internalCluster().getMasterNodeInstance(RepositoriesService.class).repository(repository); return PlainActionFuture.get( @@ -267,7 +283,13 @@ protected BlobStore newBlobStore() { } public void testSnapshotAndRestore() throws Exception { - final String repoName = createRepository(randomName()); + testSnapshotAndRestore(randomBoolean()); + } + + protected void testSnapshotAndRestore(boolean recreateRepositoryBeforeRestore) throws Exception { + final String repoName = randomRepositoryName(); + final Settings repoSettings = repositorySettings(repoName); + createRepository(repoName, repoSettings, randomBoolean()); int indexCount = randomIntBetween(1, 5); int[] docCounts = new int[indexCount]; String[] indexNames = generateRandomNames(indexCount); @@ -315,6 +337,11 @@ public void testSnapshotAndRestore() throws Exception { assertAcked(client().admin().indices().prepareClose(closeIndices.toArray(new String[closeIndices.size()]))); } + if (recreateRepositoryBeforeRestore) { + deleteRepository(repoName); + createRepository(repoName, repoSettings, randomBoolean()); + } + logger.info("--> restore all indices from the snapshot"); assertSuccessfulRestore(client().admin().cluster().prepareRestoreSnapshot(repoName, snapshotName).setWaitForCompletion(true)); @@ -339,7 +366,7 @@ public void testSnapshotAndRestore() throws Exception { } public void testMultipleSnapshotAndRollback() throws Exception { - final String repoName = createRepository(randomName()); + final String repoName = createRepository(randomRepositoryName()); int iterationCount = randomIntBetween(2, 5); int[] docCounts = new int[iterationCount]; String indexName = randomName(); @@ -394,7 +421,7 @@ public void testMultipleSnapshotAndRollback() throws Exception { } public void testIndicesDeletedFromRepository() throws Exception { - final String repoName = createRepository("test-repo"); + final String repoName = createRepository(randomRepositoryName()); Client client = client(); createIndex("test-idx-1", "test-idx-2", "test-idx-3"); ensureGreen(); @@ -491,7 +518,15 @@ private static void assertSuccessfulRestore(RestoreSnapshotResponse response) { assertThat(response.getRestoreInfo().successfulShards(), equalTo(response.getRestoreInfo().totalShards())); } - protected static String randomName() { + protected String randomName() { return randomAlphaOfLength(randomIntBetween(1, 10)).toLowerCase(Locale.ROOT); } + + protected String randomRepositoryName() { + return randomName(); + } + + protected long blobLengthFromContentLength(long contentLength) { + return contentLength; + } } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIT.java b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESFsBasedRepositoryIntegTestCase.java similarity index 69% rename from server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIT.java rename to test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESFsBasedRepositoryIntegTestCase.java index 9ad02412f3771..3a8501d65e95a 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIT.java +++ b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESFsBasedRepositoryIntegTestCase.java @@ -7,7 +7,7 @@ * not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an @@ -16,18 +16,16 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.repositories.fs; +package org.elasticsearch.repositories.blobstore; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; -import org.elasticsearch.common.blobstore.fs.FsBlobStore; +import org.elasticsearch.common.blobstore.BlobStore; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.unit.ByteSizeUnit; -import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.internal.io.IOUtils; -import org.elasticsearch.repositories.blobstore.ESBlobStoreRepositoryIntegTestCase; +import org.elasticsearch.repositories.fs.FsRepository; import java.io.IOException; import java.nio.file.Files; @@ -39,35 +37,22 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.hamcrest.Matchers.instanceOf; -public class FsBlobStoreRepositoryIT extends ESBlobStoreRepositoryIntegTestCase { +public abstract class ESFsBasedRepositoryIntegTestCase extends ESBlobStoreRepositoryIntegTestCase { @Override protected String repositoryType() { return FsRepository.TYPE; } - @Override - protected Settings repositorySettings() { - final Settings.Builder settings = Settings.builder(); - settings.put(super.repositorySettings()); - settings.put("location", randomRepoPath()); - if (randomBoolean()) { - long size = 1 << randomInt(10); - settings.put("chunk_size", new ByteSizeValue(size, ByteSizeUnit.KB)); - } - return settings.build(); - } - public void testMissingDirectoriesNotCreatedInReadonlyRepository() throws IOException, InterruptedException { - final String repoName = randomName(); + final String repoName = randomRepositoryName(); final Path repoPath = randomRepoPath(); - logger.info("--> creating repository {} at {}", repoName, repoPath); - - assertAcked(client().admin().cluster().preparePutRepository(repoName).setType("fs").setSettings(Settings.builder() - .put("location", repoPath) - .put("compress", randomBoolean()) - .put("chunk_size", randomIntBetween(100, 1000), ByteSizeUnit.BYTES))); + final Settings repoSettings = Settings.builder() + .put(repositorySettings(repoName)) + .put("location", repoPath) + .build(); + createRepository(repoName, repoSettings, randomBoolean()); final String indexName = randomName(); int docCount = iterations(10, 1000); @@ -91,8 +76,7 @@ public void testMissingDirectoriesNotCreatedInReadonlyRepository() throws IOExce } assertFalse(Files.exists(deletedPath)); - assertAcked(client().admin().cluster().preparePutRepository(repoName).setType("fs").setSettings(Settings.builder() - .put("location", repoPath).put("readonly", true))); + createRepository(repoName, Settings.builder().put(repoSettings).put("readonly", true).build(), randomBoolean()); final ElasticsearchException exception = expectThrows(ElasticsearchException.class, () -> client().admin().cluster().prepareRestoreSnapshot(repoName, snapshotName).setWaitForCompletion(randomBoolean()).get()); @@ -102,25 +86,34 @@ public void testMissingDirectoriesNotCreatedInReadonlyRepository() throws IOExce } public void testReadOnly() throws Exception { - Path tempDir = createTempDir(); - Path path = tempDir.resolve("bar"); - - try (FsBlobStore store = new FsBlobStore(randomIntBetween(1, 8) * 1024, path, true)) { - assertFalse(Files.exists(path)); + final String repoName = randomRepositoryName(); + final Path repoPath = randomRepoPath(); + final Settings repoSettings = Settings.builder() + .put(repositorySettings(repoName)) + .put("readonly", true) + .put(FsRepository.LOCATION_SETTING.getKey(), repoPath) + .put(BlobStoreRepository.BUFFER_SIZE_SETTING.getKey(), String.valueOf(randomIntBetween(1, 8) * 1024) + "kb") + .build(); + createRepository(repoName, repoSettings, false); + + try (BlobStore store = newBlobStore(repoName)) { + assertFalse(Files.exists(repoPath)); BlobPath blobPath = BlobPath.cleanPath().add("foo"); store.blobContainer(blobPath); - Path storePath = store.path(); + Path storePath = repoPath; for (String d : blobPath) { storePath = storePath.resolve(d); } assertFalse(Files.exists(storePath)); } - try (FsBlobStore store = new FsBlobStore(randomIntBetween(1, 8) * 1024, path, false)) { - assertTrue(Files.exists(path)); + createRepository(repoName, Settings.builder().put(repoSettings).put("readonly", false).build(), false); + + try (BlobStore store = newBlobStore(repoName)) { + assertTrue(Files.exists(repoPath)); BlobPath blobPath = BlobPath.cleanPath().add("foo"); BlobContainer container = store.blobContainer(blobPath); - Path storePath = store.path(); + Path storePath = repoPath; for (String d : blobPath) { storePath = storePath.resolve(d); } diff --git a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESMockAPIBasedRepositoryIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESMockAPIBasedRepositoryIntegTestCase.java index e174d94d3716a..305faecc1bcf4 100644 --- a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESMockAPIBasedRepositoryIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESMockAPIBasedRepositoryIntegTestCase.java @@ -34,12 +34,15 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.network.InetAddresses; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.mocksocket.MockHttpServer; import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.repositories.Repository; import org.elasticsearch.repositories.RepositoryMissingException; import org.elasticsearch.repositories.RepositoryStats; import org.elasticsearch.test.BackgroundIndexer; +import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; @@ -55,6 +58,9 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -83,6 +89,7 @@ protected interface BlobStoreHttpHandler extends HttpHandler { private static final byte[] BUFFER = new byte[1024]; private static HttpServer httpServer; + private static ExecutorService executorService; protected Map handlers; private static final Logger log = LogManager.getLogger(); @@ -90,13 +97,19 @@ protected interface BlobStoreHttpHandler extends HttpHandler { @BeforeClass public static void startHttpServer() throws Exception { httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); + ThreadFactory threadFactory = EsExecutors.daemonThreadFactory("[" + ESMockAPIBasedRepositoryIntegTestCase.class.getName() + "]"); + // the EncryptedRepository can require more than one connection open at one time + executorService = EsExecutors.newScaling(ESMockAPIBasedRepositoryIntegTestCase.class.getName(), 0, 2, 60, + TimeUnit.SECONDS, threadFactory, new ThreadContext(Settings.EMPTY)); httpServer.setExecutor(r -> { - try { - r.run(); - } catch (Throwable t) { - log.error("Error in execution on mock http server IO thread", t); - throw t; - } + executorService.execute(() -> { + try { + r.run(); + } catch (Throwable t) { + log.error("Error in execution on mock http server IO thread", t); + throw t; + } + }); }); httpServer.start(); } @@ -111,6 +124,7 @@ public void setUpHttpServer() { @AfterClass public static void stopHttpServer() { httpServer.stop(0); + ThreadPool.terminate(executorService, 10, TimeUnit.SECONDS); httpServer = null; } @@ -124,14 +138,17 @@ public void tearDownHttpServer() { h = ((DelegatingHttpHandler) h).getDelegate(); } if (h instanceof BlobStoreHttpHandler) { - List blobs = ((BlobStoreHttpHandler) h).blobs().keySet().stream() - .filter(blob -> blob.contains("index") == false).collect(Collectors.toList()); - assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + assertEmptyRepo(((BlobStoreHttpHandler) h).blobs()); } } } } + protected void assertEmptyRepo(Map blobsMap) { + List blobs = blobsMap.keySet().stream().filter(blob -> blob.contains("index") == false).collect(Collectors.toList()); + assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + } + protected abstract Map createHttpHandlers(); protected abstract HttpHandler createErroneousHttpHandler(HttpHandler delegate); @@ -140,7 +157,7 @@ public void tearDownHttpServer() { * Test the snapshot and restore of an index which has large segments files. */ public final void testSnapshotWithLargeSegmentFiles() throws Exception { - final String repository = createRepository(randomName()); + final String repository = createRepository(randomRepositoryName()); final String index = "index-no-merges"; createIndex(index, Settings.builder() .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) @@ -171,7 +188,7 @@ public final void testSnapshotWithLargeSegmentFiles() throws Exception { } public void testRequestStats() throws Exception { - final String repository = createRepository(randomName()); + final String repository = createRepository(randomRepositoryName()); final String index = "index-no-merges"; createIndex(index, Settings.builder() .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java index c9fd7304d9712..2ce2086287f4d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java @@ -62,6 +62,8 @@ public enum Feature { MONITORING_CLUSTER_ALERTS(OperationMode.STANDARD, true), MONITORING_UPDATE_RETENTION(OperationMode.STANDARD, false), + ENCRYPTED_SNAPSHOT(OperationMode.PLATINUM, true), + CCR(OperationMode.PLATINUM, true), GRAPH(OperationMode.PLATINUM, true), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/license/XPackLicenseStateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/license/XPackLicenseStateTests.java index a325971704ea3..bd36e8d1a4e42 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/license/XPackLicenseStateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/license/XPackLicenseStateTests.java @@ -20,6 +20,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.license.License.OperationMode.BASIC; +import static org.elasticsearch.license.License.OperationMode.ENTERPRISE; import static org.elasticsearch.license.License.OperationMode.GOLD; import static org.elasticsearch.license.License.OperationMode.MISSING; import static org.elasticsearch.license.License.OperationMode.PLATINUM; @@ -299,6 +300,24 @@ public void testWatcherInactivePlatinumGoldTrial() throws Exception { assertAllowed(STANDARD, false, s -> s.checkFeature(Feature.WATCHER), false); } + public void testEncryptedSnapshotsWithInactiveLicense() { + assertAllowed(BASIC, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(TRIAL, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(GOLD, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(PLATINUM, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(ENTERPRISE, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(STANDARD, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + } + + public void testEncryptedSnapshotsWithActiveLicense() { + assertAllowed(BASIC, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(TRIAL, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), true); + assertAllowed(GOLD, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(PLATINUM, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), true); + assertAllowed(ENTERPRISE, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), true); + assertAllowed(STANDARD, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + } + public void testGraphPlatinumTrial() throws Exception { assertAllowed(TRIAL, true, s -> s.checkFeature(Feature.GRAPH), true); assertAllowed(PLATINUM, true, s -> s.checkFeature(Feature.GRAPH), true); diff --git a/x-pack/plugin/repository-encrypted/build.gradle b/x-pack/plugin/repository-encrypted/build.gradle new file mode 100644 index 0000000000000..1a0eccc6421e6 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/build.gradle @@ -0,0 +1,30 @@ +evaluationDependsOn(xpackModule('core')) + +apply plugin: 'elasticsearch.esplugin' +apply plugin: 'elasticsearch.internal-cluster-test' +esplugin { + name 'repository-encrypted' + description 'Elasticsearch Expanded Pack Plugin - client-side encrypted repositories.' + classname 'org.elasticsearch.repositories.encrypted.EncryptedRepositoryPlugin' + extendedPlugins = ['x-pack-core'] +} +archivesBaseName = 'x-pack-repository-encrypted' + +dependencies { + // necessary for the license check + compileOnly project(path: xpackModule('core'), configuration: 'default') + testImplementation project(path: xpackModule('core'), configuration: 'testArtifacts') + // required for integ tests of encrypted FS repository + internalClusterTestImplementation project(":test:framework") + // required for integ tests of encrypted cloud repositories + internalClusterTestImplementation project(path: ':plugins:repository-gcs', configuration: 'internalClusterTestArtifacts') + internalClusterTestImplementation project(path: ':plugins:repository-azure', configuration: 'internalClusterTestArtifacts') + internalClusterTestImplementation(project(path: ':plugins:repository-s3', configuration: 'internalClusterTestArtifacts')) { + // HACK, resolves jar hell, such as: + // jar1: jakarta.xml.bind/jakarta.xml.bind-api/2.3.2/8d49996a4338670764d7ca4b85a1c4ccf7fe665d/jakarta.xml.bind-api-2.3.2.jar + // jar2: javax.xml.bind/jaxb-api/2.2.2/aeb3021ca93dde265796d82015beecdcff95bf09/jaxb-api-2.2.2.jar + exclude group: 'javax.xml.bind', module: 'jaxb-api' + } + // for encrypted GCS repository integ tests + internalClusterTestRuntimeOnly 'com.google.guava:guava:26.0-jre' +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedAzureBlobStoreRepositoryIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedAzureBlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..01ac4b83fa838 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedAzureBlobStoreRepositoryIntegTests.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.azure.AzureBlobStoreRepositoryTests; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.DEK_ROOT_CONTAINER; +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.getEncryptedBlobByteLength; +import static org.hamcrest.Matchers.hasSize; + +public final class EncryptedAzureBlobStoreRepositoryIntegTests extends AzureBlobStoreRepositoryTests { + private static List repositoryNames; + + @BeforeClass + private static void preGenerateRepositoryNames() { + List names = new ArrayList<>(); + for (int i = 0; i < 32; i++) { + names.add("test-repo-" + i); + } + repositoryNames = Collections.synchronizedList(names); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()); + MockSecureSettings superSecureSettings = (MockSecureSettings) settingsBuilder.getSecureSettings(); + superSecureSettings.merge(nodeSecureSettings()); + return settingsBuilder.build(); + } + + protected MockSecureSettings nodeSecureSettings() { + MockSecureSettings secureSettings = new MockSecureSettings(); + for (String repositoryName : repositoryNames) { + secureSettings.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repositoryName + " ".repeat(14 - repositoryName.length()) // pad to the minimum pass length of 112 bits (14) + ); + } + return secureSettings; + } + + @Override + protected String randomRepositoryName() { + return repositoryNames.remove(randomIntBetween(0, repositoryNames.size() - 1)); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class, TestAzureRepositoryPlugin.class); + } + + @Override + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + @Override + protected Settings repositorySettings(String repositoryName) { + return Settings.builder() + .put(super.repositorySettings(repositoryName)) + .put(EncryptedRepositoryPlugin.DELEGATE_TYPE_SETTING.getKey(), "azure") + .put(EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.getKey(), repositoryName) + .build(); + } + + @Override + protected void assertEmptyRepo(Map blobsMap) { + List blobs = blobsMap.keySet() + .stream() + .filter(blob -> false == blob.contains("index")) + .filter(blob -> false == blob.contains(DEK_ROOT_CONTAINER)) // encryption metadata "leaks" + .collect(Collectors.toList()); + assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + } + + @Override + protected long blobLengthFromContentLength(long contentLength) { + return getEncryptedBlobByteLength(contentLength); + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedFSBlobStoreRepositoryIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedFSBlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..6eeca1bbcacff --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedFSBlobStoreRepositoryIntegTests.java @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.RepositoriesService; +import org.elasticsearch.repositories.RepositoryData; +import org.elasticsearch.repositories.RepositoryException; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.repositories.blobstore.ESFsBasedRepositoryIntegTestCase; +import org.elasticsearch.repositories.fs.FsRepository; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Stream; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.getEncryptedBlobByteLength; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.containsString; + +public final class EncryptedFSBlobStoreRepositoryIntegTests extends ESFsBasedRepositoryIntegTestCase { + private static int NUMBER_OF_TEST_REPOSITORIES = 32; + + private static List repositoryNames = new ArrayList<>(); + + @BeforeClass + private static void preGenerateRepositoryNames() { + for (int i = 0; i < NUMBER_OF_TEST_REPOSITORIES; i++) { + repositoryNames.add("test-repo-" + i); + } + } + + @Override + protected Settings repositorySettings(String repositoryName) { + final Settings.Builder settings = Settings.builder() + .put("compress", randomBoolean()) + .put("location", randomRepoPath()) + .put("delegate_type", FsRepository.TYPE) + .put("password_name", repositoryName); + if (randomBoolean()) { + long size = 1 << randomInt(10); + settings.put("chunk_size", new ByteSizeValue(size, ByteSizeUnit.KB)); + } + return settings.build(); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()) + .setSecureSettings(nodeSecureSettings()) + .build(); + } + + protected MockSecureSettings nodeSecureSettings() { + MockSecureSettings secureSettings = new MockSecureSettings(); + for (String repositoryName : repositoryNames) { + secureSettings.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repositoryName + " ".repeat(14 - repositoryName.length()) // pad to the minimum pass length of 112 bits (14) + ); + } + return secureSettings; + } + + @Override + protected String randomRepositoryName() { + return repositoryNames.remove(randomIntBetween(0, repositoryNames.size() - 1)); + } + + @Override + protected long blobLengthFromContentLength(long contentLength) { + return getEncryptedBlobByteLength(contentLength); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class); + } + + @Override + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + public void testTamperedEncryptionMetadata() throws Exception { + final String repoName = randomRepositoryName(); + final Path repoPath = randomRepoPath(); + final Settings repoSettings = Settings.builder().put(repositorySettings(repoName)).put("location", repoPath).build(); + createRepository(repoName, repoSettings, true); + + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repoName, snapshotName); + client().admin().cluster().prepareCreateSnapshot(repoName, snapshotName).setWaitForCompletion(true).setIndices("other*").get(); + + assertAcked(client().admin().cluster().prepareDeleteRepository(repoName)); + createRepository(repoName, Settings.builder().put(repoSettings).put("readonly", randomBoolean()).build(), randomBoolean()); + + try (Stream rootContents = Files.list(repoPath.resolve(EncryptedRepository.DEK_ROOT_CONTAINER))) { + // tamper all DEKs + rootContents.filter(Files::isDirectory).forEach(DEKRootPath -> { + try (Stream contents = Files.list(DEKRootPath)) { + contents.filter(Files::isRegularFile).forEach(DEKPath -> { + try { + byte[] originalDEKBytes = Files.readAllBytes(DEKPath); + // tamper DEK + int tamperPos = randomIntBetween(0, originalDEKBytes.length - 1); + originalDEKBytes[tamperPos] ^= 0xFF; + Files.write(DEKPath, originalDEKBytes); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + final BlobStoreRepository blobStoreRepository = (BlobStoreRepository) internalCluster().getCurrentMasterNodeInstance( + RepositoriesService.class + ).repository(repoName); + RepositoryException e = expectThrows( + RepositoryException.class, + () -> PlainActionFuture.get( + f -> blobStoreRepository.threadPool().generic().execute(ActionRunnable.wrap(f, blobStoreRepository::getRepositoryData)) + ) + ); + assertThat(e.getMessage(), containsString("the encryption metadata in the repository has been corrupted")); + e = expectThrows( + RepositoryException.class, + () -> client().admin().cluster().prepareRestoreSnapshot(repoName, snapshotName).setWaitForCompletion(true).get() + ); + assertThat(e.getMessage(), containsString("the encryption metadata in the repository has been corrupted")); + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedGCSBlobStoreRepositoryIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedGCSBlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..128869dd87aff --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedGCSBlobStoreRepositoryIntegTests.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.gcs.GoogleCloudStorageBlobStoreRepositoryTests; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.DEK_ROOT_CONTAINER; +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.getEncryptedBlobByteLength; +import static org.hamcrest.Matchers.hasSize; + +public final class EncryptedGCSBlobStoreRepositoryIntegTests extends GoogleCloudStorageBlobStoreRepositoryTests { + + private static List repositoryNames; + + @BeforeClass + private static void preGenerateRepositoryNames() { + List names = new ArrayList<>(); + for (int i = 0; i < 32; i++) { + names.add("test-repo-" + i); + } + repositoryNames = Collections.synchronizedList(names); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()); + MockSecureSettings superSecureSettings = (MockSecureSettings) settingsBuilder.getSecureSettings(); + superSecureSettings.merge(nodeSecureSettings()); + return settingsBuilder.build(); + } + + protected MockSecureSettings nodeSecureSettings() { + MockSecureSettings secureSettings = new MockSecureSettings(); + for (String repositoryName : repositoryNames) { + secureSettings.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repositoryName + " ".repeat(14 - repositoryName.length()) // pad to the minimum pass length of 112 bits (14) + ); + } + return secureSettings; + } + + @Override + protected String randomRepositoryName() { + return repositoryNames.remove(randomIntBetween(0, repositoryNames.size() - 1)); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class, TestGoogleCloudStoragePlugin.class); + } + + @Override + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + @Override + protected Settings repositorySettings(String repositoryName) { + return Settings.builder() + .put(super.repositorySettings(repositoryName)) + .put(EncryptedRepositoryPlugin.DELEGATE_TYPE_SETTING.getKey(), "gcs") + .put(EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.getKey(), repositoryName) + .build(); + } + + @Override + protected void assertEmptyRepo(Map blobsMap) { + List blobs = blobsMap.keySet() + .stream() + .filter(blob -> false == blob.contains("index")) + .filter(blob -> false == blob.contains(DEK_ROOT_CONTAINER)) // encryption metadata "leaks" + .collect(Collectors.toList()); + assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + } + + @Override + protected long blobLengthFromContentLength(long contentLength) { + return getEncryptedBlobByteLength(contentLength); + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedRepositorySecretIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedRepositorySecretIntegTests.java new file mode 100644 index 0000000000000..3deba6f9b476f --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedRepositorySecretIntegTests.java @@ -0,0 +1,806 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.admin.cluster.repositories.verify.VerifyRepositoryResponse; +import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse; +import org.elasticsearch.action.admin.cluster.snapshots.get.GetSnapshotsResponse; +import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse; +import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.SnapshotsInProgress; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.RepositoriesService; +import org.elasticsearch.repositories.RepositoryData; +import org.elasticsearch.repositories.RepositoryException; +import org.elasticsearch.repositories.RepositoryMissingException; +import org.elasticsearch.repositories.RepositoryVerificationException; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.repositories.fs.FsRepository; +import org.elasticsearch.snapshots.Snapshot; +import org.elasticsearch.snapshots.SnapshotInfo; +import org.elasticsearch.snapshots.SnapshotMissingException; +import org.elasticsearch.snapshots.SnapshotState; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.InternalTestCluster; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Matchers.anyObject; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, autoManageMasterNodes = false) +public final class EncryptedRepositorySecretIntegTests extends ESIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()) + .build(); + } + + public void testRepositoryCreationFailsForMissingPassword() throws Exception { + // if the password is missing on the master node, the repository creation fails + final String repositoryName = randomName(); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + randomAlphaOfLength(20) + ); + logger.info("--> start 3 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + final String masterNodeName = internalCluster().startNode(); + logger.info("--> started master node " + masterNodeName); + ensureStableCluster(1); + internalCluster().startNodes(2, Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + logger.info("--> started two other nodes"); + ensureStableCluster(3); + assertThat(masterNodeName, equalTo(internalCluster().getMasterName())); + + final Settings repositorySettings = repositorySettings(repositoryName); + RepositoryException e = expectThrows( + RepositoryException.class, + () -> client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(randomBoolean()) + .setSettings(repositorySettings) + .get() + ); + assertThat(e.getMessage(), containsString("failed to create repository")); + expectThrows(RepositoryMissingException.class, () -> client().admin().cluster().prepareGetRepositories(repositoryName).get()); + + if (randomBoolean()) { + // stop the node with the missing password + internalCluster().stopRandomNode(InternalTestCluster.nameFilter(masterNodeName)); + ensureStableCluster(2); + } else { + // restart the node with the missing password + internalCluster().restartNode(masterNodeName, new InternalTestCluster.RestartCallback() { + @Override + public Settings onNodeStopped(String nodeName) throws Exception { + Settings.Builder newSettings = Settings.builder().put(super.onNodeStopped(nodeName)); + newSettings.setSecureSettings(secureSettingsWithPassword); + return newSettings.build(); + } + }); + ensureStableCluster(3); + } + // repository creation now successful + createRepository(repositoryName, repositorySettings, true); + } + + public void testRepositoryVerificationFailsForMissingPassword() throws Exception { + // if the password is missing on any non-master node, the repository verification fails + final String repositoryName = randomName(); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + randomAlphaOfLength(20) + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + final String masterNodeName = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + logger.info("--> started master node " + masterNodeName); + ensureStableCluster(1); + final String otherNodeName = internalCluster().startNode(); + logger.info("--> started other node " + otherNodeName); + ensureStableCluster(2); + assertThat(masterNodeName, equalTo(internalCluster().getMasterName())); + // repository create fails verification + final Settings repositorySettings = repositorySettings(repositoryName); + expectThrows( + RepositoryVerificationException.class, + () -> client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(true) + .setSettings(repositorySettings) + .get() + ); + if (randomBoolean()) { + // delete and recreate repo + logger.debug("--> deleting repository [name: {}]", repositoryName); + assertAcked(client().admin().cluster().prepareDeleteRepository(repositoryName)); + assertAcked( + client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(false) + .setSettings(repositorySettings) + .get() + ); + } + // test verify call fails + expectThrows(RepositoryVerificationException.class, () -> client().admin().cluster().prepareVerifyRepository(repositoryName).get()); + if (randomBoolean()) { + // stop the node with the missing password + internalCluster().stopRandomNode(InternalTestCluster.nameFilter(otherNodeName)); + ensureStableCluster(1); + // repository verification now succeeds + VerifyRepositoryResponse verifyRepositoryResponse = client().admin().cluster().prepareVerifyRepository(repositoryName).get(); + List verifiedNodes = verifyRepositoryResponse.getNodes().stream().map(n -> n.getName()).collect(Collectors.toList()); + assertThat(verifiedNodes, contains(masterNodeName)); + } else { + // restart the node with the missing password + internalCluster().restartNode(otherNodeName, new InternalTestCluster.RestartCallback() { + @Override + public Settings onNodeStopped(String nodeName) throws Exception { + Settings.Builder newSettings = Settings.builder().put(super.onNodeStopped(nodeName)); + newSettings.setSecureSettings(secureSettingsWithPassword); + return newSettings.build(); + } + }); + ensureStableCluster(2); + // repository verification now succeeds + VerifyRepositoryResponse verifyRepositoryResponse = client().admin().cluster().prepareVerifyRepository(repositoryName).get(); + List verifiedNodes = verifyRepositoryResponse.getNodes().stream().map(n -> n.getName()).collect(Collectors.toList()); + assertThat(verifiedNodes, containsInAnyOrder(masterNodeName, otherNodeName)); + } + } + + public void testRepositoryVerificationFailsForDifferentPassword() throws Exception { + final String repositoryName = randomName(); + final String repoPass1 = randomAlphaOfLength(20); + final String repoPass2 = randomAlphaOfLength(19); + // put a different repository password + MockSecureSettings secureSettings1 = new MockSecureSettings(); + secureSettings1.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repoPass1 + ); + MockSecureSettings secureSettings2 = new MockSecureSettings(); + secureSettings2.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repoPass2 + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(1); + final String node1 = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettings1).build()); + final String node2 = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettings2).build()); + ensureStableCluster(2); + // repository create fails verification + Settings repositorySettings = repositorySettings(repositoryName); + expectThrows( + RepositoryVerificationException.class, + () -> client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(true) + .setSettings(repositorySettings) + .get() + ); + if (randomBoolean()) { + // delete and recreate repo + logger.debug("--> deleting repository [name: {}]", repositoryName); + assertAcked(client().admin().cluster().prepareDeleteRepository(repositoryName)); + assertAcked( + client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(false) + .setSettings(repositorySettings) + .get() + ); + } + // test verify call fails + expectThrows(RepositoryVerificationException.class, () -> client().admin().cluster().prepareVerifyRepository(repositoryName).get()); + // restart one of the nodes to use the same password + if (randomBoolean()) { + secureSettings1.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repoPass2 + ); + internalCluster().restartNode(node1, new InternalTestCluster.RestartCallback()); + } else { + secureSettings2.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repoPass1 + ); + internalCluster().restartNode(node2, new InternalTestCluster.RestartCallback()); + } + ensureStableCluster(2); + // repository verification now succeeds + VerifyRepositoryResponse verifyRepositoryResponse = client().admin().cluster().prepareVerifyRepository(repositoryName).get(); + List verifiedNodes = verifyRepositoryResponse.getNodes().stream().map(n -> n.getName()).collect(Collectors.toList()); + assertThat(verifiedNodes, containsInAnyOrder(node1, node2)); + } + + public void testLicenseComplianceSnapshotAndRestore() throws Exception { + final String repositoryName = randomName(); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + randomAlphaOfLength(20) + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(1); + internalCluster().startNodes(2, Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + ensureStableCluster(2); + + logger.info("--> creating repo " + repositoryName); + createRepository(repositoryName); + final String indexName = randomName(); + logger.info("--> create random index {} with {} records", indexName, 3); + indexRandom( + true, + client().prepareIndex(indexName).setId("1").setSource("field1", "the quick brown fox jumps"), + client().prepareIndex(indexName).setId("2").setSource("field1", "quick brown"), + client().prepareIndex(indexName).setId("3").setSource("field1", "quick") + ); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 3); + + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repositoryName, snapshotName); + assertSuccessfulSnapshot( + client().admin() + .cluster() + .prepareCreateSnapshot(repositoryName, snapshotName) + .setIndices(indexName) + .setWaitForCompletion(true) + .get() + ); + + // make license not accept encrypted snapshots + EncryptedRepository encryptedRepository = (EncryptedRepository) internalCluster().getCurrentMasterNodeInstance( + RepositoriesService.class + ).repository(repositoryName); + encryptedRepository.licenseStateSupplier = () -> { + XPackLicenseState mockLicenseState = mock(XPackLicenseState.class); + when(mockLicenseState.isAllowed(anyObject())).thenReturn(false); + return mockLicenseState; + }; + + // now snapshot is not permitted + ElasticsearchSecurityException e = expectThrows( + ElasticsearchSecurityException.class, + () -> client().admin().cluster().prepareCreateSnapshot(repositoryName, snapshotName + "2").setWaitForCompletion(true).get() + ); + assertThat(e.getDetailedMessage(), containsString("current license is non-compliant for [encrypted snapshots]")); + + logger.info("--> delete index {}", indexName); + assertAcked(client().admin().indices().prepareDelete(indexName)); + + // but restore is permitted + logger.info("--> restore index from the snapshot"); + assertSuccessfulRestore( + client().admin().cluster().prepareRestoreSnapshot(repositoryName, snapshotName).setWaitForCompletion(true).get() + ); + ensureGreen(); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 3); + // also delete snapshot is permitted + logger.info("--> delete snapshot {}:{}", repositoryName, snapshotName); + assertAcked(client().admin().cluster().prepareDeleteSnapshot(repositoryName, snapshotName).get()); + } + + public void testSnapshotIsPartialForMissingPassword() throws Exception { + final String repositoryName = randomName(); + final Settings repositorySettings = repositorySettings(repositoryName); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + randomAlphaOfLength(20) + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + // master has the password + internalCluster().startNode(Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + ensureStableCluster(1); + final String otherNode = internalCluster().startNode(); + ensureStableCluster(2); + logger.debug("--> creating repository [name: {}, verify: {}, settings: {}]", repositoryName, false, repositorySettings); + assertAcked( + client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(false) + .setSettings(repositorySettings) + ); + // create an index with the shard on the node without a repository password + final String indexName = randomName(); + final Settings indexSettings = Settings.builder() + .put(indexSettings()) + .put("index.routing.allocation.include._name", otherNode) + .put(SETTING_NUMBER_OF_SHARDS, 1) + .build(); + logger.info("--> create random index {}", indexName); + createIndex(indexName, indexSettings); + indexRandom( + true, + client().prepareIndex(indexName).setId("1").setSource("field1", "the quick brown fox jumps"), + client().prepareIndex(indexName).setId("2").setSource("field1", "quick brown"), + client().prepareIndex(indexName).setId("3").setSource("field1", "quick") + ); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 3); + + // empty snapshot completes successfully because it does not involve data on the node without a repository password + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repositoryName, snapshotName); + CreateSnapshotResponse createSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repositoryName, snapshotName) + .setIndices(indexName + "other*") + .setWaitForCompletion(true) + .get(); + assertThat( + createSnapshotResponse.getSnapshotInfo().successfulShards(), + equalTo(createSnapshotResponse.getSnapshotInfo().totalShards()) + ); + assertThat(createSnapshotResponse.getSnapshotInfo().successfulShards(), equalTo(0)); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY)) + ); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY)) + ); + + // snapshot is PARTIAL because it includes shards on nodes with a missing repository password + final String snapshotName2 = snapshotName + "2"; + CreateSnapshotResponse incompleteSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repositoryName, snapshotName2) + .setWaitForCompletion(true) + .setIndices(indexName) + .get(); + assertThat(incompleteSnapshotResponse.getSnapshotInfo().state(), equalTo(SnapshotState.PARTIAL)); + assertTrue( + incompleteSnapshotResponse.getSnapshotInfo() + .shardFailures() + .stream() + .allMatch(shardFailure -> shardFailure.reason().contains("[" + repositoryName + "] missing")) + ); + assertThat( + incompleteSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY)) + ); + assertThat( + incompleteSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY)) + ); + final Set nodesWithFailures = incompleteSnapshotResponse.getSnapshotInfo() + .shardFailures() + .stream() + .map(sf -> sf.nodeId()) + .collect(Collectors.toSet()); + assertThat(nodesWithFailures.size(), equalTo(1)); + final ClusterStateResponse clusterState = client().admin().cluster().prepareState().clear().setNodes(true).get(); + assertThat(clusterState.getState().nodes().get(nodesWithFailures.iterator().next()).getName(), equalTo(otherNode)); + } + + public void testSnapshotIsPartialForDifferentPassword() throws Exception { + final String repoName = randomName(); + final Settings repoSettings = repositorySettings(repoName); + final String repoPass1 = randomAlphaOfLength(20); + final String repoPass2 = randomAlphaOfLength(19); + MockSecureSettings secureSettingsMaster = new MockSecureSettings(); + secureSettingsMaster.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repoName).getKey(), + repoPass1 + ); + MockSecureSettings secureSettingsOther = new MockSecureSettings(); + secureSettingsOther.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repoName).getKey(), + repoPass2 + ); + final boolean putRepoEarly = randomBoolean(); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + final String masterNode = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettingsMaster).build()); + ensureStableCluster(1); + if (putRepoEarly) { + createRepository(repoName, repoSettings, true); + } + final String otherNode = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettingsOther).build()); + ensureStableCluster(2); + if (false == putRepoEarly) { + createRepository(repoName, repoSettings, false); + } + + // create index with shards on both nodes + final String indexName = randomName(); + final Settings indexSettings = Settings.builder().put(indexSettings()).put(SETTING_NUMBER_OF_SHARDS, 5).build(); + logger.info("--> create random index {}", indexName); + createIndex(indexName, indexSettings); + indexRandom( + true, + client().prepareIndex(indexName).setId("1").setSource("field1", "the quick brown fox jumps"), + client().prepareIndex(indexName).setId("2").setSource("field1", "quick brown"), + client().prepareIndex(indexName).setId("3").setSource("field1", "quick"), + client().prepareIndex(indexName).setId("4").setSource("field1", "lazy"), + client().prepareIndex(indexName).setId("5").setSource("field1", "dog") + ); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 5); + + // empty snapshot completes successfully for both repos because it does not involve any data + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repoName, snapshotName); + CreateSnapshotResponse createSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repoName, snapshotName) + .setIndices(indexName + "other*") + .setWaitForCompletion(true) + .get(); + assertThat( + createSnapshotResponse.getSnapshotInfo().successfulShards(), + equalTo(createSnapshotResponse.getSnapshotInfo().totalShards()) + ); + assertThat(createSnapshotResponse.getSnapshotInfo().successfulShards(), equalTo(0)); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY)) + ); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY)) + ); + + // snapshot is PARTIAL because it includes shards on nodes with a different repository KEK + final String snapshotName2 = snapshotName + "2"; + CreateSnapshotResponse incompleteSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repoName, snapshotName2) + .setWaitForCompletion(true) + .setIndices(indexName) + .get(); + assertThat(incompleteSnapshotResponse.getSnapshotInfo().state(), equalTo(SnapshotState.PARTIAL)); + assertTrue( + incompleteSnapshotResponse.getSnapshotInfo() + .shardFailures() + .stream() + .allMatch(shardFailure -> shardFailure.reason().contains("Repository password mismatch")) + ); + final Set nodesWithFailures = incompleteSnapshotResponse.getSnapshotInfo() + .shardFailures() + .stream() + .map(sf -> sf.nodeId()) + .collect(Collectors.toSet()); + assertThat(nodesWithFailures.size(), equalTo(1)); + final ClusterStateResponse clusterState = client().admin().cluster().prepareState().clear().setNodes(true).get(); + assertThat(clusterState.getState().nodes().get(nodesWithFailures.iterator().next()).getName(), equalTo(otherNode)); + } + + public void testWrongRepositoryPassword() throws Exception { + final String repositoryName = randomName(); + final Settings repositorySettings = repositorySettings(repositoryName); + final String goodPassword = randomAlphaOfLength(20); + final String wrongPassword = randomAlphaOfLength(19); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + goodPassword + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(1); + internalCluster().startNodes(2, Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + ensureStableCluster(2); + createRepository(repositoryName, repositorySettings, true); + // create empty smapshot + final String snapshotName = randomName(); + logger.info("--> create empty snapshot {}:{}", repositoryName, snapshotName); + CreateSnapshotResponse createSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repositoryName, snapshotName) + .setWaitForCompletion(true) + .get(); + assertThat( + createSnapshotResponse.getSnapshotInfo().successfulShards(), + equalTo(createSnapshotResponse.getSnapshotInfo().totalShards()) + ); + assertThat(createSnapshotResponse.getSnapshotInfo().successfulShards(), equalTo(0)); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY)) + ); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY)) + ); + // restart master node and fill in a wrong password + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + wrongPassword + ); + Set nodesWithWrongPassword = new HashSet<>(); + do { + String masterNodeName = internalCluster().getMasterName(); + logger.info("--> restart master node {}", masterNodeName); + internalCluster().restartNode(masterNodeName, new InternalTestCluster.RestartCallback()); + nodesWithWrongPassword.add(masterNodeName); + ensureStableCluster(2); + } while (false == nodesWithWrongPassword.contains(internalCluster().getMasterName())); + // maybe recreate the repository + if (randomBoolean()) { + deleteRepository(repositoryName); + createRepository(repositoryName, repositorySettings, false); + } + // all repository operations return "repository password is incorrect", but the repository does not move to the corrupted state + final BlobStoreRepository blobStoreRepository = (BlobStoreRepository) internalCluster().getCurrentMasterNodeInstance( + RepositoriesService.class + ).repository(repositoryName); + RepositoryException e = expectThrows( + RepositoryException.class, + () -> PlainActionFuture.get( + f -> blobStoreRepository.threadPool().generic().execute(ActionRunnable.wrap(f, blobStoreRepository::getRepositoryData)) + ) + ); + assertThat(e.getCause().getMessage(), containsString("repository password is incorrect")); + e = expectThrows( + RepositoryException.class, + () -> client().admin().cluster().prepareCreateSnapshot(repositoryName, snapshotName + "2").setWaitForCompletion(true).get() + ); + assertThat(e.getCause().getMessage(), containsString("repository password is incorrect")); + GetSnapshotsResponse getSnapshotResponse = client().admin().cluster().prepareGetSnapshots(repositoryName).get(); + assertThat(getSnapshotResponse.getSuccessfulResponses().keySet(), empty()); + assertThat(getSnapshotResponse.getFailedResponses().keySet(), contains(repositoryName)); + assertThat( + getSnapshotResponse.getFailedResponses().get(repositoryName).getCause().getMessage(), + containsString("repository password is incorrect") + ); + e = expectThrows( + RepositoryException.class, + () -> client().admin().cluster().prepareRestoreSnapshot(repositoryName, snapshotName).setWaitForCompletion(true).get() + ); + assertThat(e.getCause().getMessage(), containsString("repository password is incorrect")); + e = expectThrows( + RepositoryException.class, + () -> client().admin().cluster().prepareDeleteSnapshot(repositoryName, snapshotName).get() + ); + assertThat(e.getCause().getMessage(), containsString("repository password is incorrect")); + // restart master node and fill in the good password + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + goodPassword + ); + do { + String masterNodeName = internalCluster().getMasterName(); + logger.info("--> restart master node {}", masterNodeName); + internalCluster().restartNode(masterNodeName, new InternalTestCluster.RestartCallback()); + nodesWithWrongPassword.remove(masterNodeName); + ensureStableCluster(2); + } while (nodesWithWrongPassword.contains(internalCluster().getMasterName())); + // ensure get snapshot works + getSnapshotResponse = client().admin().cluster().prepareGetSnapshots(repositoryName).get(); + assertThat(getSnapshotResponse.getFailedResponses().keySet(), empty()); + assertThat(getSnapshotResponse.getSuccessfulResponses().keySet(), contains(repositoryName)); + } + + public void testSnapshotFailsForMasterFailoverWithWrongPassword() throws Exception { + final String repoName = randomName(); + final Settings repoSettings = repositorySettings(repoName); + final String goodPass = randomAlphaOfLength(20); + final String wrongPass = randomAlphaOfLength(19); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repoName).getKey(), + goodPass + ); + logger.info("--> start 4 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + final String masterNode = internalCluster().startMasterOnlyNodes( + 1, + Settings.builder().setSecureSettings(secureSettingsWithPassword).build() + ).get(0); + final String otherNode = internalCluster().startDataOnlyNodes( + 1, + Settings.builder().setSecureSettings(secureSettingsWithPassword).build() + ).get(0); + ensureStableCluster(2); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repoName).getKey(), + wrongPass + ); + internalCluster().startMasterOnlyNodes(2, Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + ensureStableCluster(4); + assertThat(internalCluster().getMasterName(), equalTo(masterNode)); + + logger.debug("--> creating repository [name: {}, verify: {}, settings: {}]", repoName, false, repoSettings); + assertAcked( + client().admin().cluster().preparePutRepository(repoName).setType(repositoryType()).setVerify(false).setSettings(repoSettings) + ); + // create index with just one shard on the "other" data node + final String indexName = randomName(); + final Settings indexSettings = Settings.builder() + .put(indexSettings()) + .put("index.routing.allocation.include._name", otherNode) + .put(SETTING_NUMBER_OF_SHARDS, 1) + .build(); + logger.info("--> create random index {}", indexName); + createIndex(indexName, indexSettings); + indexRandom( + true, + client().prepareIndex(indexName).setId("1").setSource("field1", "the quick brown fox jumps"), + client().prepareIndex(indexName).setId("2").setSource("field1", "quick brown"), + client().prepareIndex(indexName).setId("3").setSource("field1", "quick"), + client().prepareIndex(indexName).setId("4").setSource("field1", "lazy"), + client().prepareIndex(indexName).setId("5").setSource("field1", "dog") + ); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 5); + + // block shard snapshot on the data node + final LocalStateEncryptedRepositoryPlugin.TestEncryptedRepository otherNodeEncryptedRepo = + (LocalStateEncryptedRepositoryPlugin.TestEncryptedRepository) internalCluster().getInstance( + RepositoriesService.class, + otherNode + ).repository(repoName); + otherNodeEncryptedRepo.blockSnapshotShard(); + + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repoName, snapshotName); + client().admin().cluster().prepareCreateSnapshot(repoName, snapshotName).setIndices(indexName).setWaitForCompletion(false).get(); + + // stop master + internalCluster().stopRandomNode(InternalTestCluster.nameFilter(masterNode)); + ensureStableCluster(3); + + otherNodeEncryptedRepo.unblockSnapshotShard(); + + // the failover master has the wrong password, snapshot fails + logger.info("--> waiting for completion"); + expectThrows(SnapshotMissingException.class, () -> { waitForCompletion(repoName, snapshotName, TimeValue.timeValueSeconds(60)); }); + } + + protected String randomName() { + return randomAlphaOfLength(randomIntBetween(1, 10)).toLowerCase(Locale.ROOT); + } + + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + protected Settings repositorySettings(String repositoryName) { + return Settings.builder() + .put("compress", randomBoolean()) + .put(EncryptedRepositoryPlugin.DELEGATE_TYPE_SETTING.getKey(), FsRepository.TYPE) + .put(EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.getKey(), repositoryName) + .put("location", randomRepoPath()) + .build(); + } + + protected String createRepository(final String name) { + return createRepository(name, true); + } + + protected String createRepository(final String name, final boolean verify) { + return createRepository(name, repositorySettings(name), verify); + } + + protected String createRepository(final String name, final Settings settings, final boolean verify) { + logger.debug("--> creating repository [name: {}, verify: {}, settings: {}]", name, verify, settings); + assertAcked( + client().admin().cluster().preparePutRepository(name).setType(repositoryType()).setVerify(verify).setSettings(settings) + ); + + internalCluster().getDataOrMasterNodeInstances(RepositoriesService.class).forEach(repositories -> { + assertThat(repositories.repository(name), notNullValue()); + assertThat(repositories.repository(name), instanceOf(BlobStoreRepository.class)); + assertThat(repositories.repository(name).isReadOnly(), is(settings.getAsBoolean("readonly", false))); + }); + + return name; + } + + protected void deleteRepository(final String name) { + logger.debug("--> deleting repository [name: {}]", name); + assertAcked(client().admin().cluster().prepareDeleteRepository(name)); + + internalCluster().getDataOrMasterNodeInstances(RepositoriesService.class).forEach(repositories -> { + RepositoryMissingException e = expectThrows(RepositoryMissingException.class, () -> repositories.repository(name)); + assertThat(e.repository(), equalTo(name)); + }); + } + + private void assertSuccessfulRestore(RestoreSnapshotResponse response) { + assertThat(response.getRestoreInfo().successfulShards(), greaterThan(0)); + assertThat(response.getRestoreInfo().successfulShards(), equalTo(response.getRestoreInfo().totalShards())); + } + + private void assertSuccessfulSnapshot(CreateSnapshotResponse response) { + assertThat(response.getSnapshotInfo().successfulShards(), greaterThan(0)); + assertThat(response.getSnapshotInfo().successfulShards(), equalTo(response.getSnapshotInfo().totalShards())); + assertThat(response.getSnapshotInfo().userMetadata(), not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY))); + assertThat(response.getSnapshotInfo().userMetadata(), not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY))); + } + + public SnapshotInfo waitForCompletion(String repository, String snapshotName, TimeValue timeout) throws InterruptedException { + long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < timeout.millis()) { + List snapshotInfos = client().admin() + .cluster() + .prepareGetSnapshots(repository) + .setSnapshots(snapshotName) + .get() + .getSnapshots(repository); + assertThat(snapshotInfos.size(), equalTo(1)); + if (snapshotInfos.get(0).state().completed()) { + // Make sure that snapshot clean up operations are finished + ClusterStateResponse stateResponse = client().admin().cluster().prepareState().get(); + SnapshotsInProgress snapshotsInProgress = stateResponse.getState().custom(SnapshotsInProgress.TYPE); + if (snapshotsInProgress == null) { + return snapshotInfos.get(0); + } else { + boolean found = false; + for (SnapshotsInProgress.Entry entry : snapshotsInProgress.entries()) { + final Snapshot curr = entry.snapshot(); + if (curr.getRepository().equals(repository) && curr.getSnapshotId().getName().equals(snapshotName)) { + found = true; + break; + } + } + if (found == false) { + return snapshotInfos.get(0); + } + } + } + Thread.sleep(100); + } + fail("Timeout!!!"); + return null; + } +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedS3BlobStoreRepositoryIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedS3BlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..96aaed3c096e5 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedS3BlobStoreRepositoryIntegTests.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.s3.S3BlobStoreRepositoryTests; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.DEK_ROOT_CONTAINER; +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.getEncryptedBlobByteLength; +import static org.hamcrest.Matchers.hasSize; + +public final class EncryptedS3BlobStoreRepositoryIntegTests extends S3BlobStoreRepositoryTests { + private static List repositoryNames; + + @BeforeClass + private static void preGenerateRepositoryNames() { + List names = new ArrayList<>(); + for (int i = 0; i < 32; i++) { + names.add("test-repo-" + i); + } + repositoryNames = Collections.synchronizedList(names); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()); + MockSecureSettings superSecureSettings = (MockSecureSettings) settingsBuilder.getSecureSettings(); + superSecureSettings.merge(nodeSecureSettings()); + return settingsBuilder.build(); + } + + protected MockSecureSettings nodeSecureSettings() { + MockSecureSettings secureSettings = new MockSecureSettings(); + for (String repositoryName : repositoryNames) { + secureSettings.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repositoryName + " ".repeat(14 - repositoryName.length()) // pad to the minimum pass length of 112 bits (14) + ); + } + return secureSettings; + } + + @Override + protected String randomRepositoryName() { + return repositoryNames.remove(randomIntBetween(0, repositoryNames.size() - 1)); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class, TestS3RepositoryPlugin.class); + } + + @Override + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + @Override + protected Settings repositorySettings(String repositoryName) { + return Settings.builder() + .put(super.repositorySettings(repositoryName)) + .put(EncryptedRepositoryPlugin.DELEGATE_TYPE_SETTING.getKey(), "s3") + .put(EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.getKey(), repositoryName) + .build(); + } + + @Override + protected void assertEmptyRepo(Map blobsMap) { + List blobs = blobsMap.keySet() + .stream() + .filter(blob -> false == blob.contains("index")) + .filter(blob -> false == blob.contains(DEK_ROOT_CONTAINER)) // encryption metadata "leaks" + .collect(Collectors.toList()); + assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + } + + @Override + protected long blobLengthFromContentLength(long contentLength) { + return getEncryptedBlobByteLength(contentLength); + } + + @Override + public void testEnforcedCooldownPeriod() { + // this test is not applicable for the encrypted repository because it verifies behavior which pertains to snapshots that must + // be created before the encrypted repository was introduced, hence no such encrypted snapshots can possibly exist + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/AESKeyUtils.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/AESKeyUtils.java new file mode 100644 index 0000000000000..92a128d93848c --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/AESKeyUtils.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.settings.SecureString; + +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.SecretKeyFactory; +import javax.crypto.spec.PBEKeySpec; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.security.Key; +import java.util.Base64; + +public final class AESKeyUtils { + public static final int KEY_LENGTH_IN_BYTES = 32; // 256-bit AES key + public static final int WRAPPED_KEY_LENGTH_IN_BYTES = KEY_LENGTH_IN_BYTES + 8; // https://www.ietf.org/rfc/rfc3394.txt section 2.2 + // parameter for the KDF function, it's a funny and unusual iter count larger than 60k + private static final int KDF_ITER = 61616; + // the KDF algorithm that generate the symmetric key given the password + private static final String KDF_ALGO = "PBKDF2WithHmacSHA512"; + // The Id of any AES SecretKey is the AES-Wrap-ciphertext of this fixed 32 byte wide array. + // Key wrapping encryption is deterministic (same plaintext generates the same ciphertext) + // and the probability that two different keys map the same plaintext to the same ciphertext is very small + // (2^-256, much lower than the UUID collision of 2^-128), assuming AES is indistinguishable from a pseudorandom permutation. + private static final byte[] KEY_ID_PLAINTEXT = "wrapping known text forms key id".getBytes(StandardCharsets.UTF_8); + + public static byte[] wrap(SecretKey wrappingKey, SecretKey keyToWrap) throws GeneralSecurityException { + assert "AES".equals(wrappingKey.getAlgorithm()); + assert "AES".equals(keyToWrap.getAlgorithm()); + Cipher c = Cipher.getInstance("AESWrap"); + c.init(Cipher.WRAP_MODE, wrappingKey); + return c.wrap(keyToWrap); + } + + public static SecretKey unwrap(SecretKey wrappingKey, byte[] keyToUnwrap) throws GeneralSecurityException { + assert "AES".equals(wrappingKey.getAlgorithm()); + assert keyToUnwrap.length == WRAPPED_KEY_LENGTH_IN_BYTES; + Cipher c = Cipher.getInstance("AESWrap"); + c.init(Cipher.UNWRAP_MODE, wrappingKey); + Key unwrappedKey = c.unwrap(keyToUnwrap, "AES", Cipher.SECRET_KEY); + return new SecretKeySpec(unwrappedKey.getEncoded(), "AES"); // make sure unwrapped key is "AES" + } + + /** + * Computes the ID of the given AES {@code SecretKey}. + * The ID can be published as it does not leak any information about the key. + * Different {@code SecretKey}s have different IDs with a very high probability. + *

+ * The ID is the ciphertext of a known plaintext, using the AES Wrap cipher algorithm. + * AES Wrap algorithm is deterministic, i.e. encryption using the same key, of the same plaintext, generates the same ciphertext. + * Moreover, the ciphertext reveals no information on the key, and the probability of collision of ciphertexts given different + * keys is statistically negligible. + */ + public static String computeId(SecretKey secretAESKey) throws GeneralSecurityException { + byte[] ciphertextOfKnownPlaintext = wrap(secretAESKey, new SecretKeySpec(KEY_ID_PLAINTEXT, "AES")); + return new String(Base64.getUrlEncoder().withoutPadding().encode(ciphertextOfKnownPlaintext), StandardCharsets.UTF_8); + } + + public static SecretKey generatePasswordBasedKey(SecureString password, String salt) throws GeneralSecurityException { + return generatePasswordBasedKey(password, salt.getBytes(StandardCharsets.UTF_8)); + } + + public static SecretKey generatePasswordBasedKey(SecureString password, byte[] salt) throws GeneralSecurityException { + PBEKeySpec keySpec = new PBEKeySpec(password.getChars(), salt, KDF_ITER, KEY_LENGTH_IN_BYTES * Byte.SIZE); + SecretKeyFactory keyFactory = SecretKeyFactory.getInstance(KDF_ALGO); + SecretKey secretKey = keyFactory.generateSecret(keySpec); + SecretKeySpec secret = new SecretKeySpec(secretKey.getEncoded(), "AES"); + return secret; + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStream.java new file mode 100644 index 0000000000000..0e178ed974472 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStream.java @@ -0,0 +1,547 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; + +/** + * A {@code BufferOnMarkInputStream} adds the {@code mark} and {@code reset} functionality to another input stream. + * All the bytes read or skipped following a {@link #mark(int)} call are also stored in a fixed-size internal array + * so they can be replayed following a {@link #reset()} call. The size of the internal buffer is specified at construction + * time. It is an error (throws {@code IllegalArgumentException}) to specify a larger {@code readlimit} value as an argument + * to a {@code mark} call. + *

+ * Unlike the {@link java.io.BufferedInputStream} this only buffers upon a {@link #mark(int)} call, + * i.e. if {@code mark} is never called this is equivalent to a bare pass-through {@link FilterInputStream}. + * Moreover, this does not buffer in advance, so the amount of bytes read from this input stream, at any time, is equal to the amount + * read from the underlying stream (provided that reset has not been called, in which case bytes are replayed from the internal buffer + * and no bytes are read from the underlying stream). + *

+ * Close will also close the underlying stream and any subsequent {@code read}, {@code skip}, {@code available} and + * {@code reset} calls will throw {@code IOException}s. + *

+ * This is NOT thread-safe, multiple threads sharing a single instance must synchronize access. + */ +public final class BufferOnMarkInputStream extends InputStream { + + /** + * the underlying input stream supplying the actual bytes to read + */ + final InputStream source; + /** + * The fixed capacity buffer used to store the bytes following a {@code mark} call on the input stream, + * and which are then replayed after the {@code reset} call. + * The buffer permits appending bytes which can then be read, possibly multiple times, by also + * supporting the mark and reset operations on its own. + * Reading will not discard the bytes just read. Subsequent reads will return the + * next bytes, but the bytes can be replayed by reading after calling {@code reset}. + * The {@code mark} operation is used to adjust the position of the reset return position to the current + * read position and also discard the bytes read before. + */ + final RingBuffer ringBuffer; // package-protected for tests + /** + * {@code true} when the result of a read or a skip from the underlying source stream must also be stored in the buffer + */ + boolean storeToBuffer; // package-protected for tests + /** + * {@code true} when the returned bytes must come from the buffer and not from the underlying source stream + */ + boolean replayFromBuffer; // package-protected for tests + /** + * {@code true} when this stream is closed and any further calls throw IOExceptions + */ + boolean closed; // package-protected for tests + + /** + * Creates a {@code BufferOnMarkInputStream} that buffers a maximum of {@code bufferSize} elements + * from the wrapped input stream {@code source} in order to support {@code mark} and {@code reset}. + * The {@code bufferSize} is the maximum value for the {@code mark} readlimit argument. + * + * @param source the underlying input buffer + * @param bufferSize the number of bytes that can be stored after a call to mark + */ + public BufferOnMarkInputStream(InputStream source, int bufferSize) { + this.source = source; + this.ringBuffer = new RingBuffer(bufferSize); + this.storeToBuffer = this.replayFromBuffer = false; + this.closed = false; + } + + /** + * Reads up to {@code len} bytes of data into an array of bytes from this + * input stream. If {@code len} is zero, then no bytes are read and {@code 0} + * is returned; otherwise, there is an attempt to read at least one byte. + * If the contents of the stream must be replayed following a {@code reset} + * call, the call will return buffered bytes which have been returned in a previous + * call. Otherwise it forwards the read call to the underlying source input stream. + * If no byte is available because there are no more bytes to replay following + * a reset (if a reset was called) and the underlying stream is exhausted, the + * value {@code -1} is returned; otherwise, at least one byte is read and stored + * into {@code b}, starting at offset {@code off}. + * + * @param b the buffer into which the data is read. + * @param off the start offset in the destination array {@code b} + * @param len the maximum number of bytes read. + * @return the total number of bytes read into the buffer, or + * {@code -1} if there is no more data because the end of + * the stream has been reached. + * @throws NullPointerException If {@code b} is {@code null}. + * @throws IndexOutOfBoundsException If {@code off} is negative, + * {@code len} is negative, or {@code len} is greater than + * {@code b.length - off} + * @throws IOException if this stream has been closed or an I/O error occurs on the underlying stream. + * @see java.io.InputStream#read(byte[], int, int) + */ + @Override + public int read(byte[] b, int off, int len) throws IOException { + ensureOpen(); + Objects.checkFromIndexSize(off, len, b.length); + if (len == 0) { + return 0; + } + // firstly try reading any buffered bytes in case this read call is part of a rewind following a reset call + if (replayFromBuffer) { + int bytesRead = ringBuffer.read(b, off, len); + if (bytesRead == 0) { + // rewinding is complete, no more bytes to replay + replayFromBuffer = false; + } else { + return bytesRead; + } + } + int bytesRead = source.read(b, off, len); + if (bytesRead <= 0) { + return bytesRead; + } + // if mark has been previously called, buffer all the read bytes + if (storeToBuffer) { + if (bytesRead > ringBuffer.getAvailableToWriteByteCount()) { + // can not fully write to buffer + // invalidate mark + storeToBuffer = false; + // empty buffer + ringBuffer.clear(); + } else { + ringBuffer.write(b, off, bytesRead); + } + } + return bytesRead; + } + + /** + * Reads the next byte of data from this input stream. The value + * byte is returned as an {@code int} in the range + * {@code 0} to {@code 255}. If no byte is available + * because the end of the stream has been reached, the value + * {@code -1} is returned. The end of the stream is reached if the + * end of the underlying stream is reached, and reset has not been + * called or there are no more bytes to replay following a reset. + * This method blocks until input data is available, the end of + * the stream is detected, or an exception is thrown. + * + * @return the next byte of data, or {@code -1} if the end of the + * stream is reached. + * @exception IOException if this stream has been closed or an I/O error occurs on the underlying stream. + * @see BufferOnMarkInputStream#read(byte[], int, int) + */ + @Override + public int read() throws IOException { + ensureOpen(); + byte[] arr = new byte[1]; + int readResult = read(arr, 0, arr.length); + if (readResult == -1) { + return -1; + } + return arr[0]; + } + + /** + * Skips over and discards {@code n} bytes of data from the + * input stream. The {@code skip} method may, for a variety of + * reasons, end up skipping over some smaller number of bytes, + * possibly {@code 0}. The actual number of bytes skipped is + * returned. + * + * @param n the number of bytes to be skipped. + * @return the actual number of bytes skipped. + * @throws IOException if this stream is closed, or if {@code in.skip(n)} throws an IOException or, + * in the case that {@code mark} is called, if BufferOnMarkInputStream#read(byte[], int, int) throws an IOException + */ + @Override + public long skip(long n) throws IOException { + ensureOpen(); + if (n <= 0) { + return 0; + } + if (false == storeToBuffer) { + // integrity check of the replayFromBuffer state variable + if (replayFromBuffer) { + throw new IllegalStateException("Reset cannot be called without a preceding mark invocation"); + } + // if mark has not been called, no storing to the buffer is required + return source.skip(n); + } + long remaining = n; + int size = (int) Math.min(2048, remaining); + byte[] skipBuffer = new byte[size]; + while (remaining > 0) { + // skipping translates to a read so that the skipped bytes are stored in the buffer, + // so they can possibly be replayed after a reset + int bytesRead = read(skipBuffer, 0, (int) Math.min(size, remaining)); + if (bytesRead < 0) { + break; + } + remaining -= bytesRead; + } + return n - remaining; + } + + /** + * Returns an estimate of the number of bytes that can be read (or + * skipped over) from this input stream without blocking by the next + * caller of a method for this input stream. The next caller might be + * the same thread or another thread. A single read or skip of this + * many bytes will not block, but may read or skip fewer bytes. + * + * @return an estimate of the number of bytes that can be read (or skipped + * over) from this input stream without blocking. + * @exception IOException if this stream is closed or if {@code in.available()} throws an IOException + */ + @Override + public int available() throws IOException { + ensureOpen(); + int bytesAvailable = 0; + if (replayFromBuffer) { + bytesAvailable += ringBuffer.getAvailableToReadByteCount(); + } + bytesAvailable += source.available(); + return bytesAvailable; + } + + /** + * Tests if this input stream supports the {@code mark} and {@code reset} methods. + * This always returns {@code true}. + */ + @Override + public boolean markSupported() { + return true; + } + + /** + * Marks the current position in this input stream. A subsequent call to + * the {@code reset} method repositions this stream at the last marked + * position so that subsequent reads re-read the same bytes. The bytes + * read or skipped following a {@code mark} call will be buffered internally + * and any previously buffered bytes are discarded. + *

+ * The {@code readlimit} arguments tells this input stream to + * allow that many bytes to be read before the mark position can be + * invalidated. The {@code readlimit} argument value must be smaller than + * the {@code bufferSize} constructor argument value, as returned by + * {@link #getMaxMarkReadlimit()}. + *

+ * The invalidation of the mark position when the read count exceeds the read + * limit is not currently enforced. A mark position is invalidated when the + * read count exceeds the maximum read limit, as returned by + * {@link #getMaxMarkReadlimit()}. + * + * @param readlimit the maximum limit of bytes that can be read before + * the mark position can be invalidated. + * @see BufferOnMarkInputStream#reset() + * @see java.io.InputStream#mark(int) + */ + @Override + public void mark(int readlimit) { + // readlimit is otherwise ignored but this defensively fails if the caller is expecting to be able to mark/reset more than this + // instance can accommodate in the fixed ring buffer + if (readlimit > ringBuffer.getBufferSize()) { + throw new IllegalArgumentException( + "Readlimit value [" + readlimit + "] exceeds the maximum value of [" + ringBuffer.getBufferSize() + "]" + ); + } else if (readlimit < 0) { + throw new IllegalArgumentException("Readlimit value [" + readlimit + "] cannot be negative"); + } + if (closed) { + return; + } + // signal that further read or skipped bytes must be stored to the buffer + storeToBuffer = true; + if (replayFromBuffer) { + // the mark operation while replaying after a reset + // this only discards the previously buffered bytes before the current position + // as well as updates the mark position in the buffer + ringBuffer.mark(); + } else { + // any previously stored bytes are discarded because mark only has to retain bytes from this position on + ringBuffer.clear(); + } + } + + /** + * Repositions this stream to the position at the time the {@code mark} method was last called on this input stream. + * It throws an {@code IOException} if {@code mark} has not yet been called on this instance. + * Internally, this resets the buffer to the last mark position and signals that further reads (and skips) + * on this input stream must return bytes from the buffer and not from the underlying source stream. + * + * @throws IOException if the stream has been closed or the number of bytes + * read since the last mark call exceeded {@link #getMaxMarkReadlimit()} + * @see java.io.InputStream#mark(int) + */ + @Override + public void reset() throws IOException { + ensureOpen(); + if (false == storeToBuffer) { + throw new IOException("Mark not called or has been invalidated"); + } + // signal that further reads/skips must be satisfied from the buffer and not from the underlying source stream + replayFromBuffer = true; + // position the buffer's read pointer back to the last mark position + ringBuffer.reset(); + } + + /** + * Closes this input stream as well as the underlying stream. + * + * @exception IOException if an I/O error occurs while closing the underlying stream. + */ + @Override + public void close() throws IOException { + if (false == closed) { + closed = true; + source.close(); + } + } + + /** + * Returns the maximum value for the {@code readlimit} argument of the {@link #mark(int)} method. + * This is the value of the {@code bufferSize} constructor argument and represents the maximum number + * of bytes that can be internally buffered (so they can be replayed after the reset call). + */ + public int getMaxMarkReadlimit() { + return ringBuffer.getBufferSize(); + } + + private void ensureOpen() throws IOException { + if (closed) { + throw new IOException("Stream has been closed"); + } + } + + /** + * This buffer is used to store all the bytes read or skipped after the last {@link BufferOnMarkInputStream#mark(int)} + * invocation. + *

+ * The latest bytes written to the ring buffer are appended following the previous ones. + * Reading back the bytes advances an internal pointer so that subsequent read calls return subsequent bytes. + * However, read bytes are not discarded. The same bytes can be re-read following the {@link #reset()} invocation. + * {@link #reset()} permits re-reading the bytes since the last {@link #mark()}} call, or since the buffer instance + * has been created or the {@link #clear()} method has been invoked. + * Calling {@link #mark()} will discard all bytes read before, and calling {@link #clear()} will discard all the + * bytes (new bytes must be written otherwise reading will return {@code 0} bytes). + */ + static class RingBuffer { + + /** + * This holds the size of the buffer which is lazily allocated on the first {@link #write(byte[], int, int)} invocation + */ + private final int bufferSize; + /** + * The array used to store the bytes to be replayed upon a reset call. + */ + byte[] buffer; // package-protected for tests + /** + * The start offset (inclusive) for the bytes that must be re-read after a reset call. This offset is advanced + * by invoking {@link #mark()} + */ + int head; // package-protected for tests + /** + * The end offset (exclusive) for the bytes that must be re-read after a reset call. This offset is advanced + * by writing to the ring buffer. + */ + int tail; // package-protected for tests + /** + * The offset of the bytes to return on the next read call. This offset is advanced by reading from the ring buffer. + */ + int position; // package-protected for tests + + /** + * Creates a new ring buffer instance that can store a maximum of {@code bufferSize} bytes. + * More bytes are stored by writing to the ring buffer, and bytes are discarded from the buffer by the + * {@code mark} and {@code reset} method invocations. + */ + RingBuffer(int bufferSize) { + if (bufferSize <= 0) { + throw new IllegalArgumentException("The buffersize constructor argument must be a strictly positive value"); + } + this.bufferSize = bufferSize; + } + + /** + * Returns the maximum number of bytes that this buffer can store. + */ + int getBufferSize() { + return bufferSize; + } + + /** + * Rewind back to the read position of the last {@link #mark()} or {@link #reset()}. The next + * {@link RingBuffer#read(byte[], int, int)} call will return the same bytes that the read + * call after the last {@link #mark()} did. + */ + void reset() { + position = head; + } + + /** + * Mark the current read position. Any previously read bytes are discarded from the ring buffer, + * i.e. they cannot be re-read, but this frees up space for writing other bytes. + * All the following {@link RingBuffer#read(byte[], int, int)} calls will revert back to this position. + */ + void mark() { + head = position; + } + + /** + * Empties out the ring buffer, discarding all the bytes written to it, i.e. any following read calls don't + * return any bytes. + */ + void clear() { + head = position = tail = 0; + } + + /** + * Copies up to {@code len} bytes from the ring buffer and places them in the {@code b} array starting at offset {@code off}. + * This advances the internal pointer of the ring buffer so that a subsequent call will return the following bytes, not the + * same ones (see {@link #reset()}). + * Exactly {@code len} bytes are copied from the ring buffer, but no more than {@link #getAvailableToReadByteCount()}; i.e. + * if {@code len} is greater than the value returned by {@link #getAvailableToReadByteCount()} this reads all the remaining + * available bytes (which could be {@code 0}). + * This returns the exact count of bytes read (the minimum of {@code len} and the value of {@code #getAvailableToReadByteCount}). + * + * @param b the array where to place the bytes read + * @param off the offset in the array where to start placing the bytes read (i.e. first byte is stored at b[off]) + * @param len the maximum number of bytes to read + * @return the number of bytes actually read + */ + int read(byte[] b, int off, int len) { + Objects.requireNonNull(b); + Objects.checkFromIndexSize(off, len, b.length); + if (position == tail || len == 0) { + return 0; + } + // the number of bytes to read + final int readLength; + if (position <= tail) { + readLength = Math.min(len, tail - position); + } else { + // the ring buffer contains elements that wrap around the end of the array + readLength = Math.min(len, buffer.length - position); + } + System.arraycopy(buffer, position, b, off, readLength); + // update the internal pointer with the bytes read + position += readLength; + if (position == buffer.length) { + // pointer wrap around + position = 0; + // also read the remaining bytes after the wrap around + return readLength + read(b, off + readLength, len - readLength); + } + return readLength; + } + + /** + * Copies exactly {@code len} bytes from the array {@code b}, starting at offset {@code off}, into the ring buffer. + * The bytes are appended after the ones written in the same way by a previous call, and are available to + * {@link #read(byte[], int, int)} immediately. + * This throws {@code IllegalArgumentException} if the ring buffer does not have enough space left. + * To get the available capacity left call {@link #getAvailableToWriteByteCount()}. + * + * @param b the array from which to copy the bytes into the ring buffer + * @param off the offset of the first element to copy + * @param len the number of elements to copy + */ + void write(byte[] b, int off, int len) { + Objects.requireNonNull(b); + Objects.checkFromIndexSize(off, len, b.length); + // allocate internal buffer lazily + if (buffer == null && len > 0) { + // "+ 1" for the full-buffer sentinel element + buffer = new byte[bufferSize + 1]; + head = position = tail = 0; + } + if (len > getAvailableToWriteByteCount()) { + throw new IllegalArgumentException("Not enough remaining space in the ring buffer"); + } + while (len > 0) { + final int writeLength; + if (head <= tail) { + writeLength = Math.min(len, buffer.length - tail - (head == 0 ? 1 : 0)); + } else { + writeLength = Math.min(len, head - tail - 1); + } + if (writeLength <= 0) { + throw new IllegalStateException("No space left in the ring buffer"); + } + System.arraycopy(b, off, buffer, tail, writeLength); + tail += writeLength; + off += writeLength; + len -= writeLength; + if (tail == buffer.length) { + tail = 0; + // tail wrap-around overwrites head + if (head == 0) { + throw new IllegalStateException("Possible overflow of the ring buffer"); + } + } + } + } + + /** + * Returns the number of bytes that can be written to this ring buffer before it becomes full + * and will not accept further writes. Be advised that reading (see {@link #read(byte[], int, int)}) + * does not free up space because bytes can be re-read multiple times (see {@link #reset()}); + * ring buffer space can be reclaimed by calling {@link #mark()} or {@link #clear()} + */ + int getAvailableToWriteByteCount() { + if (buffer == null) { + return bufferSize; + } + if (head == tail) { + return buffer.length - 1; + } else if (head < tail) { + return buffer.length - tail + head - 1; + } else { + return head - tail - 1; + } + } + + /** + * Returns the number of bytes that can be read from this ring buffer before it becomes empty + * and all subsequent {@link #read(byte[], int, int)} calls will return {@code 0}. Writing + * more bytes (see {@link #write(byte[], int, int)}) will obviously increase the number of + * bytes available to read. Calling {@link #reset()} will also increase the available byte + * count because the following reads will go over again the same bytes since the last + * {@code mark} call. + */ + int getAvailableToReadByteCount() { + if (buffer == null) { + return 0; + } + if (head <= tail) { + return tail - position; + } else if (position >= head) { + return buffer.length - position + tail; + } else { + return tail - position; + } + } + + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/ChainingInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/ChainingInputStream.java new file mode 100644 index 0000000000000..cc86f772d53c7 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/ChainingInputStream.java @@ -0,0 +1,426 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.util.Objects; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.core.internal.io.IOUtils; + +/** + * A {@code ChainingInputStream} concatenates multiple component input streams into a + * single input stream. + * It starts reading from the first input stream until it's exhausted, whereupon + * it closes it and starts reading from the next one, until the last component input + * stream is exhausted. + *

+ * The implementing subclass provides the component input streams by implementing the + * {@link #nextComponent(InputStream)} method. This method receives the instance of the + * current input stream, which has been exhausted, and must return the next input stream, + * or {@code null} if there are no more component streams. + * The {@code ChainingInputStream} assumes ownership of the newly generated component input + * stream, i.e. components should not be used by other callers and they will be closed + * when they are exhausted or when the {@code ChainingInputStream} is closed. + *

+ * This stream does support {@code mark} and {@code reset} but it expects that the component + * streams also support it. When {@code mark} is invoked on the chaining input stream, the + * call is forwarded to the current input stream component and a reference to that component + * is stored internally. A {@code reset} invocation on the chaining input stream will then make the + * stored component the current component and will then call the {@code reset} on it. + * The {@link #nextComponent(InputStream)} method must be able to generate the same components + * anew, starting from the component of the {@code reset} call. + * If the component input streams do not support {@code mark}/{@code reset} or + * {@link #nextComponent(InputStream)} cannot generate the same component multiple times, + * the implementing subclass must override {@link #markSupported()} to return {@code false}. + *

+ * The {@code close} call will close the current component input stream and any subsequent {@code read}, + * {@code skip}, {@code available} and {@code reset} calls will throw {@code IOException}s. + *

+ * The {@code ChainingInputStream} is similar in purpose to the {@link java.io.SequenceInputStream}, + * with the addition of {@code mark}/{@code reset} support. + *

+ * This is NOT thread-safe, multiple threads sharing a single instance must synchronize access. + */ +public abstract class ChainingInputStream extends InputStream { + + private static final Logger LOGGER = LogManager.getLogger(ChainingInputStream.class); + + /** + * value for the current input stream when there are no subsequent streams remaining, i.e. when + * {@link #nextComponent(InputStream)} returns {@code null} + */ + protected static final InputStream EXHAUSTED_MARKER = InputStream.nullInputStream(); // protected for tests + + /** + * The instance of the currently in use component input stream, + * i.e. the instance currently servicing the read and skip calls on the {@code ChainingInputStream} + */ + protected InputStream currentIn; // protected for tests + /** + * The instance of the component input stream at the time of the last {@code mark} call. + */ + protected InputStream markIn; // protected for tests + /** + * {@code true} if {@link #close()} has been called; any subsequent {@code read}, {@code skip} + * {@code available} and {@code reset} calls will throw {@code IOException}s + */ + private boolean closed; + + /** + * Returns a new {@link ChainingInputStream} that concatenates the bytes to be read from the first + * input stream with the bytes from the second input stream. The stream arguments must support + * the {@code mark} and {@code reset} operations; otherwise use {@link SequenceInputStream}. + * + * @param first the input stream supplying the first bytes of the returned {@link ChainingInputStream} + * @param second the input stream supplying the bytes after the {@code first} input stream has been exhausted + */ + public static ChainingInputStream chain(InputStream first, InputStream second) { + if (false == Objects.requireNonNull(first).markSupported()) { + throw new IllegalArgumentException("The first component input stream does not support mark"); + } + if (false == Objects.requireNonNull(second).markSupported()) { + throw new IllegalArgumentException("The second component input stream does not support mark"); + } + // components can be reused, and the {@code ChainingInputStream} eagerly closes components after every use + // "first" and "second" are closed when the returned {@code ChainingInputStream} is closed + final InputStream firstComponent = Streams.noCloseStream(first); + final InputStream secondComponent = Streams.noCloseStream(second); + // be sure to remember the start of components because they might be reused + firstComponent.mark(Integer.MAX_VALUE); + secondComponent.mark(Integer.MAX_VALUE); + + return new ChainingInputStream() { + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + // when returning the next component, start from its beginning + firstComponent.reset(); + return firstComponent; + } else if (currentComponentIn == firstComponent) { + // when returning the next component, start from its beginning + secondComponent.reset(); + return secondComponent; + } else if (currentComponentIn == secondComponent) { + return null; + } else { + throw new IllegalStateException("Unexpected component input stream"); + } + } + + @Override + public void close() throws IOException { + IOUtils.close(super::close, first, second); + } + }; + } + + /** + * This method is responsible for generating the component input streams. + * It is passed the current input stream and must return the successive one, + * or {@code null} if the current component is the last one. + * It is passed the {@code null} value at the very start, when no component + * input stream has yet been generated. + * The successive input stream returns the bytes (during reading) that should + * logically follow the bytes that have been previously returned by the passed-in + * {@code currentComponentIn}; i.e. the first {@code read} call on the next + * component returns the byte logically following the last byte of the previous + * component. + * In order to support {@code mark}/{@code reset} this method must be able + * to generate the successive input stream given any of the previously generated + * ones, i.e. implementors must not assume that the passed-in argument is the + * instance last returned by this method. Therefore, implementors must identify + * the bytes that the passed-in component generated and must return a new + * {@code InputStream} which returns the bytes that logically follow, even if + * the same sequence has been previously returned by another component. + * If this is not possible, and the implementation + * can only generate the component input streams once, it must override + * {@link #nextComponent(InputStream)} to return {@code false}. + */ + abstract @Nullable InputStream nextComponent(@Nullable InputStream currentComponentIn) throws IOException; + + /** + * Reads the next byte of data from this chaining input stream. + * The value byte is returned as an {@code int} in the range + * {@code 0} to {@code 255}. If no byte is available + * because the end of the stream has been reached, the value + * {@code -1} is returned. The end of the chaining input stream + * is reached when the end of the last component stream is reached. + * This method blocks until input data is available (possibly + * asking for the next input stream component), the end of + * the stream is detected, or an exception is thrown. + * + * @return the next byte of data, or {@code -1} if the end of the + * stream is reached. + * @exception IOException if this stream has been closed or + * an I/O error occurs on the current component stream. + * @see ChainingInputStream#read(byte[], int, int) + */ + @Override + public int read() throws IOException { + ensureOpen(); + do { + int byteVal = currentIn == null ? -1 : currentIn.read(); + if (byteVal != -1) { + return byteVal; + } + } while (nextIn()); + return -1; + } + + /** + * Reads up to {@code len} bytes of data into an array of bytes from this + * chaining input stream. If {@code len} is zero, then no bytes are read + * and {@code 0} is returned; otherwise, there is an attempt to read at least one byte. + * The {@code read} call is forwarded to the current component input stream. + * If the current component input stream is exhausted, the next one is obtained + * by invoking {@link #nextComponent(InputStream)} and the {@code read} call is + * forwarded to that. If the current component is exhausted + * and there is no subsequent component the value {@code -1} is returned; + * otherwise, at least one byte is read and stored into {@code b}, starting at + * offset {@code off}. + * + * @param b the buffer into which the data is read. + * @param off the start offset in the destination array {@code b} + * @param len the maximum number of bytes read. + * @return the total number of bytes read into the buffer, or + * {@code -1} if there is no more data because the current + * input stream component is exhausted and there is no next one + * {@link #nextComponent(InputStream)} retuns {@code null}. + * @throws NullPointerException If {@code b} is {@code null}. + * @throws IndexOutOfBoundsException If {@code off} is negative, + * {@code len} is negative, or {@code len} is greater than + * {@code b.length - off} + * @throws IOException if this stream has been closed or an I/O error + * occurs on the current component input stream. + * @see java.io.InputStream#read(byte[], int, int) + */ + @Override + public int read(byte[] b, int off, int len) throws IOException { + ensureOpen(); + Objects.checkFromIndexSize(off, len, b.length); + if (len == 0) { + return 0; + } + do { + int bytesRead = currentIn == null ? -1 : currentIn.read(b, off, len); + if (bytesRead != -1) { + return bytesRead; + } + } while (nextIn()); + return -1; + } + + /** + * Skips over and discards {@code n} bytes of data from the + * chaining input stream. If {@code n} is negative or {@code 0}, + * the value {@code 0} is returned and no bytes are skipped. + * The {@code skip} method will skip exactly {@code n} bytes, + * possibly generating the next component input streams and + * recurring to {@code read} if {@code skip} on the current + * component does not make progress (returns 0). + * The actual number of bytes skipped, which can be smaller than + * {@code n}, is returned. + * + * @param n the number of bytes to be skipped. + * @return the actual number of bytes skipped. + * @throws IOException if this stream is closed, or if + * {@code currentComponentIn.skip(n)} throws an IOException + */ + @Override + public long skip(long n) throws IOException { + ensureOpen(); + if (n <= 0) { + return 0; + } + if (currentIn == null) { + nextIn(); + } + long bytesRemaining = n; + while (bytesRemaining > 0) { + long bytesSkipped = currentIn.skip(bytesRemaining); + if (bytesSkipped == 0) { + int byteRead = read(); + if (byteRead == -1) { + break; + } else { + bytesRemaining--; + } + } else { + bytesRemaining -= bytesSkipped; + } + } + return n - bytesRemaining; + } + + /** + * Returns an estimate of the number of bytes that can be read (or + * skipped over) from this chaining input stream without blocking by the next + * caller of a method for this stream. The next caller might be + * the same thread or another thread. A single read or skip of this + * many bytes will not block, but may read or skip fewer bytes. + *

+ * This simply forwards the {@code available} call to the current + * component input stream, so the returned value is a conservative + * lower bound of the available byte count; i.e. it's possible that + * subsequent component streams have available bytes but this method + * only returns the available bytes of the current component. + * + * @return an estimate of the number of bytes that can be read (or skipped + * over) from this input stream without blocking. + * @exception IOException if this stream is closed or if + * {@code currentIn.available()} throws an IOException + */ + @Override + public int available() throws IOException { + ensureOpen(); + if (currentIn == null) { + nextIn(); + } + return currentIn.available(); + } + + /** + * Tests if this chaining input stream supports the {@code mark} and + * {@code reset} methods. By default this returns {@code true} but there + * are some requirements for how components are generated (see + * {@link #nextComponent(InputStream)}), in which case, if the implementer + * cannot satisfy them, it should override this to return {@code false}. + */ + @Override + public boolean markSupported() { + return true; + } + + /** + * Marks the current position in this input stream. A subsequent call to + * the {@code reset} method repositions this stream at the last marked + * position so that subsequent reads re-read the same bytes. + *

+ * The {@code readlimit} arguments tells this input stream to + * allow that many bytes to be read before the mark position can be + * invalidated. + *

+ * The {@code mark} call is forwarded to the current component input + * stream and a reference to it is stored internally. + * + * @param readlimit the maximum limit of bytes that can be read before + * the mark position can be invalidated. + * @see BufferOnMarkInputStream#reset() + * @see java.io.InputStream#mark(int) + */ + @Override + public void mark(int readlimit) { + if (markSupported() && false == closed) { + // closes any previously stored mark input stream + if (markIn != null && markIn != EXHAUSTED_MARKER && currentIn != markIn) { + try { + markIn.close(); + } catch (IOException e) { + // an IOException on a component input stream close is not important + LOGGER.info("IOException while closing a marked component input stream during a mark", e); + } + } + // stores the current input stream to be reused in case of a reset + markIn = currentIn; + if (markIn != null && markIn != EXHAUSTED_MARKER) { + markIn.mark(readlimit); + } + } + } + + /** + * Repositions this stream to the position at the time the + * {@code mark} method was last called on this chaining input stream, + * or at the beginning if the {@code mark} method was never called. + * Subsequent read calls will return the same bytes in the same + * order since the point of the {@code mark} call. Naturally, + * {@code mark} can be invoked at any moment, even after a + * {@code reset}. + *

+ * The previously stored reference to the current component during the + * {@code mark} invocation is made the new current component and then + * the {@code reset} call is forwarded to it. The next internal call to + * {@link #nextComponent(InputStream)} will use this component, so + * the {@link #nextComponent(InputStream)} must not assume monotonous + * arguments. + * + * @throws IOException if the stream has been closed or the number of bytes + * read since the last mark call exceeded the + * {@code readLimit} parameter + * @see java.io.InputStream#mark(int) + */ + @Override + public void reset() throws IOException { + ensureOpen(); + if (false == markSupported()) { + throw new IOException("Mark/reset not supported"); + } + if (currentIn != null && currentIn != EXHAUSTED_MARKER && currentIn != markIn) { + try { + currentIn.close(); + } catch (IOException e) { + // an IOException on a component input stream close is not important + LOGGER.info("IOException while closing the current component input stream during a reset", e); + } + } + currentIn = markIn; + if (currentIn != null && currentIn != EXHAUSTED_MARKER) { + currentIn.reset(); + } + } + + /** + * Closes this chaining input stream, closing the current component stream as well + * as any internally stored reference of a component during a {@code mark} call. + * + * @exception IOException if an I/O error occurs while closing the current or the marked stream. + */ + @Override + public void close() throws IOException { + if (false == closed) { + closed = true; + if (currentIn != null && currentIn != EXHAUSTED_MARKER) { + currentIn.close(); + } + if (markIn != null && markIn != currentIn && markIn != EXHAUSTED_MARKER) { + markIn.close(); + } + } + } + + private void ensureOpen() throws IOException { + if (closed) { + throw new IOException("Stream is closed"); + } + } + + private boolean nextIn() throws IOException { + if (currentIn == EXHAUSTED_MARKER) { + return false; + } + // close the current component, but only if it is not saved because of mark + if (currentIn != null && currentIn != markIn) { + currentIn.close(); + } + currentIn = nextComponent(currentIn); + if (currentIn == null) { + currentIn = EXHAUSTED_MARKER; + return false; + } + if (markSupported() && false == currentIn.markSupported()) { + throw new IllegalStateException("Component input stream must support mark"); + } + return true; + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/CountingInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/CountingInputStream.java new file mode 100644 index 0000000000000..91245beaa8707 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/CountingInputStream.java @@ -0,0 +1,115 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; + +/** + * A {@code CountingInputStream} wraps another input stream and counts the number of bytes + * that have been read or skipped. + *

+ * This input stream does no buffering on its own and only supports {@code mark} and + * {@code reset} if the underlying wrapped stream supports it. + *

+ * If the stream supports {@code mark} and {@code reset} the byte count is also reset to the + * value that it had on the last {@code mark} call, thereby not counting the same bytes twice. + *

+ * If the {@code closeSource} constructor argument is {@code true}, closing this + * stream will also close the wrapped input stream. Apart from closing the wrapped + * stream in this case, the {@code close} method does nothing else. + */ +public final class CountingInputStream extends InputStream { + + private final InputStream source; + private final boolean closeSource; + long count; // package-protected for tests + long mark; // package-protected for tests + boolean closed; // package-protected for tests + + /** + * Wraps another input stream, counting the number of bytes read. + * + * @param source the input stream to be wrapped + * @param closeSource {@code true} if closing this stream will also close the wrapped stream + */ + public CountingInputStream(InputStream source, boolean closeSource) { + this.source = Objects.requireNonNull(source); + this.closeSource = closeSource; + this.count = 0L; + this.mark = -1L; + this.closed = false; + } + + /** Returns the number of bytes read. */ + public long getCount() { + return count; + } + + @Override + public int read() throws IOException { + int result = source.read(); + if (result != -1) { + count++; + } + return result; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int result = source.read(b, off, len); + if (result != -1) { + count += result; + } + return result; + } + + @Override + public long skip(long n) throws IOException { + long result = source.skip(n); + count += result; + return result; + } + + @Override + public int available() throws IOException { + return source.available(); + } + + @Override + public boolean markSupported() { + return source.markSupported(); + } + + @Override + public synchronized void mark(int readlimit) { + source.mark(readlimit); + mark = count; + } + + @Override + public synchronized void reset() throws IOException { + if (false == source.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1L) { + throw new IOException("Mark not set"); + } + count = mark; + source.reset(); + } + + @Override + public void close() throws IOException { + if (false == closed) { + closed = true; + if (closeSource) { + source.close(); + } + } + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStream.java new file mode 100644 index 0000000000000..77fc979e00094 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStream.java @@ -0,0 +1,176 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.util.ByteUtils; +import org.elasticsearch.core.internal.io.IOUtils; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.GCMParameterSpec; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Objects; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.GCM_IV_LENGTH_IN_BYTES; +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + +/** + * A {@code DecryptionPacketsInputStream} wraps an encrypted input stream and decrypts + * its contents. This is designed (and tested) to decrypt only the encryption format that + * {@link EncryptionPacketsInputStream} generates. No decrypted bytes are returned before + * they are authenticated. + *

+ * The same parameters, namely {@code secretKey} and {@code packetLength}, + * which have been used during encryption, must also be used for decryption, + * otherwise decryption will fail. + *

+ * This implementation buffers the encrypted packet in memory. The maximum packet size it can + * accommodate is {@link EncryptedRepository#MAX_PACKET_LENGTH_IN_BYTES}. + *

+ * This implementation does not support {@code mark} and {@code reset}. + *

+ * The {@code close} call will close the decryption input stream and any subsequent {@code read}, + * {@code skip}, {@code available} and {@code reset} calls will throw {@code IOException}s. + *

+ * This is NOT thread-safe, multiple threads sharing a single instance must synchronize access. + * + * @see EncryptionPacketsInputStream + */ +public final class DecryptionPacketsInputStream extends ChainingInputStream { + + private final InputStream source; + private final SecretKey secretKey; + private final int packetLength; + private final byte[] packetBuffer; + + private boolean hasNext; + private long counter; + + /** + * Computes and returns the length of the plaintext given the {@code ciphertextLength} and the {@code packetLength} + * used during encryption. + * Each ciphertext packet is prepended by the Initilization Vector and has the Authentication Tag appended. + * Decryption is 1:1, and the ciphertext is not padded, but stripping away the IV and the AT amounts to a shorter + * plaintext compared to the ciphertext. + * + * @see EncryptionPacketsInputStream#getEncryptionLength(long, int) + */ + public static long getDecryptionLength(long ciphertextLength, int packetLength) { + long encryptedPacketLength = packetLength + GCM_TAG_LENGTH_IN_BYTES + GCM_IV_LENGTH_IN_BYTES; + long completePackets = ciphertextLength / encryptedPacketLength; + long decryptedSize = completePackets * packetLength; + if (ciphertextLength % encryptedPacketLength != 0) { + decryptedSize += (ciphertextLength % encryptedPacketLength) - GCM_IV_LENGTH_IN_BYTES - GCM_TAG_LENGTH_IN_BYTES; + } + return decryptedSize; + } + + public DecryptionPacketsInputStream(InputStream source, SecretKey secretKey, int packetLength) { + this.source = Objects.requireNonNull(source); + this.secretKey = Objects.requireNonNull(secretKey); + if (packetLength <= 0 || packetLength >= EncryptedRepository.MAX_PACKET_LENGTH_IN_BYTES) { + throw new IllegalArgumentException("Invalid packet length [" + packetLength + "]"); + } + this.packetLength = packetLength; + this.packetBuffer = new byte[packetLength + GCM_TAG_LENGTH_IN_BYTES]; + this.hasNext = true; + this.counter = EncryptedRepository.PACKET_START_COUNTER; + } + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn != null && currentComponentIn.read() != -1) { + throw new IllegalStateException("Stream for previous packet has not been fully processed"); + } + if (false == hasNext) { + return null; + } + PrefixInputStream packetInputStream = new PrefixInputStream( + source, + packetLength + GCM_IV_LENGTH_IN_BYTES + GCM_TAG_LENGTH_IN_BYTES, + false + ); + int currentPacketLength = decrypt(packetInputStream); + // only the last packet is shorter, so this must be the last packet + if (currentPacketLength != packetLength) { + hasNext = false; + } + return new ByteArrayInputStream(packetBuffer, 0, currentPacketLength); + } + + @Override + public boolean markSupported() { + return false; + } + + @Override + public void mark(int readlimit) {} + + @Override + public void reset() throws IOException { + throw new IOException("Mark/reset not supported"); + } + + @Override + public void close() throws IOException { + IOUtils.close(super::close, source); + } + + private int decrypt(PrefixInputStream packetInputStream) throws IOException { + // read only the IV prefix into the packet buffer + int ivLength = packetInputStream.readNBytes(packetBuffer, 0, GCM_IV_LENGTH_IN_BYTES); + if (ivLength != GCM_IV_LENGTH_IN_BYTES) { + throw new IOException("Packet heading IV error. Unexpected length [" + ivLength + "]."); + } + // extract the counter from the packet IV and validate it (that the packet is in order) + // skips the first 4 bytes in the packet IV, which contain the encryption nonce, which cannot be explicitly validated + // because the nonce is not passed in during decryption, but it is implicitly because it is part of the IV, + // when GCM validates the packet authn tag + long packetIvCounter = ByteUtils.readLongLE(packetBuffer, Integer.BYTES); + if (packetIvCounter != counter) { + throw new IOException("Packet counter mismatch. Expecting [" + counter + "], but got [" + packetIvCounter + "]."); + } + // counter increment for the subsequent packet + counter++; + // counter wrap around + if (counter == EncryptedRepository.PACKET_START_COUNTER) { + throw new IOException("Maximum packet count limit exceeded"); + } + // cipher used to decrypt only the current packetInputStream + Cipher packetCipher = getPacketDecryptionCipher(packetBuffer); + // read the rest of the packet, reusing the packetBuffer + int packetLength = packetInputStream.readNBytes(packetBuffer, 0, packetBuffer.length); + if (packetLength < GCM_TAG_LENGTH_IN_BYTES) { + throw new IOException("Encrypted packet is too short"); + } + try { + // in-place decryption of the whole packet and return decrypted length + return packetCipher.doFinal(packetBuffer, 0, packetLength, packetBuffer); + } catch (ShortBufferException | IllegalBlockSizeException | BadPaddingException e) { + throw new IOException("Exception during packet decryption", e); + } + } + + private Cipher getPacketDecryptionCipher(byte[] packet) throws IOException { + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(GCM_TAG_LENGTH_IN_BYTES * Byte.SIZE, packet, 0, GCM_IV_LENGTH_IN_BYTES); + try { + Cipher packetCipher = Cipher.getInstance(EncryptedRepository.DATA_ENCRYPTION_SCHEME); + packetCipher.init(Cipher.DECRYPT_MODE, secretKey, gcmParameterSpec); + return packetCipher; + } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | InvalidAlgorithmParameterException e) { + throw new IOException("Exception during packet cipher initialisation", e); + } + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepository.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepository.java new file mode 100644 index 0000000000000..951e157070678 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepository.java @@ -0,0 +1,697 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.IndexCommit; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.CheckedFunction; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.blobstore.BlobContainer; +import org.elasticsearch.common.blobstore.BlobMetadata; +import org.elasticsearch.common.blobstore.BlobPath; +import org.elasticsearch.common.blobstore.BlobStore; +import org.elasticsearch.common.blobstore.DeleteResult; +import org.elasticsearch.common.blobstore.support.AbstractBlobContainer; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.cache.Cache; +import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus; +import org.elasticsearch.index.store.Store; +import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.repositories.IndexId; +import org.elasticsearch.repositories.RepositoryData; +import org.elasticsearch.repositories.RepositoryException; +import org.elasticsearch.repositories.RepositoryStats; +import org.elasticsearch.repositories.ShardGenerations; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.snapshots.SnapshotId; +import org.elasticsearch.snapshots.SnapshotInfo; + +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.NoSuchFileException; +import java.security.GeneralSecurityException; +import java.security.SecureRandom; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; +import java.util.function.Supplier; + +public class EncryptedRepository extends BlobStoreRepository { + static final Logger logger = LogManager.getLogger(EncryptedRepository.class); + // the following constants are fixed by definition + static final int GCM_TAG_LENGTH_IN_BYTES = 16; + static final int GCM_IV_LENGTH_IN_BYTES = 12; + static final int AES_BLOCK_LENGTH_IN_BYTES = 128; + // the following constants require careful thought before changing because they will break backwards compatibility + static final String DATA_ENCRYPTION_SCHEME = "AES/GCM/NoPadding"; + static final long PACKET_START_COUNTER = Long.MIN_VALUE; + static final int MAX_PACKET_LENGTH_IN_BYTES = 8 << 20; // 8MB + // this should be smaller than {@code #MAX_PACKET_LENGTH_IN_BYTES} and it's what {@code EncryptionPacketsInputStream} uses + // during encryption and what {@code DecryptionPacketsInputStream} expects during decryption (it is not configurable) + static final int PACKET_LENGTH_IN_BYTES = 64 * (1 << 10); // 64KB + // the path of the blob container holding all the DEKs + // this is relative to the root base path holding the encrypted blobs (i.e. the repository root base path) + static final String DEK_ROOT_CONTAINER = ".encryption-metadata"; // package private for tests + static final int DEK_ID_LENGTH = 22; // {@code org.elasticsearch.common.UUIDS} length + + // the snapshot metadata (residing in the cluster state for the lifetime of the snapshot) + // contains the salted hash of the repository password as present on the master node (which starts the snapshot operation). + // The hash is verified on each data node, before initiating the actual shard files snapshot, as well + // as on the master node that finalizes the snapshot (which could be a different master node from the one that started + // the operation if a master failover occurred during the snapshot). + // This ensures that all participating nodes in the snapshot operation agree on the value of the key encryption key, so that + // all the data included in a snapshot is encrypted using the same password. + static final String PASSWORD_HASH_USER_METADATA_KEY = EncryptedRepository.class.getName() + ".repositoryPasswordHash"; + static final String PASSWORD_SALT_USER_METADATA_KEY = EncryptedRepository.class.getName() + ".repositoryPasswordSalt"; + private static final int DEK_CACHE_WEIGHT = 2048; + + // this is the repository instance to which all blob reads and writes are forwarded to (it stores both the encrypted blobs, as well + // as the associated encrypted DEKs) + private final BlobStoreRepository delegatedRepository; + // every data blob is encrypted with its randomly generated AES key (DEK) + private final Supplier> dekGenerator; + // license is checked before every snapshot operations; protected non-final for tests + protected Supplier licenseStateSupplier; + private final SecureString repositoryPassword; + private final String localRepositoryPasswordHash; + private final String localRepositoryPasswordSalt; + private volatile String validatedLocalRepositoryPasswordHash; + private final Cache dekCache; + + /** + * Returns the byte length (i.e. the storage size) of an encrypted blob, given the length of the blob's plaintext contents. + * + * @see EncryptionPacketsInputStream#getEncryptionLength(long, int) + */ + public static long getEncryptedBlobByteLength(long plaintextBlobByteLength) { + return (long) DEK_ID_LENGTH /* UUID byte length */ + + EncryptionPacketsInputStream.getEncryptionLength(plaintextBlobByteLength, PACKET_LENGTH_IN_BYTES); + } + + protected EncryptedRepository( + RepositoryMetadata metadata, + NamedXContentRegistry namedXContentRegistry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings, + BlobStoreRepository delegatedRepository, + Supplier licenseStateSupplier, + SecureString repositoryPassword + ) throws GeneralSecurityException { + super( + metadata, + namedXContentRegistry, + clusterService, + bigArrays, + recoverySettings, + BlobPath.cleanPath() /* the encrypted repository uses a hardcoded empty + base blob path but the base path setting is honored for the delegated repository */ + ); + this.delegatedRepository = delegatedRepository; + this.dekGenerator = createDEKGenerator(); + this.licenseStateSupplier = licenseStateSupplier; + this.repositoryPassword = repositoryPassword; + // the salt used to generate an irreversible "hash"; it is generated randomly but it's fixed for the lifetime of the + // repository solely for efficiency reasons + this.localRepositoryPasswordSalt = UUIDs.randomBase64UUID(); + // the "hash" of the repository password from the local node is not actually a hash but the ciphertext of a + // known-plaintext using a key derived from the repository password using a random salt + this.localRepositoryPasswordHash = AESKeyUtils.computeId( + AESKeyUtils.generatePasswordBasedKey(repositoryPassword, localRepositoryPasswordSalt) + ); + // a "hash" computed locally is also locally trusted (trivially) + this.validatedLocalRepositoryPasswordHash = this.localRepositoryPasswordHash; + // stores decrypted DEKs; DEKs are reused to encrypt/decrypt multiple independent blobs + this.dekCache = CacheBuilder.builder().setMaximumWeight(DEK_CACHE_WEIGHT).build(); + if (isReadOnly() != delegatedRepository.isReadOnly()) { + throw new RepositoryException( + metadata.name(), + "Unexpected fatal internal error", + new IllegalStateException("The encrypted repository must be read-only iff the delegate repository is read-only") + ); + } + } + + @Override + public RepositoryStats stats() { + return this.delegatedRepository.stats(); + } + + /** + * The repository hook method which populates the snapshot metadata with the salted password hash of the repository on the (master) + * node that starts of the snapshot operation. All the other actions associated with the same snapshot operation will first verify + * that the local repository password checks with the hash from the snapshot metadata. + *

+ * In addition, if the installed license does not comply with the "encrypted snapshots" feature, this method throws an exception, + * which aborts the snapshot operation. + * + * See {@link org.elasticsearch.repositories.Repository#adaptUserMetadata(Map)}. + * + * @param userMetadata the snapshot metadata as received from the calling user + * @return the snapshot metadata containing the salted password hash of the node initializing the snapshot + */ + @Override + public Map adaptUserMetadata(Map userMetadata) { + // because populating the snapshot metadata must be done before the actual snapshot is first initialized, + // we take the opportunity to validate the license and abort if non-compliant + if (false == licenseStateSupplier.get().isAllowed(XPackLicenseState.Feature.ENCRYPTED_SNAPSHOT)) { + throw LicenseUtils.newComplianceException("encrypted snapshots"); + } + Map snapshotUserMetadata = new HashMap<>(); + if (userMetadata != null) { + snapshotUserMetadata.putAll(userMetadata); + } + // fill in the hash of the repository password, which is then checked before every snapshot operation + // (i.e. {@link #snapshotShard} and {@link #finalizeSnapshot}) to ensure that all participating nodes + // in the snapshot operation use the same repository password + snapshotUserMetadata.put(PASSWORD_SALT_USER_METADATA_KEY, localRepositoryPasswordSalt); + snapshotUserMetadata.put(PASSWORD_HASH_USER_METADATA_KEY, localRepositoryPasswordHash); + logger.trace( + "Snapshot metadata for local repository password [{}] and [{}]", + localRepositoryPasswordSalt, + localRepositoryPasswordHash + ); + // do not wrap in Map.of; we have to be able to modify the map (remove the added entries) when finalizing the snapshot + return snapshotUserMetadata; + } + + @Override + public void finalizeSnapshot( + ShardGenerations shardGenerations, + long repositoryStateId, + Metadata clusterMetadata, + SnapshotInfo snapshotInfo, + Version repositoryMetaVersion, + Function stateTransformer, + ActionListener listener + ) { + try { + validateLocalRepositorySecret(snapshotInfo.userMetadata()); + } catch (RepositoryException passwordValidationException) { + listener.onFailure(passwordValidationException); + return; + } finally { + // remove the repository password hash (and salt) from the snapshot metadata so that it is not displayed in the API response + // to the user + snapshotInfo.userMetadata().remove(PASSWORD_HASH_USER_METADATA_KEY); + snapshotInfo.userMetadata().remove(PASSWORD_SALT_USER_METADATA_KEY); + } + super.finalizeSnapshot( + shardGenerations, + repositoryStateId, + clusterMetadata, + snapshotInfo, + repositoryMetaVersion, + stateTransformer, + listener + ); + } + + @Override + public void snapshotShard( + Store store, + MapperService mapperService, + SnapshotId snapshotId, + IndexId indexId, + IndexCommit snapshotIndexCommit, + String shardStateIdentifier, + IndexShardSnapshotStatus snapshotStatus, + Version repositoryMetaVersion, + Map userMetadata, + ActionListener listener + ) { + try { + validateLocalRepositorySecret(userMetadata); + } catch (RepositoryException passwordValidationException) { + listener.onFailure(passwordValidationException); + return; + } + super.snapshotShard( + store, + mapperService, + snapshotId, + indexId, + snapshotIndexCommit, + shardStateIdentifier, + snapshotStatus, + repositoryMetaVersion, + userMetadata, + listener + ); + } + + @Override + protected BlobStore createBlobStore() { + final Supplier> blobStoreDEKGenerator; + if (isReadOnly()) { + // make sure that a read-only repository can't encrypt anything + blobStoreDEKGenerator = () -> { + throw new RepositoryException( + metadata.name(), + "Unexpected fatal internal error", + new IllegalStateException("DEKs are required for encryption but this is a read-only repository") + ); + }; + } else { + blobStoreDEKGenerator = this.dekGenerator; + } + return new EncryptedBlobStore( + delegatedRepository.blobStore(), + delegatedRepository.basePath(), + metadata.name(), + this::generateKEK, + blobStoreDEKGenerator, + dekCache + ); + } + + @Override + protected void doStart() { + this.delegatedRepository.start(); + super.doStart(); + } + + @Override + protected void doStop() { + super.doStop(); + this.delegatedRepository.stop(); + } + + @Override + protected void doClose() { + super.doClose(); + this.delegatedRepository.close(); + } + + private Supplier> createDEKGenerator() throws GeneralSecurityException { + // DEK and DEK Ids MUST be generated randomly (with independent random instances) + // the rand algo is not pinned so that it goes well with various providers (eg FIPS) + // TODO maybe we can make this a setting for rigurous users + final SecureRandom dekSecureRandom = new SecureRandom(); + final SecureRandom dekIdSecureRandom = new SecureRandom(); + final KeyGenerator dekGenerator = KeyGenerator.getInstance(DATA_ENCRYPTION_SCHEME.split("/")[0]); + dekGenerator.init(AESKeyUtils.KEY_LENGTH_IN_BYTES * Byte.SIZE, dekSecureRandom); + return () -> { + final BytesReference dekId = new BytesArray(UUIDs.randomBase64UUID(dekIdSecureRandom)); + final SecretKey dek = dekGenerator.generateKey(); + logger.debug("Repository [{}] generated new DEK [{}]", metadata.name(), dekId); + return new Tuple<>(dekId, dek); + }; + } + + // pkg-private for tests + Tuple generateKEK(String dekId) { + try { + // we rely on the DEK Id being generated randomly so it can be used as a salt + final SecretKey kek = AESKeyUtils.generatePasswordBasedKey(repositoryPassword, dekId); + final String kekId = AESKeyUtils.computeId(kek); + logger.debug("Repository [{}] computed KEK [{}] for DEK [{}]", metadata.name(), kekId, dekId); + return new Tuple<>(kekId, kek); + } catch (GeneralSecurityException e) { + throw new RepositoryException(metadata.name(), "Failure to generate KEK to wrap the DEK [" + dekId + "]", e); + } + } + + /** + * Called before the shard snapshot and finalize operations, on the data and master nodes. This validates that the repository + * password on the master node that started the snapshot operation is identical to the repository password on the local node. + * + * @param snapshotUserMetadata the snapshot metadata containing the repository password hash to assert + * @throws RepositoryException if the repository password hash on the local node mismatches the master's + */ + private void validateLocalRepositorySecret(Map snapshotUserMetadata) throws RepositoryException { + assert snapshotUserMetadata != null; + assert snapshotUserMetadata.get(PASSWORD_HASH_USER_METADATA_KEY) instanceof String; + final String masterRepositoryPasswordId = (String) snapshotUserMetadata.get(PASSWORD_HASH_USER_METADATA_KEY); + if (false == masterRepositoryPasswordId.equals(validatedLocalRepositoryPasswordHash)) { + assert snapshotUserMetadata.get(PASSWORD_SALT_USER_METADATA_KEY) instanceof String; + final String masterRepositoryPasswordIdSalt = (String) snapshotUserMetadata.get(PASSWORD_SALT_USER_METADATA_KEY); + final String computedRepositoryPasswordId; + try { + computedRepositoryPasswordId = AESKeyUtils.computeId( + AESKeyUtils.generatePasswordBasedKey(repositoryPassword, masterRepositoryPasswordIdSalt) + ); + } catch (Exception e) { + throw new RepositoryException(metadata.name(), "Unexpected fatal internal error", e); + } + if (computedRepositoryPasswordId.equals(masterRepositoryPasswordId)) { + this.validatedLocalRepositoryPasswordHash = computedRepositoryPasswordId; + } else { + throw new RepositoryException( + metadata.name(), + "Repository password mismatch. The local node's repository password, from the keystore setting [" + + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace( + EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.get(metadata.settings()) + ).getKey() + + "], is different compared to the elected master node's which started the snapshot operation" + ); + } + } + } + + // pkg-private for tests + static final class EncryptedBlobStore implements BlobStore { + private final BlobStore delegatedBlobStore; + private final BlobPath delegatedBasePath; + private final String repositoryName; + private final Function> getKEKforDEK; + private final Cache dekCache; + private final CheckedSupplier singleUseDEKSupplier; + + EncryptedBlobStore( + BlobStore delegatedBlobStore, + BlobPath delegatedBasePath, + String repositoryName, + Function> getKEKforDEK, + Supplier> dekGenerator, + Cache dekCache + ) { + this.delegatedBlobStore = delegatedBlobStore; + this.delegatedBasePath = delegatedBasePath; + this.repositoryName = repositoryName; + this.getKEKforDEK = getKEKforDEK; + this.dekCache = dekCache; + this.singleUseDEKSupplier = SingleUseKey.createSingleUseKeySupplier(() -> { + Tuple newDEK = dekGenerator.get(); + // store the newly generated DEK before making it available + storeDEK(newDEK.v1().utf8ToString(), newDEK.v2()); + return newDEK; + }); + } + + // pkg-private for tests + SecretKey getDEKById(String dekId) throws IOException { + try { + return dekCache.computeIfAbsent(dekId, ignored -> loadDEK(dekId)); + } catch (ExecutionException e) { + // some exception types are to be expected + if (e.getCause() instanceof IOException) { + throw (IOException) e.getCause(); + } else if (e.getCause() instanceof ElasticsearchException) { + throw (ElasticsearchException) e.getCause(); + } else { + throw new RepositoryException(repositoryName, "Unexpected exception retrieving DEK [" + dekId + "]", e); + } + } + } + + private SecretKey loadDEK(String dekId) throws IOException { + final BlobPath dekBlobPath = delegatedBasePath.add(DEK_ROOT_CONTAINER).add(dekId); + logger.debug("Repository [{}] loading wrapped DEK [{}] from blob path {}", repositoryName, dekId, dekBlobPath); + final BlobContainer dekBlobContainer = delegatedBlobStore.blobContainer(dekBlobPath); + final Tuple kekTuple = getKEKforDEK.apply(dekId); + final String kekId = kekTuple.v1(); + final SecretKey kek = kekTuple.v2(); + logger.trace("Repository [{}] using KEK [{}] to unwrap DEK [{}]", repositoryName, kekId, dekId); + final byte[] encryptedDEKBytes = new byte[AESKeyUtils.WRAPPED_KEY_LENGTH_IN_BYTES]; + try (InputStream encryptedDEKInputStream = dekBlobContainer.readBlob(kekId)) { + final int bytesRead = Streams.readFully(encryptedDEKInputStream, encryptedDEKBytes); + if (bytesRead != AESKeyUtils.WRAPPED_KEY_LENGTH_IN_BYTES) { + throw new RepositoryException( + repositoryName, + "Wrapped DEK [" + dekId + "] has smaller length [" + bytesRead + "] than expected" + ); + } + if (encryptedDEKInputStream.read() != -1) { + throw new RepositoryException(repositoryName, "Wrapped DEK [" + dekId + "] is larger than expected"); + } + } catch (NoSuchFileException e) { + // do NOT throw IOException when the DEK does not exist, as this is a decryption problem, and IOExceptions + // can move the repository in the corrupted state + throw new ElasticsearchException( + "Failure to read and decrypt DEK [" + + dekId + + "] from " + + dekBlobContainer.path() + + ". Most likely the repository password is incorrect, where previous " + + "snapshots have used a different password.", + e + ); + } + logger.trace("Repository [{}] successfully read DEK [{}] from path {} {}", repositoryName, dekId, dekBlobPath, kekId); + try { + final SecretKey dek = AESKeyUtils.unwrap(kek, encryptedDEKBytes); + logger.debug("Repository [{}] successfully loaded DEK [{}] from path {} {}", repositoryName, dekId, dekBlobPath, kekId); + return dek; + } catch (GeneralSecurityException e) { + throw new RepositoryException( + repositoryName, + "Failure to AES unwrap the DEK [" + + dekId + + "]. " + + "Most likely the encryption metadata in the repository has been corrupted", + e + ); + } + } + + // pkg-private for tests + void storeDEK(String dekId, SecretKey dek) throws IOException { + final BlobPath dekBlobPath = delegatedBasePath.add(DEK_ROOT_CONTAINER).add(dekId); + logger.debug("Repository [{}] storing wrapped DEK [{}] under blob path {}", repositoryName, dekId, dekBlobPath); + final BlobContainer dekBlobContainer = delegatedBlobStore.blobContainer(dekBlobPath); + final Tuple kek = getKEKforDEK.apply(dekId); + logger.trace("Repository [{}] using KEK [{}] to wrap DEK [{}]", repositoryName, kek.v1(), dekId); + final byte[] encryptedDEKBytes; + try { + encryptedDEKBytes = AESKeyUtils.wrap(kek.v2(), dek); + if (encryptedDEKBytes.length != AESKeyUtils.WRAPPED_KEY_LENGTH_IN_BYTES) { + throw new RepositoryException( + repositoryName, + "Wrapped DEK [" + dekId + "] has unexpected length [" + encryptedDEKBytes.length + "]" + ); + } + } catch (GeneralSecurityException e) { + // throw unchecked ElasticsearchException; IOExceptions are interpreted differently and can move the repository in the + // corrupted state + throw new RepositoryException(repositoryName, "Failure to AES wrap the DEK [" + dekId + "]", e); + } + logger.trace("Repository [{}] successfully wrapped DEK [{}]", repositoryName, dekId); + dekBlobContainer.writeBlobAtomic(kek.v1(), new BytesArray(encryptedDEKBytes), true); + logger.debug("Repository [{}] successfully stored DEK [{}] under path {} {}", repositoryName, dekId, dekBlobPath, kek.v1()); + } + + @Override + public BlobContainer blobContainer(BlobPath path) { + final Iterator pathIterator = path.iterator(); + BlobPath delegatedBlobContainerPath = delegatedBasePath; + while (pathIterator.hasNext()) { + delegatedBlobContainerPath = delegatedBlobContainerPath.add(pathIterator.next()); + } + final BlobContainer delegatedBlobContainer = delegatedBlobStore.blobContainer(delegatedBlobContainerPath); + return new EncryptedBlobContainer(path, repositoryName, delegatedBlobContainer, singleUseDEKSupplier, this::getDEKById); + } + + @Override + public void close() { + // do NOT close delegatedBlobStore; it will be closed when the inner delegatedRepository is closed + } + } + + private static final class EncryptedBlobContainer extends AbstractBlobContainer { + private final String repositoryName; + private final BlobContainer delegatedBlobContainer; + // supplier for the DEK used for encryption (snapshot) + private final CheckedSupplier singleUseDEKSupplier; + // retrieves the DEK required for decryption (restore) + private final CheckedFunction getDEKById; + + EncryptedBlobContainer( + BlobPath path, // this path contains the {@code EncryptedRepository#basePath} which, importantly, is empty + String repositoryName, + BlobContainer delegatedBlobContainer, + CheckedSupplier singleUseDEKSupplier, + CheckedFunction getDEKById + ) { + super(path); + this.repositoryName = repositoryName; + final String rootPathElement = path.iterator().hasNext() ? path.iterator().next() : null; + if (DEK_ROOT_CONTAINER.equals(rootPathElement)) { + throw new RepositoryException(repositoryName, "Cannot descend into the DEK blob container " + path); + } + this.delegatedBlobContainer = delegatedBlobContainer; + this.singleUseDEKSupplier = singleUseDEKSupplier; + this.getDEKById = getDEKById; + } + + @Override + public boolean blobExists(String blobName) throws IOException { + return delegatedBlobContainer.blobExists(blobName); + } + + /** + * Returns a new {@link InputStream} for the given {@code blobName} that can be used to read the contents of the blob. + * The returned {@code InputStream} transparently handles the decryption of the blob contents, by first working out + * the blob name of the associated DEK id, reading and decrypting the DEK (given the repository password, unless the DEK is + * already cached because it had been used for other blobs before), and lastly reading and decrypting the data blob, + * in a streaming fashion, by employing the {@link DecryptionPacketsInputStream}. + * The {@code DecryptionPacketsInputStream} does not return un-authenticated data. + * + * @param blobName The name of the blob to get an {@link InputStream} for. + */ + @Override + public InputStream readBlob(String blobName) throws IOException { + // This MIGHT require two concurrent readBlob connections if the DEK is not already in the cache and if the encrypted blob + // is large enough so that the underlying network library keeps the connection open after reading the prepended DEK ID. + // Arguably this is a problem only under lab conditions, when the storage service is saturated only by the first read + // connection of the pair, so that the second read connection (for the DEK) can not be fulfilled. + // In this case the second connection will time-out which will trigger the closing of the first one, therefore + // allowing other pair connections to complete. + // In this situation the restore process should slowly make headway, albeit under read-timeout exceptions + final InputStream encryptedDataInputStream = delegatedBlobContainer.readBlob(blobName); + try { + // read the DEK Id (fixed length) which is prepended to the encrypted blob + final byte[] dekIdBytes = new byte[DEK_ID_LENGTH]; + final int bytesRead = Streams.readFully(encryptedDataInputStream, dekIdBytes); + if (bytesRead != DEK_ID_LENGTH) { + throw new RepositoryException(repositoryName, "The encrypted blob [" + blobName + "] is too small [" + bytesRead + "]"); + } + final String dekId = new String(dekIdBytes, StandardCharsets.UTF_8); + // might open a connection to read and decrypt the DEK, but most likely it will be served from cache + final SecretKey dek = getDEKById.apply(dekId); + // read and decrypt the rest of the blob + return new DecryptionPacketsInputStream(encryptedDataInputStream, dek, PACKET_LENGTH_IN_BYTES); + } catch (Exception e) { + try { + encryptedDataInputStream.close(); + } catch (IOException closeEx) { + e.addSuppressed(closeEx); + } + throw e; + } + } + + @Override + public InputStream readBlob(String blobName, long position, long length) throws IOException { + throw new UnsupportedOperationException("Not yet implemented"); + } + + /** + * Reads the blob content from the input stream and writes it to the container in a new blob with the given name. + * If {@code failIfAlreadyExists} is {@code true} and a blob with the same name already exists, the write operation will fail; + * otherwise, if {@code failIfAlreadyExists} is {@code false} the blob is overwritten. + * The contents are encrypted in a streaming fashion. The DEK (encryption key) is randomly generated and reused for encrypting + * subsequent blobs such that the same IV is not reused together with the same key. + * The DEK encryption key is separately stored in a different blob, which is encrypted with the repository key. + * + * @param blobName + * The name of the blob to write the contents of the input stream to. + * @param inputStream + * The input stream from which to retrieve the bytes to write to the blob. + * @param blobSize + * The size of the blob to be written, in bytes. The actual number of bytes written to the storage service is larger + * because of encryption and authentication overhead. It is implementation dependent whether this value is used + * in writing the blob to the repository. + * @param failIfAlreadyExists + * whether to throw a FileAlreadyExistsException if the given blob already exists + */ + @Override + public void writeBlob(String blobName, InputStream inputStream, long blobSize, boolean failIfAlreadyExists) throws IOException { + // reuse, but possibly generate and store a new DEK + final SingleUseKey singleUseNonceAndDEK = singleUseDEKSupplier.get(); + final BytesReference dekIdBytes = singleUseNonceAndDEK.getKeyId(); + if (dekIdBytes.length() != DEK_ID_LENGTH) { + throw new RepositoryException( + repositoryName, + "Unexpected fatal internal error", + new IllegalStateException("Unexpected DEK Id length [" + dekIdBytes.length() + "]") + ); + } + final long encryptedBlobSize = getEncryptedBlobByteLength(blobSize); + try ( + InputStream encryptedInputStream = ChainingInputStream.chain( + dekIdBytes.streamInput(), + new EncryptionPacketsInputStream( + inputStream, + singleUseNonceAndDEK.getKey(), + singleUseNonceAndDEK.getNonce(), + PACKET_LENGTH_IN_BYTES + ) + ) + ) { + delegatedBlobContainer.writeBlob(blobName, encryptedInputStream, encryptedBlobSize, failIfAlreadyExists); + } + } + + @Override + public void writeBlobAtomic(String blobName, BytesReference bytes, boolean failIfAlreadyExists) throws IOException { + // the encrypted repository does not offer an alternative implementation for atomic writes + // fallback to regular write + writeBlob(blobName, bytes, failIfAlreadyExists); + } + + @Override + public DeleteResult delete() throws IOException { + return delegatedBlobContainer.delete(); + } + + @Override + public void deleteBlobsIgnoringIfNotExists(List blobNames) throws IOException { + delegatedBlobContainer.deleteBlobsIgnoringIfNotExists(blobNames); + } + + @Override + public Map listBlobs() throws IOException { + return delegatedBlobContainer.listBlobs(); + } + + @Override + public Map listBlobsByPrefix(String blobNamePrefix) throws IOException { + return delegatedBlobContainer.listBlobsByPrefix(blobNamePrefix); + } + + @Override + public Map children() throws IOException { + final Map childEncryptedBlobContainers = delegatedBlobContainer.children(); + final Map resultBuilder = new HashMap<>(childEncryptedBlobContainers.size()); + for (Map.Entry childBlobContainer : childEncryptedBlobContainers.entrySet()) { + if (childBlobContainer.getKey().equals(DEK_ROOT_CONTAINER) && false == path().iterator().hasNext()) { + // do not descend into the DEK blob container + continue; + } + // get an encrypted blob container for each child + // Note that the encryption metadata blob container might be missing + resultBuilder.put( + childBlobContainer.getKey(), + new EncryptedBlobContainer( + path().add(childBlobContainer.getKey()), + repositoryName, + childBlobContainer.getValue(), + singleUseDEKSupplier, + getDEKById + ) + ); + } + return Map.copyOf(resultBuilder); + } + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryPlugin.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryPlugin.java new file mode 100644 index 0000000000000..020507cbbb2ab --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryPlugin.java @@ -0,0 +1,199 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Build; +import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureSetting; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.env.Environment; +import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.RepositoryPlugin; +import org.elasticsearch.repositories.Repository; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.xpack.core.XPackPlugin; + +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; + +public class EncryptedRepositoryPlugin extends Plugin implements RepositoryPlugin { + + private static final Boolean ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED; + static { + final String property = System.getProperty("es.encrypted_repository_feature_flag_registered"); + if (Build.CURRENT.isSnapshot() && property != null) { + throw new IllegalArgumentException("es.encrypted_repository_feature_flag_registered is only supported in non-snapshot builds"); + } + if ("true".equals(property)) { + ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED = true; + } else if ("false".equals(property)) { + ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED = false; + } else if (property == null) { + ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED = null; + } else { + throw new IllegalArgumentException( + "expected es.encrypted_repository_feature_flag_registered to be unset or [true|false] but was [" + property + "]" + ); + } + } + + static final Logger logger = LogManager.getLogger(EncryptedRepositoryPlugin.class); + static final String REPOSITORY_TYPE_NAME = "encrypted"; + // TODO add at least hdfs, and investigate supporting all `BlobStoreRepository` implementations + static final List SUPPORTED_ENCRYPTED_TYPE_NAMES = Arrays.asList("fs", "gcs", "azure", "s3"); + static final Setting.AffixSetting ENCRYPTION_PASSWORD_SETTING = Setting.affixKeySetting( + "repository.encrypted.", + "password", + key -> SecureSetting.secureString(key, null) + ); + static final Setting DELEGATE_TYPE_SETTING = Setting.simpleString("delegate_type", ""); + static final Setting PASSWORD_NAME_SETTING = Setting.simpleString("password_name", ""); + + // "protected" because it is overloaded for tests + protected XPackLicenseState getLicenseState() { + return XPackPlugin.getSharedLicenseState(); + } + + @Override + public List> getSettings() { + return List.of(ENCRYPTION_PASSWORD_SETTING); + } + + @Override + public Map getRepositories( + Environment env, + NamedXContentRegistry registry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings + ) { + // load all the passwords from the keystore in memory because the keystore is not readable when the repository is created + final Map repositoryPasswordsMapBuilder = new HashMap<>(); + for (String passwordName : ENCRYPTION_PASSWORD_SETTING.getNamespaces(env.settings())) { + Setting passwordSetting = ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(passwordName); + repositoryPasswordsMapBuilder.put(passwordName, passwordSetting.get(env.settings())); + logger.debug("Loaded repository password [{}] from the node keystore", passwordName); + } + final Map repositoryPasswordsMap = Map.copyOf(repositoryPasswordsMapBuilder); + + if (false == Build.CURRENT.isSnapshot() + && (ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED == null || ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED == false)) { + return Map.of(); + } + + return Collections.singletonMap(REPOSITORY_TYPE_NAME, new Repository.Factory() { + + @Override + public Repository create(RepositoryMetadata metadata) { + throw new UnsupportedOperationException(); + } + + @Override + public Repository create(RepositoryMetadata metadata, Function typeLookup) throws Exception { + final String delegateType = DELEGATE_TYPE_SETTING.get(metadata.settings()); + if (Strings.hasLength(delegateType) == false) { + throw new IllegalArgumentException("Repository setting [" + DELEGATE_TYPE_SETTING.getKey() + "] must be set"); + } + if (REPOSITORY_TYPE_NAME.equals(delegateType)) { + throw new IllegalArgumentException( + "Cannot encrypt an already encrypted repository. [" + + DELEGATE_TYPE_SETTING.getKey() + + "] must not be equal to [" + + REPOSITORY_TYPE_NAME + + "]" + ); + } + final Repository.Factory factory = typeLookup.apply(delegateType); + if (null == factory || false == SUPPORTED_ENCRYPTED_TYPE_NAMES.contains(delegateType)) { + throw new IllegalArgumentException( + "Unsupported delegate repository type [" + delegateType + "] for setting [" + DELEGATE_TYPE_SETTING.getKey() + "]" + ); + } + final String repositoryPasswordName = PASSWORD_NAME_SETTING.get(metadata.settings()); + if (Strings.hasLength(repositoryPasswordName) == false) { + throw new IllegalArgumentException("Repository setting [" + PASSWORD_NAME_SETTING.getKey() + "] must be set"); + } + final SecureString repositoryPassword = repositoryPasswordsMap.get(repositoryPasswordName); + if (repositoryPassword == null) { + throw new IllegalArgumentException( + "Secure setting [" + + ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryPasswordName).getKey() + + "] must be set" + ); + } + final Repository delegatedRepository = factory.create( + new RepositoryMetadata(metadata.name(), delegateType, metadata.settings()) + ); + if (false == (delegatedRepository instanceof BlobStoreRepository) || delegatedRepository instanceof EncryptedRepository) { + throw new IllegalArgumentException("Unsupported delegate repository type [" + DELEGATE_TYPE_SETTING.getKey() + "]"); + } + if (false == getLicenseState().checkFeature(XPackLicenseState.Feature.ENCRYPTED_SNAPSHOT)) { + logger.warn( + new ParameterizedMessage( + "Encrypted snapshots are not allowed for the currently installed license [{}]." + + " Snapshots to the [{}] encrypted repository are not permitted." + + " All the other operations, including restore, work without restrictions.", + getLicenseState().getOperationMode().description(), + metadata.name() + ), + LicenseUtils.newComplianceException("encrypted snapshots") + ); + } + return createEncryptedRepository( + metadata, + registry, + clusterService, + bigArrays, + recoverySettings, + (BlobStoreRepository) delegatedRepository, + () -> getLicenseState(), + repositoryPassword + ); + } + }); + } + + // protected for tests + protected EncryptedRepository createEncryptedRepository( + RepositoryMetadata metadata, + NamedXContentRegistry registry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings, + BlobStoreRepository delegatedRepository, + Supplier licenseStateSupplier, + SecureString repoPassword + ) throws GeneralSecurityException { + return new EncryptedRepository( + metadata, + registry, + clusterService, + bigArrays, + recoverySettings, + delegatedRepository, + licenseStateSupplier, + repoPassword + ); + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStream.java new file mode 100644 index 0000000000000..a08ea4216c94d --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStream.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.core.internal.io.IOUtils; + +import javax.crypto.Cipher; +import javax.crypto.CipherInputStream; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Objects; + +/** + * An {@code EncryptionPacketsInputStream} wraps another input stream and encrypts its contents. + * The method of encryption is AES/GCM/NoPadding, which is a type of authenticated encryption. + * The encryption works packet wise, i.e. the stream is segmented into fixed-size byte packets + * which are separately encrypted using a unique {@link Cipher}. As an exception, only the last + * packet will have a different size, possibly zero. Note that the encrypted packets are + * larger compared to the plaintext packets, because they contain a 16 byte length trailing + * authentication tag. The resulting encrypted and authenticated packets are assembled back into + * the resulting stream. + *

+ * The packets are encrypted using the same {@link SecretKey} but using a different initialization + * vector. The IV of each packet is 12 bytes wide and is comprised of a 4-byte integer {@code nonce}, + * the same for every packet in the stream, and a monotonically increasing 8-byte integer counter. + * The caller must assure that the same {@code nonce} is not reused for other encrypted streams + * using the same {@code secretKey}. The counter from the IV identifies the position of the packet + * in the encrypted stream, so that packets cannot be reordered without breaking the decryption. + * When assembling the encrypted stream, the IV is prepended to the corresponding packet's ciphertext. + *

+ * The packet length is preferably a large multiple (typically 128) of the AES block size (128 bytes), + * but any positive integer value smaller than {@link EncryptedRepository#MAX_PACKET_LENGTH_IN_BYTES} + * is valid. A larger packet length incurs smaller relative size overhead because the 12 byte wide IV + * and the 16 byte wide authentication tag are constant no matter the packet length. A larger packet + * length also exposes more opportunities for the JIT compilation of the AES encryption loop. But + * {@code mark} will buffer up to packet length bytes, and, more importantly, decryption might + * need to allocate a memory buffer the size of the packet in order to assure that no un-authenticated + * decrypted ciphertext is returned. The decryption procedure is the primary factor that limits the + * packet length. + *

+ * This input stream supports the {@code mark} and {@code reset} operations, but only if the wrapped + * stream supports them as well. A {@code mark} call will trigger the memory buffering of the current + * packet and will also trigger a {@code mark} call on the wrapped input stream on the next + * packet boundary. Upon a {@code reset} call, the buffered packet will be replayed and new packets + * will be generated starting from the marked packet boundary on the wrapped stream. + *

+ * The {@code close} call will close the encryption input stream and any subsequent {@code read}, + * {@code skip}, {@code available} and {@code reset} calls will throw {@code IOException}s. + *

+ * This is NOT thread-safe, multiple threads sharing a single instance must synchronize access. + * + * @see DecryptionPacketsInputStream + */ +public final class EncryptionPacketsInputStream extends ChainingInputStream { + + private final SecretKey secretKey; + private final int packetLength; + private final ByteBuffer packetIv; + private final int encryptedPacketLength; + + final InputStream source; // package-protected for tests + long counter; // package-protected for tests + Long markCounter; // package-protected for tests + int markSourceOnNextPacket; // package-protected for tests + + /** + * Computes and returns the length of the ciphertext given the {@code plaintextLength} and the {@code packetLength} + * used during encryption. + * The plaintext is segmented into packets of equal {@code packetLength} length, with the exception of the last + * packet which is shorter and can have a length of {@code 0}. Encryption is packet-wise and is 1:1, with no padding. + * But each encrypted packet is prepended by the Initilization Vector and appended the Authentication Tag, including + * the last packet, so when pieced together will amount to a longer resulting ciphertext. + * + * @see DecryptionPacketsInputStream#getDecryptionLength(long, int) + */ + public static long getEncryptionLength(long plaintextLength, int packetLength) { + return plaintextLength + (plaintextLength / packetLength + 1) * (EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES + + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES); + } + + public EncryptionPacketsInputStream(InputStream source, SecretKey secretKey, int nonce, int packetLength) { + this.source = Objects.requireNonNull(source); + this.secretKey = Objects.requireNonNull(secretKey); + if (packetLength <= 0 || packetLength >= EncryptedRepository.MAX_PACKET_LENGTH_IN_BYTES) { + throw new IllegalArgumentException("Invalid packet length [" + packetLength + "]"); + } + this.packetLength = packetLength; + this.packetIv = ByteBuffer.allocate(EncryptedRepository.GCM_IV_LENGTH_IN_BYTES).order(ByteOrder.LITTLE_ENDIAN); + // nonce takes the first 4 bytes of the IV + this.packetIv.putInt(0, nonce); + this.encryptedPacketLength = packetLength + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + this.counter = EncryptedRepository.PACKET_START_COUNTER; + this.markCounter = null; + this.markSourceOnNextPacket = -1; + } + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + // the last packet input stream is the only one shorter than encryptedPacketLength + if (currentComponentIn != null && ((CountingInputStream) currentComponentIn).getCount() < encryptedPacketLength) { + // there are no more packets + return null; + } + // If the enclosing stream has a mark set, + // then apply it to the source input stream when we reach a packet boundary + if (markSourceOnNextPacket != -1) { + source.mark(markSourceOnNextPacket); + markSourceOnNextPacket = -1; + } + // create the new packet + InputStream encryptionInputStream = new PrefixInputStream(source, packetLength, false); + // the counter takes up the last 8 bytes of the packet IV (12 byte wide) + // the first 4 bytes are used by the nonce (which is the same for every packet IV) + packetIv.putLong(Integer.BYTES, counter++); + // counter wrap around + if (counter == EncryptedRepository.PACKET_START_COUNTER) { + throw new IOException("Maximum packet count limit exceeded"); + } + Cipher packetCipher = getPacketEncryptionCipher(secretKey, packetIv.array()); + encryptionInputStream = new CipherInputStream(encryptionInputStream, packetCipher); + encryptionInputStream = new SequenceInputStream(new ByteArrayInputStream(packetIv.array()), encryptionInputStream); + encryptionInputStream = new BufferOnMarkInputStream(encryptionInputStream, encryptedPacketLength); + return new CountingInputStream(encryptionInputStream, false); + } + + // remove after https://github.com/elastic/elasticsearch/pull/66769 is merged in + @Override + public int available() throws IOException { + return 0; + } + + @Override + public boolean markSupported() { + return source.markSupported(); + } + + @Override + public void mark(int readlimit) { + if (markSupported()) { + if (readlimit <= 0) { + throw new IllegalArgumentException("Mark readlimit must be a positive integer"); + } + // handles the packet-wise part of the marking operation + super.mark(encryptedPacketLength); + // saves the counter used to generate packet IVs + markCounter = counter; + // stores the flag used to mark the source input stream at packet boundary + markSourceOnNextPacket = readlimit; + } + } + + @Override + public void reset() throws IOException { + if (false == markSupported()) { + throw new IOException("Mark/reset not supported"); + } + if (markCounter == null) { + throw new IOException("Mark no set"); + } + super.reset(); + counter = markCounter; + if (markSourceOnNextPacket == -1) { + source.reset(); + } + } + + @Override + public void close() throws IOException { + IOUtils.close(super::close, source); + } + + private static Cipher getPacketEncryptionCipher(SecretKey secretKey, byte[] packetIv) throws IOException { + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES * Byte.SIZE, packetIv); + try { + Cipher packetCipher = Cipher.getInstance(EncryptedRepository.DATA_ENCRYPTION_SCHEME); + packetCipher.init(Cipher.ENCRYPT_MODE, secretKey, gcmParameterSpec); + return packetCipher; + } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | InvalidAlgorithmParameterException e) { + throw new IOException(e); + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/PrefixInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/PrefixInputStream.java new file mode 100644 index 0000000000000..873ffc319e176 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/PrefixInputStream.java @@ -0,0 +1,150 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; + +/** + * A {@code PrefixInputStream} wraps another input stream and exposes + * only the first bytes of it. Reading from the wrapping + * {@code PrefixInputStream} consumes the underlying stream. The stream + * is exhausted when {@code prefixLength} bytes have been read, or the underlying + * stream is exhausted before that. + *

+ * Only if the {@code closeSource} constructor argument is {@code true}, the + * closing of this stream will also close the underlying input stream. + * Any subsequent {@code read}, {@code skip} and {@code available} calls + * will throw {@code IOException}s. + */ +public final class PrefixInputStream extends InputStream { + + /** + * The underlying stream of which only a prefix is returned + */ + private final InputStream source; + /** + * The length in bytes of the prefix. + * This is the maximum number of bytes that can be read from this stream, + * but fewer bytes can be read if the wrapped source stream itself contains fewer bytes + */ + private final int prefixLength; + /** + * The current count of bytes read from this stream. + * This starts of as {@code 0} and is always smaller or equal to {@code prefixLength}. + */ + private int count; + /** + * whether closing this stream must also close the underlying stream + */ + private boolean closeSource; + /** + * flag signalling if this stream has been closed + */ + private boolean closed; + + public PrefixInputStream(InputStream source, int prefixLength, boolean closeSource) { + if (prefixLength < 0) { + throw new IllegalArgumentException("The prefixLength constructor argument must be a positive integer"); + } + this.source = source; + this.prefixLength = prefixLength; + this.count = 0; + this.closeSource = closeSource; + this.closed = false; + } + + @Override + public int read() throws IOException { + ensureOpen(); + if (remainingPrefixByteCount() <= 0) { + return -1; + } + int byteVal = source.read(); + if (byteVal == -1) { + return -1; + } + count++; + return byteVal; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + ensureOpen(); + Objects.checkFromIndexSize(off, len, b.length); + if (len == 0) { + return 0; + } + if (remainingPrefixByteCount() <= 0) { + return -1; + } + int readSize = Math.min(len, remainingPrefixByteCount()); + int bytesRead = source.read(b, off, readSize); + if (bytesRead == -1) { + return -1; + } + count += bytesRead; + return bytesRead; + } + + @Override + public long skip(long n) throws IOException { + ensureOpen(); + if (n <= 0 || remainingPrefixByteCount() <= 0) { + return 0; + } + long bytesToSkip = Math.min(n, remainingPrefixByteCount()); + assert bytesToSkip > 0; + long bytesSkipped = source.skip(bytesToSkip); + count += bytesSkipped; + return bytesSkipped; + } + + @Override + public int available() throws IOException { + ensureOpen(); + return Math.min(remainingPrefixByteCount(), source.available()); + } + + @Override + public boolean markSupported() { + return false; + } + + @Override + public void mark(int readlimit) { + // mark and reset are not supported + } + + @Override + public void reset() throws IOException { + throw new IOException("mark/reset not supported"); + } + + @Override + public void close() throws IOException { + if (closed) { + return; + } + closed = true; + if (closeSource) { + source.close(); + } + } + + private int remainingPrefixByteCount() { + return prefixLength - count; + } + + private void ensureOpen() throws IOException { + if (closed) { + throw new IOException("Stream has been closed"); + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/SingleUseKey.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/SingleUseKey.java new file mode 100644 index 0000000000000..fe4729e8acec1 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/SingleUseKey.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; + +import javax.crypto.SecretKey; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Container class for a {@code SecretKey} with a unique identifier, and a 4-byte wide {@code Integer} nonce, that can be used for a + * single encryption operation. Use {@link #createSingleUseKeySupplier(CheckedSupplier)} to obtain a {@code Supplier} that returns + * a new {@link SingleUseKey} instance on every invocation. The number of unique {@code SecretKey}s (and their associated identifiers) + * generated is minimized and, at the same time, ensuring that a given {@code nonce} is not reused with the same key. + */ +final class SingleUseKey { + private static final Logger logger = LogManager.getLogger(SingleUseKey.class); + static final int MIN_NONCE = Integer.MIN_VALUE; + static final int MAX_NONCE = Integer.MAX_VALUE; + private static final int MAX_ATTEMPTS = 9; + private static final SingleUseKey EXPIRED_KEY = new SingleUseKey(null, null, MAX_NONCE); + + private final BytesReference keyId; + private final SecretKey key; + private final int nonce; + + // for tests use only! + SingleUseKey(BytesReference KeyId, SecretKey Key, int nonce) { + this.keyId = KeyId; + this.key = Key; + this.nonce = nonce; + } + + public BytesReference getKeyId() { + return keyId; + } + + public SecretKey getKey() { + return key; + } + + public int getNonce() { + return nonce; + } + + /** + * Returns a {@code CheckedSupplier} of {@code SingleUseKey}s so that no two instances contain the same key and nonce pair. + * The current implementation increments the {@code nonce} while keeping the key constant, until the {@code nonce} space + * is exhausted, at which moment a new key is generated and the {@code nonce} is reset back. + * + * @param keyGenerator supplier for the key and the key id + */ + static CheckedSupplier createSingleUseKeySupplier( + CheckedSupplier, T> keyGenerator + ) { + final AtomicReference keyCurrentlyInUse = new AtomicReference<>(EXPIRED_KEY); + return internalSingleUseKeySupplier(keyGenerator, keyCurrentlyInUse); + } + + // for tests use only, the {@code keyCurrentlyInUse} must not be exposed to caller code + static CheckedSupplier internalSingleUseKeySupplier( + CheckedSupplier, T> keyGenerator, + AtomicReference keyCurrentlyInUse + ) { + final Object lock = new Object(); + return () -> { + for (int attemptNo = 0; attemptNo < MAX_ATTEMPTS; attemptNo++) { + final SingleUseKey nonceAndKey = keyCurrentlyInUse.getAndUpdate( + prev -> prev.nonce < MAX_NONCE ? new SingleUseKey(prev.keyId, prev.key, prev.nonce + 1) : EXPIRED_KEY + ); + if (nonceAndKey.nonce < MAX_NONCE) { + // this is the commonly used code path, where just the nonce is incremented + logger.trace( + () -> new ParameterizedMessage("Key with id [{}] reused with nonce [{}]", nonceAndKey.keyId, nonceAndKey.nonce) + ); + return nonceAndKey; + } else { + // this is the infrequent code path, where a new key is generated and the nonce is reset back + logger.trace( + () -> new ParameterizedMessage("Try to generate a new key to replace the key with id [{}]", nonceAndKey.keyId) + ); + synchronized (lock) { + if (keyCurrentlyInUse.get().nonce == MAX_NONCE) { + final Tuple newKey = keyGenerator.get(); + logger.debug(() -> new ParameterizedMessage("New key with id [{}] has been generated", newKey.v1())); + keyCurrentlyInUse.set(new SingleUseKey(newKey.v1(), newKey.v2(), MIN_NONCE)); + } + } + } + } + throw new IllegalStateException("Failure to generate new key"); + }; + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/repository-encrypted/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..7f75e2af67c6e --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +grant { +}; diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/AESKeyUtilsTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/AESKeyUtilsTests.java new file mode 100644 index 0000000000000..d30bc34a0f237 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/AESKeyUtilsTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.test.ESTestCase; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; + +import java.security.InvalidKeyException; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class AESKeyUtilsTests extends ESTestCase { + + public void testWrapUnwrap() throws Exception { + byte[] keyToWrapBytes = randomByteArrayOfLength(AESKeyUtils.KEY_LENGTH_IN_BYTES); + SecretKey keyToWrap = new SecretKeySpec(keyToWrapBytes, "AES"); + byte[] wrappingKeyBytes = randomByteArrayOfLength(AESKeyUtils.KEY_LENGTH_IN_BYTES); + SecretKey wrappingKey = new SecretKeySpec(wrappingKeyBytes, "AES"); + byte[] wrappedKey = AESKeyUtils.wrap(wrappingKey, keyToWrap); + assertThat(wrappedKey.length, equalTo(AESKeyUtils.WRAPPED_KEY_LENGTH_IN_BYTES)); + SecretKey unwrappedKey = AESKeyUtils.unwrap(wrappingKey, wrappedKey); + assertThat(unwrappedKey, equalTo(keyToWrap)); + } + + public void testComputeId() throws Exception { + byte[] key1Bytes = randomByteArrayOfLength(AESKeyUtils.KEY_LENGTH_IN_BYTES); + SecretKey key1 = new SecretKeySpec(key1Bytes, "AES"); + byte[] key2Bytes = randomByteArrayOfLength(AESKeyUtils.KEY_LENGTH_IN_BYTES); + SecretKey key2 = new SecretKeySpec(key2Bytes, "AES"); + assertThat(AESKeyUtils.computeId(key1), not(equalTo(AESKeyUtils.computeId(key2)))); + assertThat(AESKeyUtils.computeId(key1), equalTo(AESKeyUtils.computeId(key1))); + assertThat(AESKeyUtils.computeId(key2), equalTo(AESKeyUtils.computeId(key2))); + } + + public void testFailedWrapUnwrap() throws Exception { + byte[] toWrapBytes = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 }; + SecretKey keyToWrap = new SecretKeySpec(toWrapBytes, "AES"); + byte[] wrapBytes = new byte[] { 0, 0, 0, 0, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 0, 0, 0 }; + SecretKey wrappingKey = new SecretKeySpec(wrapBytes, "AES"); + byte[] wrappedKey = AESKeyUtils.wrap(wrappingKey, keyToWrap); + for (int i = 0; i < wrappedKey.length; i++) { + wrappedKey[i] ^= 0xFFFFFFFF; + expectThrows(InvalidKeyException.class, () -> AESKeyUtils.unwrap(wrappingKey, wrappedKey)); + wrappedKey[i] ^= 0xFFFFFFFF; + } + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStreamTests.java new file mode 100644 index 0000000000000..239d04268e145 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStreamTests.java @@ -0,0 +1,853 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.BeforeClass; + +import java.io.ByteArrayInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class BufferOnMarkInputStreamTests extends ESTestCase { + + private static byte[] testArray; + + @BeforeClass + static void createTestArray() throws Exception { + testArray = new byte[128]; + for (int i = 0; i < testArray.length; i++) { + testArray[i] = (byte) i; + } + } + + public void testResetWithoutMarkFails() throws Exception { + Tuple mockSourceTuple = getMockInfiniteInputStream(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), 1 + Randomness.get().nextInt(1024)); + // maybe read some bytes + test.readNBytes(randomFrom(0, randomInt(31))); + IOException e = expectThrows(IOException.class, () -> { test.reset(); }); + assertThat(e.getMessage(), Matchers.is("Mark not called or has been invalidated")); + } + + public void testMarkAndBufferReadLimitsCheck() throws Exception { + Tuple mockSourceTuple = getMockInfiniteInputStream(); + int bufferSize = randomIntBetween(1, 1024); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + assertThat(test.getMaxMarkReadlimit(), Matchers.is(bufferSize)); + // maybe read some bytes + test.readNBytes(randomFrom(0, randomInt(32))); + int wrongLargeReadLimit = bufferSize + randomIntBetween(1, 8); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> { test.mark(wrongLargeReadLimit); }); + assertThat( + e.getMessage(), + Matchers.is("Readlimit value [" + wrongLargeReadLimit + "] exceeds the maximum value of [" + bufferSize + "]") + ); + e = expectThrows(IllegalArgumentException.class, () -> { test.mark(-1 - randomInt(1)); }); + assertThat(e.getMessage(), Matchers.containsString("cannot be negative")); + e = expectThrows(IllegalArgumentException.class, () -> { new BufferOnMarkInputStream(mock(InputStream.class), 0 - randomInt(1)); }); + assertThat(e.getMessage(), Matchers.is("The buffersize constructor argument must be a strictly positive value")); + } + + public void testCloseRejectsSuccessiveCalls() throws Exception { + int bufferSize = 3 + Randomness.get().nextInt(128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + test.readNBytes(randomFrom(0, Randomness.get().nextInt(32))); + test.close(); + int bytesReadBefore = bytesRead.get(); + IOException e = expectThrows(IOException.class, () -> { test.read(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { + byte[] b = new byte[1 + Randomness.get().nextInt(32)]; + test.read(b, 0, 1 + Randomness.get().nextInt(b.length)); + }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.skip(1 + Randomness.get().nextInt(32)); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.available(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.reset(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + int bytesReadAfter = bytesRead.get(); + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + } + + public void testBufferingUponMark() throws Exception { + int bufferSize = randomIntBetween(3, 128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + // read without mark, should be a simple pass-through with the same byte count + int bytesReadBefore = bytesRead.get(); + assertThat(test.read(), Matchers.not(-1)); + int bytesReadAfter = bytesRead.get(); + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(1)); + int readLen = randomIntBetween(1, 8); + bytesReadBefore = bytesRead.get(); + if (randomBoolean()) { + test.readNBytes(readLen); + } else { + skipNBytes(test, readLen); + } + bytesReadAfter = bytesRead.get(); + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert no buffering + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read one byte + bytesReadBefore = bytesRead.get(); + assertThat(test.read(), Matchers.not(-1)); + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(1)); + // assert byte is buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - 1)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(1)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // read more bytes, up to buffer size bytes + readLen = randomIntBetween(1, bufferSize - 1); + bytesReadBefore = bytesRead.get(); + if (randomBoolean()) { + test.readNBytes(readLen); + } else { + skipNBytes(test, readLen); + } + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert byte is buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - 1 - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(1 + readLen)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + assertThat(test.storeToBuffer, Matchers.is(true)); + } + + public void testMarkInvalidation() throws Exception { + int bufferSize = randomIntBetween(3, 128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.storeToBuffer, Matchers.is(false)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read all bytes to fill the mark buffer + int bytesReadBefore = bytesRead.get(); + // read enough to populate the full buffer space + int readLen = bufferSize; + if (randomBoolean()) { + test.readNBytes(readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert byte is buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(bufferSize)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + assertThat(test.storeToBuffer, Matchers.is(true)); + // read another one byte + bytesReadBefore = bytesRead.get(); + assertThat(test.read(), Matchers.not(-1)); + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(1)); + // assert mark is invalidated and no buffering is further performed + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + assertThat(test.storeToBuffer, Matchers.is(false)); + // read more bytes + bytesReadBefore = bytesRead.get(); + readLen = randomIntBetween(1, 2 * bufferSize); + if (randomBoolean()) { + test.readNBytes(readLen); + } else { + skipNBytes(test, readLen); + } + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert byte again is NOT buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.storeToBuffer, Matchers.is(false)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // assert reset does not work any more + IOException e = expectThrows(IOException.class, () -> { test.reset(); }); + assertThat(e.getMessage(), Matchers.is("Mark not called or has been invalidated")); + } + + public void testConsumeBufferUponReset() throws Exception { + int bufferSize = randomIntBetween(3, 128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + test.readNBytes(randomFrom(0, randomInt(32))); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read less than bufferSize bytes + int bytesReadBefore = bytesRead.get(); + int readLen = randomIntBetween(1, bufferSize); + if (randomBoolean()) { + test.readNBytes(readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert buffer is populated + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // reset + test.reset(); + assertThat(test.replayFromBuffer, Matchers.is(true)); + assertThat(test.storeToBuffer, Matchers.is(true)); + // read again, from buffer this time + bytesReadBefore = bytesRead.get(); + int readLen2 = randomIntBetween(1, readLen); + if (randomBoolean()) { + test.readNBytes(readLen2); + } else { + skipNBytes(test, readLen2); + } + bytesReadAfter = bytesRead.get(); + // assert bytes are replayed from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + // assert buffer is consumed + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - readLen2)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(true)); + } + + public void testInvalidateMarkAfterReset() throws Exception { + int bufferSize = randomIntBetween(3, 128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + test.readNBytes(randomFrom(0, randomInt(32))); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read less than bufferSize bytes + int bytesReadBefore = bytesRead.get(); + int readLen = randomIntBetween(1, bufferSize); + if (randomBoolean()) { + test.readNBytes(readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert buffer is populated + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // reset + test.reset(); + // assert signal for replay from buffer is toggled + assertThat(test.replayFromBuffer, Matchers.is(true)); + assertThat(test.storeToBuffer, Matchers.is(true)); + // assert bytes are still buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + // read again, from buffer this time + bytesReadBefore = bytesRead.get(); + // read all bytes from the buffer + int readLen2 = readLen; + if (randomBoolean()) { + test.readNBytes(readLen2); + } else { + skipNBytes(test, readLen2); + } + bytesReadAfter = bytesRead.get(); + // assert bytes are replayed from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + // assert buffer is consumed + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(true)); + // read on, from the stream, until the mark buffer is full + bytesReadBefore = bytesRead.get(); + // read the remaining bytes to fill the buffer + int readLen3 = bufferSize - readLen; + if (randomBoolean()) { + test.readNBytes(readLen3); + } else { + skipNBytes(test, readLen3); + } + bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen3)); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen3)); + assertThat(test.storeToBuffer, Matchers.is(true)); + if (readLen3 > 0) { + assertThat(test.replayFromBuffer, Matchers.is(false)); + } else { + assertThat(test.replayFromBuffer, Matchers.is(true)); + } + // read more bytes + bytesReadBefore = bytesRead.get(); + int readLen4 = randomIntBetween(1, 2 * bufferSize); + if (randomBoolean()) { + test.readNBytes(readLen4); + } else { + skipNBytes(test, readLen4); + } + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen4)); + // assert mark reset + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.storeToBuffer, Matchers.is(false)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // assert reset does not work anymore + IOException e = expectThrows(IOException.class, () -> { test.reset(); }); + assertThat(e.getMessage(), Matchers.is("Mark not called or has been invalidated")); + } + + public void testMarkAfterResetWhileReplayingBuffer() throws Exception { + int bufferSize = randomIntBetween(8, 16); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + test.readNBytes(randomFrom(0, randomInt(32))); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read less than bufferSize bytes + int bytesReadBefore = bytesRead.get(); + int readLen = randomIntBetween(1, bufferSize); + if (randomBoolean()) { + test.readNBytes(readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert buffer is populated + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // reset + test.reset(); + assertThat(test.replayFromBuffer, Matchers.is(true)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + // read bytes after reset + for (int readLen2 = 1; readLen2 <= readLen; readLen2++) { + Tuple mockSourceTuple2 = getMockInfiniteInputStream(); + BufferOnMarkInputStream cloneTest = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + cloneBufferOnMarkStream(cloneTest, test); + AtomicInteger bytesRead2 = mockSourceTuple2.v1(); + // read again, from buffer this time, less than before + bytesReadBefore = bytesRead2.get(); + if (randomBoolean()) { + cloneTest.readNBytes(readLen2); + } else { + skipNBytes(cloneTest, readLen2); + } + bytesReadAfter = bytesRead2.get(); + // assert bytes are replayed from the buffer, and not read from the stream + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + // assert buffer is consumed + assertThat(cloneTest.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(cloneTest.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - readLen2)); + assertThat(cloneTest.storeToBuffer, Matchers.is(true)); + assertThat(cloneTest.replayFromBuffer, Matchers.is(true)); + // mark inside the buffer after reset + cloneTest.mark(randomIntBetween(1, bufferSize)); + assertThat(cloneTest.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen + readLen2)); + assertThat(cloneTest.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - readLen2)); + assertThat(cloneTest.storeToBuffer, Matchers.is(true)); + assertThat(cloneTest.replayFromBuffer, Matchers.is(true)); + // read until the buffer is filled + for (int readLen3 = 1; readLen3 <= readLen - readLen2; readLen3++) { + Tuple mockSourceTuple3 = getMockInfiniteInputStream(); + BufferOnMarkInputStream cloneTest3 = new BufferOnMarkInputStream(mockSourceTuple3.v2(), bufferSize); + cloneBufferOnMarkStream(cloneTest3, cloneTest); + AtomicInteger bytesRead3 = mockSourceTuple3.v1(); + // read again from buffer, after the mark inside the buffer + bytesReadBefore = bytesRead3.get(); + if (randomBoolean()) { + cloneTest3.readNBytes(readLen3); + } else { + skipNBytes(cloneTest3, readLen3); + } + bytesReadAfter = bytesRead3.get(); + // assert bytes are replayed from the buffer, and not read from the stream + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + // assert buffer is consumed completely + assertThat(cloneTest3.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen + readLen2)); + assertThat(cloneTest3.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - readLen2 - readLen3)); + assertThat(cloneTest3.storeToBuffer, Matchers.is(true)); + assertThat(cloneTest3.replayFromBuffer, Matchers.is(true)); + } + // read beyond the buffer can supply, but not more than it can accommodate + for (int readLen3 = readLen - readLen2 + 1; readLen3 <= bufferSize - readLen2; readLen3++) { + Tuple mockSourceTuple3 = getMockInfiniteInputStream(); + BufferOnMarkInputStream cloneTest3 = new BufferOnMarkInputStream(mockSourceTuple3.v2(), bufferSize); + cloneBufferOnMarkStream(cloneTest3, cloneTest); + AtomicInteger bytesRead3 = mockSourceTuple3.v1(); + // read again from buffer, after the mark inside the buffer + bytesReadBefore = bytesRead3.get(); + if (randomBoolean()) { + cloneTest3.readNBytes(readLen3); + } else { + skipNBytes(cloneTest3, readLen3); + } + bytesReadAfter = bytesRead3.get(); + // assert bytes are PARTLY replayed, PARTLY read from the stream + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen3 + readLen2 - readLen)); + // assert buffer is appended and fully replayed + assertThat(cloneTest3.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen3)); + assertThat(cloneTest3.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen3 + readLen2 - readLen)); + assertThat(cloneTest3.storeToBuffer, Matchers.is(true)); + assertThat(cloneTest3.replayFromBuffer, Matchers.is(false)); + } + } + } + + public void testMarkAfterResetAfterReplayingBuffer() throws Exception { + int bufferSize = randomIntBetween(8, 16); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + test.readNBytes(randomFrom(0, randomInt(32))); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read less than bufferSize bytes + int bytesReadBefore = bytesRead.get(); + int readLen = randomIntBetween(1, bufferSize); + if (randomBoolean()) { + test.readNBytes(readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert buffer is populated + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // reset + test.reset(); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(true)); + for (int readLen2 = readLen + 1; readLen2 <= bufferSize; readLen2++) { + Tuple mockSourceTuple2 = getMockInfiniteInputStream(); + BufferOnMarkInputStream test2 = new BufferOnMarkInputStream(mockSourceTuple2.v2(), bufferSize); + cloneBufferOnMarkStream(test2, test); + AtomicInteger bytesRead2 = mockSourceTuple2.v1(); + // read again, more than before + bytesReadBefore = bytesRead2.get(); + if (randomBoolean()) { + test2.readNBytes(readLen2); + } else { + skipNBytes(test2, readLen2); + } + bytesReadAfter = bytesRead2.get(); + // assert bytes are PARTLY replayed, PARTLY read from the stream + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen2 - readLen)); + // assert buffer is appended and fully replayed + assertThat(test2.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen2)); + assertThat(test2.storeToBuffer, Matchers.is(true)); + assertThat(test2.replayFromBuffer, Matchers.is(false)); + // mark + test2.mark(randomIntBetween(1, bufferSize)); + assertThat(test2.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test2.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test2.storeToBuffer, Matchers.is(true)); + assertThat(test2.replayFromBuffer, Matchers.is(false)); + } + } + + public void testNoMockSimpleMarkResetAtBeginning() throws Exception { + for (int length = 1; length <= 8; length++) { + for (int mark = 1; mark <= length; mark++) { + try (BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), mark)) { + in.mark(mark); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(mark)); + byte[] test1 = in.readNBytes(mark); + assertArray(0, test1); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + in.reset(); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + byte[] test2 = in.readNBytes(mark); + assertArray(0, test2); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + } + } + } + } + + public void testNoMockMarkResetAtBeginning() throws Exception { + for (int length = 1; length <= 8; length++) { + try (BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), length)) { + in.mark(length); + // increasing length read/reset + for (int readLen = 1; readLen <= length; readLen++) { + byte[] test1 = in.readNBytes(readLen); + assertArray(0, test1); + in.reset(); + } + } + try (BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), length)) { + in.mark(length); + // decreasing length read/reset + for (int readLen = length; readLen >= 1; readLen--) { + byte[] test1 = in.readNBytes(readLen); + assertArray(0, test1); + in.reset(); + } + } + } + } + + public void testNoMockSimpleMarkResetEverywhere() throws Exception { + for (int length = 1; length <= 10; length++) { + for (int offset = 0; offset < length; offset++) { + for (int mark = 1; mark <= length - offset; mark++) { + try ( + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), mark) + ) { + // skip first offset bytes + in.readNBytes(offset); + in.mark(mark); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(mark)); + byte[] test1 = in.readNBytes(mark); + assertArray(offset, test1); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + in.reset(); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + byte[] test2 = in.readNBytes(mark); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + assertArray(offset, test2); + } + } + } + } + } + + public void testNoMockMarkResetEverywhere() throws Exception { + for (int length = 1; length <= 8; length++) { + for (int offset = 0; offset < length; offset++) { + try ( + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), length) + ) { + // skip first offset bytes + in.readNBytes(offset); + in.mark(length); + // increasing read lengths + for (int readLen = 1; readLen <= length - offset; readLen++) { + byte[] test = in.readNBytes(readLen); + assertArray(offset, test); + in.reset(); + } + } + try ( + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), length) + ) { + // skip first offset bytes + in.readNBytes(offset); + in.mark(length); + // decreasing read lengths + for (int readLen = length - offset; readLen >= 1; readLen--) { + byte[] test = in.readNBytes(readLen); + assertArray(offset, test); + in.reset(); + } + } + } + } + } + + public void testNoMockDoubleMarkEverywhere() throws Exception { + for (int length = 1; length <= 16; length++) { + for (int offset = 0; offset < length; offset++) { + for (int readLen = 1; readLen <= length - offset; readLen++) { + for (int markLen = 1; markLen <= length - offset; markLen++) { + try ( + BufferOnMarkInputStream in = new BufferOnMarkInputStream( + new NoMarkByteArrayInputStream(testArray, 0, length), + length + ) + ) { + in.readNBytes(offset); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + // first mark + in.mark(length - offset); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + byte[] test = in.readNBytes(readLen); + assertArray(offset, test); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + // reset to first + in.reset(); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + // advance before/after the first read length + test = in.readNBytes(markLen); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - Math.max(readLen, markLen))); + if (markLen <= readLen) { + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - markLen)); + } else { + assertThat(in.replayFromBuffer, Matchers.is(false)); + } + assertArray(offset, test); + // second mark + in.mark(length - offset - markLen); + if (markLen <= readLen) { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen + markLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - markLen)); + } else { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + } + for (int readLen2 = 1; readLen2 <= length - offset - markLen; readLen2++) { + byte[] test2 = in.readNBytes(readLen2); + if (markLen + readLen2 <= readLen) { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen + markLen)); + assertThat(in.replayFromBuffer, Matchers.is(true)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - markLen - readLen2)); + } else { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen2)); + assertThat(in.replayFromBuffer, Matchers.is(false)); + } + assertArray(offset + markLen, test2); + in.reset(); + assertThat(in.replayFromBuffer, Matchers.is(true)); + if (markLen + readLen2 <= readLen) { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen + markLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - markLen)); + } else { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen2)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen2)); + } + } + } + } + } + } + } + } + + public void testNoMockMarkWithoutReset() throws Exception { + int maxMark = 8; + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, testArray.length), maxMark); + int offset = 0; + while (offset < testArray.length) { + int readLen = Math.min(1 + Randomness.get().nextInt(maxMark), testArray.length - offset); + in.mark(Randomness.get().nextInt(readLen)); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(maxMark)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + byte[] test = in.readNBytes(readLen); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(maxMark - readLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertArray(offset, test); + offset += readLen; + } + } + + public void testNoMockThreeMarkResetMarkSteps() throws Exception { + int length = randomIntBetween(8, 16); + int stepLen = randomIntBetween(4, 8); + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), stepLen); + testMarkResetMarkStep(in, 0, length, stepLen, 2); + } + + private void testMarkResetMarkStep(BufferOnMarkInputStream stream, int offset, int length, int stepLen, int step) throws Exception { + stream.mark(stepLen); + for (int readLen = 1; readLen <= Math.min(stepLen, length - offset); readLen++) { + for (int markLen = 1; markLen <= Math.min(stepLen, length - offset); markLen++) { + BufferOnMarkInputStream cloneStream = cloneBufferOnMarkStream(stream); + // read ahead + byte[] test = cloneStream.readNBytes(readLen); + assertArray(offset, test); + // reset back + cloneStream.reset(); + // read ahead different length + test = cloneStream.readNBytes(markLen); + assertArray(offset, test); + if (step > 0) { + testMarkResetMarkStep(cloneStream, offset + markLen, length, stepLen, step - 1); + } + } + } + } + + private BufferOnMarkInputStream cloneBufferOnMarkStream(BufferOnMarkInputStream orig) { + int origOffset = ((NoMarkByteArrayInputStream) orig.source).getPos(); + int origLen = ((NoMarkByteArrayInputStream) orig.source).getCount(); + BufferOnMarkInputStream cloneStream = new BufferOnMarkInputStream( + new NoMarkByteArrayInputStream(testArray, origOffset, origLen - origOffset), + orig.ringBuffer.getBufferSize() + ); + if (orig.ringBuffer.buffer != null) { + cloneStream.ringBuffer.buffer = Arrays.copyOf(orig.ringBuffer.buffer, orig.ringBuffer.buffer.length); + } else { + cloneStream.ringBuffer.buffer = null; + } + cloneStream.ringBuffer.head = orig.ringBuffer.head; + cloneStream.ringBuffer.tail = orig.ringBuffer.tail; + cloneStream.ringBuffer.position = orig.ringBuffer.position; + cloneStream.storeToBuffer = orig.storeToBuffer; + cloneStream.replayFromBuffer = orig.replayFromBuffer; + cloneStream.closed = orig.closed; + return cloneStream; + } + + private void cloneBufferOnMarkStream(BufferOnMarkInputStream clone, BufferOnMarkInputStream orig) { + if (orig.ringBuffer.buffer != null) { + clone.ringBuffer.buffer = Arrays.copyOf(orig.ringBuffer.buffer, orig.ringBuffer.buffer.length); + } else { + clone.ringBuffer.buffer = null; + } + clone.ringBuffer.head = orig.ringBuffer.head; + clone.ringBuffer.tail = orig.ringBuffer.tail; + clone.ringBuffer.position = orig.ringBuffer.position; + clone.storeToBuffer = orig.storeToBuffer; + clone.replayFromBuffer = orig.replayFromBuffer; + clone.closed = orig.closed; + } + + private void assertArray(int offset, byte[] test) { + for (int i = 0; i < test.length; i++) { + Assert.assertThat(test[i], Matchers.is(testArray[offset + i])); + } + } + + private Tuple getMockInfiniteInputStream() throws IOException { + InputStream mockSource = mock(InputStream.class); + AtomicInteger bytesRead = new AtomicInteger(0); + when(mockSource.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())).thenAnswer( + invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + int bytesCount = 1 + Randomness.get().nextInt(len); + bytesRead.addAndGet(bytesCount); + return bytesCount; + } + } + ); + when(mockSource.read()).thenAnswer(invocationOnMock -> { + bytesRead.incrementAndGet(); + return Randomness.get().nextInt(256); + }); + when(mockSource.skip(org.mockito.Matchers.anyLong())).thenAnswer(invocationOnMock -> { + final long n = (long) invocationOnMock.getArguments()[0]; + if (n <= 0) { + return 0; + } + int bytesSkipped = 1 + Randomness.get().nextInt(Math.toIntExact(n)); + bytesRead.addAndGet(bytesSkipped); + return bytesSkipped; + }); + when(mockSource.available()).thenReturn(1 + Randomness.get().nextInt(32)); + when(mockSource.markSupported()).thenReturn(false); + return new Tuple<>(bytesRead, mockSource); + } + + private static void skipNBytes(InputStream in, long n) throws IOException { + if (n > 0) { + long ns = in.skip(n); + if (ns >= 0 && ns < n) { // skipped too few bytes + // adjust number to skip + n -= ns; + // read until requested number skipped or EOS reached + while (n > 0 && in.read() != -1) { + n--; + } + // if not enough skipped, then EOFE + if (n != 0) { + throw new EOFException(); + } + } else if (ns != n) { // skipped negative or too many bytes + throw new IOException("Unable to skip exactly"); + } + } + } + + static class NoMarkByteArrayInputStream extends ByteArrayInputStream { + + NoMarkByteArrayInputStream(byte[] buf) { + super(buf); + } + + NoMarkByteArrayInputStream(byte[] buf, int offset, int length) { + super(buf, offset, length); + } + + int getPos() { + return pos; + } + + int getCount() { + return count; + } + + @Override + public void mark(int readlimit) {} + + @Override + public boolean markSupported() { + return false; + } + + @Override + public void reset() { + throw new IllegalStateException("Mark not called or has been invalidated"); + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/ChainingInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/ChainingInputStreamTests.java new file mode 100644 index 0000000000000..c664f29ffbbfc --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/ChainingInputStreamTests.java @@ -0,0 +1,1170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.mockito.Mockito; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ChainingInputStreamTests extends ESTestCase { + + public void testChainComponentsWhenUsingFactoryMethod() throws Exception { + InputStream input1 = mock(InputStream.class); + when(input1.markSupported()).thenReturn(true); + when(input1.read()).thenReturn(randomIntBetween(0, 255)); + InputStream input2 = mock(InputStream.class); + when(input2.markSupported()).thenReturn(true); + when(input2.read()).thenReturn(randomIntBetween(0, 255)); + + ChainingInputStream chain = ChainingInputStream.chain(input1, input2); + + chain.read(); + verify(input1).read(); + verify(input2, times(0)).read(); + + when(input1.read()).thenReturn(-1); + chain.read(); + verify(input1, times(2)).read(); + verify(input1, times(0)).close(); + verify(input2).read(); + + when(input2.read()).thenReturn(-1); + chain.read(); + verify(input1, times(2)).read(); + verify(input2, times(2)).read(); + verify(input1, times(0)).close(); + verify(input2, times(0)).close(); + + chain.close(); + verify(input1).close(); + verify(input2).close(); + } + + public void testMarkAndResetWhenUsingFactoryMethod() throws Exception { + InputStream input1 = mock(InputStream.class); + when(input1.markSupported()).thenReturn(true); + when(input1.read()).thenReturn(randomIntBetween(0, 255)); + InputStream input2 = mock(InputStream.class); + when(input2.markSupported()).thenReturn(true); + when(input2.read()).thenReturn(randomIntBetween(0, 255)); + + ChainingInputStream chain = ChainingInputStream.chain(input1, input2); + verify(input1, times(1)).mark(anyInt()); + verify(input2, times(1)).mark(anyInt()); + + // mark at the beginning + chain.mark(randomIntBetween(1, 32)); + verify(input1, times(1)).mark(anyInt()); + verify(input2, times(1)).mark(anyInt()); + + verify(input1, times(0)).reset(); + chain.read(); + verify(input1, times(1)).reset(); + chain.reset(); + verify(input1, times(0)).close(); + verify(input1, times(1)).reset(); + chain.read(); + verify(input1, times(2)).reset(); + + // mark at the first component + chain.mark(randomIntBetween(1, 32)); + verify(input1, times(2)).mark(anyInt()); + verify(input2, times(1)).mark(anyInt()); + + when(input1.read()).thenReturn(-1); + chain.read(); + verify(input1, times(0)).close(); + chain.reset(); + verify(input1, times(3)).reset(); + + chain.read(); + verify(input2, times(2)).reset(); + + // mark at the second component + chain.mark(randomIntBetween(1, 32)); + verify(input1, times(2)).mark(anyInt()); + verify(input2, times(2)).mark(anyInt()); + + when(input2.read()).thenReturn(-1); + chain.read(); + verify(input1, times(0)).close(); + verify(input2, times(0)).close(); + chain.reset(); + verify(input2, times(3)).reset(); + + chain.close(); + verify(input1, times(1)).close(); + verify(input2, times(1)).close(); + } + + public void testSkipWithinComponent() throws Exception { + byte[] b1 = randomByteArrayOfLength(randomIntBetween(2, 16)); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return new ByteArrayInputStream(b1); + } else { + return null; + } + } + }; + int prefix = randomIntBetween(0, b1.length - 2); + test.readNBytes(prefix); + // skip less bytes than the component has + int nSkip1 = randomInt(b1.length - prefix); + long nSkip = test.skip(nSkip1); + assertThat((int) nSkip, Matchers.is(nSkip1)); + int nSkip2 = b1.length - prefix - nSkip1 + randomIntBetween(1, 8); + // skip more bytes than the component has + nSkip = test.skip(nSkip2); + assertThat((int) nSkip, Matchers.is(b1.length - prefix - nSkip1)); + } + + public void testSkipAcrossComponents() throws Exception { + byte[] b1 = randomByteArrayOfLength(randomIntBetween(1, 16)); + byte[] b2 = randomByteArrayOfLength(randomIntBetween(1, 16)); + ChainingInputStream test = new ChainingInputStream() { + final Iterator iter = List.of(new ByteArrayInputStream(b1), new ByteArrayInputStream(b2)).iterator(); + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (iter.hasNext()) { + return iter.next(); + } else { + return null; + } + } + }; + long skipArg = b1.length + randomIntBetween(1, b2.length); + long nSkip = test.skip(skipArg); + assertThat(nSkip, Matchers.is(skipArg)); + byte[] rest = test.readAllBytes(); + assertThat((long) rest.length, Matchers.is(b1.length + b2.length - nSkip)); + for (int i = rest.length - 1; i >= 0; i--) { + assertThat(rest[i], Matchers.is(b2[i + (int) nSkip - b1.length])); + } + } + + public void testEmptyChain() throws Exception { + // chain is empty because it doesn't have any components + ChainingInputStream emptyStream = newEmptyStream(false); + assertThat(emptyStream.read(), Matchers.is(-1)); + emptyStream = newEmptyStream(false); + byte[] b = randomByteArrayOfLength(randomIntBetween(1, 8)); + int off = randomInt(b.length - 1); + assertThat(emptyStream.read(b, off, b.length - off), Matchers.is(-1)); + emptyStream = newEmptyStream(false); + assertThat(emptyStream.available(), Matchers.is(0)); + emptyStream = newEmptyStream(false); + assertThat(emptyStream.skip(randomIntBetween(1, 32)), Matchers.is(0L)); + // chain is empty because all its components are empty + emptyStream = newEmptyStream(true); + assertThat(emptyStream.read(), Matchers.is(-1)); + emptyStream = newEmptyStream(true); + b = randomByteArrayOfLength(randomIntBetween(1, 8)); + off = randomInt(b.length - 1); + assertThat(emptyStream.read(b, off, b.length - off), Matchers.is(-1)); + emptyStream = newEmptyStream(true); + assertThat(emptyStream.available(), Matchers.is(0)); + emptyStream = newEmptyStream(true); + assertThat(emptyStream.skip(randomIntBetween(1, 32)), Matchers.is(0L)); + } + + public void testClose() throws Exception { + ChainingInputStream test1 = newEmptyStream(randomBoolean()); + test1.close(); + IOException e = expectThrows(IOException.class, () -> { test1.read(); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test2 = newEmptyStream(randomBoolean()); + test2.close(); + byte[] b = randomByteArrayOfLength(randomIntBetween(2, 9)); + int off = randomInt(b.length - 2); + e = expectThrows(IOException.class, () -> { test2.read(b, off, randomInt(b.length - off - 1)); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test3 = newEmptyStream(randomBoolean()); + test3.close(); + e = expectThrows(IOException.class, () -> { test3.skip(randomInt(31)); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test4 = newEmptyStream(randomBoolean()); + test4.close(); + e = expectThrows(IOException.class, () -> { test4.available(); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test5 = newEmptyStream(randomBoolean()); + test5.close(); + e = expectThrows(IOException.class, () -> { test5.reset(); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test6 = newEmptyStream(randomBoolean()); + test6.close(); + try { + test6.mark(randomInt()); + } catch (Exception e1) { + assumeNoException("mark on a closed stream should not throw", e1); + } + } + + public void testInitialComponentArgumentIsNull() throws Exception { + AtomicReference initialInputStream = new AtomicReference<>(); + AtomicBoolean nextCalled = new AtomicBoolean(false); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + initialInputStream.set(currentComponentIn); + nextCalled.set(true); + return null; + } + }; + assertThat(test.read(), Matchers.is(-1)); + assertThat(nextCalled.get(), Matchers.is(true)); + assertThat(initialInputStream.get(), Matchers.nullValue()); + } + + public void testChaining() throws Exception { + int componentCount = randomIntBetween(2, 9); + ByteBuffer testSource = ByteBuffer.allocate(componentCount); + TestInputStream[] sourceComponents = new TestInputStream[componentCount]; + for (int i = 0; i < sourceComponents.length; i++) { + byte[] b = randomByteArrayOfLength(randomInt(1)); + testSource.put(b); + sourceComponents[i] = new TestInputStream(b); + } + ChainingInputStream test = new ChainingInputStream() { + int i = 0; + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (i == 0) { + assertThat(currentComponentIn, Matchers.nullValue()); + return sourceComponents[i++]; + } else if (i < sourceComponents.length) { + assertThat(((TestInputStream) currentComponentIn).closed.get(), Matchers.is(true)); + assertThat(currentComponentIn, Matchers.is(sourceComponents[i - 1])); + return sourceComponents[i++]; + } else if (i == sourceComponents.length) { + assertThat(((TestInputStream) currentComponentIn).closed.get(), Matchers.is(true)); + assertThat(currentComponentIn, Matchers.is(sourceComponents[i - 1])); + i++; + return null; + } else { + throw new IllegalStateException(); + } + } + + @Override + public boolean markSupported() { + return false; + } + }; + byte[] testArr = test.readAllBytes(); + byte[] ref = testSource.array(); + // testArr and ref should be equal, but ref might have trailing zeroes + for (int i = 0; i < testArr.length; i++) { + assertThat(testArr[i], Matchers.is(ref[i])); + } + } + + public void testEmptyInputStreamComponents() throws Exception { + // leading single empty stream + Tuple test = testEmptyComponentsInChain(3, Arrays.asList(0)); + byte[] result = test.v1().readAllBytes(); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // leading double empty streams + test = testEmptyComponentsInChain(3, Arrays.asList(0, 1)); + result = test.v1().readAllBytes(); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // trailing single empty stream + test = testEmptyComponentsInChain(3, Arrays.asList(2)); + result = test.v1().readAllBytes(); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // trailing double empty stream + test = testEmptyComponentsInChain(3, Arrays.asList(1, 2)); + result = test.v1().readAllBytes(); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // middle single empty stream + test = testEmptyComponentsInChain(3, Arrays.asList(1)); + result = test.v1().readAllBytes(); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // leading and trailing empty streams + test = testEmptyComponentsInChain(3, Arrays.asList(0, 2)); + result = test.v1().readAllBytes(); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // all streams are empty + test = testEmptyComponentsInChain(3, Arrays.asList(0, 1, 2)); + result = test.v1().readAllBytes(); + assertThat(result.length, Matchers.is(0)); + } + + public void testNullComponentTerminatesChain() throws Exception { + TestInputStream[] sourceComponents = new TestInputStream[3]; + TestInputStream[] chainComponents = new TestInputStream[5]; + byte[] b1 = randomByteArrayOfLength(randomIntBetween(1, 2)); + sourceComponents[0] = new TestInputStream(b1); + sourceComponents[1] = null; + byte[] b2 = randomByteArrayOfLength(randomIntBetween(1, 2)); + sourceComponents[2] = new TestInputStream(b2); + ChainingInputStream test = new ChainingInputStream() { + int i = 0; + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + chainComponents[i] = (TestInputStream) currentComponentIn; + if (i < sourceComponents.length) { + return sourceComponents[i++]; + } else { + i++; + return null; + } + } + + @Override + public boolean markSupported() { + return false; + } + }; + assertThat(test.readAllBytes(), Matchers.equalTo(b1)); + assertThat(chainComponents[0], Matchers.nullValue()); + assertThat(chainComponents[1], Matchers.is(sourceComponents[0])); + assertThat(chainComponents[1].closed.get(), Matchers.is(true)); + assertThat(chainComponents[2], Matchers.nullValue()); + assertThat(chainComponents[3], Matchers.nullValue()); + } + + public void testCallsForwardToCurrentComponent() throws Exception { + InputStream mockCurrentIn = mock(InputStream.class); + when(mockCurrentIn.markSupported()).thenReturn(true); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockCurrentIn; + } else { + throw new IllegalStateException(); + } + } + }; + // verify "byte-wise read" is proxied to the current component stream + when(mockCurrentIn.read()).thenReturn(randomInt(255)); + test.read(); + verify(mockCurrentIn).read(); + // verify "array read" is proxied to the current component stream + when(mockCurrentIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + // partial read return + int bytesCount = randomIntBetween(1, len); + return bytesCount; + } + }); + byte[] b = randomByteArrayOfLength(randomIntBetween(2, 33)); + int len = randomIntBetween(1, b.length - 1); + int offset = randomInt(b.length - len - 1); + test.read(b, offset, len); + verify(mockCurrentIn).read(Mockito.eq(b), Mockito.eq(offset), Mockito.eq(len)); + // verify "skip" is proxied to the current component stream + long skipCount = randomIntBetween(1, 3); + test.skip(skipCount); + verify(mockCurrentIn).skip(Mockito.eq(skipCount)); + // verify "available" is proxied to the current component stream + test.available(); + verify(mockCurrentIn).available(); + } + + public void testEmptyReadAsksForNext() throws Exception { + InputStream mockCurrentIn = mock(InputStream.class); + when(mockCurrentIn.markSupported()).thenReturn(true); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + return mockCurrentIn; + } + }; + test.currentIn = InputStream.nullInputStream(); + when(mockCurrentIn.read()).thenReturn(randomInt(255)); + test.read(); + verify(mockCurrentIn).read(); + // test "array read" + test.currentIn = InputStream.nullInputStream(); + when(mockCurrentIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + int bytesCount = randomIntBetween(1, len); + return bytesCount; + } + }); + byte[] b = new byte[randomIntBetween(2, 33)]; + int len = randomIntBetween(1, b.length - 1); + int offset = randomInt(b.length - len - 1); + test.read(b, offset, len); + verify(mockCurrentIn).read(Mockito.eq(b), Mockito.eq(offset), Mockito.eq(len)); + } + + public void testReadAll() throws Exception { + byte[] b = randomByteArrayOfLength(randomIntBetween(2, 33)); + int splitIdx = randomInt(b.length - 2); + ByteArrayInputStream first = new ByteArrayInputStream(b, 0, splitIdx + 1); + ByteArrayInputStream second = new ByteArrayInputStream(b, splitIdx + 1, b.length - splitIdx - 1); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentElementIn) throws IOException { + if (currentElementIn == null) { + return first; + } else if (currentElementIn == first) { + return second; + } else if (currentElementIn == second) { + return null; + } else { + throw new IllegalArgumentException(); + } + } + }; + byte[] result = test.readAllBytes(); + assertThat(result, Matchers.equalTo(b)); + } + + public void testMarkAtBeginning() throws Exception { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockIn; + } else { + return null; + } + } + }; + assertThat(test.currentIn, Matchers.nullValue()); + // mark at the beginning + assertThat(test.markIn, Matchers.nullValue()); + test.mark(randomInt(63)); + assertThat(test.markIn, Matchers.nullValue()); + // another mark is a no-op + test.mark(randomInt(63)); + assertThat(test.markIn, Matchers.nullValue()); + // read does not change the marK + test.read(); + assertThat(test.currentIn, Matchers.is(mockIn)); + // mark reference is still unchanged + assertThat(test.markIn, Matchers.nullValue()); + // read reaches end + when(mockIn.read()).thenReturn(-1); + test.read(); + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + verify(mockIn).close(); + // mark reference is still unchanged + assertThat(test.markIn, Matchers.nullValue()); + } + + public void testMarkAtEnding() throws Exception { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + when(mockIn.read()).thenAnswer(invocationOnMock -> randomFrom(-1, randomInt(255))); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockIn; + } else { + return null; + } + } + }; + // read all bytes + while (test.read() != -1) { + } + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is null (beginning) + assertThat(test.markIn, Matchers.nullValue()); + test.mark(randomInt(255)); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // another mark is a no-op + test.mark(randomInt(255)); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + } + + public void testSingleMarkAnywhere() throws Exception { + Supplier mockInputStreamSupplier = () -> { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + try { + when(mockIn.read()).thenAnswer(invocationOnMock -> randomFrom(-1, randomInt(1))); + when(mockIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (randomBoolean()) { + return -1; + } else { + // partial read return + return randomIntBetween(1, len); + } + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return mockIn; + }; + AtomicBoolean chainingInputStreamEOF = new AtomicBoolean(false); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (chainingInputStreamEOF.get()) { + return null; + } else { + return mockInputStreamSupplier.get(); + } + } + }; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + test.readNBytes(randomInt(63)); + } + InputStream currentIn = test.currentIn; + int readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(currentIn).mark(Mockito.eq(readLimit)); + // mark again, same position + int readLimit2 = randomInt(63); + test.mark(readLimit2); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + if (readLimit != readLimit2) { + verify(currentIn).mark(Mockito.eq(readLimit2)); + } else { + verify(currentIn, times(2)).mark(Mockito.eq(readLimit)); + } + // read more (possibly moving on to a new component) + test.readNBytes(randomInt(63)); + // mark does not budge + assertThat(test.markIn, Matchers.is(currentIn)); + // read until the end + chainingInputStreamEOF.set(true); + test.readAllBytes(); + // current component is at the end + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is still put + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.markIn, never()).close(); + // but close also closes the mark + test.close(); + verify(test.markIn).close(); + } + + public void testMarkOverwritesPreviousMark() throws Exception { + AtomicBoolean chainingInputStreamEOF = new AtomicBoolean(false); + Supplier mockInputStreamSupplier = () -> { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + try { + // single byte read never returns "-1" so it never advances component + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + when(mockIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (randomBoolean()) { + return -1; + } else { + // partial read return + return randomIntBetween(1, len); + } + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return mockIn; + }; + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (chainingInputStreamEOF.get()) { + return null; + } else { + return mockInputStreamSupplier.get(); + } + } + }; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + test.readNBytes(randomInt(63)); + } + InputStream currentIn = test.currentIn; + int readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.markIn).mark(Mockito.eq(readLimit)); + // read more within the same component + for (int i = 0; i < randomIntBetween(4, 16); i++) { + test.read(); + } + // mark does not budge + assertThat(test.markIn, Matchers.is(currentIn)); + // mark again + int readLimit2 = randomInt(63); + test.mark(readLimit2); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(currentIn, never()).close(); + if (readLimit != readLimit2) { + verify(currentIn).mark(Mockito.eq(readLimit2)); + } else { + verify(currentIn, times(2)).mark(Mockito.eq(readLimit)); + } + // read more while switching the component + for (int i = 0; i < randomIntBetween(4, 16) || test.currentIn == currentIn; i++) { + test.readNBytes(randomInt(63)); + } + // mark does not budge + assertThat(test.markIn, Matchers.is(currentIn)); + // mark again + readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.markIn, Matchers.is(test.currentIn)); + // previous mark closed + verify(currentIn).close(); + verify(test.markIn).mark(Mockito.eq(readLimit)); + InputStream markIn = test.markIn; + // read until the end + chainingInputStreamEOF.set(true); + test.readAllBytes(); + // current component is at the end + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is still put + assertThat(test.markIn, Matchers.is(markIn)); + verify(test.markIn, never()).close(); + // mark at the end + readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + verify(markIn).close(); + } + + public void testResetAtBeginning() throws Exception { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockIn; + } else { + return null; + } + } + }; + assertThat(test.currentIn, Matchers.nullValue()); + assertThat(test.markIn, Matchers.nullValue()); + if (randomBoolean()) { + // mark at the beginning + test.mark(randomInt(63)); + assertThat(test.markIn, Matchers.nullValue()); + } + // reset immediately + test.reset(); + assertThat(test.currentIn, Matchers.nullValue()); + // read does not change the marK + test.read(); + assertThat(test.currentIn, Matchers.is(mockIn)); + // mark reference is still unchanged + assertThat(test.markIn, Matchers.nullValue()); + // reset back to beginning + test.reset(); + verify(mockIn).close(); + assertThat(test.currentIn, Matchers.nullValue()); + // read reaches end + when(mockIn.read()).thenReturn(-1); + test.read(); + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark reference is still unchanged + assertThat(test.markIn, Matchers.nullValue()); + // reset back to beginning + test.reset(); + assertThat(test.currentIn, Matchers.nullValue()); + } + + public void testResetAtEnding() throws Exception { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + when(mockIn.read()).thenAnswer(invocationOnMock -> randomFrom(-1, randomInt(255))); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockIn; + } else { + return null; + } + } + }; + // read all bytes + while (test.read() != -1) { + } + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is null (beginning) + assertThat(test.markIn, Matchers.nullValue()); + test.mark(randomInt(255)); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // reset + test.reset(); + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + assertThat(test.read(), Matchers.is(-1)); + // another mark is a no-op + test.mark(randomInt(255)); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + assertThat(test.read(), Matchers.is(-1)); + } + + public void testResetForSingleMarkAnywhere() throws Exception { + Supplier mockInputStreamSupplier = () -> { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + try { + // single byte read never returns "-1" so it never advances component + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + when(mockIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (randomBoolean()) { + return -1; + } else { + // partial read return + return randomIntBetween(1, len); + } + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return mockIn; + }; + AtomicBoolean chainingInputStreamEOF = new AtomicBoolean(false); + AtomicReference nextComponentArg = new AtomicReference<>(); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (nextComponentArg.get() != null) { + Assert.assertThat(currentComponentIn, Matchers.is(nextComponentArg.get())); + nextComponentArg.set(null); + } + if (chainingInputStreamEOF.get()) { + return null; + } else { + return mockInputStreamSupplier.get(); + } + } + }; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + test.readNBytes(randomInt(63)); + } + InputStream currentIn = test.currentIn; + int readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(currentIn).mark(Mockito.eq(readLimit)); + // read more without moving to a new component + for (int i = 0; i < randomIntBetween(4, 16); i++) { + test.read(); + } + // first reset + test.reset(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.currentIn, never()).close(); + verify(test.currentIn).reset(); + // read more, moving on to a new component + for (int i = 0; i < randomIntBetween(4, 16) || test.currentIn == currentIn; i++) { + test.readNBytes(randomInt(63)); + } + // mark does not budge + assertThat(test.markIn, Matchers.is(currentIn)); + assertThat(test.currentIn, Matchers.not(currentIn)); + InputStream lastCurrentIn = test.currentIn; + // second reset + test.reset(); + verify(lastCurrentIn).close(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.currentIn, times(2)).reset(); + // assert the "nextComponent" argument + nextComponentArg.set(currentIn); + // read more, moving on to a new component + for (int i = 0; i < randomIntBetween(4, 16) || test.currentIn == currentIn; i++) { + test.readNBytes(randomInt(63)); + } + // read until the end + chainingInputStreamEOF.set(true); + test.readAllBytes(); + // current component is at the end + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is still put + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.markIn, never()).close(); + // reset when stream is at the end + test.reset(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.currentIn, times(3)).reset(); + // assert the "nextComponent" argument + nextComponentArg.set(currentIn); + // read more to verify that current component is passed as nextComponent argument + test.readAllBytes(); + } + + public void testResetForDoubleMarkAnywhere() throws Exception { + Supplier mockInputStreamSupplier = () -> { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + try { + // single byte read never returns "-1" so it never advances component + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + when(mockIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (randomBoolean()) { + return -1; + } else { + // partial read return + return randomIntBetween(1, len); + } + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return mockIn; + }; + AtomicBoolean chainingInputStreamEOF = new AtomicBoolean(false); + AtomicReference nextComponentArg = new AtomicReference<>(); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (nextComponentArg.get() != null) { + Assert.assertThat(currentComponentIn, Matchers.is(nextComponentArg.get())); + nextComponentArg.set(null); + } + if (chainingInputStreamEOF.get()) { + return null; + } else { + return mockInputStreamSupplier.get(); + } + } + }; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + test.readNBytes(randomInt(63)); + } + InputStream currentIn = test.currentIn; + int readLimit = randomInt(63); + // first mark + test.mark(readLimit); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(currentIn).mark(Mockito.eq(readLimit)); + // possibly skips over several components + for (int i = 0; i < randomIntBetween(1, 2); i++) { + test.readNBytes(randomInt(63)); + } + InputStream lastCurrentIn = test.currentIn; + // second mark + readLimit = randomInt(63); + test.mark(readLimit); + if (lastCurrentIn != currentIn) { + verify(currentIn).close(); + } + assertThat(test.currentIn, Matchers.is(lastCurrentIn)); + assertThat(test.markIn, Matchers.is(lastCurrentIn)); + verify(lastCurrentIn).mark(Mockito.eq(readLimit)); + currentIn = lastCurrentIn; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(1, 2); i++) { + test.readNBytes(randomInt(63)); + } + lastCurrentIn = test.currentIn; + // reset + test.reset(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + if (lastCurrentIn != currentIn) { + verify(lastCurrentIn).close(); + } + verify(currentIn).reset(); + // assert the "nextComponet" arg is the current component + nextComponentArg.set(currentIn); + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + test.readNBytes(randomInt(63)); + } + lastCurrentIn = test.currentIn; + // third mark after reset + readLimit = randomInt(63); + test.mark(readLimit); + if (lastCurrentIn != currentIn) { + verify(currentIn).close(); + } + assertThat(test.currentIn, Matchers.is(lastCurrentIn)); + assertThat(test.markIn, Matchers.is(lastCurrentIn)); + verify(lastCurrentIn).mark(Mockito.eq(readLimit)); + nextComponentArg.set(lastCurrentIn); + currentIn = lastCurrentIn; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(1, 2); i++) { + test.readNBytes(randomInt(63)); + } + lastCurrentIn = test.currentIn; + // reset after mark after reset + test.reset(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + if (lastCurrentIn != currentIn) { + verify(lastCurrentIn).close(); + } + verify(currentIn).reset(); + } + + public void testMarkAfterResetNoMock() throws Exception { + int len = randomIntBetween(8, 15); + byte[] b = randomByteArrayOfLength(len); + for (int p = 0; p <= len; p++) { + for (int mark1 = 0; mark1 < len; mark1++) { + for (int offset1 = 0; offset1 < len - mark1; offset1++) { + for (int mark2 = 0; mark2 < len - mark1; mark2++) { + for (int offset2 = 0; offset2 < len - mark1 - mark2; offset2++) { + final int pivot = p; + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return new TestInputStream(b, 0, pivot, 1); + } else if (((TestInputStream) currentComponentIn).label == 1) { + return new TestInputStream(b, pivot, len - pivot, 2); + } else if (((TestInputStream) currentComponentIn).label == 2) { + return null; + } else { + throw new IllegalStateException(); + } + } + }; + // read "mark1" bytes + byte[] pre = test.readNBytes(mark1); + for (int i = 0; i < pre.length; i++) { + assertThat(pre[i], Matchers.is(b[i])); + } + // first mark + test.mark(len); + // read "offset" bytes + byte[] span1 = test.readNBytes(offset1); + for (int i = 0; i < span1.length; i++) { + assertThat(span1[i], Matchers.is(b[mark1 + i])); + } + // reset back to "mark1" offset + test.reset(); + // read/replay "mark2" bytes + byte[] span2 = test.readNBytes(mark2); + for (int i = 0; i < span2.length; i++) { + assertThat(span2[i], Matchers.is(b[mark1 + i])); + } + // second mark + test.mark(len); + byte[] span3 = test.readNBytes(offset2); + for (int i = 0; i < span3.length; i++) { + assertThat(span3[i], Matchers.is(b[mark1 + mark2 + i])); + } + // reset to second mark + test.reset(); + // read rest of bytes + byte[] span4 = test.readAllBytes(); + for (int i = 0; i < span4.length; i++) { + assertThat(span4[i], Matchers.is(b[mark1 + mark2 + i])); + } + } + } + } + } + } + } + + private byte[] concatenateArrays(byte[] b1, byte[] b2) { + byte[] result = new byte[b1.length + b2.length]; + System.arraycopy(b1, 0, result, 0, b1.length); + System.arraycopy(b2, 0, result, b1.length, b2.length); + return result; + } + + private Tuple testEmptyComponentsInChain(int componentCount, List emptyComponentIndices) + throws Exception { + byte[] result = new byte[0]; + InputStream[] sourceComponents = new InputStream[componentCount]; + for (int i = 0; i < componentCount; i++) { + if (emptyComponentIndices.contains(i)) { + sourceComponents[i] = InputStream.nullInputStream(); + } else { + byte[] b = randomByteArrayOfLength(randomIntBetween(1, 8)); + sourceComponents[i] = new ByteArrayInputStream(b); + result = concatenateArrays(result, b); + } + } + return new Tuple<>(new ChainingInputStream() { + int i = 0; + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (i < sourceComponents.length) { + return sourceComponents[i++]; + } else { + return null; + } + } + + @Override + public boolean markSupported() { + return false; + } + }, result); + } + + private ChainingInputStream newEmptyStream(boolean hasEmptyComponents) { + if (hasEmptyComponents) { + final Iterator iterator = Arrays.asList( + randomArray(1, 5, ByteArrayInputStream[]::new, () -> new ByteArrayInputStream(new byte[0])) + ).iterator(); + return new ChainingInputStream() { + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (iterator.hasNext()) { + return iterator.next(); + } else { + return null; + } + } + }; + } else { + return new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentElementIn) throws IOException { + return null; + } + }; + } + } + + static class TestInputStream extends InputStream { + + final byte[] b; + final int label; + final int len; + int i = 0; + int mark = -1; + final AtomicBoolean closed = new AtomicBoolean(false); + + TestInputStream(byte[] b) { + this(b, 0, b.length, 0); + } + + TestInputStream(byte[] b, int label) { + this(b, 0, b.length, label); + } + + TestInputStream(byte[] b, int offset, int len, int label) { + this.b = b; + this.i = offset; + this.len = len; + this.label = label; + } + + @Override + public int read() throws IOException { + if (b == null || i >= len) { + return -1; + } + return b[i++] & 0xFF; + } + + @Override + public void close() throws IOException { + closed.set(true); + } + + @Override + public void mark(int readlimit) { + this.mark = i; + } + + @Override + public void reset() { + this.i = this.mark; + } + + @Override + public boolean markSupported() { + return true; + } + + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/CountingInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/CountingInputStreamTests.java new file mode 100644 index 0000000000000..894e6b4cb088c --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/CountingInputStreamTests.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CountingInputStreamTests extends ESTestCase { + + private static byte[] testArray; + + @BeforeClass + static void createTestArray() throws Exception { + testArray = new byte[32]; + for (int i = 0; i < testArray.length; i++) { + testArray[i] = (byte) i; + } + } + + public void testWrappedMarkAndClose() throws Exception { + AtomicBoolean isClosed = new AtomicBoolean(false); + InputStream mockIn = mock(InputStream.class); + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocation) { + isClosed.set(true); + return null; + } + }).when(mockIn).close(); + new CountingInputStream(mockIn, true).close(); + assertThat(isClosed.get(), Matchers.is(true)); + isClosed.set(false); + new CountingInputStream(mockIn, false).close(); + assertThat(isClosed.get(), Matchers.is(false)); + when(mockIn.markSupported()).thenAnswer(invocationOnMock -> { return false; }); + assertThat(new CountingInputStream(mockIn, randomBoolean()).markSupported(), Matchers.is(false)); + when(mockIn.markSupported()).thenAnswer(invocationOnMock -> { return true; }); + assertThat(new CountingInputStream(mockIn, randomBoolean()).markSupported(), Matchers.is(true)); + } + + public void testSimpleCountForRead() throws Exception { + CountingInputStream test = new CountingInputStream(new ByteArrayInputStream(testArray), randomBoolean()); + assertThat(test.getCount(), Matchers.is(0L)); + int readLen = Randomness.get().nextInt(testArray.length); + test.readNBytes(readLen); + assertThat(test.getCount(), Matchers.is((long) readLen)); + readLen = testArray.length - readLen; + test.readNBytes(readLen); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + test.close(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + } + + public void testSimpleCountForSkip() throws Exception { + CountingInputStream test = new CountingInputStream(new ByteArrayInputStream(testArray), randomBoolean()); + assertThat(test.getCount(), Matchers.is(0L)); + int skipLen = Randomness.get().nextInt(testArray.length); + test.skip(skipLen); + assertThat(test.getCount(), Matchers.is((long) skipLen)); + skipLen = testArray.length - skipLen; + test.readNBytes(skipLen); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + test.close(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + } + + public void testCountingForMarkAndReset() throws Exception { + CountingInputStream test = new CountingInputStream(new ByteArrayInputStream(testArray), randomBoolean()); + assertThat(test.getCount(), Matchers.is(0L)); + assertThat(test.markSupported(), Matchers.is(true)); + int offset1 = Randomness.get().nextInt(testArray.length - 1); + if (randomBoolean()) { + test.skip(offset1); + } else { + test.read(new byte[offset1]); + } + assertThat(test.getCount(), Matchers.is((long) offset1)); + test.mark(testArray.length); + int offset2 = 1 + Randomness.get().nextInt(testArray.length - offset1 - 1); + if (randomBoolean()) { + test.skip(offset2); + } else { + test.read(new byte[offset2]); + } + assertThat(test.getCount(), Matchers.is((long) offset1 + offset2)); + test.reset(); + assertThat(test.getCount(), Matchers.is((long) offset1)); + int offset3 = Randomness.get().nextInt(offset2); + if (randomBoolean()) { + test.skip(offset3); + } else { + test.read(new byte[offset3]); + } + assertThat(test.getCount(), Matchers.is((long) offset1 + offset3)); + test.reset(); + assertThat(test.getCount(), Matchers.is((long) offset1)); + test.readAllBytes(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + test.close(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + } + + public void testCountingForMarkAfterReset() throws Exception { + CountingInputStream test = new CountingInputStream(new ByteArrayInputStream(testArray), randomBoolean()); + assertThat(test.getCount(), Matchers.is(0L)); + assertThat(test.markSupported(), Matchers.is(true)); + int offset1 = Randomness.get().nextInt(testArray.length - 1); + if (randomBoolean()) { + test.skip(offset1); + } else { + test.read(new byte[offset1]); + } + assertThat(test.getCount(), Matchers.is((long) offset1)); + test.mark(testArray.length); + int offset2 = 1 + Randomness.get().nextInt(testArray.length - offset1 - 1); + if (randomBoolean()) { + test.skip(offset2); + } else { + test.read(new byte[offset2]); + } + assertThat(test.getCount(), Matchers.is((long) offset1 + offset2)); + test.reset(); + assertThat(test.getCount(), Matchers.is((long) offset1)); + int offset3 = Randomness.get().nextInt(offset2); + if (randomBoolean()) { + test.skip(offset3); + } else { + test.read(new byte[offset3]); + } + test.mark(testArray.length); + assertThat(test.getCount(), Matchers.is((long) offset1 + offset3)); + int offset4 = Randomness.get().nextInt(testArray.length - offset1 - offset3); + if (randomBoolean()) { + test.skip(offset4); + } else { + test.read(new byte[offset4]); + } + assertThat(test.getCount(), Matchers.is((long) offset1 + offset3 + offset4)); + test.reset(); + assertThat(test.getCount(), Matchers.is((long) offset1 + offset3)); + test.readAllBytes(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + test.close(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStreamTests.java new file mode 100644 index 0000000000000..66470ce874007 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStreamTests.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import java.io.ByteArrayInputStream; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.SecureRandom; +import java.util.Arrays; + +public class DecryptionPacketsInputStreamTests extends ESTestCase { + + public void testSuccessEncryptAndDecryptSmallPacketLength() throws Exception { + int len = 8 + Randomness.get().nextInt(8); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + for (int packetLen : Arrays.asList(1, 2, 3, 4)) { + testEncryptAndDecryptSuccess(plainBytes, secretKey, nonce, packetLen); + } + } + + public void testSuccessEncryptAndDecryptLargePacketLength() throws Exception { + int len = 256 + Randomness.get().nextInt(256); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + for (int packetLen : Arrays.asList(len - 1, len - 2, len - 3, len - 4)) { + testEncryptAndDecryptSuccess(plainBytes, secretKey, nonce, packetLen); + } + } + + public void testSuccessEncryptAndDecryptTypicalPacketLength() throws Exception { + int len = 1024 + Randomness.get().nextInt(512); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + for (int packetLen : Arrays.asList(128, 256, 512)) { + testEncryptAndDecryptSuccess(plainBytes, secretKey, nonce, packetLen); + } + } + + public void testFailureEncryptAndDecryptWrongKey() throws Exception { + int len = 256 + Randomness.get().nextInt(256); + // 2-3 packets + int packetLen = 1 + Randomness.get().nextInt(len / 2); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey encryptSecretKey = generateSecretKey(); + SecretKey decryptSecretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + byte[] encryptedBytes; + try ( + InputStream in = new EncryptionPacketsInputStream( + new ByteArrayInputStream(plainBytes, 0, len), + encryptSecretKey, + nonce, + packetLen + ) + ) { + encryptedBytes = in.readAllBytes(); + } + try (InputStream in = new DecryptionPacketsInputStream(new ByteArrayInputStream(encryptedBytes), decryptSecretKey, packetLen)) { + IOException e = expectThrows(IOException.class, () -> { in.readAllBytes(); }); + assertThat(e.getMessage(), Matchers.is("Exception during packet decryption")); + } + } + + public void testFailureEncryptAndDecryptAlteredCiphertext() throws Exception { + int len = 8 + Randomness.get().nextInt(8); + // one packet + int packetLen = len + Randomness.get().nextInt(8); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + byte[] encryptedBytes; + try (InputStream in = new EncryptionPacketsInputStream(new ByteArrayInputStream(plainBytes, 0, len), secretKey, nonce, packetLen)) { + encryptedBytes = in.readAllBytes(); + } + for (int i = EncryptedRepository.GCM_IV_LENGTH_IN_BYTES; i < EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + len + + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; i++) { + for (int j = 0; j < 8; j++) { + // flip bit + encryptedBytes[i] ^= (1 << j); + // fail decryption + try (InputStream in = new DecryptionPacketsInputStream(new ByteArrayInputStream(encryptedBytes), secretKey, packetLen)) { + IOException e = expectThrows(IOException.class, () -> { in.readAllBytes(); }); + assertThat(e.getMessage(), Matchers.is("Exception during packet decryption")); + } + // flip bit back + encryptedBytes[i] ^= (1 << j); + } + } + } + + public void testFailureEncryptAndDecryptAlteredCiphertextIV() throws Exception { + int len = 8 + Randomness.get().nextInt(8); + int packetLen = 4 + Randomness.get().nextInt(4); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + byte[] encryptedBytes; + try (InputStream in = new EncryptionPacketsInputStream(new ByteArrayInputStream(plainBytes, 0, len), secretKey, nonce, packetLen)) { + encryptedBytes = in.readAllBytes(); + } + assertThat(encryptedBytes.length, Matchers.is((int) EncryptionPacketsInputStream.getEncryptionLength(len, packetLen))); + int encryptedPacketLen = EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + packetLen + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + for (int i = 0; i < encryptedBytes.length; i += encryptedPacketLen) { + for (int j = 0; j < EncryptedRepository.GCM_IV_LENGTH_IN_BYTES; j++) { + for (int k = 0; k < 8; k++) { + // flip bit + encryptedBytes[i + j] ^= (1 << k); + try ( + InputStream in = new DecryptionPacketsInputStream(new ByteArrayInputStream(encryptedBytes), secretKey, packetLen) + ) { + IOException e = expectThrows(IOException.class, () -> { in.readAllBytes(); }); + if (j < Integer.BYTES) { + assertThat(e.getMessage(), Matchers.startsWith("Exception during packet decryption")); + } else { + assertThat(e.getMessage(), Matchers.startsWith("Packet counter mismatch")); + } + } + // flip bit back + encryptedBytes[i + j] ^= (1 << k); + } + } + } + } + + private void testEncryptAndDecryptSuccess(byte[] plainBytes, SecretKey secretKey, int nonce, int packetLen) throws Exception { + for (int len = 0; len <= plainBytes.length; len++) { + byte[] encryptedBytes; + try ( + InputStream in = new EncryptionPacketsInputStream(new ByteArrayInputStream(plainBytes, 0, len), secretKey, nonce, packetLen) + ) { + encryptedBytes = in.readAllBytes(); + } + assertThat((long) encryptedBytes.length, Matchers.is(EncryptionPacketsInputStream.getEncryptionLength(len, packetLen))); + byte[] decryptedBytes; + try ( + InputStream in = new DecryptionPacketsInputStream( + new ReadLessFilterInputStream(new ByteArrayInputStream(encryptedBytes)), + secretKey, + packetLen + ) + ) { + decryptedBytes = in.readAllBytes(); + } + assertThat(decryptedBytes.length, Matchers.is(len)); + assertThat( + (long) decryptedBytes.length, + Matchers.is(DecryptionPacketsInputStream.getDecryptionLength(encryptedBytes.length, packetLen)) + ); + for (int i = 0; i < len; i++) { + assertThat(decryptedBytes[i], Matchers.is(plainBytes[i])); + } + } + } + + // input stream that reads less bytes than asked to, testing that packet-wide reads don't rely on `read` calls for memory buffers which + // always return the same number of bytes they are asked to + private static class ReadLessFilterInputStream extends FilterInputStream { + + protected ReadLessFilterInputStream(InputStream in) { + super(in); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return 0; + } + return super.read(b, off, randomIntBetween(1, len)); + } + } + + private SecretKey generateSecretKey() throws Exception { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(256, new SecureRandom()); + return keyGen.generateKey(); + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java new file mode 100644 index 0000000000000..6bb46603f7ba9 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.cluster.service.ClusterApplierService; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.blobstore.BlobContainer; +import org.elasticsearch.common.blobstore.BlobPath; +import org.elasticsearch.common.blobstore.BlobStore; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.repositories.RepositoryException; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.Before; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class EncryptedRepositoryTests extends ESTestCase { + + private SecureString repoPassword; + private BlobPath delegatedPath; + private BlobStore delegatedBlobStore; + private BlobStoreRepository delegatedRepository; + private RepositoryMetadata repositoryMetadata; + private EncryptedRepository encryptedRepository; + private EncryptedRepository.EncryptedBlobStore encryptedBlobStore; + private Map blobsMap; + + @Before + public void setUpMocks() throws Exception { + this.repoPassword = new SecureString(randomAlphaOfLength(20).toCharArray()); + this.delegatedPath = randomFrom( + BlobPath.cleanPath(), + BlobPath.cleanPath().add(randomAlphaOfLength(8)), + BlobPath.cleanPath().add(randomAlphaOfLength(4)).add(randomAlphaOfLength(4)) + ); + this.delegatedBlobStore = mock(BlobStore.class); + this.delegatedRepository = mock(BlobStoreRepository.class); + when(delegatedRepository.blobStore()).thenReturn(delegatedBlobStore); + when(delegatedRepository.basePath()).thenReturn(delegatedPath); + this.repositoryMetadata = new RepositoryMetadata( + randomAlphaOfLength(4), + EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME, + Settings.EMPTY + ); + ClusterApplierService clusterApplierService = mock(ClusterApplierService.class); + when(clusterApplierService.threadPool()).thenReturn(mock(ThreadPool.class)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterApplierService()).thenReturn(clusterApplierService); + this.encryptedRepository = new EncryptedRepository( + repositoryMetadata, + mock(NamedXContentRegistry.class), + clusterService, + mock(BigArrays.class), + mock(RecoverySettings.class), + delegatedRepository, + () -> mock(XPackLicenseState.class), + repoPassword + ); + this.encryptedBlobStore = (EncryptedRepository.EncryptedBlobStore) encryptedRepository.createBlobStore(); + this.blobsMap = new HashMap<>(); + doAnswer(invocationOnMockBlobStore -> { + BlobPath blobPath = ((BlobPath) invocationOnMockBlobStore.getArguments()[0]); + BlobContainer blobContainer = mock(BlobContainer.class); + // write atomic + doAnswer(invocationOnMockBlobContainer -> { + String DEKId = ((String) invocationOnMockBlobContainer.getArguments()[0]); + BytesReference DEKBytesReference = ((BytesReference) invocationOnMockBlobContainer.getArguments()[1]); + this.blobsMap.put(blobPath.add(DEKId), BytesReference.toBytes(DEKBytesReference)); + return null; + }).when(blobContainer).writeBlobAtomic(any(String.class), any(BytesReference.class), anyBoolean()); + // read + doAnswer(invocationOnMockBlobContainer -> { + String DEKId = ((String) invocationOnMockBlobContainer.getArguments()[0]); + return new ByteArrayInputStream(blobsMap.get(blobPath.add(DEKId))); + }).when(blobContainer).readBlob(any(String.class)); + return blobContainer; + }).when(this.delegatedBlobStore).blobContainer(any(BlobPath.class)); + } + + public void testStoreDEKSuccess() throws Exception { + String DEKId = randomAlphaOfLengthBetween(16, 32); // at least 128 bits because of FIPS + SecretKey DEK = new SecretKeySpec(randomByteArrayOfLength(32), "AES"); + + encryptedBlobStore.storeDEK(DEKId, DEK); + + Tuple KEK = encryptedRepository.generateKEK(DEKId); + assertThat(blobsMap.keySet(), contains(delegatedPath.add(EncryptedRepository.DEK_ROOT_CONTAINER).add(DEKId).add(KEK.v1()))); + byte[] wrappedKey = blobsMap.values().iterator().next(); + SecretKey unwrappedKey = AESKeyUtils.unwrap(KEK.v2(), wrappedKey); + assertThat(unwrappedKey.getEncoded(), equalTo(DEK.getEncoded())); + } + + public void testGetDEKSuccess() throws Exception { + String DEKId = randomAlphaOfLengthBetween(16, 32); // at least 128 bits because of FIPS + SecretKey DEK = new SecretKeySpec(randomByteArrayOfLength(32), "AES"); + Tuple KEK = encryptedRepository.generateKEK(DEKId); + + byte[] wrappedDEK = AESKeyUtils.wrap(KEK.v2(), DEK); + blobsMap.put(delegatedPath.add(EncryptedRepository.DEK_ROOT_CONTAINER).add(DEKId).add(KEK.v1()), wrappedDEK); + + SecretKey loadedDEK = encryptedBlobStore.getDEKById(DEKId); + assertThat(loadedDEK.getEncoded(), equalTo(DEK.getEncoded())); + } + + public void testGetTamperedDEKFails() throws Exception { + String DEKId = randomAlphaOfLengthBetween(16, 32); // at least 128 bits because of FIPS + SecretKey DEK = new SecretKeySpec("01234567890123456789012345678901".getBytes(StandardCharsets.UTF_8), "AES"); + Tuple KEK = encryptedRepository.generateKEK(DEKId); + + byte[] wrappedDEK = AESKeyUtils.wrap(KEK.v2(), DEK); + int tamperPos = randomIntBetween(0, wrappedDEK.length - 1); + wrappedDEK[tamperPos] ^= 0xFF; + blobsMap.put(delegatedPath.add(EncryptedRepository.DEK_ROOT_CONTAINER).add(DEKId).add(KEK.v1()), wrappedDEK); + + RepositoryException e = expectThrows(RepositoryException.class, () -> encryptedBlobStore.getDEKById(DEKId)); + assertThat(e.repository(), equalTo(repositoryMetadata.name())); + assertThat(e.getMessage(), containsString("Failure to AES unwrap the DEK")); + } + + public void testGetDEKIOException() { + doAnswer(invocationOnMockBlobStore -> { + BlobPath blobPath = ((BlobPath) invocationOnMockBlobStore.getArguments()[0]); + BlobContainer blobContainer = mock(BlobContainer.class); + // read + doAnswer(invocationOnMockBlobContainer -> { throw new IOException("Tested IOException"); }).when(blobContainer) + .readBlob(any(String.class)); + return blobContainer; + }).when(this.delegatedBlobStore).blobContainer(any(BlobPath.class)); + IOException e = expectThrows(IOException.class, () -> encryptedBlobStore.getDEKById("this must be at least 16")); + assertThat(e.getMessage(), containsString("Tested IOException")); + } + + public void testGenerateKEK() { + String id1 = "fixed identifier 1"; + String id2 = "fixed identifier 2"; + Tuple KEK1 = encryptedRepository.generateKEK(id1); + Tuple KEK2 = encryptedRepository.generateKEK(id2); + assertThat(KEK1.v1(), not(equalTo(KEK2.v1()))); + assertThat(KEK1.v2(), not(equalTo(KEK2.v2()))); + Tuple sameKEK1 = encryptedRepository.generateKEK(id1); + assertThat(KEK1.v1(), equalTo(sameKEK1.v1())); + assertThat(KEK1.v2(), equalTo(sameKEK1.v2())); + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStreamTests.java new file mode 100644 index 0000000000000..e2d8a9b0fd64d --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStreamTests.java @@ -0,0 +1,536 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.mockito.Mockito; + +import javax.crypto.Cipher; +import javax.crypto.CipherInputStream; +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class EncryptionPacketsInputStreamTests extends ESTestCase { + + private static int TEST_ARRAY_SIZE = 1 << 20; + private static byte[] testPlaintextArray; + private static SecretKey secretKey; + + @BeforeClass + static void createSecretKeyAndTestArray() throws Exception { + try { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(256, new SecureRandom()); + secretKey = keyGen.generateKey(); + } catch (Exception e) { + throw new RuntimeException(e); + } + testPlaintextArray = new byte[TEST_ARRAY_SIZE]; + Randomness.get().nextBytes(testPlaintextArray); + } + + public void testEmpty() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(2048); + testEncryptPacketWise(0, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSingleByteSize() throws Exception { + testEncryptPacketWise(1, 1, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(1, 2, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(1, 3, new DefaultBufferedReadAllStrategy()); + int packetSize = 4 + Randomness.get().nextInt(2046); + testEncryptPacketWise(1, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeSmallerThanPacketSize() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(2045); + testEncryptPacketWise(packetSize - 1, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetSize - 2, packetSize, new DefaultBufferedReadAllStrategy()); + int size = 1 + Randomness.get().nextInt(packetSize - 1); + testEncryptPacketWise(size, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeEqualToPacketSize() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(2048); + testEncryptPacketWise(packetSize, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeLargerThanPacketSize() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(2048); + testEncryptPacketWise(packetSize + 1, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetSize + 2, packetSize, new DefaultBufferedReadAllStrategy()); + int size = packetSize + 3 + Randomness.get().nextInt(packetSize); + testEncryptPacketWise(size, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeMultipleOfPacketSize() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(512); + testEncryptPacketWise(2 * packetSize, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(3 * packetSize, packetSize, new DefaultBufferedReadAllStrategy()); + int packetCount = 4 + Randomness.get().nextInt(12); + testEncryptPacketWise(packetCount * packetSize, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeAlmostMultipleOfPacketSize() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(510); + int packetCount = 2 + Randomness.get().nextInt(15); + testEncryptPacketWise(packetCount * packetSize - 1, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetCount * packetSize - 2, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetCount * packetSize + 1, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetCount * packetSize + 2, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testShortPacketSizes() throws Exception { + int packetCount = 2 + Randomness.get().nextInt(15); + testEncryptPacketWise(2 + Randomness.get().nextInt(15), 1, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(4 + Randomness.get().nextInt(30), 2, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(6 + Randomness.get().nextInt(45), 3, new DefaultBufferedReadAllStrategy()); + } + + public void testPacketSizeMultipleOfAESBlockSize() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(8); + testEncryptPacketWise( + 1 + Randomness.get().nextInt(packetSize * EncryptedRepository.AES_BLOCK_LENGTH_IN_BYTES), + packetSize * EncryptedRepository.AES_BLOCK_LENGTH_IN_BYTES, + new DefaultBufferedReadAllStrategy() + ); + testEncryptPacketWise( + packetSize * EncryptedRepository.AES_BLOCK_LENGTH_IN_BYTES + Randomness.get().nextInt(8192), + packetSize * EncryptedRepository.AES_BLOCK_LENGTH_IN_BYTES, + new DefaultBufferedReadAllStrategy() + ); + } + + public void testMarkAndResetPacketBoundaryNoMock() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(512); + int size = 4 * packetSize + Randomness.get().nextInt(512); + int plaintextOffset = Randomness.get().nextInt(testPlaintextArray.length - size + 1); + int nonce = Randomness.get().nextInt(); + final byte[] referenceCiphertextArray; + try ( + InputStream encryptionInputStream = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, size), + secretKey, + nonce, + packetSize + ) + ) { + referenceCiphertextArray = encryptionInputStream.readAllBytes(); + } + assertThat((long) referenceCiphertextArray.length, Matchers.is(EncryptionPacketsInputStream.getEncryptionLength(size, packetSize))); + int encryptedPacketSize = packetSize + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + try ( + InputStream encryptionInputStream = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, size), + secretKey, + nonce, + packetSize + ) + ) { + // mark at the beginning + encryptionInputStream.mark(encryptedPacketSize - 1); + byte[] test = encryptionInputStream.readNBytes(1 + Randomness.get().nextInt(encryptedPacketSize)); + assertSubArray(referenceCiphertextArray, 0, test, 0, test.length); + // reset at the beginning + encryptionInputStream.reset(); + // read packet fragment + test = encryptionInputStream.readNBytes(1 + Randomness.get().nextInt(encryptedPacketSize)); + assertSubArray(referenceCiphertextArray, 0, test, 0, test.length); + // reset at the beginning + encryptionInputStream.reset(); + // read complete packet + test = encryptionInputStream.readNBytes(encryptedPacketSize); + assertSubArray(referenceCiphertextArray, 0, test, 0, test.length); + // mark at the second packet boundary + encryptionInputStream.mark(Integer.MAX_VALUE); + // read more than one packet + test = encryptionInputStream.readNBytes(encryptedPacketSize + 1 + Randomness.get().nextInt(encryptedPacketSize)); + assertSubArray(referenceCiphertextArray, encryptedPacketSize, test, 0, test.length); + // reset at the second packet boundary + encryptionInputStream.reset(); + int middlePacketOffset = Randomness.get().nextInt(encryptedPacketSize); + test = encryptionInputStream.readNBytes(middlePacketOffset); + assertSubArray(referenceCiphertextArray, encryptedPacketSize, test, 0, test.length); + // read up to the third packet boundary + test = encryptionInputStream.readNBytes(encryptedPacketSize - middlePacketOffset); + assertSubArray(referenceCiphertextArray, encryptedPacketSize + middlePacketOffset, test, 0, test.length); + // mark at the third packet boundary + encryptionInputStream.mark(Integer.MAX_VALUE); + test = encryptionInputStream.readAllBytes(); + assertSubArray(referenceCiphertextArray, 2 * encryptedPacketSize, test, 0, test.length); + encryptionInputStream.reset(); + test = encryptionInputStream.readNBytes( + 1 + Randomness.get().nextInt(referenceCiphertextArray.length - 2 * encryptedPacketSize) + ); + assertSubArray(referenceCiphertextArray, 2 * encryptedPacketSize, test, 0, test.length); + } + } + + public void testMarkResetInsidePacketNoMock() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(64); + int encryptedPacketSize = EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + packetSize + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + int size = 3 * packetSize + Randomness.get().nextInt(64); + byte[] bytes = new byte[size]; + Randomness.get().nextBytes(bytes); + int nonce = Randomness.get().nextInt(); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream(new TestInputStream(bytes), secretKey, nonce, packetSize); + int offset1 = 1 + Randomness.get().nextInt(encryptedPacketSize - 1); + // read past the first packet + test.readNBytes(encryptedPacketSize + offset1); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE + 2)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset1)); + assertThat(test.markCounter, Matchers.nullValue()); + int readLimit = 1 + Randomness.get().nextInt(packetSize); + // first mark + test.mark(readLimit); + assertThat(test.markCounter, Matchers.is(test.counter)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(test.markIn, Matchers.is(test.currentIn)); + assertThat(((CountingInputStream) test.markIn).mark, Matchers.is((long) offset1)); + assertThat(((TestInputStream) test.source).mark, Matchers.is(-1)); + // read before packet is complete + test.readNBytes(1 + Randomness.get().nextInt(encryptedPacketSize - offset1)); + assertThat(((TestInputStream) test.source).mark, Matchers.is(-1)); + // reset + test.reset(); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(test.counter, Matchers.is(test.markCounter)); + assertThat(test.currentIn, Matchers.is(test.markIn)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset1)); + // read before the packet is complete + int offset2 = 1 + Randomness.get().nextInt(encryptedPacketSize - offset1); + test.readNBytes(offset2); + assertThat(((TestInputStream) test.source).mark, Matchers.is(-1)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset1 + offset2)); + // second mark + readLimit = 1 + Randomness.get().nextInt(packetSize); + test.mark(readLimit); + assertThat(((TestInputStream) test.source).mark, Matchers.is(-1)); + assertThat(test.markCounter, Matchers.is(test.counter)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(test.markIn, Matchers.is(test.currentIn)); + assertThat(((CountingInputStream) test.markIn).mark, Matchers.is((long) offset1 + offset2)); + } + + public void testMarkResetAcrossPacketsNoMock() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(64); + int encryptedPacketSize = EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + packetSize + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + int size = 3 * packetSize + Randomness.get().nextInt(64); + byte[] bytes = new byte[size]; + Randomness.get().nextBytes(bytes); + int nonce = Randomness.get().nextInt(); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream(new TestInputStream(bytes), secretKey, nonce, packetSize); + int readLimit = 2 * size + Randomness.get().nextInt(4096); + // mark at the beginning + test.mark(readLimit); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(test.markIn, Matchers.nullValue()); + // read past the first packet + int offset1 = 1 + Randomness.get().nextInt(encryptedPacketSize); + test.readNBytes(encryptedPacketSize + offset1); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(((TestInputStream) test.source).mark, Matchers.is(0)); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE + 2)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset1)); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.markIn, Matchers.nullValue()); + // reset at the beginning + test.reset(); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.currentIn, Matchers.nullValue()); + assertThat(((TestInputStream) test.source).off, Matchers.is(0)); + // read past the first two packets + int offset2 = 1 + Randomness.get().nextInt(encryptedPacketSize); + test.readNBytes(2 * encryptedPacketSize + offset2); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(((TestInputStream) test.source).mark, Matchers.is(0)); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE + 3)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset2)); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.markIn, Matchers.nullValue()); + // mark inside the third packet + test.mark(readLimit); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE + 3)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset2)); + assertThat(test.markIn, Matchers.is(test.currentIn)); + assertThat(((CountingInputStream) test.markIn).mark, Matchers.is((long) offset2)); + // read until the end + test.readAllBytes(); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE + 3)); + assertThat(test.counter, Matchers.not(Long.MIN_VALUE + 3)); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(test.markIn, Matchers.not(test.currentIn)); + assertThat(((CountingInputStream) test.markIn).mark, Matchers.is((long) offset2)); + // reset + test.reset(); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE + 3)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset2)); + assertThat(test.markIn, Matchers.is(test.currentIn)); + } + + public void testMarkAfterResetNoMock() throws Exception { + int packetSize = 4 + Randomness.get().nextInt(4); + int plainLen = packetSize + 1 + Randomness.get().nextInt(packetSize - 1); + int plaintextOffset = Randomness.get().nextInt(testPlaintextArray.length - plainLen + 1); + int nonce = Randomness.get().nextInt(); + final byte[] referenceCiphertextArray; + try ( + InputStream encryptionInputStream = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, plainLen), + secretKey, + nonce, + packetSize + ) + ) { + referenceCiphertextArray = encryptionInputStream.readAllBytes(); + } + int encryptedLen = referenceCiphertextArray.length; + assertThat((long) encryptedLen, Matchers.is(EncryptionPacketsInputStream.getEncryptionLength(plainLen, packetSize))); + for (int mark1 = 0; mark1 < encryptedLen; mark1++) { + for (int offset1 = 0; offset1 < encryptedLen - mark1; offset1++) { + int mark2 = Randomness.get().nextInt(encryptedLen - mark1); + int offset2 = Randomness.get().nextInt(encryptedLen - mark1 - mark2); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, plainLen), + secretKey, + nonce, + packetSize + ); + // read "mark1" bytes + byte[] pre = test.readNBytes(mark1); + for (int i = 0; i < pre.length; i++) { + assertThat(pre[i], Matchers.is(referenceCiphertextArray[i])); + } + // first mark + test.mark(encryptedLen); + // read "offset" bytes + byte[] span1 = test.readNBytes(offset1); + for (int i = 0; i < span1.length; i++) { + assertThat(span1[i], Matchers.is(referenceCiphertextArray[mark1 + i])); + } + // reset back to "mark1" offset + test.reset(); + // read/replay "mark2" bytes + byte[] span2 = test.readNBytes(mark2); + for (int i = 0; i < span2.length; i++) { + assertThat(span2[i], Matchers.is(referenceCiphertextArray[mark1 + i])); + } + // second mark + test.mark(encryptedLen); + byte[] span3 = test.readNBytes(offset2); + for (int i = 0; i < span3.length; i++) { + assertThat(span3[i], Matchers.is(referenceCiphertextArray[mark1 + mark2 + i])); + } + // reset to second mark + test.reset(); + // read rest of bytes + byte[] span4 = test.readAllBytes(); + for (int i = 0; i < span4.length; i++) { + assertThat(span4[i], Matchers.is(referenceCiphertextArray[mark1 + mark2 + i])); + } + } + } + } + + public void testMark() throws Exception { + InputStream mockSource = mock(InputStream.class); + when(mockSource.markSupported()).thenAnswer(invocationOnMock -> true); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream( + mockSource, + mock(SecretKey.class), + Randomness.get().nextInt(), + 1 + Randomness.get().nextInt(32) + ); + int readLimit = 1 + Randomness.get().nextInt(4096); + InputStream mockMarkIn = mock(InputStream.class); + test.markIn = mockMarkIn; + InputStream mockCurrentIn = mock(InputStream.class); + test.currentIn = mockCurrentIn; + test.counter = Randomness.get().nextLong(); + test.markCounter = Randomness.get().nextLong(); + test.markSourceOnNextPacket = Randomness.get().nextInt(); + // mark + test.mark(readLimit); + verify(mockMarkIn).close(); + assertThat(test.markIn, Matchers.is(mockCurrentIn)); + verify(test.markIn).mark(Mockito.anyInt()); + assertThat(test.currentIn, Matchers.is(mockCurrentIn)); + assertThat(test.markCounter, Matchers.is(test.counter)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + } + + public void testReset() throws Exception { + InputStream mockSource = mock(InputStream.class); + when(mockSource.markSupported()).thenAnswer(invocationOnMock -> true); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream( + mockSource, + mock(SecretKey.class), + Randomness.get().nextInt(), + 1 + Randomness.get().nextInt(32) + ); + InputStream mockMarkIn = mock(InputStream.class); + test.markIn = mockMarkIn; + InputStream mockCurrentIn = mock(InputStream.class); + test.currentIn = mockCurrentIn; + test.counter = Randomness.get().nextLong(); + test.markCounter = Randomness.get().nextLong(); + // source requires reset as well + test.markSourceOnNextPacket = -1; + // reset + test.reset(); + verify(mockCurrentIn).close(); + assertThat(test.currentIn, Matchers.is(mockMarkIn)); + verify(test.currentIn).reset(); + assertThat(test.markIn, Matchers.is(mockMarkIn)); + assertThat(test.counter, Matchers.is(test.markCounter)); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + verify(mockSource).reset(); + } + + private void testEncryptPacketWise(int size, int packetSize, ReadStrategy readStrategy) throws Exception { + int encryptedPacketSize = packetSize + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + int plaintextOffset = Randomness.get().nextInt(testPlaintextArray.length - size + 1); + int nonce = Randomness.get().nextInt(); + long counter = EncryptedRepository.PACKET_START_COUNTER; + try ( + InputStream encryptionInputStream = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, size), + secretKey, + nonce, + packetSize + ) + ) { + byte[] ciphertextArray = readStrategy.readAll(encryptionInputStream); + assertThat((long) ciphertextArray.length, Matchers.is(EncryptionPacketsInputStream.getEncryptionLength(size, packetSize))); + for (int ciphertextOffset = 0; ciphertextOffset < ciphertextArray.length; ciphertextOffset += encryptedPacketSize) { + ByteBuffer ivBuffer = ByteBuffer.wrap(ciphertextArray, ciphertextOffset, EncryptedRepository.GCM_IV_LENGTH_IN_BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + assertThat(ivBuffer.getInt(), Matchers.is(nonce)); + assertThat(ivBuffer.getLong(), Matchers.is(counter++)); + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec( + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES * Byte.SIZE, + Arrays.copyOfRange(ciphertextArray, ciphertextOffset, ciphertextOffset + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES) + ); + Cipher packetCipher = Cipher.getInstance(EncryptedRepository.DATA_ENCRYPTION_SCHEME); + packetCipher.init(Cipher.DECRYPT_MODE, secretKey, gcmParameterSpec); + try ( + InputStream packetDecryptionInputStream = new CipherInputStream( + new ByteArrayInputStream( + ciphertextArray, + ciphertextOffset + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES, + packetSize + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES + ), + packetCipher + ) + ) { + byte[] decryptedCiphertext = packetDecryptionInputStream.readAllBytes(); + int decryptedPacketSize = size <= packetSize ? size : packetSize; + assertThat(decryptedCiphertext.length, Matchers.is(decryptedPacketSize)); + assertSubArray(decryptedCiphertext, 0, testPlaintextArray, plaintextOffset, decryptedPacketSize); + size -= decryptedPacketSize; + plaintextOffset += decryptedPacketSize; + } + } + } + } + + private void assertSubArray(byte[] arr1, int offset1, byte[] arr2, int offset2, int length) { + Objects.checkFromIndexSize(offset1, length, arr1.length); + Objects.checkFromIndexSize(offset2, length, arr2.length); + for (int i = 0; i < length; i++) { + assertThat("Mismatch at index [" + i + "]", arr1[offset1 + i], Matchers.is(arr2[offset2 + i])); + } + } + + interface ReadStrategy { + byte[] readAll(InputStream inputStream) throws IOException; + } + + static class DefaultBufferedReadAllStrategy implements ReadStrategy { + @Override + public byte[] readAll(InputStream inputStream) throws IOException { + return inputStream.readAllBytes(); + } + } + + static class TestInputStream extends InputStream { + + final byte[] b; + final int label; + final int len; + int off = 0; + int mark = -1; + final AtomicBoolean closed = new AtomicBoolean(false); + + TestInputStream(byte[] b) { + this(b, 0, b.length, 0); + } + + TestInputStream(byte[] b, int label) { + this(b, 0, b.length, label); + } + + TestInputStream(byte[] b, int offset, int len, int label) { + this.b = b; + this.off = offset; + this.len = len; + this.label = label; + } + + @Override + public int read() throws IOException { + if (b == null || off >= len) { + return -1; + } + return b[off++] & 0xFF; + } + + @Override + public void close() throws IOException { + closed.set(true); + } + + @Override + public void mark(int readlimit) { + this.mark = off; + } + + @Override + public void reset() { + this.off = this.mark; + } + + @Override + public boolean markSupported() { + return true; + } + + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/LocalStateEncryptedRepositoryPlugin.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/LocalStateEncryptedRepositoryPlugin.java new file mode 100644 index 0000000000000..7e6080ceac476 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/LocalStateEncryptedRepositoryPlugin.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.apache.lucene.index.IndexCommit; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus; +import org.elasticsearch.index.store.Store; +import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.repositories.IndexId; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.snapshots.SnapshotId; +import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin; + +import java.nio.file.Path; +import java.security.GeneralSecurityException; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; + +public final class LocalStateEncryptedRepositoryPlugin extends LocalStateCompositeXPackPlugin { + + final EncryptedRepositoryPlugin encryptedRepositoryPlugin; + + public LocalStateEncryptedRepositoryPlugin(final Settings settings, final Path configPath) { + super(settings, configPath); + final LocalStateEncryptedRepositoryPlugin thisVar = this; + + encryptedRepositoryPlugin = new EncryptedRepositoryPlugin() { + + @Override + protected XPackLicenseState getLicenseState() { + return thisVar.getLicenseState(); + } + + @Override + protected EncryptedRepository createEncryptedRepository( + RepositoryMetadata metadata, + NamedXContentRegistry registry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings, + BlobStoreRepository delegatedRepository, + Supplier licenseStateSupplier, + SecureString repoPassword + ) throws GeneralSecurityException { + return new TestEncryptedRepository( + metadata, + registry, + clusterService, + bigArrays, + recoverySettings, + delegatedRepository, + licenseStateSupplier, + repoPassword + ); + } + }; + plugins.add(encryptedRepositoryPlugin); + } + + static class TestEncryptedRepository extends EncryptedRepository { + private final Lock snapshotShardLock = new ReentrantLock(); + private final Condition snapshotShardCondition = snapshotShardLock.newCondition(); + private final AtomicBoolean snapshotShardBlock = new AtomicBoolean(false); + + TestEncryptedRepository( + RepositoryMetadata metadata, + NamedXContentRegistry registry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings, + BlobStoreRepository delegatedRepository, + Supplier licenseStateSupplier, + SecureString repoPassword + ) throws GeneralSecurityException { + super(metadata, registry, clusterService, bigArrays, recoverySettings, delegatedRepository, licenseStateSupplier, repoPassword); + } + + @Override + public void snapshotShard( + Store store, + MapperService mapperService, + SnapshotId snapshotId, + IndexId indexId, + IndexCommit snapshotIndexCommit, + String shardStateIdentifier, + IndexShardSnapshotStatus snapshotStatus, + Version repositoryMetaVersion, + Map userMetadata, + ActionListener listener + ) { + snapshotShardLock.lock(); + try { + while (snapshotShardBlock.get()) { + snapshotShardCondition.await(); + } + super.snapshotShard( + store, + mapperService, + snapshotId, + indexId, + snapshotIndexCommit, + shardStateIdentifier, + snapshotStatus, + repositoryMetaVersion, + userMetadata, + listener + ); + } catch (InterruptedException e) { + listener.onFailure(e); + } finally { + snapshotShardLock.unlock(); + } + } + + void blockSnapshotShard() { + snapshotShardLock.lock(); + try { + snapshotShardBlock.set(true); + snapshotShardCondition.signalAll(); + } finally { + snapshotShardLock.unlock(); + } + } + + void unblockSnapshotShard() { + snapshotShardLock.lock(); + try { + snapshotShardBlock.set(false); + snapshotShardCondition.signalAll(); + } finally { + snapshotShardLock.unlock(); + } + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/PrefixInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/PrefixInputStreamTests.java new file mode 100644 index 0000000000000..d9063325936ff --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/PrefixInputStreamTests.java @@ -0,0 +1,222 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class PrefixInputStreamTests extends ESTestCase { + + public void testZeroLength() throws Exception { + Tuple mockTuple = getMockBoundedInputStream(0); + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), 1 + Randomness.get().nextInt(32), randomBoolean()); + assertThat(test.available(), Matchers.is(0)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.skip(1 + Randomness.get().nextInt(32)), Matchers.is(0L)); + } + + public void testClose() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + int prefixLength = Randomness.get().nextInt(boundedLength); + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + test.close(); + int byteCountBefore = mockTuple.v1().get(); + IOException e = expectThrows(IOException.class, () -> { test.read(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { + byte[] b = new byte[1 + Randomness.get().nextInt(32)]; + test.read(b, 0, 1 + Randomness.get().nextInt(b.length)); + }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.skip(1 + Randomness.get().nextInt(32)); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.available(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + int byteCountAfter = mockTuple.v1().get(); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(0)); + // test closeSource parameter + AtomicBoolean isClosed = new AtomicBoolean(false); + InputStream mockIn = mock(InputStream.class); + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocation) { + isClosed.set(true); + return null; + } + }).when(mockIn).close(); + new PrefixInputStream(mockIn, 1 + Randomness.get().nextInt(32), true).close(); + assertThat(isClosed.get(), Matchers.is(true)); + isClosed.set(false); + new PrefixInputStream(mockIn, 1 + Randomness.get().nextInt(32), false).close(); + assertThat(isClosed.get(), Matchers.is(false)); + } + + public void testAvailable() throws Exception { + AtomicInteger available = new AtomicInteger(0); + int boundedLength = 1 + Randomness.get().nextInt(256); + InputStream mockIn = mock(InputStream.class); + when(mockIn.available()).thenAnswer(invocationOnMock -> { return available.get(); }); + PrefixInputStream test = new PrefixInputStream(mockIn, boundedLength, randomBoolean()); + assertThat(test.available(), Matchers.is(0)); + available.set(Randomness.get().nextInt(boundedLength)); + assertThat(test.available(), Matchers.is(available.get())); + available.set(boundedLength + 1 + Randomness.get().nextInt(boundedLength)); + assertThat(test.available(), Matchers.is(boundedLength)); + } + + public void testReadPrefixLength() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + int prefixLength = Randomness.get().nextInt(boundedLength); + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + int byteCountBefore = mockTuple.v1().get(); + byte[] b = test.readAllBytes(); + int byteCountAfter = mockTuple.v1().get(); + assertThat(b.length, Matchers.is(prefixLength)); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(prefixLength)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.available(), Matchers.is(0)); + assertThat(mockTuple.v2().read(), Matchers.not(-1)); + } + + public void testSkipPrefixLength() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + int prefixLength = Randomness.get().nextInt(boundedLength); + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + int byteCountBefore = mockTuple.v1().get(); + skipNBytes(test, prefixLength); + int byteCountAfter = mockTuple.v1().get(); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(prefixLength)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.available(), Matchers.is(0)); + assertThat(mockTuple.v2().read(), Matchers.not(-1)); + } + + public void testReadShorterWrapped() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + int prefixLength = boundedLength; + if (randomBoolean()) { + prefixLength += 1 + Randomness.get().nextInt(boundedLength); + } + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + int byteCountBefore = mockTuple.v1().get(); + byte[] b = test.readAllBytes(); + int byteCountAfter = mockTuple.v1().get(); + assertThat(b.length, Matchers.is(boundedLength)); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(boundedLength)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.available(), Matchers.is(0)); + assertThat(mockTuple.v2().read(), Matchers.is(-1)); + assertThat(mockTuple.v2().available(), Matchers.is(0)); + } + + public void testSkipShorterWrapped() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + final int prefixLength; + if (randomBoolean()) { + prefixLength = boundedLength + 1 + Randomness.get().nextInt(boundedLength); + } else { + prefixLength = boundedLength; + } + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + int byteCountBefore = mockTuple.v1().get(); + if (prefixLength == boundedLength) { + skipNBytes(test, prefixLength); + } else { + expectThrows(EOFException.class, () -> { skipNBytes(test, prefixLength); }); + } + int byteCountAfter = mockTuple.v1().get(); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(boundedLength)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.available(), Matchers.is(0)); + assertThat(mockTuple.v2().read(), Matchers.is(-1)); + assertThat(mockTuple.v2().available(), Matchers.is(0)); + } + + private Tuple getMockBoundedInputStream(int bound) throws IOException { + InputStream mockSource = mock(InputStream.class); + AtomicInteger bytesRemaining = new AtomicInteger(bound); + when(mockSource.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())).thenAnswer( + invocationOnMock -> { + final byte[] b = (byte[]) invocationOnMock.getArguments()[0]; + final int off = (int) invocationOnMock.getArguments()[1]; + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (bytesRemaining.get() <= 0) { + return -1; + } + int bytesCount = 1 + Randomness.get().nextInt(Math.min(len, bytesRemaining.get())); + bytesRemaining.addAndGet(-bytesCount); + return bytesCount; + } + } + ); + when(mockSource.read()).thenAnswer(invocationOnMock -> { + if (bytesRemaining.get() <= 0) { + return -1; + } + bytesRemaining.decrementAndGet(); + return Randomness.get().nextInt(256); + }); + when(mockSource.skip(org.mockito.Matchers.anyLong())).thenAnswer(invocationOnMock -> { + final long n = (long) invocationOnMock.getArguments()[0]; + if (n <= 0 || bytesRemaining.get() <= 0) { + return 0; + } + int bytesSkipped = 1 + Randomness.get().nextInt(Math.min(bytesRemaining.get(), Math.toIntExact(n))); + bytesRemaining.addAndGet(-bytesSkipped); + return bytesSkipped; + }); + when(mockSource.available()).thenAnswer(invocationOnMock -> { + if (bytesRemaining.get() <= 0) { + return 0; + } + return 1 + Randomness.get().nextInt(bytesRemaining.get()); + }); + when(mockSource.markSupported()).thenReturn(false); + return new Tuple<>(bytesRemaining, mockSource); + } + + private static void skipNBytes(InputStream in, long n) throws IOException { + if (n > 0) { + long ns = in.skip(n); + if (ns >= 0 && ns < n) { // skipped too few bytes + // adjust number to skip + n -= ns; + // read until requested number skipped or EOS reached + while (n > 0 && in.read() != -1) { + n--; + } + // if not enough skipped, then EOFE + if (n != 0) { + throw new EOFException(); + } + } else if (ns != n) { // skipped negative or too many bytes + throw new IOException("Unable to skip exactly"); + } + } + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/SingleUseKeyTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/SingleUseKeyTests.java new file mode 100644 index 0000000000000..034cc41a84888 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/SingleUseKeyTests.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.Before; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.contains; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SingleUseKeyTests extends ESTestCase { + + byte[] testKeyPlaintext; + SecretKey testKey; + BytesReference testKeyId; + + @Before + public void setUpMocks() { + testKeyPlaintext = randomByteArrayOfLength(32); + testKey = new SecretKeySpec(testKeyPlaintext, "AES"); + testKeyId = new BytesArray(randomAlphaOfLengthBetween(2, 32)); + } + + public void testNewKeySupplier() throws Exception { + CheckedSupplier singleUseKeySupplier = SingleUseKey.createSingleUseKeySupplier( + () -> new Tuple<>(testKeyId, testKey) + ); + SingleUseKey generatedSingleUseKey = singleUseKeySupplier.get(); + assertThat(generatedSingleUseKey.getKeyId(), equalTo(testKeyId)); + assertThat(generatedSingleUseKey.getNonce(), equalTo(SingleUseKey.MIN_NONCE)); + assertThat(generatedSingleUseKey.getKey().getEncoded(), equalTo(testKeyPlaintext)); + } + + public void testNonceIncrement() throws Exception { + int nonce = randomIntBetween(SingleUseKey.MIN_NONCE, SingleUseKey.MAX_NONCE - 2); + SingleUseKey singleUseKey = new SingleUseKey(testKeyId, testKey, nonce); + AtomicReference keyCurrentlyInUse = new AtomicReference<>(singleUseKey); + @SuppressWarnings("unchecked") + CheckedSupplier, IOException> keyGenerator = mock(CheckedSupplier.class); + CheckedSupplier singleUseKeySupplier = SingleUseKey.internalSingleUseKeySupplier( + keyGenerator, + keyCurrentlyInUse + ); + SingleUseKey generatedSingleUseKey = singleUseKeySupplier.get(); + assertThat(generatedSingleUseKey.getKeyId(), equalTo(testKeyId)); + assertThat(generatedSingleUseKey.getNonce(), equalTo(nonce)); + assertThat(generatedSingleUseKey.getKey().getEncoded(), equalTo(testKeyPlaintext)); + SingleUseKey generatedSingleUseKey2 = singleUseKeySupplier.get(); + assertThat(generatedSingleUseKey2.getKeyId(), equalTo(testKeyId)); + assertThat(generatedSingleUseKey2.getNonce(), equalTo(nonce + 1)); + assertThat(generatedSingleUseKey2.getKey().getEncoded(), equalTo(testKeyPlaintext)); + verifyZeroInteractions(keyGenerator); + } + + public void testConcurrentWrapAround() throws Exception { + int nThreads = 3; + TestThreadPool testThreadPool = new TestThreadPool( + "SingleUserKeyTests#testConcurrentWrapAround", + Settings.builder() + .put("thread_pool." + ThreadPool.Names.GENERIC + ".size", nThreads) + .put("thread_pool." + ThreadPool.Names.GENERIC + ".queue_size", 1) + .build() + ); + int nonce = SingleUseKey.MAX_NONCE; + SingleUseKey singleUseKey = new SingleUseKey(null, null, nonce); + + AtomicReference keyCurrentlyInUse = new AtomicReference<>(singleUseKey); + @SuppressWarnings("unchecked") + CheckedSupplier, IOException> keyGenerator = mock(CheckedSupplier.class); + when(keyGenerator.get()).thenReturn(new Tuple<>(testKeyId, testKey)); + CheckedSupplier singleUseKeySupplier = SingleUseKey.internalSingleUseKeySupplier( + keyGenerator, + keyCurrentlyInUse + ); + List generatedKeys = new ArrayList<>(nThreads); + for (int i = 0; i < nThreads; i++) { + generatedKeys.add(null); + } + for (int i = 0; i < nThreads; i++) { + final int resultIdx = i; + testThreadPool.generic().execute(() -> { + try { + generatedKeys.set(resultIdx, singleUseKeySupplier.get()); + } catch (IOException e) { + fail(); + } + }); + } + terminate(testThreadPool); + verify(keyGenerator, times(1)).get(); + assertThat(keyCurrentlyInUse.get().getNonce(), equalTo(SingleUseKey.MIN_NONCE + nThreads)); + assertThat(generatedKeys.stream().map(suk -> suk.getKey()).collect(Collectors.toSet()).size(), equalTo(1)); + assertThat( + generatedKeys.stream().map(suk -> suk.getKey().getEncoded()).collect(Collectors.toSet()).iterator().next(), + equalTo(testKeyPlaintext) + ); + assertThat(generatedKeys.stream().map(suk -> suk.getKeyId()).collect(Collectors.toSet()).iterator().next(), equalTo(testKeyId)); + assertThat(generatedKeys.stream().map(suk -> suk.getNonce()).collect(Collectors.toSet()).size(), equalTo(nThreads)); + assertThat( + generatedKeys.stream().map(suk -> suk.getNonce()).collect(Collectors.toSet()), + contains(SingleUseKey.MIN_NONCE, SingleUseKey.MIN_NONCE + 1, SingleUseKey.MIN_NONCE + 2) + ); + } + + public void testNonceWrapAround() throws Exception { + int nonce = SingleUseKey.MAX_NONCE; + SingleUseKey singleUseKey = new SingleUseKey(testKeyId, testKey, nonce); + AtomicReference keyCurrentlyInUse = new AtomicReference<>(singleUseKey); + byte[] newTestKeyPlaintext = randomByteArrayOfLength(32); + SecretKey newTestKey = new SecretKeySpec(newTestKeyPlaintext, "AES"); + BytesReference newTestKeyId = new BytesArray(randomAlphaOfLengthBetween(2, 32)); + CheckedSupplier singleUseKeySupplier = SingleUseKey.internalSingleUseKeySupplier( + () -> new Tuple<>(newTestKeyId, newTestKey), + keyCurrentlyInUse + ); + SingleUseKey generatedSingleUseKey = singleUseKeySupplier.get(); + assertThat(generatedSingleUseKey.getKeyId(), equalTo(newTestKeyId)); + assertThat(generatedSingleUseKey.getNonce(), equalTo(SingleUseKey.MIN_NONCE)); + assertThat(generatedSingleUseKey.getKey().getEncoded(), equalTo(newTestKeyPlaintext)); + } + + public void testGeneratorException() { + int nonce = SingleUseKey.MAX_NONCE; + SingleUseKey singleUseKey = new SingleUseKey(null, null, nonce); + AtomicReference keyCurrentlyInUse = new AtomicReference<>(singleUseKey); + CheckedSupplier singleUseKeySupplier = SingleUseKey.internalSingleUseKeySupplier( + () -> { throw new IOException("expected exception"); }, + keyCurrentlyInUse + ); + expectThrows(IOException.class, () -> singleUseKeySupplier.get()); + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/resources/rest-api-spec/test/repository_encrypted/10_basic.yml b/x-pack/plugin/repository-encrypted/src/test/resources/rest-api-spec/test/repository_encrypted/10_basic.yml new file mode 100644 index 0000000000000..858ba3e21e3ae --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/resources/rest-api-spec/test/repository_encrypted/10_basic.yml @@ -0,0 +1,16 @@ +# Integration tests for repository-encrypted +# +"Plugin repository-encrypted is loaded": + - skip: + reason: "contains is a newly added assertion" + features: contains + - do: + cluster.state: {} + + # Get master node id + - set: { master_node: master } + + - do: + nodes.info: {} + + - contains: { nodes.$master.plugins: { name: repository-encrypted } }