diff --git a/junit-platform-testkit/src/main/java/org/junit/platform/testkit/engine/TestExecutionResultConditions.java b/junit-platform-testkit/src/main/java/org/junit/platform/testkit/engine/TestExecutionResultConditions.java index 474b6e38b6a7..1ebbeabb875c 100644 --- a/junit-platform-testkit/src/main/java/org/junit/platform/testkit/engine/TestExecutionResultConditions.java +++ b/junit-platform-testkit/src/main/java/org/junit/platform/testkit/engine/TestExecutionResultConditions.java @@ -16,6 +16,7 @@ import static org.apiguardian.api.API.Status.MAINTAINED; import static org.junit.platform.commons.util.FunctionUtils.where; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.function.Predicate; @@ -166,8 +167,9 @@ private static Condition cause(Condition condition) { private static Condition rootCause(Condition condition) { Predicate predicate = throwable -> { Preconditions.notNull(throwable, "Throwable must not be null"); - Throwable cause = Preconditions.notNull(throwable.getCause(), "Throwable does not have a cause"); - return condition.matches(getRootCause(cause)); + Preconditions.notNull(throwable.getCause(), "Throwable does not have a cause"); + Throwable rootCause = getRootCause(throwable, new ArrayList<>()); + return condition.matches(rootCause); }; return new Condition<>(predicate, "throwable root cause matches %s", condition); } @@ -176,12 +178,20 @@ private static Condition rootCause(Condition condition) { * Get the root cause of the supplied {@link Throwable}, or the supplied * {@link Throwable} if it has no cause. */ - private static Throwable getRootCause(Throwable throwable) { + private static Throwable getRootCause(Throwable throwable, List causeChain) { + // If we have already seen the current Throwable, that means we have + // encountered recursion in the cause chain and therefore return the last + // Throwable in the cause chain, which was the root cause before the recursion. + if (causeChain.contains(throwable)) { + return causeChain.get(causeChain.size() - 1); + } Throwable cause = throwable.getCause(); if (cause == null) { return throwable; } - return getRootCause(cause); + // Track current Throwable before recursing. + causeChain.add(throwable); + return getRootCause(cause, causeChain); } private static Condition suppressed(int index, Condition condition) { diff --git a/platform-tests/src/test/java/org/junit/platform/testkit/engine/TestExecutionResultConditionsTests.java b/platform-tests/src/test/java/org/junit/platform/testkit/engine/TestExecutionResultConditionsTests.java index 0f83b6324c17..9ea2079aedbf 100644 --- a/platform-tests/src/test/java/org/junit/platform/testkit/engine/TestExecutionResultConditionsTests.java +++ b/platform-tests/src/test/java/org/junit/platform/testkit/engine/TestExecutionResultConditionsTests.java @@ -82,4 +82,42 @@ void rootCauseDoesNotMatchForRootCauseWithDifferentMessage() { assertThat(rootCauseCondition.matches(throwable)).isFalse(); } + @Test + void rootCauseMatchesForRootCauseWithExpectedMessageAndSingleLevelRecursiveCauseChain() { + RuntimeException rootCause = new RuntimeException(EXPECTED); + Throwable throwable = new Throwable(rootCause); + rootCause.initCause(throwable); + + assertThat(rootCauseCondition.matches(throwable)).isTrue(); + } + + @Test + void rootCauseDoesNotMatchForRootCauseWithDifferentMessageAndSingleLevelRecursiveCauseChain() { + RuntimeException rootCause = new RuntimeException(UNEXPECTED); + Throwable throwable = new Throwable(rootCause); + rootCause.initCause(throwable); + + assertThat(rootCauseCondition.matches(throwable)).isFalse(); + } + + @Test + void rootCauseMatchesForRootCauseWithExpectedMessageAndDoubleLevelRecursiveCauseChain() { + RuntimeException rootCause = new RuntimeException(EXPECTED); + Exception intermediateCause = new Exception("intermediate cause", rootCause); + Throwable throwable = new Throwable(intermediateCause); + rootCause.initCause(throwable); + + assertThat(rootCauseCondition.matches(throwable)).isTrue(); + } + + @Test + void rootCauseDoesNotMatchForRootCauseWithDifferentMessageAndDoubleLevelRecursiveCauseChain() { + RuntimeException rootCause = new RuntimeException(UNEXPECTED); + Exception intermediateCause = new Exception("intermediate cause", rootCause); + Throwable throwable = new Throwable(intermediateCause); + rootCause.initCause(throwable); + + assertThat(rootCauseCondition.matches(throwable)).isFalse(); + } + }