diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/context/DefaultDriverContext.java b/core/src/main/java/com/datastax/oss/driver/internal/core/context/DefaultDriverContext.java index cc725994d7c..10fc409b8e3 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/context/DefaultDriverContext.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/context/DefaultDriverContext.java @@ -85,6 +85,7 @@ import com.datastax.oss.driver.internal.core.session.PoolManager; import com.datastax.oss.driver.internal.core.session.RequestProcessor; import com.datastax.oss.driver.internal.core.session.RequestProcessorRegistry; +import com.datastax.oss.driver.internal.core.session.SessionRegistry; import com.datastax.oss.driver.internal.core.ssl.JdkSslHandlerFactory; import com.datastax.oss.driver.internal.core.ssl.SslHandlerFactory; import com.datastax.oss.driver.internal.core.tracker.MultiplexingRequestTracker; @@ -226,6 +227,7 @@ public class DefaultDriverContext implements InternalDriverContext { private final LazyReference> lifecycleListenersRef = new LazyReference<>("lifecycleListeners", this::buildLifecycleListeners, cycleDetector); + private static SessionRegistry sessionRegistry; private final DriverConfig config; private final DriverConfigLoader configLoader; private final ChannelPoolFactory channelPoolFactory = new ChannelPoolFactory(); @@ -335,6 +337,14 @@ public DefaultDriverContext( .build()); } + public SessionRegistry getSessionRegistry() { + return sessionRegistry; + } + + public static void setSessionRegistry(SessionRegistry registry) { + sessionRegistry = registry; + } + /** * Builds a map of options to send in a Startup message. * diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/session/DefaultSession.java b/core/src/main/java/com/datastax/oss/driver/internal/core/session/DefaultSession.java index c9fee86f2c1..7319ae579bf 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/session/DefaultSession.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/session/DefaultSession.java @@ -39,6 +39,7 @@ import com.datastax.oss.driver.api.core.session.Request; import com.datastax.oss.driver.api.core.type.reflect.GenericType; import com.datastax.oss.driver.internal.core.channel.DriverChannel; +import com.datastax.oss.driver.internal.core.context.DefaultDriverContext; import com.datastax.oss.driver.internal.core.context.InternalDriverContext; import com.datastax.oss.driver.internal.core.context.LifecycleListener; import com.datastax.oss.driver.internal.core.metadata.DefaultNode; @@ -336,6 +337,10 @@ private class SingleThreaded { private boolean forceCloseWasCalled; private SingleThreaded(InternalDriverContext context, Set contactPoints) { + if (context instanceof DefaultDriverContext) { + SessionRegistry sessionRegistry = ((DefaultDriverContext) context).getSessionRegistry(); + if (sessionRegistry != null) sessionRegistry.registerSession(this); + } this.context = context; this.nodeStateManager = new NodeStateManager(context); this.initialContactPoints = contactPoints; @@ -656,6 +661,10 @@ private void warnIfFailed(CompletionStage stage) { } private void closePolicies() { + if (context instanceof DefaultDriverContext) { + SessionRegistry sessionRegistry = ((DefaultDriverContext) context).getSessionRegistry(); + if (sessionRegistry != null) sessionRegistry.closeSession(this); + } // This is a bit tricky: we might be closing the session because of an initialization error. // This error might have been triggered by a policy failing to initialize. If we try to access // the policy here to close it, it will fail again. So make sure we ignore that error and diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/session/SessionRegistry.java b/core/src/main/java/com/datastax/oss/driver/internal/core/session/SessionRegistry.java new file mode 100644 index 00000000000..b42eaf50730 --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/session/SessionRegistry.java @@ -0,0 +1,13 @@ +package com.datastax.oss.driver.internal.core.session; + +import com.datastax.oss.driver.internal.core.context.DefaultDriverContext; + +public abstract class SessionRegistry { + public SessionRegistry() { + DefaultDriverContext.setSessionRegistry(this); + } + + public abstract void registerSession(Object session); + + public abstract void closeSession(Object session); +} diff --git a/osgi-tests/src/test/java/com/datastax/oss/driver/internal/osgi/support/CcmPaxExam.java b/osgi-tests/src/test/java/com/datastax/oss/driver/internal/osgi/support/CcmPaxExam.java index 77e82ff2fee..ea96dcec679 100644 --- a/osgi-tests/src/test/java/com/datastax/oss/driver/internal/osgi/support/CcmPaxExam.java +++ b/osgi-tests/src/test/java/com/datastax/oss/driver/internal/osgi/support/CcmPaxExam.java @@ -18,6 +18,7 @@ package com.datastax.oss.driver.internal.osgi.support; import com.datastax.oss.driver.api.testinfra.requirement.BackendRequirementRule; +import com.datastax.oss.driver.api.testinfra.session.SessionTracker; import org.junit.AssumptionViolatedException; import org.junit.runner.Description; import org.junit.runner.notification.Failure; @@ -26,7 +27,6 @@ import org.ops4j.pax.exam.junit.PaxExam; public class CcmPaxExam extends PaxExam { - public CcmPaxExam(Class klass) throws InitializationError { super(klass); } @@ -36,7 +36,12 @@ public void run(RunNotifier notifier) { Description description = getDescription(); if (BackendRequirementRule.meetsDescriptionRequirements(description)) { - super.run(notifier); + try { + SessionTracker.testStarted(description.getClassName(), description.getMethodName()); + super.run(notifier); + } finally { + SessionTracker.testEnded(description.getClassName(), description.getMethodName()); + } } else { // requirements not met, throw reasoning assumption to skip test AssumptionViolatedException e = diff --git a/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/ccm/CustomCcmRule.java b/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/ccm/CustomCcmRule.java index 5ea1bf7ed3c..3d1f0dee3f8 100644 --- a/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/ccm/CustomCcmRule.java +++ b/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/ccm/CustomCcmRule.java @@ -17,7 +17,10 @@ */ package com.datastax.oss.driver.api.testinfra.ccm; +import com.datastax.oss.driver.api.testinfra.session.SessionTracker; import java.util.concurrent.atomic.AtomicReference; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,6 +42,24 @@ public class CustomCcmRule extends BaseCcmRule { super(ccmBridge); } + @Override + public Statement apply(Statement base, Description description) { + final Statement statement = super.apply(base, description); + return new Statement() { + final Statement original = statement; + + @Override + public void evaluate() throws Throwable { + try { + SessionTracker.testStarted(description.getClassName(), description.getMethodName()); + original.evaluate(); + } finally { + SessionTracker.testEnded(description.getClassName(), description.getMethodName()); + } + } + }; + } + @Override protected void before() { if (CURRENT.get() == null && CURRENT.compareAndSet(null, this)) { diff --git a/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/session/SessionTracker.java b/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/session/SessionTracker.java new file mode 100644 index 00000000000..0218f8a8234 --- /dev/null +++ b/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/session/SessionTracker.java @@ -0,0 +1,73 @@ +package com.datastax.oss.driver.api.testinfra.session; + +import com.datastax.oss.driver.internal.core.session.SessionRegistry; +import java.lang.ref.WeakReference; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.stream.Collectors; + +public class SessionTracker { + static final TestSessionRegistry sessionRegistry = new TestSessionRegistry(); + + private static final Set runningTests = new ConcurrentSkipListSet<>(); + + public static void testStarted(String className, String methodName) { + runningTests.add(String.format("%s.%s", className, methodName)); + } + + public static void testEnded(String className, String methodName) { + runningTests.remove(String.format("%s.%s", className, methodName)); + if (runningTests.isEmpty()) { + List activeSessions = + sessionRegistry.getActiveSessionsAndForget(); + if (!activeSessions.isEmpty()) { + throw new IllegalStateException( + String.format( + "There are active sessions, created in following tests: %s", + activeSessions.stream() + .flatMap(s -> s.sourceTests.stream()) + .collect(Collectors.toList()))); + } + } + } + + private static class TestSessionRegistry extends SessionRegistry { + protected TestSessionRegistry() { + super(); + } + + public static class SessionRecord { + final WeakReference session; + final Set sourceTests; + + SessionRecord(WeakReference session, Set sourceTests) { + this.session = session; + this.sourceTests = sourceTests; + } + } + + private static final List sessions = new CopyOnWriteArrayList<>(); + + @Override + public void registerSession(Object session) { + sessions.add( + new SessionRecord( + new WeakReference<>(session), runningTests.stream().collect(Collectors.toSet()))); + } + + @Override + public void closeSession(Object session) { + sessions.removeIf(s -> s.session == session); + } + + public List getActiveSessionsAndForget() { + // Purge known sessions + sessions.removeIf(ref -> ref.session.get() == null); + return sessions.stream() + .filter(ref -> ref.session.get() == null) + .collect(Collectors.toList()); + } + } +}