diff --git a/extensions/aws-ai/src/main/java/ai/djl/aws/sagemaker/SageMaker.java b/extensions/aws-ai/src/main/java/ai/djl/aws/sagemaker/SageMaker.java index 4ee3f271e50..a62cf16e9b8 100644 --- a/extensions/aws-ai/src/main/java/ai/djl/aws/sagemaker/SageMaker.java +++ b/extensions/aws-ai/src/main/java/ai/djl/aws/sagemaker/SageMaker.java @@ -26,6 +26,7 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.waiters.WaiterResponse; import software.amazon.awssdk.regions.Region; @@ -152,34 +153,64 @@ public void deploy() throws IOException { logger.info("SageMaker endpoint {} created: {}", endpointName, endpointArn); } - /** Deletes the Amazon SageMaker endpoint. */ - public void deleteEndpoint() { - logger.info("Deleting SageMaker endpoint {} ...", endpointName); - DeleteEndpointRequest req = - DeleteEndpointRequest.builder().endpointName(endpointName).build(); - sageMaker.deleteEndpoint(req); - SageMakerWaiter waiter = sageMaker.waiter(); - DescribeEndpointRequest waitReq = - DescribeEndpointRequest.builder().endpointName(endpointConfigName).build(); - waiter.waitUntilEndpointDeleted(waitReq); - logger.info("SageMaker endpoint {} deleted.", endpointName); + /** + * Deletes the Amazon SageMaker endpoint. + * + * @param quietly true to suppress error + */ + public void deleteEndpoint(boolean quietly) { + try { + logger.info("Deleting SageMaker endpoint {} ...", endpointName); + DeleteEndpointRequest req = + DeleteEndpointRequest.builder().endpointName(endpointName).build(); + sageMaker.deleteEndpoint(req); + SageMakerWaiter waiter = sageMaker.waiter(); + DescribeEndpointRequest waitReq = + DescribeEndpointRequest.builder().endpointName(endpointConfigName).build(); + waiter.waitUntilEndpointDeleted(waitReq); + logger.info("SageMaker endpoint {} deleted.", endpointName); + } catch (SdkException e) { + if (!quietly) { + throw e; + } + } } - /** Deletes the endpoint configuration. */ - public void deleteEndpointConfig() { - DeleteEndpointConfigRequest req = - DeleteEndpointConfigRequest.builder() - .endpointConfigName(endpointConfigName) - .build(); - sageMaker.deleteEndpointConfig(req); - logger.info("SageMaker endpoint config {} deleted.", endpointConfigName); + /** + * Deletes the endpoint configuration. + * + * @param quietly true to suppress error + */ + public void deleteEndpointConfig(boolean quietly) { + try { + DeleteEndpointConfigRequest req = + DeleteEndpointConfigRequest.builder() + .endpointConfigName(endpointConfigName) + .build(); + sageMaker.deleteEndpointConfig(req); + logger.info("SageMaker endpoint config {} deleted.", endpointConfigName); + } catch (SdkException e) { + if (!quietly) { + throw e; + } + } } - /** Deletes the SageMaker model configuration. */ - public void deleteSageMakerModel() { - DeleteModelRequest req = DeleteModelRequest.builder().modelName(modelName).build(); - sageMaker.deleteModel(req); - logger.info("SageMaker model {} deleted.", modelName); + /** + * Deletes the SageMaker model configuration. + * + * @param quietly true to suppress error + */ + public void deleteSageMakerModel(boolean quietly) { + try { + DeleteModelRequest req = DeleteModelRequest.builder().modelName(modelName).build(); + sageMaker.deleteModel(req); + logger.info("SageMaker model {} deleted.", modelName); + } catch (SdkException e) { + if (!quietly) { + throw e; + } + } } /** @@ -298,7 +329,7 @@ private Path tar(Path dir) throws IOException { BufferedOutputStream bos = new BufferedOutputStream(os); GzipCompressorOutputStream zos = new GzipCompressorOutputStream(bos); TarArchiveOutputStream tos = new TarArchiveOutputStream(zos)) { - + tos.setBigNumberMode(TarArchiveOutputStream.BIGNUMBER_STAR); addToTar(dir, dir, tos); tos.finish(); } diff --git a/extensions/aws-ai/src/test/java/ai/djl/aws/sagemaker/SageMakerTest.java b/extensions/aws-ai/src/test/java/ai/djl/aws/sagemaker/SageMakerTest.java index 8554043c999..30ada2f7a55 100644 --- a/extensions/aws-ai/src/test/java/ai/djl/aws/sagemaker/SageMakerTest.java +++ b/extensions/aws-ai/src/test/java/ai/djl/aws/sagemaker/SageMakerTest.java @@ -28,7 +28,6 @@ import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.core.exception.SdkException; import java.io.IOException; import java.io.InputStream; @@ -50,36 +49,40 @@ public void testDeployModel() throws IOException, ModelException { Criteria criteria = Criteria.builder() .setTypes(NDList.class, NDList.class) - .optModelUrls("https://resources.djl.ai/test-models/mlp.tar.gz") + .optModelUrls( + "https://resources.djl.ai/test-models/pytorch/resnet18_jit.tar.gz") .build(); try (ZooModel model = criteria.loadModel()) { SageMaker sageMaker = SageMaker.builder() .setModel(model) .optBucketName("djl-sm-test") - .optModelName("resnet") - .optContainerImage("125045733377.dkr.ecr.us-east-1.amazonaws.com/djl") + .optModelName("resnet18-jit") + .optContainerImage( + "125045733377.dkr.ecr.us-east-1.amazonaws.com/djl-serving") .optExecutionRole( "arn:aws:iam::125045733377:role/service-role/DJLSageMaker-ExecutionRole-20210213T1027050") .build(); - sageMaker.deploy(); + try { + sageMaker.deploy(); - byte[] image; - Path imagePath = Paths.get("../../examples/src/test/resources/0.png"); - try (InputStream is = Files.newInputStream(imagePath)) { - image = Utils.toByteArray(is); + byte[] image; + Path imagePath = Paths.get("../../examples/src/test/resources/kitten.jpg"); + try (InputStream is = Files.newInputStream(imagePath)) { + image = Utils.toByteArray(is); + } + String ret = new String(sageMaker.invoke(image), StandardCharsets.UTF_8); + Type type = new TypeToken>() {}.getType(); + List list = JsonUtils.GSON.fromJson(ret, type); + String className = list.get(0).getClassName(); + Assert.assertEquals(className, "n02123159 tiger cat"); + } finally { + sageMaker.deleteEndpoint(true); + sageMaker.deleteEndpointConfig(true); + sageMaker.deleteSageMakerModel(true); } - String ret = new String(sageMaker.invoke(image), StandardCharsets.UTF_8); - Type type = new TypeToken>() {}.getType(); - List list = JsonUtils.GSON.fromJson(ret, type); - String className = list.get(0).getClassName(); - Assert.assertEquals(className, "0"); - - sageMaker.deleteEndpoint(); - sageMaker.deleteEndpointConfig(); - sageMaker.deleteSageMakerModel(); - } catch (SdkException e) { + } catch (SdkClientException e) { throw new SkipException("Skip tests that requires permission.", e); } }