Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class AstCreator(
val node = createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any))
Seq(Ast(node))
case _ =>
logger.error("astForSingleLeftHandSideContext() All contexts mismatched.")
logger.error(s"astForSingleLeftHandSideContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())

}
Expand Down Expand Up @@ -340,7 +340,7 @@ class AstCreator(
case ctx: ChainedInvocationWithoutArgumentsPrimaryContext =>
astForChainedInvocationWithoutArgumentsPrimaryContext(ctx)
case _ =>
logger.error("astForPrimaryContext() All contexts mismatched.")
logger.error(s"astForPrimaryContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand All @@ -364,7 +364,7 @@ class AstCreator(
case ctx: MultipleAssignmentExpressionContext => astForMultipleAssignmentExpressionContext(ctx)
case ctx: IsDefinedExpressionContext => Seq(astForIsDefinedExpression(ctx))
case _ =>
logger.error("astForExpressionContext() All contexts mismatched.")
logger.error(s"astForExpressionContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -415,7 +415,7 @@ class AstCreator(
case ctx: RubyParser.SplattingOnlyIndexingArgumentsContext =>
astForSplattingArgumentContext(ctx.splattingArgument())
case _ =>
logger.error("astForIndexingArgumentsContext() All contexts mismatched.")
logger.error(s"astForIndexingArgumentsContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -627,7 +627,7 @@ class AstCreator(
case ctx: GroupedLeftHandSideOnlyMultipleLeftHandSideContext =>
astForGroupedLeftHandSideContext(ctx.groupedLeftHandSide())
case _ =>
logger.error("astForMultipleLeftHandSideContext() All contexts mismatched.")
logger.error(s"astForMultipleLeftHandSideContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -735,7 +735,7 @@ class AstCreator(
.withChildren(astForArguments(ctx.arguments()))
)
case _ =>
logger.error("astForInvocationWithoutParenthesesContext() All contexts mismatched.")
logger.error(s"astForInvocationWithoutParenthesesContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -973,7 +973,7 @@ class AstCreator(
case ctx: SimpleMethodNamePartContext => astForSimpleMethodNamePartContext(ctx)
case ctx: SingletonMethodNamePartContext => astForSingletonMethodNamePartContext(ctx)
case _ =>
logger.error("astForMethodNamePartContext() All contexts mismatched.")
logger.error(s"astForMethodNamePartContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -1050,7 +1050,7 @@ class AstCreator(
}

def astForBodyStatementContext(ctx: BodyStatementContext, addReturnNode: Boolean = false): Seq[Ast] = {
val compoundStatementAsts = astForCompoundStatement(ctx.compoundStatement())
val compoundStatementAsts = astForCompoundStatement(ctx.compoundStatement(), !addReturnNode)

val compoundStatementAstsWithReturn =
if (addReturnNode && compoundStatementAsts.size > 0) {
Expand Down Expand Up @@ -1320,7 +1320,7 @@ class AstCreator(
val primaryAsts = astForPrimaryContext(ctx.primary())
primaryAsts ++ methodNameAsts ++ argsAsts ++ doBlockAsts
case _ =>
logger.error("astForCommandWithDoBlockContext() All contexts mismatched.")
logger.error(s"astForCommandWithDoBlockContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down Expand Up @@ -1354,7 +1354,7 @@ class AstCreator(
case ctx: ChainedCommandWithDoBlockOnlyArgumentsWithParenthesesContext =>
astForChainedCommandWithDoBlockContext(ctx.chainedCommandWithDoBlock())
case _ =>
logger.error("astForArgumentsWithParenthesesContext() All contexts mismatched.")
logger.error(s"astForArgumentsWithParenthesesContext() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import io.joern.x2cpg.Ast
import io.joern.x2cpg.Defines.DynamicCallUnknownFullName
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators}
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewCall, NewControlStructure, NewImport, NewLiteral}
import org.slf4j.LoggerFactory
import org.antlr.v4.runtime.ParserRuleContext

import scala.jdk.CollectionConverters.CollectionHasAsScala

trait AstForStatementsCreator {
this: AstCreator =>

private val logger = LoggerFactory.getLogger(this.getClass)
protected def astForAliasStatement(ctx: AliasStatementContext): Ast = {
val aliasName = ctx.definedMethodNameOrSymbol(0).getText.substring(1)
val methodName = ctx.definedMethodNameOrSymbol(1).getText.substring(1)
Expand Down Expand Up @@ -80,9 +82,13 @@ trait AstForStatementsCreator {
controlStructureAst(throwNode, rhs.headOption, lhs)
}

protected def astForCompoundStatement(ctx: CompoundStatementContext): Seq[Ast] = {
protected def astForCompoundStatement(ctx: CompoundStatementContext, packInBlock: Boolean = true): Seq[Ast] = {
val stmtAsts = Option(ctx.statements()).map(astForStatements).getOrElse(Seq())
Seq(blockAst(blockNode(ctx), stmtAsts.toList))
if (packInBlock) {
Seq(blockAst(blockNode(ctx), stmtAsts.toList))
} else {
stmtAsts
}
}

protected def astForStatements(ctx: StatementsContext): Seq[Ast] = {
Expand Down Expand Up @@ -110,7 +116,9 @@ trait AstForStatementsCreator {
case ctx: NotExpressionOrCommandContext => Seq(astForNotKeywordExpressionOrCommand(ctx))
case ctx: OrAndExpressionOrCommandContext => Seq(astForOrAndExpressionOrCommand(ctx))
case ctx: ExpressionExpressionOrCommandContext => astForExpressionContext(ctx.expression())
case _ => Seq(Ast())
case _ =>
logger.error(s"astForExpressionOrCommand() $filename, ${ctx.getText} All contexts mismatched.")
Seq(Ast())
}

protected def astForNotKeywordExpressionOrCommand(ctx: NotExpressionOrCommandContext): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,29 @@ class DataFlowTests extends DataFlowCodeToCpgSuite {
}
}

"Data flow for begin/rescue with sink in else" should {
val cpg = code("""
|x = 1
|begin
| puts "In begin"
|rescue SomeException
| puts "SomeException occurred"
|rescue => exceptionVar
| puts "Caught exception in variable #{exceptionVar}"
|rescue
| puts "Catch-all block"
|else
| puts x
|end
|""".stripMargin)

"find flows to the sink" in {
val source = cpg.identifier.name("x").l
val sink = cpg.call.name("puts").l
sink.reachableByFlows(source).size shouldBe 2
}
}

"Data flow for begin/rescue with sink in rescue" should {
val cpg = code("""
|x = 1
Expand Down Expand Up @@ -1010,6 +1033,70 @@ class DataFlowTests extends DataFlowCodeToCpgSuite {
}
}

"Data flow for begin/rescue with sink in ensure" should {
val cpg = code("""
|x = 1
|begin
| puts "in begin"
|rescue SomeException
| puts "SomeException occurred"
|rescue => exceptionVar
| puts "Caught exception in variable #{exceptionVar}"
|rescue
| puts "In rescue all"
|ensure
| puts x
|end
|
|""".stripMargin)

"find flows to the sink" in {
val source = cpg.identifier.name("x").l
val sink = cpg.call.name("puts").l
sink.reachableByFlows(source).size shouldBe 2
}
}

// parsing issue. comment out when fixed
"Data flow for begin/rescue with data flow through the exception" ignore {
val cpg = code("""
|x = "Exception message: "
|begin
|1/0
|rescue ZeroDivisionError => e
| y = x + e.message
| puts y
|end
|
|""".stripMargin)

"find flows to the sink" in {
val source = cpg.identifier.name("x").l
val sink = cpg.call.name("puts").l
sink.reachableByFlows(source).size shouldBe 2
}
}

"Data flow for begin/rescue with data flow through block with multiple exceptions being caught" should {
val cpg = code("""
|x = 1
|y = 10
|begin
|1/0
|rescue SystemCallError, ZeroDivisionError
| y = x + 100
|end
|
|puts y
|""".stripMargin)

"find flows to the sink" in {
val source = cpg.identifier.name("x").l
val sink = cpg.call.name("puts").l
sink.reachableByFlows(source).size shouldBe 2
}
}

"Data flow for begin/rescue with sink in function without begin" ignore {
val cpg = code("""
|def foo(arg)
Expand Down