Skip to content

Commit

Permalink
apacheGH-41947: [Java] Support catalog in JDBC driver with session op…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
stevelorddremio committed Jun 11, 2024
1 parent 64b1109 commit 1c7ebdb
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ private static ArrowFlightSqlClientHandler createNewClientHandler(
.withCallOptions(config.toCallOption())
.withRetainCookies(config.retainCookies())
.withRetainAuth(config.retainAuth())
.withCatalog(config.getCatalog())
.build();
} catch (final SQLException e) {
try {
Expand Down Expand Up @@ -171,6 +172,7 @@ public Properties getClientInfo() {

@Override
public void close() throws SQLException {
clientHandler.close();
if (executorService != null) {
executorService.shutdown();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CloseSessionRequest;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightEndpoint;
Expand All @@ -36,6 +38,10 @@
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.LocationSchemes;
import org.apache.arrow.flight.SessionOptionValue;
import org.apache.arrow.flight.SessionOptionValueFactory;
import org.apache.arrow.flight.SetSessionOptionsRequest;
import org.apache.arrow.flight.SetSessionOptionsResult;
import org.apache.arrow.flight.auth2.BearerCredentialWriter;
import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
Expand All @@ -54,22 +60,32 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A {@link FlightSqlClient} handler. */
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;

/**
* A {@link FlightSqlClient} handler.
*/
public final class ArrowFlightSqlClientHandler implements AutoCloseable {
private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class);
// JDBC connection string query parameter
private static final String CATALOG = "catalog";

private final FlightSqlClient sqlClient;
private final Set<CallOption> options = new HashSet<>();
private final Builder builder;
private final String catalog;

ArrowFlightSqlClientHandler(
final FlightSqlClient sqlClient,
final Builder builder,
final Collection<CallOption> credentialOptions) {
final FlightSqlClient sqlClient,
final Builder builder,
final Collection<CallOption> credentialOptions,
final String catalog) {
this.options.addAll(builder.options);
this.options.addAll(credentialOptions);
this.sqlClient = Preconditions.checkNotNull(sqlClient);
this.builder = builder;
this.catalog = catalog;
}

/**
Expand All @@ -80,9 +96,11 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable {
* @param options the {@link CallOption}s to persist in between subsequent client calls.
* @return a new {@link ArrowFlightSqlClientHandler}.
*/
public static ArrowFlightSqlClientHandler createNewHandler(
final FlightClient client, final Builder builder, final Collection<CallOption> options) {
return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options);
public static ArrowFlightSqlClientHandler createNewHandler(final FlightClient client,
final Builder builder,
final Collection<CallOption> options,
final String catalog) {
return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options, catalog);
}

/**
Expand Down Expand Up @@ -199,14 +217,23 @@ public FlightInfo getInfo(final String query) {

@Override
public void close() throws SQLException {
if (hasCatalog()) {
sqlClient.closeSession(new CloseSessionRequest(), getOptions());
}
try {
AutoCloseables.close(sqlClient);
} catch (final Exception e) {
throw new SQLException("Failed to clean up client resources.", e);
}
}

/** A prepared statement handler. */
private boolean hasCatalog() {
return !Strings.isNullOrEmpty(catalog);
}

/**
* A prepared statement handler.
*/
public interface PreparedStatement extends AutoCloseable {
/**
* Executes this {@link PreparedStatement}.
Expand Down Expand Up @@ -257,6 +284,21 @@ public interface PreparedStatement extends AutoCloseable {
* @return a new prepared statement.
*/
public PreparedStatement prepare(final String query) {
if (hasCatalog()) {
final SetSessionOptionsRequest setSessionOptionRequest =
new SetSessionOptionsRequest(ImmutableMap.<String, SessionOptionValue>builder()
.put(CATALOG, SessionOptionValueFactory.makeSessionOptionValue(catalog))
.build());
final SetSessionOptionsResult result = sqlClient.setSessionOptions(setSessionOptionRequest, getOptions());
if (result.hasErrors()) {
Map<String, SetSessionOptionsResult.Error> errors = result.getErrors();
for (Map.Entry<String, SetSessionOptionsResult.Error> error : errors.entrySet()) {
LOGGER.warn(error.toString());
}
throw new RuntimeException(String.format("Cannot set session option for catalog = %s", catalog));
}
}

final FlightSqlClient.PreparedStatement preparedStatement =
sqlClient.prepare(query, getOptions());
return new PreparedStatement() {
Expand Down Expand Up @@ -492,8 +534,10 @@ public static final class Builder {

@VisibleForTesting boolean retainAuth = true;

// These two middleware are for internal use within build() and should not be exposed by builder
// APIs.
@VisibleForTesting
String catalog;

// These two middleware are for internal use within build() and should not be exposed by builder APIs.
// Note that these middleware may not necessarily be registered.
@VisibleForTesting
ClientIncomingAuthHeaderMiddleware.Factory authFactory =
Expand Down Expand Up @@ -527,6 +571,7 @@ public Builder() {}
this.clientCertificatePath = original.clientCertificatePath;
this.clientKeyPath = original.clientKeyPath;
this.allocator = original.allocator;
this.catalog = original.catalog;

if (original.retainCookies) {
this.cookieFactory = original.cookieFactory;
Expand Down Expand Up @@ -762,6 +807,16 @@ public Builder withCallOptions(final Collection<CallOption> options) {
return this;
}

/**
* Sets the catalog for this handler.
* @param catalog the catalog
* @return this instance.
*/
public Builder withCatalog(final String catalog) {
this.catalog = catalog;
return this;
}

/**
* Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields.
*
Expand Down Expand Up @@ -841,7 +896,7 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
new CredentialCallOption(new BearerCredentialWriter(token)),
options.toArray(new CallOption[0])));
}
return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions);
return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions, catalog);

} catch (final IllegalArgumentException
| GeneralSecurityException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ public boolean retainAuth() {
return ArrowFlightConnectionProperty.RETAIN_AUTH.getBoolean(properties);
}

/**
* The catalog to which a connection is made.
* @return the catalog.
*/
public String getCatalog() {
return ArrowFlightConnectionProperty.CATALOG.getString(properties);
}

/**
* Gets the {@link CallOption}s from this {@link ConnectionConfig}.
*
Expand Down Expand Up @@ -203,7 +211,8 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty {
THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false),
TOKEN("token", null, Type.STRING, false),
RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false),
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false);
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false),
CATALOG("catalog", null, Type.STRING, false);

private final String camelName;
private final Object defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static java.lang.Runtime.getRuntime;
import static java.util.Arrays.asList;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CATALOG;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT;
Expand Down Expand Up @@ -104,9 +105,13 @@ public static List<Object[]> provideParameters() {
{
THREAD_POOL_SIZE,
RANDOM.nextInt(getRuntime().availableProcessors()),
(Function<ArrowFlightConnectionConfigImpl, ?>)
ArrowFlightConnectionConfigImpl::threadPoolSize
(Function<ArrowFlightConnectionConfigImpl, ?>) ArrowFlightConnectionConfigImpl::threadPoolSize
},
{
CATALOG,
"catalog",
(Function<ArrowFlightConnectionConfigImpl, ?>) ArrowFlightConnectionConfigImpl::getCatalog
},
});
});
}
}

0 comments on commit 1c7ebdb

Please sign in to comment.