Skip to content

Commit

Permalink
fix(sdk): allow SDK to handle protocols in addresses (#70)
Browse files Browse the repository at this point in the history
TDFs contain embedded URLs, some of which contain protocols. In order
for them to
work with GRPC we need to strip off the protocol.

The logic for ports is to use one if it is specified, otherwise we use
80 if the protocol is `http`,
otherwise use `443`.
  • Loading branch information
mkleene committed Jun 6, 2024
1 parent c1bbbb4 commit 97ae8ee
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
44 changes: 38 additions & 6 deletions sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
import io.opentdf.platform.kas.PublicKeyRequest;
import io.opentdf.platform.kas.RewrapRequest;

import java.net.MalformedURLException;
import java.net.URL;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.function.Function;

import static java.lang.String.format;

public class KASClient implements SDK.KAS, AutoCloseable {

private final Function<String, ManagedChannel> channelFactory;
Expand Down Expand Up @@ -51,6 +55,33 @@ public String getPublicKey(Config.KASInfo kasInfo) {
.getPublicKey();
}

private String normalizeAddress(String urlString) {
URL url;
try {
url = new URL(urlString);
} catch (MalformedURLException e) {
// if there is no protocol then they either gave us
// a correct address or one we don't know how to fix
return urlString;
}

// otherwise we take the specified port or default
// based on whether the URL uses a scheme that
// implies TLS
int port;
if (url.getPort() == -1) {
if ("http".equals(url.getProtocol())) {
port = 80;
} else {
port = 443;
}
} else {
port = url.getPort();
}

return format("%s:%d", url.getHost(), port);
}

@Override
public synchronized void close() {
var entries = new ArrayList<>(stubs.values());
Expand Down Expand Up @@ -103,21 +134,22 @@ public byte[] unwrap(Manifest.KeyAccess keyAccess, String policy) {
private static class CacheEntry {
final ManagedChannel channel;
final AccessServiceGrpc.AccessServiceBlockingStub stub;

private CacheEntry(ManagedChannel channel, AccessServiceGrpc.AccessServiceBlockingStub stub) {
this.channel = channel;
this.stub = stub;
}
}

private synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(String url) {
if (!stubs.containsKey(url)) {
var channel = channelFactory.apply(url);
// make this protected so we can test the address normalization logic
synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(String url) {
var realAddress = normalizeAddress(url);
if (!stubs.containsKey(realAddress)) {
var channel = channelFactory.apply(realAddress);
var stub = AccessServiceGrpc.newBlockingStub(channel);
stubs.put(url, new CacheEntry(channel, stub));
stubs.put(realAddress, new CacheEntry(channel, stub));
}

return stubs.get(url).stub;
return stubs.get(realAddress).stub;
}
}

24 changes: 24 additions & 0 deletions sdk/src/test/java/io/opentdf/platform/sdk/KASClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.text.ParseException;
import java.util.Base64;
import java.util.Random;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static io.opentdf.platform.sdk.SDKBuilderTest.getRandomPort;
Expand Down Expand Up @@ -136,6 +137,29 @@ public void rewrap(RewrapRequest request, StreamObserver<RewrapResponse> respons
}
}

@Test
public void testAddressNormalization() {
var lastAddress = new AtomicReference<String>();
var dpopKeypair = CryptoUtils.generateRSAKeypair();
var dpopKey = new RSAKey.Builder((RSAPublicKey)dpopKeypair.getPublic()).privateKey(dpopKeypair.getPrivate()).build();
var kasClient = new KASClient(addr -> {
lastAddress.set(addr);
return ManagedChannelBuilder.forTarget(addr).build();
}, dpopKey);

var stub = kasClient.getStub("http://localhost:8080");
assertThat(lastAddress.get()).isEqualTo("localhost:8080");
var otherStub = kasClient.getStub("https://localhost:8080");
assertThat(lastAddress.get()).isEqualTo("localhost:8080");
assertThat(stub).isSameAs(otherStub);

kasClient.getStub("https://example.org");
assertThat(lastAddress.get()).isEqualTo("example.org:443");

kasClient.getStub("http://example.org");
assertThat(lastAddress.get()).isEqualTo("example.org:80");
}

private static Server startServer(AccessServiceGrpc.AccessServiceImplBase accessService) throws IOException {
return ServerBuilder
.forPort(getRandomPort())
Expand Down

0 comments on commit 97ae8ee

Please sign in to comment.