diff --git a/core/build.gradle.kts b/core/build.gradle.kts index bff72b2ce..4b92231bb 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -20,6 +20,19 @@ dependencies { implementation("io.grpc", "grpc-protobuf", Versions.gRPCVersion) implementation("io.grpc", "grpc-stub", Versions.gRPCVersion) implementation("javax.annotation", "javax.annotation-api", "1.3.2") + + // Test deps — pinned to versions still compatible with the Java 8 source target. + testImplementation("org.junit.jupiter:junit-jupiter:5.10.5") + testImplementation("org.mockito:mockito-core:4.11.0") + testImplementation("org.awaitility:awaitility:4.2.2") + testImplementation("io.netty", "netty-transport", Versions.nettyVersion) + testImplementation("io.netty", "netty-codec", Versions.nettyVersion) + testImplementation("com.squareup.okhttp3:mockwebserver:4.9.3") + testRuntimeOnly("org.junit.platform:junit-platform-launcher") +} + +tasks.test { + useJUnitPlatform() } // present on all platforms diff --git a/core/src/main/java/com/minekube/connect/addon/packethandler/ChannelInPacketHandler.java b/core/src/main/java/com/minekube/connect/addon/packethandler/ChannelInPacketHandler.java index 034c462f1..a125c7030 100644 --- a/core/src/main/java/com/minekube/connect/addon/packethandler/ChannelInPacketHandler.java +++ b/core/src/main/java/com/minekube/connect/addon/packethandler/ChannelInPacketHandler.java @@ -46,7 +46,7 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { packetHandlers.getPacketHandlers(msg.getClass())) { Object res = consumer.apply(ctx, msg, toServer); - if (!res.equals(msg)) { + if (res != msg) { packet = res; } } diff --git a/core/src/main/java/com/minekube/connect/addon/packethandler/ChannelOutPacketHandler.java b/core/src/main/java/com/minekube/connect/addon/packethandler/ChannelOutPacketHandler.java index 29d980ba2..4b00d416d 100644 --- a/core/src/main/java/com/minekube/connect/addon/packethandler/ChannelOutPacketHandler.java +++ b/core/src/main/java/com/minekube/connect/addon/packethandler/ChannelOutPacketHandler.java @@ -47,7 +47,7 @@ protected void encode(ChannelHandlerContext ctx, Object msg, List out) { packetHandlers.getPacketHandlers(msg.getClass())) { Object res = consumer.apply(ctx, msg, toServer); - if (!res.equals(msg)) { + if (res != msg) { packet = res; } } diff --git a/core/src/main/java/com/minekube/connect/inject/CommonPlatformInjector.java b/core/src/main/java/com/minekube/connect/inject/CommonPlatformInjector.java index 9dd033cc7..19bf4b79c 100644 --- a/core/src/main/java/com/minekube/connect/inject/CommonPlatformInjector.java +++ b/core/src/main/java/com/minekube/connect/inject/CommonPlatformInjector.java @@ -30,10 +30,9 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import java.net.SocketAddress; -import java.util.HashMap; -import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import lombok.AccessLevel; import lombok.Getter; @@ -65,10 +64,12 @@ public void shutdown() { } } + // Both registries are mutated from Netty I/O threads (per-channel) and iterated + // concurrently, so they must be thread-safe. @Getter(AccessLevel.PROTECTED) - private final Set injectedClients = new HashSet<>(); + private final Set injectedClients = ConcurrentHashMap.newKeySet(); - private final Map, InjectorAddon> addons = new HashMap<>(); + private final Map, InjectorAddon> addons = new ConcurrentHashMap<>(); protected boolean addInjectedClient(Channel channel) { return injectedClients.add(channel); diff --git a/core/src/main/java/com/minekube/connect/network/netty/LocalSession.java b/core/src/main/java/com/minekube/connect/network/netty/LocalSession.java index d89e72e25..0af982935 100644 --- a/core/src/main/java/com/minekube/connect/network/netty/LocalSession.java +++ b/core/src/main/java/com/minekube/connect/network/netty/LocalSession.java @@ -69,7 +69,7 @@ */ @RequiredArgsConstructor public final class LocalSession { - private static final int CONNECTION_TIMEOUT = (int) Duration.ofSeconds(30).toMillis(); + private static final int CONNECTION_TIMEOUT = (int) Duration.ofSeconds(10).toMillis(); private static DefaultEventLoopGroup DEFAULT_EVENT_LOOP_GROUP; private static EventLoopGroup PLATFORM_EVENT_LOOP_GROUP; // Platform-specific event loop group diff --git a/core/src/main/java/com/minekube/connect/network/netty/TunnelHandler.java b/core/src/main/java/com/minekube/connect/network/netty/TunnelHandler.java index 0f725425c..780ddcd90 100644 --- a/core/src/main/java/com/minekube/connect/network/netty/TunnelHandler.java +++ b/core/src/main/java/com/minekube/connect/network/netty/TunnelHandler.java @@ -32,6 +32,9 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; +import io.netty.channel.EventLoop; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; import lombok.RequiredArgsConstructor; @RequiredArgsConstructor @@ -39,11 +42,37 @@ class TunnelHandler implements Handler { private final ConnectLogger logger; private final Channel downstreamServerConn; // local server connection + // Coalesces flushes across an EventLoop tick: one flush() per batch of + // onReceive calls instead of one per packet. The CAS lives inside the + // write task so the flush is always enqueued after the write that needs + // it — scheduling the CAS outside the EventLoop races, because a later + // write can be enqueued behind an already-scheduled flush. + private final AtomicBoolean flushScheduled = new AtomicBoolean(false); + @Override public void onReceive(byte[] data) { - // TunnelService -> local session server -> downstream server - ByteBuf buf = Unpooled.wrappedBuffer(data); - downstreamServerConn.writeAndFlush(buf); + // TunnelService -> local session server -> downstream server. + // Allocate the ByteBuf inside the lambda so it isn't leaked if execute() + // rejects (event loop shutting down during proxy stop). + Channel ch = downstreamServerConn; + EventLoop el = ch.eventLoop(); + try { + el.execute(() -> { + ch.write(Unpooled.wrappedBuffer(data), ch.voidPromise()); + if (flushScheduled.compareAndSet(false, true)) { + try { + el.execute(() -> { + flushScheduled.set(false); + ch.flush(); + }); + } catch (RejectedExecutionException ignored) { + flushScheduled.set(false); + } + } + }); + } catch (RejectedExecutionException ignored) { + // Event loop is shutting down; the channel is going away anyway. + } } @Override @@ -63,7 +92,20 @@ public void onError(Throwable t) { @Override public void onClose() { - // disconnect from downstream server - downstreamServerConn.close(); + // Flush before closing: deferred writes from onReceive() may still be + // sitting in the channel's outbound buffer with the flush scheduled as + // a separate EventLoop task, so closing without a final flush can drop + // the last payload. + Channel ch = downstreamServerConn; + try { + ch.eventLoop().execute(() -> { + ch.flush(); + ch.close(); + }); + } catch (RejectedExecutionException ignored) { + // Event loop already shut down: close directly. Netty's close is + // thread-safe and a no-op on an already-closed channel. + ch.close(); + } } } diff --git a/core/src/main/java/com/minekube/connect/packet/PacketHandlersImpl.java b/core/src/main/java/com/minekube/connect/packet/PacketHandlersImpl.java index 40d7e287b..0d36172ca 100644 --- a/core/src/main/java/com/minekube/connect/packet/PacketHandlersImpl.java +++ b/core/src/main/java/com/minekube/connect/packet/PacketHandlersImpl.java @@ -29,21 +29,25 @@ import com.minekube.connect.api.packet.PacketHandlers; import com.minekube.connect.api.util.TriFunction; import io.netty.channel.ChannelHandlerContext; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CopyOnWriteArraySet; import lombok.AllArgsConstructor; import lombok.Getter; public final class PacketHandlersImpl implements PacketHandlers { - private final Map> handlers = new HashMap<>(); - private final Set> globalPacketHandlers = new HashSet<>(); - private final Map, Set>> packetHandlers = new HashMap<>(); + // CopyOnWriteArraySet for the per-class fanout: reads happen on every packet + // (hot path, must be lock-free); writes only on register/deregister. + private final Map> handlers = new ConcurrentHashMap<>(); + private final Set> globalPacketHandlers = + new CopyOnWriteArraySet<>(); + private final Map, Set>> packetHandlers = + new ConcurrentHashMap<>(); @Override public void register( @@ -55,10 +59,10 @@ public void register( return; } - handlers.computeIfAbsent(handler, $ -> new ArrayList<>()) + handlers.computeIfAbsent(handler, $ -> new CopyOnWriteArrayList<>()) .add(new HandlerEntry(packetClass, consumer)); - packetHandlers.computeIfAbsent(packetClass, $ -> new HashSet<>(globalPacketHandlers)) + packetHandlers.computeIfAbsent(packetClass, $ -> new CopyOnWriteArraySet<>(globalPacketHandlers)) .add(consumer); } @@ -70,7 +74,7 @@ public void registerAll(PacketHandler handler) { TriFunction packetHandler = handler::handle; - handlers.computeIfAbsent(handler, $ -> new ArrayList<>()) + handlers.computeIfAbsent(handler, $ -> new CopyOnWriteArrayList<>()) .add(new HandlerEntry(null, packetHandler)); globalPacketHandlers.add(packetHandler); @@ -88,13 +92,19 @@ public void deregister(PacketHandler handler) { List values = handlers.remove(handler); if (values != null) { for (HandlerEntry value : values) { - Set handlers = packetHandlers.get(value.getPacket()); - - if (handlers != null) { - handlers.removeIf(o -> o.equals(value.getHandler())); - if (handlers.isEmpty()) { - packetHandlers.remove(value.getPacket()); - } + // registerAll() stores HandlerEntry with packetClass == null. + // ConcurrentHashMap rejects null keys, so skip the per-class + // lookup for global handlers (the old HashMap returned null + // silently for the same case). + Class packetClass = value.getPacket(); + if (packetClass != null) { + // computeIfPresent atomically removes the entry only if it's + // still empty after our removal, so a concurrent register() + // that re-populates the set in between isn't dropped. + packetHandlers.computeIfPresent(packetClass, (k, set) -> { + set.removeIf(o -> o.equals(value.getHandler())); + return set.isEmpty() ? null : set; + }); } globalPacketHandlers.remove(value.getHandler()); diff --git a/core/src/main/java/com/minekube/connect/platform/command/CommandUtil.java b/core/src/main/java/com/minekube/connect/platform/command/CommandUtil.java index 8d250feab..beaa2489d 100644 --- a/core/src/main/java/com/minekube/connect/platform/command/CommandUtil.java +++ b/core/src/main/java/com/minekube/connect/platform/command/CommandUtil.java @@ -80,7 +80,7 @@ public abstract class CommandUtil { public @NonNull Collection getOnlineUsernames() { Collection usernames = new ArrayList<>(); - getOnlinePlayers().forEach(this::getUsernameFromSource); + getOnlinePlayers().forEach(player -> usernames.add(getUsernameFromSource(player))); return usernames; } diff --git a/core/src/main/java/com/minekube/connect/register/WatcherRegister.java b/core/src/main/java/com/minekube/connect/register/WatcherRegister.java index 68b2eef3c..b79f4f52c 100644 --- a/core/src/main/java/com/minekube/connect/register/WatcherRegister.java +++ b/core/src/main/java/com/minekube/connect/register/WatcherRegister.java @@ -40,9 +40,13 @@ import com.minekube.connect.watch.SessionProposal.State; import com.minekube.connect.watch.WatchClient; import com.minekube.connect.watch.Watcher; +import io.netty.util.concurrent.DefaultThreadFactory; import java.io.IOException; import java.time.Duration; -import java.util.Timer; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import okhttp3.WebSocket; @@ -56,13 +60,23 @@ public class WatcherRegister { @Inject private ConnectLogger logger; @Inject private SimpleConnectApi api; - private WebSocket ws; + // volatile: written from injection thread (start/stop) and read from the + // scheduler thread (retry) and OkHttp dispatcher (WatcherImpl callbacks). + private volatile WebSocket ws; private ExponentialBackOff backOffPolicy; private final AtomicBoolean started = new AtomicBoolean(); + // Lazily created in start() so a stop()/start() cycle reuses cleanly, + // and so the daemon thread isn't allocated if start() is never called. + // java.util.Timer would leak one OS thread per reconnect cycle. + private volatile ScheduledExecutorService scheduler; + private volatile ScheduledFuture retryFuture; + @Inject public void start() { if (started.compareAndSet(false, true)) { + scheduler = Executors.newSingleThreadScheduledExecutor( + new DefaultThreadFactory("connect-watcher-scheduler", true)); backOffPolicy = new ExponentialBackOff.Builder() .setInitialIntervalMillis(1000) // 1 second .setMaxElapsedTimeMillis(Integer.MAX_VALUE) // 24.8 days @@ -78,62 +92,58 @@ public void resetBackOff() { } public void stop() { + // Gate the whole teardown so a concurrent stop() races safely. + // A stop() before start() is a no-op (started is already false). + if (!started.compareAndSet(true, false)) { + return; + } + logger.info("Stopped watching for sessions"); + if (retryFuture != null) { + retryFuture.cancel(false); + retryFuture = null; + } + if (scheduler != null) { + scheduler.shutdownNow(); + scheduler = null; + } if (ws != null) { - if (started.compareAndSet(true, false)) { - logger.info("Stopped watching for sessions"); - } - if (timer != null) { - timer.cancel(); - timer = null; - } - if (retryTask != null) { - retryTask.cancel(); - retryTask = null; - } ws.close(1000, "watcher stopped"); ws = null; } } - private Timer timer; - private TimerTask retryTask; - private void retry() { - if (started.get()) { - if (retryTask != null) { - retryTask.cancel(); - } - if (timer == null) { - timer = new Timer(); - } - long millis; - try { - millis = backOffPolicy.nextBackOffMillis(); - if (millis == BackOff.STOP) { - stop(); - return; - } - } catch (IOException e) { - logger.error("nextBackOffMillis error", e); + if (!started.get()) { + return; + } + if (retryFuture != null) { + retryFuture.cancel(false); + } + long millis; + try { + millis = backOffPolicy.nextBackOffMillis(); + if (millis == BackOff.STOP) { + stop(); return; } - retryTask = new TimerTask(); - logger.info("Trying to reconnect in {}...", - Utils.humanReadableFormat(Duration.ofMillis(millis))); - timer.schedule(retryTask, millis); + } catch (IOException e) { + logger.error("nextBackOffMillis error", e); + return; } - } - - private class TimerTask extends java.util.TimerTask { - @Override - public void run() { + // Snapshot to avoid NPE if stop() races with a late callback that triggered retry(). + ScheduledExecutorService s = scheduler; + if (s == null) { + return; + } + logger.info("Trying to reconnect in {}...", + Utils.humanReadableFormat(Duration.ofMillis(millis))); + retryFuture = s.schedule(() -> { if (started.get()) { watch(); } - } + }, millis, TimeUnit.MILLISECONDS); } - private void watch() { if (ws != null) { ws.close(1000, "watcher is reconnecting"); @@ -146,8 +156,6 @@ private class WatcherImpl implements Watcher { @Override public void onOpen() { logger.translatedInfo("connect.watch.started"); - - // Reset the retry backoff after the connection is healthy for some seconds startResetBackOffTimer(); } @@ -172,7 +180,7 @@ public void onProposal(SessionProposal proposal) { return; } - if (logger.isDebug()) { // skipping a lot of proposal.toString operations + if (logger.isDebug()) { logger.debug("Received {}", proposal); } @@ -180,7 +188,6 @@ public void onProposal(SessionProposal proposal) { return; } - // Try establishing connection new LocalSession(logger, api, tunneler, platformInjector.getServerSocketAddress(), proposal @@ -205,29 +212,27 @@ public void onCompleted() { retry(); } - private Timer resetBackOffTimer; + private volatile ScheduledFuture resetBackOffFuture; void startResetBackOffTimer() { - if (resetBackOffTimer != null) { - resetBackOffTimer.cancel(); + cancelResetBackOffTimer(); + // Snapshot: a late onOpen after stop() can land here with scheduler == null. + ScheduledExecutorService s = scheduler; + if (s == null || !started.get()) { + return; } - resetBackOffTimer = new Timer(); - resetBackOffTimer.schedule(new TimerTask() { - @Override - public void run() { - if (started.get()) { - resetBackOff(); - } + resetBackOffFuture = s.schedule(() -> { + if (started.get()) { + resetBackOff(); } - }, Duration.ofSeconds(10).toMillis()); + }, Duration.ofSeconds(10).toMillis(), TimeUnit.MILLISECONDS); } void cancelResetBackOffTimer() { - if (resetBackOffTimer != null) { - resetBackOffTimer.cancel(); - resetBackOffTimer = null; + if (resetBackOffFuture != null) { + resetBackOffFuture.cancel(false); + resetBackOffFuture = null; } } - } } diff --git a/core/src/main/java/com/minekube/connect/tunnel/Tunneler.java b/core/src/main/java/com/minekube/connect/tunnel/Tunneler.java index 2aa23038b..66a63919d 100644 --- a/core/src/main/java/com/minekube/connect/tunnel/Tunneler.java +++ b/core/src/main/java/com/minekube/connect/tunnel/Tunneler.java @@ -106,8 +106,17 @@ public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, public void onMessage(@NotNull WebSocket webSocket, @NotNull ByteString bytes) { // handler.onReceive(bytes.toByteArray()); // would copy - // zero-copy get byte array - handler.onReceive(ReflectionUtils.getCastedValue(bytes, DATA)); + // Zero-copy fast path with safe fallback. DATA is null if the field + // can't be located (e.g. after Okio relocation); getCastedValue can + // also return null at runtime because IllegalAccessException is + // swallowed inside ReflectionUtils.getValue. Without this guard, + // onReceive(null) flows into Unpooled.wrappedBuffer(null), NPE-ing + // the OkHttp dispatcher thread and dropping every tunnel on it. + byte[] rawBytes = DATA != null ? ReflectionUtils.getCastedValue(bytes, DATA) : null; + if (rawBytes == null) { + rawBytes = bytes.toByteArray(); + } + handler.onReceive(rawBytes); } @Override diff --git a/core/src/main/java/com/minekube/connect/util/HttpUtils.java b/core/src/main/java/com/minekube/connect/util/HttpUtils.java index 423a7aa19..78f79de31 100644 --- a/core/src/main/java/com/minekube/connect/util/HttpUtils.java +++ b/core/src/main/java/com/minekube/connect/util/HttpUtils.java @@ -28,6 +28,7 @@ import com.google.common.collect.ImmutableList; import com.google.gson.Gson; import com.google.gson.JsonObject; +import io.netty.util.concurrent.DefaultThreadFactory; import java.io.InputStream; import java.io.InputStreamReader; import java.net.HttpURLConnection; @@ -48,7 +49,8 @@ // resources are properly closed and ignoring the original stack trace is intended @SuppressWarnings({"PMD.CloseResource", "PMD.PreserveStackTrace"}) public class HttpUtils { - private static final ExecutorService EXECUTOR_SERVICE = Executors.newSingleThreadExecutor(); + private static final ExecutorService EXECUTOR_SERVICE = Executors.newFixedThreadPool( + 4, new DefaultThreadFactory("connect-http-worker", true)); private static final Gson GSON = new Gson(); private static final String USER_AGENT = "Minekube/Connect"; diff --git a/core/src/test/java/com/minekube/connect/addon/packethandler/ChannelInPacketHandlerTest.java b/core/src/test/java/com/minekube/connect/addon/packethandler/ChannelInPacketHandlerTest.java new file mode 100644 index 000000000..76c9b6c53 --- /dev/null +++ b/core/src/test/java/com/minekube/connect/addon/packethandler/ChannelInPacketHandlerTest.java @@ -0,0 +1,76 @@ +package com.minekube.connect.addon.packethandler; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.mockito.Mockito.mock; + +import com.minekube.connect.api.packet.PacketHandler; +import com.minekube.connect.packet.PacketHandlersImpl; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +class ChannelInPacketHandlerTest { + + @Test + void passesOriginalMessageWhenNoHandlersRegistered() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + EmbeddedChannel channel = newChannel(packetHandlers); + String msg = new String("foo"); + + channel.writeInbound(msg); + Object forwarded = channel.readInbound(); + + assertSame(msg, forwarded); + } + + @Test + void passesOriginalMessageWhenHandlerReturnsSameReference() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + PacketHandler owner = mock(PacketHandler.class); + packetHandlers.register(owner, String.class, (ctx, packet, serverbound) -> packet); + EmbeddedChannel channel = newChannel(packetHandlers); + String msg = new String("foo"); + + channel.writeInbound(msg); + Object forwarded = channel.readInbound(); + + assertSame(msg, forwarded); + } + + @Test + void forwardsNewReferenceEvenWhenItEqualsOriginalMessage() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + PacketHandler owner = mock(PacketHandler.class); + packetHandlers.register(owner, String.class, (ctx, packet, serverbound) -> new String((String) packet)); + EmbeddedChannel channel = newChannel(packetHandlers); + String msg = new String("foo"); + + channel.writeInbound(msg); + Object forwarded = channel.readInbound(); + + assertEquals(msg, forwarded); + assertNotSame(msg, forwarded); + } + + @Test + void forwardsLastReplacementWhenMultipleHandlersReturnNewReferences() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + PacketHandler firstOwner = mock(PacketHandler.class); + PacketHandler secondOwner = mock(PacketHandler.class); + String firstReplacement = new String("first"); + String secondReplacement = new String("second"); + packetHandlers.register(firstOwner, String.class, (ctx, packet, serverbound) -> firstReplacement); + packetHandlers.register(secondOwner, String.class, (ctx, packet, serverbound) -> secondReplacement); + EmbeddedChannel channel = newChannel(packetHandlers); + + channel.writeInbound(new String("foo")); + Object forwarded = channel.readInbound(); + + assertSame(secondReplacement, forwarded); + } + + private static EmbeddedChannel newChannel(PacketHandlersImpl packetHandlers) { + return new EmbeddedChannel(new ChannelInPacketHandler(packetHandlers, true)); + } +} diff --git a/core/src/test/java/com/minekube/connect/addon/packethandler/ChannelOutPacketHandlerTest.java b/core/src/test/java/com/minekube/connect/addon/packethandler/ChannelOutPacketHandlerTest.java new file mode 100644 index 000000000..57512778a --- /dev/null +++ b/core/src/test/java/com/minekube/connect/addon/packethandler/ChannelOutPacketHandlerTest.java @@ -0,0 +1,80 @@ +package com.minekube.connect.addon.packethandler; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.minekube.connect.api.packet.PacketHandler; +import com.minekube.connect.packet.PacketHandlersImpl; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +class ChannelOutPacketHandlerTest { + @Test + void noHandlersRegisteredPassesThroughOriginalMessage() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + String msg = "foo"; + + Object outbound = writeOutbound(packetHandlers, msg); + + assertSame(msg, outbound); + } + + @Test + void handlerReturningSameReferencePassesThroughOriginalMessage() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + String msg = "foo"; + register(packetHandlers, msg); + + Object outbound = writeOutbound(packetHandlers, msg); + + assertSame(msg, outbound); + } + + @Test + void handlerReturningEqualNewReferenceForwardsNewReference() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + String msg = "foo"; + String replacement = new String("foo"); + register(packetHandlers, replacement); + + Object outbound = writeOutbound(packetHandlers, msg); + + assertEquals(msg, outbound); + assertNotSame(msg, outbound); + assertSame(replacement, outbound); + } + + @Test + void multipleHandlersForwardLastNewReference() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + String msg = "foo"; + String firstReplacement = new String("first"); + String secondReplacement = new String("second"); + register(packetHandlers, firstReplacement); + register(packetHandlers, secondReplacement); + + Object outbound = writeOutbound(packetHandlers, msg); + + assertSame(secondReplacement, outbound); + } + + private static void register(PacketHandlersImpl packetHandlers, Object result) { + PacketHandler owner = (ctx, packet, serverbound) -> packet; + packetHandlers.register(owner, String.class, (ctx, packet, serverbound) -> result); + } + + private static Object writeOutbound(PacketHandlersImpl packetHandlers, Object msg) { + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutPacketHandler(packetHandlers, true)); + try { + assertTrue(channel.writeOutbound(msg)); + Object outbound = channel.readOutbound(); + assertNull(channel.readOutbound()); + return outbound; + } finally { + channel.finishAndReleaseAll(); + } + } +} diff --git a/core/src/test/java/com/minekube/connect/inject/CommonPlatformInjectorTest.java b/core/src/test/java/com/minekube/connect/inject/CommonPlatformInjectorTest.java new file mode 100644 index 000000000..1cbff14bd --- /dev/null +++ b/core/src/test/java/com/minekube/connect/inject/CommonPlatformInjectorTest.java @@ -0,0 +1,153 @@ +package com.minekube.connect.inject; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import com.minekube.connect.api.inject.InjectorAddon; +import io.netty.channel.Channel; +import io.netty.channel.embedded.EmbeddedChannel; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.Test; + +class CommonPlatformInjectorTest { + private static final int THREADS = 8; + private static final int CHANNELS_PER_THREAD = 100; + @Test + void addInjectedClientReturnsFalseForDuplicateChannel() { + TestInjector injector = new TestInjector(); + EmbeddedChannel channel = new EmbeddedChannel(); + assertTrue(injector.addClient(channel)); + assertFalse(injector.addClient(channel)); + assertEquals(1, injector.injectedClientCount()); + } + @Test + void removeInjectedClientRemovesChannel() { + TestInjector injector = new TestInjector(); + EmbeddedChannel channel = new EmbeddedChannel(); + injector.addClient(channel); + assertTrue(injector.removeClient(channel)); + assertFalse(injector.removeClient(channel)); + assertEquals(0, injector.injectedClientCount()); + } + @Test + void addInjectedClientAcceptsConcurrentAdds() throws Exception { + TestInjector injector = new TestInjector(); + AtomicReference failure = runConcurrently(() -> { + for (int i = 0; i < CHANNELS_PER_THREAD; i++) { + injector.addClient(new EmbeddedChannel()); + } + }); + assertNull(failure.get()); + assertEquals(THREADS * CHANNELS_PER_THREAD, injector.injectedClientCount()); + } + @Test + void injectedClientsCanBeSizedWhileConcurrentAddsRun() throws Exception { + TestInjector injector = new TestInjector(); + AtomicBoolean adding = new AtomicBoolean(true); + AtomicReference sizeFailure = new AtomicReference<>(); + Thread sizer = new Thread(() -> { + try { + while (adding.get()) { + injector.injectedClientCount(); + Thread.yield(); + } + } catch (Throwable throwable) { + sizeFailure.compareAndSet(null, throwable); + } + }); + AtomicReference addFailure; + try { + sizer.start(); + addFailure = runConcurrently(() -> { + for (int i = 0; i < CHANNELS_PER_THREAD; i++) { + injector.addClient(new EmbeddedChannel()); + } + }); + } finally { + adding.set(false); + sizer.join(TimeUnit.SECONDS.toMillis(5)); + } + assertFalse(sizer.isAlive()); + assertNull(addFailure.get()); + assertNull(sizeFailure.get()); + assertEquals(THREADS * CHANNELS_PER_THREAD, injector.injectedClientCount()); + } + @Test + void injectAddonsCallCanIterateWhileAddonsChange() throws Exception { + TestInjector injector = new TestInjector(); + InjectorAddon addon = mock(InjectorAddon.class); + EmbeddedChannel channel = new EmbeddedChannel(); + AtomicBoolean mutating = new AtomicBoolean(true); + AtomicReference iterationFailure = new AtomicReference<>(); + Thread iterator = new Thread(() -> { + try { + while (mutating.get()) { + injector.injectAddonsCall(channel, true); + Thread.yield(); + } + } catch (Throwable throwable) { + iterationFailure.compareAndSet(null, throwable); + } + }); + AtomicReference mutationFailure; + try { + iterator.start(); + mutationFailure = runConcurrently(() -> { + // 200 mutations × 8 threads = 1600 add/remove pairs — enough to + // provoke a CME if the data structures aren't thread-safe. + for (int i = 0; i < 200; i++) { + injector.addAddon(addon); + injector.removeAddon(addon.getClass()); + } + }); + } finally { + mutating.set(false); + iterator.join(TimeUnit.SECONDS.toMillis(5)); + } + assertFalse(iterator.isAlive()); + assertNull(mutationFailure.get()); + assertNull(iterationFailure.get()); + } + private static AtomicReference runConcurrently(CheckedRunnable task) + throws Exception { + ExecutorService executor = Executors.newFixedThreadPool(THREADS); + CountDownLatch start = new CountDownLatch(1); + AtomicReference failure = new AtomicReference<>(); + try { + for (int i = 0; i < THREADS; i++) { + executor.execute(() -> { + try { + start.await(); + task.run(); + } catch (Throwable throwable) { + failure.compareAndSet(null, throwable); + } + }); + } + start.countDown(); + executor.shutdown(); + assertTrue(executor.awaitTermination(10, TimeUnit.SECONDS)); + return failure; + } finally { + executor.shutdownNow(); + } + } + private interface CheckedRunnable { void run() throws Exception; } + private static final class TestInjector extends CommonPlatformInjector { + boolean addClient(Channel channel) { return addInjectedClient(channel); } + boolean removeClient(Channel channel) { return removeInjectedClient(channel); } + int injectedClientCount() { return getInjectedClients().size(); } + @Override + public boolean inject() { return true; } + @Override + public boolean isInjected() { return true; } + } +} diff --git a/core/src/test/java/com/minekube/connect/network/netty/TunnelHandlerTest.java b/core/src/test/java/com/minekube/connect/network/netty/TunnelHandlerTest.java new file mode 100644 index 000000000..ab46e8b29 --- /dev/null +++ b/core/src/test/java/com/minekube/connect/network/netty/TunnelHandlerTest.java @@ -0,0 +1,204 @@ +package com.minekube.connect.network.netty; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.minekube.connect.api.logger.ConnectLogger; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoop; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +class TunnelHandlerTest { + private DefaultEventLoop eventLoop; + private Channel channel; + private ChannelFuture closeFuture; + private final List events = new ArrayList<>(); + + @AfterEach + void shutdownEventLoop() throws Exception { + if (eventLoop != null) { + // Zero quiet period — tests don't need Netty's default 2s grace. + eventLoop.shutdownGracefully(0, 5, SECONDS).get(5, SECONDS); + } + } + + @Test + void onReceiveWritesPayloadAndFlushesOnce() throws Exception { + TunnelHandler handler = newHandler(); + byte[] payload = new byte[] {1, 2, 3}; + + runWithEventLoopBlocked(() -> handler.onReceive(payload)); + awaitEventLoop(); + + assertEventTypes(Event.WRITE, Event.FLUSH); + assertArrayEquals(payload, events.get(0).payload); + } + + @Test + void burstOfReceivesCoalescesIntoOneFlush() throws Exception { + TunnelHandler handler = newHandler(); + List payloads = new ArrayList<>(); + + runWithEventLoopBlocked(() -> { + for (int i = 0; i < 50; i++) { + byte[] payload = new byte[] {(byte) i, (byte) (i + 1)}; + payloads.add(payload); + handler.onReceive(payload); + } + }); + awaitEventLoop(); + + assertEquals(50, count(Event.WRITE)); + assertEquals(1, count(Event.FLUSH)); + List actualPayloads = writePayloads(); + assertEquals(50, actualPayloads.size()); + for (int i = 0; i < payloads.size(); i++) { + assertArrayEquals(payloads.get(i), actualPayloads.get(i)); + } + } + + @Test + void onCloseFlushesPendingWriteBeforeClosingChannel() throws Exception { + TunnelHandler handler = newHandler(); + byte[] payload = new byte[] {9, 8, 7}; + + runWithEventLoopBlocked(() -> { + handler.onReceive(payload); + handler.onClose(); + }); + awaitEventLoop(); + + // The CLOSE-driven flush must precede CLOSE so the payload reaches the + // wire before the channel is torn down. A trailing FLUSH from the + // deferred-write task is harmless (no-op on a closed channel in real Netty). + List types = eventTypes(); + int firstFlush = types.indexOf(Event.FLUSH); + int close = types.indexOf(Event.CLOSE); + assertEquals(Event.WRITE, types.get(0)); + assertTrue(firstFlush > 0, "expected a FLUSH before CLOSE, got: " + types); + assertTrue(close > firstFlush, "expected CLOSE after FLUSH, got: " + types); + assertArrayEquals(payload, events.get(0).payload); + assertTrue(closeFuture.isDone()); + } + + private TunnelHandler newHandler() { + eventLoop = new DefaultEventLoop(); + channel = mock(Channel.class); + closeFuture = mock(ChannelFuture.class); + + when(channel.eventLoop()).thenReturn(eventLoop); + when(channel.voidPromise()).thenReturn(mock(ChannelPromise.class)); + when(closeFuture.isDone()).thenReturn(true); + + doAnswer(invocation -> { + ByteBuf buf = invocation.getArgument(0); + try { + byte[] payload = new byte[buf.readableBytes()]; + buf.getBytes(buf.readerIndex(), payload); + events.add(new RecordedEvent(Event.WRITE, payload)); + return invocation.getArgument(1); + } finally { + buf.release(); + } + }).when(channel).write(any(ByteBuf.class), any(ChannelPromise.class)); + + doAnswer(invocation -> { + events.add(new RecordedEvent(Event.FLUSH)); + return channel; + }).when(channel).flush(); + + doAnswer(invocation -> { + events.add(new RecordedEvent(Event.CLOSE)); + return closeFuture; + }).when(channel).close(); + + return new TunnelHandler(mock(ConnectLogger.class), channel); + } + + private void awaitEventLoop() throws Exception { + eventLoop.submit(() -> null).get(5, SECONDS); + eventLoop.submit(() -> null).get(5, SECONDS); + } + + private void runWithEventLoopBlocked(Runnable task) throws Exception { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch release = new CountDownLatch(1); + eventLoop.execute(() -> { + started.countDown(); + try { + assertTrue(release.await(5, SECONDS)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError(e); + } + }); + assertTrue(started.await(5, SECONDS)); + try { + task.run(); + } finally { + release.countDown(); + } + } + + private void assertEventTypes(Event... expected) { + assertEquals(Arrays.asList(expected), eventTypes()); + } + + private List eventTypes() { + List types = new ArrayList<>(); + for (RecordedEvent event : events) { + types.add(event.type); + } + return types; + } + + private List writePayloads() { + List payloads = new ArrayList<>(); + for (RecordedEvent event : events) { + if (event.type == Event.WRITE) { + payloads.add(event.payload); + } + } + return payloads; + } + + private int count(Event type) { + int count = 0; + for (RecordedEvent event : events) { + if (event.type == type) { + count++; + } + } + return count; + } + + private enum Event { WRITE, FLUSH, CLOSE } + + private static final class RecordedEvent { + private final Event type; + private final byte[] payload; + + private RecordedEvent(Event type) { + this(type, null); + } + + private RecordedEvent(Event type, byte[] payload) { + this.type = type; + this.payload = payload; + } + } +} diff --git a/core/src/test/java/com/minekube/connect/packet/PacketHandlersImplTest.java b/core/src/test/java/com/minekube/connect/packet/PacketHandlersImplTest.java new file mode 100644 index 000000000..c732d920c --- /dev/null +++ b/core/src/test/java/com/minekube/connect/packet/PacketHandlersImplTest.java @@ -0,0 +1,181 @@ +package com.minekube.connect.packet; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import com.minekube.connect.api.packet.PacketHandler; +import com.minekube.connect.api.util.TriFunction; +import io.netty.channel.ChannelHandlerContext; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; + +class PacketHandlersImplTest { + private static final TriFunction STRING_CONSUMER = + (ctx, packet, serverbound) -> "string"; + private static final TriFunction INTEGER_CONSUMER = + (ctx, packet, serverbound) -> "integer"; + + @Test + void registerForPacketClassReturnsConsumerOnlyForThatClass() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + PacketHandler owner = mock(PacketHandler.class); + + packetHandlers.register(owner, String.class, STRING_CONSUMER); + + assertTrue(packetHandlers.getPacketHandlers(String.class).contains(STRING_CONSUMER)); + assertTrue(packetHandlers.getPacketHandlers(Integer.class).isEmpty()); + } + + @Test + void registerAllAddsGlobalHandlerToExistingPacketClasses() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + PacketHandler owner = mock(PacketHandler.class); + PacketHandler globalHandler = (ctx, packet, serverbound) -> "global"; + packetHandlers.register(owner, String.class, STRING_CONSUMER); + + packetHandlers.registerAll(globalHandler); + + Collection> handlers = + packetHandlers.getPacketHandlers(String.class); + assertTrue(handlers.contains(STRING_CONSUMER)); + assertTrue(anyHandlerReturns(handlers, "global")); + } + + @Test + void deregisterRemovesPacketClassHandlersCleanly() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + PacketHandler owner = mock(PacketHandler.class); + packetHandlers.register(owner, String.class, STRING_CONSUMER); + + packetHandlers.deregister(owner); + + assertTrue(packetHandlers.getPacketHandlers(String.class).isEmpty()); + assertFalse(packetHandlers.hasHandlers()); + } + + @Test + void deregisterGlobalHandlerDoesNotThrowAndRemovesItFromFuturePacketClasses() { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + PacketHandler globalHandler = (ctx, packet, serverbound) -> "global"; + PacketHandler owner = mock(PacketHandler.class); + packetHandlers.registerAll(globalHandler); + + assertDoesNotThrow(() -> packetHandlers.deregister(globalHandler)); + + packetHandlers.register(owner, String.class, STRING_CONSUMER); + + Collection> handlers = + packetHandlers.getPacketHandlers(String.class); + assertTrue(handlers.contains(STRING_CONSUMER)); + assertFalse(anyHandlerReturns(handlers, "global")); + } + + @Test + void concurrentRegisterAndDeregisterDoesNotLeaveRegisteredHandlers() throws Exception { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + int threads = 8; + int iterations = 200; + ExecutorService executor = Executors.newFixedThreadPool(threads); + CountDownLatch start = new CountDownLatch(1); + List> futures = new ArrayList<>(); + + try { + for (int thread = 0; thread < threads; thread++) { + futures.add(executor.submit(() -> { + start.await(); + for (int i = 0; i < iterations; i++) { + PacketHandler owner = newOwner(); + TriFunction consumer = + (ctx, packet, serverbound) -> packet; + packetHandlers.register(owner, String.class, consumer); + packetHandlers.deregister(owner); + } + return null; + })); + } + + start.countDown(); + + for (Future future : futures) { + future.get(5, TimeUnit.SECONDS); + } + + executor.shutdown(); + assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS)); + } finally { + executor.shutdownNow(); + } + + assertFalse(packetHandlers.hasHandlers()); + } + + @Test + void concurrentRegisterDuringDeregisterKeepsNewlyRegisteredHandler() throws Exception { + ExecutorService executor = Executors.newFixedThreadPool(2); + + try { + // 30 iterations catches the same races without paying Mockito mock + // construction cost 100×. + for (int iteration = 0; iteration < 30; iteration++) { + PacketHandlersImpl packetHandlers = new PacketHandlersImpl(); + PacketHandler handlerA = mock(PacketHandler.class); + PacketHandler handlerB = mock(PacketHandler.class); + CountDownLatch start = new CountDownLatch(1); + + Future registerA = executor.submit(() -> { + start.await(); + packetHandlers.register(handlerA, String.class, STRING_CONSUMER); + return null; + }); + Future registerBThenDeregisterA = executor.submit(() -> { + start.await(); + packetHandlers.register(handlerB, String.class, INTEGER_CONSUMER); + packetHandlers.deregister(handlerA); + return null; + }); + + start.countDown(); + registerA.get(5, TimeUnit.SECONDS); + registerBThenDeregisterA.get(5, TimeUnit.SECONDS); + + assertTrue( + packetHandlers.getPacketHandlers(String.class).contains(INTEGER_CONSUMER), + "consumer registered during deregister was dropped on iteration " + iteration); + } + + executor.shutdown(); + assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS)); + } finally { + executor.shutdownNow(); + } + } + + private static PacketHandler newOwner() { + return new PacketHandler() { + @Override + public Object handle(ChannelHandlerContext ctx, Object packet, boolean serverbound) { + return packet; + } + }; + } + + private static boolean anyHandlerReturns( + Collection> handlers, + Object expected) { + for (TriFunction handler : handlers) { + if (expected.equals(handler.apply(null, "packet", true))) { + return true; + } + } + return false; + } +} diff --git a/core/src/test/java/com/minekube/connect/platform/command/CommandUtilTest.java b/core/src/test/java/com/minekube/connect/platform/command/CommandUtilTest.java new file mode 100644 index 000000000..393dbdc9a --- /dev/null +++ b/core/src/test/java/com/minekube/connect/platform/command/CommandUtilTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2019-2022 GeyserMC. http://geysermc.org + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * @author GeyserMC + * @link https://github.com/GeyserMC/Floodgate + */ + +package com.minekube.connect.platform.command; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.minekube.connect.player.UserAudience; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.UUID; +import java.util.function.Function; +import org.junit.jupiter.api.Test; + +class CommandUtilTest { + @Test + void getOnlineUsernamesReturnsEmptyWhenThereAreNoOnlinePlayers() { + TestCommandUtil util = new TestCommandUtil(Collections.emptyList(), Object::toString); + + assertEquals(Collections.emptyList(), util.getOnlineUsernames()); + } + + @Test + void getOnlineUsernamesReturnsOneNamePerOnlinePlayerInOrder() { + TestCommandUtil util = new TestCommandUtil( + Arrays.asList("first", "second", "third"), + source -> "name-" + source); + + assertEquals(Arrays.asList("name-first", "name-second", "name-third"), util.getOnlineUsernames()); + } + + @Test + void getOnlineUsernamesPreservesNullUsernames() { + TestCommandUtil util = new TestCommandUtil( + Arrays.asList("first", "missing", "third"), + source -> "missing".equals(source) ? null : "name-" + source); + + assertEquals(Arrays.asList("name-first", null, "name-third"), util.getOnlineUsernames()); + } + + private static final class TestCommandUtil extends CommandUtil { + private final Collection onlinePlayers; + private final Function usernameResolver; + + private TestCommandUtil(Collection onlinePlayers, Function usernameResolver) { + super(null, null); + this.onlinePlayers = onlinePlayers; + this.usernameResolver = usernameResolver; + } + + @Override + public UserAudience getUserAudience(Object source) { + throw new UnsupportedOperationException(); + } + + @Override + protected String getUsernameFromSource(Object source) { + return usernameResolver.apply(source); + } + + @Override + protected UUID getUuidFromSource(Object source) { + throw new UnsupportedOperationException(); + } + + @Override + protected Collection getOnlinePlayers() { + return onlinePlayers; + } + + @Override + public Object getPlayerByUuid(UUID uuid) { + throw new UnsupportedOperationException(); + } + + @Override + public Object getPlayerByUsername(String username) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasPermission(Object player, String permission) { + throw new UnsupportedOperationException(); + } + + @Override + public void sendMessage(Object target, String message) { + throw new UnsupportedOperationException(); + } + + @Override + public void kickPlayer(Object player, String message) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/core/src/test/java/com/minekube/connect/register/WatcherRegisterTest.java b/core/src/test/java/com/minekube/connect/register/WatcherRegisterTest.java new file mode 100644 index 000000000..877d89291 --- /dev/null +++ b/core/src/test/java/com/minekube/connect/register/WatcherRegisterTest.java @@ -0,0 +1,188 @@ +package com.minekube.connect.register; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.minekube.connect.api.SimpleConnectApi; +import com.minekube.connect.api.inject.PlatformInjector; +import com.minekube.connect.api.logger.ConnectLogger; +import com.minekube.connect.tunnel.Tunneler; +import com.minekube.connect.watch.WatchClient; +import com.minekube.connect.watch.Watcher; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Timer; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import okhttp3.WebSocket; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +class WatcherRegisterTest { + private WatcherRegister register; + + @AfterEach + void stopRegister() { + if (register != null) { + register.stop(); + } + } + + @Test + void constructorDoesNotCreateScheduler() throws Exception { + WatcherRegister register = new WatcherRegister(); + + assertNull(scheduler(register)); + } + + @Test + void startCreatesScheduler() throws Exception { + register = newRegister(); + + register.start(); + + ScheduledExecutorService scheduler = scheduler(register); + assertNotNull(scheduler); + assertFalse(scheduler.isShutdown()); + } + + @Test + void stopShutsDownSchedulerAndClearsField() throws Exception { + register = newRegister(); + register.start(); + ScheduledExecutorService scheduler = scheduler(register); + + register.stop(); + + assertNull(scheduler(register)); + assertTrue(scheduler.isShutdown()); + } + + @Test + void startAfterStopCreatesNewScheduler() throws Exception { + register = newRegister(); + register.start(); + ScheduledExecutorService firstScheduler = scheduler(register); + + register.stop(); + register.start(); + + ScheduledExecutorService secondScheduler = scheduler(register); + assertNotNull(secondScheduler); + assertNotSame(firstScheduler, secondScheduler); + assertTrue(firstScheduler.isShutdown()); + assertFalse(secondScheduler.isShutdown()); + } + + @Test + void stopBeforeStartDoesNotCreateScheduler() throws Exception { + WatcherRegister register = new WatcherRegister(); + + assertDoesNotThrow(register::stop); + + assertNull(scheduler(register)); + } + + @Test + void retryDoesNotThrowWhenSchedulerIsClearedAfterStop() throws Exception { + register = newRegister(); + register.start(); + register.stop(); + + // Force the race window: started=true but scheduler already cleared. + started(register).set(true); + + assertDoesNotThrow(() -> invokeRetry(register)); + assertNull(scheduler(register)); + } + + @Test + void resetBackOffTimerDoesNotThrowWhenSchedulerIsClearedAfterStop() throws Exception { + Fixture fixture = newFixture(); + register = fixture.register; + register.start(); + ArgumentCaptor watcher = ArgumentCaptor.forClass(Watcher.class); + verify(fixture.watchClient).watch(watcher.capture()); + + register.stop(); + + assertDoesNotThrow(() -> watcher.getValue().onOpen()); + assertNull(scheduler(register)); + } + + @Test + void watcherRegisterDoesNotDeclareTimerFields() { + assertNoTimerFields(WatcherRegister.class); + for (Class nestedClass : WatcherRegister.class.getDeclaredClasses()) { + assertNoTimerFields(nestedClass); + } + } + + private static WatcherRegister newRegister() throws Exception { + return newFixture().register; + } + + private static Fixture newFixture() throws Exception { + WatcherRegister register = new WatcherRegister(); + WatchClient watchClient = mock(WatchClient.class); + when(watchClient.watch(any(Watcher.class))).thenReturn(mock(WebSocket.class)); + + inject(register, "watchClient", watchClient); + inject(register, "tunneler", mock(Tunneler.class)); + inject(register, "platformInjector", mock(PlatformInjector.class)); + inject(register, "logger", mock(ConnectLogger.class)); + inject(register, "api", mock(SimpleConnectApi.class)); + return new Fixture(register, watchClient); + } + + private static void inject(WatcherRegister register, String fieldName, Object value) + throws Exception { + Field field = WatcherRegister.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(register, value); + } + + private static ScheduledExecutorService scheduler(WatcherRegister register) throws Exception { + Field field = WatcherRegister.class.getDeclaredField("scheduler"); + field.setAccessible(true); + return (ScheduledExecutorService) field.get(register); + } + + private static AtomicBoolean started(WatcherRegister register) throws Exception { + Field field = WatcherRegister.class.getDeclaredField("started"); + field.setAccessible(true); + return (AtomicBoolean) field.get(register); + } + + private static void invokeRetry(WatcherRegister register) throws Exception { + Method method = WatcherRegister.class.getDeclaredMethod("retry"); + method.setAccessible(true); + method.invoke(register); + } + + private static void assertNoTimerFields(Class type) { + for (Field field : type.getDeclaredFields()) { + assertFalse(Timer.class.isAssignableFrom(field.getType()), + type.getName() + "#" + field.getName() + " must not use java.util.Timer"); + } + } + + private static final class Fixture { + private final WatcherRegister register; + private final WatchClient watchClient; + + private Fixture(WatcherRegister register, WatchClient watchClient) { + this.register = register; + this.watchClient = watchClient; + } + } +} diff --git a/core/src/test/java/com/minekube/connect/tunnel/TunnelerTest.java b/core/src/test/java/com/minekube/connect/tunnel/TunnelerTest.java new file mode 100644 index 000000000..00f5f7804 --- /dev/null +++ b/core/src/test/java/com/minekube/connect/tunnel/TunnelerTest.java @@ -0,0 +1,129 @@ +package com.minekube.connect.tunnel; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.fail; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.concurrent.atomic.AtomicReference; +import okhttp3.OkHttpClient; +import okhttp3.Response; +import okhttp3.WebSocket; +import okhttp3.WebSocketListener; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okio.ByteString; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Test; + +class TunnelerTest { + + @Test + void receivesBinaryFrameFromTunnelService() throws Exception { + byte[] sent = new byte[] {1, 2, 3, 4, 5}; + + byte[] received = receiveServerMessage(sent); + + assertArrayEquals(sent, received); + } + + @Test + void receivesBinaryFrameWhenByteStringDataFieldIsUnavailable() throws Exception { + byte[] sent = new byte[] {9, 8, 7, 6}; + Field dataField = Tunneler.class.getDeclaredField("DATA"); + dataField.setAccessible(true); + Field originalData = (Field) dataField.get(null); + + try { + setStaticField(dataField, null); + + byte[] received = receiveServerMessage(sent); + + assertArrayEquals(sent, received); + } finally { + setStaticField(dataField, originalData); + } + } + + private static byte[] receiveServerMessage(byte[] sent) throws Exception { + OkHttpClient client = new OkHttpClient(); + CapturingHandler handler = new CapturingHandler(); + TunnelConn conn = null; + + try (MockWebServer server = new MockWebServer()) { + server.enqueue(new MockResponse().withWebSocketUpgrade(new WebSocketListener() { + @Override + public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) { + webSocket.send(ByteString.of(sent)); + } + })); + server.start(); + + conn = new Tunneler(client).tunnel(webSocketUrl(server), "test-session", handler); + + return handler.awaitReceived(); + } finally { + if (conn != null) { + conn.close(); + } + client.dispatcher().cancelAll(); + client.dispatcher().executorService().shutdownNow(); + client.connectionPool().evictAll(); + } + } + + private static String webSocketUrl(MockWebServer server) { + return server.url("/tunnel").toString().replaceFirst("^http:", "ws:"); + } + + private static void setStaticField(Field field, Object value) throws Exception { + try { + field.set(null, value); + return; + } catch (IllegalAccessException ignored) { + } + + Object unsafe = unsafe(); + Method staticFieldBase = unsafe.getClass().getMethod("staticFieldBase", Field.class); + Method staticFieldOffset = unsafe.getClass().getMethod("staticFieldOffset", Field.class); + Method putObject = unsafe.getClass().getMethod("putObject", Object.class, long.class, Object.class); + Object base = staticFieldBase.invoke(unsafe, field); + long offset = (Long) staticFieldOffset.invoke(unsafe, field); + putObject.invoke(unsafe, base, offset, value); + } + + private static Object unsafe() throws Exception { + Field unsafe = Class.forName("sun.misc.Unsafe").getDeclaredField("theUnsafe"); + unsafe.setAccessible(true); + return unsafe.get(null); + } + + private static final class CapturingHandler implements TunnelConn.Handler { + private final AtomicReference received = new AtomicReference<>(); + private final AtomicReference error = new AtomicReference<>(); + + @Override + public void onReceive(byte[] data) { + received.set(data); + } + + @Override + public void onError(Throwable t) { + error.set(t); + } + + private byte[] awaitReceived() { + await().atMost(5, SECONDS).untilAsserted(() -> { + Throwable thrown = error.get(); + if (thrown != null) { + fail("Tunnel handler received an error", thrown); + } + assertNotNull(received.get()); + }); + return received.get(); + } + } +} diff --git a/core/src/test/java/com/minekube/connect/util/HttpUtilsTest.java b/core/src/test/java/com/minekube/connect/util/HttpUtilsTest.java new file mode 100644 index 000000000..71f532521 --- /dev/null +++ b/core/src/test/java/com/minekube/connect/util/HttpUtilsTest.java @@ -0,0 +1,73 @@ +package com.minekube.connect.util; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.lang.reflect.Field; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; + +class HttpUtilsTest { + private static final String THREAD_PREFIX = "connect-http-worker"; + + @Test + void executorRunsAtLeastFourTasksConcurrently() throws Exception { + ExecutorService executorService = executorService(); + CountDownLatch started = new CountDownLatch(4); + CountDownLatch release = new CountDownLatch(1); + AtomicInteger running = new AtomicInteger(); + AtomicInteger peak = new AtomicInteger(); + Future[] futures = new Future[4]; + + try { + for (int i = 0; i < futures.length; i++) { + futures[i] = executorService.submit(() -> { + int current = running.incrementAndGet(); + peak.accumulateAndGet(current, Math::max); + started.countDown(); + try { + release.await(1, TimeUnit.SECONDS); + } catch (InterruptedException exception) { + Thread.currentThread().interrupt(); + } finally { + running.decrementAndGet(); + } + }); + } + + assertTrue(started.await(1, TimeUnit.SECONDS)); + assertEquals(4, peak.get()); + } finally { + release.countDown(); + for (Future future : futures) { + if (future != null) { + future.get(1, TimeUnit.SECONDS); + } + } + } + } + + @Test + void executorUsesConnectHttpWorkerThreadNames() throws Exception { + Future threadName = executorService().submit(() -> Thread.currentThread().getName()); + + assertTrue(threadName.get(1, TimeUnit.SECONDS).startsWith(THREAD_PREFIX)); + } + + @Test + void executorUsesDaemonThreads() throws Exception { + Future daemon = executorService().submit(() -> Thread.currentThread().isDaemon()); + + assertTrue(daemon.get(1, TimeUnit.SECONDS)); + } + + private static ExecutorService executorService() throws Exception { + Field field = HttpUtils.class.getDeclaredField("EXECUTOR_SERVICE"); + field.setAccessible(true); + return (ExecutorService) field.get(null); + } +} diff --git a/spigot/src/main/java/com/minekube/connect/inject/spigot/SpigotInjector.java b/spigot/src/main/java/com/minekube/connect/inject/spigot/SpigotInjector.java index c852fb6ac..0e865c205 100644 --- a/spigot/src/main/java/com/minekube/connect/inject/spigot/SpigotInjector.java +++ b/spigot/src/main/java/com/minekube/connect/inject/spigot/SpigotInjector.java @@ -153,20 +153,22 @@ public void injectClient(ChannelFuture future) { future.channel().pipeline().addFirst("connect-init", new ChannelInboundHandlerAdapter() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - super.channelRead(ctx, msg); - Channel channel = (Channel) msg; // only need to inject if is a local session & auth passthrough is disabled LocalSession.context(channel) .filter(context -> !context.getPlayer().getAuth().isPassthrough()) - .ifPresent( - $ -> channel.pipeline().addLast(new ChannelInitializer() { + .ifPresent($ -> channel.pipeline().addLast("connect-injector", + new ChannelInboundHandlerAdapter() { @Override - protected void initChannel(Channel channel) { - injectAddonsCall(channel, false); - addInjectedClient(channel); + public void channelActive(ChannelHandlerContext childCtx) + throws Exception { + injectAddonsCall(childCtx.channel(), false); + addInjectedClient(childCtx.channel()); + childCtx.pipeline().remove(this); + super.channelActive(childCtx); } })); + super.channelRead(ctx, msg); } }); }