diff --git a/batteries/syntax/ast_types.luau b/batteries/syntax/ast_types.luau index 8992b3ed..f840d0d8 100644 --- a/batteries/syntax/ast_types.luau +++ b/batteries/syntax/ast_types.luau @@ -229,7 +229,7 @@ export type AstStatLocal = { export type AstStatFor = { tag: "for", ["for"]: Token<"for">, - variable: Token, + variable: AstLocal, equals: Token<"=">, from: AstExpr, toComma: Token<",">, @@ -251,6 +251,21 @@ export type AstStatForIn = { body: AstStatBlock, ["end"]: Token<"end">, } + +export type AstStatAssign = { + tag: "assign", + variables: Punctuated, + equals: Token<"=">, + values: Punctuated, +} + +export type AstStatCompoundAssign = { + tag: "compoundassign", + variable: AstExpr, + operand: Token, -- TODO: enforce token type, + value: AstExpr, +} + export type AstStat = | AstStatBlock | AstStatIf @@ -263,5 +278,7 @@ export type AstStat = | AstStatLocal | AstStatFor | AstStatForIn + | AstStatAssign + | AstStatCompoundAssign return {} diff --git a/batteries/syntax/visitor.luau b/batteries/syntax/visitor.luau index 688bdee5..633739df 100644 --- a/batteries/syntax/visitor.luau +++ b/batteries/syntax/visitor.luau @@ -11,6 +11,8 @@ type Visitor = { visitLocalDeclaration: (T.AstStatLocal) -> boolean, visitFor: (T.AstStatFor) -> boolean, visitForIn: (T.AstStatForIn) -> boolean, + visitAssign: (T.AstStatAssign) -> boolean, + visitCompoundAssign: (T.AstStatCompoundAssign) -> boolean, visitLocalReference: (T.AstExprLocal) -> boolean, visitGlobal: (T.AstExprGlobal) -> boolean, @@ -46,6 +48,8 @@ local defaultVisitor: Visitor = { visitLocalDeclaration = alwaysVisit :: any, visitFor = alwaysVisit :: any, visitForIn = alwaysVisit :: any, + visitAssign = alwaysVisit :: any, + visitCompoundAssign = alwaysVisit :: any, visitLocalReference = alwaysVisit :: any, visitGlobal = alwaysVisit :: any, @@ -185,6 +189,22 @@ local function visitForIn(node: T.AstStatForIn, visitor: Visitor) end end +local function visitAssign(node: T.AstStatAssign, visitor: Visitor) + if visitor.visitAssign(node) then + visitPunctuated(node.variables, visitor, visitExpression) + visitToken(node.equals, visitor) + visitPunctuated(node.values, visitor, visitExpression) + end +end + +local function visitCompoundAssign(node: T.AstStatCompoundAssign, visitor: Visitor) + if visitor.visitCompoundAssign(node) then + visitExpression(node.variable, visitor) + visitToken(node.operand, visitor) + visitExpression(node.value, visitor) + end +end + local function visitString(node: T.AstExprConstantString, visitor: Visitor) if visitor.visitString(node) then visitor.visitToken(node) @@ -388,6 +408,10 @@ function visitStatement(statement: T.AstStat, visitor: Visitor) visitFor(statement, visitor) elseif statement.tag == "forin" then visitForIn(statement, visitor) + elseif statement.tag == "assign" then + visitAssign(statement, visitor) + elseif statement.tag == "compoundassign" then + visitCompoundAssign(statement, visitor) else exhaustiveMatch(statement.tag) end diff --git a/luau/src/luau.cpp b/luau/src/luau.cpp index c4e5aa03..09e2f747 100644 --- a/luau/src/luau.cpp +++ b/luau/src/luau.cpp @@ -1241,14 +1241,22 @@ struct AstSerialize : public Luau::AstVisitor void serializeStat(Luau::AstStatAssign* node) { lua_rawcheckstack(L, 2); - lua_createtable(L, 0, preambleSize + 2); + lua_createtable(L, 0, preambleSize + 3); + + const auto cstNode = lookupCstNode(node); serializeNodePreamble(node, "assign"); - serializeExprs(node->vars); + serializePunctuated(node->vars, cstNode ? cstNode->varsCommaPositions : Luau::AstArray{}, ","); lua_setfield(L, -2, "variables"); - serializeExprs(node->values); + if (cstNode) + { + serializeToken(cstNode->equalsPosition, "="); + lua_setfield(L, -2, "equals"); + } + + serializePunctuated(node->values, cstNode ? cstNode->valuesCommaPositions : Luau::AstArray{}, ","); lua_setfield(L, -2, "values"); } @@ -1259,12 +1267,15 @@ struct AstSerialize : public Luau::AstVisitor serializeNodePreamble(node, "compoundassign"); - serialize(node->op); - lua_setfield(L, -2, "operand"); - node->var->visit(this); lua_setfield(L, -2, "variable"); + if (const auto cstNode = lookupCstNode(node)) + serializeToken(cstNode->opPosition, (Luau::toString(node->op) + "=").data()); + else + serialize(node->op); + lua_setfield(L, -2, "operand"); + node->value->visit(this); lua_setfield(L, -2, "value"); } diff --git a/tests/astSerializerTests/assignment-1.luau b/tests/astSerializerTests/assignment-1.luau new file mode 100644 index 00000000..d101dd0c --- /dev/null +++ b/tests/astSerializerTests/assignment-1.luau @@ -0,0 +1,5 @@ +x = 1 + +a, b = 1, true + +a, b, c.d.e[f][g][1], h:i().j[k]:l()[m] = true, false, 1, 4 diff --git a/tests/astSerializerTests/compound-assignment-1.luau b/tests/astSerializerTests/compound-assignment-1.luau new file mode 100644 index 00000000..0dca2b14 --- /dev/null +++ b/tests/astSerializerTests/compound-assignment-1.luau @@ -0,0 +1,24 @@ +local x = 1 +local y = 2 + +x += 5 +x -= 5 +x *= 5 +x /= 5 +x //= 5 +x %= 5 +x ^= 5 + +x += y +x -= y +x *= y +x /= y +x //= y +x %= y +x ^= y + +local str1 = "Hello, " +local str2 = "world!" + +str1 ..= "world!" +str1 ..= str2 diff --git a/tests/testAstSerializer.spec.luau b/tests/testAstSerializer.spec.luau index 8654c5c3..cf510d67 100644 --- a/tests/testAstSerializer.spec.luau +++ b/tests/testAstSerializer.spec.luau @@ -129,7 +129,9 @@ local function test_roundtrippableAst() "examples/parsing.luau", "examples/time_example.luau", "examples/writeFile.luau", + "tests/astSerializerTests/assignment-1.luau", "tests/astSerializerTests/break-continue-1.luau", + "tests/astSerializerTests/compound-assignment-1.luau", "tests/astSerializerTests/generic-for-loop-1.luau", "tests/astSerializerTests/numeric-for-loop-1.luau", "tests/astSerializerTests/while-1.luau",