Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSockets Next: send ping message from the server automatically #40207

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketServerException;
import io.quarkus.websockets.next.WebSocketsRuntimeConfig;
import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig;
import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback;
import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback.MessageType;
import io.quarkus.websockets.next.runtime.Codecs;
Expand Down Expand Up @@ -383,7 +383,7 @@ private void validateOnPongMessage(Callback callback) {
"@OnPongMessage callback must return void or Uni<Void>: " + callbackToString(callback.method));
}
Type messageType = callback.argumentType(MessageCallbackArgument::isMessage);
if (!messageType.name().equals(WebSocketDotNames.BUFFER)) {
if (messageType == null || !messageType.name().equals(WebSocketDotNames.BUFFER)) {
throw new WebSocketServerException(
"@OnPongMessage callback must accept exactly one message parameter of type io.vertx.core.buffer.Buffer: "
+ callbackToString(callback.method));
Expand Down Expand Up @@ -478,10 +478,10 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint,
.build();

MethodCreator constructor = endpointCreator.getConstructorCreator(WebSocketConnection.class,
Codecs.class, WebSocketsRuntimeConfig.class, ContextSupport.class);
Codecs.class, WebSocketsServerRuntimeConfig.class, ContextSupport.class);
constructor.invokeSpecialMethod(
MethodDescriptor.ofConstructor(WebSocketEndpointBase.class, WebSocketConnection.class,
Codecs.class, WebSocketsRuntimeConfig.class, ContextSupport.class),
Codecs.class, WebSocketsServerRuntimeConfig.class, ContextSupport.class),
constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1),
constructor.getMethodParam(2), constructor.getMethodParam(3));
constructor.returnNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class MaxMessageSizeTest {
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Echo.class, WSClient.class);
}).overrideConfigKey("quarkus.websockets-next.max-message-size", "10");
}).overrideConfigKey("quarkus.websockets-next.server.max-message-size", "10");

@Inject
Vertx vertx;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package io.quarkus.websockets.next.test.pingpong;

import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnPongMessage;
import io.quarkus.websockets.next.WebSocket;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.WebSocketClient;

public class AutoPingIntervalTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class);
}).overrideConfigKey("quarkus.websockets-next.server.auto-ping-interval", "200ms");

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@Test
public void testPingPong() throws InterruptedException, ExecutionException {
WebSocketClient client = vertx.createWebSocketClient();
try {
CountDownLatch connectedLatch = new CountDownLatch(1);
client
.connect(endUri.getPort(), endUri.getHost(), endUri.getPath())
.onComplete(r -> {
if (r.succeeded()) {
connectedLatch.countDown();
} else {
throw new IllegalStateException(r.cause());
}
});
assertTrue(connectedLatch.await(5, TimeUnit.SECONDS));
// The pong message should be sent by the client automatically and should be identical to the ping message
assertTrue(Endpoint.PONG.await(5, TimeUnit.SECONDS));
} finally {
client.close().toCompletionStage().toCompletableFuture().get();
}
}

@WebSocket(path = "/end")
public static class Endpoint {

static final CountDownLatch PONG = new CountDownLatch(3);

@OnOpen
public String open() {
return "ok";
}

@OnPongMessage
void pong(Buffer data) {
PONG.countDown();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class SubprotocolSelectedTest {
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class);
}).overrideConfigKey("quarkus.websockets-next.supported-subprotocols", "oak,larch");
}).overrideConfigKey("quarkus.websockets-next.server.supported-subprotocols", "oak,larch");

@Inject
Vertx vertx;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkus.websockets.next;

import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
Expand All @@ -10,9 +11,9 @@
import io.smallrye.config.WithDefault;
import io.vertx.core.http.HttpServerOptions;

