diff --git a/.idea/misc.xml b/.idea/misc.xml
index 2c3c021..06d5bb2 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -4,7 +4,7 @@
-
+
\ No newline at end of file
diff --git a/README.md b/README.md
index 3588036..966eb68 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@ parsing, and multiple other phases.
- ``jsMain``: contains the core compiler logic
- ``jsTest``: test directory for the main logic
-These two packages are compiled to JS and used for production
+These `js` packages are compiled to JavaScript and used for production.
- ``jvmMain`` and ``jvmTest``: generated automatically through building the project. These packages are copied versions of ``jsMain`` and ``jsTest`` without js-specific code and are used only to generate test coverage reports, since Kover (the plugin we use to generate test reports) only supports JVM-compatible Kotlin code.
@@ -39,4 +39,21 @@ To generate a test coverage report,
``./gradlew koverHtmlReport``
- _\* These two commands are also part of the build command_
\ No newline at end of file
+_All of these commands are also part of the build command_
+
+More test cases are found in the test suite of the book "Writing a C Compiler" by Nora Sandler. The test suite is also included in this project and can be run by following the steps below.
+
+
+1. Build the compiler
+
+``./gradlew build``
+
+2. Create a jar file
+
+``./gradlew createCompilerJar``
+
+3. Run the test script
+
+``cd src/resources/write_a_c_compiler-tests ./test_compiler_kotlin.sh ../../../build/libs/compiler-1.0-SNAPSHOT.jar && cd ../../..``
+
+For more information, see the [test suite's README](https://github.com/nlsandler/write_a_c_compiler/blob/master/README.md).
\ No newline at end of file
diff --git a/build.gradle.kts b/build.gradle.kts
index 3bb7bf3..bb4efe1 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -130,3 +130,77 @@ tasks.named("build") {
tasks.named("koverHtmlReport") {
dependsOn("jsTest", "jvmTest")
}
+
+// Task to create the main class
+tasks.register("compileMainClass") {
+ group = "build"
+ description = "Compiles the main class for the JAR"
+ val tempDir = temporaryDir
+ val mainClassFile = File(tempDir, "CompilerMain.java")
+ // Create the main class file during configuration
+ val mainClassContent =
+ """
+package compiler;
+
+import java.io.File;
+import java.nio.file.Files;
+
+public class CompilerMain {
+ public static void main(String[] args) {
+ if (args.length == 0) {
+ System.out.println("Usage: java -jar compiler.jar ");
+ System.exit(1);
+ }
+
+ File inputFile = new File(args[0]);
+ if (!inputFile.exists()) {
+ System.out.println("Error: File " + args[0] + " does not exist");
+ System.exit(1);
+ }
+
+ try {
+ String sourceCode = new String(Files.readAllBytes(inputFile.toPath()));
+ CompilerWorkflow.Companion.fullCompile(sourceCode);
+ System.exit(0);
+ } catch (Exception e) {
+ System.err.println("Exception: " + e.getMessage());
+ e.printStackTrace();
+ System.exit(1);
+ }
+ }
+}
+ """.trimIndent()
+
+ mainClassFile.writeText(mainClassContent)
+
+ source = fileTree(tempDir) { include("**/*.java") }
+ destinationDirectory = file("$buildDir/classes/java/main")
+ classpath = kotlin.jvm().compilations["main"].runtimeDependencyFiles + files(kotlin.jvm().compilations["main"].output)
+
+ dependsOn("jvmMainClasses")
+}
+
+// Create executable JAR for JVM target
+tasks.register("createCompilerJar") {
+ group = "build"
+ description = "Creates an executable JAR for the JVM target"
+ from(kotlin.jvm().compilations["main"].output)
+ from("$buildDir/classes/java/main") {
+ include("compiler/CompilerMain.class")
+ }
+ archiveBaseName.set("compiler")
+ archiveClassifier.set("")
+ manifest {
+ attributes["Main-Class"] = "compiler.CompilerMain"
+ }
+ dependsOn("jvmMainClasses", "compileMainClass")
+ // Include all dependencies in the JAR
+ from(
+ kotlin
+ .jvm()
+ .compilations["main"]
+ .runtimeDependencyFiles
+ .map { if (it.isDirectory) it else zipTree(it) }
+ )
+ duplicatesStrategy = DuplicatesStrategy.EXCLUDE
+}
diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar
index 249e583..1b33c55 100644
Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ
diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties
index 623ed80..ca025c8 100644
--- a/gradle/wrapper/gradle-wrapper.properties
+++ b/gradle/wrapper/gradle-wrapper.properties
@@ -1,6 +1,7 @@
-#Mon Jul 21 19:49:00 CEST 2025
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
-distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip
+distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-bin.zip
+networkTimeout=10000
+validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
diff --git a/gradlew b/gradlew
index 1b6c787..23d15a9 100755
--- a/gradlew
+++ b/gradlew
@@ -15,6 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+# SPDX-License-Identifier: Apache-2.0
+#
##############################################################################
#
@@ -55,7 +57,7 @@
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
-# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
+# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
@@ -80,13 +82,11 @@ do
esac
done
-APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
-
-APP_NAME="Gradle"
+# This is normally unused
+# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
-
-# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
-DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
+APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
@@ -114,7 +114,7 @@ case "$( uname )" in #(
NONSTOP* ) nonstop=true ;;
esac
-CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+CLASSPATH="\\\"\\\""
# Determine the Java command to use to start the JVM.
@@ -133,22 +133,29 @@ location of your Java installation."
fi
else
JAVACMD=java
- which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+ if ! command -v java >/dev/null 2>&1
+ then
+ die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
+ fi
fi
# Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
+ # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
+ # shellcheck disable=SC2039,SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
+ # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
+ # shellcheck disable=SC2039,SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
@@ -193,18 +200,28 @@ if "$cygwin" || "$msys" ; then
done
fi
-# Collect all arguments for the java command;
-# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
-# shell script including quotes and variable substitutions, so put them in
-# double quotes to make sure that they get re-expanded; and
-# * put everything else in single quotes, so that it's not re-expanded.
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+
+# Collect all arguments for the java command:
+# * DEFAULT_JVM_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments,
+# and any embedded shellness will be escaped.
+# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be
+# treated as '${Hostname}' itself on the command line.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \
- org.gradle.wrapper.GradleWrapperMain \
+ -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \
"$@"
+# Stop when "xargs" is not available.
+if ! command -v xargs >/dev/null 2>&1
+then
+ die "xargs is not available"
+fi
+
# Use "xargs" to parse quoted args.
#
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
diff --git a/gradlew.bat b/gradlew.bat
index ac1b06f..db3a6ac 100644
--- a/gradlew.bat
+++ b/gradlew.bat
@@ -1,89 +1,94 @@
-@rem
-@rem Copyright 2015 the original author or authors.
-@rem
-@rem Licensed under the Apache License, Version 2.0 (the "License");
-@rem you may not use this file except in compliance with the License.
-@rem You may obtain a copy of the License at
-@rem
-@rem https://www.apache.org/licenses/LICENSE-2.0
-@rem
-@rem Unless required by applicable law or agreed to in writing, software
-@rem distributed under the License is distributed on an "AS IS" BASIS,
-@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-@rem See the License for the specific language governing permissions and
-@rem limitations under the License.
-@rem
-
-@if "%DEBUG%" == "" @echo off
-@rem ##########################################################################
-@rem
-@rem Gradle startup script for Windows
-@rem
-@rem ##########################################################################
-
-@rem Set local scope for the variables with windows NT shell
-if "%OS%"=="Windows_NT" setlocal
-
-set DIRNAME=%~dp0
-if "%DIRNAME%" == "" set DIRNAME=.
-set APP_BASE_NAME=%~n0
-set APP_HOME=%DIRNAME%
-
-@rem Resolve any "." and ".." in APP_HOME to make it shorter.
-for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
-
-@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
-set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
-
-@rem Find java.exe
-if defined JAVA_HOME goto findJavaFromJavaHome
-
-set JAVA_EXE=java.exe
-%JAVA_EXE% -version >NUL 2>&1
-if "%ERRORLEVEL%" == "0" goto execute
-
-echo.
-echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
-echo.
-echo Please set the JAVA_HOME variable in your environment to match the
-echo location of your Java installation.
-
-goto fail
-
-:findJavaFromJavaHome
-set JAVA_HOME=%JAVA_HOME:"=%
-set JAVA_EXE=%JAVA_HOME%/bin/java.exe
-
-if exist "%JAVA_EXE%" goto execute
-
-echo.
-echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
-echo.
-echo Please set the JAVA_HOME variable in your environment to match the
-echo location of your Java installation.
-
-goto fail
-
-:execute
-@rem Setup the command line
-
-set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
-
-
-@rem Execute Gradle
-"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
-
-:end
-@rem End local scope for the variables with windows NT shell
-if "%ERRORLEVEL%"=="0" goto mainEnd
-
-:fail
-rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
-rem the _cmd.exe /c_ return code!
-if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
-exit /b 1
-
-:mainEnd
-if "%OS%"=="Windows_NT" endlocal
-
-:omega
+@rem
+@rem Copyright 2015 the original author or authors.
+@rem
+@rem Licensed under the Apache License, Version 2.0 (the "License");
+@rem you may not use this file except in compliance with the License.
+@rem You may obtain a copy of the License at
+@rem
+@rem https://www.apache.org/licenses/LICENSE-2.0
+@rem
+@rem Unless required by applicable law or agreed to in writing, software
+@rem distributed under the License is distributed on an "AS IS" BASIS,
+@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+@rem See the License for the specific language governing permissions and
+@rem limitations under the License.
+@rem
+@rem SPDX-License-Identifier: Apache-2.0
+@rem
+
+@if "%DEBUG%"=="" @echo off
+@rem ##########################################################################
+@rem
+@rem Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+set DIRNAME=%~dp0
+if "%DIRNAME%"=="" set DIRNAME=.
+@rem This is normally unused
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Resolve any "." and ".." in APP_HOME to make it shorter.
+for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if %ERRORLEVEL% equ 0 goto execute
+
+echo. 1>&2
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2
+echo. 1>&2
+echo Please set the JAVA_HOME variable in your environment to match the 1>&2
+echo location of your Java installation. 1>&2
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto execute
+
+echo. 1>&2
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2
+echo. 1>&2
+echo Please set the JAVA_HOME variable in your environment to match the 1>&2
+echo location of your Java installation. 1>&2
+
+goto fail
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=
+
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %*
+
+:end
+@rem End local scope for the variables with windows NT shell
+if %ERRORLEVEL% equ 0 goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+set EXIT_CODE=%ERRORLEVEL%
+if %EXIT_CODE% equ 0 set EXIT_CODE=1
+if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
+exit /b %EXIT_CODE%
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/src/jsMain/kotlin/CompilerWorkflow.kt b/src/jsMain/kotlin/CompilerWorkflow.kt
index 2d6b66a..7808619 100644
--- a/src/jsMain/kotlin/CompilerWorkflow.kt
+++ b/src/jsMain/kotlin/CompilerWorkflow.kt
@@ -3,12 +3,15 @@ package compiler
import assembly.AsmConstruct
import assembly.InstructionFixer
import assembly.PseudoEliminator
+import assembly.TackyToAsm
import lexer.Lexer
import lexer.Token
import optimizations.ConstantFolding
import optimizations.ControlFlowGraph
+import optimizations.CopyPropagation
import optimizations.DeadStoreElimination
import optimizations.OptimizationType
+import optimizations.UnreachableCodeElimination
import parser.ASTNode
import parser.Parser
import parser.SimpleProgram
@@ -19,12 +22,12 @@ import semanticAnalysis.TypeChecker
import tacky.TackyConstruct
import tacky.TackyGenVisitor
import tacky.TackyProgram
-import tacky.TackyToAsm
enum class CompilerStage {
LEXER,
PARSER,
TACKY,
+ OPTIMIZATIONS,
ASSEMBLY
}
@@ -41,17 +44,30 @@ sealed class CompilerWorkflow {
private val pseudoEliminator = PseudoEliminator()
private val constantFolding = ConstantFolding()
private val deadStoreElimination = DeadStoreElimination()
+ private val copyPropagation = CopyPropagation()
+ private val unreachableCodeElimination = UnreachableCodeElimination()
fun fullCompile(code: String): Map {
val tokens = take(code)
val ast = take(tokens)
val tacky = take(ast)
- val asm = take(tacky as TackyProgram)
+ val optimizedTacky =
+ take(
+ tacky as TackyProgram,
+ listOf(
+ OptimizationType.B_CONSTANT_FOLDING,
+ OptimizationType.D_DEAD_STORE_ELIMINATION,
+ OptimizationType.C_UNREACHABLE_CODE_ELIMINATION,
+ OptimizationType.A_COPY_PROPAGATION
+ )
+ )
+ val asm = take(optimizedTacky)
return mapOf(
CompilerStage.LEXER to tokens,
CompilerStage.PARSER to ast,
CompilerStage.TACKY to tacky,
+ CompilerStage.OPTIMIZATIONS to optimizedTacky,
CompilerStage.ASSEMBLY to asm
)
}
@@ -75,17 +91,31 @@ sealed class CompilerWorkflow {
return tacky
}
- fun take(tacky: TackyProgram, optimizations: Set): TackyProgram {
+ fun take(
+ tackyProgram: TackyProgram,
+ optimizations: List
+ ): TackyProgram {
+ val tacky = tackyProgram.deepCopy()
tacky.functions.forEach {
- var cfg = ControlFlowGraph().construct(it.name, it.body)
- for (optimization in optimizations) {
- if (optimization == OptimizationType.CONSTANT_FOLDING) {
- cfg = constantFolding.apply(cfg)
- } else if (optimization == OptimizationType.DEAD_STORE_ELIMINATION) {
- cfg = deadStoreElimination.apply(cfg)
+ while (true) {
+ var cfg = ControlFlowGraph().construct(it.name, it.body)
+ for (optimization in optimizations.sorted()) {
+ if (optimization == OptimizationType.B_CONSTANT_FOLDING) {
+ cfg = constantFolding.apply(cfg)
+ } else if (optimization == OptimizationType.D_DEAD_STORE_ELIMINATION) {
+ cfg = deadStoreElimination.apply(cfg)
+ } else if (optimization == OptimizationType.A_COPY_PROPAGATION) {
+ cfg = copyPropagation.apply(cfg)
+ } else {
+ cfg = unreachableCodeElimination.apply(cfg)
+ }
+ }
+ val optimizedBody = cfg.toInstructions()
+ it.body = optimizedBody
+ if (optimizedBody == it.body || optimizedBody.isEmpty()) {
+ break
}
}
- it.body = cfg.toInstructions()
}
return tacky
}
diff --git a/src/jsMain/kotlin/assembly/CodeEmitter.kt b/src/jsMain/kotlin/assembly/CodeEmitter.kt
index 193fff9..3e9b2b0 100644
--- a/src/jsMain/kotlin/assembly/CodeEmitter.kt
+++ b/src/jsMain/kotlin/assembly/CodeEmitter.kt
@@ -21,6 +21,7 @@ class CodeEmitter {
private fun emitFunction(function: AsmFunction): String {
val functionName = formatLabel(function.name)
val bodyAsm = function.body.joinToString("\n") { emitInstruction(it) }
+ val endsWithRet = function.body.lastOrNull() is Ret
return buildString {
appendLine(" .globl $functionName")
@@ -30,9 +31,11 @@ class CodeEmitter {
if (bodyAsm.isNotEmpty()) {
appendLine(bodyAsm)
}
- appendLine(" mov rsp, rbp")
- appendLine(" pop rbp")
- append(" ret")
+ if (!endsWithRet) {
+ appendLine(" mov rsp, rbp")
+ appendLine(" pop rbp")
+ append(" ret")
+ }
}
}
@@ -68,7 +71,7 @@ class CodeEmitter {
RawInstruction("${indent}set${instruction.condition.text} $destOperand", instruction.sourceId)
}
- is Ret -> RawInstruction("", instruction.sourceId)
+ is Ret -> RawInstruction("${indent}mov rsp, rbp\n${indent}pop rbp\n${indent}ret", instruction.sourceId)
}
}
@@ -95,7 +98,7 @@ class CodeEmitter {
val destOperand = emitOperand(instruction.dest, size = OperandSize.BYTE)
"${indent}set${instruction.condition.text} $destOperand"
}
- is Ret -> ""
+ is Ret -> "${indent}mov rsp, rbp\n${indent}pop rbp\n${indent}ret"
}
}
diff --git a/src/jsMain/kotlin/tacky/TackyToAsm.kt b/src/jsMain/kotlin/assembly/TackyToAsm.kt
similarity index 92%
rename from src/jsMain/kotlin/tacky/TackyToAsm.kt
rename to src/jsMain/kotlin/assembly/TackyToAsm.kt
index 8871554..cae75f5 100644
--- a/src/jsMain/kotlin/tacky/TackyToAsm.kt
+++ b/src/jsMain/kotlin/assembly/TackyToAsm.kt
@@ -1,32 +1,22 @@
-package tacky
-
-import assembly.AllocateStack
-import assembly.AsmBinary
-import assembly.AsmBinaryOp
-import assembly.AsmFunction
-import assembly.AsmProgram
-import assembly.AsmUnary
-import assembly.AsmUnaryOp
-import assembly.Call
-import assembly.Cdq
-import assembly.Cmp
-import assembly.ConditionCode
-import assembly.DeAllocateStack
-import assembly.HardwareRegister
-import assembly.Idiv
-import assembly.Imm
-import assembly.Instruction
-import assembly.Jmp
-import assembly.JmpCC
-import assembly.Label
-import assembly.Mov
-import assembly.Operand
-import assembly.Pseudo
-import assembly.Push
-import assembly.Register
-import assembly.Ret
-import assembly.SetCC
-import assembly.Stack
+package assembly
+
+import tacky.JumpIfNotZero
+import tacky.JumpIfZero
+import tacky.TackyBinary
+import tacky.TackyBinaryOP
+import tacky.TackyConstant
+import tacky.TackyCopy
+import tacky.TackyFunCall
+import tacky.TackyFunction
+import tacky.TackyInstruction
+import tacky.TackyJump
+import tacky.TackyLabel
+import tacky.TackyProgram
+import tacky.TackyRet
+import tacky.TackyUnary
+import tacky.TackyUnaryOP
+import tacky.TackyVal
+import tacky.TackyVar
class TackyToAsm {
fun convert(tackyProgram: TackyProgram): AsmProgram {
@@ -42,7 +32,10 @@ class TackyToAsm {
return AsmFunction(tackyFunc.name, paramSetupInstructions + bodyInstructions)
}
- private fun generateParamSetup(params: List, sourceId: String): List {
+ private fun generateParamSetup(
+ params: List,
+ sourceId: String
+ ): List {
val instructions = mutableListOf()
val argRegisters =
listOf(
@@ -197,6 +190,12 @@ class TackyToAsm {
instructions.add(AllocateStack(stackPadding, tackyInstr.sourceId))
}
+ // Pass arguments in registers
+ registerArgs.forEachIndexed { index, arg ->
+ val asmArg = convertVal(arg)
+ instructions.add(Mov(asmArg, Register(argRegisters[index]), tackyInstr.sourceId))
+ }
+
// Pass arguments on the stack in reverse order
stackArgs.asReversed().forEach { arg ->
val asmArg = convertVal(arg)
@@ -208,12 +207,6 @@ class TackyToAsm {
}
}
- // Pass arguments in registers
- registerArgs.forEachIndexed { index, arg ->
- val asmArg = convertVal(arg)
- instructions.add(Mov(asmArg, Register(argRegisters[index]), tackyInstr.sourceId))
- }
-
instructions.add(Call(tackyInstr.funName, tackyInstr.sourceId))
// Clean up stack
diff --git a/src/jsMain/kotlin/exceptions/CompilationExceptions.kt b/src/jsMain/kotlin/exceptions/CompilationExceptions.kt
index cf39b54..8cf05f8 100644
--- a/src/jsMain/kotlin/exceptions/CompilationExceptions.kt
+++ b/src/jsMain/kotlin/exceptions/CompilationExceptions.kt
@@ -19,12 +19,21 @@ sealed class CompilationException(
}
}
-class LexicalException(
+// Lexer
+class InvalidCharacterException(
character: Char,
line: Int? = null,
column: Int? = null
-) : CompilationException("LexicalException(Invalid character '$character')", line, column)
+) : CompilationException("InvalidCharacterException('$character' is not a valid character)", line, column)
+class UnexpectedCharacterException(
+ expected: String,
+ actual: String,
+ line: Int? = null,
+ column: Int? = null
+) : CompilationException("UnexpectedCharacterException(Expected '$expected', got '$actual')", line, column)
+
+// Parser
class UnexpectedTokenException(
val expected: String,
val actual: String,
@@ -71,12 +80,6 @@ class InvalidStatementException(
column: Int? = null
) : CompilationException("InvalidStatementException($message)", line, column)
-class TackyException(
- operator: String,
- line: Int? = null,
- column: Int? = null
-) : CompilationException("TackyException(Invalid operator: $operator)", line, column)
-
class NestedFunctionException(
line: Int? = null,
column: Int? = null
@@ -92,19 +95,23 @@ class IncompatibleFuncDeclarationException(
name: String,
line: Int? = null,
column: Int? = null
-) : CompilationException("Function '$name' redeclared with a different number of parameters.", line, column)
+) : CompilationException(
+ "IncompatibleFuncDeclarationException(Function '$name' redeclared with a different number of parameters.)",
+ line,
+ column
+)
-class NotFunctionException(
+class NotAFunctionException(
name: String,
line: Int? = null,
column: Int? = null
-) : CompilationException("Cannot call '$name' because it is not a function.", line, column)
+) : CompilationException("NotAFunctionException(Cannot call '$name' because it is not a function.)", line, column)
-class NotVariableException(
+class NotAVariableException(
name: String,
line: Int? = null,
column: Int? = null
-) : CompilationException("Cannot use function '$name' as a variable.", line, column)
+) : CompilationException("NotAVariableException(Cannot use function '$name' as a variable.)", line, column)
class ArgumentCountException(
name: String,
@@ -112,10 +119,25 @@ class ArgumentCountException(
actual: Int,
line: Int? = null,
column: Int? = null
-) : CompilationException("Wrong number of arguments for function '$name'. Expected $expected, got $actual.", line, column)
+) : CompilationException(
+ "ArgumentCountException(Wrong number of arguments for function '$name'. Expected $expected, got $actual.)",
+ line,
+ column
+)
class IllegalStateException(
name: String,
line: Int? = null,
column: Int? = null
-) : CompilationException("Internal error: Variable '$name' should have been caught by IdentifierResolution.")
+) : CompilationException(
+ "IllegalStateException(Internal error: Variable '$name' should have been caught by IdentifierResolution.)",
+ line,
+ column
+)
+
+// TACKY
+class TackyException(
+ operator: String,
+ line: Int? = null,
+ column: Int? = null
+) : CompilationException("TackyException(Invalid operator: $operator)", line, column)
diff --git a/src/jsMain/kotlin/export/ASTExport.kt b/src/jsMain/kotlin/export/ASTExport.kt
index 90f38e2..bcb164a 100644
--- a/src/jsMain/kotlin/export/ASTExport.kt
+++ b/src/jsMain/kotlin/export/ASTExport.kt
@@ -4,6 +4,7 @@ import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
+import parser.ASTVisitor
import parser.AssignmentExpression
import parser.BinaryExpression
import parser.Block
@@ -32,7 +33,6 @@ import parser.UnaryExpression
import parser.VarDecl
import parser.VariableDeclaration
import parser.VariableExpression
-import parser.Visitor
import parser.WhileStatement
fun createJsonNode(
@@ -43,22 +43,24 @@ fun createJsonNode(
location: parser.SourceLocation? = null,
id: String? = null
): JsonObject {
- val nodeMap = mutableMapOf(
- "type" to JsonPrimitive(type),
- "label" to JsonPrimitive(label),
- "children" to children,
- "edgeLabels" to JsonPrimitive(edgeLabels)
- )
+ val nodeMap =
+ mutableMapOf(
+ "type" to JsonPrimitive(type),
+ "label" to JsonPrimitive(label),
+ "children" to children,
+ "edgeLabels" to JsonPrimitive(edgeLabels)
+ )
location?.let {
- nodeMap["location"] = JsonObject(
- mapOf(
- "startLine" to JsonPrimitive(it.startLine),
- "startCol" to JsonPrimitive(it.startCol),
- "endLine" to JsonPrimitive(it.endLine),
- "endCol" to JsonPrimitive(it.endCol)
+ nodeMap["location"] =
+ JsonObject(
+ mapOf(
+ "startLine" to JsonPrimitive(it.startLine),
+ "startCol" to JsonPrimitive(it.startCol),
+ "endLine" to JsonPrimitive(it.endLine),
+ "endCol" to JsonPrimitive(it.endCol)
+ )
)
- )
}
id?.let {
@@ -80,7 +82,7 @@ enum class NodeType {
Declaration
}
-class ASTExport : Visitor {
+class ASTExport : ASTVisitor {
override fun visit(node: SimpleProgram): JsonObject {
val decls = JsonArray(node.functionDeclaration.map { it.accept(this) })
return createJsonNode(NodeType.Program.name, "Program", JsonObject(mapOf("declarations" to decls)), false, node.location, node.id)
@@ -96,11 +98,14 @@ class ASTExport : Visitor {
return createJsonNode(NodeType.Statement.name, "ExpressionStatement", children, false, node.location, node.id)
}
- override fun visit(node: NullStatement): JsonObject = createJsonNode(NodeType.Statement.name, "NullStatement", JsonObject(emptyMap()), false, node.location, node.id)
+ override fun visit(node: NullStatement): JsonObject =
+ createJsonNode(NodeType.Statement.name, "NullStatement", JsonObject(emptyMap()), false, node.location, node.id)
- override fun visit(node: BreakStatement): JsonObject = createJsonNode(NodeType.Statement.name, "BreakStatement", JsonObject(emptyMap()), false, node.location, node.id)
+ override fun visit(node: BreakStatement): JsonObject =
+ createJsonNode(NodeType.Statement.name, "BreakStatement", JsonObject(emptyMap()), false, node.location, node.id)
- override fun visit(node: ContinueStatement): JsonObject = createJsonNode(NodeType.Statement.name, "continue", JsonObject(emptyMap()), false, node.location, node.id)
+ override fun visit(node: ContinueStatement): JsonObject =
+ createJsonNode(NodeType.Statement.name, "continue", JsonObject(emptyMap()), false, node.location, node.id)
override fun visit(node: WhileStatement): JsonObject {
val children =
@@ -126,7 +131,7 @@ class ASTExport : Visitor {
override fun visit(node: ForStatement): JsonObject {
val childrenMap =
- mutableMapOf(
+ mutableMapOf(
"init" to node.init.accept(this)
)
node.condition?.let { childrenMap["cond"] = it.accept(this) }
@@ -142,7 +147,7 @@ class ASTExport : Visitor {
}
override fun visit(node: InitExpression): JsonObject {
- val childrenMap = mutableMapOf()
+ val childrenMap = mutableMapOf()
node.expression?.let { childrenMap["expression"] = it.accept(this) }
return createJsonNode(NodeType.Expression.name, "Expression", JsonObject(childrenMap), false, node.location, node.id)
}
@@ -174,7 +179,13 @@ class ASTExport : Visitor {
"expression" to node.expression.accept(this)
)
)
- return createJsonNode(NodeType.Expression.name, "UnaryExpression(${node.operator.type})", children, location = node.location, id = node.id)
+ return createJsonNode(
+ NodeType.Expression.name,
+ "UnaryExpression(${node.operator.type})",
+ children,
+ location = node.location,
+ id = node.id
+ )
}
override fun visit(node: BinaryExpression): JsonObject {
@@ -185,10 +196,18 @@ class ASTExport : Visitor {
"right" to node.right.accept(this)
)
)
- return createJsonNode(NodeType.Expression.name, "BinaryExpression(${node.operator.type})", children, edgeLabels = true, location = node.location, id = node.id)
+ return createJsonNode(
+ NodeType.Expression.name,
+ "BinaryExpression(${node.operator.type})",
+ children,
+ edgeLabels = true,
+ location = node.location,
+ id = node.id
+ )
}
- override fun visit(node: IntExpression): JsonObject = createJsonNode(NodeType.Expression.name, "Int(${node.value})", JsonObject(emptyMap()), false, node.location, node.id)
+ override fun visit(node: IntExpression): JsonObject =
+ createJsonNode(NodeType.Expression.name, "Int(${node.value})", JsonObject(emptyMap()), false, node.location, node.id)
override fun visit(node: IfStatement): JsonObject {
val childrenMap =
@@ -204,7 +223,7 @@ class ASTExport : Visitor {
val children =
JsonObject(
mapOf(
- "cond" to node.codition.accept(this),
+ "cond" to node.condition.accept(this),
"then" to node.thenExpression.accept(this),
"else" to node.elseExpression.accept(this)
)
@@ -242,16 +261,24 @@ class ASTExport : Visitor {
override fun visit(node: VariableDeclaration): JsonObject {
val childrenMap = mutableMapOf("name" to JsonPrimitive(node.name))
node.init?.let { childrenMap["initializer"] = it.accept(this) }
- return createJsonNode(NodeType.Declaration.name, "Declaration(${node.name})", JsonObject(childrenMap), false, node.location, node.id)
+ return createJsonNode(
+ NodeType.Declaration.name,
+ "Declaration(${node.name})",
+ JsonObject(childrenMap),
+ false,
+ node.location,
+ node.id
+ )
}
override fun visit(node: S): JsonObject = node.statement.accept(this)
- override fun visit(node: D): JsonObject = when (node.declaration) {
- is VarDecl -> node.declaration.accept(this)
- is FunDecl -> node.declaration.accept(this)
- is VariableDeclaration -> node.declaration.accept(this)
- }
+ override fun visit(node: D): JsonObject =
+ when (node.declaration) {
+ is VarDecl -> node.declaration.accept(this)
+ is FunDecl -> node.declaration.accept(this)
+ is VariableDeclaration -> node.declaration.accept(this)
+ }
override fun visit(node: Block): JsonObject {
val blockItems = node.items.map { it.accept(this) }
diff --git a/src/jsMain/kotlin/export/CompilationOutput.kt b/src/jsMain/kotlin/export/CompilationOutput.kt
index a796299..9843a90 100644
--- a/src/jsMain/kotlin/export/CompilationOutput.kt
+++ b/src/jsMain/kotlin/export/CompilationOutput.kt
@@ -44,8 +44,19 @@ data class TackyOutput(
override val stage: String = CompilerStage.TACKY.name.lowercase(),
val tacky: String? = null,
val tackyPretty: String? = null,
+ override val errors: Array,
+ val sourceLocation: SourceLocationInfo? = null
+) : CompilationOutput()
+
+@OptIn(ExperimentalJsExport::class)
+@JsExport
+@Serializable
+@SerialName("OptimizationOutput")
+data class OptimizationOutput(
+ override val stage: String = CompilerStage.OPTIMIZATIONS.name.lowercase(),
val precomputedCFGs: String = "",
- val optimizations: Array = arrayOf("CONSTANT_FOLDING", "DEAD_STORE_ELIMINATION", "COPY_PROPAGATION", "UNREACHABLE_CODE_ELIMINATION"),
+ val optimizations: Array =
+ arrayOf("CONSTANT_FOLDING", "DEAD_STORE_ELIMINATION", "COPY_PROPAGATION", "UNREACHABLE_CODE_ELIMINATION"),
val functionNames: Array = emptyArray(),
override val errors: Array,
val sourceLocation: SourceLocationInfo? = null
@@ -59,6 +70,7 @@ data class AssemblyOutput(
override val stage: String = CompilerStage.ASSEMBLY.name.lowercase(),
val assembly: String? = null,
val rawAssembly: String? = null,
+ val precomputedAssembly: String = "",
override val errors: Array,
val sourceLocation: SourceLocationInfo? = null
) : CompilationOutput()
diff --git a/src/jsMain/kotlin/export/CompilerExport.kt b/src/jsMain/kotlin/export/CompilerExport.kt
index dfd287d..4ba351a 100644
--- a/src/jsMain/kotlin/export/CompilerExport.kt
+++ b/src/jsMain/kotlin/export/CompilerExport.kt
@@ -38,7 +38,8 @@ data class CFGEdge(
data class CFGExport(
val functionName: String,
val nodes: List,
- val edges: List
+ val edges: List,
+ val instructionCount: Int
)
@Serializable
@@ -48,10 +49,15 @@ data class CFGEntry(
val cfg: String
)
+@Serializable
+data class AssemblyEntry(
+ val optimizations: List,
+ val asmCode: String
+)
+
@OptIn(ExperimentalJsExport::class)
@JsExport
class CompilerExport {
-
private fun calculateSourceLocationInfo(code: String): SourceLocationInfo {
val lines = code.split('\n')
val totalLines = lines.size
@@ -96,21 +102,39 @@ class CompilerExport {
outputs.add(
TackyOutput(
tackyPretty = tackyProgram.toPseudoCode(),
- functionNames = tackyProgram.functions.map { it.name }.toTypedArray(),
- precomputedCFGs = precomputeAllCFGs(tackyProgram),
errors = emptyArray(),
tacky = Json.encodeToString(tackyProgram),
sourceLocation = sourceLocationInfo
)
)
- val asm = CompilerWorkflow.take(tacky as TackyProgram)
+ val cfgs = precomputeAllCFGs(tackyProgram)
+ outputs.add(
+ OptimizationOutput(
+ precomputedCFGs = cfgs,
+ functionNames = tackyProgram.functions.map { it.name }.toTypedArray(),
+ errors = emptyArray()
+ )
+ )
+ val optimizedTacky =
+ CompilerWorkflow.take(
+ tacky,
+ optimizations =
+ listOf(
+ OptimizationType.B_CONSTANT_FOLDING,
+ OptimizationType.D_DEAD_STORE_ELIMINATION,
+ OptimizationType.A_COPY_PROPAGATION,
+ OptimizationType.C_UNREACHABLE_CODE_ELIMINATION
+ )
+ )
+ val asm = CompilerWorkflow.take(optimizedTacky)
val finalAssemblyString = codeEmitter.emit(asm as AsmProgram)
- val rawAssembly = codeEmitter.emitRaw(asm as AsmProgram)
+ val rawAssembly = codeEmitter.emitRaw(asm)
outputs.add(
AssemblyOutput(
errors = emptyArray(),
assembly = finalAssemblyString,
rawAssembly = rawAssembly,
+ precomputedAssembly = precomputeAllAssembly(tackyProgram),
sourceLocation = sourceLocationInfo
)
)
@@ -135,6 +159,7 @@ class CompilerExport {
CompilerStage.PARSER -> outputs.add(ParserOutput(errors = arrayOf(error), sourceLocation = sourceLocationInfo))
CompilerStage.TACKY -> outputs.add(TackyOutput(errors = arrayOf(error), sourceLocation = sourceLocationInfo))
CompilerStage.ASSEMBLY -> outputs.add(AssemblyOutput(errors = arrayOf(error), sourceLocation = sourceLocationInfo))
+ CompilerStage.OPTIMIZATIONS -> outputs.add(OptimizationOutput(errors = arrayOf(error), sourceLocation = sourceLocationInfo))
}
} catch (e: Exception) {
// Fallback for any unexpected runtime errors
@@ -163,33 +188,65 @@ class CompilerExport {
}
private fun precomputeAllCFGs(program: TackyProgram): String {
- val allOptSets = generateOptimizationCombinations()
- val cfgs = program.functions.filter { it.body.isNotEmpty() }.flatMap { fn ->
- allOptSets.map { optSet ->
- try {
- val cfg = ControlFlowGraph().construct(fn.name, fn.body)
- val types = optSet.mapNotNull(optTypeMap::get).toSet()
- val optimized = OptimizationManager.applyOptimizations(cfg, types)
- CFGEntry(fn.name, optSet.sorted(), exportControlFlowGraph(optimized))
- } catch (_: Exception) {
- CFGEntry(fn.name, optSet.sorted(), createEmptyCFGJson(fn.name))
+ val allOptLists = generateOptimizationCombinations()
+ val cfgs =
+ program.functions.filter { it.body.isNotEmpty() }.flatMap { fn ->
+ allOptLists.map { optList ->
+ try {
+ val cfg = ControlFlowGraph().construct(fn.name, fn.body)
+ val types = optList.mapNotNull(optTypeMap::get)
+ val optimized = OptimizationManager.applyOptimizations(cfg, types)
+ CFGEntry(fn.name, optList, exportControlFlowGraph(optimized))
+ } catch (_: Exception) {
+ CFGEntry(fn.name, optList.sorted(), createEmptyCFGJson(fn.name))
+ }
}
}
- }
return Json.encodeToString(cfgs)
}
- private val optTypeMap = mapOf(
- "CONSTANT_FOLDING" to OptimizationType.CONSTANT_FOLDING,
- "DEAD_STORE_ELIMINATION" to OptimizationType.DEAD_STORE_ELIMINATION,
- "COPY_PROPAGATION" to OptimizationType.COPY_PROPAGATION,
- "UNREACHABLE_CODE_ELIMINATION" to OptimizationType.UNREACHABLE_CODE_ELIMINATION
- )
+ private fun precomputeAllAssembly(program: TackyProgram): String {
+ val allOptLists = generateOptimizationCombinations()
+ val assemblies = mutableListOf()
+
+ for (optList in allOptLists) {
+ val optimizedProgram =
+ TackyProgram(
+ functions =
+ program.functions.map { function ->
+ if (function.body.isNotEmpty()) {
+ val cfg = ControlFlowGraph().construct(function.name, function.body)
+ val optimizedCfg = OptimizationManager.applyOptimizations(cfg, optList.mapNotNull(optTypeMap::get))
+ function.copy(body = optimizedCfg.toInstructions())
+ } else {
+ function
+ }
+ }
+ )
+
+ val asm = CompilerWorkflow.take(optimizedProgram)
+ val finalAssemblyString = CodeEmitter().emit(asm as AsmProgram)
+
+ // Create one entry per optimization set with the full program assembly
+ assemblies.add(AssemblyEntry(optList, finalAssemblyString))
+ }
+
+ val result = Json.encodeToString(assemblies)
+ return result
+ }
+
+ private val optTypeMap =
+ mapOf(
+ "CONSTANT_FOLDING" to OptimizationType.B_CONSTANT_FOLDING,
+ "DEAD_STORE_ELIMINATION" to OptimizationType.D_DEAD_STORE_ELIMINATION,
+ "COPY_PROPAGATION" to OptimizationType.A_COPY_PROPAGATION,
+ "UNREACHABLE_CODE_ELIMINATION" to OptimizationType.C_UNREACHABLE_CODE_ELIMINATION
+ )
- private fun generateOptimizationCombinations(): List> {
- val opts = optTypeMap.keys.toList()
+ private fun generateOptimizationCombinations(): List> {
+ val opts = optTypeMap.keys.sorted()
return (0 until (1 shl opts.size)).map { mask ->
- opts.filterIndexed { i, _ -> mask and (1 shl i) != 0 }.toSet()
+ opts.filterIndexed { i, _ -> mask and (1 shl i) != 0 }
}
}
@@ -198,57 +255,57 @@ class CompilerExport {
CFGExport(
functionName = fn,
nodes = listOf(CFGNode("entry", "Entry", "entry"), CFGNode("exit", "Exit", "exit")),
- edges = listOf(CFGEdge("entry", "exit"))
+ edges = listOf(CFGEdge("entry", "exit")),
+ instructionCount = 0
)
)
private fun exportControlFlowGraph(cfg: ControlFlowGraph): String {
val nodes = mutableListOf(CFGNode("entry", "Entry", "entry"))
- nodes += cfg.blocks.mapIndexed { i, block ->
- val id = "block_$i"
- val label = block.instructions.joinToString(";\n") { it.toPseudoCode(0) }.ifEmpty { "Empty Block" }
- CFGNode(id, label, "block")
- }
+ nodes +=
+ cfg.blocks.mapIndexed { i, block ->
+ val id = "block_$i"
+ val label = block.instructions.joinToString("\n") { it.toPseudoCode(0) }.ifEmpty { "Empty Block" }
+ CFGNode(id, label, "block")
+ }
nodes += CFGNode("exit", "Exit", "exit")
- val edges = cfg.edges.map { edge ->
- val fromId = when (edge.from) {
- is START -> "entry"
- is EXIT -> "exit"
- is Block -> {
- // Find block by ID instead of object equality
- val index = cfg.blocks.indexOfFirst { it.id == edge.from.id }
- if (index >= 0) "block_$index" else "unknown_block"
- }
- else -> "unknown_block"
- }
- val toId = when (edge.to) {
- is START -> "entry"
- is EXIT -> "exit"
- is Block -> {
- // Find block by ID instead of object equality
- val index = cfg.blocks.indexOfFirst { it.id == edge.to.id }
- if (index >= 0) "block_$index" else "unknown_block"
- }
- else -> "unknown_block"
+ val edges =
+ cfg.edges.map { edge ->
+ val fromId =
+ when (edge.from) {
+ is START -> "entry"
+ is EXIT -> "exit"
+ is Block -> {
+ // Find block by ID instead of object equality
+ val index = cfg.blocks.indexOfFirst { it.id == edge.from.id }
+ if (index >= 0) "block_$index" else "unknown_block"
+ }
+ }
+ val toId =
+ when (edge.to) {
+ is START -> "entry"
+ is EXIT -> "exit"
+ is Block -> {
+ // Find block by ID instead of object equality
+ val index = cfg.blocks.indexOfFirst { it.id == edge.to.id }
+ if (index >= 0) "block_$index" else "unknown_block"
+ }
+ }
+ CFGEdge(fromId, toId)
}
- CFGEdge(fromId, toId)
- }
- return Json.encodeToString(CFGExport(cfg.functionName ?: "unknown", nodes, edges))
- }
+ // Calculate total instruction count across all blocks
+ val instructionCount = cfg.blocks.sumOf { it.instructions.size }
- private fun Any.toId(cfg: ControlFlowGraph): String = when (this) {
- is START -> "entry"
- is EXIT -> "exit"
- is Block -> {
- val index = cfg.blocks.indexOf(this)
- if (index >= 0) "block_$index" else "unknown_block"
- }
- else -> "unknown"
+ return Json.encodeToString(CFGExport(cfg.functionName ?: "unknown", nodes, edges, instructionCount))
}
- fun getCFGForFunction(precomputed: String?, fn: String, enabledOpts: Array): String {
+ fun getCFGForFunction(
+ precomputed: String?,
+ fn: String,
+ enabledOpts: Array
+ ): String {
if (precomputed == null) return createEmptyCFGJson(fn)
val sortedOpts = enabledOpts.sorted()
return try {
@@ -268,14 +325,15 @@ fun List.toJsonString(): String {
mapOf(
"type" to JsonPrimitive(token.type.toString()),
"lexeme" to JsonPrimitive(token.lexeme),
- "location" to JsonObject(
- mapOf(
- "startLine" to JsonPrimitive(token.startLine),
- "startCol" to JsonPrimitive(token.startColumn),
- "endLine" to JsonPrimitive(token.endLine),
- "endCol" to JsonPrimitive(token.endColumn)
+ "location" to
+ JsonObject(
+ mapOf(
+ "startLine" to JsonPrimitive(token.startLine),
+ "startCol" to JsonPrimitive(token.startColumn),
+ "endLine" to JsonPrimitive(token.endLine),
+ "endCol" to JsonPrimitive(token.endColumn)
+ )
)
- )
)
)
}
diff --git a/src/jsMain/kotlin/lexer/Lexer.kt b/src/jsMain/kotlin/lexer/Lexer.kt
index 543922c..388e32c 100644
--- a/src/jsMain/kotlin/lexer/Lexer.kt
+++ b/src/jsMain/kotlin/lexer/Lexer.kt
@@ -1,6 +1,7 @@
package lexer
-import exceptions.LexicalException
+import exceptions.InvalidCharacterException
+import exceptions.UnexpectedCharacterException
sealed class TokenType {
// keywords
@@ -189,11 +190,15 @@ class Lexer(
'&' -> {
if (match('&')) {
addToken(TokenType.AND)
+ } else {
+ throw UnexpectedCharacterException(char.toString(), "&", line, current - lineStart)
}
}
'|' -> {
if (match('|')) {
addToken(TokenType.OR)
+ } else {
+ throw UnexpectedCharacterException(char.toString(), "|", line, current - lineStart)
}
}
'=' -> {
@@ -237,7 +242,7 @@ class Lexer(
} else if (isAlphabetic(char)) {
identifier()
} else {
- throw LexicalException(char, line, current - lineStart)
+ throw InvalidCharacterException(char, line, current - lineStart)
}
}
}
diff --git a/src/jsMain/kotlin/optimizations/ConstantFolding.kt b/src/jsMain/kotlin/optimizations/ConstantFolding.kt
index 8ba7237..6b20068 100644
--- a/src/jsMain/kotlin/optimizations/ConstantFolding.kt
+++ b/src/jsMain/kotlin/optimizations/ConstantFolding.kt
@@ -14,60 +14,107 @@ import tacky.TackyUnaryOP
import tacky.TackyVal
class ConstantFolding : Optimization() {
- override val optimizationType = OptimizationType.CONSTANT_FOLDING
+ override val optimizationType = OptimizationType.B_CONSTANT_FOLDING
override fun apply(cfg: ControlFlowGraph): ControlFlowGraph {
- val optimizedBlocks = cfg.blocks.map { block ->
- val optimizedInstructions = block.instructions.mapNotNull { foldInstruction(it) }
- block.copy(instructions = optimizedInstructions)
- }
+ val optimizedBlocks =
+ cfg.blocks.map { block ->
+ val optimizedInstructions =
+ block.instructions.mapNotNull { instruction ->
+ val folded = foldInstruction(instruction)
+ if (folded == null && (instruction is JumpIfZero || instruction is JumpIfNotZero)) {
+ // Remove the edge to the target label since the condition is always false (folded = null)
+ val target =
+ when (instruction) {
+ is JumpIfZero -> instruction.target
+ is JumpIfNotZero -> instruction.target
+ else -> null
+ }
+ val labelBlock = cfg.blocks.find { it.instructions.first() == target }
+ if (labelBlock != null) {
+ block.successors.remove(labelBlock.id)
+ val edgeToRemove = cfg.edges.find { it.from.id == block.id && it.to.id == labelBlock.id }
+ if (edgeToRemove != null) {
+ cfg.edges.remove(edgeToRemove)
+ }
+ }
+ } else if (folded is TackyJump) {
+ // Remove the edge to the next block if next block ≠ target block, since we're now jumping unconditionally to target
+ val nextBlock = cfg.blocks.find { it.id == block.id + 1 }
+ val target =
+ when (instruction) {
+ is JumpIfZero -> instruction.target
+ is JumpIfNotZero -> instruction.target
+ else -> null
+ }
+ val labelBlock = cfg.blocks.find { it.instructions.first() == target }
+ if (nextBlock != null && nextBlock.id != labelBlock?.id) {
+ block.successors.remove(block.id + 1)
+ val edgeToRemove = cfg.edges.find { it.from.id == block.id && it.to.id == block.id + 1 }
+ if (edgeToRemove != null) {
+ cfg.edges.remove(edgeToRemove)
+ }
+ }
+ }
+ folded
+ }
+ block.copy(instructions = optimizedInstructions)
+ }
return cfg.copy(blocks = optimizedBlocks)
}
- private fun foldInstruction(instruction: TackyInstruction): TackyInstruction? = when (instruction) {
- is TackyUnary -> foldUnary(instruction)
- is TackyBinary -> foldBinary(instruction)
- is JumpIfZero -> foldJump(instruction.condition, expectZero = true, target = instruction.target)
- is JumpIfNotZero -> foldJump(instruction.condition, expectZero = false, target = instruction.target)
- else -> instruction
- }
+ private fun foldInstruction(instruction: TackyInstruction): TackyInstruction? =
+ when (instruction) {
+ is TackyUnary -> foldUnary(instruction)
+ is TackyBinary -> foldBinary(instruction)
+ is JumpIfZero -> foldJump(instruction.condition, expectZero = true, target = instruction.target)
+ is JumpIfNotZero -> foldJump(instruction.condition, expectZero = false, target = instruction.target)
+ else -> instruction
+ }
private fun foldUnary(inst: TackyUnary): TackyInstruction {
val src = inst.src as? TackyConstant ?: return inst
- val result = when (inst.operator) {
- TackyUnaryOP.COMPLEMENT -> src.value.inv()
- TackyUnaryOP.NEGATE -> -src.value
- TackyUnaryOP.NOT -> if (src.value == 0) 1 else 0
- }
- return TackyCopy(TackyConstant(result), inst.dest)
+ val result =
+ when (inst.operator) {
+ TackyUnaryOP.COMPLEMENT -> src.value.inv()
+ TackyUnaryOP.NEGATE -> -src.value
+ TackyUnaryOP.NOT -> if (src.value == 0) 1 else 0
+ }
+ return TackyCopy(TackyConstant(result), inst.dest, inst.sourceId)
}
private fun foldBinary(inst: TackyBinary): TackyInstruction {
val lhs = inst.src1 as? TackyConstant ?: return inst
val rhs = inst.src2 as? TackyConstant ?: return inst
- val result = when (inst.operator) {
- TackyBinaryOP.ADD -> lhs.value + rhs.value
- TackyBinaryOP.SUBTRACT -> lhs.value - rhs.value
- TackyBinaryOP.MULTIPLY -> lhs.value * rhs.value
- TackyBinaryOP.DIVIDE -> if (rhs.value == 0) return inst else lhs.value / rhs.value
- TackyBinaryOP.REMAINDER -> if (rhs.value == 0) return inst else lhs.value % rhs.value
- TackyBinaryOP.LESS -> if (lhs.value < rhs.value) 1 else 0
- TackyBinaryOP.GREATER -> if (lhs.value > rhs.value) 1 else 0
- TackyBinaryOP.LESS_EQUAL -> if (lhs.value <= rhs.value) 1 else 0
- TackyBinaryOP.GREATER_EQUAL -> if (lhs.value >= rhs.value) 1 else 0
- TackyBinaryOP.EQUAL -> if (lhs.value == rhs.value) 1 else 0
- TackyBinaryOP.NOT_EQUAL -> if (lhs.value != rhs.value) 1 else 0
- }
+ val result =
+ when (inst.operator) {
+ TackyBinaryOP.ADD -> lhs.value + rhs.value
+ TackyBinaryOP.SUBTRACT -> lhs.value - rhs.value
+ TackyBinaryOP.MULTIPLY -> lhs.value * rhs.value
+ TackyBinaryOP.DIVIDE -> if (rhs.value == 0) return inst else lhs.value / rhs.value
+ TackyBinaryOP.REMAINDER -> if (rhs.value == 0) return inst else lhs.value % rhs.value
+ TackyBinaryOP.LESS -> if (lhs.value < rhs.value) 1 else 0
+ TackyBinaryOP.GREATER -> if (lhs.value > rhs.value) 1 else 0
+ TackyBinaryOP.LESS_EQUAL -> if (lhs.value <= rhs.value) 1 else 0
+ TackyBinaryOP.GREATER_EQUAL -> if (lhs.value >= rhs.value) 1 else 0
+ TackyBinaryOP.EQUAL -> if (lhs.value == rhs.value) 1 else 0
+ TackyBinaryOP.NOT_EQUAL -> if (lhs.value != rhs.value) 1 else 0
+ }
- return TackyCopy(TackyConstant(result), inst.dest)
+ return TackyCopy(TackyConstant(result), inst.dest, inst.sourceId)
}
- private fun foldJump(condition: TackyVal, expectZero: Boolean, target: TackyLabel): TackyInstruction? {
- val constant = condition as? TackyConstant ?: return when (expectZero) {
- true -> JumpIfZero(condition, target)
- false -> JumpIfNotZero(condition, target)
- }
+ private fun foldJump(
+ condition: TackyVal,
+ expectZero: Boolean,
+ target: TackyLabel
+ ): TackyInstruction? {
+ val constant =
+ condition as? TackyConstant ?: return when (expectZero) {
+ true -> JumpIfZero(condition, target)
+ false -> JumpIfNotZero(condition, target)
+ }
return if ((constant.value == 0) == expectZero) {
TackyJump(target)
} else {
diff --git a/src/jsMain/kotlin/optimizations/ControlFlowGraph.kt b/src/jsMain/kotlin/optimizations/ControlFlowGraph.kt
index 9a05f74..670e2ca 100644
--- a/src/jsMain/kotlin/optimizations/ControlFlowGraph.kt
+++ b/src/jsMain/kotlin/optimizations/ControlFlowGraph.kt
@@ -14,14 +14,14 @@ sealed class CFGNode {
}
data class START(
- override val id: Int,
+ override val id: Int = -1,
override val successors: MutableList = mutableListOf()
) : CFGNode() {
override val predecessors: MutableList = mutableListOf()
}
data class EXIT(
- override val id: Int,
+ override val id: Int = -2,
override val predecessors: MutableList = mutableListOf()
) : CFGNode() {
override val successors: MutableList = mutableListOf()
@@ -34,15 +34,21 @@ data class Block(
override val successors: MutableList = mutableListOf()
) : CFGNode()
-data class Edge(val from: CFGNode, val to: CFGNode)
+data class Edge(
+ val from: CFGNode,
+ val to: CFGNode
+)
data class ControlFlowGraph(
val functionName: String? = null,
val root: CFGNode? = null,
val blocks: List = emptyList(),
- val edges: List = emptyList()
+ val edges: MutableList = mutableListOf()
) {
- fun construct(functionName: String, functionBody: List): ControlFlowGraph {
+ fun construct(
+ functionName: String,
+ functionBody: List
+ ): ControlFlowGraph {
val nodes = toBasicBlocks(functionBody)
val blocks = nodes.filterIsInstance()
val edges = buildEdges(nodes, blocks)
@@ -54,15 +60,14 @@ data class ControlFlowGraph(
)
}
- fun toInstructions(): List =
- blocks.flatMap { it.instructions }
+ fun toInstructions(): List = blocks.flatMap { it.instructions }
private fun toBasicBlocks(instructions: List): List {
val nodes = mutableListOf()
val current = mutableListOf()
var blockId = 0
- nodes += START(blockId++)
+ nodes += START()
for (inst in instructions) {
when (inst) {
@@ -86,23 +91,29 @@ data class ControlFlowGraph(
nodes += Block(blockId++, current.toList())
}
- nodes += EXIT(blockId++)
+ nodes += EXIT()
return nodes
}
- private fun buildEdges(nodes: List, blocks: List): List {
+ private fun buildEdges(
+ nodes: List,
+ blocks: List
+ ): MutableList {
val edges = mutableListOf()
val entry = nodes.filterIsInstance().firstOrNull()
val exit = nodes.filterIsInstance().firstOrNull()
- fun connect(from: CFGNode, to: CFGNode) {
+ fun connect(
+ from: CFGNode,
+ to: CFGNode
+ ) {
edges += Edge(from, to)
from.successors += to.id
to.predecessors += from.id
}
// entry -> first block
- blocks.firstOrNull()?.let { connect(entry!!, it) }
+ blocks.firstOrNull()?.let { connect(entry ?: return@let, it) }
for ((i, block) in blocks.withIndex()) {
val last = block.instructions.lastOrNull()
@@ -116,11 +127,12 @@ data class ControlFlowGraph(
}
is JumpIfZero, is JumpIfNotZero -> {
- val target = when (last) {
- is JumpIfZero -> last.target
- is JumpIfNotZero -> last.target
- else -> null
- }
+ val target =
+ when (last) {
+ is JumpIfZero -> last.target
+ is JumpIfNotZero -> last.target
+ else -> null
+ }
target?.let { t -> findBlockByLabel(blocks, t)?.let { connect(block, it) } }
next?.let { connect(block, next) }
}
@@ -138,6 +150,8 @@ data class ControlFlowGraph(
return edges
}
- private fun findBlockByLabel(blocks: List, label: TackyLabel): Block? =
- blocks.find { blk -> blk.instructions.any { it is TackyLabel && it.name == label.name } }
+ private fun findBlockByLabel(
+ blocks: List,
+ label: TackyLabel
+ ): Block? = blocks.find { blk -> blk.instructions.any { it is TackyLabel && it.name == label.name } }
}
diff --git a/src/jsMain/kotlin/optimizations/CopyPropagation.kt b/src/jsMain/kotlin/optimizations/CopyPropagation.kt
index 8fc6e79..18c227d 100644
--- a/src/jsMain/kotlin/optimizations/CopyPropagation.kt
+++ b/src/jsMain/kotlin/optimizations/CopyPropagation.kt
@@ -12,230 +12,104 @@ import tacky.TackyVal
import tacky.TackyVar
class CopyPropagation : Optimization() {
- // A map to store the set of copies reaching the *entry* of each block.
- private lateinit var inSets: MutableMap>
- // A map to store the set of copies reaching the *exit* of each block.
- private lateinit var outSets: MutableMap>
-
- // A map to store which copies reach each specific instruction. This is needed for the final rewrite.
- private val instructionReachingCopies = mutableMapOf>()
-
- override val optimizationType: OptimizationType = OptimizationType.COPY_PROPAGATION
+ override val optimizationType: OptimizationType = OptimizationType.A_COPY_PROPAGATION
override fun apply(cfg: ControlFlowGraph): ControlFlowGraph {
- val newBlocks = cfg.blocks.map { block ->
- val newInstructions = mutableListOf()
- val copyMap = mutableMapOf()
-
- for (instr in block.instructions) {
- when (instr) {
- is TackyCopy -> {
- if (instr.src is TackyVar && instr.src.name == instr.dest.name) {
- } else {
- val newSrc = if (instr.src is TackyVar && copyMap.containsKey(instr.src.name)) {
- copyMap[instr.src.name]!!
+ val newBlocks =
+ cfg.blocks.map { block ->
+ val newInstructions = mutableListOf()
+ val copyMap = mutableMapOf()
+
+ for (instr in block.instructions) {
+ when (instr) {
+ is TackyCopy -> {
+ if (instr.src is TackyVar && instr.src.name == instr.dest.name) {
} else {
- instr.src
+ val newSrc =
+ if (instr.src is TackyVar && copyMap.containsKey(instr.src.name)) {
+ copyMap[instr.src.name]!!
+ } else {
+ instr.src
+ }
+
+ copyMap[instr.dest.name] = newSrc
+ newInstructions.add(TackyCopy(newSrc, instr.dest, instr.sourceId))
}
-
- copyMap[instr.dest.name] = newSrc
- newInstructions.add(TackyCopy(newSrc, instr.dest))
}
- }
- is TackyRet -> {
- val newValue = if (instr.value is TackyVar && copyMap.containsKey(instr.value.name)) {
- copyMap[instr.value.name]!!
- } else {
- instr.value
+ is TackyRet -> {
+ val newValue =
+ if (instr.value is TackyVar && copyMap.containsKey(instr.value.name)) {
+ copyMap[instr.value.name]!!
+ } else {
+ instr.value
+ }
+ newInstructions.add(TackyRet(newValue, instr.sourceId))
}
- newInstructions.add(TackyRet(newValue))
- }
- is TackyUnary -> {
- val newSrc = if (instr.src is TackyVar && copyMap.containsKey(instr.src.name)) {
- copyMap[instr.src.name]!!
- } else {
- instr.src
+ is TackyUnary -> {
+ val newSrc =
+ if (instr.src is TackyVar && copyMap.containsKey(instr.src.name)) {
+ copyMap[instr.src.name]!!
+ } else {
+ instr.src
+ }
+ copyMap.remove(instr.dest.name)
+ newInstructions.add(TackyUnary(instr.operator, newSrc, instr.dest, instr.sourceId))
}
- copyMap.remove(instr.dest.name)
- newInstructions.add(TackyUnary(instr.operator, newSrc, instr.dest))
- }
- is TackyBinary -> {
- val newSrc1 = if (instr.src1 is TackyVar && copyMap.containsKey(instr.src1.name)) {
- copyMap[instr.src1.name]!!
- } else {
- instr.src1
+ is TackyBinary -> {
+ val newSrc1 =
+ if (instr.src1 is TackyVar && copyMap.containsKey(instr.src1.name)) {
+ copyMap[instr.src1.name]!!
+ } else {
+ instr.src1
+ }
+ val newSrc2 =
+ if (instr.src2 is TackyVar && copyMap.containsKey(instr.src2.name)) {
+ copyMap[instr.src2.name]!!
+ } else {
+ instr.src2
+ }
+ copyMap.remove(instr.dest.name)
+ newInstructions.add(TackyBinary(instr.operator, newSrc1, newSrc2, instr.dest, instr.sourceId))
}
- val newSrc2 = if (instr.src2 is TackyVar && copyMap.containsKey(instr.src2.name)) {
- copyMap[instr.src2.name]!!
- } else {
- instr.src2
+ is TackyFunCall -> {
+ val newArgs =
+ instr.args.map { arg ->
+ if (arg is TackyVar && copyMap.containsKey(arg.name)) {
+ copyMap[arg.name]!!
+ } else {
+ arg
+ }
+ }
+ copyMap.remove(instr.dest.name)
+ newInstructions.add(TackyFunCall(instr.funName, newArgs, instr.dest, instr.sourceId))
}
- copyMap.remove(instr.dest.name)
- newInstructions.add(TackyBinary(instr.operator, newSrc1, newSrc2, instr.dest))
- }
- is TackyFunCall -> {
- val newArgs = instr.args.map { arg ->
- if (arg is TackyVar && copyMap.containsKey(arg.name)) {
- copyMap[arg.name]!!
- } else {
- arg
- }
+ is JumpIfZero -> {
+ val newCondition =
+ if (instr.condition is TackyVar && copyMap.containsKey(instr.condition.name)) {
+ copyMap[instr.condition.name]!!
+ } else {
+ instr.condition
+ }
+ newInstructions.add(JumpIfZero(newCondition, instr.target, instr.sourceId))
}
- copyMap.remove(instr.dest.name)
- newInstructions.add(TackyFunCall(instr.funName, newArgs, instr.dest))
- }
- is JumpIfZero -> {
- val newCondition = if (instr.condition is TackyVar && copyMap.containsKey(instr.condition.name)) {
- copyMap[instr.condition.name]!!
- } else {
- instr.condition
+ is JumpIfNotZero -> {
+ val newCondition =
+ if (instr.condition is TackyVar && copyMap.containsKey(instr.condition.name)) {
+ copyMap[instr.condition.name]!!
+ } else {
+ instr.condition
+ }
+ newInstructions.add(JumpIfNotZero(newCondition, instr.target, instr.sourceId))
}
- newInstructions.add(JumpIfZero(newCondition, instr.target))
- }
- is JumpIfNotZero -> {
- val newCondition = if (instr.condition is TackyVar && copyMap.containsKey(instr.condition.name)) {
- copyMap[instr.condition.name]!!
- } else {
- instr.condition
+ else -> {
+ newInstructions.add(instr)
}
- newInstructions.add(JumpIfNotZero(newCondition, instr.target))
- }
- else -> {
- newInstructions.add(instr)
}
}
+ block.copy(instructions = newInstructions)
}
- block.copy(instructions = newInstructions)
- }
return cfg.copy(blocks = newBlocks)
}
-
- private fun runAnalysis(cfg: ControlFlowGraph) {
- outSets = mutableMapOf()
- instructionReachingCopies.clear()
-
- val allCopies = cfg.blocks.flatMap { it.instructions }.filterIsInstance().toSet()
-
- val worklist = cfg.blocks.toMutableList()
- cfg.blocks.forEach {
- outSets[it.id] = allCopies
- }
-
- while (worklist.isNotEmpty()) {
- val block = worklist.removeAt(0)
- val inSet = meet(block, allCopies)
- val newOut = transfer(block, inSet)
-
- if (newOut != outSets[block.id]) {
- outSets[block.id] = newOut
- block.successors.forEach { succId ->
- cfg.blocks.find { it.id == succId }?.let { successorBlock ->
- if (!worklist.contains(successorBlock)) {
- worklist.add(successorBlock)
- }
- }
- }
- }
- }
- }
-
- private fun meet(block: Block, allCopies: Set): Set {
- if (block.predecessors.all { it == 0 }) {
- return emptySet()
- }
-
- var incomingCopies: Set = allCopies
-
- for (predId in block.predecessors) {
- val predOutSet = outSets[predId]
- if (predOutSet != null) {
- incomingCopies = incomingCopies.intersect(predOutSet)
- }
- }
- return incomingCopies
- }
-
- private fun transfer(block: Block, inSet: Set): Set {
- var currentCopies = inSet.toMutableSet()
- for (instruction in block.instructions) {
- instructionReachingCopies[instruction] = currentCopies.toSet()
- val toRemove = mutableSetOf()
- when (instruction) {
- is TackyCopy -> {
- val destVar = instruction.dest.name
- // Kill all previous copies to or from the destination variable.
- currentCopies.forEach { if (it.src is TackyVar && it.src.name == destVar || it.dest.name == destVar) toRemove.add(it) }
- currentCopies.removeAll(toRemove)
- currentCopies.add(instruction)
- }
- is TackyUnary -> {
- val destVar = instruction.dest.name
- // Kill any copies to or from the destination variable.
- currentCopies.forEach { if (it.src is TackyVar && it.src.name == destVar || it.dest.name == destVar) toRemove.add(it) }
- currentCopies.removeAll(toRemove)
- }
- is TackyBinary -> {
- val destVar = instruction.dest.name
- // Kill any copies to or from the destination variable.
- currentCopies.forEach { if (it.src is TackyVar && it.src.name == destVar || it.dest.name == destVar) toRemove.add(it) }
- currentCopies.removeAll(toRemove)
- }
- is TackyFunCall -> {
- val destVar = instruction.dest.name
- // Kill any copies to or from the destination variable.
- currentCopies.forEach { if (it.src is TackyVar && it.src.name == destVar || it.dest.name == destVar) toRemove.add(it) }
- currentCopies.removeAll(toRemove)
- }
- else -> {}
- }
- }
- return currentCopies
- }
-
- private fun rewrite(instruction: TackyInstruction, reaching: Set): TackyInstruction {
- val substitutionMap = reaching.associate { it.dest.name to it.src }
-
- return when (instruction) {
- is TackyRet -> TackyRet(value = substitute(instruction.value, substitutionMap))
- is TackyUnary -> TackyUnary(
- operator = instruction.operator,
- src = substitute(instruction.src, substitutionMap),
- dest = instruction.dest
- )
- is TackyBinary -> TackyBinary(
- operator = instruction.operator,
- src1 = substitute(instruction.src1, substitutionMap),
- src2 = substitute(instruction.src2, substitutionMap),
- dest = instruction.dest
- )
- is TackyCopy -> TackyCopy(
- src = substitute(instruction.src, substitutionMap),
- dest = instruction.dest
- )
- is TackyFunCall -> TackyFunCall(
- funName = instruction.funName,
- args = instruction.args.map { substitute(it, substitutionMap) },
- dest = instruction.dest
- )
- is JumpIfZero -> JumpIfZero(
- condition = substitute(instruction.condition, substitutionMap),
- target = instruction.target
- )
- is JumpIfNotZero -> JumpIfNotZero(
- condition = substitute(instruction.condition, substitutionMap),
- target = instruction.target
- )
- else -> instruction
- }
- }
-
- private fun substitute(value: TackyVal, substitutionMap: Map): TackyVal {
- return if (value is TackyVar && substitutionMap.containsKey(value.name)) {
- substitutionMap.getValue(value.name)
- } else {
- value
- }
- }
}
diff --git a/src/jsMain/kotlin/optimizations/DeadStoreElimination.kt b/src/jsMain/kotlin/optimizations/DeadStoreElimination.kt
index 2ae30cc..9d688fd 100644
--- a/src/jsMain/kotlin/optimizations/DeadStoreElimination.kt
+++ b/src/jsMain/kotlin/optimizations/DeadStoreElimination.kt
@@ -5,7 +5,6 @@ import tacky.JumpIfZero
import tacky.TackyBinary
import tacky.TackyCopy
import tacky.TackyFunCall
-import tacky.TackyInstruction
import tacky.TackyJump
import tacky.TackyLabel
import tacky.TackyRet
@@ -13,133 +12,116 @@ import tacky.TackyUnary
import tacky.TackyVar
class DeadStoreElimination : Optimization() {
- override val optimizationType: OptimizationType = OptimizationType.DEAD_STORE_ELIMINATION
+ override val optimizationType: OptimizationType = OptimizationType.D_DEAD_STORE_ELIMINATION
override fun apply(cfg: ControlFlowGraph): ControlFlowGraph {
- val livenessAnalysis = LivenessAnalysis()
- val liveVariables = livenessAnalysis.analyze(cfg)
-
- val optimizedBlocks = cfg.blocks.map { block ->
- val optimizedInstructions = block.instructions.withIndex().filterNot { (idx, instr) ->
- isDeadStore(block.id, idx, instr, liveVariables)
- }.map { it.value }
-
- block.copy(instructions = optimizedInstructions)
- }
+ val liveness = LivenessAnalysis()
+ val liveAfter = liveness.analyze(cfg)
+
+ val optimizedBlocks =
+ cfg.blocks.map { block ->
+ val optimizedInstructions =
+ block.instructions
+ .withIndex()
+ .filterNot { (idx, instr) ->
+ when (instr) {
+ is TackyFunCall -> false
+ is TackyUnary, is TackyBinary, is TackyCopy -> {
+ val live = liveAfter[block.id to idx] ?: emptySet()
+ val dest =
+ when (instr) {
+ is TackyUnary -> instr.dest.name
+ is TackyBinary -> instr.dest.name
+ is TackyCopy -> instr.dest.name
+ else -> ""
+ }
+ dest !in live
+ }
+ else -> false
+ }
+ }.map { it.value }
+ block.copy(instructions = optimizedInstructions)
+ }
return cfg.copy(blocks = optimizedBlocks)
}
-
- internal fun isDeadStore(
- blockId: Int,
- idx: Int,
- instruction: TackyInstruction,
- liveVariables: Map, Set>
- ): Boolean {
- // Never eliminate function calls (side effects)
- if (instruction is TackyFunCall) return false
-
- // Only instructions with destinations are considered
- val dest = when (instruction) {
- is TackyUnary -> instruction.dest.name
- is TackyBinary -> instruction.dest.name
- is TackyCopy -> instruction.dest.name
- else -> return false
- }
-
- val liveAfter = liveVariables[blockId to idx] ?: emptySet()
- return dest !in liveAfter
- }
}
class LivenessAnalysis {
- fun analyze(cfg: ControlFlowGraph): Map, Set> {
- val allStaticVariables = extractStaticVariables(cfg)
- val blockOut = mutableMapOf>()
- val worklist = ArrayDeque()
-
- // init: all blocks start with empty live-out
- cfg.blocks.forEach { block ->
- blockOut[block.id] = emptySet()
- worklist.add(block.id)
- }
-
- // backward fixpoint
- while (worklist.isNotEmpty()) {
- val currentId = worklist.removeFirst()
- val currentBlock = cfg.blocks.find { it.id == currentId } ?: continue
-
- val succLive = currentBlock.successors.flatMap { succId ->
- blockOut[succId] ?: emptySet()
- }.toSet()
+ private val instructionAnnotations = mutableMapOf, Set>()
+ private val blockAnnotations = mutableMapOf>()
- if (succLive != blockOut[currentId]) {
- blockOut[currentId] = succLive
- currentBlock.predecessors.forEach { worklist.add(it) }
- }
+ fun analyze(cfg: ControlFlowGraph): Map, Set> {
+ for (block in cfg.blocks) {
+ blockAnnotations[block.id] = emptySet()
}
- // instruction-level liveness
- val instructionLiveVars = mutableMapOf, Set>()
-
- cfg.blocks.forEach { block ->
- var live = blockOut[block.id] ?: emptySet()
-
- block.instructions.withIndex().reversed().forEach { (idx, instr) ->
- instructionLiveVars[block.id to idx] = live
- live = transfer(instr, live, allStaticVariables)
+ val workList = cfg.blocks.map { it.id }.toMutableList()
+ while (workList.isNotEmpty()) {
+ val blockId = workList.removeFirst()
+ val block = cfg.blocks.find { it.id == blockId } ?: continue
+ val out = meet(block, emptySet())
+ val newIn = transfer(block, out)
+ if (newIn != blockAnnotations[block.id]) {
+ blockAnnotations[block.id] = newIn
+ workList.addAll(block.predecessors)
}
}
- return instructionLiveVars
+ return instructionAnnotations
}
- internal fun transfer(
- instruction: TackyInstruction,
+ private fun transfer(
+ block: Block,
liveAfter: Set,
- allStaticVariables: Set
+ staticVariables: Set = emptySet()
): Set {
val liveBefore = liveAfter.toMutableSet()
-
- when (instruction) {
- is TackyUnary -> {
- liveBefore.remove(instruction.dest.name)
- if (instruction.src is TackyVar) liveBefore.add(instruction.src.name)
- }
- is TackyBinary -> {
- liveBefore.remove(instruction.dest.name)
- if (instruction.src1 is TackyVar) liveBefore.add(instruction.src1.name)
- if (instruction.src2 is TackyVar) liveBefore.add(instruction.src2.name)
- }
- is TackyCopy -> {
- liveBefore.remove(instruction.dest.name)
- if (instruction.src is TackyVar) liveBefore.add(instruction.src.name)
- }
- is TackyFunCall -> {
- liveBefore.remove(instruction.dest.name)
- instruction.args.forEach { arg ->
- if (arg is TackyVar) liveBefore.add(arg.name)
+ block.instructions.withIndex().reversed().forEach { (idx, instruction) ->
+ instructionAnnotations[block.id to idx] = liveBefore.toSet()
+ when (instruction) {
+ is TackyUnary -> {
+ liveBefore.remove(instruction.dest.name)
+ if (instruction.src is TackyVar) liveBefore.add(instruction.src.name)
}
- liveBefore.addAll(allStaticVariables) // conservatively keep statics alive
- }
- is TackyRet -> {
- if (instruction.value is TackyVar) liveBefore.add(instruction.value.name)
- }
- is JumpIfZero -> {
- if (instruction.condition is TackyVar) liveBefore.add(instruction.condition.name)
- }
- is JumpIfNotZero -> {
- if (instruction.condition is TackyVar) liveBefore.add(instruction.condition.name)
+ is TackyBinary -> {
+ liveBefore.remove(instruction.dest.name)
+ if (instruction.src1 is TackyVar) liveBefore.add(instruction.src1.name)
+ if (instruction.src2 is TackyVar) liveBefore.add(instruction.src2.name)
+ }
+ is TackyCopy -> {
+ liveBefore.remove(instruction.dest.name)
+ if (instruction.src is TackyVar) liveBefore.add(instruction.src.name)
+ }
+ is TackyFunCall -> {
+ liveBefore.remove(instruction.dest.name)
+ instruction.args.forEach { arg ->
+ if (arg is TackyVar) liveBefore.add(arg.name)
+ }
+ }
+ is TackyRet -> {
+ if (instruction.value is TackyVar) liveBefore.add(instruction.value.name)
+ }
+ is JumpIfZero -> {
+ if (instruction.condition is TackyVar) liveBefore.add(instruction.condition.name)
+ }
+ is JumpIfNotZero -> {
+ if (instruction.condition is TackyVar) liveBefore.add(instruction.condition.name)
+ }
+ is TackyJump, is TackyLabel -> {}
}
- is TackyJump -> { /* no effect */ }
- is TackyLabel -> { /* no effect */ }
}
-
return liveBefore
}
- internal fun extractStaticVariables(cfg: ControlFlowGraph): Set {
- // stub: no statics for now
- return emptySet()
+ private fun meet(
+ block: Block,
+ allStaticVariables: Set
+ ): MutableSet {
+ val liveVariables = mutableSetOf()
+ for (suc in block.successors) {
+ liveVariables.addAll(blockAnnotations.getOrElse(suc) { emptySet() })
+ }
+ return liveVariables
}
}
diff --git a/src/jsMain/kotlin/optimizations/Optimization.kt b/src/jsMain/kotlin/optimizations/Optimization.kt
index 736d54c..7054d29 100644
--- a/src/jsMain/kotlin/optimizations/Optimization.kt
+++ b/src/jsMain/kotlin/optimizations/Optimization.kt
@@ -1,42 +1,31 @@
package optimizations
-import tacky.TackyProgram
-
enum class OptimizationType {
- CONSTANT_FOLDING,
- DEAD_STORE_ELIMINATION,
- UNREACHABLE_CODE_ELIMINATION,
- COPY_PROPAGATION
+ B_CONSTANT_FOLDING,
+ D_DEAD_STORE_ELIMINATION,
+ C_UNREACHABLE_CODE_ELIMINATION,
+ A_COPY_PROPAGATION
}
sealed class Optimization {
abstract val optimizationType: OptimizationType
+
abstract fun apply(cfg: ControlFlowGraph): ControlFlowGraph
}
object OptimizationManager {
- private val optimizations: Map = mapOf(
- OptimizationType.CONSTANT_FOLDING to ConstantFolding(),
- OptimizationType.DEAD_STORE_ELIMINATION to DeadStoreElimination(),
- OptimizationType.UNREACHABLE_CODE_ELIMINATION to UnreachableCodeElimination(),
- OptimizationType.COPY_PROPAGATION to CopyPropagation()
- )
-
- fun optimizeProgram(program: TackyProgram, enabledOptimizations: Set): TackyProgram {
- val optimizedFunctions = program.functions.map { function ->
- if (function.body.isEmpty()) {
- function
- } else {
- val cfg = ControlFlowGraph().construct(function.name, function.body)
- val optimizedCfg = applyOptimizations(cfg, enabledOptimizations)
- val optimizedInstructions = optimizedCfg.toInstructions()
- function.copy(body = optimizedInstructions)
- }
- }
- return program.copy(functions = optimizedFunctions)
- }
-
- fun applyOptimizations(cfg: ControlFlowGraph, enabledOptimizations: Set): ControlFlowGraph {
+ private val optimizations: Map =
+ mapOf(
+ OptimizationType.B_CONSTANT_FOLDING to ConstantFolding(),
+ OptimizationType.D_DEAD_STORE_ELIMINATION to DeadStoreElimination(),
+ OptimizationType.C_UNREACHABLE_CODE_ELIMINATION to UnreachableCodeElimination(),
+ OptimizationType.A_COPY_PROPAGATION to CopyPropagation()
+ )
+
+ fun applyOptimizations(
+ cfg: ControlFlowGraph,
+ enabledOptimizations: List
+ ): ControlFlowGraph {
var currentCfg = cfg
while (true) {
diff --git a/src/jsMain/kotlin/optimizations/UnreachableCodeElimination.kt b/src/jsMain/kotlin/optimizations/UnreachableCodeElimination.kt
index 88fff96..bf022e5 100644
--- a/src/jsMain/kotlin/optimizations/UnreachableCodeElimination.kt
+++ b/src/jsMain/kotlin/optimizations/UnreachableCodeElimination.kt
@@ -7,7 +7,7 @@ import tacky.TackyJump
import tacky.TackyLabel
class UnreachableCodeElimination : Optimization() {
- override val optimizationType: OptimizationType = OptimizationType.UNREACHABLE_CODE_ELIMINATION
+ override val optimizationType: OptimizationType = OptimizationType.C_UNREACHABLE_CODE_ELIMINATION
override fun apply(cfg: ControlFlowGraph): ControlFlowGraph {
var currentCfg = removeUnreachableBlocks(cfg)
@@ -38,19 +38,20 @@ class UnreachableCodeElimination : Optimization() {
val reachableBlocks = cfg.blocks.filter { it.id in reachableNodeIds }
- val reachableEdges = cfg.edges.filter { edge ->
- val fromReachable = edge.from.id in reachableNodeIds
- val toReachable = edge.to.id in reachableNodeIds
- val toExit = edge.to is EXIT
+ val reachableEdges =
+ cfg.edges.filter { edge ->
+ val fromReachable = edge.from.id in reachableNodeIds
+ val toReachable = edge.to.id in reachableNodeIds
+ val toExit = edge.to is EXIT
- fromReachable && (toReachable || toExit)
- }
+ fromReachable && (toReachable || toExit)
+ }
return ControlFlowGraph(
functionName = cfg.functionName,
root = cfg.root,
blocks = reachableBlocks,
- edges = reachableEdges
+ edges = reachableEdges.toMutableList()
)
}
@@ -86,10 +87,11 @@ class UnreachableCodeElimination : Optimization() {
}
// rebuild the blocks with the redundant jumps removed
- val newBlocks = cfg.blocks.map { oldBlock ->
- val newInstructions = oldBlock.instructions.filterNot { it in jumpsToRemove }
- Block(oldBlock.id, newInstructions, oldBlock.predecessors, oldBlock.successors)
- }
+ val newBlocks =
+ cfg.blocks.map { oldBlock ->
+ val newInstructions = oldBlock.instructions.filterNot { it in jumpsToRemove }
+ Block(oldBlock.id, newInstructions, oldBlock.predecessors, oldBlock.successors)
+ }
return ControlFlowGraph(
functionName = cfg.functionName,
@@ -135,10 +137,11 @@ class UnreachableCodeElimination : Optimization() {
return cfg
}
- val newBlocks = cfg.blocks.map { oldBlock ->
- val newInstructions = oldBlock.instructions.filterNot { it in labelsToRemove }
- Block(oldBlock.id, newInstructions, oldBlock.predecessors, oldBlock.successors)
- }
+ val newBlocks =
+ cfg.blocks.map { oldBlock ->
+ val newInstructions = oldBlock.instructions.filterNot { it in labelsToRemove }
+ Block(oldBlock.id, newInstructions, oldBlock.predecessors, oldBlock.successors)
+ }
return ControlFlowGraph(
functionName = cfg.functionName,
@@ -155,10 +158,11 @@ class UnreachableCodeElimination : Optimization() {
return cfg
}
- val blocksToRemove = cfg.blocks
- .filter { it.instructions.isEmpty() && it.successors.size <= 1 }
- .toMutableSet()
- val newEdges = cfg.edges.toMutableList()
+ val blocksToRemove =
+ cfg.blocks
+ .filter { it.instructions.isEmpty() && it.successors.size <= 1 }
+ .toMutableSet()
+ val newEdges = cfg.edges
val blocksToKeep = cfg.blocks.filter { it !in blocksToRemove }.toMutableList()
// map for easy search of nodes by their ID
@@ -190,20 +194,25 @@ class UnreachableCodeElimination : Optimization() {
functionName = cfg.functionName,
root = cfg.root,
blocks = blocksToKeep,
- edges = newEdges.distinct()
+ edges = newEdges.distinct().toMutableList()
)
}
- private fun findNodeById(cfg: ControlFlowGraph, nodeId: Int): CFGNode? {
+ private fun findNodeById(
+ cfg: ControlFlowGraph,
+ nodeId: Int
+ ): CFGNode? {
cfg.blocks.find { it.id == nodeId }?.let { return it }
cfg.root?.let { if (it.id == nodeId) return it }
return null
}
- private fun findBlockByLabel(blocks: List, label: TackyLabel): Block? {
- return blocks.find { block ->
+ private fun findBlockByLabel(
+ blocks: List,
+ label: TackyLabel
+ ): Block? =
+ blocks.find { block ->
val firstInstruction = block.instructions.firstOrNull()
firstInstruction is TackyLabel && firstInstruction.name == label.name
}
- }
}
diff --git a/src/jsMain/kotlin/parser/ASTNode.kt b/src/jsMain/kotlin/parser/ASTNode.kt
index df079ab..3364aaf 100644
--- a/src/jsMain/kotlin/parser/ASTNode.kt
+++ b/src/jsMain/kotlin/parser/ASTNode.kt
@@ -2,8 +2,16 @@ package parser
import kotlin.random.Random
-data class SourceLocation(val startLine: Int, val startCol: Int, val endLine: Int, val endCol: Int)
+data class SourceLocation(
+ val startLine: Int,
+ val startCol: Int,
+ val endLine: Int,
+ val endCol: Int
+)
-sealed class ASTNode(open val location: SourceLocation, open val id: String = Random.nextLong().toString()) {
- abstract fun accept(visitor: Visitor): T
+sealed class ASTNode(
+ open val location: SourceLocation,
+ open val id: String = Random.nextLong().toString()
+) {
+ abstract fun accept(visitor: ASTVisitor): T
}
diff --git a/src/jsMain/kotlin/parser/Visitor.kt b/src/jsMain/kotlin/parser/ASTVisitor.kt
similarity index 97%
rename from src/jsMain/kotlin/parser/Visitor.kt
rename to src/jsMain/kotlin/parser/ASTVisitor.kt
index 3cf9e83..532aab1 100644
--- a/src/jsMain/kotlin/parser/Visitor.kt
+++ b/src/jsMain/kotlin/parser/ASTVisitor.kt
@@ -1,6 +1,6 @@
package parser
-interface Visitor {
+interface ASTVisitor {
fun visit(node: SimpleProgram): T
fun visit(node: ReturnStatement): T
diff --git a/src/jsMain/kotlin/parser/BlockItems.kt b/src/jsMain/kotlin/parser/BlockItems.kt
index 60bb64b..d695c68 100644
--- a/src/jsMain/kotlin/parser/BlockItems.kt
+++ b/src/jsMain/kotlin/parser/BlockItems.kt
@@ -1,23 +1,27 @@
package parser
-sealed class Statement(location: SourceLocation) : ASTNode(location)
+sealed class Statement(
+ location: SourceLocation
+) : ASTNode(location)
data class ReturnStatement(
val expression: Expression,
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class ExpressionStatement(
val expression: Expression,
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
-class NullStatement(override val location: SourceLocation) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+class NullStatement(
+ override val location: SourceLocation
+) : Statement(location) {
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
override fun equals(other: Any?): Boolean = other is NullStatement
@@ -28,14 +32,14 @@ data class BreakStatement(
var label: String = "",
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class ContinueStatement(
var label: String = "",
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class WhileStatement(
@@ -44,7 +48,7 @@ data class WhileStatement(
var label: String = "",
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class DoWhileStatement(
@@ -53,7 +57,7 @@ data class DoWhileStatement(
var label: String = "",
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class ForStatement(
@@ -64,23 +68,25 @@ data class ForStatement(
var label: String = "",
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
-sealed class ForInit(location: SourceLocation) : ASTNode(location)
+sealed class ForInit(
+ location: SourceLocation
+) : ASTNode(location)
data class InitDeclaration(
val varDeclaration: VariableDeclaration,
override val location: SourceLocation
) : ForInit(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class InitExpression(
val expression: Expression?,
override val location: SourceLocation
) : ForInit(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
class IfStatement(
@@ -89,14 +95,14 @@ class IfStatement(
val _else: Statement?,
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
class GotoStatement(
val label: String,
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
class LabeledStatement(
@@ -104,55 +110,59 @@ class LabeledStatement(
val statement: Statement,
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
-sealed class Declaration(location: SourceLocation) : ASTNode(location)
+sealed class Declaration(
+ location: SourceLocation
+) : ASTNode(location)
data class VariableDeclaration(
val name: String,
val init: Expression?,
override val location: SourceLocation
) : Declaration(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class VarDecl(
val varDecl: VariableDeclaration
) : Declaration(location = varDecl.location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class FunDecl(
val funDecl: FunctionDeclaration
) : Declaration(location = funDecl.location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
-sealed class BlockItem(location: SourceLocation) : ASTNode(location)
+sealed class BlockItem(
+ location: SourceLocation
+) : ASTNode(location)
data class S(
val statement: Statement
) : BlockItem(location = statement.location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class D(
val declaration: Declaration
) : BlockItem(declaration.location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class CompoundStatement(
val block: Block,
override val location: SourceLocation
) : Statement(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class Block(
val items: List,
override val location: SourceLocation
) : ASTNode(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
diff --git a/src/jsMain/kotlin/parser/Expressions.kt b/src/jsMain/kotlin/parser/Expressions.kt
index 6dbe0d8..7a57f55 100644
--- a/src/jsMain/kotlin/parser/Expressions.kt
+++ b/src/jsMain/kotlin/parser/Expressions.kt
@@ -2,20 +2,22 @@ package parser
import lexer.Token
-sealed class Expression(location: SourceLocation) : ASTNode(location)
+sealed class Expression(
+ location: SourceLocation
+) : ASTNode(location)
data class IntExpression(
val value: Int,
override val location: SourceLocation
) : Expression(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class VariableExpression(
val name: String,
override val location: SourceLocation
) : Expression(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class UnaryExpression(
@@ -23,7 +25,7 @@ data class UnaryExpression(
val expression: Expression,
override val location: SourceLocation
) : Expression(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class BinaryExpression(
@@ -32,7 +34,7 @@ data class BinaryExpression(
val right: Expression,
override val location: SourceLocation
) : Expression(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class AssignmentExpression(
@@ -40,16 +42,16 @@ data class AssignmentExpression(
val rvalue: Expression,
override val location: SourceLocation
) : Expression(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class ConditionalExpression(
- val codition: Expression,
+ val condition: Expression,
val thenExpression: Expression,
val elseExpression: Expression,
override val location: SourceLocation
) : Expression(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
data class FunctionCall(
@@ -57,5 +59,5 @@ data class FunctionCall(
val arguments: List,
override val location: SourceLocation
) : Expression(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
diff --git a/src/jsMain/kotlin/parser/FunctionDeclaration.kt b/src/jsMain/kotlin/parser/FunctionDeclaration.kt
index c478a32..2c0e38a 100644
--- a/src/jsMain/kotlin/parser/FunctionDeclaration.kt
+++ b/src/jsMain/kotlin/parser/FunctionDeclaration.kt
@@ -8,5 +8,5 @@ data class FunctionDeclaration(
val body: Block?,
override val location: SourceLocation
) : ASTNode(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
diff --git a/src/jsMain/kotlin/parser/Parser.kt b/src/jsMain/kotlin/parser/Parser.kt
index 75ba66a..ecfe1a7 100644
--- a/src/jsMain/kotlin/parser/Parser.kt
+++ b/src/jsMain/kotlin/parser/Parser.kt
@@ -33,7 +33,12 @@ class Parser {
// After a full program is parsed, we must be at EOF
expect(TokenType.EOF, tokenSet)
if (!tokenSet.isEmpty()) {
- throw UnexpectedEndOfFileException()
+ throw UnexpectedTokenException(
+ line = tokenSet.first().startLine,
+ column = tokens.first().startColumn,
+ expected = TokenType.EOF.toString(),
+ actual = tokens.first().type.toString()
+ )
}
return ast
}
@@ -59,15 +64,16 @@ class Parser {
val name = parseIdentifier(tokens)
expect(TokenType.LEFT_PAREN, tokens)
val params = mutableListOf()
- if (tokens.firstOrNull()?.type != TokenType.KEYWORD_VOID) {
+ if (tokens.firstOrNull()?.type == TokenType.KEYWORD_VOID) {
+ tokens.removeFirst() // consume 'void'
+ } else if (tokens.firstOrNull()?.type == TokenType.KEYWORD_INT) {
// get params
do {
expect(TokenType.KEYWORD_INT, tokens)
params.add(parseIdentifier(tokens))
} while (tokens.firstOrNull()?.type == TokenType.COMMA && tokens.removeFirst().type == TokenType.COMMA)
- } else {
- tokens.removeFirst() // consume 'void'
}
+ // If neither void nor int, assume no parameters (empty parameter list)
val endParan = expect(TokenType.RIGHT_PAREN, tokens)
val body: Block?
val endLine: Int
@@ -86,18 +92,23 @@ class Parser {
return FunctionDeclaration(name, params, body, SourceLocation(func.startLine, func.startColumn, endLine, endColumn))
}
- private fun parseFunctionDeclarationFromBody(tokens: MutableList, name: String, location: SourceLocation): FunctionDeclaration {
+ private fun parseFunctionDeclarationFromBody(
+ tokens: MutableList,
+ name: String,
+ location: SourceLocation
+ ): FunctionDeclaration {
expect(TokenType.LEFT_PAREN, tokens)
val params = mutableListOf()
- if (tokens.firstOrNull()?.type != TokenType.KEYWORD_VOID) {
+ if (tokens.firstOrNull()?.type == TokenType.KEYWORD_VOID) {
+ tokens.removeFirst() // consume 'void'
+ } else if (tokens.firstOrNull()?.type == TokenType.KEYWORD_INT) {
// get params
do {
expect(TokenType.KEYWORD_INT, tokens)
params.add(parseIdentifier(tokens))
} while (tokens.firstOrNull()?.type == TokenType.COMMA && tokens.removeFirst().type == TokenType.COMMA)
- } else {
- tokens.removeFirst() // consume 'void'
}
+ // If neither void nor int, assume no parameters (empty parameter list)
val end = expect(TokenType.RIGHT_PAREN, tokens)
val body: Block?
val finalLocation: SourceLocation
@@ -128,12 +139,20 @@ class Parser {
if (tokens.firstOrNull()?.type == TokenType.KEYWORD_INT) {
val lookaheadTokens = tokens.toMutableList()
val start = expect(TokenType.KEYWORD_INT, lookaheadTokens)
- val name = parseIdentifier(lookaheadTokens)
+ parseIdentifier(lookaheadTokens)
if (lookaheadTokens.firstOrNull()?.type == TokenType.LEFT_PAREN) {
- expect(TokenType.KEYWORD_INT, tokens)
+ expect(TokenType.KEYWORD_INT, tokens) // consume the int keyword
val actualName = parseIdentifier(tokens)
- D(FunDecl(parseFunctionDeclarationFromBody(tokens, actualName, SourceLocation(start.startLine, start.startColumn, start.endLine, start.endColumn))))
+ D(
+ FunDecl(
+ parseFunctionDeclarationFromBody(
+ tokens,
+ actualName,
+ SourceLocation(start.startLine, start.startColumn, start.endLine, start.endColumn)
+ )
+ )
+ )
} else {
val end = expect(TokenType.KEYWORD_INT, tokens)
val actualName = parseIdentifier(tokens)
@@ -214,7 +233,12 @@ class Parser {
endLine = elseStatement.location.endLine
endCol = elseStatement.location.endCol
}
- return IfStatement(condition, thenStatement, elseStatement, SourceLocation(ifToken.startLine, ifToken.startColumn, endLine, endCol))
+ return IfStatement(
+ condition,
+ thenStatement,
+ elseStatement,
+ SourceLocation(ifToken.startLine, ifToken.startColumn, endLine, endCol)
+ )
}
TokenType.KEYWORD_RETURN -> {
val returnToken = expect(TokenType.KEYWORD_RETURN, tokens)
@@ -229,7 +253,10 @@ class Parser {
val gotoToken = expect(TokenType.GOTO, tokens)
val label = parseIdentifier(tokens)
val semicolonToken = expect(TokenType.SEMICOLON, tokens)
- return GotoStatement(label, SourceLocation(gotoToken.startLine, gotoToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn))
+ return GotoStatement(
+ label,
+ SourceLocation(gotoToken.startLine, gotoToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn)
+ )
}
TokenType.IDENTIFIER -> {
// Handle labeled statements: IDENTIFIER followed by COLON
@@ -238,22 +265,38 @@ class Parser {
val labelName = labelToken.lexeme
expect(TokenType.COLON, tokens)
val statement = parseStatement(tokens)
- return LabeledStatement(labelName, statement, SourceLocation(labelToken.startLine, labelToken.startColumn, statement.location.endLine, statement.location.endCol))
+ return LabeledStatement(
+ labelName,
+ statement,
+ SourceLocation(labelToken.startLine, labelToken.startColumn, statement.location.endLine, statement.location.endCol)
+ )
} else {
// Not a label, parse as expression statement by delegating to default branch
val expression = parseOptionalExpression(tokens = tokens, followedByType = TokenType.SEMICOLON)
- return if (expression != null) ExpressionStatement(expression, expression.location) else NullStatement(SourceLocation(0, 0, 0, 0))
+ return if (expression !=
+ null
+ ) {
+ ExpressionStatement(expression, expression.location)
+ } else {
+ NullStatement(SourceLocation(0, 0, 0, 0))
+ }
}
}
TokenType.KEYWORD_BREAK -> {
val breakToken = expect(TokenType.KEYWORD_BREAK, tokens)
val semicolonToken = expect(TokenType.SEMICOLON, tokens)
- return BreakStatement("", SourceLocation(breakToken.startLine, breakToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn))
+ return BreakStatement(
+ "",
+ SourceLocation(breakToken.startLine, breakToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn)
+ )
}
TokenType.KEYWORD_CONTINUE -> {
val continueToken = expect(TokenType.KEYWORD_CONTINUE, tokens)
val semicolonToken = expect(TokenType.SEMICOLON, tokens)
- return ContinueStatement("", SourceLocation(continueToken.startLine, continueToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn))
+ return ContinueStatement(
+ "",
+ SourceLocation(continueToken.startLine, continueToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn)
+ )
}
TokenType.KEYWORD_WHILE -> {
val whileToken = expect(TokenType.KEYWORD_WHILE, tokens)
@@ -261,7 +304,12 @@ class Parser {
val condition = parseExpression(tokens = tokens)
expect(TokenType.RIGHT_PAREN, tokens)
val body = parseStatement(tokens)
- return WhileStatement(condition, body, "", SourceLocation(whileToken.startLine, whileToken.startColumn, body.location.endLine, body.location.endCol))
+ return WhileStatement(
+ condition,
+ body,
+ "",
+ SourceLocation(whileToken.startLine, whileToken.startColumn, body.location.endLine, body.location.endCol)
+ )
}
TokenType.KEYWORD_DO -> {
val doToken = expect(TokenType.KEYWORD_DO, tokens)
@@ -271,7 +319,12 @@ class Parser {
val condition = parseExpression(tokens = tokens)
expect(TokenType.RIGHT_PAREN, tokens)
val semicolonToken = expect(TokenType.SEMICOLON, tokens)
- return DoWhileStatement(condition, body, "", SourceLocation(doToken.startLine, doToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn))
+ return DoWhileStatement(
+ condition,
+ body,
+ "",
+ SourceLocation(doToken.startLine, doToken.startColumn, semicolonToken.endLine, semicolonToken.endColumn)
+ )
}
TokenType.KEYWORD_FOR -> {
val forToken = expect(TokenType.KEYWORD_FOR, tokens)
@@ -295,7 +348,13 @@ class Parser {
}
else -> {
val expression = parseOptionalExpression(tokens = tokens, followedByType = TokenType.SEMICOLON)
- return if (expression != null) ExpressionStatement(expression, expression.location) else NullStatement(SourceLocation(0, 0, 0, 0))
+ return if (expression !=
+ null
+ ) {
+ ExpressionStatement(expression, expression.location)
+ } else {
+ NullStatement(SourceLocation(0, 0, 0, 0))
+ }
}
}
}
@@ -304,8 +363,12 @@ class Parser {
if (tokens.firstOrNull()?.type == TokenType.KEYWORD_INT) {
val start = expect(TokenType.KEYWORD_INT, tokens)
val name = parseIdentifier(tokens)
- val declaration = parseVariableDeclaration(tokens, name, SourceLocation(start.startLine, start.startColumn, start.endLine, start.endColumn))
- return InitDeclaration(declaration, SourceLocation(start.startLine, start.startColumn, declaration.location.endLine, declaration.location.endCol))
+ val declaration =
+ parseVariableDeclaration(tokens, name, SourceLocation(start.startLine, start.startColumn, start.endLine, start.endColumn))
+ return InitDeclaration(
+ declaration,
+ SourceLocation(start.startLine, start.startColumn, declaration.location.endLine, declaration.location.endCol)
+ )
}
val expression = parseOptionalExpression(tokens = tokens, followedByType = TokenType.SEMICOLON)
return InitExpression(expression, expression?.location ?: SourceLocation(0, 0, 0, 0))
@@ -328,16 +391,30 @@ class Parser {
when (nextType) {
TokenType.ASSIGN -> {
if (left !is VariableExpression) {
- throw InvalidLValueException()
+ throw InvalidLValueException(line = left.location.startLine, column = left.location.startCol)
}
val right = parseExpression(prec, tokens)
- AssignmentExpression(left, right, SourceLocation(left.location.startLine, left.location.startCol, right.location.endLine, right.location.endCol))
+ AssignmentExpression(
+ left,
+ right,
+ SourceLocation(left.location.startLine, left.location.startCol, right.location.endLine, right.location.endCol)
+ )
}
TokenType.QUESTION_MARK -> {
- val thenExpression = parseExpression(prec, tokens)
+ val thenExpression = parseExpression(tokens = tokens)
expect(TokenType.COLON, tokens)
val elseExpression = parseExpression(prec, tokens)
- return ConditionalExpression(left, thenExpression, elseExpression, SourceLocation(left.location.startLine, left.location.startCol, elseExpression.location.endLine, elseExpression.location.endCol))
+ return ConditionalExpression(
+ left,
+ thenExpression,
+ elseExpression,
+ SourceLocation(
+ left.location.startLine,
+ left.location.startCol,
+ elseExpression.location.endLine,
+ elseExpression.location.endCol
+ )
+ )
}
else -> {
val right = parseExpression(prec + 1, tokens)
@@ -372,13 +449,16 @@ class Parser {
when (nextToken.type) {
TokenType.INT_LITERAL -> {
nextToken = tokens.removeFirst()
- return IntExpression(value = nextToken.lexeme.toInt(), SourceLocation(nextToken.startLine, nextToken.startColumn, nextToken.endLine, nextToken.endColumn))
+ return IntExpression(
+ value = nextToken.lexeme.toInt(),
+ SourceLocation(nextToken.startLine, nextToken.startColumn, nextToken.endLine, nextToken.endColumn)
+ )
}
TokenType.IDENTIFIER -> {
nextToken = tokens.removeFirst()
if (tokens.firstOrNull()?.type == TokenType.LEFT_PAREN) {
// function call
- val leftParen = tokens.removeFirst() // consume '('
+ tokens.removeFirst() // consume '('
val args = mutableListOf()
if (tokens.firstOrNull()?.type != TokenType.RIGHT_PAREN) {
do {
@@ -386,17 +466,28 @@ class Parser {
} while (tokens.firstOrNull()?.type == TokenType.COMMA && tokens.removeFirst().type == TokenType.COMMA)
}
val rightParen = expect(TokenType.RIGHT_PAREN, tokens)
- return FunctionCall(nextToken.lexeme, args, SourceLocation(nextToken.startLine, nextToken.startColumn, rightParen.endLine, rightParen.endColumn))
+ return FunctionCall(
+ nextToken.lexeme,
+ args,
+ SourceLocation(nextToken.startLine, nextToken.startColumn, rightParen.endLine, rightParen.endColumn)
+ )
} else {
// It's a variable
- return VariableExpression(nextToken.lexeme, SourceLocation(nextToken.startLine, nextToken.startColumn, nextToken.endLine, nextToken.endColumn))
+ return VariableExpression(
+ nextToken.lexeme,
+ SourceLocation(nextToken.startLine, nextToken.startColumn, nextToken.endLine, nextToken.endColumn)
+ )
}
}
TokenType.TILDE, TokenType.NEGATION, TokenType.NOT -> {
val operator = tokens.removeFirst()
val factor = parseFactor(tokens)
- return UnaryExpression(operator = operator, expression = factor, SourceLocation(operator.startLine, operator.startColumn, factor.location.endLine, factor.location.endCol))
+ return UnaryExpression(
+ operator = operator,
+ expression = factor,
+ SourceLocation(operator.startLine, operator.startColumn, factor.location.endLine, factor.location.endCol)
+ )
}
TokenType.LEFT_PAREN -> {
expect(TokenType.LEFT_PAREN, tokens)
diff --git a/src/jsMain/kotlin/parser/Programs.kt b/src/jsMain/kotlin/parser/Programs.kt
index 05694cc..ddcc20e 100644
--- a/src/jsMain/kotlin/parser/Programs.kt
+++ b/src/jsMain/kotlin/parser/Programs.kt
@@ -1,10 +1,12 @@
package parser
-sealed class Program(location: SourceLocation) : ASTNode(location)
+sealed class Program(
+ location: SourceLocation
+) : ASTNode(location)
data class SimpleProgram(
val functionDeclaration: List,
override val location: SourceLocation
) : Program(location) {
- override fun accept(visitor: Visitor): T = visitor.visit(this)
+ override fun accept(visitor: ASTVisitor): T = visitor.visit(this)
}
diff --git a/src/jsMain/kotlin/semanticAnalysis/IdentifierResolution.kt b/src/jsMain/kotlin/semanticAnalysis/IdentifierResolution.kt
index a406266..f961447 100644
--- a/src/jsMain/kotlin/semanticAnalysis/IdentifierResolution.kt
+++ b/src/jsMain/kotlin/semanticAnalysis/IdentifierResolution.kt
@@ -4,6 +4,7 @@ import exceptions.DuplicateVariableDeclaration
import exceptions.MissingDeclarationException
import exceptions.NestedFunctionException
import parser.ASTNode
+import parser.ASTVisitor
import parser.AssignmentExpression
import parser.BinaryExpression
import parser.Block
@@ -36,7 +37,6 @@ import parser.UnaryExpression
import parser.VarDecl
import parser.VariableDeclaration
import parser.VariableExpression
-import parser.Visitor
import parser.WhileStatement
data class SymbolInfo(
@@ -44,7 +44,7 @@ data class SymbolInfo(
val hasLinkage: Boolean
)
-class IdentifierResolution : Visitor {
+class IdentifierResolution : ASTVisitor {
private var tempCounter = 0
private fun newTemporary(name: String): String = "$name.${tempCounter++}"
@@ -67,7 +67,9 @@ class IdentifierResolution : Visitor {
private fun declare(
name: String,
- hasLinkage: Boolean
+ hasLinkage: Boolean,
+ line: Int = -1,
+ col: Int = -1
): String {
val currentScope = scopeStack.last()
val existing = currentScope[name]
@@ -75,7 +77,7 @@ class IdentifierResolution : Visitor {
if (existing != null) {
// A redeclaration in the same scope is only okay if both have linkage.
if (!existing.hasLinkage || !hasLinkage) {
- throw DuplicateVariableDeclaration()
+ throw DuplicateVariableDeclaration(line, col)
}
// If both have linkage (e.g., two function declarations), it's okay.
return existing.uniqueName
@@ -89,13 +91,17 @@ class IdentifierResolution : Visitor {
private fun leaveScope() = scopeStack.removeAt(scopeStack.lastIndex)
- private fun resolve(name: String): SymbolInfo {
+ private fun resolve(
+ name: String,
+ line: Int = -1,
+ col: Int = -1
+ ): SymbolInfo {
scopeStack.asReversed().forEach { scope ->
if (scope.containsKey(name)) {
return scope.getValue(name)
}
}
- throw MissingDeclarationException(name)
+ throw MissingDeclarationException(name, line, col)
}
override fun visit(node: SimpleProgram): ASTNode {
@@ -161,19 +167,19 @@ class IdentifierResolution : Visitor {
// We're inside another function - check if this is a prototype or definition
if (node.body != null) {
// Function definition with body - not allowed inside other functions
- throw NestedFunctionException()
+ throw NestedFunctionException(node.location.startLine, node.location.startCol)
} else {
- declare(node.name, hasLinkage = true)
+ declare(node.name, hasLinkage = true, node.location.startLine, node.location.startCol)
return FunctionDeclaration(node.name, node.params, null, node.location)
}
} else {
- declare(node.name, hasLinkage = true)
+ declare(node.name, hasLinkage = true, node.location.startLine, node.location.startCol)
enterScope()
val newParams =
node.params.map { paramName ->
- declare(paramName, hasLinkage = false)
+ declare(paramName, hasLinkage = false, node.location.startLine, node.location.startCol)
}
val newBody = node.body?.accept(this) as Block?
@@ -209,7 +215,7 @@ class IdentifierResolution : Visitor {
}
override fun visit(node: ConditionalExpression): ASTNode {
- val condition = node.codition.accept(this) as Expression
+ val condition = node.condition.accept(this) as Expression
val thenExpression = node.thenExpression.accept(this) as Expression
val elseExpression = node.elseExpression.accept(this) as Expression
return ConditionalExpression(condition, thenExpression, elseExpression, node.location)
@@ -231,7 +237,7 @@ class IdentifierResolution : Visitor