diff --git a/src/jsMain/kotlin/assembly/AsmConstruct.kt b/src/jsMain/kotlin/assembly/AsmConstruct.kt index f47d79e..cafe5e2 100644 --- a/src/jsMain/kotlin/assembly/AsmConstruct.kt +++ b/src/jsMain/kotlin/assembly/AsmConstruct.kt @@ -1,5 +1,8 @@ package assembly +import kotlinx.serialization.Serializable + +@Serializable sealed class AsmConstruct { protected fun indent(level: Int): String = " ".repeat(level) } diff --git a/src/jsMain/kotlin/assembly/CodeEmitter.kt b/src/jsMain/kotlin/assembly/CodeEmitter.kt index 31fe25f..0d6f7d3 100644 --- a/src/jsMain/kotlin/assembly/CodeEmitter.kt +++ b/src/jsMain/kotlin/assembly/CodeEmitter.kt @@ -1,11 +1,23 @@ package assembly +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json class CodeEmitter { // TODO make it based on the device instead hard coded private val useLinuxPrefix = true + @kotlinx.serialization.Serializable + data class RawInstruction(val code: String, val sourceId: String?) + + @kotlinx.serialization.Serializable + data class RawFunction( + val name: String, + val body: List, + val stackSize: Int + ) fun emit(program: AsmProgram): String = program.functions.joinToString("\n\n") { emitFunction(it) } + fun emitRaw(program: AsmProgram): String = program.functions.joinToString("\n\n") { emitFunctionRaw(it) } private fun emitFunction(function: AsmFunction): String { val functionName = formatLabel(function.name) val bodyAsm = function.body.joinToString("\n") { emitInstruction(it) } @@ -24,6 +36,39 @@ class CodeEmitter { } } + private fun emitFunctionRaw(function: AsmFunction): String { + val bodyRaw = function.body.map { emitInstructionRaw(it) } + val rawFunc = RawFunction(function.name, bodyRaw, function.stackSize) + return Json.encodeToString(rawFunc) // returns valid JSON + } + + private fun emitInstructionRaw(instruction: Instruction): RawInstruction { + val indent = " " + return when (instruction) { + is Call -> RawInstruction("call ${formatLabel(instruction.identifier)}", instruction.sourceId.toString()) + is Push -> { + val operand = emitOperand(instruction.operand, size = OperandSize.QUAD) + RawInstruction("push $operand", instruction.sourceId.toString()) + } + is DeAllocateStack -> RawInstruction("addq rsp, ${instruction.size}", instruction.sourceId.toString()) + is Mov -> RawInstruction("mov ${emitOperand(instruction.dest)}, ${emitOperand(instruction.src)}", instruction.sourceId.toString()) + is AsmUnary -> RawInstruction("${instruction.op.text} ${emitOperand(instruction.dest)}", instruction.sourceId.toString()) + is AsmBinary -> RawInstruction("${instruction.op.text} ${emitOperand(instruction.dest)}, ${emitOperand(instruction.src)}", instruction.sourceId.toString()) + is Cmp -> RawInstruction("cmp ${emitOperand(instruction.dest)}, ${emitOperand(instruction.src)}", instruction.sourceId.toString()) + is Idiv -> RawInstruction("idiv ${emitOperand(instruction.divisor)}", instruction.sourceId.toString()) + is AllocateStack -> RawInstruction("subq rsp, ${instruction.size}", instruction.sourceId.toString()) + is Cdq -> RawInstruction("cdq", instruction.sourceId.toString()) + is Label -> RawInstruction("${formatLabel(instruction.name)}:", instruction.sourceId.toString()) + is Jmp -> RawInstruction("jmp ${formatLabel(instruction.label.name)}", instruction.sourceId.toString()) + is JmpCC -> RawInstruction("j${instruction.condition.text} ${formatLabel(instruction.label.name)}", instruction.sourceId.toString()) + is SetCC -> { + val destOperand = emitOperand(instruction.dest, size = OperandSize.BYTE) + RawInstruction("set${instruction.condition.text} $destOperand", instruction.sourceId.toString()) + } + is Ret -> RawInstruction("ret", instruction.sourceId.toString()) + } + } + private fun emitInstruction(instruction: Instruction): String { val indent = " " return when (instruction) { @@ -33,7 +78,6 @@ class CodeEmitter { "${indent}push $operand" } is DeAllocateStack -> "${indent}addq rsp, ${instruction.size}" - is Mov -> "${indent}mov ${emitOperand(instruction.dest)}, ${emitOperand(instruction.src)}" is AsmUnary -> "${indent}${instruction.op.text} ${emitOperand(instruction.dest)}" is AsmBinary -> "${indent}${instruction.op.text} ${emitOperand(instruction.dest)}, ${emitOperand(instruction.src)}" @@ -44,15 +88,20 @@ class CodeEmitter { is Label -> formatLabel(instruction.name) + ":" is Jmp -> "${indent}jmp ${formatLabel(instruction.label.name)}" is JmpCC -> "${indent}j${instruction.condition.text} ${formatLabel(instruction.label.name)}" - is SetCC -> { val destOperand = emitOperand(instruction.dest, size = OperandSize.BYTE) "${indent}set${instruction.condition.text} $destOperand" } - is Ret -> "" + } + } - else -> throw NotImplementedError("Emission for ${instruction::class.simpleName} not implemented.") + private fun emitOperandRaw(operand: Operand): String { + return when (operand) { + is Imm -> "Imm(value=${operand.value})" + is Stack -> "Stack(offset=${operand.offset})" + is Register -> "Register(name=${operand.name})" + is Pseudo -> "Pseudo(name=${operand.name})" } } diff --git a/src/jsMain/kotlin/assembly/Functions.kt b/src/jsMain/kotlin/assembly/Functions.kt index bd95917..6edcddc 100644 --- a/src/jsMain/kotlin/assembly/Functions.kt +++ b/src/jsMain/kotlin/assembly/Functions.kt @@ -1,7 +1,13 @@ package assembly +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable sealed class Function : AsmConstruct() +@Serializable +@SerialName("AsmFunction") data class AsmFunction( val name: String, var body: List, diff --git a/src/jsMain/kotlin/assembly/InstructionFixer.kt b/src/jsMain/kotlin/assembly/InstructionFixer.kt index 195e2b0..b96491b 100644 --- a/src/jsMain/kotlin/assembly/InstructionFixer.kt +++ b/src/jsMain/kotlin/assembly/InstructionFixer.kt @@ -15,22 +15,22 @@ class InstructionFixer { // Idiv cannot take an immediate value directly. instruction is Idiv && instruction.divisor is Imm -> { listOf( - Mov(instruction.divisor, Register(HardwareRegister.R10D)), - Idiv(Register(HardwareRegister.R10D)) + Mov(instruction.divisor, Register(HardwareRegister.R10D), instruction.sourceId), + Idiv(Register(HardwareRegister.R10D), instruction.sourceId) ) } instruction is Push && instruction.operand is Stack -> { listOf( - Mov(instruction.operand, Register(HardwareRegister.EAX)), // Use a caller-saved register - Push(Register(HardwareRegister.EAX)) + Mov(instruction.operand, Register(HardwareRegister.EAX), instruction.sourceId), // Use a caller-saved register + Push(Register(HardwareRegister.EAX), instruction.sourceId) ) } instruction is Mov && instruction.src is Stack && instruction.dest is Stack -> { listOf( - Mov(instruction.src, Register(HardwareRegister.R10D)), - Mov(Register(HardwareRegister.R10D), instruction.dest) + Mov(instruction.src, Register(HardwareRegister.R10D), instruction.sourceId), + Mov(Register(HardwareRegister.R10D), instruction.dest, instruction.sourceId) ) } @@ -40,8 +40,8 @@ class InstructionFixer { instruction.src is Stack && instruction.dest is Stack -> { listOf( - Mov(instruction.src, Register(HardwareRegister.R10D)), - AsmBinary(instruction.op, Register(HardwareRegister.R10D), instruction.dest) + Mov(instruction.src, Register(HardwareRegister.R10D), instruction.sourceId), + AsmBinary(instruction.op, Register(HardwareRegister.R10D), instruction.dest, instruction.sourceId) ) } @@ -49,25 +49,25 @@ class InstructionFixer { instruction.op == AsmBinaryOp.MUL && instruction.dest is Stack -> { listOf( - Mov(instruction.dest, Register(HardwareRegister.R11D)), - AsmBinary(instruction.op, instruction.src, Register(HardwareRegister.R11D)), - Mov(Register(HardwareRegister.R11D), instruction.dest) + Mov(instruction.dest, Register(HardwareRegister.R11D), instruction.sourceId), + AsmBinary(instruction.op, instruction.src, Register(HardwareRegister.R11D), instruction.sourceId), + Mov(Register(HardwareRegister.R11D), instruction.dest, instruction.sourceId) ) } // `cmp` cannot be memory-to-memory. instruction is Cmp && instruction.src is Stack && instruction.dest is Stack -> { listOf( - Mov(instruction.src, Register(HardwareRegister.R10D)), - Cmp(Register(HardwareRegister.R10D), instruction.dest) + Mov(instruction.src, Register(HardwareRegister.R10D), instruction.sourceId), + Cmp(Register(HardwareRegister.R10D), instruction.dest, instruction.sourceId) ) } // The destination of `cmp` cannot be an immediate. instruction is Cmp && instruction.dest is Imm -> { listOf( - Mov(instruction.dest, Register(HardwareRegister.R11D)), - Cmp(instruction.src, Register(HardwareRegister.R11D)) + Mov(instruction.dest, Register(HardwareRegister.R11D), instruction.sourceId), + Cmp(instruction.src, Register(HardwareRegister.R11D), instruction.sourceId) ) } @@ -83,7 +83,8 @@ class InstructionFixer { val finalInstructions = if (stackSpace > 0) { - listOf(AllocateStack(stackSpace)) + fixedInstructions + // Stack allocation is a function-level operation, not tied to a specific source instruction + listOf(AllocateStack(stackSpace, "")) + fixedInstructions } else { fixedInstructions } diff --git a/src/jsMain/kotlin/assembly/Instructions.kt b/src/jsMain/kotlin/assembly/Instructions.kt index ab4f8db..d1e843e 100644 --- a/src/jsMain/kotlin/assembly/Instructions.kt +++ b/src/jsMain/kotlin/assembly/Instructions.kt @@ -1,12 +1,25 @@ package assembly -sealed class Instruction : AsmConstruct() +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable -object Ret : Instruction() +@Serializable +sealed class Instruction() : AsmConstruct() { + abstract val sourceId: String +} + +@Serializable +@SerialName("Ret") +data class Ret( + override val sourceId: String = "" +) : Instruction() +@Serializable +@SerialName("Mov") data class Mov( val src: Operand, - val dest: Operand + val dest: Operand, + override val sourceId: String = "" ) : Instruction() enum class AsmUnaryOp( @@ -24,56 +37,93 @@ enum class AsmBinaryOp( MUL("imul") } +@Serializable +@SerialName("AsmUnary") data class AsmUnary( val op: AsmUnaryOp, - val dest: Operand + val dest: Operand, + override val sourceId: String = "" ) : Instruction() +@Serializable +@SerialName("AsmBinary") data class AsmBinary( val op: AsmBinaryOp, val src: Operand, - val dest: Operand + val dest: Operand, + override val sourceId: String = "" ) : Instruction() +@Serializable +@SerialName("Idiv") data class Idiv( - val divisor: Operand + val divisor: Operand, + override val sourceId: String = "" ) : Instruction() // Convert Doubleword 32 to Quadword 64 -object Cdq : Instruction() +@Serializable +@SerialName("Cdq") +data class Cdq( + override val sourceId: String = "" +) : Instruction() +@Serializable +@SerialName("AllocateStack") data class AllocateStack( - val size: Int + val size: Int, + override val sourceId: String = "" ) : Instruction() +@Serializable +@SerialName("DeAllocateStack") data class DeAllocateStack( - val size: Int + val size: Int, + override val sourceId: String = "" ) : Instruction() +@Serializable +@SerialName("Push") data class Push( - val operand: Operand + val operand: Operand, + override val sourceId: String = "" ) : Instruction() +@Serializable +@SerialName("Call") data class Call( - val identifier: String + val identifier: String, + override val sourceId: String = "" ) : Instruction() +@Serializable +@SerialName("Label") data class Label( - val name: String + val name: String, + override val sourceId: String = "" ) : Instruction() +@Serializable +@SerialName("Jmp") data class Jmp( - val label: Label + val label: Label, + override val sourceId: String = "" ) : Instruction() +@Serializable +@SerialName("JmpCC") data class JmpCC( val condition: ConditionCode, - val label: Label + val label: Label, + override val sourceId: String = "" ) : Instruction() +@Serializable +@SerialName("Cmp") data class Cmp( val src: Operand, - val dest: Operand + val dest: Operand, + override val sourceId: String = "" ) : Instruction() enum class ConditionCode( @@ -87,7 +137,10 @@ enum class ConditionCode( GE("ge") // Greater or Equal } +@Serializable +@SerialName("SetCC") data class SetCC( val condition: ConditionCode, - val dest: Operand + val dest: Operand, + override val sourceId: String = "" ) : Instruction() diff --git a/src/jsMain/kotlin/assembly/Operands.kt b/src/jsMain/kotlin/assembly/Operands.kt index 03056c5..2bbf489 100644 --- a/src/jsMain/kotlin/assembly/Operands.kt +++ b/src/jsMain/kotlin/assembly/Operands.kt @@ -1,19 +1,31 @@ package assembly +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable sealed class Operand : AsmConstruct() +@Serializable +@SerialName("Imm") data class Imm( val value: Int ) : Operand() +@Serializable +@SerialName("Register") data class Register( val name: HardwareRegister ) : Operand() +@Serializable +@SerialName("Pseudo") data class Pseudo( val name: String ) : Operand() +@Serializable +@SerialName("Stack") data class Stack( val offset: Int ) : Operand() diff --git a/src/jsMain/kotlin/assembly/Programs.kt b/src/jsMain/kotlin/assembly/Programs.kt index 98989f1..17f857d 100644 --- a/src/jsMain/kotlin/assembly/Programs.kt +++ b/src/jsMain/kotlin/assembly/Programs.kt @@ -1,7 +1,13 @@ package assembly +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable sealed class Program : AsmConstruct() +@Serializable +@SerialName("AsmProgram") data class AsmProgram( val functions: List ) : Program() diff --git a/src/jsMain/kotlin/assembly/PseudoEliminator.kt b/src/jsMain/kotlin/assembly/PseudoEliminator.kt index bcbd60f..a0b79af 100644 --- a/src/jsMain/kotlin/assembly/PseudoEliminator.kt +++ b/src/jsMain/kotlin/assembly/PseudoEliminator.kt @@ -26,14 +26,14 @@ class PseudoEliminator { val newInstructions = function.body.map { instruction -> when (instruction) { - is Mov -> Mov(replace(instruction.src), replace(instruction.dest)) - is AsmUnary -> AsmUnary(instruction.op, replace(instruction.dest)) - is AsmBinary -> AsmBinary(instruction.op, replace(instruction.src), replace(instruction.dest)) - is Cmp -> Cmp(replace(instruction.src), replace(instruction.dest)) - is SetCC -> SetCC(instruction.condition, replace(instruction.dest)) - is Push -> Push(replace(instruction.operand)) + is Mov -> Mov(replace(instruction.src), replace(instruction.dest), instruction.sourceId) + is AsmUnary -> AsmUnary(instruction.op, replace(instruction.dest), instruction.sourceId) + is AsmBinary -> AsmBinary(instruction.op, replace(instruction.src), replace(instruction.dest), instruction.sourceId) + is Cmp -> Cmp(replace(instruction.src), replace(instruction.dest), instruction.sourceId) + is SetCC -> SetCC(instruction.condition, replace(instruction.dest), instruction.sourceId) + is Push -> Push(replace(instruction.operand), instruction.sourceId) is Call -> instruction - is Idiv -> Idiv(replace(instruction.divisor)) + is Idiv -> Idiv(replace(instruction.divisor), instruction.sourceId) else -> instruction } } diff --git a/src/jsMain/kotlin/export/ASTExport.kt b/src/jsMain/kotlin/export/ASTExport.kt index b83b026..90f38e2 100644 --- a/src/jsMain/kotlin/export/ASTExport.kt +++ b/src/jsMain/kotlin/export/ASTExport.kt @@ -1,8 +1,7 @@ package export -import kotlinx.serialization.encodeToString -import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import parser.AssignmentExpression @@ -40,18 +39,33 @@ fun createJsonNode( type: String, label: String, children: JsonObject, - edgeLabels: Boolean = false -): String { - val jsonNode = - JsonObject( + edgeLabels: Boolean = false, + location: parser.SourceLocation? = null, + id: String? = null +): JsonObject { + val nodeMap = mutableMapOf( + "type" to JsonPrimitive(type), + "label" to JsonPrimitive(label), + "children" to children, + "edgeLabels" to JsonPrimitive(edgeLabels) + ) + + location?.let { + nodeMap["location"] = JsonObject( mapOf( - "type" to JsonPrimitive(type), - "label" to JsonPrimitive(label), - "children" to children, - "edgeLabels" to JsonPrimitive(edgeLabels) + "startLine" to JsonPrimitive(it.startLine), + "startCol" to JsonPrimitive(it.startCol), + "endLine" to JsonPrimitive(it.endLine), + "endCol" to JsonPrimitive(it.endCol) ) ) - return Json.encodeToString(jsonNode) + } + + id?.let { + nodeMap["id"] = JsonPrimitive(it) + } + + return JsonObject(nodeMap) } @OptIn(ExperimentalJsExport::class) @@ -66,195 +80,195 @@ enum class NodeType { Declaration } -class ASTExport : Visitor { - override fun visit(node: SimpleProgram): String { - val decls = JsonArray(node.functionDeclaration.map { JsonPrimitive(it.accept(this)) }) - return createJsonNode(NodeType.Program.name, "Program", JsonObject(mapOf("declarations" to decls))) +class ASTExport : Visitor { + override fun visit(node: SimpleProgram): JsonObject { + val decls = JsonArray(node.functionDeclaration.map { it.accept(this) }) + return createJsonNode(NodeType.Program.name, "Program", JsonObject(mapOf("declarations" to decls)), false, node.location, node.id) } - override fun visit(node: ReturnStatement): String { - val children = JsonObject(mapOf("expression" to JsonPrimitive(node.expression.accept(this)))) - return createJsonNode(NodeType.Statement.name, "ReturnStatement", children) + override fun visit(node: ReturnStatement): JsonObject { + val children = JsonObject(mapOf("expression" to node.expression.accept(this))) + return createJsonNode(NodeType.Statement.name, "ReturnStatement", children, false, node.location, node.id) } - override fun visit(node: ExpressionStatement): String { - val children = JsonObject(mapOf("expression" to JsonPrimitive(node.expression.accept(this)))) - return createJsonNode(NodeType.Statement.name, "ExpressionStatement", children) + override fun visit(node: ExpressionStatement): JsonObject { + val children = JsonObject(mapOf("expression" to node.expression.accept(this))) + return createJsonNode(NodeType.Statement.name, "ExpressionStatement", children, false, node.location, node.id) } - override fun visit(node: NullStatement): String = createJsonNode(NodeType.Statement.name, "NullStatement", JsonObject(emptyMap())) + override fun visit(node: NullStatement): JsonObject = createJsonNode(NodeType.Statement.name, "NullStatement", JsonObject(emptyMap()), false, node.location, node.id) - override fun visit(node: BreakStatement): String = createJsonNode(NodeType.Statement.name, "BreakStatement", JsonObject(emptyMap())) + override fun visit(node: BreakStatement): JsonObject = createJsonNode(NodeType.Statement.name, "BreakStatement", JsonObject(emptyMap()), false, node.location, node.id) - override fun visit(node: ContinueStatement): String = createJsonNode(NodeType.Statement.name, "continue", JsonObject(emptyMap())) + override fun visit(node: ContinueStatement): JsonObject = createJsonNode(NodeType.Statement.name, "continue", JsonObject(emptyMap()), false, node.location, node.id) - override fun visit(node: WhileStatement): String { + override fun visit(node: WhileStatement): JsonObject { val children = JsonObject( mapOf( - "cond" to JsonPrimitive(node.condition.accept(this)), - "body" to JsonPrimitive(node.body.accept(this)) + "cond" to node.condition.accept(this), + "body" to node.body.accept(this) ) ) - return createJsonNode(NodeType.Statement.name, "WhileLoop", children, edgeLabels = true) + return createJsonNode(NodeType.Statement.name, "WhileLoop", children, true, node.location, node.id) } - override fun visit(node: DoWhileStatement): String { + override fun visit(node: DoWhileStatement): JsonObject { val children = JsonObject( mapOf( - "body" to JsonPrimitive(node.body.accept(this)), - "cond" to JsonPrimitive(node.condition.accept(this)) + "body" to node.body.accept(this), + "cond" to node.condition.accept(this) ) ) - return createJsonNode(NodeType.Statement.name, "DoWhileLoop", children, edgeLabels = true) + return createJsonNode(NodeType.Statement.name, "DoWhileLoop", children, true, node.location, node.id) } - override fun visit(node: ForStatement): String { + override fun visit(node: ForStatement): JsonObject { val childrenMap = mutableMapOf( - "init" to JsonPrimitive(node.init.accept(this)) + "init" to node.init.accept(this) ) - node.condition?.let { childrenMap["cond"] = JsonPrimitive(it.accept(this)) } - node.post?.let { childrenMap["post"] = JsonPrimitive(it.accept(this)) } - childrenMap["body"] = JsonPrimitive(node.body.accept(this)) + node.condition?.let { childrenMap["cond"] = it.accept(this) } + node.post?.let { childrenMap["post"] = it.accept(this) } + childrenMap["body"] = node.body.accept(this) - return createJsonNode(NodeType.Statement.name, "ForLoop", JsonObject(childrenMap), edgeLabels = true) + return createJsonNode(NodeType.Statement.name, "ForLoop", JsonObject(childrenMap), true, node.location, node.id) } - override fun visit(node: InitDeclaration): String { - val children = JsonObject(mapOf("declaration" to JsonPrimitive(node.varDeclaration.accept(this)))) - return createJsonNode(NodeType.ASTNode.name, "Declaration", children) + override fun visit(node: InitDeclaration): JsonObject { + val children = JsonObject(mapOf("declaration" to node.varDeclaration.accept(this))) + return createJsonNode(NodeType.ASTNode.name, "Declaration", children, false, node.location, node.id) } - override fun visit(node: InitExpression): String { + override fun visit(node: InitExpression): JsonObject { val childrenMap = mutableMapOf() - node.expression?.let { childrenMap["expression"] = JsonPrimitive(it.accept(this)) } - return createJsonNode(NodeType.Expression.name, "Expression", JsonObject(childrenMap)) + node.expression?.let { childrenMap["expression"] = it.accept(this) } + return createJsonNode(NodeType.Expression.name, "Expression", JsonObject(childrenMap), false, node.location, node.id) } - override fun visit(node: FunctionDeclaration): String { - val childrenMap = mutableMapOf("name" to JsonPrimitive(node.name)) - node.body?.let { childrenMap["body"] = JsonPrimitive(it.accept(this)) } - return createJsonNode(NodeType.Function.name, "Function(${node.name})", JsonObject(childrenMap), edgeLabels = true) + override fun visit(node: FunctionDeclaration): JsonObject { + val childrenMap = mutableMapOf("name" to JsonPrimitive(node.name)) + node.body?.let { childrenMap["body"] = it.accept(this) } + return createJsonNode(NodeType.Function.name, "Function(${node.name})", JsonObject(childrenMap), true, node.location, node.id) } - override fun visit(node: VarDecl): String { - val children = JsonObject(mapOf("variableDeclaration" to JsonPrimitive(node.varDecl.accept(this)))) - return createJsonNode(NodeType.Declaration.name, "VarDeclaration", children) + override fun visit(node: VarDecl): JsonObject { + val children = JsonObject(mapOf("variableDeclaration" to node.varDecl.accept(this))) + return createJsonNode(NodeType.Declaration.name, "VarDeclaration", children, false, node.location, node.id) } - override fun visit(node: FunDecl): String { - val children = JsonObject(mapOf("functionDeclaration" to JsonPrimitive(node.funDecl.accept(this)))) - return createJsonNode(NodeType.Declaration.name, "FuncDeclaration", children) + override fun visit(node: FunDecl): JsonObject { + val children = JsonObject(mapOf("functionDeclaration" to node.funDecl.accept(this))) + return createJsonNode(NodeType.Declaration.name, "FuncDeclaration", children, false, node.location, node.id) } - override fun visit(node: VariableExpression): String = - createJsonNode(NodeType.Expression.name, "Variable(${node.name})", JsonObject(emptyMap())) + override fun visit(node: VariableExpression): JsonObject = + createJsonNode(NodeType.Expression.name, "Variable(${node.name})", JsonObject(emptyMap()), false, node.location, node.id) - override fun visit(node: UnaryExpression): String { + override fun visit(node: UnaryExpression): JsonObject { val children = JsonObject( mapOf( "operator" to JsonPrimitive(node.operator.toString()), - "expression" to JsonPrimitive(node.expression.accept(this)) + "expression" to node.expression.accept(this) ) ) - return createJsonNode(NodeType.Expression.name, "UnaryExpression(${node.operator.type})", children) + return createJsonNode(NodeType.Expression.name, "UnaryExpression(${node.operator.type})", children, location = node.location, id = node.id) } - override fun visit(node: BinaryExpression): String { + override fun visit(node: BinaryExpression): JsonObject { val children = JsonObject( mapOf( - "left" to JsonPrimitive(node.left.accept(this)), - "right" to JsonPrimitive(node.right.accept(this)) + "left" to node.left.accept(this), + "right" to node.right.accept(this) ) ) - return createJsonNode(NodeType.Expression.name, "BinaryExpression(${node.operator.type})", children, edgeLabels = true) + return createJsonNode(NodeType.Expression.name, "BinaryExpression(${node.operator.type})", children, edgeLabels = true, location = node.location, id = node.id) } - override fun visit(node: IntExpression): String = createJsonNode(NodeType.Expression.name, "Int(${node.value})", JsonObject(emptyMap())) + override fun visit(node: IntExpression): JsonObject = createJsonNode(NodeType.Expression.name, "Int(${node.value})", JsonObject(emptyMap()), false, node.location, node.id) - override fun visit(node: IfStatement): String { + override fun visit(node: IfStatement): JsonObject { val childrenMap = mutableMapOf( - "cond" to JsonPrimitive(node.condition.accept(this)), - "then" to JsonPrimitive(node.then.accept(this)) + "cond" to node.condition.accept(this), + "then" to node.then.accept(this) ) - node._else?.let { childrenMap["else"] = JsonPrimitive(it.accept(this)) } - return createJsonNode(NodeType.Statement.name, "IfStatement", JsonObject(childrenMap), edgeLabels = true) + node._else?.let { childrenMap["else"] = it.accept(this) } + return createJsonNode(NodeType.Statement.name, "IfStatement", JsonObject(childrenMap), true, node.location, node.id) } - override fun visit(node: ConditionalExpression): String { + override fun visit(node: ConditionalExpression): JsonObject { val children = JsonObject( mapOf( - "cond" to JsonPrimitive(node.codition.accept(this)), - "then" to JsonPrimitive(node.thenExpression.accept(this)), - "else" to JsonPrimitive(node.elseExpression.accept(this)) + "cond" to node.codition.accept(this), + "then" to node.thenExpression.accept(this), + "else" to node.elseExpression.accept(this) ) ) - return createJsonNode(NodeType.Expression.name, "ConditionalExpression", children, edgeLabels = true) + return createJsonNode(NodeType.Expression.name, "ConditionalExpression", children, true, node.location, node.id) } - override fun visit(node: GotoStatement): String { + override fun visit(node: GotoStatement): JsonObject { val children = JsonObject(mapOf("targetLabel" to JsonPrimitive(node.label))) - return createJsonNode(NodeType.Statement.name, "Goto(${node.label})", children, edgeLabels = true) + return createJsonNode(NodeType.Statement.name, "Goto(${node.label})", children, true, node.location, node.id) } - override fun visit(node: LabeledStatement): String { + override fun visit(node: LabeledStatement): JsonObject { val children = JsonObject( mapOf( "label" to JsonPrimitive(node.label), - "statement" to JsonPrimitive(node.statement.accept(this)) + "statement" to node.statement.accept(this) ) ) - return createJsonNode(NodeType.Statement.name, "LabeledStatement(${node.label})", children, edgeLabels = true) + return createJsonNode(NodeType.Statement.name, "LabeledStatement(${node.label})", children, true, node.location, node.id) } - override fun visit(node: AssignmentExpression): String { + override fun visit(node: AssignmentExpression): JsonObject { val children = JsonObject( mapOf( - "lvalue" to JsonPrimitive(node.lvalue.accept(this)), - "rvalue" to JsonPrimitive(node.rvalue.accept(this)) + "lvalue" to node.lvalue.accept(this), + "rvalue" to node.rvalue.accept(this) ) ) - return createJsonNode(NodeType.Expression.name, "Assignment", children, edgeLabels = true) + return createJsonNode(NodeType.Expression.name, "Assignment", children, true, node.location, node.id) } - override fun visit(node: VariableDeclaration): String { - val childrenMap = mutableMapOf("name" to JsonPrimitive(node.name)) - node.init?.let { childrenMap["initializer"] = JsonPrimitive(it.accept(this)) } - return createJsonNode(NodeType.Declaration.name, "Declaration(${node.name})", JsonObject(childrenMap)) + override fun visit(node: VariableDeclaration): JsonObject { + val childrenMap = mutableMapOf("name" to JsonPrimitive(node.name)) + node.init?.let { childrenMap["initializer"] = it.accept(this) } + return createJsonNode(NodeType.Declaration.name, "Declaration(${node.name})", JsonObject(childrenMap), false, node.location, node.id) } - override fun visit(node: S): String = node.statement.accept(this) + override fun visit(node: S): JsonObject = node.statement.accept(this) - override fun visit(node: D): String = - when (node.declaration) { - is VarDecl -> node.declaration.accept(this) - is FunDecl -> node.declaration.accept(this) - } + override fun visit(node: D): JsonObject = when (node.declaration) { + is VarDecl -> node.declaration.accept(this) + is FunDecl -> node.declaration.accept(this) + is VariableDeclaration -> node.declaration.accept(this) + } - override fun visit(node: Block): String { - val blockItems = node.block.map { it.accept(this) } - val children = JsonObject(mapOf("block" to JsonArray(blockItems.map { JsonPrimitive(it) }))) - return createJsonNode(NodeType.Block.name, "Block", children) + override fun visit(node: Block): JsonObject { + val blockItems = node.items.map { it.accept(this) } + val children = JsonObject(mapOf("block" to JsonArray(blockItems))) + return createJsonNode(NodeType.Block.name, "Block", children, false, node.location, node.id) } - override fun visit(node: CompoundStatement): String = node.block.accept(this) + override fun visit(node: CompoundStatement): JsonObject = node.block.accept(this) - override fun visit(node: FunctionCall): String { + override fun visit(node: FunctionCall): JsonObject { val children = JsonObject( mapOf( "name" to JsonPrimitive(node.name), - "arguments" to JsonArray(node.arguments.map { JsonPrimitive(it.accept(this)) }) + "arguments" to JsonArray(node.arguments.map { it.accept(this) }) ) ) - return createJsonNode(NodeType.Function.name, "FuncCall(${node.name})", children) + return createJsonNode(NodeType.Function.name, "FuncCall(${node.name})", children, false, node.location, node.id) } } diff --git a/src/jsMain/kotlin/export/CompilationOutput.kt b/src/jsMain/kotlin/export/CompilationOutput.kt index 8e80722..5c578d2 100644 --- a/src/jsMain/kotlin/export/CompilationOutput.kt +++ b/src/jsMain/kotlin/export/CompilationOutput.kt @@ -21,7 +21,8 @@ sealed class CompilationOutput { data class LexerOutput( override val stage: String = CompilerStage.LEXER.name.lowercase(), val tokens: String? = null, - override val errors: Array + override val errors: Array, + val sourceLocation: SourceLocationInfo? = null ) : CompilationOutput() @OptIn(ExperimentalJsExport::class) @@ -31,7 +32,8 @@ data class LexerOutput( data class ParserOutput( override val stage: String = CompilerStage.PARSER.name.lowercase(), val ast: String? = null, - override val errors: Array + override val errors: Array, + val sourceLocation: SourceLocationInfo? = null ) : CompilationOutput() @OptIn(ExperimentalJsExport::class) @@ -40,8 +42,11 @@ data class ParserOutput( @SerialName("TackyOutput") data class TackyOutput( override val stage: String = CompilerStage.TACKY.name.lowercase(), + val tacky: String? = null, + val tackyPretty: String? = null, - override val errors: Array + override val errors: Array, + val sourceLocation: SourceLocationInfo? = null ) : CompilationOutput() @OptIn(ExperimentalJsExport::class) @@ -51,9 +56,22 @@ data class TackyOutput( data class AssemblyOutput( override val stage: String = CompilerStage.ASSEMBLY.name.lowercase(), val assembly: String? = null, - override val errors: Array + val rawAssembly: String? = null, + override val errors: Array, + val sourceLocation: SourceLocationInfo? = null ) : CompilationOutput() +@OptIn(ExperimentalJsExport::class) +@JsExport +@Serializable +data class SourceLocationInfo( + val startLine: Int, + val startColumn: Int, + val endLine: Int, + val endColumn: Int, + val totalLines: Int +) + @OptIn(ExperimentalJsExport::class) @JsExport @Serializable diff --git a/src/jsMain/kotlin/export/CompilerExport.kt b/src/jsMain/kotlin/export/CompilerExport.kt index 503350c..c909a79 100644 --- a/src/jsMain/kotlin/export/CompilerExport.kt +++ b/src/jsMain/kotlin/export/CompilerExport.kt @@ -17,40 +17,65 @@ import tacky.TackyProgram @OptIn(ExperimentalJsExport::class) @JsExport class CompilerExport { + + private fun calculateSourceLocationInfo(code: String): SourceLocationInfo { + val lines = code.split('\n') + val totalLines = lines.size + val lastLine = lines.lastOrNull() ?: "" + val endColumn = lastLine.length + 1 + + return SourceLocationInfo( + startLine = 1, + startColumn = 1, + endLine = totalLines, + endColumn = endColumn, + totalLines = totalLines + ) + } + fun exportCompilationResults(code: String): String { val outputs = mutableListOf() val overallErrors = mutableListOf() val codeEmitter = CodeEmitter() + val sourceLocationInfo = calculateSourceLocationInfo(code) + try { val tokens = CompilerWorkflow.take(code) Lexer(code) outputs.add( LexerOutput( - tokens = tokens.toJsonString(), - errors = emptyArray() + tokens = exportTokens(tokens), + errors = emptyArray(), + sourceLocation = sourceLocationInfo ) ) val ast = CompilerWorkflow.take(tokens) outputs.add( ParserOutput( errors = emptyArray(), - ast = ast.accept(ASTExport()) + ast = Json.encodeToString(ast.accept(ASTExport())), + sourceLocation = sourceLocationInfo ) ) val tacky = CompilerWorkflow.take(ast) val tackyProgram = tacky as? TackyProgram outputs.add( TackyOutput( + tacky = if (tackyProgram != null) Json.encodeToString(tackyProgram) else null, tackyPretty = tackyProgram?.toPseudoCode(), - errors = emptyArray() + errors = emptyArray(), + sourceLocation = sourceLocationInfo ) ) val asm = CompilerWorkflow.take(tacky) val finalAssemblyString = codeEmitter.emit(asm as AsmProgram) + val rawAssembly = codeEmitter.emitRaw(asm as AsmProgram) outputs.add( AssemblyOutput( errors = emptyArray(), - assembly = finalAssemblyString + assembly = finalAssemblyString, + rawAssembly = rawAssembly, + sourceLocation = sourceLocationInfo ) ) } catch (e: CompilationException) { @@ -70,10 +95,10 @@ class CompilerExport { ) overallErrors.add(error) when (stage) { - CompilerStage.LEXER -> outputs.add(LexerOutput(errors = arrayOf(error))) - CompilerStage.PARSER -> outputs.add(ParserOutput(errors = arrayOf(error))) - CompilerStage.TACKY -> outputs.add(TackyOutput(errors = arrayOf(error))) - CompilerStage.ASSEMBLY -> outputs.add(AssemblyOutput(errors = arrayOf(error))) + CompilerStage.LEXER -> outputs.add(LexerOutput(errors = arrayOf(error), sourceLocation = sourceLocationInfo)) + CompilerStage.PARSER -> outputs.add(ParserOutput(errors = arrayOf(error), sourceLocation = sourceLocationInfo)) + CompilerStage.TACKY -> outputs.add(TackyOutput(errors = arrayOf(error), sourceLocation = sourceLocationInfo)) + CompilerStage.ASSEMBLY -> outputs.add(AssemblyOutput(errors = arrayOf(error), sourceLocation = sourceLocationInfo)) } } catch (e: Exception) { // Fallback for any unexpected runtime errors @@ -86,9 +111,9 @@ class CompilerExport { overallErrors.add(error) // ensure we return four stages while (outputs.size < 3) { - outputs.add(ParserOutput(errors = emptyArray())) + outputs.add(ParserOutput(errors = emptyArray(), sourceLocation = sourceLocationInfo)) } - outputs.add(AssemblyOutput(errors = arrayOf(error))) + outputs.add(AssemblyOutput(errors = arrayOf(error), sourceLocation = sourceLocationInfo)) } val result = @@ -101,19 +126,27 @@ class CompilerExport { return result.toJsonString() } - fun List.toJsonString(): String { - val jsonTokens = - this.map { token -> - JsonObject( - mapOf( - "line" to JsonPrimitive(token.line), - "column" to JsonPrimitive(token.column), - "type" to JsonPrimitive(token.type.toString()), - "lexeme" to JsonPrimitive(token.lexeme) + fun exportTokens(tokens: List): String = tokens.toJsonString() +} + +fun List.toJsonString(): String { + val jsonTokens = + this.map { token -> + JsonObject( + mapOf( + "type" to JsonPrimitive(token.type.toString()), + "lexeme" to JsonPrimitive(token.lexeme), + "location" to JsonObject( + mapOf( + "startLine" to JsonPrimitive(token.startLine), + "startCol" to JsonPrimitive(token.startColumn), + "endLine" to JsonPrimitive(token.endLine), + "endCol" to JsonPrimitive(token.endColumn) + ) ) ) - } + ) + } - return Json.encodeToString(JsonArray(jsonTokens)) - } + return Json.encodeToString(JsonArray(jsonTokens)) } diff --git a/src/jsMain/kotlin/lexer/Lexer.kt b/src/jsMain/kotlin/lexer/Lexer.kt index 0908aff..543922c 100644 --- a/src/jsMain/kotlin/lexer/Lexer.kt +++ b/src/jsMain/kotlin/lexer/Lexer.kt @@ -97,10 +97,19 @@ sealed class TokenType { data class Token( val type: TokenType, val lexeme: String, - val line: Int, - val column: Int + val startLine: Int, + val startColumn: Int, + val endLine: Int, + val endColumn: Int ) { - override fun equals(other: Any?): Boolean = other is Token && other.type == this.type && other.lexeme == this.lexeme + override fun equals(other: Any?): Boolean = + other is Token && + other.type == this.type && + other.lexeme == this.lexeme && + other.startLine == this.startLine && + other.startColumn == this.startColumn && + other.endLine == this.endLine && + other.endColumn == this.endColumn } class Lexer( @@ -133,7 +142,7 @@ class Lexer( start = current scanToken() } - tokens.add(Token(TokenType.EOF, "", line, current - lineStart + 1)) + tokens.add(Token(TokenType.EOF, "", line, current - lineStart + 1, line, current - lineStart + 1)) return tokens } @@ -264,7 +273,17 @@ class Lexer( private fun addToken(type: TokenType) { val text = source.substring(start, current) - val column = start - lineStart + 1 - tokens.add(Token(type, text, line, column)) + val startCol = start - lineStart + 1 + val endCol = current - lineStart + tokens.add( + Token( + type, + text, + startLine = line, + startColumn = startCol, + endLine = line, + endColumn = endCol + ) + ) } } diff --git a/src/jsMain/kotlin/parser/ASTNode.kt b/src/jsMain/kotlin/parser/ASTNode.kt index a640265..df079ab 100644 --- a/src/jsMain/kotlin/parser/ASTNode.kt +++ b/src/jsMain/kotlin/parser/ASTNode.kt @@ -1,5 +1,9 @@ package parser -sealed class ASTNode { +import kotlin.random.Random + +data class SourceLocation(val startLine: Int, val startCol: Int, val endLine: Int, val endCol: Int) + +sealed class ASTNode(open val location: SourceLocation, open val id: String = Random.nextLong().toString()) { abstract fun accept(visitor: Visitor): T } diff --git a/src/jsMain/kotlin/parser/BlockItems.kt b/src/jsMain/kotlin/parser/BlockItems.kt index 1986516..60bb64b 100644 --- a/src/jsMain/kotlin/parser/BlockItems.kt +++ b/src/jsMain/kotlin/parser/BlockItems.kt @@ -1,20 +1,22 @@ package parser -sealed class Statement : ASTNode() +sealed class Statement(location: SourceLocation) : ASTNode(location) data class ReturnStatement( - val expression: Expression -) : Statement() { + val expression: Expression, + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class ExpressionStatement( - val expression: Expression -) : Statement() { + val expression: Expression, + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } -class NullStatement : Statement() { +class NullStatement(override val location: SourceLocation) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) override fun equals(other: Any?): Boolean = other is NullStatement @@ -23,30 +25,34 @@ class NullStatement : Statement() { } data class BreakStatement( - var label: String = "" -) : Statement() { + var label: String = "", + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class ContinueStatement( - var label: String = "" -) : Statement() { + var label: String = "", + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class WhileStatement( val condition: Expression, val body: Statement, - var label: String = "" -) : Statement() { + var label: String = "", + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class DoWhileStatement( val condition: Expression, val body: Statement, - var label: String = "" -) : Statement() { + var label: String = "", + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } @@ -55,89 +61,98 @@ data class ForStatement( val condition: Expression?, val post: Expression?, val body: Statement, - var label: String = "" -) : Statement() { + var label: String = "", + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } -sealed class ForInit : ASTNode() +sealed class ForInit(location: SourceLocation) : ASTNode(location) data class InitDeclaration( - val varDeclaration: VariableDeclaration -) : ForInit() { + val varDeclaration: VariableDeclaration, + override val location: SourceLocation +) : ForInit(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class InitExpression( - val expression: Expression? -) : ForInit() { + val expression: Expression?, + override val location: SourceLocation +) : ForInit(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } class IfStatement( val condition: Expression, val then: Statement, - val _else: Statement? -) : Statement() { + val _else: Statement?, + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } class GotoStatement( - val label: String -) : Statement() { + val label: String, + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } class LabeledStatement( val label: String, - val statement: Statement -) : Statement() { + val statement: Statement, + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } -sealed class Declaration : ASTNode() +sealed class Declaration(location: SourceLocation) : ASTNode(location) data class VariableDeclaration( val name: String, - val init: Expression? -) : ASTNode() { + val init: Expression?, + override val location: SourceLocation +) : Declaration(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class VarDecl( val varDecl: VariableDeclaration -) : Declaration() { +) : Declaration(location = varDecl.location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class FunDecl( val funDecl: FunctionDeclaration -) : Declaration() { +) : Declaration(location = funDecl.location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } -sealed class BlockItem : ASTNode() +sealed class BlockItem(location: SourceLocation) : ASTNode(location) data class S( val statement: Statement -) : BlockItem() { +) : BlockItem(location = statement.location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class D( val declaration: Declaration -) : BlockItem() { +) : BlockItem(declaration.location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class CompoundStatement( - val block: Block -) : Statement() { + val block: Block, + override val location: SourceLocation +) : Statement(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class Block( - val block: List -) : ASTNode() { + val items: List, + override val location: SourceLocation +) : ASTNode(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } diff --git a/src/jsMain/kotlin/parser/Expressions.kt b/src/jsMain/kotlin/parser/Expressions.kt index 5a743fd..6dbe0d8 100644 --- a/src/jsMain/kotlin/parser/Expressions.kt +++ b/src/jsMain/kotlin/parser/Expressions.kt @@ -2,53 +2,60 @@ package parser import lexer.Token -sealed class Expression : ASTNode() +sealed class Expression(location: SourceLocation) : ASTNode(location) data class IntExpression( - val value: Int -) : Expression() { + val value: Int, + override val location: SourceLocation +) : Expression(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class VariableExpression( - val name: String -) : Expression() { + val name: String, + override val location: SourceLocation +) : Expression(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class UnaryExpression( val operator: Token, - val expression: Expression -) : Expression() { + val expression: Expression, + override val location: SourceLocation +) : Expression(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class BinaryExpression( val left: Expression, val operator: Token, - val right: Expression -) : Expression() { + val right: Expression, + override val location: SourceLocation +) : Expression(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class AssignmentExpression( val lvalue: VariableExpression, - val rvalue: Expression -) : Expression() { + val rvalue: Expression, + override val location: SourceLocation +) : Expression(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class ConditionalExpression( val codition: Expression, val thenExpression: Expression, - val elseExpression: Expression -) : Expression() { + val elseExpression: Expression, + override val location: SourceLocation +) : Expression(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } data class FunctionCall( val name: String, - val arguments: List -) : Expression() { + val arguments: List, + override val location: SourceLocation +) : Expression(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } diff --git a/src/jsMain/kotlin/parser/FunctionDeclaration.kt b/src/jsMain/kotlin/parser/FunctionDeclaration.kt index 858482a..c478a32 100644 --- a/src/jsMain/kotlin/parser/FunctionDeclaration.kt +++ b/src/jsMain/kotlin/parser/FunctionDeclaration.kt @@ -5,7 +5,8 @@ package parser data class FunctionDeclaration( val name: String, val params: List, - val body: Block? -) : ASTNode() { + val body: Block?, + override val location: SourceLocation +) : ASTNode(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } diff --git a/src/jsMain/kotlin/parser/Parser.kt b/src/jsMain/kotlin/parser/Parser.kt index 4921ce1..75ba66a 100644 --- a/src/jsMain/kotlin/parser/Parser.kt +++ b/src/jsMain/kotlin/parser/Parser.kt @@ -40,17 +40,22 @@ class Parser { private fun parseProgram(tokens: MutableList): SimpleProgram { val declarations = mutableListOf() + val startLine = 0 + val startColumn = 0 while (tokens.isNotEmpty() && tokens.first().type != TokenType.EOF) { declarations.add(parseFunctionDeclaration(tokens)) } + val endLine = declarations.last().location.endLine + val endColumn = declarations.last().location.endCol return SimpleProgram( - functionDeclaration = declarations + functionDeclaration = declarations, + location = SourceLocation(startLine, startColumn, endLine, endColumn) ) } private fun parseFunctionDeclaration(tokens: MutableList): FunctionDeclaration { - expect(TokenType.KEYWORD_INT, tokens) + val func = expect(TokenType.KEYWORD_INT, tokens) val name = parseIdentifier(tokens) expect(TokenType.LEFT_PAREN, tokens) val params = mutableListOf() @@ -63,19 +68,25 @@ class Parser { } else { tokens.removeFirst() // consume 'void' } - expect(TokenType.RIGHT_PAREN, tokens) + val endParan = expect(TokenType.RIGHT_PAREN, tokens) val body: Block? + val endLine: Int + val endColumn: Int if (tokens.firstOrNull()?.type == TokenType.LEFT_BRACK) { body = parseBlock(tokens) + endLine = body.location.endLine + endColumn = body.location.endCol } else { expect(TokenType.SEMICOLON, tokens) body = null + endLine = endParan.endLine + endColumn = endParan.endColumn } - return FunctionDeclaration(name, params, body) + return FunctionDeclaration(name, params, body, SourceLocation(func.startLine, func.startColumn, endLine, endColumn)) } - private fun parseFunctionDeclarationFromBody(tokens: MutableList, name: String): FunctionDeclaration { + private fun parseFunctionDeclarationFromBody(tokens: MutableList, name: String, location: SourceLocation): FunctionDeclaration { expect(TokenType.LEFT_PAREN, tokens) val params = mutableListOf() if (tokens.firstOrNull()?.type != TokenType.KEYWORD_VOID) { @@ -87,43 +98,48 @@ class Parser { } else { tokens.removeFirst() // consume 'void' } - expect(TokenType.RIGHT_PAREN, tokens) + val end = expect(TokenType.RIGHT_PAREN, tokens) val body: Block? + val finalLocation: SourceLocation if (tokens.firstOrNull()?.type == TokenType.LEFT_BRACK) { // Parse function body, we will throw exception in semantic pass body = parseBlock(tokens) + finalLocation = SourceLocation(location.startLine, location.startCol, body.location.endLine, body.location.endCol) } else { expect(TokenType.SEMICOLON, tokens) body = null + finalLocation = SourceLocation(location.startLine, location.startCol, end.endLine, end.endColumn) } - return FunctionDeclaration(name, params, body) + return FunctionDeclaration(name, params, body, finalLocation) } private fun parseBlock(tokens: MutableList): Block { val body = mutableListOf() - expect(TokenType.LEFT_BRACK, tokens) + val start = expect(TokenType.LEFT_BRACK, tokens) while (tokens.firstOrNull()?.type != TokenType.RIGHT_BRACK) { body.add(parseBlockItem(tokens)) } - expect(TokenType.RIGHT_BRACK, tokens) - return Block(body) + val end = expect(TokenType.RIGHT_BRACK, tokens) + return Block(body, SourceLocation(start.startLine, start.startColumn, end.endLine, end.endColumn)) } private fun parseBlockItem(tokens: MutableList): BlockItem = if (tokens.firstOrNull()?.type == TokenType.KEYWORD_INT) { val lookaheadTokens = tokens.toMutableList() - expect(TokenType.KEYWORD_INT, lookaheadTokens) + val start = expect(TokenType.KEYWORD_INT, lookaheadTokens) val name = parseIdentifier(lookaheadTokens) if (lookaheadTokens.firstOrNull()?.type == TokenType.LEFT_PAREN) { expect(TokenType.KEYWORD_INT, tokens) val actualName = parseIdentifier(tokens) - D(FunDecl(parseFunctionDeclarationFromBody(tokens, actualName))) + D(FunDecl(parseFunctionDeclarationFromBody(tokens, actualName, SourceLocation(start.startLine, start.startColumn, start.endLine, start.endColumn)))) } else { - expect(TokenType.KEYWORD_INT, tokens) + val end = expect(TokenType.KEYWORD_INT, tokens) val actualName = parseIdentifier(tokens) - D(VarDecl(parseVariableDeclaration(tokens, actualName))) + // check if end is right later + val location = SourceLocation(start.startLine, start.startColumn, end.endLine, end.endColumn) + D(VarDecl(parseVariableDeclaration(tokens, actualName, location))) } } else { S(parseStatement(tokens)) @@ -131,15 +147,17 @@ class Parser { private fun parseVariableDeclaration( tokens: MutableList, - name: String + name: String, + location: SourceLocation ): VariableDeclaration { var init: Expression? = null if (tokens.firstOrNull()?.type == TokenType.ASSIGN) { tokens.removeFirst() // consume '=' init = parseExpression(0, tokens) } - expect(TokenType.SEMICOLON, tokens) - return VariableDeclaration(name, init) + val end = expect(TokenType.SEMICOLON, tokens) + val finalLocation = SourceLocation(location.startLine, location.startCol, end.endLine, end.endColumn) + return VariableDeclaration(name, init, finalLocation) } private fun expect( @@ -155,8 +173,8 @@ class Parser { throw UnexpectedTokenException( expected = expected.toString(), actual = token.type.toString(), - line = token.line, - column = token.column + line = token.startLine, + column = token.startColumn ) } @@ -169,8 +187,8 @@ class Parser { throw UnexpectedTokenException( expected = TokenType.IDENTIFIER.toString(), actual = token.type.toString(), - line = token.line, - column = token.column + line = token.startLine, + column = token.startColumn ) } @@ -182,75 +200,81 @@ class Parser { val secondToken = if (tokens.size > 1) tokens[1] else null when (firstToken.type) { TokenType.IF -> { - tokens.removeFirst() + val ifToken = expect(TokenType.IF, tokens) expect(TokenType.LEFT_PAREN, tokens) val condition = parseExpression(tokens = tokens) expect(TokenType.RIGHT_PAREN, tokens) val thenStatement = parseStatement(tokens) var elseStatement: Statement? = null + var endLine = thenStatement.location.endLine + var endCol = thenStatement.location.endCol if (tokens.firstOrNull()?.type == TokenType.ELSE) { tokens.removeFirst() elseStatement = parseStatement(tokens) + endLine = elseStatement.location.endLine + endCol = elseStatement.location.endCol } - return IfStatement(condition, thenStatement, elseStatement) + return IfStatement(condition, thenStatement, elseStatement, SourceLocation(ifToken.startLine, ifToken.startColumn, endLine, endCol)) } TokenType.KEYWORD_RETURN -> { - tokens.removeFirst() + val returnToken = expect(TokenType.KEYWORD_RETURN, tokens) val expression = parseExpression(tokens = tokens) - expect(TokenType.SEMICOLON, tokens) + val semicolonToken = expect(TokenType.SEMICOLON, tokens) return ReturnStatement( - expression = expression + expression = expression, + SourceLocation(returnToken.startLine, returnToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn) ) } TokenType.GOTO -> { - tokens.removeFirst() + val gotoToken = expect(TokenType.GOTO, tokens) val label = parseIdentifier(tokens) - expect(TokenType.SEMICOLON, tokens) - return GotoStatement(label) + val semicolonToken = expect(TokenType.SEMICOLON, tokens) + return GotoStatement(label, SourceLocation(gotoToken.startLine, gotoToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn)) } TokenType.IDENTIFIER -> { // Handle labeled statements: IDENTIFIER followed by COLON if (secondToken?.type == TokenType.COLON) { - val labelName = parseIdentifier(tokens) + val labelToken = expect(TokenType.IDENTIFIER, tokens) + val labelName = labelToken.lexeme expect(TokenType.COLON, tokens) val statement = parseStatement(tokens) - return LabeledStatement(labelName, statement) + return LabeledStatement(labelName, statement, SourceLocation(labelToken.startLine, labelToken.startColumn, statement.location.endLine, statement.location.endCol)) } else { // Not a label, parse as expression statement by delegating to default branch val expression = parseOptionalExpression(tokens = tokens, followedByType = TokenType.SEMICOLON) - return if (expression != null) ExpressionStatement(expression) else NullStatement() + return if (expression != null) ExpressionStatement(expression, expression.location) else NullStatement(SourceLocation(0, 0, 0, 0)) } } TokenType.KEYWORD_BREAK -> { - tokens.removeFirst() - expect(TokenType.SEMICOLON, tokens) - return BreakStatement() + val breakToken = expect(TokenType.KEYWORD_BREAK, tokens) + val semicolonToken = expect(TokenType.SEMICOLON, tokens) + return BreakStatement("", SourceLocation(breakToken.startLine, breakToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn)) } TokenType.KEYWORD_CONTINUE -> { - tokens.removeFirst() - expect(TokenType.SEMICOLON, tokens) - return ContinueStatement() + val continueToken = expect(TokenType.KEYWORD_CONTINUE, tokens) + val semicolonToken = expect(TokenType.SEMICOLON, tokens) + return ContinueStatement("", SourceLocation(continueToken.startLine, continueToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn)) } TokenType.KEYWORD_WHILE -> { - tokens.removeFirst() + val whileToken = expect(TokenType.KEYWORD_WHILE, tokens) expect(TokenType.LEFT_PAREN, tokens) val condition = parseExpression(tokens = tokens) expect(TokenType.RIGHT_PAREN, tokens) val body = parseStatement(tokens) - return WhileStatement(condition, body) + return WhileStatement(condition, body, "", SourceLocation(whileToken.startLine, whileToken.startColumn, body.location.endLine, body.location.endCol)) } TokenType.KEYWORD_DO -> { - tokens.removeFirst() + val doToken = expect(TokenType.KEYWORD_DO, tokens) val body = parseStatement(tokens) expect(TokenType.KEYWORD_WHILE, tokens) expect(TokenType.LEFT_PAREN, tokens) val condition = parseExpression(tokens = tokens) expect(TokenType.RIGHT_PAREN, tokens) - expect(TokenType.SEMICOLON, tokens) - return DoWhileStatement(condition, body) + val semicolonToken = expect(TokenType.SEMICOLON, tokens) + return DoWhileStatement(condition, body, "", SourceLocation(doToken.startLine, doToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn)) } TokenType.KEYWORD_FOR -> { - tokens.removeFirst() + val forToken = expect(TokenType.KEYWORD_FOR, tokens) expect(TokenType.LEFT_PAREN, tokens) val init = parseForInit(tokens) val condition = parseOptionalExpression(tokens = tokens, followedByType = TokenType.SEMICOLON) @@ -260,29 +284,31 @@ class Parser { init = init, condition = condition, post = post, - body = body + body = body, + label = "", + SourceLocation(forToken.startLine, forToken.startColumn, body.location.endLine, body.location.endCol) ) } TokenType.LEFT_BRACK -> { val body = parseBlock(tokens) - return CompoundStatement(body) + return CompoundStatement(body, body.location) } else -> { val expression = parseOptionalExpression(tokens = tokens, followedByType = TokenType.SEMICOLON) - return if (expression != null) ExpressionStatement(expression) else NullStatement() + return if (expression != null) ExpressionStatement(expression, expression.location) else NullStatement(SourceLocation(0, 0, 0, 0)) } } } private fun parseForInit(tokens: MutableList): ForInit { if (tokens.firstOrNull()?.type == TokenType.KEYWORD_INT) { - expect(TokenType.KEYWORD_INT, tokens) + val start = expect(TokenType.KEYWORD_INT, tokens) val name = parseIdentifier(tokens) - val declaration = parseVariableDeclaration(tokens, name) - return InitDeclaration(declaration) + val declaration = parseVariableDeclaration(tokens, name, SourceLocation(start.startLine, start.startColumn, start.endLine, start.endColumn)) + return InitDeclaration(declaration, SourceLocation(start.startLine, start.startColumn, declaration.location.endLine, declaration.location.endCol)) } val expression = parseOptionalExpression(tokens = tokens, followedByType = TokenType.SEMICOLON) - return InitExpression(expression) + return InitExpression(expression, expression?.location ?: SourceLocation(0, 0, 0, 0)) } private fun parseExpression( @@ -305,20 +331,21 @@ class Parser { throw InvalidLValueException() } val right = parseExpression(prec, tokens) - AssignmentExpression(left, right) + AssignmentExpression(left, right, SourceLocation(left.location.startLine, left.location.startCol, right.location.endLine, right.location.endCol)) } TokenType.QUESTION_MARK -> { val thenExpression = parseExpression(prec, tokens) expect(TokenType.COLON, tokens) val elseExpression = parseExpression(prec, tokens) - return ConditionalExpression(left, thenExpression, elseExpression) + return ConditionalExpression(left, thenExpression, elseExpression, SourceLocation(left.location.startLine, left.location.startCol, elseExpression.location.endLine, elseExpression.location.endCol)) } else -> { val right = parseExpression(prec + 1, tokens) BinaryExpression( left = left, operator = op, - right = right + right = right, + SourceLocation(left.location.startLine, left.location.startCol, right.location.endLine, right.location.endCol) ) } } @@ -345,31 +372,31 @@ class Parser { when (nextToken.type) { TokenType.INT_LITERAL -> { nextToken = tokens.removeFirst() - return IntExpression(value = nextToken.lexeme.toInt()) + return IntExpression(value = nextToken.lexeme.toInt(), SourceLocation(nextToken.startLine, nextToken.startColumn, nextToken.endLine, nextToken.endColumn)) } TokenType.IDENTIFIER -> { nextToken = tokens.removeFirst() if (tokens.firstOrNull()?.type == TokenType.LEFT_PAREN) { // function call - tokens.removeFirst() // consume '(' + val leftParen = tokens.removeFirst() // consume '(' val args = mutableListOf() if (tokens.firstOrNull()?.type != TokenType.RIGHT_PAREN) { do { args.add(parseExpression(0, tokens)) } while (tokens.firstOrNull()?.type == TokenType.COMMA && tokens.removeFirst().type == TokenType.COMMA) } - expect(TokenType.RIGHT_PAREN, tokens) - return FunctionCall(nextToken.lexeme, args) + val rightParen = expect(TokenType.RIGHT_PAREN, tokens) + return FunctionCall(nextToken.lexeme, args, SourceLocation(nextToken.startLine, nextToken.startColumn, rightParen.endLine, rightParen.endColumn)) } else { // It's a variable - return VariableExpression(nextToken.lexeme) + return VariableExpression(nextToken.lexeme, SourceLocation(nextToken.startLine, nextToken.startColumn, nextToken.endLine, nextToken.endColumn)) } } TokenType.TILDE, TokenType.NEGATION, TokenType.NOT -> { val operator = tokens.removeFirst() val factor = parseFactor(tokens) - return UnaryExpression(operator = operator, expression = factor) + return UnaryExpression(operator = operator, expression = factor, SourceLocation(operator.startLine, operator.startColumn, factor.location.endLine, factor.location.endCol)) } TokenType.LEFT_PAREN -> { expect(TokenType.LEFT_PAREN, tokens) @@ -383,8 +410,8 @@ class Parser { expected = "${TokenType.INT_LITERAL}, ${TokenType.IDENTIFIER}, unary operator, ${TokenType.LEFT_PAREN}", actual = nToken.type.toString(), - line = nToken.line, - column = nToken.column + line = nToken.startLine, + column = nToken.startColumn ) } } diff --git a/src/jsMain/kotlin/parser/Programs.kt b/src/jsMain/kotlin/parser/Programs.kt index 2d6f3d9..05694cc 100644 --- a/src/jsMain/kotlin/parser/Programs.kt +++ b/src/jsMain/kotlin/parser/Programs.kt @@ -1,9 +1,10 @@ package parser -sealed class Program : ASTNode() +sealed class Program(location: SourceLocation) : ASTNode(location) data class SimpleProgram( - val functionDeclaration: List -) : Program() { + val functionDeclaration: List, + override val location: SourceLocation +) : Program(location) { override fun accept(visitor: Visitor): T = visitor.visit(this) } diff --git a/src/jsMain/kotlin/semanticAnalysis/IdentifierResolution.kt b/src/jsMain/kotlin/semanticAnalysis/IdentifierResolution.kt index 3a83dee..a406266 100644 --- a/src/jsMain/kotlin/semanticAnalysis/IdentifierResolution.kt +++ b/src/jsMain/kotlin/semanticAnalysis/IdentifierResolution.kt @@ -100,17 +100,17 @@ class IdentifierResolution : Visitor { override fun visit(node: SimpleProgram): ASTNode { val newDecls = node.functionDeclaration.map { it.accept(this) as FunctionDeclaration } - return SimpleProgram(newDecls) + return SimpleProgram(newDecls, node.location) } override fun visit(node: ReturnStatement): ASTNode { val exp = node.expression.accept(this) as Expression - return ReturnStatement(exp) + return ReturnStatement(exp, node.location) } override fun visit(node: ExpressionStatement): ASTNode { val exp = node.expression.accept(this) as Expression - return ExpressionStatement(exp) + return ExpressionStatement(exp, node.location) } override fun visit(node: NullStatement): ASTNode = node @@ -122,13 +122,13 @@ class IdentifierResolution : Visitor { override fun visit(node: WhileStatement): ASTNode { val cond = node.condition.accept(this) as Expression val newBody = node.body.accept(this) as Statement - return WhileStatement(cond, newBody, node.label) + return WhileStatement(cond, newBody, node.label, node.location) } override fun visit(node: DoWhileStatement): ASTNode { val cond = node.condition.accept(this) as Expression val newBody = node.body.accept(this) as Statement - return DoWhileStatement(cond, newBody, node.label) + return DoWhileStatement(cond, newBody, node.label, node.location) } override fun visit(node: ForStatement): ASTNode { @@ -142,17 +142,17 @@ class IdentifierResolution : Visitor { leaveScope() - return ForStatement(newInit, newCond, newPost, newBody) + return ForStatement(newInit, newCond, newPost, newBody, node.label, node.location) } override fun visit(node: InitDeclaration): ASTNode { val newDecl = node.varDeclaration.accept(this) as VariableDeclaration - return InitDeclaration(newDecl) + return InitDeclaration(newDecl, node.location) } override fun visit(node: InitExpression): ASTNode { val newExp = node.expression?.accept(this) as Expression? - return InitExpression(newExp) + return InitExpression(newExp, node.location) } override fun visit(node: FunctionDeclaration): ASTNode { @@ -164,7 +164,7 @@ class IdentifierResolution : Visitor { throw NestedFunctionException() } else { declare(node.name, hasLinkage = true) - return FunctionDeclaration(node.name, node.params, null) + return FunctionDeclaration(node.name, node.params, null, node.location) } } else { declare(node.name, hasLinkage = true) @@ -179,24 +179,24 @@ class IdentifierResolution : Visitor { leaveScope() - return FunctionDeclaration(node.name, newParams, newBody) + return FunctionDeclaration(node.name, newParams, newBody, node.location) } } override fun visit(node: VariableExpression): ASTNode { val symbol = resolve(node.name) - return VariableExpression(symbol.uniqueName) + return VariableExpression(symbol.uniqueName, node.location) } override fun visit(node: UnaryExpression): ASTNode { val exp = node.expression.accept(this) as Expression - return UnaryExpression(node.operator, exp) + return UnaryExpression(node.operator, exp, node.location) } override fun visit(node: BinaryExpression): ASTNode { val left = node.left.accept(this) as Expression val right = node.right.accept(this) as Expression - return BinaryExpression(left, node.operator, right) + return BinaryExpression(left, node.operator, right, node.location) } override fun visit(node: IntExpression): ASTNode = node @@ -205,27 +205,27 @@ class IdentifierResolution : Visitor { val condition = node.condition.accept(this) as Expression val thenStatement = node.then.accept(this) as Statement val elseStatement = node._else?.accept(this) as Statement? - return IfStatement(condition, thenStatement, elseStatement) + return IfStatement(condition, thenStatement, elseStatement, node.location) } override fun visit(node: ConditionalExpression): ASTNode { val condition = node.codition.accept(this) as Expression val thenExpression = node.thenExpression.accept(this) as Expression val elseExpression = node.elseExpression.accept(this) as Expression - return ConditionalExpression(condition, thenExpression, elseExpression) + return ConditionalExpression(condition, thenExpression, elseExpression, node.location) } override fun visit(node: GotoStatement): ASTNode = node override fun visit(node: LabeledStatement): ASTNode { val statement = node.statement.accept(this) as Statement - return LabeledStatement(node.label, statement) + return LabeledStatement(node.label, statement, node.location) } override fun visit(node: AssignmentExpression): ASTNode { val lvalue = node.lvalue.accept(this) as VariableExpression val rvalue = node.rvalue.accept(this) as Expression - return AssignmentExpression(lvalue, rvalue) + return AssignmentExpression(lvalue, rvalue, node.location) } override fun visit(node: VariableDeclaration): ASTNode { @@ -233,7 +233,7 @@ class IdentifierResolution : Visitor { val uniqueName = declare(node.name, hasLinkage = false) - return VariableDeclaration(uniqueName, newInit) + return VariableDeclaration(uniqueName, newInit, node.location) } override fun visit(node: S): ASTNode { @@ -267,20 +267,20 @@ class IdentifierResolution : Visitor { override fun visit(node: Block): ASTNode { enterScope() - val newItems = node.block.map { it.accept(this) as BlockItem } + val newItems = node.items.map { it.accept(this) as BlockItem } leaveScope() - return Block(newItems) + return Block(newItems, node.location) } override fun visit(node: CompoundStatement): ASTNode { val newBlock = node.block.accept(this) as Block - return CompoundStatement(newBlock) + return CompoundStatement(newBlock, node.location) } override fun visit(node: FunctionCall): ASTNode { val symbol = resolve(node.name) val newArgs = node.arguments.map { it.accept(this) as Expression } // The unique name for a function is just its original name. - return FunctionCall(symbol.uniqueName, newArgs) + return FunctionCall(symbol.uniqueName, newArgs, node.location) } } diff --git a/src/jsMain/kotlin/semanticAnalysis/LabelCollector.kt b/src/jsMain/kotlin/semanticAnalysis/LabelCollector.kt index ba515b6..12c1a26 100644 --- a/src/jsMain/kotlin/semanticAnalysis/LabelCollector.kt +++ b/src/jsMain/kotlin/semanticAnalysis/LabelCollector.kt @@ -135,7 +135,7 @@ class LabelCollector : Visitor { override fun visit(node: GotoStatement) {} override fun visit(node: Block) { - node.block.forEach { it.accept(this) } + node.items.forEach { it.accept(this) } } override fun visit(node: CompoundStatement) { @@ -250,7 +250,7 @@ class LabelCollector : Visitor { } override fun visit(node: Block) { - node.block.forEach { it.accept(this) } + node.items.forEach { it.accept(this) } } override fun visit(node: CompoundStatement) { diff --git a/src/jsMain/kotlin/semanticAnalysis/LoopLabeling.kt b/src/jsMain/kotlin/semanticAnalysis/LoopLabeling.kt index 40d48da..bc1eb91 100644 --- a/src/jsMain/kotlin/semanticAnalysis/LoopLabeling.kt +++ b/src/jsMain/kotlin/semanticAnalysis/LoopLabeling.kt @@ -166,7 +166,7 @@ class LoopLabeling : Visitor { } override fun visit(node: Block) { - node.block.forEach { it.accept(this) } + node.items.forEach { it.accept(this) } } override fun visit(node: CompoundStatement) { diff --git a/src/jsMain/kotlin/semanticAnalysis/TypeChecker.kt b/src/jsMain/kotlin/semanticAnalysis/TypeChecker.kt index 13f0cda..a58c97e 100644 --- a/src/jsMain/kotlin/semanticAnalysis/TypeChecker.kt +++ b/src/jsMain/kotlin/semanticAnalysis/TypeChecker.kt @@ -172,7 +172,7 @@ class TypeChecker : Visitor { } override fun visit(node: Block) { - node.block.forEach { it.accept(this) } + node.items.forEach { it.accept(this) } } override fun visit(node: CompoundStatement) { diff --git a/src/jsMain/kotlin/tacky/Instructions.kt b/src/jsMain/kotlin/tacky/Instructions.kt index b3b9967..69da434 100644 --- a/src/jsMain/kotlin/tacky/Instructions.kt +++ b/src/jsMain/kotlin/tacky/Instructions.kt @@ -1,10 +1,20 @@ package tacky -sealed class TackyInstruction : TackyConstruct() +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +@Serializable +sealed class TackyInstruction() : TackyConstruct() { + abstract val sourceId: String +} + +@Serializable +@SerialName("TackyRet") data class TackyRet( - val value: TackyVal + val value: TackyVal, + override val sourceId: String = "" ) : TackyInstruction() { + override fun toPseudoCode(indentationLevel: Int): String = "${indent(indentationLevel)}return ${value.toPseudoCode()}" } @@ -16,10 +26,13 @@ enum class TackyUnaryOP( NOT("!") } +@Serializable +@SerialName("TackyUnary") data class TackyUnary( val operator: TackyUnaryOP, val src: TackyVal, - val dest: TackyVar + val dest: TackyVar, + override val sourceId: String = "" ) : TackyInstruction() { override fun toPseudoCode(indentationLevel: Int): String = "${indent(indentationLevel)}${dest.toPseudoCode()} = ${operator.text}${src.toPseudoCode()}" @@ -41,52 +54,79 @@ enum class TackyBinaryOP( NOT_EQUAL("!=") } +@Serializable +@SerialName("TackyBinary") data class TackyBinary( val operator: TackyBinaryOP, val src1: TackyVal, val src2: TackyVal, - val dest: TackyVar + val dest: TackyVar, + override val sourceId: String = "" ) : TackyInstruction() { override fun toPseudoCode(indentationLevel: Int): String = "${indent(indentationLevel)}${dest.toPseudoCode()} = ${src1.toPseudoCode()} ${operator.text} ${src2.toPseudoCode()}" } +@Serializable +@SerialName("TackyCopy") data class TackyCopy( val src: TackyVal, - val dest: TackyVar + val dest: TackyVar, + override val sourceId: String = "" ) : TackyInstruction() { override fun toPseudoCode(indentationLevel: Int): String = "${indent(indentationLevel)}${dest.toPseudoCode()} = ${src.toPseudoCode()}" } +@Serializable +@SerialName("TackyJump") data class TackyJump( - val target: TackyLabel + val target: TackyLabel, + override val sourceId: String = "" ) : TackyInstruction() { override fun toPseudoCode(indentationLevel: Int): String = "${indent(indentationLevel)}goto ${target.name}" } +@Serializable +@SerialName("JumpIfZero") data class JumpIfZero( val condition: TackyVal, - val target: TackyLabel + val target: TackyLabel, + override val sourceId: String = "" ) : TackyInstruction() { override fun toPseudoCode(indentationLevel: Int): String = "${indent(indentationLevel)}if (${condition.toPseudoCode()} == 0) goto ${target.name}" } +@Serializable +@SerialName("JumpIfNotZero") data class JumpIfNotZero( val condition: TackyVal, - val target: TackyLabel + val target: TackyLabel, + override val sourceId: String = "" ) : TackyInstruction() { override fun toPseudoCode(indentationLevel: Int): String = "${indent(indentationLevel)}if (${condition.toPseudoCode()} != 0) goto ${target.name}" } +@Serializable +@SerialName("TackyFunCall") data class TackyFunCall( val funName: String, val args: List, - val dest: TackyVar + val dest: TackyVar, + override val sourceId: String = "" ) : TackyInstruction() { override fun toPseudoCode(indentationLevel: Int): String { val argString = args.joinToString(", ") { it.toPseudoCode() } return "${indent(indentationLevel)}${dest.toPseudoCode()} = $funName($argString)" } } + +@Serializable +@SerialName("TackyLabel") +data class TackyLabel( + val name: String, + override val sourceId: String = "" +) : TackyInstruction() { + override fun toPseudoCode(indentationLevel: Int): String = "$name:" +} diff --git a/src/jsMain/kotlin/tacky/Tacky.kt b/src/jsMain/kotlin/tacky/Tacky.kt index f1a6699..2954c2f 100644 --- a/src/jsMain/kotlin/tacky/Tacky.kt +++ b/src/jsMain/kotlin/tacky/Tacky.kt @@ -1,41 +1,49 @@ package tacky -sealed class TackyConstruct { +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +sealed class TackyConstruct() { abstract fun toPseudoCode(indentationLevel: Int = 0): String protected fun indent(level: Int): String = " ".repeat(level) } -sealed class TackyVal : TackyConstruct() +@Serializable +sealed class TackyVal() : TackyConstruct() +@Serializable +@SerialName("TackyConstant") data class TackyConstant( val value: Int ) : TackyVal() { override fun toPseudoCode(indentationLevel: Int): String = value.toString() } +@Serializable +@SerialName("TackyVar") data class TackyVar( val name: String ) : TackyVal() { override fun toPseudoCode(indentationLevel: Int): String = name } -data class TackyLabel( - val name: String -) : TackyInstruction() { - override fun toPseudoCode(indentationLevel: Int): String = "$name:" -} - +@Serializable +@SerialName("TackyProgram") data class TackyProgram( val functions: List ) : TackyConstruct() { override fun toPseudoCode(indentationLevel: Int): String = functions.joinToString("\n\n") { it.toPseudoCode(indentationLevel) } } +@Serializable +@SerialName("TackyFunction") data class TackyFunction( val name: String, val args: List, - val body: List + val body: List, + val sourceId: String = "" ) : TackyConstruct() { override fun toPseudoCode(indentationLevel: Int): String { val paramString = args.joinToString(", ") diff --git a/src/jsMain/kotlin/tacky/TackyGenVisitor.kt b/src/jsMain/kotlin/tacky/TackyGenVisitor.kt index 4541a9b..151572f 100644 --- a/src/jsMain/kotlin/tacky/TackyGenVisitor.kt +++ b/src/jsMain/kotlin/tacky/TackyGenVisitor.kt @@ -39,7 +39,7 @@ class TackyGenVisitor : Visitor { private fun newTemporary(): TackyVar = TackyVar("tmp.${tempCounter++}") - private fun newLabel(base: String): TackyLabel = TackyLabel(".L_${base}_${labelCounter++}") + private fun newLabel(base: String, sourceId: String = ""): TackyLabel = TackyLabel(".L_${base}_${labelCounter++}", sourceId) private val currentInstructions = mutableListOf() @@ -98,7 +98,7 @@ class TackyGenVisitor : Visitor { override fun visit(node: ReturnStatement): TackyConstruct { val value = node.expression.accept(this) as TackyVal - val instr = TackyRet(value) + val instr = TackyRet(value, node.id) currentInstructions += instr return instr } @@ -111,56 +111,56 @@ class TackyGenVisitor : Visitor { override fun visit(node: NullStatement): TackyConstruct? = null override fun visit(node: BreakStatement): TackyConstruct? { - val breakLabel = TackyLabel("break_${node.label}") - currentInstructions += TackyJump(breakLabel) + val breakLabel = TackyLabel("break_${node.label}", node.id) + currentInstructions += TackyJump(breakLabel, node.id) return null } override fun visit(node: ContinueStatement): TackyConstruct? { - val continueLabel = TackyLabel("continue_${node.label}") - currentInstructions += TackyJump(continueLabel) + val continueLabel = TackyLabel("continue_${node.label}", node.id) + currentInstructions += TackyJump(continueLabel, node.id) return null } override fun visit(node: WhileStatement): TackyConstruct? { - val continueLabel = TackyLabel("continue_${node.label}") - val breakLabel = TackyLabel("break_${node.label}") + val continueLabel = TackyLabel("continue_${node.label}", node.id) + val breakLabel = TackyLabel("break_${node.label}", node.id) currentInstructions += continueLabel val condition = node.condition.accept(this) as TackyVal - currentInstructions += JumpIfZero(condition, breakLabel) + currentInstructions += JumpIfZero(condition, breakLabel, node.id) node.body.accept(this) - currentInstructions += TackyJump(continueLabel) + currentInstructions += TackyJump(continueLabel, node.id) currentInstructions += breakLabel return null } override fun visit(node: DoWhileStatement): TackyConstruct? { - val startLabel = TackyLabel("start_${node.label}") - val continueLabel = TackyLabel("continue_${node.label}") - val breakLabel = TackyLabel("break_${node.label}") + val startLabel = TackyLabel("start_${node.label}", node.id) + val continueLabel = TackyLabel("continue_${node.label}", node.id) + val breakLabel = TackyLabel("break_${node.label}", node.id) currentInstructions += startLabel node.body.accept(this) currentInstructions += continueLabel val condition = node.condition.accept(this) as TackyVal - currentInstructions += JumpIfNotZero(condition, startLabel) + currentInstructions += JumpIfNotZero(condition, startLabel, node.id) currentInstructions += breakLabel return null } override fun visit(node: ForStatement): TackyConstruct? { - val startLabel = TackyLabel("start_${node.label}") - val continueLabel = TackyLabel("continue_${node.label}") - val breakLabel = TackyLabel("break_${node.label}") + val startLabel = TackyLabel("start_${node.label}", node.id) + val continueLabel = TackyLabel("continue_${node.label}", node.id) + val breakLabel = TackyLabel("break_${node.label}", node.id) node.init.accept(this) currentInstructions += startLabel if (node.condition != null) { val condition = node.condition.accept(this) as TackyVal - currentInstructions += JumpIfZero(condition, breakLabel) + currentInstructions += JumpIfZero(condition, breakLabel, node.id) } node.body.accept(this) currentInstructions += continueLabel node.post?.accept(this) - currentInstructions += TackyJump(startLabel) + currentInstructions += TackyJump(startLabel, node.id) currentInstructions += breakLabel return null } @@ -183,9 +183,9 @@ class TackyGenVisitor : Visitor { node.body?.accept(this) if (currentInstructions.lastOrNull() !is TackyRet) { - currentInstructions += TackyRet(TackyConstant(0)) + currentInstructions += TackyRet(TackyConstant(0), node.id) } - return TackyFunction(functionName, functionParams, currentInstructions.toList()) + return TackyFunction(functionName, functionParams, currentInstructions.toList(), node.id) } override fun visit(node: VarDecl): TackyConstruct? { @@ -204,42 +204,42 @@ class TackyGenVisitor : Visitor { val src = node.expression.accept(this) as TackyVal val dst = newTemporary() val op = convertUnaryOp(node.operator.type) - currentInstructions += TackyUnary(op, src, dst) + currentInstructions += TackyUnary(op, src, dst, node.id) return dst } override fun visit(node: BinaryExpression): TackyConstruct { when (node.operator.type) { TokenType.AND -> { - val falseLabel = newLabel("and_false") - val endLabel = newLabel("and_end") + val falseLabel = newLabel("and_false", node.id) + val endLabel = newLabel("and_end", node.id) val resultVar = newTemporary() val left = node.left.accept(this) as TackyVal - currentInstructions += JumpIfZero(left, falseLabel) + currentInstructions += JumpIfZero(left, falseLabel, node.id) val right = node.right.accept(this) as TackyVal - currentInstructions += JumpIfZero(right, falseLabel) - currentInstructions += TackyCopy(TackyConstant(1), resultVar) - currentInstructions += TackyJump(endLabel) + currentInstructions += JumpIfZero(right, falseLabel, node.id) + currentInstructions += TackyCopy(TackyConstant(1), resultVar, node.id) + currentInstructions += TackyJump(endLabel, node.id) currentInstructions += falseLabel - currentInstructions += TackyCopy(TackyConstant(0), resultVar) + currentInstructions += TackyCopy(TackyConstant(0), resultVar, node.id) currentInstructions += endLabel return resultVar } TokenType.OR -> { - val trueLabel = newLabel("or_true") - val endLabel = newLabel("or_end") + val trueLabel = newLabel("or_true", node.id) + val endLabel = newLabel("or_end", node.id) val resultVar = newTemporary() val left = node.left.accept(this) as TackyVal - currentInstructions += JumpIfNotZero(left, trueLabel) + currentInstructions += JumpIfNotZero(left, trueLabel, node.id) val right = node.right.accept(this) as TackyVal - currentInstructions += JumpIfNotZero(right, trueLabel) - currentInstructions += TackyCopy(TackyConstant(0), resultVar) - currentInstructions += TackyJump(endLabel) + currentInstructions += JumpIfNotZero(right, trueLabel, node.id) + currentInstructions += TackyCopy(TackyConstant(0), resultVar, node.id) + currentInstructions += TackyJump(endLabel, node.id) currentInstructions += trueLabel - currentInstructions += TackyCopy(TackyConstant(1), resultVar) + currentInstructions += TackyCopy(TackyConstant(1), resultVar, node.id) currentInstructions += endLabel return resultVar @@ -250,7 +250,7 @@ class TackyGenVisitor : Visitor { val op = convertBinaryOp(node.operator.type) val dst = newTemporary() - currentInstructions += TackyBinary(operator = op, src1 = src1, src2 = src2, dest = dst) + currentInstructions += TackyBinary(operator = op, src1 = src1, src2 = src2, dest = dst, sourceId = node.id) return dst } } @@ -259,18 +259,18 @@ class TackyGenVisitor : Visitor { override fun visit(node: IntExpression): TackyConstruct = TackyConstant(node.value) override fun visit(node: IfStatement): TackyConstruct? { - val endLabel = newLabel("end") + val endLabel = newLabel("end", node.id) val condition = node.condition.accept(this) as TackyVal if (node._else == null) { - currentInstructions += JumpIfZero(condition, endLabel) + currentInstructions += JumpIfZero(condition, endLabel, node.id) node.then.accept(this) currentInstructions += endLabel } else { - val elseLabel = newLabel("else_label") - currentInstructions += JumpIfZero(condition, elseLabel) + val elseLabel = newLabel("else_label", node.id) + currentInstructions += JumpIfZero(condition, elseLabel, node.id) node.then.accept(this) - currentInstructions += TackyJump(endLabel) + currentInstructions += TackyJump(endLabel, node.id) currentInstructions += elseLabel node._else.accept(this) currentInstructions += endLabel @@ -281,31 +281,31 @@ class TackyGenVisitor : Visitor { override fun visit(node: ConditionalExpression): TackyConstruct? { val resultVar = newTemporary() - val elseLabel = newLabel("cond_else") - val endLabel = newLabel("cond_end") + val elseLabel = newLabel("cond_else", node.id) + val endLabel = newLabel("cond_end", node.id) val conditionResult = node.codition.accept(this) as TackyVal - currentInstructions += JumpIfZero(conditionResult, elseLabel) + currentInstructions += JumpIfZero(conditionResult, elseLabel, node.id) val thenResult = node.thenExpression.accept(this) as TackyVal - currentInstructions += TackyCopy(thenResult, resultVar) - currentInstructions += TackyJump(endLabel) + currentInstructions += TackyCopy(thenResult, resultVar, node.id) + currentInstructions += TackyJump(endLabel, node.id) currentInstructions += elseLabel val elseResult = node.elseExpression.accept(this) as TackyVal - currentInstructions += TackyCopy(elseResult, resultVar) + currentInstructions += TackyCopy(elseResult, resultVar, node.id) currentInstructions += endLabel return resultVar } override fun visit(node: GotoStatement): TackyConstruct? { - currentInstructions += TackyJump(TackyLabel(node.label)) + currentInstructions += TackyJump(TackyLabel(node.label, node.id), node.id) return null } override fun visit(node: LabeledStatement): TackyConstruct? { - // val label = newLabel(node.label) - currentInstructions += TackyLabel(node.label) + val label = node.label + currentInstructions += TackyLabel(node.label, node.id) node.statement.accept(this) return null } @@ -313,7 +313,7 @@ class TackyGenVisitor : Visitor { override fun visit(node: AssignmentExpression): TackyConstruct { val rvalue = node.rvalue.accept(this) as TackyVal val dest = TackyVar(node.lvalue.name) - currentInstructions += TackyCopy(rvalue, dest) + currentInstructions += TackyCopy(rvalue, dest, node.id) return dest } @@ -321,7 +321,7 @@ class TackyGenVisitor : Visitor { if (node.init != null) { val initVal = node.init.accept(this) as TackyVal // The `node.name` is already the unique name from IdentifierResolution - currentInstructions += TackyCopy(initVal, TackyVar(node.name)) + currentInstructions += TackyCopy(initVal, TackyVar(node.name), node.id) } return null } @@ -337,7 +337,7 @@ class TackyGenVisitor : Visitor { } override fun visit(node: Block): TackyConstruct? { - node.block.forEach { it.accept(this) } + node.items.forEach { it.accept(this) } return null } @@ -349,7 +349,7 @@ class TackyGenVisitor : Visitor { override fun visit(node: FunctionCall): TackyConstruct? { val args = node.arguments.map { it.accept(this) as TackyVal } val dest = newTemporary() - currentInstructions += TackyFunCall(node.name, args, dest) + currentInstructions += TackyFunCall(node.name, args, dest, node.id) return dest } } diff --git a/src/jsMain/kotlin/tacky/TackyToAsm.kt b/src/jsMain/kotlin/tacky/TackyToAsm.kt index 7cdb5d8..8871554 100644 --- a/src/jsMain/kotlin/tacky/TackyToAsm.kt +++ b/src/jsMain/kotlin/tacky/TackyToAsm.kt @@ -36,13 +36,13 @@ class TackyToAsm { } private fun convertFunction(tackyFunc: TackyFunction): AsmFunction { - val paramSetupInstructions = generateParamSetup(tackyFunc.args) + val paramSetupInstructions = generateParamSetup(tackyFunc.args, tackyFunc.sourceId) // Convert the rest of the body as before val bodyInstructions = tackyFunc.body.flatMap { convertInstruction(it) } return AsmFunction(tackyFunc.name, paramSetupInstructions + bodyInstructions) } - private fun generateParamSetup(params: List): List { + private fun generateParamSetup(params: List, sourceId: String): List { val instructions = mutableListOf() val argRegisters = listOf( @@ -58,12 +58,12 @@ class TackyToAsm { if (index < argRegisters.size) { val srcReg = Register(argRegisters[index]) val destPseudo = Pseudo(paramName) - instructions.add(Mov(srcReg, destPseudo)) + instructions.add(Mov(srcReg, destPseudo, sourceId)) } else { val stackOffset = 16 + (index - argRegisters.size) * 8 val srcStack = Stack(stackOffset) val destPseudo = Pseudo(paramName) - instructions.add(Mov(srcStack, destPseudo)) + instructions.add(Mov(srcStack, destPseudo, sourceId)) } } return instructions @@ -73,8 +73,8 @@ class TackyToAsm { when (tackyInstr) { is TackyRet -> { listOf( - Mov(convertVal(tackyInstr.value), Register(HardwareRegister.EAX)), - Ret + Mov(convertVal(tackyInstr.value), Register(HardwareRegister.EAX), tackyInstr.sourceId), + Ret(tackyInstr.sourceId) ) } is TackyUnary -> @@ -83,16 +83,16 @@ class TackyToAsm { val src = convertVal(tackyInstr.src) val dest = convertVal(tackyInstr.dest) listOf( - Cmp(Imm(0), src), - Mov(Imm(0), dest), // Zero out destination - SetCC(ConditionCode.E, dest) // Set if equal to zero + Cmp(Imm(0), src, tackyInstr.sourceId), + Mov(Imm(0), dest, tackyInstr.sourceId), // Zero out destination + SetCC(ConditionCode.E, dest, tackyInstr.sourceId) // Set if equal to zero ) } else -> { val destOperand = convertVal(tackyInstr.dest) listOf( - Mov(convertVal(tackyInstr.src), destOperand), - AsmUnary(convertOp(tackyInstr.operator), destOperand) + Mov(convertVal(tackyInstr.src), destOperand, tackyInstr.sourceId), + AsmUnary(convertOp(tackyInstr.operator), destOperand, tackyInstr.sourceId) ) } } @@ -105,32 +105,32 @@ class TackyToAsm { when (tackyInstr.operator) { TackyBinaryOP.ADD -> listOf( - Mov(src1, dest), - AsmBinary(AsmBinaryOp.ADD, src2, dest) + Mov(src1, dest, tackyInstr.sourceId), + AsmBinary(AsmBinaryOp.ADD, src2, dest, tackyInstr.sourceId) ) TackyBinaryOP.MULTIPLY -> listOf( - Mov(src1, dest), - AsmBinary(AsmBinaryOp.MUL, src2, dest) + Mov(src1, dest, tackyInstr.sourceId), + AsmBinary(AsmBinaryOp.MUL, src2, dest, tackyInstr.sourceId) ) TackyBinaryOP.SUBTRACT -> listOf( - Mov(src1, dest), - AsmBinary(AsmBinaryOp.SUB, src2, dest) + Mov(src1, dest, tackyInstr.sourceId), + AsmBinary(AsmBinaryOp.SUB, src2, dest, tackyInstr.sourceId) ) TackyBinaryOP.DIVIDE -> listOf( - Mov(src1, Register(HardwareRegister.EAX)), // Dividend in EAX - Cdq, - Idiv(src2), // Divisor - Mov(Register(HardwareRegister.EAX), dest) // Quotient result is in EAX + Mov(src1, Register(HardwareRegister.EAX), tackyInstr.sourceId), // Dividend in EAX + Cdq(tackyInstr.sourceId), + Idiv(src2, tackyInstr.sourceId), // Divisor + Mov(Register(HardwareRegister.EAX), dest, tackyInstr.sourceId) // Quotient result is in EAX ) TackyBinaryOP.REMAINDER -> listOf( - Mov(src1, Register(HardwareRegister.EAX)), - Cdq, - Idiv(src2), - Mov(Register(HardwareRegister.EDX), dest) // Remainder result is in EDX + Mov(src1, Register(HardwareRegister.EAX), tackyInstr.sourceId), + Cdq(tackyInstr.sourceId), + Idiv(src2, tackyInstr.sourceId), + Mov(Register(HardwareRegister.EDX), dest, tackyInstr.sourceId) // Remainder result is in EDX ) TackyBinaryOP.EQUAL, TackyBinaryOP.NOT_EQUAL, TackyBinaryOP.GREATER, @@ -146,31 +146,35 @@ class TackyToAsm { else -> throw IllegalStateException("Unreachable: This case is logically impossible.") } listOf( - Cmp(src2, src1), - Mov(Imm(0), dest), // Zero out destination - SetCC(condition, dest) // conditionally set the low byte to 0/1 + Cmp(src2, src1, tackyInstr.sourceId), + Mov(Imm(0), dest, tackyInstr.sourceId), // Zero out destination + SetCC(condition, dest, tackyInstr.sourceId) // conditionally set the low byte to 0/1 ) } } } - is JumpIfNotZero -> + is JumpIfZero -> { + val condition = convertVal(tackyInstr.condition) + val target = Label(tackyInstr.target.name, tackyInstr.sourceId) listOf( - Cmp(Imm(0), convertVal(tackyInstr.condition)), - JmpCC(ConditionCode.NE, Label(tackyInstr.target.name)) + Cmp(Imm(0), condition, tackyInstr.sourceId), + JmpCC(ConditionCode.E, target, tackyInstr.sourceId) ) - - is JumpIfZero -> + } + is JumpIfNotZero -> { + val condition = convertVal(tackyInstr.condition) + val target = Label(tackyInstr.target.name, tackyInstr.sourceId) listOf( - Cmp(Imm(0), convertVal(tackyInstr.condition)), - JmpCC(ConditionCode.E, Label(tackyInstr.target.name)) + Cmp(Imm(0), condition, tackyInstr.sourceId), + JmpCC(ConditionCode.NE, target, tackyInstr.sourceId) ) + } + is TackyCopy -> listOf(Mov(convertVal(tackyInstr.src), convertVal(tackyInstr.dest), tackyInstr.sourceId)) - is TackyCopy -> listOf(Mov(convertVal(tackyInstr.src), convertVal(tackyInstr.dest))) - - is TackyJump -> listOf(Jmp(Label(tackyInstr.target.name))) + is TackyJump -> listOf(Jmp(Label(tackyInstr.target.name, tackyInstr.sourceId), tackyInstr.sourceId)) - is TackyLabel -> listOf(Label(tackyInstr.name)) + is TackyLabel -> listOf(Label(tackyInstr.name, "")) is TackyFunCall -> { val instructions = mutableListOf() @@ -190,36 +194,36 @@ class TackyToAsm { // Adjust stack alignment val stackPadding = if (stackArgs.size % 2 != 0) 8 else 0 if (stackPadding > 0) { - instructions.add(AllocateStack(stackPadding)) + instructions.add(AllocateStack(stackPadding, tackyInstr.sourceId)) } // Pass arguments on the stack in reverse order stackArgs.asReversed().forEach { arg -> val asmArg = convertVal(arg) if (asmArg is Stack) { - instructions.add(Mov(asmArg, Register(HardwareRegister.EAX))) - instructions.add(Push(Register(HardwareRegister.EAX))) + instructions.add(Mov(asmArg, Register(HardwareRegister.EAX), tackyInstr.sourceId)) + instructions.add(Push(Register(HardwareRegister.EAX), tackyInstr.sourceId)) } else { - instructions.add(Push(asmArg)) + instructions.add(Push(asmArg, tackyInstr.sourceId)) } } // Pass arguments in registers registerArgs.forEachIndexed { index, arg -> val asmArg = convertVal(arg) - instructions.add(Mov(asmArg, Register(argRegisters[index]))) + instructions.add(Mov(asmArg, Register(argRegisters[index]), tackyInstr.sourceId)) } - instructions.add(Call(tackyInstr.funName)) + instructions.add(Call(tackyInstr.funName, tackyInstr.sourceId)) // Clean up stack val bytesToRemove = stackArgs.size * 8 + stackPadding if (bytesToRemove > 0) { - instructions.add(DeAllocateStack(bytesToRemove)) + instructions.add(DeAllocateStack(bytesToRemove, tackyInstr.sourceId)) } // Retrieve return value - instructions.add(Mov(Register(HardwareRegister.EAX), convertVal(tackyInstr.dest))) + instructions.add(Mov(Register(HardwareRegister.EAX), convertVal(tackyInstr.dest), tackyInstr.sourceId)) instructions } diff --git a/src/jsTest/kotlin/export/ASTExportTest.kt b/src/jsTest/kotlin/export/ASTExportTest.kt new file mode 100644 index 0000000..7c488d9 --- /dev/null +++ b/src/jsTest/kotlin/export/ASTExportTest.kt @@ -0,0 +1,74 @@ +package export + +import kotlinx.serialization.json.JsonObject +import parser.IntExpression +import parser.SourceLocation +import parser.VariableExpression +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class ASTExportTest { + + @Test + fun testASTExportIncludesLocationAndId() { + // Create a simple test AST node with location and ID + val location = SourceLocation(1, 1, 1, 5) + val intExpr = IntExpression(42, location) + + // Export the AST + val export = ASTExport() + val jsonResult = intExpr.accept(export) + + // Use the JSON result directly (it's already a JsonObject) + val json = jsonResult + + // Verify that location and ID are included + assertTrue(json.containsKey("location"), "JSON should contain location information") + assertTrue(json.containsKey("id"), "JSON should contain ID information") + + // Check location details + val loc = json["location"] as JsonObject + assertEquals(1, loc["startLine"]?.toString()?.toInt(), "Start line should be 1") + assertEquals(1, loc["startCol"]?.toString()?.toInt(), "Start column should be 1") + assertEquals(1, loc["endLine"]?.toString()?.toInt(), "End line should be 1") + assertEquals(5, loc["endCol"]?.toString()?.toInt(), "End column should be 5") + + // Check that ID is not empty + val id = json["id"]?.toString()?.removeSurrounding("\"") + assertTrue(!id.isNullOrEmpty(), "ID should not be empty") + + println("Test passed! Location and ID are included in AST export:") + println("Location: startLine=1, startCol=1, endLine=1, endCol=5") + println("ID: $id") + } + + @Test + fun testASTExportForVariableExpression() { + // Create a variable expression with location + val location = SourceLocation(2, 5, 2, 8) + val varExpr = VariableExpression("test", location) + + // Export the AST + val export = ASTExport() + val jsonResult = varExpr.accept(export) + + // Use the JSON result directly (it's already a JsonObject) + val json = jsonResult + + // Verify basic structure + assertEquals("Expression", json["type"]?.toString()?.removeSurrounding("\"")) + assertEquals("Variable(test)", json["label"]?.toString()?.removeSurrounding("\"")) + + // Verify location and ID are included + assertTrue(json.containsKey("location"), "Variable expression should have location") + assertTrue(json.containsKey("id"), "Variable expression should have ID") + + // Check location details + val loc = json["location"] as JsonObject + assertEquals(2, loc["startLine"]?.toString()?.toInt(), "Start line should be 2") + assertEquals(5, loc["startCol"]?.toString()?.toInt(), "Start column should be 5") + assertEquals(2, loc["endLine"]?.toString()?.toInt(), "End line should be 2") + assertEquals(8, loc["endCol"]?.toString()?.toInt(), "End column should be 8") + } +} diff --git a/src/jsTest/kotlin/export/CompilationOutputTest.kt b/src/jsTest/kotlin/export/CompilationOutputTest.kt new file mode 100644 index 0000000..77f7b0b --- /dev/null +++ b/src/jsTest/kotlin/export/CompilationOutputTest.kt @@ -0,0 +1,112 @@ +package export + +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class CompilationOutputTest { + + @Test + fun testSourceLocationInfoStructure() { + val sourceLocation = SourceLocationInfo( + startLine = 1, + startColumn = 1, + endLine = 5, + endColumn = 10, + totalLines = 5 + ) + + // Test that all fields are accessible + assertEquals(1, sourceLocation.startLine) + assertEquals(1, sourceLocation.startColumn) + assertEquals(5, sourceLocation.endLine) + assertEquals(10, sourceLocation.endColumn) + assertEquals(5, sourceLocation.totalLines) + } + + @Test + fun testLexerOutputIncludesSourceLocation() { + val sourceLocation = SourceLocationInfo(1, 1, 3, 15, 3) + val lexerOutput = LexerOutput( + tokens = "[]", + errors = emptyArray(), + sourceLocation = sourceLocation + ) + + assertEquals("lexer", lexerOutput.stage) + assertEquals("[]", lexerOutput.tokens) + assertTrue(lexerOutput.errors.isEmpty()) + assertEquals(sourceLocation, lexerOutput.sourceLocation) + } + + @Test + fun testParserOutputIncludesSourceLocation() { + val sourceLocation = SourceLocationInfo(1, 1, 3, 15, 3) + val parserOutput = ParserOutput( + ast = "{}", + errors = emptyArray(), + sourceLocation = sourceLocation + ) + + assertEquals("parser", parserOutput.stage) + assertEquals("{}", parserOutput.ast) + assertTrue(parserOutput.errors.isEmpty()) + assertEquals(sourceLocation, parserOutput.sourceLocation) + } + + @Test + fun testTackyOutputIncludesSourceLocation() { + val sourceLocation = SourceLocationInfo(1, 1, 3, 15, 3) + val tackyOutput = TackyOutput( + tackyPretty = "tacky code", + errors = emptyArray(), + sourceLocation = sourceLocation + ) + + assertEquals("tacky", tackyOutput.stage) + assertEquals("tacky code", tackyOutput.tackyPretty) + assertTrue(tackyOutput.errors.isEmpty()) + assertEquals(sourceLocation, tackyOutput.sourceLocation) + } + + @Test + fun testAssemblyOutputIncludesSourceLocation() { + val sourceLocation = SourceLocationInfo(1, 1, 3, 15, 3) + val assemblyOutput = AssemblyOutput( + assembly = "mov eax, 1", + errors = emptyArray(), + sourceLocation = sourceLocation + ) + + assertEquals("assembly", assemblyOutput.stage) + assertEquals("mov eax, 1", assemblyOutput.assembly) + assertTrue(assemblyOutput.errors.isEmpty()) + assertEquals(sourceLocation, assemblyOutput.sourceLocation) + } + + @Test + fun testSourceLocationInfoSerialization() { + val sourceLocation = SourceLocationInfo(1, 1, 5, 20, 5) + val lexerOutput = LexerOutput( + tokens = "[]", + errors = emptyArray(), + sourceLocation = sourceLocation + ) + + // Test that the output can be serialized to JSON + val result = CompilationResult( + outputs = arrayOf(lexerOutput), + overallSuccess = true, + overallErrors = emptyArray() + ) + + val jsonString = result.toJsonString() + assertTrue(jsonString.isNotEmpty(), "JSON string should not be empty") + + // Parse back to verify structure + val json = Json.parseToJsonElement(jsonString) as JsonObject + assertTrue(json.containsKey("outputs"), "JSON should contain outputs") + } +} diff --git a/src/jsTest/kotlin/export/CompilerExportTest.kt b/src/jsTest/kotlin/export/CompilerExportTest.kt index ce58c17..b94689e 100644 --- a/src/jsTest/kotlin/export/CompilerExportTest.kt +++ b/src/jsTest/kotlin/export/CompilerExportTest.kt @@ -1,53 +1,82 @@ package export +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonObject +import lexer.Token +import lexer.TokenType import kotlin.test.Test -import kotlin.test.assertNotNull +import kotlin.test.assertEquals import kotlin.test.assertTrue class CompilerExportTest { - private val compilerExport = CompilerExport() @Test - fun `test successful compilation`() { - val code = - """ - int main(void) { - return 5; - } - """.trimIndent() - - val result = compilerExport.exportCompilationResults(code) - assertNotNull(result) - assertTrue(result.isNotEmpty()) - - // Verify it's valid JSON - assertTrue(result.contains("outputs")) - assertTrue(result.contains("overallSuccess")) - assertTrue(result.contains("overallErrors")) - } + fun testTokenExportIncludesCompleteLocationInfo() { + // Create a test token with complete location information + val token = Token( + type = TokenType.IDENTIFIER, + lexeme = "testVar", + startLine = 2, + startColumn = 5, + endLine = 2, + endColumn = 12 + ) - @Test - fun `test compilation with syntax error`() { - val code = - """ - int main(void) { - return 5 - } - """.trimIndent() - // Missing semicolon - - val result = compilerExport.exportCompilationResults(code) - assertNotNull(result) - assertTrue(result.isNotEmpty()) - - // Should contain error information - assertTrue(result.contains("overallErrors")) + // Export the token + val jsonResult = listOf(token).toJsonString() + + // Parse the JSON result + val jsonArray = Json.parseToJsonElement(jsonResult) as JsonArray + val tokenJson = jsonArray[0] as JsonObject + + // Verify basic token information + assertEquals("IDENTIFIER", tokenJson["type"]?.toString()?.removeSurrounding("\"")) + assertEquals("testVar", tokenJson["lexeme"]?.toString()?.removeSurrounding("\"")) + + // Verify that complete location information is included + assertTrue(tokenJson.containsKey("location"), "Token JSON should contain location information") + + val location = tokenJson["location"] as JsonObject + assertEquals(2, location["startLine"]?.toString()?.toInt(), "Start line should be 2") + assertEquals(5, location["startCol"]?.toString()?.toInt(), "Start column should be 5") + assertEquals(2, location["endLine"]?.toString()?.toInt(), "End line should be 2") + assertEquals(12, location["endCol"]?.toString()?.toInt(), "End column should be 12") + + println("Test passed! Token export includes complete location information:") + println("Location: startLine=2, startCol=5, endLine=2, endCol=12") } @Test - fun `test empty code`() { - val result = compilerExport.exportCompilationResults("") - assertNotNull(result) - assertTrue(result.isNotEmpty()) + fun testTokenExportStructure() { + // Create multiple tokens to test array structure + val tokens = listOf( + Token(TokenType.KEYWORD_INT, "int", 1, 1, 1, 3), + Token(TokenType.IDENTIFIER, "main", 1, 5, 1, 8), + Token(TokenType.LEFT_PAREN, "(", 1, 9, 1, 9) + ) + + // Export the tokens + val jsonResult = tokens.toJsonString() + + // Parse the JSON result + val jsonArray = Json.parseToJsonElement(jsonResult) as JsonArray + + // Verify array structure + assertEquals(3, jsonArray.size, "Should have 3 tokens") + + // Verify each token has the expected structure + for (i in 0 until jsonArray.size) { + val tokenJson = jsonArray[i] as JsonObject + assertTrue(tokenJson.containsKey("type"), "Token $i should have type") + assertTrue(tokenJson.containsKey("lexeme"), "Token $i should have lexeme") + assertTrue(tokenJson.containsKey("location"), "Token $i should have location") + + val location = tokenJson["location"] as JsonObject + assertTrue(location.containsKey("startLine"), "Token $i location should have startLine") + assertTrue(location.containsKey("startCol"), "Token $i location should have startCol") + assertTrue(location.containsKey("endLine"), "Token $i location should have endLine") + assertTrue(location.containsKey("endCol"), "Token $i location should have endCol") + } } } diff --git a/src/jsTest/kotlin/integration/CompilerTestSuite.kt b/src/jsTest/kotlin/integration/CompilerTestSuite.kt index e5bade6..bd3c3b3 100644 --- a/src/jsTest/kotlin/integration/CompilerTestSuite.kt +++ b/src/jsTest/kotlin/integration/CompilerTestSuite.kt @@ -9,8 +9,440 @@ import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertIs +import kotlin.test.assertTrue class CompilerTestSuite { + + private fun assertAstEquals(expected: SimpleProgram, actual: SimpleProgram, message: String) { + // Compare function declarations + assertEquals(expected.functionDeclaration.size, actual.functionDeclaration.size, "$message - Function count mismatch") + + expected.functionDeclaration.forEachIndexed { index, expectedFunc -> + val actualFunc = actual.functionDeclaration[index] + assertEquals(expectedFunc.name, actualFunc.name, "$message - Function name mismatch at index $index") + assertEquals(expectedFunc.params, actualFunc.params, "$message - Function params mismatch at index $index") + + // Compare function bodies (ignoring SourceLocation) + if (expectedFunc.body != null && actualFunc.body != null) { + assertBlockEquals(expectedFunc.body, actualFunc.body, "$message - Function body mismatch at index $index") + } else { + assertEquals(expectedFunc.body, actualFunc.body, "$message - Function body null mismatch at index $index") + } + } + } + + private fun assertBlockEquals(expected: parser.Block, actual: parser.Block, message: String) { + assertEquals(expected.items.size, actual.items.size, "$message - Block item count mismatch") + + expected.items.forEachIndexed { index, expectedItem -> + val actualItem = actual.items[index] + when (expectedItem) { + is parser.S -> { + assertIs(actualItem, "$message - Expected S but got ${actualItem::class.simpleName} at index $index") + assertStatementEquals(expectedItem.statement, actualItem.statement, "$message - Statement mismatch at index $index") + } + is parser.D -> { + assertIs(actualItem, "$message - Expected D but got ${actualItem::class.simpleName} at index $index") + assertDeclarationEquals(expectedItem.declaration, actualItem.declaration, "$message - Declaration mismatch at index $index") + } + else -> assertEquals(expectedItem, actualItem, "$message - Block item mismatch at index $index") + } + } + } + + private fun assertStatementEquals(expected: parser.Statement, actual: parser.Statement, message: String) { + when (expected) { + is parser.ReturnStatement -> { + assertIs(actual, "$message - Expected ReturnStatement but got ${actual::class.simpleName}") + assertExpressionEquals(expected.expression, actual.expression, "$message - Return expression mismatch") + } + is parser.ExpressionStatement -> { + assertIs(actual, "$message - Expected ExpressionStatement but got ${actual::class.simpleName}") + assertExpressionEquals(expected.expression, actual.expression, "$message - Expression mismatch") + } + is parser.NullStatement -> { + assertIs(actual, "$message - Expected NullStatement but got ${actual::class.simpleName}") + } + is parser.WhileStatement -> { + assertIs(actual, "$message - Expected WhileStatement but got ${actual::class.simpleName}") + assertExpressionEquals(expected.condition, actual.condition, "$message - While condition mismatch") + assertStatementEquals(expected.body, actual.body, "$message - While body mismatch") + } + is parser.ForStatement -> { + assertIs(actual, "$message - Expected ForStatement but got ${actual::class.simpleName}") + assertForInitEquals(expected.init, actual.init, "$message - For init mismatch") + if (expected.condition != null && actual.condition != null) { + assertExpressionEquals(expected.condition, actual.condition, "$message - For condition mismatch") + } else { + assertEquals(expected.condition, actual.condition, "$message - For condition null mismatch") + } + if (expected.post != null && actual.post != null) { + assertExpressionEquals(expected.post, actual.post, "$message - For post mismatch") + } else { + assertEquals(expected.post, actual.post, "$message - For post null mismatch") + } + assertStatementEquals(expected.body, actual.body, "$message - For body mismatch") + } + is parser.DoWhileStatement -> { + assertIs(actual, "$message - Expected DoWhileStatement but got ${actual::class.simpleName}") + assertStatementEquals(expected.body, actual.body, "$message - DoWhile body mismatch") + assertExpressionEquals(expected.condition, actual.condition, "$message - DoWhile condition mismatch") + } + is parser.IfStatement -> { + assertIs(actual, "$message - Expected IfStatement but got ${actual::class.simpleName}") + assertExpressionEquals(expected.condition, actual.condition, "$message - If condition mismatch") + assertStatementEquals(expected.then, actual.then, "$message - If then mismatch") + if (expected._else != null && actual._else != null) { + assertStatementEquals(expected._else, actual._else, "$message - If else mismatch") + } else { + assertEquals(expected._else, actual._else, "$message - If else null mismatch") + } + } + is parser.BreakStatement -> { + assertIs(actual, "$message - Expected BreakStatement but got ${actual::class.simpleName}") + assertEquals(expected.label, actual.label, "$message - Break label mismatch") + } + is parser.ContinueStatement -> { + assertIs(actual, "$message - Expected ContinueStatement but got ${actual::class.simpleName}") + assertEquals(expected.label, actual.label, "$message - Continue label mismatch") + } + is parser.GotoStatement -> { + assertIs(actual, "$message - Expected GotoStatement but got ${actual::class.simpleName}") + assertEquals(expected.label, actual.label, "$message - Goto label mismatch") + } + is parser.LabeledStatement -> { + assertIs(actual, "$message - Expected LabeledStatement but got ${actual::class.simpleName}") + assertEquals(expected.label, actual.label, "$message - LabeledStatement label mismatch") + assertStatementEquals(expected.statement, actual.statement, "$message - LabeledStatement statement mismatch") + } + is parser.CompoundStatement -> { + assertIs(actual, "$message - Expected CompoundStatement but got ${actual::class.simpleName}") + assertBlockEquals(expected.block, actual.block, "$message - CompoundStatement block mismatch") + } + else -> assertEquals(expected, actual, "$message - Statement type mismatch") + } + } + + private fun assertExpressionEquals(expected: parser.Expression, actual: parser.Expression, message: String) { + when (expected) { + is parser.IntExpression -> { + assertIs(actual, "$message - Expected IntExpression but got ${actual::class.simpleName}") + assertEquals(expected.value, actual.value, "$message - Int value mismatch") + } + is parser.VariableExpression -> { + assertIs(actual, "$message - Expected VariableExpression but got ${actual::class.simpleName}") + assertEquals(expected.name, actual.name, "$message - Variable name mismatch") + } + is parser.BinaryExpression -> { + assertIs(actual, "$message - Expected BinaryExpression but got ${actual::class.simpleName}") + assertExpressionEquals(expected.left, actual.left, "$message - Binary left mismatch") + assertEquals(expected.operator, actual.operator, "$message - Binary operator mismatch") + assertExpressionEquals(expected.right, actual.right, "$message - Binary right mismatch") + } + is parser.UnaryExpression -> { + assertIs(actual, "$message - Expected UnaryExpression but got ${actual::class.simpleName}") + assertEquals(expected.operator, actual.operator, "$message - Unary operator mismatch") + assertExpressionEquals(expected.expression, actual.expression, "$message - Unary expression mismatch") + } + is parser.AssignmentExpression -> { + assertIs(actual, "$message - Expected AssignmentExpression but got ${actual::class.simpleName}") + assertExpressionEquals(expected.lvalue, actual.lvalue, "$message - Assignment lvalue mismatch") + assertExpressionEquals(expected.rvalue, actual.rvalue, "$message - Assignment rvalue mismatch") + } + is parser.FunctionCall -> { + assertIs(actual, "$message - Expected FunctionCall but got ${actual::class.simpleName}") + assertEquals(expected.name, actual.name, "$message - Function call name mismatch") + assertEquals(expected.arguments.size, actual.arguments.size, "$message - Function call argument count mismatch") + expected.arguments.forEachIndexed { index, expectedArg -> + assertExpressionEquals(expectedArg, actual.arguments[index], "$message - Function call argument $index mismatch") + } + } + is parser.ConditionalExpression -> { + assertIs(actual, "$message - Expected ConditionalExpression but got ${actual::class.simpleName}") + assertExpressionEquals(expected.codition, actual.codition, "$message - Conditional condition mismatch") + assertExpressionEquals(expected.thenExpression, actual.thenExpression, "$message - Conditional then mismatch") + assertExpressionEquals(expected.elseExpression, actual.elseExpression, "$message - Conditional else mismatch") + } + else -> assertEquals(expected, actual, "$message - Expression type mismatch") + } + } + + private fun assertForInitEquals(expected: parser.ForInit, actual: parser.ForInit, message: String) { + when (expected) { + is parser.InitDeclaration -> { + assertIs(actual, "$message - Expected InitDeclaration but got ${actual::class.simpleName}") + assertDeclarationEquals(expected.varDeclaration, actual.varDeclaration, "$message - InitDeclaration variable mismatch") + } + is parser.InitExpression -> { + assertIs(actual, "$message - Expected InitExpression but got ${actual::class.simpleName}") + if (expected.expression != null && actual.expression != null) { + assertExpressionEquals(expected.expression, actual.expression, "$message - InitExpression mismatch") + } else { + assertEquals(expected.expression, actual.expression, "$message - InitExpression null mismatch") + } + } + else -> assertEquals(expected, actual, "$message - ForInit type mismatch") + } + } + + private fun assertTackyEquals(expected: TackyProgram, actual: TackyProgram, message: String) { + assertEquals(expected.functions.size, actual.functions.size, "$message - Tacky function count mismatch") + + expected.functions.forEachIndexed { index, expectedFunc -> + val actualFunc = actual.functions[index] + assertEquals(expectedFunc.name, actualFunc.name, "$message - Tacky function name mismatch at index $index") + assertEquals(expectedFunc.args, actualFunc.args, "$message - Tacky function args mismatch at index $index") + assertEquals(expectedFunc.body.size, actualFunc.body.size, "$message - Tacky function body size mismatch at index $index") + + expectedFunc.body.forEachIndexed { instrIndex, expectedInstr -> + val actualInstr = actualFunc.body[instrIndex] + assertTackyInstructionEquals(expectedInstr, actualInstr, "$message - Tacky instruction mismatch at function $index, instruction $instrIndex") + } + } + } + + private fun assertTackyInstructionEquals(expected: tacky.TackyInstruction, actual: tacky.TackyInstruction, message: String) { + // Check sourceId for all instructions + if (expected.sourceId.isEmpty()) { + assertTrue(actual.sourceId.isNotEmpty(), "$message - TackyInstruction sourceId should not be empty") + } else { + assertEquals(expected.sourceId, actual.sourceId, "$message - TackyInstruction sourceId mismatch") + } + + when (expected) { + is tacky.TackyRet -> { + assertIs(actual, "$message - Expected TackyRet but got ${actual::class.simpleName}") + assertTackyValueEquals(expected.value, actual.value, "$message - TackyRet value mismatch") + } + is tacky.TackyBinary -> { + assertIs(actual, "$message - Expected TackyBinary but got ${actual::class.simpleName}") + assertEquals(expected.operator, actual.operator, "$message - TackyBinary operator mismatch") + assertTackyValueEquals(expected.src1, actual.src1, "$message - TackyBinary src1 mismatch") + assertTackyValueEquals(expected.src2, actual.src2, "$message - TackyBinary src2 mismatch") + assertTackyValueEquals(expected.dest, actual.dest, "$message - TackyBinary dest mismatch") + } + is tacky.TackyUnary -> { + assertIs(actual, "$message - Expected TackyUnary but got ${actual::class.simpleName}") + assertEquals(expected.operator, actual.operator, "$message - TackyUnary operator mismatch") + assertTackyValueEquals(expected.src, actual.src, "$message - TackyUnary src mismatch") + assertTackyValueEquals(expected.dest, actual.dest, "$message - TackyUnary dest mismatch") + } + is tacky.TackyCopy -> { + assertIs(actual, "$message - Expected TackyCopy but got ${actual::class.simpleName}") + assertTackyValueEquals(expected.src, actual.src, "$message - TackyCopy src mismatch") + assertTackyValueEquals(expected.dest, actual.dest, "$message - TackyCopy dest mismatch") + } + is tacky.TackyFunCall -> { + assertIs(actual, "$message - Expected TackyFunCall but got ${actual::class.simpleName}") + assertEquals(expected.funName, actual.funName, "$message - TackyFunCall name mismatch") + assertEquals(expected.args.size, actual.args.size, "$message - TackyFunCall args size mismatch") + expected.args.forEachIndexed { index, expectedArg -> + assertTackyValueEquals(expectedArg, actual.args[index], "$message - TackyFunCall arg $index mismatch") + } + assertTackyValueEquals(expected.dest, actual.dest, "$message - TackyFunCall dest mismatch") + } + is tacky.TackyJump -> { + assertIs(actual, "$message - Expected TackyJump but got ${actual::class.simpleName}") + assertTackyLabelEquals(expected.target, actual.target, "$message - TackyJump target mismatch") + } + is tacky.JumpIfZero -> { + assertIs(actual, "$message - Expected JumpIfZero but got ${actual::class.simpleName}") + assertTackyValueEquals(expected.condition, actual.condition, "$message - JumpIfZero condition mismatch") + assertTackyLabelEquals(expected.target, actual.target, "$message - JumpIfZero target mismatch") + } + is tacky.JumpIfNotZero -> { + assertIs(actual, "$message - Expected JumpIfNotZero but got ${actual::class.simpleName}") + assertTackyValueEquals(expected.condition, actual.condition, "$message - JumpIfNotZero condition mismatch") + assertTackyLabelEquals(expected.target, actual.target, "$message - JumpIfNotZero target mismatch") + } + is tacky.TackyLabel -> { + assertIs(actual, "$message - Expected TackyLabel but got ${actual::class.simpleName}") + assertTackyLabelEquals(expected, actual, "$message - TackyLabel mismatch") + } + else -> assertEquals(expected, actual, "$message - Tacky instruction type mismatch") + } + } + + private fun assertTackyLabelEquals(expected: tacky.TackyLabel, actual: tacky.TackyLabel, message: String) { + assertEquals(expected.name, actual.name, "$message - TackyLabel name mismatch") + // For test data with empty sourceId, we expect actual sourceId to be non-empty + if (expected.sourceId.isEmpty()) { + assertTrue(actual.sourceId.isNotEmpty(), "$message - TackyLabel sourceId should not be empty") + } else { + assertEquals(expected.sourceId, actual.sourceId, "$message - TackyLabel sourceId mismatch") + } + } + + private fun assertTackyValueEquals(expected: tacky.TackyVal, actual: tacky.TackyVal, message: String) { + when (expected) { + is tacky.TackyConstant -> { + assertIs(actual, "$message - Expected TackyConstant but got ${actual::class.simpleName}") + assertEquals(expected.value, actual.value, "$message - TackyConstant value mismatch") + } + is tacky.TackyVar -> { + assertIs(actual, "$message - Expected TackyVar but got ${actual::class.simpleName}") + assertEquals(expected.name, actual.name, "$message - TackyVar name mismatch") + } + else -> assertEquals(expected, actual, "$message - Tacky value type mismatch") + } + } + + private fun assertAssemblyEquals(expected: AsmProgram, actual: AsmProgram, message: String) { + assertEquals(expected.functions.size, actual.functions.size, "$message - Assembly function count mismatch") + + expected.functions.forEachIndexed { index, expectedFunc -> + val actualFunc = actual.functions[index] + assertEquals(expectedFunc.name, actualFunc.name, "$message - Assembly function name mismatch at index $index") + assertEquals(expectedFunc.body.size, actualFunc.body.size, "$message - Assembly function body size mismatch at index $index") + assertEquals(expectedFunc.stackSize, actualFunc.stackSize, "$message - Assembly function stack size mismatch at index $index") + + expectedFunc.body.forEachIndexed { instrIndex, expectedInstr -> + val actualInstr = actualFunc.body[instrIndex] + assertAssemblyInstructionEquals(expectedInstr, actualInstr, "$message - Assembly instruction mismatch at function $index, instruction $instrIndex") + } + } + } + + private fun assertAssemblyInstructionEquals(expected: assembly.Instruction, actual: assembly.Instruction, message: String) { + when (expected) { + is assembly.Mov -> { + assertIs(actual, "$message - Expected Mov but got ${actual::class.simpleName}") + assertAssemblyOperandEquals(expected.src, actual.src, "$message - Mov src mismatch") + assertAssemblyOperandEquals(expected.dest, actual.dest, "$message - Mov dest mismatch") + } + is assembly.AsmUnary -> { + assertIs(actual, "$message - Expected AsmUnary but got ${actual::class.simpleName}") + assertEquals(expected.op, actual.op, "$message - AsmUnary operator mismatch") + assertAssemblyOperandEquals(expected.dest, actual.dest, "$message - AsmUnary dest mismatch") + } + is assembly.AsmBinary -> { + assertIs(actual, "$message - Expected AsmBinary but got ${actual::class.simpleName}") + assertEquals(expected.op, actual.op, "$message - AsmBinary operator mismatch") + assertAssemblyOperandEquals(expected.src, actual.src, "$message - AsmBinary src mismatch") + assertAssemblyOperandEquals(expected.dest, actual.dest, "$message - AsmBinary dest mismatch") + } + is assembly.Idiv -> { + assertIs(actual, "$message - Expected Idiv but got ${actual::class.simpleName}") + assertAssemblyOperandEquals(expected.divisor, actual.divisor, "$message - Idiv divisor mismatch") + } + is assembly.Cdq -> { + assertIs(actual, "$message - Expected Cdq but got ${actual::class.simpleName}") + } + is assembly.AllocateStack -> { + assertIs(actual, "$message - Expected AllocateStack but got ${actual::class.simpleName}") + assertEquals(expected.size, actual.size, "$message - AllocateStack size mismatch") + } + is assembly.DeAllocateStack -> { + assertIs(actual, "$message - Expected DeAllocateStack but got ${actual::class.simpleName}") + assertEquals(expected.size, actual.size, "$message - DeAllocateStack size mismatch") + } + is assembly.Push -> { + assertIs(actual, "$message - Expected Push but got ${actual::class.simpleName}") + assertAssemblyOperandEquals(expected.operand, actual.operand, "$message - Push operand mismatch") + } + is assembly.Call -> { + assertIs(actual, "$message - Expected Call but got ${actual::class.simpleName}") + assertEquals(expected.identifier, actual.identifier, "$message - Call identifier mismatch") + } + is assembly.Label -> { + assertIs(actual, "$message - Expected Label but got ${actual::class.simpleName}") + assertEquals(expected.name, actual.name, "$message - Label name mismatch") + } + is assembly.Jmp -> { + assertIs(actual, "$message - Expected Jmp but got ${actual::class.simpleName}") + assertAssemblyLabelEquals(expected.label, actual.label, "$message - Jmp label mismatch") + } + is assembly.JmpCC -> { + assertIs(actual, "$message - Expected JmpCC but got ${actual::class.simpleName}") + assertEquals(expected.condition, actual.condition, "$message - JmpCC condition mismatch") + assertAssemblyLabelEquals(expected.label, actual.label, "$message - JmpCC label mismatch") + } + is assembly.Cmp -> { + assertIs(actual, "$message - Expected Cmp but got ${actual::class.simpleName}") + assertAssemblyOperandEquals(expected.src, actual.src, "$message - Cmp src mismatch") + assertAssemblyOperandEquals(expected.dest, actual.dest, "$message - Cmp dest mismatch") + } + is assembly.SetCC -> { + assertIs(actual, "$message - Expected SetCC but got ${actual::class.simpleName}") + assertEquals(expected.condition, actual.condition, "$message - SetCC condition mismatch") + assertAssemblyOperandEquals(expected.dest, actual.dest, "$message - SetCC dest mismatch") + } + is assembly.Ret -> { + assertIs(actual, "$message - Expected Ret but got ${actual::class.simpleName}") + } + else -> assertEquals(expected, actual, "$message - Assembly instruction type mismatch") + } + } + + private fun assertAssemblyLabelEquals(expected: assembly.Label, actual: assembly.Label, message: String) { + assertEquals(expected.name, actual.name, "$message - Assembly Label name mismatch") + } + + private fun assertAssemblyOperandEquals(expected: assembly.Operand, actual: assembly.Operand, message: String) { + when (expected) { + is assembly.Imm -> { + assertIs(actual, "$message - Expected Imm but got ${actual::class.simpleName}") + assertEquals(expected.value, actual.value, "$message - Imm value mismatch") + } + is assembly.Register -> { + assertIs(actual, "$message - Expected Register but got ${actual::class.simpleName}") + assertEquals(expected.name, actual.name, "$message - Register name mismatch") + } + is assembly.Stack -> { + assertIs(actual, "$message - Expected Stack but got ${actual::class.simpleName}") + assertEquals(expected.offset, actual.offset, "$message - Stack offset mismatch") + } + is assembly.Pseudo -> { + assertIs(actual, "$message - Expected Pseudo but got ${actual::class.simpleName}") + assertEquals(expected.name, actual.name, "$message - Pseudo name mismatch") + } + else -> assertEquals(expected, actual, "$message - Assembly operand type mismatch") + } + } + + private fun assertFunctionDeclarationEquals(expected: parser.FunctionDeclaration, actual: parser.FunctionDeclaration, message: String) { + assertEquals(expected.name, actual.name, "$message - Function declaration name mismatch") + assertEquals(expected.params, actual.params, "$message - Function declaration params mismatch") + if (expected.body != null && actual.body != null) { + assertBlockEquals(expected.body, actual.body, "$message - Function declaration body mismatch") + } else { + assertEquals(expected.body, actual.body, "$message - Function declaration body null mismatch") + } + } + + private fun assertDeclarationEquals(expected: parser.Declaration, actual: parser.Declaration, message: String) { + when (expected) { + is parser.VarDecl -> { + assertIs(actual, "$message - Expected VarDecl but got ${actual::class.simpleName}") + assertDeclarationEquals(expected.varDecl, actual.varDecl, "$message - VarDecl variable mismatch") + } + is parser.FunDecl -> { + assertIs(actual, "$message - Expected FunDecl but got ${actual::class.simpleName}") + assertFunctionDeclarationEquals(expected.funDecl, actual.funDecl, "$message - FunDecl function mismatch") + } + is parser.VariableDeclaration -> { + assertIs(actual, "$message - Expected VariableDeclaration but got ${actual::class.simpleName}") + assertEquals(expected.name, actual.name, "$message - Variable declaration name mismatch") + if (expected.init != null && actual.init != null) { + assertExpressionEquals(expected.init, actual.init, "$message - Variable declaration init mismatch") + } else { + assertEquals(expected.init, actual.init, "$message - Variable declaration init null mismatch") + } + } + is parser.FunctionDeclaration -> { + assertIs(actual, "$message - Expected FunctionDeclaration but got ${actual::class.simpleName}") + assertEquals(expected.name, actual.name, "$message - Function declaration name mismatch") + assertEquals(expected.params, actual.params, "$message - Function declaration params mismatch") + if (expected.body != null && actual.body != null) { + assertBlockEquals(expected.body, actual.body, "$message - Function declaration body mismatch") + } else { + assertEquals(expected.body, actual.body, "$message - Function declaration body null mismatch") + } + } + else -> assertEquals(expected, actual, "$message - Declaration type mismatch") + } + } + @Test fun testValidPrograms() { ValidTestCases.testCases.forEachIndexed { index, testCase -> @@ -31,15 +463,17 @@ class CompilerTestSuite { // Parser stage val ast = CompilerWorkflow.take(tokens) assertIs(ast) + val simpleProgram = ast as SimpleProgram if (testCase.expectedAst != null) { - assertEquals( - expected = testCase.expectedAst, - actual = ast, + assertIs(testCase.expectedAst, "Expected AST should be SimpleProgram") + assertAstEquals( + expected = testCase.expectedAst as SimpleProgram, + actual = simpleProgram, message = """ |Test case $index failed with: |Expected:${testCase.expectedAst} - |Actual: $ast + |Actual: $simpleProgram """.trimMargin() ) } @@ -48,7 +482,7 @@ class CompilerTestSuite { val tacky = CompilerWorkflow.take(ast) assertIs(tacky) if (testCase.expectedTacky != null) { - assertEquals( + assertTackyEquals( expected = testCase.expectedTacky, actual = tacky, message = @@ -64,7 +498,7 @@ class CompilerTestSuite { val asm = CompilerWorkflow.take(tacky) assertIs(asm) if (testCase.expectedAssembly != null) { - assertEquals( + assertAssemblyEquals( expected = testCase.expectedAssembly, actual = asm, message = diff --git a/src/jsTest/kotlin/integration/ValidTestCases.kt b/src/jsTest/kotlin/integration/ValidTestCases.kt index 0f71683..3a0e6b8 100644 --- a/src/jsTest/kotlin/integration/ValidTestCases.kt +++ b/src/jsTest/kotlin/integration/ValidTestCases.kt @@ -40,6 +40,7 @@ import parser.NullStatement import parser.ReturnStatement import parser.S import parser.SimpleProgram +import parser.SourceLocation import parser.UnaryExpression import parser.VarDecl import parser.VariableDeclaration @@ -61,6 +62,9 @@ import tacky.TackyUnary import tacky.TackyUnaryOP import tacky.TackyVar +// Helper constant for test locations +val TEST_LOCATION = SourceLocation(1, 1, 1, 1) + data class ValidTestCase( val title: String? = null, val code: String, @@ -78,67 +82,72 @@ object ValidTestCases { code = "int main(void) \n { return (5 - 3) * 4 + ~(-5) / 6 % 3; }", expectedTokenList = listOf( - Token(TokenType.KEYWORD_INT, "int", 1, 1), - Token(TokenType.IDENTIFIER, "main", 1, 5), - Token(TokenType.LEFT_PAREN, "(", 1, 9), - Token(TokenType.KEYWORD_VOID, "void", 1, 10), - Token(TokenType.RIGHT_PAREN, ")", 1, 14), - Token(TokenType.LEFT_BRACK, "{", 2, 2), - Token(TokenType.KEYWORD_RETURN, "return", 2, 4), - Token(TokenType.LEFT_PAREN, "(", 2, 11), - Token(TokenType.INT_LITERAL, "5", 2, 12), - Token(TokenType.NEGATION, "-", 2, 14), - Token(TokenType.INT_LITERAL, "3", 2, 16), - Token(TokenType.RIGHT_PAREN, ")", 2, 17), - Token(TokenType.MULTIPLY, "*", 2, 19), - Token(TokenType.INT_LITERAL, "4", 2, 21), - Token(TokenType.PLUS, "+", 2, 23), - Token(TokenType.TILDE, "~", 2, 25), - Token(TokenType.LEFT_PAREN, "(", 2, 26), - Token(TokenType.NEGATION, "-", 2, 27), - Token(TokenType.INT_LITERAL, "5", 2, 28), - Token(TokenType.RIGHT_PAREN, ")", 2, 29), - Token(TokenType.DIVIDE, "/", 2, 31), - Token(TokenType.INT_LITERAL, "6", 2, 33), - Token(TokenType.REMAINDER, "%", 2, 35), - Token(TokenType.INT_LITERAL, "3", 2, 37), - Token(TokenType.SEMICOLON, ";", 2, 38), - Token(TokenType.RIGHT_BRACK, "}", 2, 40), - Token(TokenType.EOF, "", 2, 41) + Token(TokenType.KEYWORD_INT, "int", 1, 1, 1, 3), + Token(TokenType.IDENTIFIER, "main", 1, 5, 1, 8), + Token(TokenType.LEFT_PAREN, "(", 1, 9, 1, 9), + Token(TokenType.KEYWORD_VOID, "void", 1, 10, 1, 13), + Token(TokenType.RIGHT_PAREN, ")", 1, 14, 1, 14), + Token(TokenType.LEFT_BRACK, "{", 2, 2, 2, 2), + Token(TokenType.KEYWORD_RETURN, "return", 2, 4, 2, 9), + Token(TokenType.LEFT_PAREN, "(", 2, 11, 2, 11), + Token(TokenType.INT_LITERAL, "5", 2, 12, 2, 12), + Token(TokenType.NEGATION, "-", 2, 14, 2, 14), + Token(TokenType.INT_LITERAL, "3", 2, 16, 2, 16), + Token(TokenType.RIGHT_PAREN, ")", 2, 17, 2, 17), + Token(TokenType.MULTIPLY, "*", 2, 19, 2, 19), + Token(TokenType.INT_LITERAL, "4", 2, 21, 2, 21), + Token(TokenType.PLUS, "+", 2, 23, 2, 23), + Token(TokenType.TILDE, "~", 2, 25, 2, 25), + Token(TokenType.LEFT_PAREN, "(", 2, 26, 2, 26), + Token(TokenType.NEGATION, "-", 2, 27, 2, 27), + Token(TokenType.INT_LITERAL, "5", 2, 28, 2, 28), + Token(TokenType.RIGHT_PAREN, ")", 2, 29, 2, 29), + Token(TokenType.DIVIDE, "/", 2, 31, 2, 31), + Token(TokenType.INT_LITERAL, "6", 2, 33, 2, 33), + Token(TokenType.REMAINDER, "%", 2, 35, 2, 35), + Token(TokenType.INT_LITERAL, "3", 2, 37, 2, 37), + Token(TokenType.SEMICOLON, ";", 2, 38, 2, 38), + Token(TokenType.RIGHT_BRACK, "}", 2, 40, 2, 40), + Token(TokenType.EOF, "", 2, 41, 2, 41) ), expectedAst = SimpleProgram( + location = TEST_LOCATION, functionDeclaration = listOf( FunctionDeclaration( + location = TEST_LOCATION, name = "main", params = emptyList(), body = Block( - listOf( + items = listOf( S( ReturnStatement( + location = TEST_LOCATION, expression = BinaryExpression( left = BinaryExpression( left = BinaryExpression( - left = IntExpression(5), - operator = Token(TokenType.NEGATION, "-", 2, 14), - right = IntExpression(3) + left = IntExpression(5, location = TEST_LOCATION), + operator = Token(TokenType.NEGATION, "-", 2, 14, 2, 14), + right = IntExpression(3, location = TEST_LOCATION), + location = TEST_LOCATION ), - operator = Token(TokenType.MULTIPLY, "*", 2, 19), - right = IntExpression(4) + operator = Token(TokenType.MULTIPLY, "*", 2, 19, 2, 19), + right = IntExpression(4, location = TEST_LOCATION), + location = TEST_LOCATION ), - operator = Token(TokenType.PLUS, "+", 2, 23), + operator = Token(TokenType.PLUS, "+", 2, 23, 2, 23), right = BinaryExpression( left = BinaryExpression( left = UnaryExpression( - operator = Token(TokenType.TILDE, "~", 2, 25), + operator = Token(TokenType.TILDE, "~", 2, 25, 2, 25), expression = UnaryExpression( operator = @@ -146,21 +155,29 @@ object ValidTestCases { TokenType.NEGATION, "-", 2, + 27, + 2, 27 ), - expression = IntExpression(5) - ) + expression = IntExpression(5, location = TEST_LOCATION), + location = TEST_LOCATION + ), + location = TEST_LOCATION ), - operator = Token(TokenType.DIVIDE, "/", 2, 31), - right = IntExpression(6) + operator = Token(TokenType.DIVIDE, "/", 2, 31, 2, 31), + right = IntExpression(6, location = TEST_LOCATION), + location = TEST_LOCATION ), - operator = Token(TokenType.REMAINDER, "%", 2, 35), - right = IntExpression(3) - ) + operator = Token(TokenType.REMAINDER, "%", 2, 35, 2, 35), + right = IntExpression(3, location = TEST_LOCATION), + location = TEST_LOCATION + ), + location = TEST_LOCATION ) ) ) - ) + ), + location = TEST_LOCATION ) ) ) @@ -204,7 +221,7 @@ object ValidTestCases { stackSize = 56, // 7 temporary variables * 8 bytes = 56 body = listOf( - AllocateStack(64), + AllocateStack(64, ""), // tmp.0 = 5 - 3 Mov(Imm(5), Stack(-8)), AsmBinary(AsmBinaryOp.SUB, Imm(3), Stack(-8)), @@ -223,13 +240,13 @@ object ValidTestCases { AsmUnary(AsmUnaryOp.NOT, Stack(-32)), // tmp.4 = tmp.3 / 6 Mov(Stack(-32), Register(HardwareRegister.EAX)), - Cdq, + Cdq(""), Mov(Imm(6), Register(HardwareRegister.R10D)), Idiv(Register(HardwareRegister.R10D)), Mov(Register(HardwareRegister.EAX), Stack(-40)), // tmp.5 = tmp.4 % 3 Mov(Stack(-40), Register(HardwareRegister.EAX)), - Cdq, + Cdq(""), Mov(Imm(3), Register(HardwareRegister.R10D)), Idiv(Register(HardwareRegister.R10D)), Mov(Register(HardwareRegister.EDX), Stack(-48)), @@ -240,7 +257,7 @@ object ValidTestCases { AsmBinary(AsmBinaryOp.ADD, Register(HardwareRegister.R10D), Stack(-56)), // return tmp.6 Mov(Stack(-56), Register(HardwareRegister.EAX)), - Ret + Ret("") // The implicit return 0 ) ) @@ -307,7 +324,7 @@ object ValidTestCases { stackSize = 40, // 5 temporary variables * 8 bytes = 40 body = listOf( - AllocateStack(48), + AllocateStack(48, ""), // tmp.1 = (1 == 0) Mov(Imm(1), Register(HardwareRegister.R11D)), Cmp(Imm(0), Register(HardwareRegister.R11D)), @@ -353,7 +370,7 @@ object ValidTestCases { Label(".L_or_end_1"), // return tmp.0 Mov(Stack(-40), Register(HardwareRegister.EAX)), - Ret + Ret("") ) ) ) @@ -370,66 +387,66 @@ object ValidTestCases { |; |} """.trimMargin(), - expectedTokenList = - listOf( - Token(TokenType.KEYWORD_INT, "int", 1, 1), - Token(TokenType.IDENTIFIER, "main", 1, 5), - Token(TokenType.LEFT_PAREN, "(", 1, 9), - Token(TokenType.KEYWORD_VOID, "void", 1, 10), - Token(TokenType.RIGHT_PAREN, ")", 1, 14), - Token(TokenType.LEFT_BRACK, "{", 1, 16), + expectedTokenList = listOf( + Token(TokenType.KEYWORD_INT, "int", 1, 1, 1, 3), + Token(TokenType.IDENTIFIER, "main", 1, 5, 1, 8), + Token(TokenType.LEFT_PAREN, "(", 1, 9, 1, 9), + Token(TokenType.KEYWORD_VOID, "void", 1, 10, 1, 13), + Token(TokenType.RIGHT_PAREN, ")", 1, 14, 1, 14), + Token(TokenType.LEFT_BRACK, "{", 1, 16, 1, 16), // int b; - Token(TokenType.KEYWORD_INT, "int", 2, 1), - Token(TokenType.IDENTIFIER, "b", 2, 5), - Token(TokenType.SEMICOLON, ";", 2, 6), + Token(TokenType.KEYWORD_INT, "int", 2, 1, 2, 3), Token(TokenType.IDENTIFIER, "b", 2, 5, 2, 5), + Token(TokenType.SEMICOLON, ";", 2, 6, 2, 6), // int a = 10 + 1; - Token(TokenType.KEYWORD_INT, "int", 3, 1), - Token(TokenType.IDENTIFIER, "a", 3, 5), - Token(TokenType.ASSIGN, "=", 3, 7), - Token(TokenType.INT_LITERAL, "10", 3, 9), - Token(TokenType.PLUS, "+", 3, 12), - Token(TokenType.INT_LITERAL, "1", 3, 14), - Token(TokenType.SEMICOLON, ";", 3, 15), + Token(TokenType.KEYWORD_INT, "int", 3, 1, 3, 3), Token(TokenType.IDENTIFIER, "a", 3, 5, 3, 5), + Token(TokenType.ASSIGN, "=", 3, 7, 3, 7), + Token(TokenType.INT_LITERAL, "10", 3, 9, 3, 10), Token(TokenType.PLUS, "+", 3, 12, 3, 12), + Token(TokenType.INT_LITERAL, "1", 3, 14, 3, 14), + Token(TokenType.SEMICOLON, ";", 3, 15, 3, 15), // b = (a=2) * 2; - Token(TokenType.IDENTIFIER, "b", 4, 1), - Token(TokenType.ASSIGN, "=", 4, 3), - Token(TokenType.LEFT_PAREN, "(", 4, 5), - Token(TokenType.IDENTIFIER, "a", 4, 6), - Token(TokenType.ASSIGN, "=", 4, 7), - Token(TokenType.INT_LITERAL, "2", 4, 8), - Token(TokenType.RIGHT_PAREN, ")", 4, 9), - Token(TokenType.MULTIPLY, "*", 4, 11), - Token(TokenType.INT_LITERAL, "2", 4, 13), - Token(TokenType.SEMICOLON, ";", 4, 14), + Token(TokenType.IDENTIFIER, "b", 4, 1, 4, 1), + Token(TokenType.ASSIGN, "=", 4, 3, 4, 3), + Token(TokenType.LEFT_PAREN, "(", 4, 5, 4, 5), + Token(TokenType.IDENTIFIER, "a", 4, 6, 4, 6), + Token(TokenType.ASSIGN, "=", 4, 7, 4, 7), + Token(TokenType.INT_LITERAL, "2", 4, 8, 4, 8), + Token(TokenType.RIGHT_PAREN, ")", 4, 9, 4, 9), + Token(TokenType.MULTIPLY, "*", 4, 11, 4, 11), + Token(TokenType.INT_LITERAL, "2", 4, 13, 4, 13), + Token(TokenType.SEMICOLON, ";", 4, 14, 4, 14), // return b; - Token(TokenType.KEYWORD_RETURN, "return", 5, 1), - Token(TokenType.IDENTIFIER, "b", 5, 8), - Token(TokenType.SEMICOLON, ";", 5, 9), - Token(TokenType.SEMICOLON, ";", 6, 1), - Token(TokenType.RIGHT_BRACK, "}", 7, 1), - Token(TokenType.EOF, "", 7, 2) + Token(TokenType.KEYWORD_RETURN, "return", 5, 1, 5, 6), + Token(TokenType.IDENTIFIER, "b", 5, 8, 5, 8), + Token(TokenType.SEMICOLON, ";", 5, 9, 5, 9), + Token(TokenType.SEMICOLON, ";", 6, 1, 6, 1), + Token(TokenType.RIGHT_BRACK, "}", 7, 1, 7, 1), + Token(TokenType.EOF, "", 7, 2, 7, 2) ), expectedAst = SimpleProgram( + location = TEST_LOCATION, functionDeclaration = listOf( FunctionDeclaration( + location = TEST_LOCATION, name = "main", params = emptyList(), - body = - Block( - block = + body = Block( + location = TEST_LOCATION, + items = listOf( - D(VarDecl(VariableDeclaration(name = "b.0", init = null))), + D(VarDecl(VariableDeclaration(location = TEST_LOCATION, name = "b.0", init = null))), D( VarDecl( VariableDeclaration( + location = TEST_LOCATION, name = "a.1", init = BinaryExpression( - left = IntExpression(10), - operator = Token(TokenType.PLUS, "+", 3, 12), - right = IntExpression(1) + left = IntExpression(10, location = TEST_LOCATION), + operator = Token(TokenType.PLUS, "+", 3, 12, 3, 12), + right = IntExpression(1, location = TEST_LOCATION), + location = TEST_LOCATION ) ) ) @@ -437,26 +454,29 @@ object ValidTestCases { S( ExpressionStatement( AssignmentExpression( - lvalue = VariableExpression("b.0"), - rvalue = + VariableExpression("b.0", location = TEST_LOCATION), BinaryExpression( - left = AssignmentExpression( - lvalue = VariableExpression("a.1"), - rvalue = IntExpression(2) + VariableExpression("a.1", location = TEST_LOCATION), + IntExpression(2, location = TEST_LOCATION), + location = TEST_LOCATION ), - operator = Token(TokenType.MULTIPLY, "*", 4, 11), - right = IntExpression(2) - ) - ) + Token(TokenType.MULTIPLY, "*", 4, 11, 4, 11), + IntExpression(2, location = TEST_LOCATION), + location = TEST_LOCATION + ), + location = TEST_LOCATION + ), + location = TEST_LOCATION ) ), S( ReturnStatement( - expression = VariableExpression("b.0") + location = TEST_LOCATION, + expression = VariableExpression("b.0", location = TEST_LOCATION) ) ), - S(NullStatement()) + S(NullStatement(location = TEST_LOCATION)) ) ) ) @@ -537,7 +557,7 @@ object ValidTestCases { stackSize = 16, body = listOf( - AllocateStack(16), + AllocateStack(16, ""), // int a = 0; Mov(Imm(0), Stack(-8)), // tmp.0 = a == 0; @@ -549,18 +569,18 @@ object ValidTestCases { JmpCC(ConditionCode.E, Label(".L_else_label_1")), // return 10; Mov(Imm(10), Register(HardwareRegister.EAX)), - Ret, + Ret(""), Jmp(Label(".L_end_0")), // .L_else_label_1: Label(".L_else_label_1"), // return 20; Mov(Imm(20), Register(HardwareRegister.EAX)), - Ret, + Ret(""), // .L_end_0: Label(".L_end_0"), // implicit return 0 Mov(Imm(0), Register(HardwareRegister.EAX)), - Ret + Ret("") ) ) ) @@ -621,7 +641,7 @@ object ValidTestCases { stackSize = 24, // 1 variable (a) + 2 temporaries (tmp.0, tmp.1) = 3 * 8 = 24 body = listOf( - AllocateStack(32), // 12 rounded up to nearest 16 + AllocateStack(32, ""), // 12 rounded up to nearest 16 // a = 0 Mov(Imm(0), Stack(-8)), // Stack slot for a.0 // start: @@ -646,7 +666,7 @@ object ValidTestCases { Label(".L_end_0"), // return a Mov(Stack(-8), Register(HardwareRegister.EAX)), - Ret + Ret("") ) ) ) @@ -703,7 +723,7 @@ object ValidTestCases { stackSize = 24, body = listOf( - AllocateStack(32), + AllocateStack(32, ""), // Parameter setup: a.0 -> Stack(-8), b.1 -> Stack(-16) Mov(Register(HardwareRegister.EDI), Stack(-8)), Mov(Register(HardwareRegister.ESI), Stack(-16)), @@ -714,7 +734,7 @@ object ValidTestCases { AsmBinary(AsmBinaryOp.ADD, Register(HardwareRegister.R10D), Stack(-24)), // return tmp.0 Mov(Stack(-24), Register(HardwareRegister.EAX)), - Ret + Ret("") ) ), AsmFunction( @@ -722,7 +742,7 @@ object ValidTestCases { stackSize = 16, body = listOf( - AllocateStack(16), + AllocateStack(16, ""), // Call add(5, 3) Mov(Imm(5), Register(HardwareRegister.EDI)), Mov(Imm(3), Register(HardwareRegister.ESI)), @@ -734,7 +754,7 @@ object ValidTestCases { Mov(Register(HardwareRegister.R10D), Stack(-16)), // return result Mov(Stack(-16), Register(HardwareRegister.EAX)), - Ret + Ret("") ) ) ) @@ -797,7 +817,7 @@ object ValidTestCases { listOf( // return 5; Mov(Imm(5), Register(HardwareRegister.EAX)), - Ret + Ret("") ) ), AsmFunction( @@ -805,16 +825,16 @@ object ValidTestCases { stackSize = 8, body = listOf( - AllocateStack(16), + AllocateStack(16, ""), // Implicit return 0 Mov(Imm(0), Register(HardwareRegister.EAX)), - Ret, + Ret(""), // Call foo() (no arguments) Call("foo"), // Store result and return it Mov(Register(HardwareRegister.EAX), Stack(-8)), Mov(Stack(-8), Register(HardwareRegister.EAX)), - Ret + Ret("") ) ) ) @@ -840,221 +860,230 @@ object ValidTestCases { """.trimMargin(), expectedTokenList = listOf( - Token(TokenType.KEYWORD_INT, "int", 1, 1), - Token(TokenType.IDENTIFIER, "main", 1, 5), - Token(TokenType.LEFT_PAREN, "(", 1, 9), - Token(TokenType.KEYWORD_VOID, "void", 1, 10), - Token(TokenType.RIGHT_PAREN, ")", 1, 14), - Token(TokenType.LEFT_BRACK, "{", 1, 16), + Token(TokenType.KEYWORD_INT, "int", 1, 1, 1, 3), + Token(TokenType.IDENTIFIER, "main", 1, 5, 1, 8), + Token(TokenType.LEFT_PAREN, "(", 1, 9, 1, 9), + Token(TokenType.KEYWORD_VOID, "void", 1, 10, 1, 13), + Token(TokenType.RIGHT_PAREN, ")", 1, 14, 1, 14), + Token(TokenType.LEFT_BRACK, "{", 1, 16, 1, 16), // int a = 10; - Token(TokenType.KEYWORD_INT, "int", 2, 1), - Token(TokenType.IDENTIFIER, "a", 2, 5), - Token(TokenType.ASSIGN, "=", 2, 7), - Token(TokenType.INT_LITERAL, "10", 2, 9), - Token(TokenType.SEMICOLON, ";", 2, 11), + Token(TokenType.KEYWORD_INT, "int", 2, 1, 2, 3), + Token(TokenType.IDENTIFIER, "a", 2, 5, 2, 5), + Token(TokenType.ASSIGN, "=", 2, 7, 2, 7), + Token(TokenType.INT_LITERAL, "10", 2, 9, 2, 10), + Token(TokenType.SEMICOLON, ";", 2, 11, 2, 11), // while(a > 0) - Token(TokenType.KEYWORD_WHILE, "while", 3, 1), - Token(TokenType.LEFT_PAREN, "(", 3, 6), - Token(TokenType.IDENTIFIER, "a", 3, 7), - Token(TokenType.GREATER, ">", 3, 9), - Token(TokenType.INT_LITERAL, "0", 3, 11), - Token(TokenType.RIGHT_PAREN, ")", 3, 12), + Token(TokenType.KEYWORD_WHILE, "while", 3, 1, 3, 5), + Token(TokenType.LEFT_PAREN, "(", 3, 6, 3, 6), + Token(TokenType.IDENTIFIER, "a", 3, 7, 3, 7), + Token(TokenType.GREATER, ">", 3, 9, 3, 9), + Token(TokenType.INT_LITERAL, "0", 3, 11, 3, 11), + Token(TokenType.RIGHT_PAREN, ")", 3, 12, 3, 12), // a = a - 1; - Token(TokenType.IDENTIFIER, "a", 4, 4), - Token(TokenType.ASSIGN, "=", 4, 6), - Token(TokenType.IDENTIFIER, "a", 4, 8), - Token(TokenType.NEGATION, "-", 4, 10), - Token(TokenType.INT_LITERAL, "1", 4, 12), - Token(TokenType.SEMICOLON, ";", 4, 13), + Token(TokenType.IDENTIFIER, "a", 4, 4, 4, 4), + Token(TokenType.ASSIGN, "=", 4, 6, 4, 6), + Token(TokenType.IDENTIFIER, "a", 4, 8, 4, 8), + Token(TokenType.NEGATION, "-", 4, 10, 4, 10), + Token(TokenType.INT_LITERAL, "1", 4, 12, 4, 12), + Token(TokenType.SEMICOLON, ";", 4, 13, 4, 13), // for(int i = 0; i < 10; i = i + 1) - Token(TokenType.KEYWORD_FOR, "for", 5, 1), - Token(TokenType.LEFT_PAREN, "(", 5, 4), - Token(TokenType.KEYWORD_INT, "int", 5, 5), - Token(TokenType.IDENTIFIER, "i", 5, 9), - Token(TokenType.ASSIGN, "=", 5, 11), - Token(TokenType.INT_LITERAL, "0", 5, 13), - Token(TokenType.SEMICOLON, ";", 5, 14), - Token(TokenType.IDENTIFIER, "i", 5, 16), - Token(TokenType.LESS, "<", 5, 18), - Token(TokenType.INT_LITERAL, "10", 5, 20), - Token(TokenType.SEMICOLON, ";", 5, 22), - Token(TokenType.IDENTIFIER, "i", 5, 24), - Token(TokenType.ASSIGN, "=", 5, 26), - Token(TokenType.IDENTIFIER, "i", 5, 28), - Token(TokenType.PLUS, "+", 5, 30), - Token(TokenType.INT_LITERAL, "1", 5, 32), - Token(TokenType.RIGHT_PAREN, ")", 5, 33), + Token(TokenType.KEYWORD_FOR, "for", 5, 1, 5, 3), + Token(TokenType.LEFT_PAREN, "(", 5, 4, 5, 4), + Token(TokenType.KEYWORD_INT, "int", 5, 5, 5, 7), + Token(TokenType.IDENTIFIER, "i", 5, 9, 5, 9), + Token(TokenType.ASSIGN, "=", 5, 11, 5, 11), + Token(TokenType.INT_LITERAL, "0", 5, 13, 5, 13), + Token(TokenType.SEMICOLON, ";", 5, 14, 5, 14), + Token(TokenType.IDENTIFIER, "i", 5, 16, 5, 16), + Token(TokenType.LESS, "<", 5, 18, 5, 18), + Token(TokenType.INT_LITERAL, "10", 5, 20, 5, 21), + Token(TokenType.SEMICOLON, ";", 5, 22, 5, 22), + Token(TokenType.IDENTIFIER, "i", 5, 24, 5, 24), + Token(TokenType.ASSIGN, "=", 5, 26, 5, 26), + Token(TokenType.IDENTIFIER, "i", 5, 28, 5, 28), + Token(TokenType.PLUS, "+", 5, 30, 5, 30), + Token(TokenType.INT_LITERAL, "1", 5, 32, 5, 32), + Token(TokenType.RIGHT_PAREN, ")", 5, 33, 5, 33), // a = a + 1; - Token(TokenType.IDENTIFIER, "a", 6, 4), - Token(TokenType.ASSIGN, "=", 6, 6), - Token(TokenType.IDENTIFIER, "a", 6, 8), - Token(TokenType.PLUS, "+", 6, 10), - Token(TokenType.INT_LITERAL, "1", 6, 12), - Token(TokenType.SEMICOLON, ";", 6, 13), + Token(TokenType.IDENTIFIER, "a", 6, 4, 6, 4), + Token(TokenType.ASSIGN, "=", 6, 6, 6, 6), + Token(TokenType.IDENTIFIER, "a", 6, 8, 6, 8), + Token(TokenType.PLUS, "+", 6, 10, 6, 10), + Token(TokenType.INT_LITERAL, "1", 6, 12, 6, 12), + Token(TokenType.SEMICOLON, ";", 6, 13, 6, 13), // for(;a > 0;) - Token(TokenType.KEYWORD_FOR, "for", 7, 1), - Token(TokenType.LEFT_PAREN, "(", 7, 4), - Token(TokenType.SEMICOLON, ";", 7, 5), - Token(TokenType.IDENTIFIER, "a", 7, 6), - Token(TokenType.GREATER, ">", 7, 8), - Token(TokenType.INT_LITERAL, "0", 7, 10), - Token(TokenType.SEMICOLON, ";", 7, 11), - Token(TokenType.RIGHT_PAREN, ")", 7, 12), + Token(TokenType.KEYWORD_FOR, "for", 7, 1, 7, 3), + Token(TokenType.LEFT_PAREN, "(", 7, 4, 7, 4), + Token(TokenType.SEMICOLON, ";", 7, 5, 7, 5), + Token(TokenType.IDENTIFIER, "a", 7, 6, 7, 6), + Token(TokenType.GREATER, ">", 7, 8, 7, 8), + Token(TokenType.INT_LITERAL, "0", 7, 10, 7, 10), + Token(TokenType.SEMICOLON, ";", 7, 11, 7, 11), + Token(TokenType.RIGHT_PAREN, ")", 7, 12, 7, 12), // a = a + 1; - Token(TokenType.IDENTIFIER, "a", 8, 4), - Token(TokenType.ASSIGN, "=", 8, 6), - Token(TokenType.IDENTIFIER, "a", 8, 8), - Token(TokenType.PLUS, "+", 8, 10), - Token(TokenType.INT_LITERAL, "1", 8, 12), - Token(TokenType.SEMICOLON, ";", 8, 13), + Token(TokenType.IDENTIFIER, "a", 8, 4, 8, 4), + Token(TokenType.ASSIGN, "=", 8, 6, 8, 6), + Token(TokenType.IDENTIFIER, "a", 8, 8, 8, 8), + Token(TokenType.PLUS, "+", 8, 10, 8, 10), + Token(TokenType.INT_LITERAL, "1", 8, 12, 8, 12), + Token(TokenType.SEMICOLON, ";", 8, 13, 8, 13), // do - Token(TokenType.KEYWORD_DO, "do", 9, 1), + Token(TokenType.KEYWORD_DO, "do", 9, 1, 9, 2), // a = a + 1; - Token(TokenType.IDENTIFIER, "a", 10, 4), - Token(TokenType.ASSIGN, "=", 10, 6), - Token(TokenType.IDENTIFIER, "a", 10, 8), - Token(TokenType.PLUS, "+", 10, 10), - Token(TokenType.INT_LITERAL, "1", 10, 12), - Token(TokenType.SEMICOLON, ";", 10, 13), + Token(TokenType.IDENTIFIER, "a", 10, 4, 10, 4), + Token(TokenType.ASSIGN, "=", 10, 6, 10, 6), + Token(TokenType.IDENTIFIER, "a", 10, 8, 10, 8), + Token(TokenType.PLUS, "+", 10, 10, 10, 10), + Token(TokenType.INT_LITERAL, "1", 10, 12, 10, 12), + Token(TokenType.SEMICOLON, ";", 10, 13, 10, 13), // while(a > 0); - Token(TokenType.KEYWORD_WHILE, "while", 11, 1), - Token(TokenType.LEFT_PAREN, "(", 11, 6), - Token(TokenType.IDENTIFIER, "a", 11, 7), - Token(TokenType.GREATER, ">", 11, 9), - Token(TokenType.INT_LITERAL, "0", 11, 11), - Token(TokenType.RIGHT_PAREN, ")", 11, 12), - Token(TokenType.SEMICOLON, ";", 11, 13), + Token(TokenType.KEYWORD_WHILE, "while", 11, 1, 11, 5), + Token(TokenType.LEFT_PAREN, "(", 11, 6, 11, 6), + Token(TokenType.IDENTIFIER, "a", 11, 7, 11, 7), + Token(TokenType.GREATER, ">", 11, 9, 11, 9), + Token(TokenType.INT_LITERAL, "0", 11, 11, 11, 11), + Token(TokenType.RIGHT_PAREN, ")", 11, 12, 11, 12), + Token(TokenType.SEMICOLON, ";", 11, 13, 11, 13), // return 0; - Token(TokenType.KEYWORD_RETURN, "return", 12, 1), - Token(TokenType.INT_LITERAL, "0", 12, 8), - Token(TokenType.SEMICOLON, ";", 12, 9), - Token(TokenType.RIGHT_BRACK, "}", 13, 1), - Token(TokenType.EOF, "", 13, 2) + Token(TokenType.KEYWORD_RETURN, "return", 12, 1, 12, 6), + Token(TokenType.INT_LITERAL, "0", 12, 8, 12, 8), + Token(TokenType.SEMICOLON, ";", 12, 9, 12, 9), + Token(TokenType.RIGHT_BRACK, "}", 13, 1, 13, 1), + Token(TokenType.EOF, "", 13, 2, 13, 2) ), expectedAst = SimpleProgram( - functionDeclaration = - listOf( + location = TEST_LOCATION, + functionDeclaration = listOf( FunctionDeclaration( + location = TEST_LOCATION, name = "main", params = emptyList(), - body = - Block( - block = + body = Block( + location = TEST_LOCATION, + items = listOf( - D(VarDecl(VariableDeclaration(name = "a.0", init = IntExpression(10)))), + D(VarDecl(VariableDeclaration(location = TEST_LOCATION, name = "a.0", init = IntExpression(10, location = TEST_LOCATION)))), S( WhileStatement( - condition = BinaryExpression( - left = VariableExpression("a.0"), - operator = Token(TokenType.GREATER, ">", 3, 9), - right = IntExpression(0) + VariableExpression("a.0", location = TEST_LOCATION), + Token(TokenType.GREATER, ">", 3, 9, 3, 9), + IntExpression(0, location = TEST_LOCATION), + location = TEST_LOCATION ), - body = ExpressionStatement( AssignmentExpression( - lvalue = VariableExpression("a.0"), - rvalue = + VariableExpression("a.0", location = TEST_LOCATION), BinaryExpression( - left = VariableExpression("a.0"), - operator = Token(TokenType.NEGATION, "-", 4, 10), - right = IntExpression(1) - ) - ) + VariableExpression("a.0", location = TEST_LOCATION), + Token(TokenType.NEGATION, "-", 4, 10, 4, 10), + IntExpression(1, location = TEST_LOCATION), + location = TEST_LOCATION + ), + location = TEST_LOCATION + ), + location = TEST_LOCATION ), - label = "loop.0" + "loop.0", + location = TEST_LOCATION ) ), S( ForStatement( - init = InitDeclaration( - VariableDeclaration(name = "i.1", init = IntExpression(0)) + VariableDeclaration("i.1", IntExpression(0, location = TEST_LOCATION), location = TEST_LOCATION), + location = TEST_LOCATION ), - condition = BinaryExpression( - left = VariableExpression("i.1"), - operator = Token(TokenType.LESS, "<", 5, 18), - right = IntExpression(10) + VariableExpression("i.1", location = TEST_LOCATION), + Token(TokenType.LESS, "<", 5, 18, 5, 18), + IntExpression(10, location = TEST_LOCATION), + location = TEST_LOCATION ), - post = AssignmentExpression( - lvalue = VariableExpression("i.1"), - rvalue = + VariableExpression("i.1", location = TEST_LOCATION), BinaryExpression( - left = VariableExpression("i.1"), - operator = Token(TokenType.PLUS, "+", 5, 30), - right = IntExpression(1) - ) + VariableExpression("i.1", location = TEST_LOCATION), + Token(TokenType.PLUS, "+", 5, 30, 5, 30), + IntExpression(1, location = TEST_LOCATION), + location = TEST_LOCATION + ), + location = TEST_LOCATION ), - body = ExpressionStatement( AssignmentExpression( - lvalue = VariableExpression("a.0"), - rvalue = + VariableExpression("a.0", location = TEST_LOCATION), BinaryExpression( - left = VariableExpression("a.0"), - operator = Token(TokenType.PLUS, "+", 6, 10), - right = IntExpression(1) - ) - ) + VariableExpression("a.0", location = TEST_LOCATION), + Token(TokenType.PLUS, "+", 6, 10, 6, 10), + IntExpression(1, location = TEST_LOCATION), + location = TEST_LOCATION + ), + location = TEST_LOCATION + ), + location = TEST_LOCATION ), - label = "loop.1" + "loop.1", + location = TEST_LOCATION ) ), S( ForStatement( - init = InitExpression( - expression = null + null, + location = TEST_LOCATION ), - condition = BinaryExpression( - left = VariableExpression("a.0"), - operator = Token(TokenType.GREATER, ">", 7, 8), - right = IntExpression(0) + VariableExpression("a.0", location = TEST_LOCATION), + Token(TokenType.GREATER, ">", 7, 8, 7, 8), + IntExpression(0, location = TEST_LOCATION), + location = TEST_LOCATION ), - post = null, - body = + null, ExpressionStatement( AssignmentExpression( - lvalue = VariableExpression("a.0"), - rvalue = + VariableExpression("a.0", location = TEST_LOCATION), BinaryExpression( - left = VariableExpression("a.0"), - operator = Token(TokenType.PLUS, "+", 8, 10), - right = IntExpression(1) - ) - ) + VariableExpression("a.0", location = TEST_LOCATION), + Token(TokenType.PLUS, "+", 8, 10, 8, 10), + IntExpression(1, location = TEST_LOCATION), + location = TEST_LOCATION + ), + location = TEST_LOCATION + ), + location = TEST_LOCATION ), - label = "loop.2" + "loop.2", + location = TEST_LOCATION ) ), S( DoWhileStatement( - condition = BinaryExpression( - left = VariableExpression("a.0"), - operator = Token(TokenType.GREATER, ">", 11, 9), - right = IntExpression(0) + VariableExpression("a.0", location = TEST_LOCATION), + Token(TokenType.GREATER, ">", 11, 9, 11, 9), + IntExpression(0, location = TEST_LOCATION), + location = TEST_LOCATION ), - body = ExpressionStatement( AssignmentExpression( - lvalue = VariableExpression("a.0"), - rvalue = + VariableExpression("a.0", location = TEST_LOCATION), BinaryExpression( - left = VariableExpression("a.0"), - operator = Token(TokenType.PLUS, "+", 10, 10), - right = IntExpression(1) - ) - ) + VariableExpression("a.0", location = TEST_LOCATION), + Token(TokenType.PLUS, "+", 10, 10, 10, 10), + IntExpression(1, location = TEST_LOCATION), + location = TEST_LOCATION + ), + location = TEST_LOCATION + ), + location = TEST_LOCATION ), - label = "loop.3" + "loop.3", + location = TEST_LOCATION ) ), - S(ReturnStatement(expression = IntExpression(0))) + S(ReturnStatement(location = TEST_LOCATION, expression = IntExpression(0, location = TEST_LOCATION))) ) ) ) diff --git a/src/jsTest/kotlin/parser/LabelAnalysisTest.kt b/src/jsTest/kotlin/parser/LabelAnalysisTest.kt index 9683988..bfd7f8a 100644 --- a/src/jsTest/kotlin/parser/LabelAnalysisTest.kt +++ b/src/jsTest/kotlin/parser/LabelAnalysisTest.kt @@ -7,6 +7,9 @@ import kotlin.test.Test import kotlin.test.assertFailsWith import kotlin.test.assertTrue +// Helper constant for test locations +private val DUMMY_LOC = SourceLocation(1, 1, 1, 1) + class LabelAnalysisTest { private val labelAnalysis = LabelCollector.LabelAnalysis() @@ -15,33 +18,35 @@ class LabelAnalysisTest { // Arrange: A program with a forward jump and a backward jump. val ast: ASTNode = SimpleProgram( - functionDeclaration = - listOf( + functionDeclaration = listOf( FunctionDeclaration( name = "main", params = emptyList(), - body = - Block( - block = - listOf( - S(GotoStatement("end")), + body = Block( + items = listOf( + S(GotoStatement("end", DUMMY_LOC)), S( LabeledStatement( label = "start", - statement = ExpressionStatement(IntExpression(1)) + statement = ExpressionStatement(IntExpression(1, DUMMY_LOC), DUMMY_LOC), + location = DUMMY_LOC ) ), - S(GotoStatement("start")), + S(GotoStatement("start", DUMMY_LOC)), S( LabeledStatement( label = "end", - statement = ReturnStatement(IntExpression(0)) + statement = ReturnStatement(IntExpression(0, DUMMY_LOC), DUMMY_LOC), + location = DUMMY_LOC ) ) - ) - ) + ), + location = DUMMY_LOC + ), + location = DUMMY_LOC ) - ) + ), + location = DUMMY_LOC ) // Act & Assert: This should complete without throwing an exception. @@ -54,31 +59,33 @@ class LabelAnalysisTest { // Arrange: A program where the same label is defined twice. val ast: ASTNode = SimpleProgram( - functionDeclaration = - listOf( + functionDeclaration = listOf( FunctionDeclaration( name = "main", params = emptyList(), - body = - Block( - block = - listOf( + body = Block( + items = listOf( S( LabeledStatement( label = "my_label", - statement = NullStatement() + statement = NullStatement(DUMMY_LOC), + location = DUMMY_LOC ) ), S( LabeledStatement( label = "my_label", - statement = ReturnStatement(IntExpression(0)) + statement = ReturnStatement(IntExpression(0, DUMMY_LOC), DUMMY_LOC), + location = DUMMY_LOC ) ) - ) - ) + ), + location = DUMMY_LOC + ), + location = DUMMY_LOC ) - ) + ), + location = DUMMY_LOC ) // Act & Assert: Expect the analysis to fail with the specific exception. @@ -92,21 +99,21 @@ class LabelAnalysisTest { // Arrange: A program with a goto that targets a non-existent label. val ast: ASTNode = SimpleProgram( - functionDeclaration = - listOf( + functionDeclaration = listOf( FunctionDeclaration( name = "main", params = emptyList(), - body = - Block( - block = - listOf( - S(GotoStatement("missing_label")), - S(ReturnStatement(IntExpression(0))) - ) - ) + body = Block( + items = listOf( + S(GotoStatement("missing_label", DUMMY_LOC)), + S(ReturnStatement(IntExpression(0, DUMMY_LOC), DUMMY_LOC)) + ), + location = DUMMY_LOC + ), + location = DUMMY_LOC ) - ) + ), + location = DUMMY_LOC ) // Act & Assert: Expect the analysis to fail with the specific exception. @@ -120,32 +127,35 @@ class LabelAnalysisTest { // Arrange: A program where labels are nested inside an if statement. val ast: ASTNode = SimpleProgram( - listOf( + functionDeclaration = listOf( FunctionDeclaration( name = "main", params = emptyList(), - body = - Block( - listOf( + body = Block( + items = listOf( S( IfStatement( - condition = IntExpression(1), - then = - LabeledStatement( + condition = IntExpression(1, DUMMY_LOC), + then = LabeledStatement( label = "nested_label", - statement = ReturnStatement(IntExpression(1)) + statement = ReturnStatement(IntExpression(1, DUMMY_LOC), DUMMY_LOC), + location = DUMMY_LOC ), - _else = null + _else = null, + location = DUMMY_LOC ) ), - S(GotoStatement("nested_label")) - ) - ) + S(GotoStatement("nested_label", DUMMY_LOC)) + ), + location = DUMMY_LOC + ), + location = DUMMY_LOC ) - ) + ), + location = DUMMY_LOC ) - // Act & Assert: This should complete successfully without throwing an exception. + // Act & Assert: This should complete successfully. labelAnalysis.analyze(ast) assertTrue(true, "Analysis of nested labels should complete successfully.") } diff --git a/src/jsTest/kotlin/parser/VariableResolutionTest.kt b/src/jsTest/kotlin/parser/VariableResolutionTest.kt index 0efb97b..cc2ae57 100644 --- a/src/jsTest/kotlin/parser/VariableResolutionTest.kt +++ b/src/jsTest/kotlin/parser/VariableResolutionTest.kt @@ -7,116 +7,127 @@ import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith +// Helper constant for test locations +val VAR_TEST_LOCATION = SourceLocation(1, 1, 1, 1) + class VariableResolutionTest { + private val identifierResolution = IdentifierResolution() + @Test fun testVariableRenamingAndResolution() { - val ast: ASTNode = + val ast = SimpleProgram( - functionDeclaration = - listOf( + functionDeclaration = listOf( FunctionDeclaration( name = "main", params = emptyList(), - body = - Block( - listOf( - D(VarDecl(VariableDeclaration(name = "a", init = null))), + body = Block( + items = listOf( + D(VarDecl(VariableDeclaration("a", null, VAR_TEST_LOCATION))), S( ExpressionStatement( AssignmentExpression( - lvalue = VariableExpression("a"), - rvalue = IntExpression(1) - ) + VariableExpression("a", VAR_TEST_LOCATION), + IntExpression(1, VAR_TEST_LOCATION), + VAR_TEST_LOCATION + ), + VAR_TEST_LOCATION ) ), - S(ReturnStatement(expression = VariableExpression("a"))) - ) - ) + S(ReturnStatement(VariableExpression("a", VAR_TEST_LOCATION), VAR_TEST_LOCATION)) + ), + VAR_TEST_LOCATION + ), + VAR_TEST_LOCATION ) - ) + ), + VAR_TEST_LOCATION ) - val resolved = IdentifierResolution().analyze(ast as SimpleProgram) as SimpleProgram + // Act + val resolved = identifierResolution.analyze(ast) - val expected: ASTNode = + // Assert: Also add the dummy location to the expected output + val expected = SimpleProgram( - functionDeclaration = - listOf( + functionDeclaration = listOf( FunctionDeclaration( name = "main", params = emptyList(), - body = - Block( - block = - listOf( - D(VarDecl(VariableDeclaration(name = "a.0", init = null))), + body = Block( + items = listOf( + D(VarDecl(VariableDeclaration("a.0", null, VAR_TEST_LOCATION))), S( ExpressionStatement( AssignmentExpression( - lvalue = VariableExpression("a.0"), - rvalue = IntExpression(1) - ) + VariableExpression("a.0", VAR_TEST_LOCATION), + IntExpression(1, VAR_TEST_LOCATION), + VAR_TEST_LOCATION + ), + VAR_TEST_LOCATION ) ), - S(ReturnStatement(expression = VariableExpression("a.0"))) - ) - ) + S(ReturnStatement(VariableExpression("a.0", VAR_TEST_LOCATION), VAR_TEST_LOCATION)) + ), + VAR_TEST_LOCATION + ), + VAR_TEST_LOCATION ) - ) + ), + VAR_TEST_LOCATION ) + assertEquals(expected, resolved) } @Test fun testDuplicateDeclarationThrows() { - val ast: ASTNode = + val ast = SimpleProgram( - functionDeclaration = - listOf( + functionDeclaration = listOf( FunctionDeclaration( - name = "main", - params = emptyList(), - body = + "main", + emptyList(), Block( - block = listOf( - D(VarDecl(VariableDeclaration(name = "a", init = null))), - D(VarDecl(VariableDeclaration(name = "a", init = null))) - ) - ) + D(VarDecl(VariableDeclaration("a", null, VAR_TEST_LOCATION))), + D(VarDecl(VariableDeclaration("a", null, VAR_TEST_LOCATION))) + ), + VAR_TEST_LOCATION + ), + VAR_TEST_LOCATION ) - ) + ), + VAR_TEST_LOCATION ) - // Act & Assert assertFailsWith { - IdentifierResolution().analyze(ast as SimpleProgram) + identifierResolution.analyze(ast) } } @Test fun testUndeclaredVariableThrows() { - val ast: ASTNode = + val ast = SimpleProgram( - functionDeclaration = - listOf( + functionDeclaration = listOf( FunctionDeclaration( - name = "main", - params = emptyList(), - body = + "main", + emptyList(), Block( - block = listOf( - S(ReturnStatement(expression = VariableExpression("x"))) - ) - ) + S(ReturnStatement(VariableExpression("x", VAR_TEST_LOCATION), VAR_TEST_LOCATION)) + ), + VAR_TEST_LOCATION + ), + VAR_TEST_LOCATION ) - ) + ), + VAR_TEST_LOCATION ) - // Act & Assert assertFailsWith { - IdentifierResolution().analyze(ast as SimpleProgram) + identifierResolution.analyze(ast) } } }