Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token Authorization fixes #3192

Merged
merged 39 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9b46cab
test
May 23, 2024
fc306f9
token authorization integration
May 29, 2024
80682e2
env false regression cpu ci
udaij12 May 29, 2024
56f7c35
testing ci
udaij12 May 29, 2024
107ddc2
testing newman
udaij12 May 29, 2024
0d984ab
fix newman tests
udaij12 May 29, 2024
d2ccecf
testing pytest
udaij12 May 29, 2024
fa8e3da
testing cmd arg
udaij12 May 29, 2024
40697ff
pytest fixes
udaij12 May 29, 2024
91505b3
fixing tests
udaij12 May 29, 2024
dd54734
doc update
udaij12 May 30, 2024
543dfac
spell check
udaij12 May 30, 2024
54319c5
fixing priority between config file and cmd
udaij12 May 30, 2024
9ff3749
test fixes
udaij12 May 30, 2024
dbfdd81
removing unneeded files
udaij12 May 30, 2024
b02c665
Delete unneeded files
udaij12 May 30, 2024
23260c2
review fixes
udaij12 May 31, 2024
6b4925f
Merge branch 'token' of https://github.com/pytorch/serve into token
udaij12 May 31, 2024
9cd6e72
removing comments
udaij12 May 31, 2024
ac0de19
adding doc clarification and new test
udaij12 Jun 4, 2024
963bb22
Merge branch 'master' into token
udaij12 Jun 4, 2024
996788a
Merge branch 'master' into token
udaij12 Jun 7, 2024
568d5bd
changes to docs
udaij12 Jun 7, 2024
5cb400d
adding new tests
udaij12 Jun 11, 2024
119444e
Merge branch 'master' into token
udaij12 Jun 11, 2024
1cb5a9d
fixing merge conflict
udaij12 Jun 11, 2024
3edceec
format fix
udaij12 Jun 11, 2024
22f3c47
format fixes
udaij12 Jun 11, 2024
b8da1a7
Merge branch 'master' into token
udaij12 Jun 12, 2024
c9c7b8d
addressing comments
udaij12 Jun 13, 2024
3f593d7
Merge branch 'token' of https://github.com/pytorch/serve into token
udaij12 Jun 13, 2024
b9f073c
Merge branch 'master' into token
udaij12 Jun 13, 2024
fba89ff
fixing merge conflict
udaij12 Jun 13, 2024
4e62ac4
fixing merge conflict
udaij12 Jun 13, 2024
07eb052
fixing merge conflict
udaij12 Jun 13, 2024
95d2f91
fix merge conflict
udaij12 Jun 13, 2024
c3dc699
doc update
udaij12 Jun 13, 2024
905a4db
fixing format
udaij12 Jun 13, 2024
4fd51ed
fix to benchmarks
udaij12 Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/utils/system_under_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def start(self):
click.secho("*Starting local Torchserve instance...", fg="green")