@ConfigMapping(prefix = "quarkus.websockets-next")
@ConfigMapping(prefix = "quarkus.websockets-next.server")
@ConfigRoot(phase = ConfigPhase.RUN_TIME)
public interface WebSocketsRuntimeConfig {
public interface WebSocketsServerRuntimeConfig {

/**
* See <a href="https://datatracker.ietf.org/doc/html/rfc6455#page-12">The WebSocket Protocol</a>
Expand All @@ -39,4 +40,11 @@ public interface WebSocketsRuntimeConfig {
*/
OptionalInt maxMessageSize();

/**
* The interval after which, when set, the server sends a ping message to a connected client automatically.
* <p>
* Ping messages are not sent automatically by default.
*/
Optional<Duration> autoPingInterval();
mkouba marked this conversation as resolved.
Show resolved Hide resolved

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ void add(String endpoint, WebSocketConnection connection) {
if (endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection)) {
if (!listeners.isEmpty()) {
for (ConnectionListener listener : listeners) {
listener.connectionAdded(endpoint, connection);
try {
listener.connectionAdded(endpoint, connection);
} catch (Exception e) {
LOG.warnf("Unable to call listener#connectionAdded() on [%s]: %s", listener.getClass(),
e.toString());
}
}
}
}
Expand All @@ -53,7 +58,12 @@ void remove(String endpoint, WebSocketConnection connection) {
if (connections.remove(connection)) {
if (!listeners.isEmpty()) {
for (ConnectionListener listener : listeners) {
listener.connectionRemoved(endpoint, connection.id());
try {
listener.connectionRemoved(endpoint, connection.id());
} catch (Exception e) {
LOG.warnf("Unable to call listener#connectionRemoved() on [%s]: %s", listener.getClass(),
e.toString());
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.jboss.logging.Logger;

import io.quarkus.vertx.core.runtime.VertxBufferImpl;
import io.quarkus.websockets.next.WebSocketConnection;
import io.smallrye.mutiny.Uni;
Expand All @@ -26,6 +28,8 @@

class WebSocketConnectionImpl implements WebSocketConnection {

private static final Logger LOG = Logger.getLogger(WebSocketConnectionImpl.class);

private final String generatedEndpointClass;

private final String endpointId;
Expand Down Expand Up @@ -106,6 +110,14 @@ public Uni<Void> sendPing(Buffer data) {
return UniHelper.toUni(webSocket.writePing(data));
}

void sendAutoPing() {
webSocket.writePing(Buffer.buffer("ping")).onComplete(r -> {
if (r.failed()) {
LOG.warnf("Unable to send auto-ping for %s: %s", this, r.cause().toString());
}
});
}

@Override
public Uni<Void> sendPong(Buffer data) {
return UniHelper.toUni(webSocket.writePong(data));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import io.quarkus.virtual.threads.VirtualThreadsRecorder;
import io.quarkus.websockets.next.WebSocket.ExecutionMode;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketsRuntimeConfig;
import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig;
import io.quarkus.websockets.next.runtime.ConcurrencyLimiter.PromiseComplete;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
Expand All @@ -41,7 +41,7 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint {
private final ConcurrencyLimiter limiter;

@SuppressWarnings("unused")
private final WebSocketsRuntimeConfig config;
private final WebSocketsServerRuntimeConfig config;

private final ArcContainer container;

Expand All @@ -51,7 +51,7 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint {
private final Object beanInstance;

public WebSocketEndpointBase(WebSocketConnection connection, Codecs codecs,
WebSocketsRuntimeConfig config, ContextSupport contextSupport) {
WebSocketsServerRuntimeConfig config, ContextSupport contextSupport) {
this.connection = connection;
this.codecs = codecs;
this.limiter = executionMode() == ExecutionMode.SERIAL ? new ConcurrencyLimiter(connection) : null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import jakarta.inject.Inject;

import io.quarkus.vertx.http.HttpServerOptionsCustomizer;
import io.quarkus.websockets.next.WebSocketsRuntimeConfig;
import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig;
import io.vertx.core.http.HttpServerOptions;

@Dependent
public class WebSocketHttpServerOptionsCustomizer implements HttpServerOptionsCustomizer {

@Inject
WebSocketsRuntimeConfig config;
WebSocketsServerRuntimeConfig config;

@Override
public void customizeHttpServer(HttpServerOptions options) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import io.quarkus.vertx.core.runtime.VertxCoreRecorder;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketServerException;
import io.quarkus.websockets.next.WebSocketsRuntimeConfig;
import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig;
import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState;
import io.smallrye.common.vertx.VertxContext;
import io.smallrye.mutiny.Multi;
Expand All @@ -34,9 +34,9 @@ public class WebSocketServerRecorder {

static final String WEB_SOCKET_CONN_KEY = WebSocketConnection.class.getName();

private final WebSocketsRuntimeConfig config;
private final WebSocketsServerRuntimeConfig config;

public WebSocketServerRecorder(WebSocketsRuntimeConfig config) {
public WebSocketServerRecorder(WebSocketsServerRuntimeConfig config) {
this.config = config;
}

Expand Down Expand Up @@ -67,12 +67,13 @@ public Handler<RoutingContext> createEndpointHandler(String generatedEndpointCla
public void handle(RoutingContext ctx) {
Future<ServerWebSocket> future = ctx.request().toWebSocket();
future.onSuccess(ws -> {
Context context = VertxCoreRecorder.getVertx().get().getOrCreateContext();
Vertx vertx = VertxCoreRecorder.getVertx().get();
Context context = vertx.getOrCreateContext();

WebSocketConnection connection = new WebSocketConnectionImpl(generatedEndpointClass, endpointId, ws,
WebSocketConnectionImpl connection = new WebSocketConnectionImpl(generatedEndpointClass, endpointId, ws,
connectionManager, codecs, ctx);
connectionManager.add(generatedEndpointClass, connection);
LOG.debugf("Connnected: %s", connection);
LOG.debugf("Connection created: %s", connection);

// Initialize and capture the session context state that will be activated
// during message processing
Expand Down Expand Up @@ -216,6 +217,18 @@ public void handle(Void event) {
});
});

Long timerId;
if (config.autoPingInterval().isPresent()) {
timerId = vertx.setPeriodic(config.autoPingInterval().get().toMillis(), new Handler<Long>() {
@Override
public void handle(Long timerId) {
connection.sendAutoPing();
}
});
} else {
timerId = null;
}

ws.closeHandler(new Handler<Void>() {
@Override
public void handle(Void event) {
Expand All @@ -229,6 +242,9 @@ public void handle(Void event) {
LOG.errorf(r.cause(), "Unable to complete @OnClose callback: %s", connection);
}
connectionManager.remove(generatedEndpointClass, connection);
if (timerId != null) {
vertx.cancelTimer(timerId);
}
});
}
});
Expand All @@ -249,6 +265,7 @@ public void handle(Void event) {
});
}
});

});
}
};
Expand Down Expand Up @@ -307,7 +324,7 @@ public void handle(Void event) {
}

private WebSocketEndpoint createEndpoint(String endpointClassName, Context context, WebSocketConnection connection,
Codecs codecs, WebSocketsRuntimeConfig config, ContextSupport contextSupport) {
Codecs codecs, WebSocketsServerRuntimeConfig config, ContextSupport contextSupport) {
try {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
if (cl == null) {
Expand All @@ -318,7 +335,7 @@ private WebSocketEndpoint createEndpoint(String endpointClassName, Context conte
.loadClass(endpointClassName);
WebSocketEndpoint endpoint = (WebSocketEndpoint) endpointClazz
.getDeclaredConstructor(WebSocketConnection.class, Codecs.class,
WebSocketsRuntimeConfig.class, ContextSupport.class)
WebSocketsServerRuntimeConfig.class, ContextSupport.class)
.newInstance(connection, codecs, config, contextSupport);
return endpoint;
} catch (Exception e) {
Expand Down