Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Token Authorization by default #3163

Merged
merged 29 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions docs/token_authorization_api.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# TorchServe token authorization API

## Setup
1. Download the jar files from [Maven](https://mvnrepository.com/artifact/org.pytorch/torchserve-endpoint-plugin)
2. Enable token authorization by adding the `--plugins-path /path/to/the/jar/files` flag at start up with the path leading to the downloaded jar files.
Torchserve now supports token authorization by default.
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
udaij12 marked this conversation as resolved.
Show resolved Hide resolved

## How to disable Token Authorization
1. Set global environment variable `TS_DISABLE_TOKEN_AUTHORIZATION=true`
2. Add `--disable-token` to command line when running TorchServe.
3. Add `disable_token_authorization=true` to config.properties file
udaij12 marked this conversation as resolved.
Show resolved Hide resolved

## Configuration
1. Torchserve will enable token authorization if the plugin is provided. Expected log statement `[INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Loading plugin for endpoint token`
1. Torchserve will enable token authorization by default. Expected log statement `main org.pytorch.serve.http.TokenAuthorizationHandler - TOKEN CLASS IMPORTED SUCCESSFULLY`
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
2. In the current working directory a file `key_file.json` will be generated.
1. Example key file:

Expand Down Expand Up @@ -45,9 +48,7 @@
## Customization
Torchserve offers various ways to customize the token authorization to allow owners to reach the desired result.
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30`
2. The token authorization code is consolidated in the plugin and thus can be changed without impacting the frontend or end result. The only thing the user cannot change is:
1. The urlPattern for the plugin must be 'token' and the class name must not change
2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signature must not change. However, the code in the functions can be modified depending on user necessity.


## Notes
1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public static void main(String[] args) {
ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd);
ConfigManager.init(arguments);
ConfigManager configManager = ConfigManager.getInstance();
configManager.setupToken();
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
PluginsManager.getInstance().initialize();
MetricCache.init();
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
package org.pytorch.serve.http;

import com.google.gson.GsonBuilder;
import com.google.gson.JsonObject;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.QueryStringDecoder;
import java.lang.reflect.*;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.PosixFilePermission;
import java.nio.file.attribute.PosixFilePermissions;
import java.security.SecureRandom;
import java.time.Instant;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.InvalidKeyException;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.workflow.WorkflowException;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.util.TokenType;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.slf4j.Logger;
Expand All @@ -24,7 +40,7 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class);
private static TokenType tokenType;
private static Boolean tokenEnabled = false;
private static Class<?> tokenClass;
private static Token tokenClass;
private static Object tokenObject;
private static Double timeToExpirationMinutes = 60.0;
udaij12 marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -44,36 +60,46 @@ public void handleRequest(
if (tokenEnabled) {
if (tokenType == TokenType.MANAGEMENT) {
if (req.toString().contains("/token")) {
checkTokenAuthorization(req, "token");
try {
checkTokenAuthorization(req, "token");
String resp = tokenClass.updateKeyFile(req);
NettyUtils.sendJsonResponse(ctx, resp);
return;
} catch (Exception e) {
logger.error("TOKEN CLASS UPDATED UNSUCCESSFULLY");
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
} else {
checkTokenAuthorization(req, "management");
chain.handleRequest(ctx, req, decoder, segments);
}
} else if (tokenType == TokenType.INFERENCE) {
checkTokenAuthorization(req, "inference");
chain.handleRequest(ctx, req, decoder, segments);
}
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this clause? Would we not get a ResourceNotFoundException() even if we don't handle this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to prevent the token api from being called when token authorization is disabled.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @namannandan is referring to that the last element in the chain will throw the exception:

httpRequestHandlerChain.setNextHandler(invalidRequestHandler);
as not other Handler should handle "token".

if (tokenType == TokenType.MANAGEMENT && req.toString().contains("/token")) {
throw new ResourceNotFoundException();
}
chain.handleRequest(ctx, req, decoder, segments);
}
chain.handleRequest(ctx, req, decoder, segments);
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
}

public static void setupTokenClass() {
try {
tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token");
tokenObject = tokenClass.getDeclaredConstructor().newInstance();
Method method = tokenClass.getMethod("setTime", Double.class);
tokenClass = new Token();
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
Double time = ConfigManager.getInstance().getTimeToExpiration();
String home = ConfigManager.getInstance().getModelServerHome();
tokenClass.setFilePath(home);
if (time != 0.0) {
timeToExpirationMinutes = time;
}
method.invoke(tokenObject, timeToExpirationMinutes);
method = tokenClass.getMethod("generateKeyFile", String.class);
if ((boolean) method.invoke(tokenObject, "token")) {
tokenClass.setTime(timeToExpirationMinutes);
if (tokenClass.generateKeyFile("token")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if key file creation failed. We should not silently start the server without token auth in this case.

logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY");
}
} catch (NoSuchMethodException
| IllegalAccessException
| InstantiationException
| InvocationTargetException
| ClassNotFoundException e) {
} catch (Exception e) {
e.printStackTrace();
logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We could remove this since we no longer use the plugin approach.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept it so that when TorchServe starts users can see if the keyfile was created successfully if enabled.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They will be able to see this through line 98

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats more important is that we fail in case the token auth setup could not be completed. The exception message need to be adapted here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also better to be more specific about which exception to catch.

throw new IllegalStateException("Unable to import token class", e);
Expand All @@ -84,20 +110,208 @@ public static void setupTokenClass() {
private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException {

try {
Method method =
tokenClass.getMethod(
"checkTokenAuthorization",
io.netty.handler.codec.http.FullHttpRequest.class,
String.class);
boolean result = (boolean) (method.invoke(tokenObject, req, type));
boolean result = tokenClass.checkTokenAuthorization(req, type);
if (!result) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
e.printStackTrace();
} catch (Exception e) {
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
}
}

class Token {
private static String apiKey;
private static String managementKey;
private static String inferenceKey;
private static Instant managementExpirationTimeMinutes;
private static Instant inferenceExpirationTimeMinutes;
private static Double timeToExpirationMinutes;
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
private SecureRandom secureRandom = new SecureRandom();
private Base64.Encoder baseEncoder = Base64.getUrlEncoder();
private String fileName = "key_file.json";
private String filePath = "";

public String updateKeyFile(FullHttpRequest req) throws IOException {
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
String queryResponse = parseQuery(req);
String test = "";
if ("management".equals(queryResponse)) {
generateKeyFile("management");
} else if ("inference".equals(queryResponse)) {
generateKeyFile("inference");
} else {
test = "{\n\t\"Error\": " + queryResponse + "\n}\n";
}
return test;
}

// parses query and either returns management/inference or a wrong type error
public String parseQuery(FullHttpRequest req) {
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
Map<String, List<String>> parameters = decoder.parameters();
List<String> values = parameters.get("type");
if (values != null && !values.isEmpty()) {
if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) {
return values.get(0);
} else {
return "WRONG TYPE";
}
}
return "NO TYPE PROVIDED";
}

public String generateKey() {
byte[] randomBytes = new byte[6];
secureRandom.nextBytes(randomBytes);
return baseEncoder.encodeToString(randomBytes);
}

public Instant generateTokenExpiration() {
long secondsToAdd = (long) (timeToExpirationMinutes * 60);
return Instant.now().plusSeconds(secondsToAdd);
}

public void setFilePath(String path) {
filePath = path;
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
}

// generates a key file with new keys depending on the parameter provided
public boolean generateKeyFile(String type) throws IOException {
String userDirectory = filePath + "/" + fileName;
File file = new File(userDirectory);
if (!file.createNewFile() && !file.exists()) {
return false;
}
if (apiKey == null) {
apiKey = generateKey();
}
switch (type) {
case "management":
managementKey = generateKey();
managementExpirationTimeMinutes = generateTokenExpiration();
break;
case "inference":
inferenceKey = generateKey();
inferenceExpirationTimeMinutes = generateTokenExpiration();
break;
default:
managementKey = generateKey();
inferenceKey = generateKey();
inferenceExpirationTimeMinutes = generateTokenExpiration();
managementExpirationTimeMinutes = generateTokenExpiration();
}

JsonObject parentObject = new JsonObject();

JsonObject managementObject = new JsonObject();
managementObject.addProperty("key", managementKey);
managementObject.addProperty("expiration time", managementExpirationTimeMinutes.toString());
parentObject.add("management", managementObject);

JsonObject inferenceObject = new JsonObject();
inferenceObject.addProperty("key", inferenceKey);
inferenceObject.addProperty("expiration time", inferenceExpirationTimeMinutes.toString());
parentObject.add("inference", inferenceObject);

JsonObject apiObject = new JsonObject();
apiObject.addProperty("key", apiKey);
parentObject.add("API", apiObject);

Files.write(
Paths.get(fileName),
new GsonBuilder()
.setPrettyPrinting()
.create()
.toJson(parentObject)
.getBytes(StandardCharsets.UTF_8));

if (!setFilePermissions()) {
try {
Files.delete(Paths.get(fileName));
} catch (IOException e) {
return false;
}
return false;
}
return true;
}

public boolean setFilePermissions() {
Path path = Paths.get(fileName);
try {
Set<PosixFilePermission> permissions = PosixFilePermissions.fromString("rw-------");
Files.setPosixFilePermissions(path, permissions);
} catch (Exception e) {
return false;
}
return true;
}

// checks the token provided in the http with the saved keys depening on parameters
public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
String key;
Instant expiration;
switch (type) {
case "token":
key = apiKey;
expiration = null;
break;
case "management":
key = managementKey;
expiration = managementExpirationTimeMinutes;
break;
default:
key = inferenceKey;
expiration = inferenceExpirationTimeMinutes;
}

String tokenBearer = req.headers().get("Authorization");
if (tokenBearer == null) {
return false;
}
String[] arrOfStr = tokenBearer.split(" ", 2);
if (arrOfStr.length == 1) {
return false;
}
String token = arrOfStr[1];

if (token.equals(key)) {
if (expiration != null && isTokenExpired(expiration)) {
return false;
}
} else {
return false;
}
return true;
}

public boolean isTokenExpired(Instant expirationTime) {
return !(Instant.now().isBefore(expirationTime));
}

public String getManagementKey() {
return managementKey;
}

public String getInferenceKey() {
return inferenceKey;
}

public String getKey() {
return apiKey;
}

public Instant getInferenceExpirationTime() {
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
return inferenceExpirationTimeMinutes;
}

public Instant getManagementExpirationTime() {
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
return managementExpirationTimeMinutes;
}

public void setTime(Double time) {
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
timeToExpirationMinutes = time;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import java.util.Map;
import java.util.ServiceLoader;
import org.pytorch.serve.http.InvalidPluginException;
import org.pytorch.serve.http.TokenAuthorizationHandler;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.servingsdk.annotations.Endpoint;
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
Expand All @@ -31,9 +30,6 @@ public void initialize() {
logger.info("Initializing plugins manager...");
inferenceEndpoints = initInferenceEndpoints();
managementEndpoints = initManagementEndpoints();
if (managementEndpoints.containsKey("token")) {
TokenAuthorizationHandler.setupTokenClass();
}
}

private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) {
Expand Down
Loading
Loading