ts_cmd = (
f"torchserve --start --model-store {self.execution_params['tmp_dir']}/model_store --model-api-enabled --disable-token"
f"torchserve --start --model-store {self.execution_params['tmp_dir']}/model_store --model-api-enabled --disable-token "
f"--workflow-store {self.execution_params['tmp_dir']}/wf_store "
f"--ts-config {self.execution_params['tmp_dir']}/benchmark/conf/{self.execution_params['config_properties_name']} "
f" > {self.execution_params['tmp_dir']}/benchmark/logs/model_metrics.log"
Expand Down Expand Up @@ -195,7 +195,7 @@ def start(self):
f"docker run {self.execution_params['docker_runtime']} {backend_profiling} --name ts --user root -p "
f"127.0.0.1:{inference_port}:{inference_port} -p 127.0.0.1:{management_port}:{management_port} "
f"-v {self.execution_params['tmp_dir']}:/tmp {enable_gpu} -itd {docker_image} "
f'"torchserve --start --model-store /home/model-server/model-store --model-api-enabled --disable-token'
f'"torchserve --start --model-store /home/model-server/model-store --model-api-enabled --disable-token '
f"\--workflow-store /home/model-server/wf-store "
f"--ts-config /tmp/benchmark/conf/{self.execution_params['config_properties_name']} > "
f'/tmp/benchmark/logs/model_metrics.log"'
Expand Down
10 changes: 6 additions & 4 deletions docs/token_authorization_api.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# TorchServe token authorization API

Torchserve now supports token authorization by default.
TorchServe now enforces token authorization by default


## How to set and disable Token Authorization
* Global environment variable: use `TS_DISABLE_TOKEN_AUTHORIZATION` and set to `true` to disable and `false` to enable token authorization. Note that `enable_envvars_config=true` must be set in config.properties for global environment variables to be used
* Command line: Command line can only be used to disable token authorization by adding the `--disable-token` flag.
* Config properties file: use `disable_token_authorization` and set to `true` to disable and `false` to enable token authorization.

Priority between env variables, cmd, and config file follows the following [TorchServer standard](https://github.com/pytorch/serve/blob/c74a29e8144bc12b84196775076b0e8cf3c5a6fc/docs/configuration.md#advanced-configuration)
Priority between env variables, cmd, and config file follows the following [TorchServer standard](https://github.com/pytorch/serve/blob/master/docs/configuration.md)

* Example 1:
* Config file: `disable_token_authorization=false`

Expand Down Expand Up @@ -48,7 +50,7 @@ Priority between env variables, cmd, and config file follows the following [Torc
2. Inference key: Used for inference APIs. Example:
`curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer FINhR1fj"`
3. API key: Used for the token authorization API. Check section 4 for API use.
4. The plugin also includes an API in order to generate a new key to replace either the management or inference key.
4. API in order to generate a new key to replace either the management or inference key.
1. Management Example:
`curl localhost:8081/token?type=management -H "Authorization: Bearer m4M-5IBY"` will replace the current management key in the key_file with a new one and will update the expiration time.
2. Inference example:
Expand All @@ -61,4 +63,4 @@ Priority between env variables, cmd, and config file follows the following [Torc
## Notes
1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file.
2. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30`
3. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should only be able to run inferences against models that have already been loaded. The owner can also provide owners with the management key if owners want users to add and remove models.
3. Three tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should only be able to run inferences against models that have already been loaded. The owner can also provide owners with the management key if owners want users to add and remove models.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.grpcimpl.GRPCInterceptor;
import org.pytorch.serve.grpcimpl.GRPCServiceFactory;
import org.pytorch.serve.http.TokenAuthorizationHandler;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.metrics.MetricCache;
import org.pytorch.serve.metrics.MetricManager;
Expand Down Expand Up @@ -86,7 +87,7 @@ public static void main(String[] args) {
ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd);
ConfigManager.init(arguments);
ConfigManager configManager = ConfigManager.getInstance();
configManager.setupToken();
TokenAuthorizationHandler.setupToken();
PluginsManager.getInstance().initialize();
MetricCache.init();
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class);
private static TokenType tokenType;
private static Boolean tokenEnabled = false;
private static Token tokenClass;
private static Token token;
private static Object tokenObject;
private static Double timeToExpirationMinutes = 60.0;

/** Creates a new {@code InferenceRequestHandler} instance. */
public TokenAuthorizationHandler(TokenType type) {
Expand All @@ -62,11 +61,12 @@ public void handleRequest(
if (req.toString().contains("/token")) {
try {
checkTokenAuthorization(req, "token");
String resp = tokenClass.updateKeyFile(req);
String queryResponse = parseQuery(req);
String resp = token.updateKeyFile(queryResponse);
NettyUtils.sendJsonResponse(ctx, resp);
return;
} catch (Exception e) {
logger.error("TOKEN CLASS UPDATED UNSUCCESSFULLY");
logger.error("Key file updated unsuccessfully");
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
Expand All @@ -76,48 +76,60 @@ public void handleRequest(
} else if (tokenType == TokenType.INFERENCE) {
checkTokenAuthorization(req, "inference");
}
} else {
if (tokenType == TokenType.MANAGEMENT && req.toString().contains("/token")) {
throw new ResourceNotFoundException();
}
}
chain.handleRequest(ctx, req, decoder, segments);
}

public static void setupTokenClass() {
try {
tokenClass = new Token();
Double time = ConfigManager.getInstance().getTimeToExpiration();
String home = ConfigManager.getInstance().getModelServerHome();
tokenClass.setFilePath(home);
if (time != 0.0) {
timeToExpirationMinutes = time;
}
tokenClass.setTime(timeToExpirationMinutes);
if (tokenClass.generateKeyFile("token")) {
logger.info("Token Authorization Enabled");
public static void setupToken() {
if (!ConfigManager.getInstance().getDisableTokenAuthorization()) {
try {
token = new Token();
if (token.generateKeyFile("token")) {
logger.info("Token Authorization Enabled");
}
} catch (IOException e) {
e.printStackTrace();
logger.error("Token Authorization setup unsuccessfully");
throw new IllegalStateException("Token Authorization setup unsuccessfully", e);
}
} catch (Exception e) {
e.printStackTrace();
logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
throw new IllegalStateException("Unable to import token class", e);
tokenEnabled = true;
}
tokenEnabled = true;
}

private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException {
String tokenBearer = req.headers().get("Authorization");
if (tokenBearer == null) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
String[] arrOfStr = tokenBearer.split(" ", 2);
if (arrOfStr.length == 1) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
String currToken = arrOfStr[1];

try {
boolean result = tokenClass.checkTokenAuthorization(req, type);
if (!result) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
} catch (Exception e) {
boolean result = token.checkTokenAuthorization(currToken, type);
if (!result) {
throw new InvalidKeyException(
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
}
}

// parses query and either returns management/inference or a wrong type error
private String parseQuery(FullHttpRequest req) {
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
Map<String, List<String>> parameters = decoder.parameters();
List<String> values = parameters.get("type");
if (values != null && !values.isEmpty()) {
if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) {
return values.get(0);
} else {
return "WRONG TYPE";
}
}
return "NO TYPE PROVIDED";
}
}

class Token {
Expand All @@ -126,14 +138,12 @@ class Token {
private static String inferenceKey;
private static Instant managementExpirationTimeMinutes;
private static Instant inferenceExpirationTimeMinutes;
private static Double timeToExpirationMinutes;
private SecureRandom secureRandom = new SecureRandom();
private Base64.Encoder baseEncoder = Base64.getUrlEncoder();
private String fileName = "key_file.json";
private String filePath = "";
private String filePath = ConfigManager.getInstance().getModelServerHome();

public String updateKeyFile(FullHttpRequest req) throws IOException {
String queryResponse = parseQuery(req);
public String updateKeyFile(String queryResponse) throws IOException {
String test = "";
if ("management".equals(queryResponse)) {
generateKeyFile("management");
Expand All @@ -145,36 +155,17 @@ public String updateKeyFile(FullHttpRequest req) throws IOException {
return test;
}

// parses query and either returns management/inference or a wrong type error
public String parseQuery(FullHttpRequest req) {
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
Map<String, List<String>> parameters = decoder.parameters();
List<String> values = parameters.get("type");
if (values != null && !values.isEmpty()) {
if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) {
return values.get(0);
} else {
return "WRONG TYPE";
}
}
return "NO TYPE PROVIDED";
}

public String generateKey() {
byte[] randomBytes = new byte[6];
secureRandom.nextBytes(randomBytes);
return baseEncoder.encodeToString(randomBytes);
}

public Instant generateTokenExpiration() {
long secondsToAdd = (long) (timeToExpirationMinutes * 60);
long secondsToAdd = (long) (ConfigManager.getInstance().getTimeToExpiration() * 60);
return Instant.now().plusSeconds(secondsToAdd);
}

public void setFilePath(String path) {
filePath = path;
}

// generates a key file with new keys depending on the parameter provided
public boolean generateKeyFile(String type) throws IOException {
String userDirectory = filePath + "/" + fileName;
Expand Down Expand Up @@ -248,7 +239,7 @@ public boolean setFilePermissions() {
}

// checks the token provided in the http with the saved keys depening on parameters
public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
public boolean checkTokenAuthorization(String token, String type) {
String key;
Instant expiration;
switch (type) {
Expand All @@ -265,16 +256,6 @@ public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
expiration = inferenceExpirationTimeMinutes;
}

String tokenBearer = req.headers().get("Authorization");
if (tokenBearer == null) {
return false;
}
String[] arrOfStr = tokenBearer.split(" ", 2);
if (arrOfStr.length == 1) {
return false;
}
String token = arrOfStr[1];

if (token.equals(key)) {
if (expiration != null && isTokenExpired(expiration)) {
return false;
Expand All @@ -288,28 +269,4 @@ public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
public boolean isTokenExpired(Instant expirationTime) {
return !(Instant.now().isBefore(expirationTime));
}

public String getManagementKey() {
return managementKey;
}

public String getInferenceKey() {
return inferenceKey;
}

public String getKey() {
return apiKey;
}

public Instant getInferenceExpirationTime() {
return inferenceExpirationTimeMinutes;
}

public Instant getManagementExpirationTime() {
return managementExpirationTimeMinutes;
}

public void setTime(Double time) {
timeToExpirationMinutes = time;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.commons.cli.Options;
import org.apache.commons.io.IOUtils;
import org.pytorch.serve.archive.model.Manifest;
import org.pytorch.serve.http.TokenAuthorizationHandler;
import org.pytorch.serve.metrics.MetricBuilder;
import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer;
import org.pytorch.serve.snapshot.SnapshotSerializerFactory;
Expand Down Expand Up @@ -450,14 +449,6 @@ public boolean isOpenInferenceProtocol() {
return Boolean.parseBoolean(prop.getProperty(TS_OPEN_INFERENCE_PROTOCOL, "false"));
}

public boolean setupToken() {
boolean disable_token_authorization = getDisableTokenAuthorization();
if (!disable_token_authorization) {
TokenAuthorizationHandler.setupTokenClass();
}
return true;
}

public boolean isGRPCSSLEnabled() {
return Boolean.parseBoolean(getProperty(TS_ENABLE_GRPC_SSL, "false"));
}
Expand Down Expand Up @@ -1001,7 +992,7 @@ public Double getTimeToExpiration() {
logger.error("Token expiration not a valid integer");
}
}
return 0.0;
return 60.0;
}

public String getTsHeaderKeySequenceId() {
Expand Down
Loading
Loading