Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/LWB-36_flex-broadcast-manager #55

Merged
merged 10 commits into from
Apr 22, 2024
8 changes: 4 additions & 4 deletions backend/ws-server/README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ In other case you will receive message with `ERROR` action:

```json
{
"action": "ERROR",
"timestamp": 1710622280.309209600,
"error": "You are not member of chat",
"path": "/chat/group/65fca83ea94fa31e8a726f9956/leave"
"action": "LEAVE",
"timestamp": 1713443892.540455500,
"senderId": 34,
"chatId": "6621128c182be4335fc0a6e5"
}
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ public Set<String> getUserSessions(Long userId) {
return members == null ? emptySet() : members;
}

@Override
public Set<String> getUserSessions(String customKey) {
final Set<String> members = redisTemplate.opsForSet().members(customKey);
return members == null ? emptySet() : members;
}

@Override
public boolean isMember(String chatId, Long userId) {
return TRUE.equals(
Expand Down Expand Up @@ -215,7 +221,11 @@ public void shareWithConsumer(String consumerId, String jsonMessage) {
}

private String userKey(Long userId) {
return "user:%d".formatted(userId);
return userKey(String.valueOf(userId));
}

private String userKey(String userId) {
return "user:%s".formatted(userId);
}

private String chatKey(String chatId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
public interface SessionRepository<T> {
Set<String> getUserSessions(T userId);

Set<String> getUserSessions(String customKey);

void saveSession(T userId, String sessionId);

void removeSession(T userId, String sessionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ public Box<MemberMessage> join(@PathVariable String id,
}

@Broadcast
@Broadcast(value = "user:{senderId}", analyzeMessage = true)
@SubRoute("/{id}/leave")
public Box<ChatMessage> leaveChat(@PathVariable String id,
@NonNull UserPrincipal principal,
Expand Down Expand Up @@ -191,6 +192,7 @@ public Box<MemberMessage> addMember(@PathVariable String id,
}

@Broadcast
@Broadcast("user:{memberId}")
@SubRoute("/{id}/kick/{memberId}")
public Box<MemberMessage> removeMember(@PathVariable String id,
@PathVariable Long memberId,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,46 @@
package org.linkwave.ws.websocket.routing;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.linkwave.ws.repository.ChatRepository;
import org.linkwave.ws.repository.SessionRepository;
import org.linkwave.ws.websocket.jwt.UserPrincipal;
import org.linkwave.ws.websocket.routing.args.ArgumentResolverStrategy;
import org.linkwave.ws.websocket.routing.args.PathVariableResolverStrategy;
import org.linkwave.ws.websocket.routing.args.PayloadResolverStrategy;
import org.linkwave.ws.websocket.routing.broadcast.*;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.WebSocketSession;

import java.security.Principal;
import java.util.List;
import java.util.Map;

@Configuration
public class RoutingAutoConfig {

public static final String PATH_PARAM_NAME = "path";

@Bean
public BroadcastRepositoryResolver broadcastRepositoryResolver(
ChatRepository<Long, String> chatRepository
) {
return new BroadcastRepositoryResolverImpl(
chatRepository,
Map.of(
"user:{}", SessionRepository::getUserSessions,
"chat:{}", ChatRepository::getSessions
)
);
}

@Bean
public BroadcastManager broadcastManager(WebSocketMessageBroadcast messageBroadcast,
ChatRepository<Long, String> chatRepository,
BroadcastRepositoryResolver repositoryResolver) {
return new FlexBroadcastManager(messageBroadcast, chatRepository, repositoryResolver);
}

// argument resolvers registry
@Bean
public List<ArgumentResolverStrategy> argumentResolverStrategies(ObjectMapper objectMapper) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@

import org.linkwave.ws.websocket.routing.broadcast.BroadcastManager;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.annotation.*;

/**
* This annotation is used to enable message broadcast for route handler
* This annotation is used to enable message broadcast for route handler.
*
* @see BroadcastManager
* @see SubRoute
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Repeatable(Broadcasts.class)
public @interface Broadcast {

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.linkwave.ws.websocket.routing.bpp;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* Container for {@link Broadcast} annotations in order to
* support broadcast for different destinations.
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Broadcasts {

Broadcast[] value() default {};

}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public Object postProcessAfterInitialization(@NonNull Object bean,
}

// check broadcast options
verifyBroadcast(method);
final boolean broadcast = verifyBroadcast(method);

method.setAccessible(true);
routes.put(combinedPath, new RouteComponent(entry.getValue(), method));
Expand All @@ -77,7 +77,7 @@ public Object postProcessAfterInitialization(@NonNull Object bean,
sb.setLength(0);
sb.append(rootPath);

log.debug("Route [{}], broadcast: {}", combinedPath, method.isAnnotationPresent(Broadcast.class));
log.debug("Route [{}], broadcast: {}", combinedPath, broadcast);
}
}
sb.setLength(0);
Expand All @@ -87,29 +87,35 @@ public Object postProcessAfterInitialization(@NonNull Object bean,
return bean;
}

private void verifyBroadcast(@NonNull Method routeHandler) {
if (!routeHandler.isAnnotationPresent(Broadcast.class)) {
return;
private boolean verifyBroadcast(@NonNull Method routeHandler) {
final Broadcast[] annotations = routeHandler.getAnnotationsByType(Broadcast.class);
if (annotations.length == 0) {
return false;
}

if (routeHandler.getReturnType().equals(void.class)) {
throw new RuntimeException(
format("Route handler \"%s\" with broadcast has return type void", routeHandler.getName())
format(
"Route handler \"%s\" marked as broadcast has return type void",
"%s.%s".formatted(routeHandler.getDeclaringClass().getName(), routeHandler.getName())
)
);
}

String[] keyComponents = routeHandler.getAnnotation(Broadcast.class)
.value()
.trim()
.split(BroadcastManager.KEY_SEPARATOR);

if (keyComponents.length < 2) {
String errMsg = format(
"Broadcast annotation value incorrect format at route handler \"%s\"",
routeHandler.getName()
);
throw new RuntimeException(errMsg);
for (Broadcast annotation : annotations) {
final String[] keyComponents = annotation.value()
.trim()
.split(BroadcastManager.KEY_SEPARATOR);

if (keyComponents.length < 2) {
String errMsg = format(
"Broadcast annotation value incorrect format at route handler \"%s\"",
"%s.%s".formatted(routeHandler.getDeclaringClass().getName(), routeHandler.getName())
);
throw new RuntimeException(errMsg);
}
}
return true;
}

private Field getRoutesMapField(@NonNull Class<?> cls) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.linkwave.ws.websocket.routing.broadcast;

import org.linkwave.ws.websocket.routing.bpp.Broadcast;
import org.springframework.lang.NonNull;

import java.util.Set;

/**
* Defines which sessions should be retrieved that corresponds to<br/>
* key-pattern set in {@link Broadcast#value()}.
*/
public interface BroadcastRepositoryResolver {

/**
* Retrieves a set of sessions ids based on key-pattern.
* @param broadcastKeyPattern key-pattern set in {@link Broadcast#value()}
* @param resolvedKeyPattern key-pattern with resolved key variables
* @return set of sessions ids that matched the specified criteria
*/
Set<String> resolve(@NonNull String broadcastKeyPattern, @NonNull String resolvedKeyPattern);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package org.linkwave.ws.websocket.routing.broadcast;

import lombok.RequiredArgsConstructor;
import org.linkwave.ws.repository.ChatRepository;
import org.springframework.lang.NonNull;

import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiFunction;

import static org.linkwave.ws.utils.RouteUtils.isPathVariable;
import static org.linkwave.ws.websocket.routing.broadcast.BroadcastManager.KEY_SEPARATOR;

@RequiredArgsConstructor
public class BroadcastRepositoryResolverImpl implements BroadcastRepositoryResolver {

private final ChatRepository<Long, String> chatRepository;

/**
* If the field type seems a bit complicated, this is an example how this looks like:
* <pre> {@code
* return new BroadcastRepositoryResolverImpl(
* Map.of(
* "user:{}", SessionRepository::getUserSessions,
* "chat:{}", ChatRepository::getSessions
* )
* );
* }</pre>
*/
private final Map<
String,
BiFunction<ChatRepository<Long, String>, String, Set<String>>
> repositoryResolvers;

@Override
public Set<String> resolve(@NonNull String broadcastKeyPattern, @NonNull String resolvedKeyPattern) {
Objects.requireNonNull(broadcastKeyPattern);
Objects.requireNonNull(resolvedKeyPattern);

if (broadcastKeyPattern.equals(resolvedKeyPattern)) {
throw new IllegalArgumentException("Invalid broadcast keys");
}

return repositoryResolvers
.get(eraseKey(broadcastKeyPattern))
.apply(chatRepository, resolvedKeyPattern);
}

/**
* Erases key variables names from passed key-pattern.<br/>
* <b>Example:</b> For key-pattern {@code "chat:{id}"} it returns {@code "chat:{}"}.
*
* @param keyPattern non-null string that contains key pattern
* @return key pattern without its variables names
*/
@NonNull
private String eraseKey(@NonNull String keyPattern) {
final String[] components = keyPattern.trim().split(KEY_SEPARATOR);
final var sb = new StringBuilder();
for (String part : components) {
sb.append(isPathVariable(part) ? "{}" : part).append(KEY_SEPARATOR);
}
return sb.substring(0, sb.length() - 1);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package org.linkwave.ws.websocket.routing.broadcast;

import lombok.extern.slf4j.Slf4j;
import org.linkwave.ws.repository.ChatRepository;
import org.linkwave.ws.websocket.routing.bpp.Broadcast;
import org.springframework.lang.NonNull;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;

import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;

@Slf4j
public class FlexBroadcastManager extends SimpleBroadcastManager {

private final BroadcastRepositoryResolver repositoryResolver;

public FlexBroadcastManager(WebSocketMessageBroadcast messageBroadcast,
ChatRepository<Long, String> chatRepository,
BroadcastRepositoryResolver repositoryResolver) {
super(messageBroadcast, chatRepository);
this.repositoryResolver = repositoryResolver;
}

@Override
public void process(@NonNull Method routeHandler, @NonNull Map<String, String> pathVariables,
@NonNull Object message, @NonNull String serializedMessage) {

// if it's necessary to broadcast message
if (!isBroadcast(routeHandler)) {
return;
}

log.debug("-> process(): routeHandler=[{}.{}]",
routeHandler.getDeclaringClass().getSimpleName(),
routeHandler.getName()
);

Broadcast[] broadcasts = routeHandler.getAnnotationsByType(Broadcast.class);

if (broadcasts.length > 1) {
// remove duplicate key-patterns
final Map<String, Broadcast> broadcastMap = Arrays
.stream(broadcasts)
.collect(toMap(Broadcast::value, identity(), (b1, b2) -> b1));

if (broadcastMap.size() != broadcasts.length) {
log.warn("-> process(): found duplicate key-patterns");
broadcasts = broadcastMap.values().toArray(new Broadcast[0]);
}
}

for (Broadcast broadcastAnn : broadcasts) {

final String broadcastKeyPattern = broadcastAnn.value();
final String resolveBroadcastKey = resolveKey(
broadcastKeyPattern,
pathVariables,
broadcastAnn.analyzeMessage() ? message : null
);

// resolve sessions ids
final Set<String> members = repositoryResolver.resolve(broadcastKeyPattern, resolveBroadcastKey);

if (members.isEmpty()) {
log.debug("-> process({}): sessions not found", broadcastKeyPattern);
continue;
}

broadcast(broadcastAnn, members, serializedMessage);
}
}

@Override
public boolean isBroadcast(@NonNull Method routeHandler) {
return routeHandler.getAnnotationsByType(Broadcast.class).length != 0;
}
}
Loading
Loading