From 1c7ebdbc81adf55527afca756cf3981279c2cb96 Mon Sep 17 00:00:00 2001 From: Steve Lord Date: Thu, 6 Jun 2024 16:42:20 -0700 Subject: [PATCH] GH-41947: [Java] Support catalog in JDBC driver with session options --- .../driver/jdbc/ArrowFlightConnection.java | 2 + .../client/ArrowFlightSqlClientHandler.java | 77 ++++++++++++++++--- .../ArrowFlightConnectionConfigImpl.java | 11 ++- .../ArrowFlightConnectionConfigImplTest.java | 11 ++- 4 files changed, 86 insertions(+), 15 deletions(-) diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index 24d72eb3f0832..c1b1c8f8e6add 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -112,6 +112,7 @@ private static ArrowFlightSqlClientHandler createNewClientHandler( .withCallOptions(config.toCallOption()) .withRetainCookies(config.retainCookies()) .withRetainAuth(config.retainAuth()) + .withCatalog(config.getCatalog()) .build(); } catch (final SQLException e) { try { @@ -171,6 +172,7 @@ public Properties getClientInfo() { @Override public void close() throws SQLException { + clientHandler.close(); if (executorService != null) { executorService.shutdown(); } diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index f3553ae2f01d7..e16bef8376225 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -25,9 +25,11 @@ import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils; import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.CloseSessionRequest; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightClientMiddleware; import org.apache.arrow.flight.FlightEndpoint; @@ -36,6 +38,10 @@ import org.apache.arrow.flight.FlightStatusCode; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.LocationSchemes; +import org.apache.arrow.flight.SessionOptionValue; +import org.apache.arrow.flight.SessionOptionValueFactory; +import org.apache.arrow.flight.SetSessionOptionsRequest; +import org.apache.arrow.flight.SetSessionOptionsResult; import org.apache.arrow.flight.auth2.BearerCredentialWriter; import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; @@ -54,22 +60,32 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** A {@link FlightSqlClient} handler. */ +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableMap; + +/** + * A {@link FlightSqlClient} handler. + */ public final class ArrowFlightSqlClientHandler implements AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFlightSqlClientHandler.class); + // JDBC connection string query parameter + private static final String CATALOG = "catalog"; private final FlightSqlClient sqlClient; private final Set options = new HashSet<>(); private final Builder builder; + private final String catalog; ArrowFlightSqlClientHandler( - final FlightSqlClient sqlClient, - final Builder builder, - final Collection credentialOptions) { + final FlightSqlClient sqlClient, + final Builder builder, + final Collection credentialOptions, + final String catalog) { this.options.addAll(builder.options); this.options.addAll(credentialOptions); this.sqlClient = Preconditions.checkNotNull(sqlClient); this.builder = builder; + this.catalog = catalog; } /** @@ -80,9 +96,11 @@ public final class ArrowFlightSqlClientHandler implements AutoCloseable { * @param options the {@link CallOption}s to persist in between subsequent client calls. * @return a new {@link ArrowFlightSqlClientHandler}. */ - public static ArrowFlightSqlClientHandler createNewHandler( - final FlightClient client, final Builder builder, final Collection options) { - return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options); + public static ArrowFlightSqlClientHandler createNewHandler(final FlightClient client, + final Builder builder, + final Collection options, + final String catalog) { + return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), builder, options, catalog); } /** @@ -199,6 +217,9 @@ public FlightInfo getInfo(final String query) { @Override public void close() throws SQLException { + if (hasCatalog()) { + sqlClient.closeSession(new CloseSessionRequest(), getOptions()); + } try { AutoCloseables.close(sqlClient); } catch (final Exception e) { @@ -206,7 +227,13 @@ public void close() throws SQLException { } } - /** A prepared statement handler. */ + private boolean hasCatalog() { + return !Strings.isNullOrEmpty(catalog); + } + + /** + * A prepared statement handler. + */ public interface PreparedStatement extends AutoCloseable { /** * Executes this {@link PreparedStatement}. @@ -257,6 +284,21 @@ public interface PreparedStatement extends AutoCloseable { * @return a new prepared statement. */ public PreparedStatement prepare(final String query) { + if (hasCatalog()) { + final SetSessionOptionsRequest setSessionOptionRequest = + new SetSessionOptionsRequest(ImmutableMap.builder() + .put(CATALOG, SessionOptionValueFactory.makeSessionOptionValue(catalog)) + .build()); + final SetSessionOptionsResult result = sqlClient.setSessionOptions(setSessionOptionRequest, getOptions()); + if (result.hasErrors()) { + Map errors = result.getErrors(); + for (Map.Entry error : errors.entrySet()) { + LOGGER.warn(error.toString()); + } + throw new RuntimeException(String.format("Cannot set session option for catalog = %s", catalog)); + } + } + final FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(query, getOptions()); return new PreparedStatement() { @@ -492,8 +534,10 @@ public static final class Builder { @VisibleForTesting boolean retainAuth = true; - // These two middleware are for internal use within build() and should not be exposed by builder - // APIs. + @VisibleForTesting + String catalog; + + // These two middleware are for internal use within build() and should not be exposed by builder APIs. // Note that these middleware may not necessarily be registered. @VisibleForTesting ClientIncomingAuthHeaderMiddleware.Factory authFactory = @@ -527,6 +571,7 @@ public Builder() {} this.clientCertificatePath = original.clientCertificatePath; this.clientKeyPath = original.clientKeyPath; this.allocator = original.allocator; + this.catalog = original.catalog; if (original.retainCookies) { this.cookieFactory = original.cookieFactory; @@ -762,6 +807,16 @@ public Builder withCallOptions(final Collection options) { return this; } + /** + * Sets the catalog for this handler. + * @param catalog the catalog + * @return this instance. + */ + public Builder withCatalog(final String catalog) { + this.catalog = catalog; + return this; + } + /** * Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields. * @@ -841,7 +896,7 @@ public ArrowFlightSqlClientHandler build() throws SQLException { new CredentialCallOption(new BearerCredentialWriter(token)), options.toArray(new CallOption[0]))); } - return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions); + return ArrowFlightSqlClientHandler.createNewHandler(client, this, credentialOptions, catalog); } catch (final IllegalArgumentException | GeneralSecurityException diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java index fcb53519da961..a991ac61b34fa 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java @@ -154,6 +154,14 @@ public boolean retainAuth() { return ArrowFlightConnectionProperty.RETAIN_AUTH.getBoolean(properties); } + /** + * The catalog to which a connection is made. + * @return the catalog. + */ + public String getCatalog() { + return ArrowFlightConnectionProperty.CATALOG.getString(properties); + } + /** * Gets the {@link CallOption}s from this {@link ConnectionConfig}. * @@ -203,7 +211,8 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty { THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false), TOKEN("token", null, Type.STRING, false), RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false), - RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false); + RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false), + CATALOG("catalog", null, Type.STRING, false); private final String camelName; private final Object defaultValue; diff --git a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java index 616f6e4b364f0..8ab460c250325 100644 --- a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java +++ b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java @@ -18,6 +18,7 @@ import static java.lang.Runtime.getRuntime; import static java.util.Arrays.asList; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CATALOG; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT; @@ -104,9 +105,13 @@ public static List provideParameters() { { THREAD_POOL_SIZE, RANDOM.nextInt(getRuntime().availableProcessors()), - (Function) - ArrowFlightConnectionConfigImpl::threadPoolSize + (Function) ArrowFlightConnectionConfigImpl::threadPoolSize + }, + { + CATALOG, + "catalog", + (Function) ArrowFlightConnectionConfigImpl::getCatalog }, - }); + }); } }