Skip to content

Commit

Permalink
Use itself (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev authored Jul 7, 2024
1 parent 8659d33 commit fbcf13f
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 32 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
},
"[typescript]": {
"editor.defaultFormatter": "biomejs.biome"
}
},
"auto-typing-final.import-style": "final"
}
3 changes: 2 additions & 1 deletion auto_typing_final/finder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Final

from ast_grep_py import Config, SgNode

Expand Down Expand Up @@ -178,7 +179,7 @@ class ImportsResult:


def find_imports_of_identifier_in_scope(root: SgNode, module_name: str, identifier_name: str) -> ImportsResult: # noqa: C901
result = ImportsResult(module_aliases={module_name}, has_from_import=False)
result: Final = ImportsResult(module_aliases={module_name}, has_from_import=False)

for node in root.find_all(any=[{"kind": "import_statement"}, {"kind": "import_from_statement"}]):
if _is_inside_inner_function_or_class(root, node) or node == root:
Expand Down
20 changes: 10 additions & 10 deletions auto_typing_final/lsp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from collections.abc import Iterable
from importlib.metadata import version
from typing import cast, get_args
from typing import Final, cast, get_args

import attr
import cattrs
Expand Down Expand Up @@ -42,7 +42,7 @@ def make_import_text_edit(import_text: str) -> lsp.TextEdit:


def make_text_edit(edit: Edit) -> lsp.TextEdit:
node_range = edit.node.range()
node_range: Final = edit.node.range()
return lsp.TextEdit(
range=lsp.Range(
start=lsp.Position(line=node_range.start.line, character=node_range.start.column),
Expand All @@ -55,7 +55,7 @@ def make_text_edit(edit: Edit) -> lsp.TextEdit:
def make_diagnostics(source: str) -> Iterable[lsp.Diagnostic]:
if not IMPORT_CONFIG:
return
result = make_replacements(root=SgRoot(source, "python").root(), import_config=IMPORT_CONFIG)
result: Final = make_replacements(root=SgRoot(source, "python").root(), import_config=IMPORT_CONFIG)

for replacement in result.replacements:
if replacement.operation_type == AddFinal:
Expand Down Expand Up @@ -86,7 +86,7 @@ def make_diagnostics(source: str) -> Iterable[lsp.Diagnostic]:
def make_fixall_text_edits(source: str) -> Iterable[lsp.TextEdit]:
if not IMPORT_CONFIG:
return
result = make_replacements(root=SgRoot(source, "python").root(), import_config=IMPORT_CONFIG)
result: Final = make_replacements(root=SgRoot(source, "python").root(), import_config=IMPORT_CONFIG)

for replacement in result.replacements:
for edit in replacement.edits:
Expand Down Expand Up @@ -147,7 +147,7 @@ def workspace_did_change_configuration(params: lsp.DidChangeConfigurationParams)
def did_open_did_save_did_change(
params: lsp.DidOpenTextDocumentParams | lsp.DidSaveTextDocumentParams | lsp.DidChangeTextDocumentParams,
) -> None:
text_document = LSP_SERVER.workspace.get_text_document(params.text_document.uri)
text_document: Final = LSP_SERVER.workspace.get_text_document(params.text_document.uri)
LSP_SERVER.publish_diagnostics(text_document.uri, diagnostics=list(make_diagnostics(text_document.source)))


Expand All @@ -163,12 +163,12 @@ def did_close(params: lsp.DidCloseTextDocumentParams) -> None:
),
)
def code_action(params: lsp.CodeActionParams) -> list[lsp.CodeAction] | None:
requested_kinds = params.context.only or {lsp.CodeActionKind.QuickFix, lsp.CodeActionKind.SourceFixAll}
actions: list[lsp.CodeAction] = []
requested_kinds: Final = params.context.only or {lsp.CodeActionKind.QuickFix, lsp.CodeActionKind.SourceFixAll}
actions: Final[list[lsp.CodeAction]] = []

