diff --git a/nostr-java-api/src/main/java/nostr/api/NostrSpringWebSocketClient.java b/nostr-java-api/src/main/java/nostr/api/NostrSpringWebSocketClient.java index f38e0c34f..a787fddb7 100644 --- a/nostr-java-api/src/main/java/nostr/api/NostrSpringWebSocketClient.java +++ b/nostr-java-api/src/main/java/nostr/api/NostrSpringWebSocketClient.java @@ -111,14 +111,15 @@ public List sendRequest(@NonNull List filtersList, @NonNull Str public List sendRequest(@NonNull Filters filters, @NonNull String subscriptionId) { createRequestClient(subscriptionId); - return clientMap.entrySet().stream().filter(entry -> - entry.getValue().getRelayName().equals(String.join(entry.getKey(), subscriptionId))) + return clientMap.entrySet().stream() + .filter(entry -> entry.getKey().endsWith(":" + subscriptionId)) .map(Entry::getValue) .map(webSocketClientHandler -> webSocketClientHandler.sendRequest( filters, webSocketClientHandler.getRelayName())) - .flatMap(List::stream).toList(); + .flatMap(List::stream) + .toList(); } @@ -157,17 +158,17 @@ public void close() throws IOException { } } + protected WebSocketClientHandler newWebSocketClientHandler(String relayName, String relayUri) { + return new WebSocketClientHandler(relayName, relayUri); + } + private void createRequestClient(String subscriptionId) { - if (clientMap.entrySet().stream() // if a request client doesn't yet exist for subscriptionId... - .noneMatch(entry -> - entry.getValue().getRelayName().equals(String.join(entry.getKey(), subscriptionId)))) { - clientMap.keySet().forEach(clientMapKey -> // ... create one for each relay and add it to the client map - clientMap.entrySet().stream().map(entry -> - new WebSocketClientHandler( - String.join(entry.getKey(), subscriptionId), - entry.getValue().getRelayUri())) - .toList().forEach(webSocketClientHandler -> - clientMap.put(clientMapKey, webSocketClientHandler))); - } + clientMap.entrySet().stream() + .filter(entry -> !entry.getKey().contains(":")) + .forEach(entry -> { + String requestKey = entry.getKey() + ":" + subscriptionId; + clientMap.computeIfAbsent(requestKey, + key -> newWebSocketClientHandler(requestKey, entry.getValue().getRelayUri())); + }); } } diff --git a/nostr-java-api/src/test/java/nostr/api/unit/NostrSpringWebSocketClientTest.java b/nostr-java-api/src/test/java/nostr/api/unit/NostrSpringWebSocketClientTest.java index a590c2192..a63bb788c 100644 --- a/nostr-java-api/src/test/java/nostr/api/unit/NostrSpringWebSocketClientTest.java +++ b/nostr-java-api/src/test/java/nostr/api/unit/NostrSpringWebSocketClientTest.java @@ -1,36 +1,83 @@ package nostr.api.unit; -import nostr.api.NostrIF; import nostr.api.NostrSpringWebSocketClient; -import nostr.id.Identity; -import org.junit.jupiter.api.BeforeEach; +import nostr.api.WebSocketClientHandler; import org.junit.jupiter.api.Test; import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; -import static org.junit.jupiter.api.Assertions.assertSame; +import sun.misc.Unsafe; + +import static org.junit.jupiter.api.Assertions.*; public class NostrSpringWebSocketClientTest { - @BeforeEach - void resetSingleton() throws Exception { - Field instance = NostrSpringWebSocketClient.class.getDeclaredField("INSTANCE"); - instance.setAccessible(true); - instance.set(null, null); + private static class TestClient extends NostrSpringWebSocketClient { + @Override + protected WebSocketClientHandler newWebSocketClientHandler(String relayName, String relayUri) { + try { + return createHandler(relayName, relayUri); + } catch (Exception e) { + throw new RuntimeException(e); + } + } } - @Test - void getInstanceShouldReturnSameInstance() { - NostrIF first = NostrSpringWebSocketClient.getInstance(); - NostrIF second = NostrSpringWebSocketClient.getInstance(); - assertSame(first, second, "Multiple calls should return the same instance"); + private static WebSocketClientHandler createHandler(String name, String uri) throws Exception { + Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + Unsafe unsafe = (Unsafe) theUnsafe.get(null); + WebSocketClientHandler handler = (WebSocketClientHandler) unsafe.allocateInstance(WebSocketClientHandler.class); + + Field relayName = WebSocketClientHandler.class.getDeclaredField("relayName"); + relayName.setAccessible(true); + relayName.set(handler, name); + + Field relayUri = WebSocketClientHandler.class.getDeclaredField("relayUri"); + relayUri.setAccessible(true); + relayUri.set(handler, uri); + + Field eventClient = WebSocketClientHandler.class.getDeclaredField("eventClient"); + eventClient.setAccessible(true); + eventClient.set(handler, null); + + Field requestClientMap = WebSocketClientHandler.class.getDeclaredField("requestClientMap"); + requestClientMap.setAccessible(true); + requestClientMap.set(handler, new ConcurrentHashMap<>()); + + return handler; } @Test - void getInstanceWithIdentityShouldReturnSameInstance() { - Identity identity = Identity.generateRandomIdentity(); - NostrIF first = NostrSpringWebSocketClient.getInstance(identity); - NostrIF second = NostrSpringWebSocketClient.getInstance(); - assertSame(first, second, "Calls with and without identity should return the same instance"); + void testMultipleSubscriptionsDoNotOverwriteHandlers() throws Exception { + NostrSpringWebSocketClient client = new TestClient(); + + Field field = NostrSpringWebSocketClient.class.getDeclaredField("clientMap"); + field.setAccessible(true); + @SuppressWarnings("unchecked") + Map map = (Map) field.get(client); + + map.put("relayA", createHandler("relayA", "ws://a")); + map.put("relayB", createHandler("relayB", "ws://b")); + + Method method = NostrSpringWebSocketClient.class.getDeclaredMethod("createRequestClient", String.class); + method.setAccessible(true); + + method.invoke(client, "sub1"); + assertEquals(4, map.size()); + WebSocketClientHandler handlerA1 = map.get("relayA:sub1"); + WebSocketClientHandler handlerB1 = map.get("relayB:sub1"); + assertNotNull(handlerA1); + assertNotNull(handlerB1); + + method.invoke(client, "sub2"); + assertEquals(6, map.size()); + assertSame(handlerA1, map.get("relayA:sub1")); + assertSame(handlerB1, map.get("relayB:sub1")); + assertNotNull(map.get("relayA:sub2")); + assertNotNull(map.get("relayB:sub2")); } }