|
| 1 | +package fix |
| 2 | + |
| 3 | +import scalafix.v1._ |
| 4 | +import scala.meta._ |
| 5 | + |
| 6 | +/** |
| 7 | + * Rewrites `expect` and `expect.all` calls into `expect.eql` and `expect.same`. |
| 8 | + * |
| 9 | + * As of weaver `0.9.0`, `expect` and `expect.all` do not capture any |
| 10 | + * information on failure. `expect.eql` and `expect.same` have better error |
| 11 | + * messages. |
| 12 | + * |
| 13 | + * This rule can be applied for weaver versions `0.9.x` and above. |
| 14 | + */ |
| 15 | +class RewriteExpect |
| 16 | + extends SemanticRule("RewriteExpect") { |
| 17 | + |
| 18 | + override def fix(implicit doc: SemanticDocument): Patch = { |
| 19 | + val expectMethod = |
| 20 | + SymbolMatcher.normalized("weaver/Expectations.Helpers#expect.") |
| 21 | + val expectAllMethod = SymbolMatcher.normalized("weaver/ExpectMacro#all().") |
| 22 | + doc.tree.collect { |
| 23 | + case expectTree @ Term.Apply.After_4_6_0(expectMethod(_), |
| 24 | + Term.ArgClause(List(tree), _)) => |
| 25 | + // Matched `expect(tree)` |
| 26 | + rewrite(tree) match { |
| 27 | + case Some(next) => Patch.replaceTree(expectTree, next.toString) |
| 28 | + case None => Patch.empty |
| 29 | + } |
| 30 | + case expectAll @ Term.Apply.After_4_6_0(expectAllMethod(_), |
| 31 | + Term.ArgClause(trees, _)) => |
| 32 | + // Matched `expect.all(trees)` |
| 33 | + val (equalityAssertions, otherAssertions) = partition(trees)(tree => |
| 34 | + rewrite(tree) match { |
| 35 | + case Some(equality) => Left(equality) |
| 36 | + case None => Right(tree) |
| 37 | + }) |
| 38 | + equalityAssertions match { |
| 39 | + case firstEqAssertion :: remainingEqAssertions => |
| 40 | + val combinedEqAssertion = |
| 41 | + remainingEqAssertions.foldLeft(firstEqAssertion: Term) { |
| 42 | + (acc, cur) => |
| 43 | + q"$acc.and($cur)" |
| 44 | + } |
| 45 | + otherAssertions match { |
| 46 | + case Nil => |
| 47 | + // All assertions were == or ===. Remove the `expect.all` statement. |
| 48 | + Patch.replaceTree(expectAll, combinedEqAssertion.toString) |
| 49 | + case singleAssertion :: Nil => |
| 50 | + // A single assertion is not == or ===. Wrap this in `expect`. |
| 51 | + val combinedAssertion = |
| 52 | + q"$combinedEqAssertion.and(expect($singleAssertion))" |
| 53 | + Patch.replaceTree(expectAll, combinedAssertion.toString) |
| 54 | + case _ :: _ :: _ => |
| 55 | + // Several assertions are not == or ===. Wrap these in `expect.all`. |
| 56 | + val combinedAssertion = |
| 57 | + q"$combinedEqAssertion.and(expect.all(..$otherAssertions))" |
| 58 | + Patch.replaceTree(expectAll, combinedAssertion.toString) |
| 59 | + } |
| 60 | + case Nil => |
| 61 | + // `expect.all` didn't contain any == or === assertions. |
| 62 | + Patch.empty |
| 63 | + } |
| 64 | + }.asPatch |
| 65 | + } |
| 66 | + |
| 67 | + /** |
| 68 | + * Rewrites boolean assertions into `expect.same` and `expect` calls. |
| 69 | + * |
| 70 | + * For example: |
| 71 | + * - `a == b` is rewritten to `expect.same(a, b)` |
| 72 | + * - `a && b` is rewritten to `expect(a).and(expect(b))` |
| 73 | + * - `if (cond) a else b` is rewritten to |
| 74 | + * `if (cond) expect(a) else expect(b)` . |
| 75 | + */ |
| 76 | + def rewrite(tree: Tree)(implicit doc: SemanticDocument): Option[Term] = { |
| 77 | + val catsEqMethod = SymbolMatcher.normalized( |
| 78 | + "cats/syntax/EqOps#`===`().") + SymbolMatcher.normalized( |
| 79 | + "cats/syntax/EqOps#eqv().") |
| 80 | + |
| 81 | + tree match { |
| 82 | + case q"$lhs == $rhs" if !containsClues(tree) => |
| 83 | + val (expected, found) = inferExpectedAndFound(lhs, rhs) |
| 84 | + Some(q"expect.same($expected, $found)") |
| 85 | + case Term.ApplyInfix.After_4_6_0(lhs, |
| 86 | + catsEqMethod(_), |
| 87 | + _, |
| 88 | + Term.ArgClause(List(rhs), _)) |
| 89 | + if !containsClues(tree) => |
| 90 | + val (expected, found) = inferExpectedAndFound(lhs, rhs) |
| 91 | + Some(q"expect.eql($expected, $found)") |
| 92 | + case Term.Apply.After_4_6_0(Term.Select(lhs, catsEqMethod(_)), |
| 93 | + Term.ArgClause(List(rhs), _)) |
| 94 | + if !containsClues(tree) => |
| 95 | + val (expected, found) = inferExpectedAndFound(lhs, rhs) |
| 96 | + Some(q"expect.eql($expected, $found)") |
| 97 | + case q"$lhs && $rhs" => |
| 98 | + val nextLhs = rewrite(lhs).getOrElse(q"expect($lhs)") |
| 99 | + val nextRhs = rewrite(rhs).getOrElse(q"expect($rhs)") |
| 100 | + Some(q"$nextLhs.and($nextRhs)") |
| 101 | + case q"$lhs || $rhs" => |
| 102 | + val nextLhs = rewrite(lhs).getOrElse(q"expect($lhs)") |
| 103 | + val nextRhs = rewrite(rhs).getOrElse(q"expect($rhs)") |
| 104 | + Some(q"$nextLhs.or($nextRhs)") |
| 105 | + case q"true" => Some(q"success") |
| 106 | + case q"false" => Some(q"""failure("Assertion failed")""") |
| 107 | + case q"if ($cond) $lhs else $rhs" if !containsClues(cond) => |
| 108 | + val nextLhs = rewrite(lhs).getOrElse(q"expect($lhs)") |
| 109 | + val nextRhs = rewrite(rhs).getOrElse(q"expect($rhs)") |
| 110 | + Some(q"if ($cond) $nextLhs else $nextRhs") |
| 111 | + case q"$expr match { ..case $casesnel }" if !containsClues(expr) => |
| 112 | + // Rewrite assertions with `case _ => false` to use `matches` |
| 113 | + val wildcardFalse = casesnel.find { |
| 114 | + case p"case _ => false" => true |
| 115 | + case _ => false |
| 116 | + } |
| 117 | + wildcardFalse match { |
| 118 | + case Some(wildcardCase) if casesnel.size > 1 => |
| 119 | + val nextCases = rewriteCases(casesnel.filterNot(_ == wildcardCase)) |
| 120 | + Some(q"matches($expr) {..case $nextCases}") |
| 121 | + case _ => |
| 122 | + val nextCases = rewriteCases(casesnel) |
| 123 | + Some(q"$expr match {..case $nextCases }") |
| 124 | + } |
| 125 | + case _ => |
| 126 | + None |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + def rewriteCases(casesnel: List[Case])(implicit |
| 131 | + doc: SemanticDocument): List[Case] = { |
| 132 | + casesnel.map { caseTree => |
| 133 | + val nextExpr = |
| 134 | + rewrite(caseTree.body).getOrElse(q"expect(${caseTree.body})") |
| 135 | + p"case ${caseTree.pat} if ${caseTree.cond} => $nextExpr" |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + /** |
| 140 | + * Checks is an assertion contains `clue(...)`. If so, it should not be |
| 141 | + * rewritten. |
| 142 | + */ |
| 143 | + def containsClues(tree: Tree)(implicit doc: SemanticDocument): Boolean = { |
| 144 | + val clueSymbol = |
| 145 | + SymbolMatcher.normalized("weaver/internals/ClueHelpers#clue().") |
| 146 | + tree.collect { |
| 147 | + case clueSymbol(_) => () |
| 148 | + }.nonEmpty |
| 149 | + |
| 150 | + } |
| 151 | + |
| 152 | + /** |
| 153 | + * Infers the order of `expected` and `found` parameters in `expect.same`. |
| 154 | + * |
| 155 | + * When converting from `expect(a == b)` to `expect.same(a, b)`, we do not |
| 156 | + * know which of `a` or `b` is the expected and found value. |
| 157 | + * |
| 158 | + * The expected value is likely to be: |
| 159 | + * - A literal e.g. `"hello-world"` |
| 160 | + * - An ADT e.g. `Pet.Cat` |
| 161 | + * - An term containing "expected" in its name e.g. `expectedValue` |
| 162 | + * - An expression containing literals e.g. `makeId(1)` |
| 163 | + */ |
| 164 | + def inferExpectedAndFound(left: Term, right: Term)(implicit |
| 165 | + doc: SemanticDocument): (Term, Term) = { |
| 166 | + def isAlgebraicDataType(tree: Tree): Boolean = tree.symbol.info.exists { |
| 167 | + info => |
| 168 | + info.isObject && info.isFinal |
| 169 | + } |
| 170 | + def containsLiterals(term: Term): Boolean = term.collect { |
| 171 | + case Lit(_) => () |
| 172 | + }.nonEmpty |
| 173 | + |
| 174 | + def startsWithCapital(term: Term): Boolean = { |
| 175 | + val firstLetter = term.syntax.head.toString |
| 176 | + firstLetter.capitalize == firstLetter |
| 177 | + } |
| 178 | + def hasFoundInName(term: Term): Boolean = { |
| 179 | + val foundKeywords = List("obtained", "actual", "result", "found") |
| 180 | + foundKeywords.exists(term.syntax.contains) |
| 181 | + } |
| 182 | + def hasExpectedInName(term: Term): Boolean = { |
| 183 | + term.syntax.contains("expected") |
| 184 | + } |
| 185 | + |
| 186 | + (left, right) match { |
| 187 | + case (Lit(_), _) => (left, right) |
| 188 | + case (_, Lit(_)) => (right, left) |
| 189 | + case _ if isAlgebraicDataType(right) && !isAlgebraicDataType(left) => |
| 190 | + (right, left) |
| 191 | + case _ if isAlgebraicDataType(left) && !isAlgebraicDataType(right) => |
| 192 | + (left, right) |
| 193 | + // Test for expected and found values using naming conventions instead |
| 194 | + case _ if hasExpectedInName(right) && !hasExpectedInName(left) => |
| 195 | + (right, left) |
| 196 | + case _ if hasExpectedInName(left) && !hasExpectedInName(right) => |
| 197 | + (left, right) |
| 198 | + case _ if hasFoundInName(right) && !hasFoundInName(left) => |
| 199 | + (left, right) |
| 200 | + case _ if hasFoundInName(left) && !hasFoundInName(right) => |
| 201 | + (right, left) |
| 202 | + // Assume the symbol is an expected ADT if it starts with a capital |
| 203 | + // Symbol information is not present in Scala 3 - see https://github.com/scalacenter/scalafix/issues/2054 |
| 204 | + // If the symbol-based ADT test cannot be performed, we perform an additional test based on naming conventions. |
| 205 | + case _ if startsWithCapital(right) && !startsWithCapital(left) => |
| 206 | + (right, left) |
| 207 | + case _ if startsWithCapital(left) && !startsWithCapital(right) => |
| 208 | + (left, right) |
| 209 | + case _ if containsLiterals(right) && !containsLiterals(left) => |
| 210 | + (right, left) |
| 211 | + case _ if containsLiterals(left) && !containsLiterals(right) => |
| 212 | + (left, right) |
| 213 | + case _ => |
| 214 | + (left, right) |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + def partition[A, B, C]( |
| 219 | + values: List[A])(f: A => Either[B, C]): (List[B], List[C]) = { |
| 220 | + val eithers = values.map(f) |
| 221 | + val lefts = eithers.collect { case Left(v) => v } |
| 222 | + val rights = eithers.collect { case Right(v) => v } |
| 223 | + (lefts, rights) |
| 224 | + } |
| 225 | +} |
0 commit comments