if lsp.CodeActionKind.QuickFix in requested_kinds:
text_document = LSP_SERVER.workspace.get_text_document(params.text_document.uri)
our_diagnostics = [
text_document: Final = LSP_SERVER.workspace.get_text_document(params.text_document.uri)
our_diagnostics: Final = [
diagnostic for diagnostic in params.context.diagnostics if diagnostic.source == LSP_SERVER.name
]

Expand Down Expand Up @@ -210,7 +210,7 @@ def code_action(params: lsp.CodeActionParams) -> list[lsp.CodeAction] | None:

@LSP_SERVER.feature(lsp.CODE_ACTION_RESOLVE)
def resolve_code_action(params: lsp.CodeAction) -> lsp.CodeAction:
text_document = LSP_SERVER.workspace.get_text_document(cast(str, params.data))
text_document: Final = LSP_SERVER.workspace.get_text_document(cast(str, params.data))
params.edit = lsp.WorkspaceEdit(
document_changes=[
lsp.TextDocumentEdit(
Expand Down
14 changes: 7 additions & 7 deletions auto_typing_final/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
from collections.abc import Iterable
from difflib import unified_diff
from pathlib import Path
from typing import get_args
from typing import Final, get_args

from ast_grep_py import SgRoot

from auto_typing_final.transform import IMPORT_STYLES_TO_IMPORT_CONFIGS, ImportConfig, ImportStyle, make_replacements


def transform_file_content(source: str, import_config: ImportConfig) -> str:
root = SgRoot(source, "python").root()
result = make_replacements(root, import_config)
new_text = root.commit_edits(
root: Final = SgRoot(source, "python").root()
result: Final = make_replacements(root, import_config)
new_text: Final = root.commit_edits(
[edit.node.replace(edit.new_text) for replacement in result.replacements for edit in replacement.edits]
)
return root.commit_edits([root.replace(f"{result.import_text}\n{new_text}")]) if result.import_text else new_text
Expand Down Expand Up @@ -42,13 +42,13 @@ def find_all_source_files(paths: list[Path]) -> Iterable[Path]:


def main() -> int:
parser = argparse.ArgumentParser()
parser: Final = argparse.ArgumentParser()
parser.add_argument("files", type=Path, nargs="*")
parser.add_argument("--check", action="store_true")
parser.add_argument("--import-style", type=str, choices=get_args(ImportStyle), default="typing-final")

args = parser.parse_args()
import_config = IMPORT_STYLES_TO_IMPORT_CONFIGS[args.import_style]
args: Final = parser.parse_args()
import_config: Final = IMPORT_STYLES_TO_IMPORT_CONFIGS[args.import_style]

has_changes = False

Expand Down
14 changes: 7 additions & 7 deletions auto_typing_final/transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Literal
from typing import Final, Literal

from ast_grep_py import SgNode

Expand Down Expand Up @@ -63,7 +63,7 @@ class RemoveFinal:


def _make_operation_from_assignments_to_one_name(nodes: list[SgNode]) -> Operation:
value_assignments: list[Definition] = []
value_assignments: Final[list[Definition]] = []
has_node_inside_loop = False

for node in nodes:
Expand Down Expand Up @@ -107,11 +107,11 @@ def _attribute_is_exact_identifier(node: SgNode, imports_result: ImportsResult,
def _strip_identifier_from_type_annotation(
node: SgNode, imports_result: ImportsResult, identifier_name: str
) -> str | None:
type_node_children = node.children()
type_node_children: Final = node.children()
if len(type_node_children) != 1:
return None
inner_type_node = type_node_children[0]
kind = inner_type_node.kind()
inner_type_node: Final = type_node_children[0]
kind: Final = inner_type_node.kind()

if kind == "subscript":
match tuple((child.kind(), child) for child in inner_type_node.children()):
Expand Down Expand Up @@ -182,9 +182,9 @@ class MakeReplacementsResult:


def make_replacements(root: SgNode, import_config: ImportConfig) -> MakeReplacementsResult:
replacements = []
replacements: Final = []
has_added_final = False
imports_result = find_imports_of_identifier_in_scope(root, module_name="typing", identifier_name="Final")
imports_result: Final = find_imports_of_identifier_in_scope(root, module_name="typing", identifier_name="Final")

for current_definitions in find_all_definitions_in_functions(root):
operation = _make_operation_from_assignments_to_one_name(current_definitions)
Expand Down
16 changes: 10 additions & 6 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Final

import pytest

from auto_typing_final.main import transform_file_content
Expand Down Expand Up @@ -39,16 +41,18 @@
],
)
def test_variants(import_config: ImportConfig, before: str, after: str) -> None:
source_function_content = "\n".join(f" {line.format(import_config.value)}" for line in before.splitlines())
source = f"""
source_function_content: Final = "\n".join(
f" {line.format(import_config.value)}" for line in before.splitlines()
)
source: Final = f"""
{import_config.import_text}
def foo():
{source_function_content}
"""

after_function_content = "\n".join(f" {line.format(import_config.value)}" for line in after.splitlines())
after_source = f"""
after_function_content: Final = "\n".join(f" {line.format(import_config.value)}" for line in after.splitlines())
after_source: Final = f"""
{import_config.import_text}
def foo():
Expand Down Expand Up @@ -588,7 +592,7 @@ def foo(self):
],
)
def test_transform_file_content(case: str) -> None:
import_config = IMPORT_STYLES_TO_IMPORT_CONFIGS["typing-final"]
import_config: Final = IMPORT_STYLES_TO_IMPORT_CONFIGS["typing-final"]
before, _, after = case.partition("---")
assert (
transform_file_content(f"{import_config.import_text}\n" + before.strip(), import_config=import_config)
Expand Down Expand Up @@ -653,7 +657,7 @@ def f():
)
def test_add_import(case: str) -> None:
before, _, after = case.partition("---")
import_config = IMPORT_STYLES_TO_IMPORT_CONFIGS["typing-final"]
import_config: Final = IMPORT_STYLES_TO_IMPORT_CONFIGS["typing-final"]
assert transform_file_content(before.strip(), import_config=import_config) == after.strip()


Expand Down

0 comments on commit fbcf13f

Please sign in to comment.