diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java index de0a834ab33a9..c0216132ab1f1 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java @@ -66,7 +66,7 @@ import io.quarkus.websockets.next.WebSocket; import io.quarkus.websockets.next.WebSocketConnection; import io.quarkus.websockets.next.WebSocketServerException; -import io.quarkus.websockets.next.WebSocketsRuntimeConfig; +import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig; import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback; import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback.MessageType; import io.quarkus.websockets.next.runtime.Codecs; @@ -383,7 +383,7 @@ private void validateOnPongMessage(Callback callback) { "@OnPongMessage callback must return void or Uni: " + callbackToString(callback.method)); } Type messageType = callback.argumentType(MessageCallbackArgument::isMessage); - if (!messageType.name().equals(WebSocketDotNames.BUFFER)) { + if (messageType == null || !messageType.name().equals(WebSocketDotNames.BUFFER)) { throw new WebSocketServerException( "@OnPongMessage callback must accept exactly one message parameter of type io.vertx.core.buffer.Buffer: " + callbackToString(callback.method)); @@ -478,10 +478,10 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, .build(); MethodCreator constructor = endpointCreator.getConstructorCreator(WebSocketConnection.class, - Codecs.class, WebSocketsRuntimeConfig.class, ContextSupport.class); + Codecs.class, WebSocketsServerRuntimeConfig.class, ContextSupport.class); constructor.invokeSpecialMethod( MethodDescriptor.ofConstructor(WebSocketEndpointBase.class, WebSocketConnection.class, - Codecs.class, WebSocketsRuntimeConfig.class, ContextSupport.class), + Codecs.class, WebSocketsServerRuntimeConfig.class, ContextSupport.class), constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1), constructor.getMethodParam(2), constructor.getMethodParam(3)); constructor.returnNull(); diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/maxmessagesize/MaxMessageSizeTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/maxmessagesize/MaxMessageSizeTest.java index 2ffe0778d69f7..bbaf616152555 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/maxmessagesize/MaxMessageSizeTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/maxmessagesize/MaxMessageSizeTest.java @@ -25,7 +25,7 @@ public class MaxMessageSizeTest { public static final QuarkusUnitTest test = new QuarkusUnitTest() .withApplicationRoot(root -> { root.addClasses(Echo.class, WSClient.class); - }).overrideConfigKey("quarkus.websockets-next.max-message-size", "10"); + }).overrideConfigKey("quarkus.websockets-next.server.max-message-size", "10"); @Inject Vertx vertx; diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/pingpong/AutoPingIntervalTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/pingpong/AutoPingIntervalTest.java new file mode 100644 index 0000000000000..b9abcd9f619c3 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/pingpong/AutoPingIntervalTest.java @@ -0,0 +1,77 @@ +package io.quarkus.websockets.next.test.pingpong; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnPongMessage; +import io.quarkus.websockets.next.WebSocket; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.WebSocketClient; + +public class AutoPingIntervalTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Endpoint.class); + }).overrideConfigKey("quarkus.websockets-next.server.auto-ping-interval", "200ms"); + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @Test + public void testPingPong() throws InterruptedException, ExecutionException { + WebSocketClient client = vertx.createWebSocketClient(); + try { + CountDownLatch connectedLatch = new CountDownLatch(1); + client + .connect(endUri.getPort(), endUri.getHost(), endUri.getPath()) + .onComplete(r -> { + if (r.succeeded()) { + connectedLatch.countDown(); + } else { + throw new IllegalStateException(r.cause()); + } + }); + assertTrue(connectedLatch.await(5, TimeUnit.SECONDS)); + // The pong message should be sent by the client automatically and should be identical to the ping message + assertTrue(Endpoint.PONG.await(5, TimeUnit.SECONDS)); + } finally { + client.close().toCompletionStage().toCompletableFuture().get(); + } + } + + @WebSocket(path = "/end") + public static class Endpoint { + + static final CountDownLatch PONG = new CountDownLatch(3); + + @OnOpen + public String open() { + return "ok"; + } + + @OnPongMessage + void pong(Buffer data) { + PONG.countDown(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolSelectedTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolSelectedTest.java index a9e52a296e574..b922be4955450 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolSelectedTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolSelectedTest.java @@ -27,7 +27,7 @@ public class SubprotocolSelectedTest { public static final QuarkusUnitTest test = new QuarkusUnitTest() .withApplicationRoot(root -> { root.addClasses(Endpoint.class, WSClient.class); - }).overrideConfigKey("quarkus.websockets-next.supported-subprotocols", "oak,larch"); + }).overrideConfigKey("quarkus.websockets-next.server.supported-subprotocols", "oak,larch"); @Inject Vertx vertx; diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsRuntimeConfig.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java similarity index 76% rename from extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsRuntimeConfig.java rename to extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java index e1c76dc33dde3..9566afca3ea5a 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsRuntimeConfig.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java @@ -1,5 +1,6 @@ package io.quarkus.websockets.next; +import java.time.Duration; import java.util.List; import java.util.Optional; import java.util.OptionalInt; @@ -10,9 +11,9 @@ import io.smallrye.config.WithDefault; import io.vertx.core.http.HttpServerOptions; -@ConfigMapping(prefix = "quarkus.websockets-next") +@ConfigMapping(prefix = "quarkus.websockets-next.server") @ConfigRoot(phase = ConfigPhase.RUN_TIME) -public interface WebSocketsRuntimeConfig { +public interface WebSocketsServerRuntimeConfig { /** * See The WebSocket Protocol @@ -39,4 +40,11 @@ public interface WebSocketsRuntimeConfig { */ OptionalInt maxMessageSize(); + /** + * The interval after which, when set, the server sends a ping message to a connected client automatically. + *

+ * Ping messages are not sent automatically by default. + */ + Optional autoPingInterval(); + } diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java index 513b4755dc57d..fe7268a63a852 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConnectionManager.java @@ -40,7 +40,12 @@ void add(String endpoint, WebSocketConnection connection) { if (endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection)) { if (!listeners.isEmpty()) { for (ConnectionListener listener : listeners) { - listener.connectionAdded(endpoint, connection); + try { + listener.connectionAdded(endpoint, connection); + } catch (Exception e) { + LOG.warnf("Unable to call listener#connectionAdded() on [%s]: %s", listener.getClass(), + e.toString()); + } } } } @@ -53,7 +58,12 @@ void remove(String endpoint, WebSocketConnection connection) { if (connections.remove(connection)) { if (!listeners.isEmpty()) { for (ConnectionListener listener : listeners) { - listener.connectionRemoved(endpoint, connection.id()); + try { + listener.connectionRemoved(endpoint, connection.id()); + } catch (Exception e) { + LOG.warnf("Unable to call listener#connectionRemoved() on [%s]: %s", listener.getClass(), + e.toString()); + } } } } diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java index 8b7eda81b1461..9de3d34d70efd 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java @@ -13,6 +13,8 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import org.jboss.logging.Logger; + import io.quarkus.vertx.core.runtime.VertxBufferImpl; import io.quarkus.websockets.next.WebSocketConnection; import io.smallrye.mutiny.Uni; @@ -26,6 +28,8 @@ class WebSocketConnectionImpl implements WebSocketConnection { + private static final Logger LOG = Logger.getLogger(WebSocketConnectionImpl.class); + private final String generatedEndpointClass; private final String endpointId; @@ -106,6 +110,14 @@ public Uni sendPing(Buffer data) { return UniHelper.toUni(webSocket.writePing(data)); } + void sendAutoPing() { + webSocket.writePing(Buffer.buffer("ping")).onComplete(r -> { + if (r.failed()) { + LOG.warnf("Unable to send auto-ping for %s: %s", this, r.cause().toString()); + } + }); + } + @Override public Uni sendPong(Buffer data) { return UniHelper.toUni(webSocket.writePong(data)); diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java index 261de140f1683..3a7b4da9dfa33 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -17,7 +17,7 @@ import io.quarkus.virtual.threads.VirtualThreadsRecorder; import io.quarkus.websockets.next.WebSocket.ExecutionMode; import io.quarkus.websockets.next.WebSocketConnection; -import io.quarkus.websockets.next.WebSocketsRuntimeConfig; +import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig; import io.quarkus.websockets.next.runtime.ConcurrencyLimiter.PromiseComplete; import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.Uni; @@ -41,7 +41,7 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint { private final ConcurrencyLimiter limiter; @SuppressWarnings("unused") - private final WebSocketsRuntimeConfig config; + private final WebSocketsServerRuntimeConfig config; private final ArcContainer container; @@ -51,7 +51,7 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint { private final Object beanInstance; public WebSocketEndpointBase(WebSocketConnection connection, Codecs codecs, - WebSocketsRuntimeConfig config, ContextSupport contextSupport) { + WebSocketsServerRuntimeConfig config, ContextSupport contextSupport) { this.connection = connection; this.codecs = codecs; this.limiter = executionMode() == ExecutionMode.SERIAL ? new ConcurrencyLimiter(connection) : null; diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketHttpServerOptionsCustomizer.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketHttpServerOptionsCustomizer.java index 5233fd4a1cc34..1ca59b18aec3a 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketHttpServerOptionsCustomizer.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketHttpServerOptionsCustomizer.java @@ -6,14 +6,14 @@ import jakarta.inject.Inject; import io.quarkus.vertx.http.HttpServerOptionsCustomizer; -import io.quarkus.websockets.next.WebSocketsRuntimeConfig; +import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig; import io.vertx.core.http.HttpServerOptions; @Dependent public class WebSocketHttpServerOptionsCustomizer implements HttpServerOptionsCustomizer { @Inject - WebSocketsRuntimeConfig config; + WebSocketsServerRuntimeConfig config; @Override public void customizeHttpServer(HttpServerOptions options) { diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index c53d15645b01d..bca955cc02b4c 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -14,7 +14,7 @@ import io.quarkus.vertx.core.runtime.VertxCoreRecorder; import io.quarkus.websockets.next.WebSocketConnection; import io.quarkus.websockets.next.WebSocketServerException; -import io.quarkus.websockets.next.WebSocketsRuntimeConfig; +import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig; import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState; import io.smallrye.common.vertx.VertxContext; import io.smallrye.mutiny.Multi; @@ -34,9 +34,9 @@ public class WebSocketServerRecorder { static final String WEB_SOCKET_CONN_KEY = WebSocketConnection.class.getName(); - private final WebSocketsRuntimeConfig config; + private final WebSocketsServerRuntimeConfig config; - public WebSocketServerRecorder(WebSocketsRuntimeConfig config) { + public WebSocketServerRecorder(WebSocketsServerRuntimeConfig config) { this.config = config; } @@ -67,12 +67,13 @@ public Handler createEndpointHandler(String generatedEndpointCla public void handle(RoutingContext ctx) { Future future = ctx.request().toWebSocket(); future.onSuccess(ws -> { - Context context = VertxCoreRecorder.getVertx().get().getOrCreateContext(); + Vertx vertx = VertxCoreRecorder.getVertx().get(); + Context context = vertx.getOrCreateContext(); - WebSocketConnection connection = new WebSocketConnectionImpl(generatedEndpointClass, endpointId, ws, + WebSocketConnectionImpl connection = new WebSocketConnectionImpl(generatedEndpointClass, endpointId, ws, connectionManager, codecs, ctx); connectionManager.add(generatedEndpointClass, connection); - LOG.debugf("Connnected: %s", connection); + LOG.debugf("Connection created: %s", connection); // Initialize and capture the session context state that will be activated // during message processing @@ -216,6 +217,18 @@ public void handle(Void event) { }); }); + Long timerId; + if (config.autoPingInterval().isPresent()) { + timerId = vertx.setPeriodic(config.autoPingInterval().get().toMillis(), new Handler() { + @Override + public void handle(Long timerId) { + connection.sendAutoPing(); + } + }); + } else { + timerId = null; + } + ws.closeHandler(new Handler() { @Override public void handle(Void event) { @@ -229,6 +242,9 @@ public void handle(Void event) { LOG.errorf(r.cause(), "Unable to complete @OnClose callback: %s", connection); } connectionManager.remove(generatedEndpointClass, connection); + if (timerId != null) { + vertx.cancelTimer(timerId); + } }); } }); @@ -249,6 +265,7 @@ public void handle(Void event) { }); } }); + }); } }; @@ -307,7 +324,7 @@ public void handle(Void event) { } private WebSocketEndpoint createEndpoint(String endpointClassName, Context context, WebSocketConnection connection, - Codecs codecs, WebSocketsRuntimeConfig config, ContextSupport contextSupport) { + Codecs codecs, WebSocketsServerRuntimeConfig config, ContextSupport contextSupport) { try { ClassLoader cl = Thread.currentThread().getContextClassLoader(); if (cl == null) { @@ -318,7 +335,7 @@ private WebSocketEndpoint createEndpoint(String endpointClassName, Context conte .loadClass(endpointClassName); WebSocketEndpoint endpoint = (WebSocketEndpoint) endpointClazz .getDeclaredConstructor(WebSocketConnection.class, Codecs.class, - WebSocketsRuntimeConfig.class, ContextSupport.class) + WebSocketsServerRuntimeConfig.class, ContextSupport.class) .newInstance(connection, codecs, config, contextSupport); return endpoint; } catch (Exception e) {