Skip to content

Commit

Permalink
Enable Token Authorization by default (#3163)
Browse files Browse the repository at this point in the history
* test

* token authorization integration

* env false regression cpu ci

* testing ci

* testing newman

* fix newman tests

* testing pytest

* testing cmd arg

* pytest fixes

* fixing tests

* doc update

* spell check

* fixing priority between config file and cmd

* test fixes

* removing unneeded files

* Delete unneeded files

* review fixes

* removing comments

* adding doc clarification and new test

* changes to docs

* adding new tests

* fixing merge conflict

* format fix

* format fixes

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-9-227.us-west-2.compute.internal>
  • Loading branch information
udaij12 and Ubuntu committed Jun 12, 2024
1 parent d29059f commit d622230
Show file tree
Hide file tree
Showing 33 changed files with 4,718 additions and 294 deletions.
36 changes: 23 additions & 13 deletions docs/token_authorization_api.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
# TorchServe token authorization API

## Setup
1. Download the jar files from [Maven](https://mvnrepository.com/artifact/org.pytorch/torchserve-endpoint-plugin)
2. Enable token authorization by adding the `--plugins-path /path/to/the/jar/files` flag at start up with the path leading to the downloaded jar files.
Torchserve now supports token authorization by default.

## How to set and disable Token Authorization
* Global environment variable: use `TS_DISABLE_TOKEN_AUTHORIZATION` and set to `true` to disable and `false` to enable token authorization. Note that `enable_envvars_config=true` must be set in config.properties for global environment variables to be used
* Command line: Command line can only be used to disable token authorization by adding the `--disable-token` flag.
* Config properties file: use `disable_token_authorization` and set to `true` to disable and `false` to enable token authorization.

Priority between env variables, cmd, and config file follows the following [TorchServer standard](https://github.com/pytorch/serve/blob/c74a29e8144bc12b84196775076b0e8cf3c5a6fc/docs/configuration.md#advanced-configuration)
* Example 1:
* Config file: `disable_token_authorization=false`

cmd line: `torchserve --start --ncs --model-store model_store --disable-token`

Result: Token authorization disabled through command line but enabled through config file, resulting in token authorization being disabled. Command line takes precedence
* Example 2:
* Config file: `disable_token_authorization=true`

cmd line: `torchserve --start --ncs --model-store model_store`

Result: Token authorization disable disabled through config file but not configured through command line, resulting in token authorization being disabled.

## Configuration
1. Torchserve will enable token authorization if the plugin is provided. Expected log statement `[INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Loading plugin for endpoint token`
1. Torchserve will enable token authorization by default. Expected log statement `main org.pytorch.serve.http.TokenAuthorizationHandler - Token Authorization Enabled`
2. In the current working directory a file `key_file.json` will be generated.
1. Example key file:

Expand Down Expand Up @@ -41,14 +58,7 @@

5. When users shut down the server the key_file will be deleted.


## Customization
Torchserve offers various ways to customize the token authorization to allow owners to reach the desired result.
1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30`
2. The token authorization code is consolidated in the plugin and thus can be changed without impacting the frontend or end result. The only thing the user cannot change is:
1. The urlPattern for the plugin must be 'token' and the class name must not change
2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signature must not change. However, the code in the functions can be modified depending on user necessity.

## Notes
1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file.
2. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models.
2. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30`
3. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should only be able to run inferences against models that have already been loaded. The owner can also provide owners with the management key if owners want users to add and remove models.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public static void main(String[] args) {
ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd);
ConfigManager.init(arguments);
ConfigManager configManager = ConfigManager.getInstance();
configManager.setupToken();
PluginsManager.getInstance().initialize();
MetricCache.init();
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
package org.pytorch.serve.http;

import com.google.gson.GsonBuilder;
import com.google.gson.JsonObject;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.QueryStringDecoder;
import java.lang.reflect.*;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.PosixFilePermission;
import java.nio.file.attribute.PosixFilePermissions;
import java.security.SecureRandom;
import java.time.Instant;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.InvalidKeyException;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.workflow.WorkflowException;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.util.TokenType;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.slf4j.Logger;
Expand All @@ -24,7 +40,7 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class);
private static TokenType tokenType;
private static Boolean tokenEnabled = false;
private static Class<?> tokenClass;
private static Token tokenClass;
private static Object tokenObject;
private static Double timeToExpirationMinutes = 60.0;

Expand All @@ -44,36 +60,44 @@ public void handleRequest(
if (tokenEnabled) {
if (tokenType == TokenType.MANAGEMENT) {
if (req.toString().contains("/token")) {
checkTokenAuthorization(req, "token");
try {
checkTokenAuthorization(req, "token");
String resp = tokenClass.updateKeyFile(req);
NettyUtils.sendJsonResponse(ctx, resp);
return;
} catch (Exception e) {
logger.error("TOKEN CLASS UPDATED UNSUCCESSFULLY");
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
} else {
checkTokenAuthorization(req, "management");
}
} else if (tokenType == TokenType.INFERENCE) {
checkTokenAuthorization(req, "inference");
}
} else {
if (tokenType == TokenType.MANAGEMENT && req.toString().contains("/token")) {
throw new ResourceNotFoundException();
}
}
chain.handleRequest(ctx, req, decoder, segments);
}

public static void setupTokenClass() {
try {
tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token");
tokenObject = tokenClass.getDeclaredConstructor().newInstance();
Method method = tokenClass.getMethod("setTime", Double.class);
tokenClass = new Token();
Double time = ConfigManager.getInstance().getTimeToExpiration();
String home = ConfigManager.getInstance().getModelServerHome();
tokenClass.setFilePath(home);
if (time != 0.0) {
timeToExpirationMinutes = time;
}
method.invoke(tokenObject, timeToExpirationMinutes);
method = tokenClass.getMethod("generateKeyFile", String.class);
if ((boolean) method.invoke(tokenObject, "token")) {
logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY");
tokenClass.setTime(timeToExpirationMinutes);
if (tokenClass.generateKeyFile("token")) {
logger.info("Token Authorization Enabled");
}
} catch (NoSuchMethodException
| IllegalAccessException
| InstantiationException
| InvocationTargetException
| ClassNotFoundException e) {
} catch (Exception e) {
e.printStackTrace();
logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
throw new IllegalStateException("Unable to import token class", e);
Expand All @@ -84,20 +108,208 @@ public static void setupTokenClass() {
private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException {

try {
Method method =
tokenClass.getMethod(
"checkTokenAuthorization",
io.netty.handler.codec.http.FullHttpRequest.class,
String.class);
boolean result = (boolean) (method.invoke(tokenObject, req, type));
boolean result = tokenClass.checkTokenAuthorization(req, type);
if (!result) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
e.printStackTrace();
} catch (Exception e) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
}
}

class Token {
private static String apiKey;
private static String managementKey;
private static String inferenceKey;
private static Instant managementExpirationTimeMinutes;
private static Instant inferenceExpirationTimeMinutes;
private static Double timeToExpirationMinutes;
private SecureRandom secureRandom = new SecureRandom();
private Base64.Encoder baseEncoder = Base64.getUrlEncoder();
private String fileName = "key_file.json";
private String filePath = "";

public String updateKeyFile(FullHttpRequest req) throws IOException {
String queryResponse = parseQuery(req);
String test = "";
if ("management".equals(queryResponse)) {
generateKeyFile("management");
} else if ("inference".equals(queryResponse)) {
generateKeyFile("inference");
} else {
test = "{\n\t\"Error\": " + queryResponse + "\n}\n";
}
return test;
}

// parses query and either returns management/inference or a wrong type error
public String parseQuery(FullHttpRequest req) {
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
Map<String, List<String>> parameters = decoder.parameters();
List<String> values = parameters.get("type");
if (values != null && !values.isEmpty()) {
if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) {
return values.get(0);
} else {
return "WRONG TYPE";
}
}
return "NO TYPE PROVIDED";
}

public String generateKey() {
byte[] randomBytes = new byte[6];
secureRandom.nextBytes(randomBytes);
return baseEncoder.encodeToString(randomBytes);
}

public Instant generateTokenExpiration() {
long secondsToAdd = (long) (timeToExpirationMinutes * 60);
return Instant.now().plusSeconds(secondsToAdd);
}

public void setFilePath(String path) {
filePath = path;
}

// generates a key file with new keys depending on the parameter provided
public boolean generateKeyFile(String type) throws IOException {
String userDirectory = filePath + "/" + fileName;
File file = new File(userDirectory);
if (!file.createNewFile() && !file.exists()) {
return false;
}
if (apiKey == null) {
apiKey = generateKey();
}
switch (type) {
case "management":
managementKey = generateKey();
managementExpirationTimeMinutes = generateTokenExpiration();
break;
case "inference":
inferenceKey = generateKey();
inferenceExpirationTimeMinutes = generateTokenExpiration();
break;
default:
managementKey = generateKey();
inferenceKey = generateKey();
inferenceExpirationTimeMinutes = generateTokenExpiration();
managementExpirationTimeMinutes = generateTokenExpiration();
}

JsonObject parentObject = new JsonObject();

JsonObject managementObject = new JsonObject();
managementObject.addProperty("key", managementKey);
managementObject.addProperty("expiration time", managementExpirationTimeMinutes.toString());
parentObject.add("management", managementObject);

JsonObject inferenceObject = new JsonObject();
inferenceObject.addProperty("key", inferenceKey);
inferenceObject.addProperty("expiration time", inferenceExpirationTimeMinutes.toString());
parentObject.add("inference", inferenceObject);

JsonObject apiObject = new JsonObject();
apiObject.addProperty("key", apiKey);
parentObject.add("API", apiObject);

Files.write(
Paths.get(fileName),
new GsonBuilder()
.setPrettyPrinting()
.create()
.toJson(parentObject)
.getBytes(StandardCharsets.UTF_8));

if (!setFilePermissions()) {
try {
Files.delete(Paths.get(fileName));
} catch (IOException e) {
return false;
}
return false;
}
return true;
}

public boolean setFilePermissions() {
Path path = Paths.get(fileName);
try {
Set<PosixFilePermission> permissions = PosixFilePermissions.fromString("rw-------");
Files.setPosixFilePermissions(path, permissions);
} catch (Exception e) {
return false;
}
return true;
}

// checks the token provided in the http with the saved keys depening on parameters
public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
String key;
Instant expiration;
switch (type) {
case "token":
key = apiKey;
expiration = null;
break;
case "management":
key = managementKey;
expiration = managementExpirationTimeMinutes;
break;
default:
key = inferenceKey;
expiration = inferenceExpirationTimeMinutes;
}

String tokenBearer = req.headers().get("Authorization");
if (tokenBearer == null) {
return false;
}
String[] arrOfStr = tokenBearer.split(" ", 2);
if (arrOfStr.length == 1) {
return false;
}
String token = arrOfStr[1];

if (token.equals(key)) {
if (expiration != null && isTokenExpired(expiration)) {
return false;
}
} else {
return false;
}
return true;
}

public boolean isTokenExpired(Instant expirationTime) {
return !(Instant.now().isBefore(expirationTime));
}

public String getManagementKey() {
return managementKey;
}

public String getInferenceKey() {
return inferenceKey;
}

public String getKey() {
return apiKey;
}

public Instant getInferenceExpirationTime() {
return inferenceExpirationTimeMinutes;
}

public Instant getManagementExpirationTime() {
return managementExpirationTimeMinutes;
}

public void setTime(Double time) {
timeToExpirationMinutes = time;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import java.util.Map;
import java.util.ServiceLoader;
import org.pytorch.serve.http.InvalidPluginException;
import org.pytorch.serve.http.TokenAuthorizationHandler;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.servingsdk.annotations.Endpoint;
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
Expand All @@ -31,9 +30,6 @@ public void initialize() {
logger.info("Initializing plugins manager...");
inferenceEndpoints = initInferenceEndpoints();
managementEndpoints = initManagementEndpoints();
if (managementEndpoints.containsKey("token")) {
TokenAuthorizationHandler.setupTokenClass();
}
}

private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) {
Expand Down
Loading

0 comments on commit d622230

Please sign in to comment.