diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..a1d20535 --- /dev/null +++ b/.gitignore @@ -0,0 +1,53 @@ +# Dynamic Library Sync Temp Directory +.lib_sync_temp/ + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +ENV/ +env/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ + +# Logs +*.log + +# OS +.DS_Store +Thumbs.db + +# Sync state is tracked +!Libraries/.sync_state.json + diff --git a/LIBRARY_SYNC.md b/LIBRARY_SYNC.md new file mode 100644 index 00000000..6ced88ed --- /dev/null +++ b/LIBRARY_SYNC.md @@ -0,0 +1,376 @@ +# Dynamic Library Sync System + +## Overview + +The analyzer repository uses a **dynamic library sync system** to keep external libraries (autogenlib, serena, graph-sitter) up-to-date automatically. + +Instead of using traditional git submodules or manual copying, this system: +- ✅ **Auto-detects changes** in source repositories +- ✅ **Syncs only what's needed** (Python source files) +- ✅ **Filters out tests and build artifacts** +- ✅ **Maintains sync state** for efficient updates +- ✅ **Works offline** once initially cloned + +## Quick Start + +### Sync All Libraries + +```bash +# Initial sync or update all libraries +python sync_libraries.py + +# Force sync even if no changes detected +python sync_libraries.py --force + +# Check status without syncing +python sync_libraries.py --check +``` + +### Sync Specific Library + +```bash +# Sync only autogenlib +python sync_libraries.py --library autogenlib + +# Sync only serena +python sync_libraries.py --library serena + +# Sync only graph-sitter +python sync_libraries.py --library graph_sitter +``` + +### Validate Modules + +```bash +# Validate all modules and adapters +python validate_modules.py + +# Quick validation (imports only) +python validate_modules.py --quick + +# Verbose validation +python validate_modules.py --verbose +``` + +## Library Configuration + +The sync system is configured in `sync_libraries.py`: + +```python +LIBRARY_CONFIGS = { + "autogenlib": { + "repo_url": "https://github.com/Zeeeepa/autogenlib.git", + "source_path": "autogenlib", + "target_path": LIBRARIES_DIR / "autogenlib", + }, + "serena": { + "repo_url": "https://github.com/Zeeeepa/serena.git", + "source_path": "src/serena", + "target_path": LIBRARIES_DIR / "serena", + }, + "graph_sitter": { + "repo_url": "https://github.com/Zeeeepa/graph-sitter.git", + "source_path": "src", + "target_path": LIBRARIES_DIR / "graph_sitter_lib", + }, +} +``` + +## Directory Structure + +``` +analyzer/ +├── Libraries/ # Target directory for synced libraries +│ ├── autogenlib/ # Synced from Zeeeepa/autogenlib +│ ├── serena/ # Synced from Zeeeepa/serena +│ ├── graph_sitter_lib/ # Synced from Zeeeepa/graph-sitter +│ ├── .sync_state.json # Tracks last sync state +│ ├── analyzer.py # Core analyzer +│ ├── autogenlib_adapter.py # Autogenlib integration +│ ├── graph_sitter_adapter.py # Graph-sitter integration +│ ├── lsp_adapter.py # LSP integration +│ └── static_libs.py # Utilities +├── .lib_sync_temp/ # Temporary clones (gitignored) +│ ├── autogenlib/ +│ ├── serena/ +│ └── graph_sitter/ +├── sync_libraries.py # Sync script +├── validate_modules.py # Validation script +└── .gitignore # Excludes .lib_sync_temp/ +``` + +## How It Works + +### 1. Clone or Update + +The script clones repositories to `.lib_sync_temp/` or pulls latest changes if already cloned. + +### 2. Calculate Hashes + +It calculates MD5 hashes of source and target directories to detect changes: + +```python +def should_sync(self, source_dir: Path) -> bool: + source_hash = self.calculate_directory_hash(source_dir) + target_hash = self.calculate_directory_hash(self.target_path) + return source_hash != target_hash +``` + +### 3. Filtered Copy + +Only Python source files (`*.py`, `*.pyi`, `*.typed`) are copied, excluding: +- Test files (`*test*`) +- Bytecode (`__pycache__`, `*.pyc`) +- Build artifacts + +### 4. State Tracking + +Sync state is saved to `Libraries/.sync_state.json`: + +```json +{ + "last_sync": "2025-10-15T15:27:51.826000", + "results": { + "autogenlib": true, + "serena": true, + "graph_sitter": true + } +} +``` + +## Automated Sync + +### Git Hooks (Recommended) + +Create `.git/hooks/post-merge` to auto-sync after git pull: + +```bash +#!/bin/bash +echo "Running library sync..." +python3 sync_libraries.py +``` + +Make it executable: + +```bash +chmod +x .git/hooks/post-merge +``` + +### GitHub Actions (CI/CD) + +Add to `.github/workflows/sync-libraries.yml`: + +```yaml +name: Sync Libraries + +on: + schedule: + - cron: '0 0 * * *' # Daily at midnight + workflow_dispatch: # Manual trigger + +jobs: + sync: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + - name: Sync libraries + run: python sync_libraries.py + - name: Commit changes + run: | + git config user.name "Library Sync Bot" + git config user.email "bot@analyzer.com" + git add Libraries/ + git commit -m "chore: sync external libraries" || exit 0 + git push +``` + +### Cron Job (Server) + +Add to crontab for automated syncing: + +```bash +# Sync libraries every 6 hours +0 */6 * * * cd /path/to/analyzer && python3 sync_libraries.py +``` + +## Troubleshooting + +### Issue: "No module named 'X'" + +**Cause**: Missing Python dependencies +**Solution**: Install required packages + +```bash +pip install openai anthropic networkx pydantic fastapi +# or +pip install -r requirements.txt +``` + +### Issue: "Syntax error in static_libs.py" + +**Cause**: File corruption during manual edit +**Solution**: Re-sync the library + +```bash +python sync_libraries.py --force --library autogenlib +``` + +### Issue: "Git pull failed" + +**Cause**: Temp directory has uncommitted changes +**Solution**: The script auto-handles this by re-cloning + +### Issue: "Permission denied" + +**Cause**: File permissions +**Solution**: Ensure write permissions + +```bash +chmod -R u+w Libraries/ +``` + +## Best Practices + +### 1. **Regular Syncing** + +Run sync at least: +- After git pull +- Before starting new work +- Before running tests +- Before deployment + +### 2. **Validate After Sync** + +Always validate after syncing: + +```bash +python sync_libraries.py && python validate_modules.py +``` + +### 3. **Check Status First** + +Before syncing, check if updates are needed: + +```bash +python sync_libraries.py --check +``` + +### 4. **Use Version Control** + +The synced libraries in `Libraries/` are tracked in git, so: +- ✅ Changes are versioned +- ✅ Team members get same versions +- ✅ Can rollback if needed + +### 5. **Monitor Sync State** + +Check `.sync_state.json` to see when last synced: + +```bash +cat Libraries/.sync_state.json +``` + +## Advanced Usage + +### Custom Sync Configuration + +Modify `LIBRARY_CONFIGS` in `sync_libraries.py` to: +- Change source paths +- Add new libraries +- Adjust file patterns +- Change target locations + +### Programmatic Usage + +```python +from sync_libraries import SyncManager + +manager = SyncManager() + +# Sync all +manager.sync_all(force=True) + +# Sync one +manager.sync_one("autogenlib") + +# Check status +statuses = manager.check_all() +for name, status in statuses.items(): + print(f"{name}: {status['needs_sync']}") +``` + +### Custom Filters + +Add custom include/exclude patterns: + +```python +"autogenlib": { + # ... + "include_patterns": ["*.py", "*.pyi", "*.json"], + "exclude_patterns": ["*test*", "*__pycache__*", "*.pyc", "*deprecated*"], +} +``` + +## Contributing + +When contributing changes to the sync system: + +1. Test the sync script: + ```bash + python sync_libraries.py --check + python sync_libraries.py --force + ``` + +2. Validate all modules work: + ```bash + python validate_modules.py + ``` + +3. Update documentation if changing: + - Library configurations + - Sync behavior + - File patterns + +4. Commit both script and synced libraries: + ```bash + git add sync_libraries.py Libraries/ + git commit -m "feat: update library sync system" + ``` + +## FAQs + +**Q: Why not use git submodules?** +A: Submodules are complex, require specific git commands, and include unnecessary files (tests, docs, etc.). + +**Q: Why copy instead of symlink?** +A: Copying ensures: +- Cross-platform compatibility +- No broken links +- Clean git tracking +- Easier deployment + +**Q: How big are the synced libraries?** +A: Much smaller than full repos: +- autogenlib: ~8 files +- serena: ~37 files +- graph-sitter: ~650 files + +**Q: Can I manually edit synced files?** +A: Not recommended. Changes will be overwritten on next sync. Instead, contribute to the source repositories. + +**Q: How often should I sync?** +A: Daily for active development, weekly for stable projects. + +**Q: Does this work offline?** +A: Yes, once initially cloned, it uses cached repos in `.lib_sync_temp/`. + +--- + +**Version**: 1.0 +**Last Updated**: 2025-10-15 +**Maintainer**: Analyzer Team + diff --git a/Libraries/.sync_state.json b/Libraries/.sync_state.json new file mode 100644 index 00000000..f0928759 --- /dev/null +++ b/Libraries/.sync_state.json @@ -0,0 +1,8 @@ +{ + "last_sync": "2025-10-15T15:27:51.826049", + "results": { + "autogenlib": true, + "serena": true, + "graph_sitter": true + } +} \ No newline at end of file diff --git a/Libraries/autogenlib/LICENSE b/Libraries/autogenlib/LICENSE deleted file mode 100644 index be634293..00000000 --- a/Libraries/autogenlib/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025 Egor Ternovoi - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/Libraries/autogenlib/README.md b/Libraries/autogenlib/README.md deleted file mode 100644 index ff0afa36..00000000 --- a/Libraries/autogenlib/README.md +++ /dev/null @@ -1,207 +0,0 @@ -# AutoGenLib - -> The only library you'll need ever. -> -> Import wisdom, export code. - -AutoGenLib is a Python library that automatically generates code on-the-fly using OpenAI's API. When you try to import a module or function that doesn't exist, AutoGenLib creates it for you based on a high-level description of what you need. - -[Video review of library](https://www.youtube.com/watch?v=x6ZBddPiGZE) - -## Features - -- **Dynamic Code Generation**: Import modules and functions that don't exist yet -- **Context-Aware**: New functions are generated with knowledge of existing code -- **Progressive Enhancement**: Add new functionality to existing modules seamlessly -- **No Default Caching**: Each import generates fresh code for more varied and creative results -- **Full Codebase Context**: LLM can see all previously generated modules for better consistency -- **Caller Code Analysis**: The LLM analyzes the actual code that's importing the module to better understand the context and requirements -- **Automatic Exception Handling**: All exceptions are sent to LLM to provide clear explanation and fixes for errors. - -## Installation - -```bash -pip install autogenlib -``` - -Or install from source: - -```bash -git clone https://github.com/cofob/autogenlib.git -cd autogenlib -pip install -e . -``` - -## Requirements - -- Python 3.12+ -- OpenAI API key - -## Quick Start - -Set OpenAI API key in `OPENAI_API_KEY` env variable. - -```python -# Import a function that doesn't exist yet - it will be automatically generated -from autogenlib.tokens import generate_token - -# Use the generated function -token = generate_token(length=32) -print(token) -``` - -## How It Works - -1. You initialize AutoGenLib with a description of what you need -2. When you import a module or function under the `autogenlib` namespace, the library: - - Checks if the module/function already exists - - If not, it analyzes the code that's performing the import to understand the context - - It sends a request to OpenAI's API with your description and the context - - The API generates the appropriate code - - The code is validated and executed - - The requested module/function becomes available for use - -## Examples - -### Generate a TOTP Generator - -```python -from autogenlib.totp import totp_generator - -print(totp_generator("SECRETKEY123")) -``` - -Add a Verification Function Later - -```python -# Later in your application, you need verification: -from autogenlib.totp import verify_totp -result = verify_totp("SECRETKEY123", "123456") -print(f"Verification result: {result}") -``` - -### Using Context-Awareness - -```python -# Import a function - AutoGenLib will see how your data is structured -from autogenlib.processor import get_highest_score - -# Define your data structure -data = [{"user": "Alice", "score": 95}, {"user": "Bob", "score": 82}] - -# The function will work with your data structure without you having to specify details -print(get_highest_score(data)) # Will correctly extract the highest score -``` - -### Create Multiple Modules - -```python -# You can use init function to additionally hint the purpose of your library -from autogenlib import init -init("Cryptographic utility library") - -# Generate encryption module -from autogenlib.encryption import encrypt_text, decrypt_text -encrypted = encrypt_text("Secret message", "password123") -decrypted = decrypt_text(encrypted, "password123") -print(decrypted) - -# Generate hashing module -from autogenlib.hashing import hash_password, verify_password -hashed = hash_password("my_secure_password") -is_valid = verify_password("my_secure_password", hashed) -print(f"Password valid: {is_valid}") -``` - -## Configuration - -### Setting the OpenAI API Key - -Set your OpenAI API key as an environment variable: - -```bash -export OPENAI_API_KEY="your-api-key-here" -# Optional -export OPENAI_API_BASE_URL="https://openrouter.ai/api/v1" # Use OpenRouter API -export OPENAI_MODEL="openai/gpt-4.1" -``` - -Or in your Python code (not recommended for production): - -```python -import os -os.environ["OPENAI_API_KEY"] = "your-api-key-here" -``` - -### Caching Behavior - -By default, AutoGenLib does not cache generated code. This means: - -- Each time you import a module, the LLM generates fresh code -- You get more varied and often funnier results due to LLM hallucinations -- The same import might produce different implementations across runs - -If you want to enable caching (for consistency or to reduce API calls): - -```python -from autogenlib import init -init("Library for data processing", enable_caching=True) -``` - -Or toggle caching at runtime: - -```python -from autogenlib import init, set_caching -init("Library for data processing") - -# Later in your code: -set_caching(True) # Enable caching -set_caching(False) # Disable caching -``` - -When caching is enabled, generated code is stored in `~/.autogenlib_cache`. - -## Limitations - -- Requires internet connection to generate new code -- Depends on OpenAI API availability -- Generated code quality depends on the clarity of your description -- Not suitable for production-critical code without review - -## Advanced Usage - -### Inspecting Generated Code - -You can inspect the code that was generated for a module: - -```python -from autogenlib.totp import totp_generator -import inspect -print(inspect.getsource(totp_generator)) -``` - -## How AutoGenLib Uses the OpenAI API - -AutoGenLib creates prompts for the OpenAI API that include: - -1. The description you provided -2. Any existing code in the module being enhanced -3. The full context of all previously generated modules -4. The code that's importing the module/function (new feature!) -5. The specific function or feature needed - -This comprehensive context helps the LLM generate code that's consistent with your existing codebase and fits perfectly with how you intend to use it. - -## Contributing - -Contributions are not welcome! This is just a fun PoC project. - -## License - -MIT License - ---- - -*Note: This library is meant for prototyping and experimentation. Always review automatically generated code before using it in production environments.* - -*Note: Of course 100% of the code of this library was generated via LLM* diff --git a/Libraries/autogenlib/autogenlib/__init__.py b/Libraries/autogenlib/__init__.py similarity index 96% rename from Libraries/autogenlib/autogenlib/__init__.py rename to Libraries/autogenlib/__init__.py index fc34393e..a344b88e 100644 --- a/Libraries/autogenlib/autogenlib/__init__.py +++ b/Libraries/autogenlib/__init__.py @@ -1,67 +1,67 @@ -"""Automatic code generation library using OpenAI.""" - -import sys -from ._finder import AutoLibFinder -from ._exception_handler import setup_exception_handler - - -_sentinel = object() - - -def init(desc=_sentinel, enable_exception_handler=None, enable_caching=None): - """Initialize autogenlib with a description of the functionality needed. - - Args: - desc (str): A description of the library you want to generate. - enable_exception_handler (bool): Whether to enable the global exception handler - that sends exceptions to LLM for fix suggestions. Default is True. - enable_caching (bool): Whether to enable caching of generated code. Default is False. - """ - # Update the global description - from . import _state - - if desc is not _sentinel: - _state.description = desc - if enable_exception_handler is not None: - _state.exception_handler_enabled = enable_exception_handler - if enable_caching is not None: - _state.caching_enabled = enable_caching - - # Set up exception handler if enabled - if _state.exception_handler_enabled: - from ._exception_handler import setup_exception_handler - - setup_exception_handler() - - # Add our custom finder to sys.meta_path if it's not already there - for finder in sys.meta_path: - if isinstance(finder, AutoLibFinder): - return - sys.meta_path.insert(0, AutoLibFinder()) - - -def set_exception_handler(enabled=True): - """Enable or disable the exception handler. - - Args: - enabled (bool): Whether to enable the exception handler. Default is True. - """ - from . import _state - - _state.exception_handler_enabled = enabled - - -def set_caching(enabled=True): - """Enable or disable caching. - - Args: - enabled (bool): Whether to enable caching. Default is True. - """ - from . import _state - - _state.caching_enabled = enabled - - -__all__ = ["init", "set_exception_handler", "setup_exception_handler", "set_caching"] - -init() +"""Automatic code generation library using OpenAI.""" + +import sys +from ._finder import AutoLibFinder +from ._exception_handler import setup_exception_handler + + +_sentinel = object() + + +def init(desc=_sentinel, enable_exception_handler=None, enable_caching=None): + """Initialize autogenlib with a description of the functionality needed. + + Args: + desc (str): A description of the library you want to generate. + enable_exception_handler (bool): Whether to enable the global exception handler + that sends exceptions to LLM for fix suggestions. Default is True. + enable_caching (bool): Whether to enable caching of generated code. Default is False. + """ + # Update the global description + from . import _state + + if desc is not _sentinel: + _state.description = desc + if enable_exception_handler is not None: + _state.exception_handler_enabled = enable_exception_handler + if enable_caching is not None: + _state.caching_enabled = enable_caching + + # Set up exception handler if enabled + if _state.exception_handler_enabled: + from ._exception_handler import setup_exception_handler + + setup_exception_handler() + + # Add our custom finder to sys.meta_path if it's not already there + for finder in sys.meta_path: + if isinstance(finder, AutoLibFinder): + return + sys.meta_path.insert(0, AutoLibFinder()) + + +def set_exception_handler(enabled=True): + """Enable or disable the exception handler. + + Args: + enabled (bool): Whether to enable the exception handler. Default is True. + """ + from . import _state + + _state.exception_handler_enabled = enabled + + +def set_caching(enabled=True): + """Enable or disable caching. + + Args: + enabled (bool): Whether to enable caching. Default is True. + """ + from . import _state + + _state.caching_enabled = enabled + + +__all__ = ["init", "set_exception_handler", "setup_exception_handler", "set_caching"] + +init() diff --git a/Libraries/autogenlib/autogenlib/_cache.py b/Libraries/autogenlib/_cache.py similarity index 96% rename from Libraries/autogenlib/autogenlib/_cache.py rename to Libraries/autogenlib/_cache.py index 01cb99a7..67b02c74 100644 --- a/Libraries/autogenlib/autogenlib/_cache.py +++ b/Libraries/autogenlib/_cache.py @@ -1,100 +1,100 @@ -"""Cache management for autogenlib generated code.""" - -import os -import hashlib -import json -from ._state import caching_enabled - - -def get_cache_dir(): - """Get the directory where cached files are stored.""" - cache_dir = os.path.join(os.path.expanduser("~"), ".autogenlib_cache") - os.makedirs(cache_dir, exist_ok=True) - return cache_dir - - -def get_cache_path(fullname): - """Get the path where the cached data for a module should be stored.""" - cache_dir = get_cache_dir() - - # Create a filename based on the module name - # Use only the first two parts of the fullname (e.g., autogenlib.totp) - # to ensure we're caching at the module level - module_name = ".".join(fullname.split(".")[:2]) - filename = hashlib.md5(module_name.encode()).hexdigest() + ".json" - return os.path.join(cache_dir, filename) - - -def get_cached_data(fullname): - """Get the cached data for a module if it exists.""" - if not caching_enabled: - return None - - cache_path = get_cache_path(fullname) - try: - with open(cache_path, "r") as f: - data = json.load(f) - return data - except (FileNotFoundError, json.JSONDecodeError): - return None - - -def get_cached_code(fullname): - """Get the cached code for a module if it exists.""" - if not caching_enabled: - return None - - data = get_cached_data(fullname) - if data: - return data.get("code") - return None - - -def get_cached_prompt(fullname): - """Get the cached initial prompt for a module if it exists.""" - if not caching_enabled: - return None - - data = get_cached_data(fullname) - if data: - return data.get("prompt") - return None - - -def cache_module(fullname, code, prompt): - """Cache the code and prompt for a module.""" - if not caching_enabled: - return - - cache_path = get_cache_path(fullname) - data = {"code": code, "prompt": prompt, "module_name": fullname} - with open(cache_path, "w") as f: - json.dump(data, f, indent=2) - - -def get_all_modules(): - """Get all cached modules.""" - if not caching_enabled: - return {} - - cache_dir = get_cache_dir() - modules = {} - - try: - for filename in os.listdir(cache_dir): - if filename.endswith(".json"): - filepath = os.path.join(cache_dir, filename) - try: - with open(filepath, "r") as f: - data = json.load(f) - # Extract module name from the data or use the filename - module_name = data.get( - "module_name", os.path.splitext(filename)[0] - ) - modules[module_name] = data - except (json.JSONDecodeError, IOError): - continue - except FileNotFoundError: - pass - - return modules +"""Cache management for autogenlib generated code.""" + +import os +import hashlib +import json +from ._state import caching_enabled + + +def get_cache_dir(): + """Get the directory where cached files are stored.""" + cache_dir = os.path.join(os.path.expanduser("~"), ".autogenlib_cache") + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +def get_cache_path(fullname): + """Get the path where the cached data for a module should be stored.""" + cache_dir = get_cache_dir() + + # Create a filename based on the module name + # Use only the first two parts of the fullname (e.g., autogenlib.totp) + # to ensure we're caching at the module level + module_name = ".".join(fullname.split(".")[:2]) + filename = hashlib.md5(module_name.encode()).hexdigest() + ".json" + return os.path.join(cache_dir, filename) + + +def get_cached_data(fullname): + """Get the cached data for a module if it exists.""" + if not caching_enabled: + return None + + cache_path = get_cache_path(fullname) + try: + with open(cache_path, "r") as f: + data = json.load(f) + return data + except (FileNotFoundError, json.JSONDecodeError): + return None + + +def get_cached_code(fullname): + """Get the cached code for a module if it exists.""" + if not caching_enabled: + return None + + data = get_cached_data(fullname) + if data: + return data.get("code") + return None + + +def get_cached_prompt(fullname): + """Get the cached initial prompt for a module if it exists.""" + if not caching_enabled: + return None + + data = get_cached_data(fullname) + if data: + return data.get("prompt") + return None + + +def cache_module(fullname, code, prompt): + """Cache the code and prompt for a module.""" + if not caching_enabled: + return + + cache_path = get_cache_path(fullname) + data = {"code": code, "prompt": prompt, "module_name": fullname} + with open(cache_path, "w") as f: + json.dump(data, f, indent=2) + + +def get_all_modules(): + """Get all cached modules.""" + if not caching_enabled: + return {} + + cache_dir = get_cache_dir() + modules = {} + + try: + for filename in os.listdir(cache_dir): + if filename.endswith(".json"): + filepath = os.path.join(cache_dir, filename) + try: + with open(filepath, "r") as f: + data = json.load(f) + # Extract module name from the data or use the filename + module_name = data.get( + "module_name", os.path.splitext(filename)[0] + ) + modules[module_name] = data + except (json.JSONDecodeError, IOError): + continue + except FileNotFoundError: + pass + + return modules diff --git a/Libraries/autogenlib/autogenlib/_caller.py b/Libraries/autogenlib/_caller.py similarity index 97% rename from Libraries/autogenlib/autogenlib/_caller.py rename to Libraries/autogenlib/_caller.py index 0bf8ac85..8ef0b712 100644 --- a/Libraries/autogenlib/autogenlib/_caller.py +++ b/Libraries/autogenlib/_caller.py @@ -1,127 +1,127 @@ -"""Caller context extraction for autogenlib.""" - -import inspect -import os -import sys -from pathlib import Path -import traceback -from logging import getLogger - -logger = getLogger(__name__) - - -def get_caller_info(max_depth=10): - """ - Get information about the calling code. - - Args: - max_depth: Maximum number of frames to check in the stack. - - Returns: - dict: Information about the caller including filename and code. - """ - try: - # Get the current stack frames - stack = inspect.stack() - - # Debug stack information - logger.debug(f"Stack depth: {len(stack)}") - for i, frame_info in enumerate(stack[:max_depth]): - frame = frame_info.frame - filename = frame_info.filename - lineno = frame_info.lineno - function = frame_info.function - logger.debug(f"Frame {i}: {filename}:{lineno} in {function}") - - # Find the first frame that's not from autogenlib and is a real file - caller_frame = None - caller_filename = None - - for i, frame_info in enumerate( - stack[1:max_depth] - ): # Skip the first frame (our function) - filename = frame_info.filename - - # Skip if it's internal to Python - if filename.startswith("<") or not os.path.exists(filename): - continue - - # Skip if it's within our package - if "autogenlib" in filename and "_caller.py" not in filename: - continue - - # We found a suitable caller - caller_frame = frame_info.frame - caller_filename = filename - logger.debug(f"Found caller at frame {i + 1}: {filename}") - break - - if not caller_filename: - # Try a different approach - look for an importing file - for i, frame_info in enumerate(stack[1:max_depth]): - filename = frame_info.filename - - # Skip non-file frames - if filename.startswith("<") or not os.path.exists(filename): - continue - - # Check if this frame is doing an import - if ( - frame_info.function == "" - or "import" in frame_info.code_context[0].lower() - ): - caller_frame = frame_info.frame - caller_filename = filename - logger.debug(f"Found importing caller at frame {i + 1}: {filename}") - break - - # If we still didn't find a caller, use a simpler approach - if not caller_filename: - # Just use the top-level script - for frame_info in reversed(stack[:max_depth]): - filename = frame_info.filename - if os.path.exists(filename) and not filename.startswith("<"): - caller_filename = filename - logger.debug(f"Using top-level script as caller: {filename}") - break - - if not caller_filename: - logger.debug("No suitable caller file found") - return {"code": "", "filename": ""} - - # Read the file content - try: - with open(caller_filename, "r") as f: - code = f.read() - - # Get the relative path to make logs cleaner - try: - rel_path = Path(caller_filename).relative_to(Path.cwd()) - display_filename = str(rel_path) - except ValueError: - display_filename = caller_filename - - # Limit code size if it's too large to avoid excessive prompt size - MAX_CODE_SIZE = 8000 # Characters - if len(code) > MAX_CODE_SIZE: - logger.debug( - f"Truncating large caller file ({len(code)} chars) to {MAX_CODE_SIZE} chars" - ) - # Try to find a good place to cut (newline) - cut_point = code[:MAX_CODE_SIZE].rfind("\n") - if cut_point == -1: - cut_point = MAX_CODE_SIZE - code = code[:cut_point] + "\n\n# ... [file truncated due to size] ..." - - logger.debug( - f"Successfully extracted caller code from {display_filename} ({len(code)} chars)" - ) - - return {"code": code, "filename": display_filename} - except Exception as e: - logger.debug(f"Error reading caller file {caller_filename}: {e}") - return {"code": "", "filename": caller_filename} - except Exception as e: - logger.debug(f"Error getting caller info: {e}") - logger.debug(traceback.format_exc()) - return {"code": "", "filename": ""} +"""Caller context extraction for autogenlib.""" + +import inspect +import os +import sys +from pathlib import Path +import traceback +from logging import getLogger + +logger = getLogger(__name__) + + +def get_caller_info(max_depth=10): + """ + Get information about the calling code. + + Args: + max_depth: Maximum number of frames to check in the stack. + + Returns: + dict: Information about the caller including filename and code. + """ + try: + # Get the current stack frames + stack = inspect.stack() + + # Debug stack information + logger.debug(f"Stack depth: {len(stack)}") + for i, frame_info in enumerate(stack[:max_depth]): + frame = frame_info.frame + filename = frame_info.filename + lineno = frame_info.lineno + function = frame_info.function + logger.debug(f"Frame {i}: {filename}:{lineno} in {function}") + + # Find the first frame that's not from autogenlib and is a real file + caller_frame = None + caller_filename = None + + for i, frame_info in enumerate( + stack[1:max_depth] + ): # Skip the first frame (our function) + filename = frame_info.filename + + # Skip if it's internal to Python + if filename.startswith("<") or not os.path.exists(filename): + continue + + # Skip if it's within our package + if "autogenlib" in filename and "_caller.py" not in filename: + continue + + # We found a suitable caller + caller_frame = frame_info.frame + caller_filename = filename + logger.debug(f"Found caller at frame {i + 1}: {filename}") + break + + if not caller_filename: + # Try a different approach - look for an importing file + for i, frame_info in enumerate(stack[1:max_depth]): + filename = frame_info.filename + + # Skip non-file frames + if filename.startswith("<") or not os.path.exists(filename): + continue + + # Check if this frame is doing an import + if ( + frame_info.function == "" + or "import" in frame_info.code_context[0].lower() + ): + caller_frame = frame_info.frame + caller_filename = filename + logger.debug(f"Found importing caller at frame {i + 1}: {filename}") + break + + # If we still didn't find a caller, use a simpler approach + if not caller_filename: + # Just use the top-level script + for frame_info in reversed(stack[:max_depth]): + filename = frame_info.filename + if os.path.exists(filename) and not filename.startswith("<"): + caller_filename = filename + logger.debug(f"Using top-level script as caller: {filename}") + break + + if not caller_filename: + logger.debug("No suitable caller file found") + return {"code": "", "filename": ""} + + # Read the file content + try: + with open(caller_filename, "r") as f: + code = f.read() + + # Get the relative path to make logs cleaner + try: + rel_path = Path(caller_filename).relative_to(Path.cwd()) + display_filename = str(rel_path) + except ValueError: + display_filename = caller_filename + + # Limit code size if it's too large to avoid excessive prompt size + MAX_CODE_SIZE = 8000 # Characters + if len(code) > MAX_CODE_SIZE: + logger.debug( + f"Truncating large caller file ({len(code)} chars) to {MAX_CODE_SIZE} chars" + ) + # Try to find a good place to cut (newline) + cut_point = code[:MAX_CODE_SIZE].rfind("\n") + if cut_point == -1: + cut_point = MAX_CODE_SIZE + code = code[:cut_point] + "\n\n# ... [file truncated due to size] ..." + + logger.debug( + f"Successfully extracted caller code from {display_filename} ({len(code)} chars)" + ) + + return {"code": code, "filename": display_filename} + except Exception as e: + logger.debug(f"Error reading caller file {caller_filename}: {e}") + return {"code": "", "filename": caller_filename} + except Exception as e: + logger.debug(f"Error getting caller info: {e}") + logger.debug(traceback.format_exc()) + return {"code": "", "filename": ""} diff --git a/Libraries/autogenlib/autogenlib/_context.py b/Libraries/autogenlib/_context.py similarity index 96% rename from Libraries/autogenlib/autogenlib/_context.py rename to Libraries/autogenlib/_context.py index 16700cc3..f484fe12 100644 --- a/Libraries/autogenlib/autogenlib/_context.py +++ b/Libraries/autogenlib/_context.py @@ -1,55 +1,55 @@ -"""Context management for autogenlib modules.""" - -import ast - -# Store the context of each module -module_contexts = {} - - -def get_module_context(fullname): - """Get the context of a module.""" - return module_contexts.get(fullname, {}) - - -def set_module_context(fullname, code): - """Update the context of a module.""" - module_contexts[fullname] = { - "code": code, - "defined_names": extract_defined_names(code), - } - - -def extract_defined_names(code): - """Extract all defined names (functions, classes, variables) from the code.""" - try: - tree = ast.parse(code) - names = set() - - for node in ast.walk(tree): - if isinstance(node, ast.FunctionDef): - names.add(node.name) - elif isinstance(node, ast.ClassDef): - names.add(node.name) - elif isinstance(node, ast.Assign): - for target in node.targets: - if isinstance(target, ast.Name): - names.add(target.id) - - return names - except SyntaxError: - return set() - - -def is_name_defined(fullname): - """Check if a name is defined in its module.""" - if "." not in fullname: - return False - - module_path, name = fullname.rsplit(".", 1) - context = get_module_context(module_path) - - if not context: - # Module doesn't exist yet - return False - - return name in context.get("defined_names", set()) +"""Context management for autogenlib modules.""" + +import ast + +# Store the context of each module +module_contexts = {} + + +def get_module_context(fullname): + """Get the context of a module.""" + return module_contexts.get(fullname, {}) + + +def set_module_context(fullname, code): + """Update the context of a module.""" + module_contexts[fullname] = { + "code": code, + "defined_names": extract_defined_names(code), + } + + +def extract_defined_names(code): + """Extract all defined names (functions, classes, variables) from the code.""" + try: + tree = ast.parse(code) + names = set() + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + names.add(node.name) + elif isinstance(node, ast.ClassDef): + names.add(node.name) + elif isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name): + names.add(target.id) + + return names + except SyntaxError: + return set() + + +def is_name_defined(fullname): + """Check if a name is defined in its module.""" + if "." not in fullname: + return False + + module_path, name = fullname.rsplit(".", 1) + context = get_module_context(module_path) + + if not context: + # Module doesn't exist yet + return False + + return name in context.get("defined_names", set()) diff --git a/Libraries/autogenlib/autogenlib/_exception_handler.py b/Libraries/autogenlib/_exception_handler.py similarity index 97% rename from Libraries/autogenlib/autogenlib/_exception_handler.py rename to Libraries/autogenlib/_exception_handler.py index 382a858e..8d1e24fc 100644 --- a/Libraries/autogenlib/autogenlib/_exception_handler.py +++ b/Libraries/autogenlib/_exception_handler.py @@ -1,493 +1,493 @@ -"""Exception handling and LLM fix suggestions for autogenlib.""" - -import sys -import traceback -import os -from logging import getLogger -import openai -import time -import textwrap -import re - -from ._cache import get_cached_code, cache_module -from ._context import set_module_context -from ._state import description, exception_handler_enabled - -logger = getLogger(__name__) - - -def setup_exception_handler(): - """Set up the global exception handler.""" - # Store the original excepthook - original_excepthook = sys.excepthook - - # Define our custom exception hook - def custom_excepthook(exc_type, exc_value, exc_traceback): - if exception_handler_enabled: - handle_exception(exc_type, exc_value, exc_traceback) - # Call the original excepthook regardless - original_excepthook(exc_type, exc_value, exc_traceback) - - # Set our custom excepthook as the global handler - sys.excepthook = custom_excepthook - - -def handle_exception(exc_type, exc_value, exc_traceback): - """Handle an exception by sending it to the LLM for fix suggestions.""" - # Extract the traceback information - tb_frames = traceback.extract_tb(exc_traceback) - tb_str = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) - - # Determine the source of the exception - is_autogenlib_exception = False - module_name = None - source_code = None - source_file = None - - # Try to find the frame where the exception originated - for frame in tb_frames: - filename = frame.filename - lineno = frame.lineno - - # Check if this file is from an autogenlib module - if "" not in filename and filename != "": - # This is a real file - if filename.endswith(".py"): - source_file = filename - module_name_from_frame = None - - # Try to get the module name from the frame - frame_module = None - if hasattr(frame, "frame") and hasattr(frame.frame, "f_globals"): - module_name_from_frame = frame.frame.f_globals.get("__name__") - elif len(frame) > 3 and hasattr(frame[0], "f_globals"): - module_name_from_frame = frame[0].f_globals.get("__name__") - - if ( - module_name_from_frame - and module_name_from_frame.startswith("autogenlib.") - and module_name_from_frame != "autogenlib" - ): - # This is an autogenlib module - is_autogenlib_exception = True - module_name = module_name_from_frame - - # Get code from cache if it's an autogenlib module - if module_name.count(".") > 1: - module_name = ".".join(module_name.split(".")[:2]) - source_code = get_cached_code(module_name) - break - - # For non-autogenlib modules, try to read the source file - try: - with open(filename, "r") as f: - source_code = f.read() - module_name = module_name_from_frame or os.path.basename( - filename - ).replace(".py", "") - break - except: - pass - - # If we couldn't determine the source from the traceback, use the last frame - if not source_code and tb_frames: - last_frame = tb_frames[-1] - if hasattr(last_frame, "filename") and last_frame.filename: - filename = last_frame.filename - if ( - "" not in filename - and filename != "" - and filename.endswith(".py") - ): - try: - with open(filename, "r") as f: - source_code = f.read() - module_name = os.path.basename(filename).replace(".py", "") - except: - pass - - # If we still don't have source code but have a module name from an autogenlib module - if not source_code and module_name and module_name.startswith("autogenlib."): - source_code = get_cached_code(module_name) - is_autogenlib_exception = True - - # Check all loaded modules if we still don't have source code - if not source_code: - for loaded_module_name, loaded_module in list(sys.modules.items()): - if ( - loaded_module_name.startswith("autogenlib.") - and loaded_module_name != "autogenlib" - ): - try: - # Try to see if this module might be related to the exception - if ( - exc_type.__module__ == loaded_module_name - or loaded_module_name in tb_str - ): - module_name = loaded_module_name - if module_name.count(".") > 1: - module_name = ".".join(module_name.split(".")[:2]) - source_code = get_cached_code(module_name) - is_autogenlib_exception = True - break - except: - continue - - # If we still don't have any source code, try to extract it from any file mentioned in the traceback - if not source_code: - for line in tb_str.split("\n"): - if 'File "' in line and '.py"' in line: - try: - file_path = line.split('File "')[1].split('"')[0] - if os.path.exists(file_path) and file_path.endswith(".py"): - with open(file_path, "r") as f: - source_code = f.read() - module_name = os.path.basename(file_path).replace(".py", "") - source_file = file_path - break - except: - continue - - # If we still don't have source code, we'll just use the traceback - if not source_code: - source_code = "# Source code could not be determined" - module_name = "unknown" - - # Generate fix using LLM - fix_info = generate_fix( - module_name, - source_code, - exc_type, - exc_value, - tb_str, - is_autogenlib_exception, - source_file, - ) - - if fix_info and is_autogenlib_exception: - # For autogenlib modules, we can try to reload them automatically - fixed_code = fix_info.get("fixed_code") - if fixed_code: - # Cache the fixed code - cache_module(module_name, fixed_code, description) - - # Update the module context - set_module_context(module_name, fixed_code) - - # Reload the module with the fixed code - try: - if module_name in sys.modules: - # Execute the new code in the module's namespace - exec(fixed_code, sys.modules[module_name].__dict__) - logger.info(f"Module {module_name} has been fixed and reloaded") - - # Output a helpful message to the user - print("\n" + "=" * 80) - print(f"AutoGenLib fixed an error in module {module_name}") - print("The module has been reloaded with the fix.") - print("Please retry your operation.") - print("=" * 80 + "\n") - except Exception as e: - logger.error(f"Error reloading fixed module: {e}") - print("\n" + "=" * 80) - print(f"AutoGenLib attempted to fix an error in module {module_name}") - print(f"But encountered an error while reloading: {e}") - print("Please restart your application to apply the fix.") - print("=" * 80 + "\n") - elif fix_info: - # For external code, just display the fix suggestions - print("\n" + "=" * 80) - print(f"AutoGenLib detected an error in {module_name}") - if source_file: - print(f"File: {source_file}") - print(f"Error: {exc_type.__name__}: {exc_value}") - - # Display the fix suggestions - print("\nFix Suggestions:") - print("-" * 40) - if "explanation" in fix_info: - explanation = textwrap.fill(fix_info["explanation"], width=78) - print(explanation) - print("-" * 40) - - if "fixed_code" in fix_info: - print("Suggested fixed code:") - print("-" * 40) - if source_file: - print(f"# Apply this fix to {source_file}") - - # If we have specific changes, display them in a more readable format - if "changes" in fix_info: - for change in fix_info["changes"]: - print( - f"Line {change.get('line', '?')}: {change.get('description', '')}" - ) - if "original" in change and "new" in change: - print(f"Original: {change['original']}") - print(f"New: {change['new']}") - print() - else: - # Otherwise just print a snippet of the fixed code (first 20 lines) - fixed_code_lines = fix_info["fixed_code"].split("\n") - if len(fixed_code_lines) > 20: - print("\n".join(fixed_code_lines[:20])) - print("... (truncated for readability)") - else: - print(fix_info["fixed_code"]) - - print("=" * 80 + "\n") - - -def extract_python_code(response): - """ - Extract Python code from LLM response more robustly. - - Handles various ways code might be formatted in the response: - - Code blocks with ```python or ``` markers - - Multiple code blocks - - Indented code blocks - - Code without any markers - - Returns the cleaned Python code. - """ - # Check if response is already clean code (no markdown) - try: - compile(response, "", "exec") - return response - except SyntaxError: - pass - - # Try to extract code from markdown code blocks - code_block_pattern = r"```(?:python)?(.*?)```" - matches = re.findall(code_block_pattern, response, re.DOTALL) - - if matches: - # Join all code blocks and check if valid - extracted_code = "\n\n".join(match.strip() for match in matches) - try: - compile(extracted_code, "", "exec") - return extracted_code - except SyntaxError: - pass - - # If we get here, no valid code blocks were found - # Try to identify the largest Python-like chunk in the text - lines = response.split("\n") - code_lines = [] - current_code_chunk = [] - - for line in lines: - # Skip obvious non-code lines - if re.match( - r"^(#|Here's|I've|This|Note:|Remember:|Explanation:)", line.strip() - ): - # If we were collecting code, save the chunk - if current_code_chunk: - code_lines.extend(current_code_chunk) - current_code_chunk = [] - continue - - # Lines that likely indicate code - if re.match( - r"^(import|from|def|class|if|for|while|return|try|with|@|\s{4}| )", line - ): - current_code_chunk.append(line) - elif line.strip() == "" and current_code_chunk: - # Empty lines within code blocks are kept - current_code_chunk.append(line) - elif current_code_chunk: - # If we have a non-empty line that doesn't look like code but follows code - # we keep it in the current chunk (might be a variable assignment, etc.) - current_code_chunk.append(line) - - # Add any remaining code chunk - if current_code_chunk: - code_lines.extend(current_code_chunk) - - # Join all identified code lines - extracted_code = "\n".join(code_lines) - - # If we couldn't extract anything or it's invalid, return the original - # but the validator will likely reject it - if not extracted_code: - return response - - try: - compile(extracted_code, "", "exec") - return extracted_code - except SyntaxError: - # Last resort: try to use the whole response if it might be valid code - if "def " in response or "class " in response or "import " in response: - try: - compile(response, "", "exec") - return response - except SyntaxError: - pass - - # Log the issue - logger.warning("Could not extract valid Python code from response") - return response - - -def generate_fix( - module_name, - current_code, - exc_type, - exc_value, - traceback_str, - is_autogenlib=False, - source_file=None, -): - """Generate a fix for the exception using the LLM. - - Args: - module_name: Name of the module where the exception occurred - current_code: Current source code of the module - exc_type: Exception type - exc_value: Exception value - traceback_str: Formatted traceback string - is_autogenlib: Whether this is an autogenlib-generated module - source_file: Path to the source file (for non-autogenlib modules) - - Returns: - Dictionary containing fix information: - - fixed_code: The fixed code (if available) - - explanation: Explanation of the issue and fix - - changes: List of specific changes made (if available) - """ - try: - # Set API key from environment variable - api_key = os.environ.get("OPENAI_API_KEY") - if not api_key: - logger.error("Please set the OPENAI_API_KEY environment variable.") - return None - - base_url = os.environ.get("OPENAI_API_BASE_URL") - model = os.environ.get("OPENAI_MODEL", "gpt-4.1") - - # Initialize the OpenAI client - client = openai.OpenAI(api_key=api_key, base_url=base_url) - - # Create a system prompt for the LLM - system_prompt = """ - You are an expert Python developer and debugger specialized in fixing code errors. - - You meticulously analyze errors by: - 1. Tracing the execution flow to the exact point of failure - 2. Understanding the root cause, not just the symptoms - 3. Identifying edge cases that may have triggered the exception - 4. Looking for similar issues elsewhere in the code - - When creating fixes, you: - 1. Make the minimal changes necessary to resolve the issue - 2. Maintain consistency with the existing code style - 3. Add appropriate defensive programming - 4. Ensure type consistency and proper error handling - 5. Add brief comments explaining non-obvious fixes - - Your responses must be precise, direct, and immediately applicable. - """ - - # Create a user prompt for the LLM - user_prompt = f""" - DEBUGGING TASK: Fix a Python error in {module_name} - - MODULE DETAILS: - {"AUTO-GENERATED MODULE" if is_autogenlib else "USER CODE"} - {f"Source file: {source_file}" if source_file else ""} - - CURRENT CODE: - ```python - {current_code} - ``` - - ERROR DETAILS: - Type: {exc_type.__name__} - Message: {exc_value} - - TRACEBACK: - {traceback_str} - - {"REQUIRED RESPONSE FORMAT: Return ONLY complete fixed Python code. No explanations, comments, or markdown." if is_autogenlib else 'REQUIRED RESPONSE FORMAT: JSON with "explanation", "changes" (line-by-line fixes), and "fixed_code" fields.'} - - {"Remember: The module will be executed directly so your response must be valid Python code only." if is_autogenlib else "Remember: Be specific about what changes and why. Include line numbers for easy reference."} - """ - - # Call the OpenAI API - max_retries = 3 - for attempt in range(max_retries): - try: - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - max_tokens=5000, - temperature=0.3, # Lower temperature for more deterministic results - response_format={"type": "json_object"} - if not is_autogenlib - else None, - ) - - # Get the generated response - content = response.choices[0].message.content.strip() - - if is_autogenlib: - # For autogenlib modules, we expect just the fixed code - fixed_code = extract_python_code(content) - - # Validate the fixed code - try: - compile(fixed_code, "", "exec") - return {"fixed_code": fixed_code} - except Exception as e: - logger.warning(f"Generated fix contains syntax errors: {e}") - if attempt == max_retries - 1: - return None - time.sleep(1) # Wait before retry - else: - # For regular code, we expect a JSON response - try: - import json - - fix_info = json.loads(content) - - # Validate that we have at least some of the expected fields - if not any( - field in fix_info - for field in ["explanation", "changes", "fixed_code"] - ): - raise ValueError("Missing required fields in response") - - # If we have fixed code, validate it - if "fixed_code" in fix_info: - try: - compile(fix_info["fixed_code"], "", "exec") - except Exception as e: - logger.warning( - f"Generated fix contains syntax errors: {e}" - ) - # We'll still return it for user information, even if it has syntax errors - - return fix_info - except Exception as e: - logger.warning(f"Error parsing LLM response as JSON: {e}") - if attempt == max_retries - 1: - # If all attempts failed to parse as JSON, return a simplified response - return { - "explanation": "Error analyzing the code. Here's the raw LLM output:", - "fixed_code": content, - } - time.sleep(1) # Wait before retry - - except Exception as e: - logger.error(f"Error generating fix: {e}") - if attempt == max_retries - 1: - return None - time.sleep(1) # Wait before retry - - return None - except Exception as e: - logger.error(f"Error in generate_fix: {e}") - return None +"""Exception handling and LLM fix suggestions for autogenlib.""" + +import sys +import traceback +import os +from logging import getLogger +import openai +import time +import textwrap +import re + +from ._cache import get_cached_code, cache_module +from ._context import set_module_context +from ._state import description, exception_handler_enabled + +logger = getLogger(__name__) + + +def setup_exception_handler(): + """Set up the global exception handler.""" + # Store the original excepthook + original_excepthook = sys.excepthook + + # Define our custom exception hook + def custom_excepthook(exc_type, exc_value, exc_traceback): + if exception_handler_enabled: + handle_exception(exc_type, exc_value, exc_traceback) + # Call the original excepthook regardless + original_excepthook(exc_type, exc_value, exc_traceback) + + # Set our custom excepthook as the global handler + sys.excepthook = custom_excepthook + + +def handle_exception(exc_type, exc_value, exc_traceback): + """Handle an exception by sending it to the LLM for fix suggestions.""" + # Extract the traceback information + tb_frames = traceback.extract_tb(exc_traceback) + tb_str = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + + # Determine the source of the exception + is_autogenlib_exception = False + module_name = None + source_code = None + source_file = None + + # Try to find the frame where the exception originated + for frame in tb_frames: + filename = frame.filename + lineno = frame.lineno + + # Check if this file is from an autogenlib module + if "" not in filename and filename != "": + # This is a real file + if filename.endswith(".py"): + source_file = filename + module_name_from_frame = None + + # Try to get the module name from the frame + frame_module = None + if hasattr(frame, "frame") and hasattr(frame.frame, "f_globals"): + module_name_from_frame = frame.frame.f_globals.get("__name__") + elif len(frame) > 3 and hasattr(frame[0], "f_globals"): + module_name_from_frame = frame[0].f_globals.get("__name__") + + if ( + module_name_from_frame + and module_name_from_frame.startswith("autogenlib.") + and module_name_from_frame != "autogenlib" + ): + # This is an autogenlib module + is_autogenlib_exception = True + module_name = module_name_from_frame + + # Get code from cache if it's an autogenlib module + if module_name.count(".") > 1: + module_name = ".".join(module_name.split(".")[:2]) + source_code = get_cached_code(module_name) + break + + # For non-autogenlib modules, try to read the source file + try: + with open(filename, "r") as f: + source_code = f.read() + module_name = module_name_from_frame or os.path.basename( + filename + ).replace(".py", "") + break + except: + pass + + # If we couldn't determine the source from the traceback, use the last frame + if not source_code and tb_frames: + last_frame = tb_frames[-1] + if hasattr(last_frame, "filename") and last_frame.filename: + filename = last_frame.filename + if ( + "" not in filename + and filename != "" + and filename.endswith(".py") + ): + try: + with open(filename, "r") as f: + source_code = f.read() + module_name = os.path.basename(filename).replace(".py", "") + except: + pass + + # If we still don't have source code but have a module name from an autogenlib module + if not source_code and module_name and module_name.startswith("autogenlib."): + source_code = get_cached_code(module_name) + is_autogenlib_exception = True + + # Check all loaded modules if we still don't have source code + if not source_code: + for loaded_module_name, loaded_module in list(sys.modules.items()): + if ( + loaded_module_name.startswith("autogenlib.") + and loaded_module_name != "autogenlib" + ): + try: + # Try to see if this module might be related to the exception + if ( + exc_type.__module__ == loaded_module_name + or loaded_module_name in tb_str + ): + module_name = loaded_module_name + if module_name.count(".") > 1: + module_name = ".".join(module_name.split(".")[:2]) + source_code = get_cached_code(module_name) + is_autogenlib_exception = True + break + except: + continue + + # If we still don't have any source code, try to extract it from any file mentioned in the traceback + if not source_code: + for line in tb_str.split("\n"): + if 'File "' in line and '.py"' in line: + try: + file_path = line.split('File "')[1].split('"')[0] + if os.path.exists(file_path) and file_path.endswith(".py"): + with open(file_path, "r") as f: + source_code = f.read() + module_name = os.path.basename(file_path).replace(".py", "") + source_file = file_path + break + except: + continue + + # If we still don't have source code, we'll just use the traceback + if not source_code: + source_code = "# Source code could not be determined" + module_name = "unknown" + + # Generate fix using LLM + fix_info = generate_fix( + module_name, + source_code, + exc_type, + exc_value, + tb_str, + is_autogenlib_exception, + source_file, + ) + + if fix_info and is_autogenlib_exception: + # For autogenlib modules, we can try to reload them automatically + fixed_code = fix_info.get("fixed_code") + if fixed_code: + # Cache the fixed code + cache_module(module_name, fixed_code, description) + + # Update the module context + set_module_context(module_name, fixed_code) + + # Reload the module with the fixed code + try: + if module_name in sys.modules: + # Execute the new code in the module's namespace + exec(fixed_code, sys.modules[module_name].__dict__) + logger.info(f"Module {module_name} has been fixed and reloaded") + + # Output a helpful message to the user + print("\n" + "=" * 80) + print(f"AutoGenLib fixed an error in module {module_name}") + print("The module has been reloaded with the fix.") + print("Please retry your operation.") + print("=" * 80 + "\n") + except Exception as e: + logger.error(f"Error reloading fixed module: {e}") + print("\n" + "=" * 80) + print(f"AutoGenLib attempted to fix an error in module {module_name}") + print(f"But encountered an error while reloading: {e}") + print("Please restart your application to apply the fix.") + print("=" * 80 + "\n") + elif fix_info: + # For external code, just display the fix suggestions + print("\n" + "=" * 80) + print(f"AutoGenLib detected an error in {module_name}") + if source_file: + print(f"File: {source_file}") + print(f"Error: {exc_type.__name__}: {exc_value}") + + # Display the fix suggestions + print("\nFix Suggestions:") + print("-" * 40) + if "explanation" in fix_info: + explanation = textwrap.fill(fix_info["explanation"], width=78) + print(explanation) + print("-" * 40) + + if "fixed_code" in fix_info: + print("Suggested fixed code:") + print("-" * 40) + if source_file: + print(f"# Apply this fix to {source_file}") + + # If we have specific changes, display them in a more readable format + if "changes" in fix_info: + for change in fix_info["changes"]: + print( + f"Line {change.get('line', '?')}: {change.get('description', '')}" + ) + if "original" in change and "new" in change: + print(f"Original: {change['original']}") + print(f"New: {change['new']}") + print() + else: + # Otherwise just print a snippet of the fixed code (first 20 lines) + fixed_code_lines = fix_info["fixed_code"].split("\n") + if len(fixed_code_lines) > 20: + print("\n".join(fixed_code_lines[:20])) + print("... (truncated for readability)") + else: + print(fix_info["fixed_code"]) + + print("=" * 80 + "\n") + + +def extract_python_code(response): + """ + Extract Python code from LLM response more robustly. + + Handles various ways code might be formatted in the response: + - Code blocks with ```python or ``` markers + - Multiple code blocks + - Indented code blocks + - Code without any markers + + Returns the cleaned Python code. + """ + # Check if response is already clean code (no markdown) + try: + compile(response, "", "exec") + return response + except SyntaxError: + pass + + # Try to extract code from markdown code blocks + code_block_pattern = r"```(?:python)?(.*?)```" + matches = re.findall(code_block_pattern, response, re.DOTALL) + + if matches: + # Join all code blocks and check if valid + extracted_code = "\n\n".join(match.strip() for match in matches) + try: + compile(extracted_code, "", "exec") + return extracted_code + except SyntaxError: + pass + + # If we get here, no valid code blocks were found + # Try to identify the largest Python-like chunk in the text + lines = response.split("\n") + code_lines = [] + current_code_chunk = [] + + for line in lines: + # Skip obvious non-code lines + if re.match( + r"^(#|Here's|I've|This|Note:|Remember:|Explanation:)", line.strip() + ): + # If we were collecting code, save the chunk + if current_code_chunk: + code_lines.extend(current_code_chunk) + current_code_chunk = [] + continue + + # Lines that likely indicate code + if re.match( + r"^(import|from|def|class|if|for|while|return|try|with|@|\s{4}| )", line + ): + current_code_chunk.append(line) + elif line.strip() == "" and current_code_chunk: + # Empty lines within code blocks are kept + current_code_chunk.append(line) + elif current_code_chunk: + # If we have a non-empty line that doesn't look like code but follows code + # we keep it in the current chunk (might be a variable assignment, etc.) + current_code_chunk.append(line) + + # Add any remaining code chunk + if current_code_chunk: + code_lines.extend(current_code_chunk) + + # Join all identified code lines + extracted_code = "\n".join(code_lines) + + # If we couldn't extract anything or it's invalid, return the original + # but the validator will likely reject it + if not extracted_code: + return response + + try: + compile(extracted_code, "", "exec") + return extracted_code + except SyntaxError: + # Last resort: try to use the whole response if it might be valid code + if "def " in response or "class " in response or "import " in response: + try: + compile(response, "", "exec") + return response + except SyntaxError: + pass + + # Log the issue + logger.warning("Could not extract valid Python code from response") + return response + + +def generate_fix( + module_name, + current_code, + exc_type, + exc_value, + traceback_str, + is_autogenlib=False, + source_file=None, +): + """Generate a fix for the exception using the LLM. + + Args: + module_name: Name of the module where the exception occurred + current_code: Current source code of the module + exc_type: Exception type + exc_value: Exception value + traceback_str: Formatted traceback string + is_autogenlib: Whether this is an autogenlib-generated module + source_file: Path to the source file (for non-autogenlib modules) + + Returns: + Dictionary containing fix information: + - fixed_code: The fixed code (if available) + - explanation: Explanation of the issue and fix + - changes: List of specific changes made (if available) + """ + try: + # Set API key from environment variable + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + logger.error("Please set the OPENAI_API_KEY environment variable.") + return None + + base_url = os.environ.get("OPENAI_API_BASE_URL") + model = os.environ.get("OPENAI_MODEL", "gpt-4.1") + + # Initialize the OpenAI client + client = openai.OpenAI(api_key=api_key, base_url=base_url) + + # Create a system prompt for the LLM + system_prompt = """ + You are an expert Python developer and debugger specialized in fixing code errors. + + You meticulously analyze errors by: + 1. Tracing the execution flow to the exact point of failure + 2. Understanding the root cause, not just the symptoms + 3. Identifying edge cases that may have triggered the exception + 4. Looking for similar issues elsewhere in the code + + When creating fixes, you: + 1. Make the minimal changes necessary to resolve the issue + 2. Maintain consistency with the existing code style + 3. Add appropriate defensive programming + 4. Ensure type consistency and proper error handling + 5. Add brief comments explaining non-obvious fixes + + Your responses must be precise, direct, and immediately applicable. + """ + + # Create a user prompt for the LLM + user_prompt = f""" + DEBUGGING TASK: Fix a Python error in {module_name} + + MODULE DETAILS: + {"AUTO-GENERATED MODULE" if is_autogenlib else "USER CODE"} + {f"Source file: {source_file}" if source_file else ""} + + CURRENT CODE: + ```python + {current_code} + ``` + + ERROR DETAILS: + Type: {exc_type.__name__} + Message: {exc_value} + + TRACEBACK: + {traceback_str} + + {"REQUIRED RESPONSE FORMAT: Return ONLY complete fixed Python code. No explanations, comments, or markdown." if is_autogenlib else 'REQUIRED RESPONSE FORMAT: JSON with "explanation", "changes" (line-by-line fixes), and "fixed_code" fields.'} + + {"Remember: The module will be executed directly so your response must be valid Python code only." if is_autogenlib else "Remember: Be specific about what changes and why. Include line numbers for easy reference."} + """ + + # Call the OpenAI API + max_retries = 3 + for attempt in range(max_retries): + try: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + max_tokens=5000, + temperature=0.3, # Lower temperature for more deterministic results + response_format={"type": "json_object"} + if not is_autogenlib + else None, + ) + + # Get the generated response + content = response.choices[0].message.content.strip() + + if is_autogenlib: + # For autogenlib modules, we expect just the fixed code + fixed_code = extract_python_code(content) + + # Validate the fixed code + try: + compile(fixed_code, "", "exec") + return {"fixed_code": fixed_code} + except Exception as e: + logger.warning(f"Generated fix contains syntax errors: {e}") + if attempt == max_retries - 1: + return None + time.sleep(1) # Wait before retry + else: + # For regular code, we expect a JSON response + try: + import json + + fix_info = json.loads(content) + + # Validate that we have at least some of the expected fields + if not any( + field in fix_info + for field in ["explanation", "changes", "fixed_code"] + ): + raise ValueError("Missing required fields in response") + + # If we have fixed code, validate it + if "fixed_code" in fix_info: + try: + compile(fix_info["fixed_code"], "", "exec") + except Exception as e: + logger.warning( + f"Generated fix contains syntax errors: {e}" + ) + # We'll still return it for user information, even if it has syntax errors + + return fix_info + except Exception as e: + logger.warning(f"Error parsing LLM response as JSON: {e}") + if attempt == max_retries - 1: + # If all attempts failed to parse as JSON, return a simplified response + return { + "explanation": "Error analyzing the code. Here's the raw LLM output:", + "fixed_code": content, + } + time.sleep(1) # Wait before retry + + except Exception as e: + logger.error(f"Error generating fix: {e}") + if attempt == max_retries - 1: + return None + time.sleep(1) # Wait before retry + + return None + except Exception as e: + logger.error(f"Error in generate_fix: {e}") + return None diff --git a/Libraries/autogenlib/autogenlib/_finder.py b/Libraries/autogenlib/_finder.py similarity index 97% rename from Libraries/autogenlib/autogenlib/_finder.py rename to Libraries/autogenlib/_finder.py index 619a7f25..dedd7574 100644 --- a/Libraries/autogenlib/autogenlib/_finder.py +++ b/Libraries/autogenlib/_finder.py @@ -1,239 +1,239 @@ -"""Import hook implementation for autogenlib.""" - -import sys -import importlib.abc -import importlib.machinery -import logging -import os -from ._state import description -from ._generator import generate_code -from ._cache import get_cached_code, cache_module -from ._context import get_module_context, set_module_context -from ._caller import get_caller_info - -logger = logging.getLogger(__name__) - - -class AutoLibFinder(importlib.abc.MetaPathFinder): - def __init__(self): - pass - - def find_spec(self, fullname, path, target=None): - # Only handle imports under the 'autogenlib' namespace, excluding autogenlib itself - if not fullname.startswith("autogenlib.") or fullname == "autogenlib": - return None - - if not description: - return None - - # Get caller code context - try: - caller_info = get_caller_info() - if caller_info.get("code"): - logger.debug(f"Got caller context from {caller_info.get('filename')}") - else: - logger.debug("No caller context available") - except Exception as e: - logger.warning(f"Error getting caller info: {e}") - caller_info = {"code": "", "filename": ""} - - # Parse the fullname into components and determine the module structure - parts = fullname.split(".") - - # Handle package structure (e.g., autogenlib.tokens.secure) - is_package = False - package_path = None - module_to_check = fullname - - if len(parts) > 2: - # This might be a nested package or a module within a package - parent_module_name = ".".join(parts[:-1]) # e.g., 'autogenlib.tokens' - - # Check if the parent module exists as a package - if parent_module_name in sys.modules: - parent_module = sys.modules[parent_module_name] - parent_path = getattr(parent_module, "__path__", None) - - if parent_path: - # Parent is a package - is_package = False - package_path = parent_path - - # We need to check if this is requesting a module that doesn't exist yet - # If the parent exists as a package, we'll create a module within it - module_to_check = fullname - - # Check if an attribute in the parent - attr_name = parts[-1] - if hasattr(parent_module, attr_name): - # The attribute exists, no need to generate code - return None - else: - # Parent module doesn't exist yet - # Start by generating the immediate parent package - parent_package_name = ".".join(parts[:2]) # e.g., 'autogenlib.tokens' - - # First ensure the parent package exists - if parent_package_name not in sys.modules: - # Generate the parent package - parent_code = generate_code( - description, parent_package_name, None, caller_info - ) - if parent_code: - # Cache the generated code with the prompt - cache_module(parent_package_name, parent_code, description) - # Update the module context - set_module_context(parent_package_name, parent_code) - - # Create a spec for the parent package - parent_loader = AutoLibLoader(parent_package_name, parent_code) - parent_spec = importlib.machinery.ModuleSpec( - parent_package_name, parent_loader, is_package=True - ) - - # Create and initialize the parent package - parent_module = importlib.util.module_from_spec(parent_spec) - sys.modules[parent_package_name] = parent_module - parent_spec.loader.exec_module(parent_module) - - # Set the __path__ attribute to make it a proper package - # This is crucial for nested imports to work - if not hasattr(parent_module, "__path__"): - parent_module.__path__ = [] - - # Now handle the subpackage or module - if len(parts) == 3: - # This is a direct submodule of the parent (e.g., autogenlib.tokens.secure) - is_package = False - module_to_check = fullname - else: - # This is a nested subpackage (e.g., autogenlib.tokens.secure.module) - # We need to create intermediate packages - current_pkg = ( - parts[0] + "." + parts[1] - ) # Start with autogenlib.tokens - - for i in range(2, len(parts) - 1): - sub_pkg = ( - current_pkg + "." + parts[i] - ) # e.g., autogenlib.tokens.secure - - if sub_pkg not in sys.modules: - # Generate and load this subpackage - sub_code = generate_code( - description, sub_pkg, None, caller_info - ) - if sub_code: - cache_module(sub_pkg, sub_code, description) - set_module_context(sub_pkg, sub_code) - - sub_loader = AutoLibLoader(sub_pkg, sub_code) - sub_spec = importlib.machinery.ModuleSpec( - sub_pkg, sub_loader, is_package=True - ) - - sub_module = importlib.util.module_from_spec(sub_spec) - sys.modules[sub_pkg] = sub_module - sub_spec.loader.exec_module(sub_module) - - if not hasattr(sub_module, "__path__"): - sub_module.__path__ = [] - - current_pkg = sub_pkg - - # Finally, set up for the actual module we want to import - is_package = False - module_to_check = fullname - else: - # Standard case: autogenlib.module - is_package = len(parts) == 2 - module_to_check = fullname - - # Handle attribute import (e.g., autogenlib.tokens.generate_token) - if len(parts) > 2: - module_name = ".".join(parts[:2]) # e.g., 'autogenlib.tokens' - attr_name = parts[-1] # e.g., 'generate_token' - - # Check if the module exists but is missing this attribute - if module_name in sys.modules: - module = sys.modules[module_name] - - # If the attribute doesn't exist, regenerate the module - if not hasattr(module, attr_name): - # Get the current module code - module_context = get_module_context(module_name) - current_code = module_context.get("code", "") - - # Generate updated code including the new function - new_code = generate_code( - description, fullname, current_code, caller_info - ) - if new_code: - # Update the cache and module - cache_module(module_name, new_code, description) - set_module_context(module_name, new_code) - - # Execute the new code in the module's namespace - exec(new_code, module.__dict__) - - # If the attribute exists now, return None to continue normal import - if hasattr(module, attr_name): - return None - - # Check if the module is already cached - code = get_cached_code(module_to_check) - - if code is None: - # Generate code using OpenAI's API with caller context - code = generate_code(description, module_to_check, None, caller_info) - if code is not None: - # Cache the generated code with the prompt - cache_module(module_to_check, code, description) - # Update the module context - set_module_context(module_to_check, code) - - if code is not None: - # Create a spec for the module - loader = AutoLibLoader(module_to_check, code) - spec = importlib.machinery.ModuleSpec( - module_to_check, loader, is_package=is_package - ) - - # Set origin for proper package handling - if is_package: - spec.submodule_search_locations = [] - - return spec - - return None - - -class AutoLibLoader(importlib.abc.Loader): - def __init__(self, fullname, code): - self.fullname = fullname - self.code = code - - def create_module(self, spec): - return None # Use the default module creation - - def exec_module(self, module): - # Set up package attributes if this is a package - if getattr(module.__spec__, "submodule_search_locations", None) is not None: - # This is a package - if not hasattr(module, "__path__"): - module.__path__ = [] - - # Create a virtual __init__.py for packages - if "__init__" not in self.code: - init_code = self.code - else: - init_code = self.code - - # Execute the code - exec(init_code, module.__dict__) - else: - # Regular module - exec(self.code, module.__dict__) - - # Update the module context - set_module_context(self.fullname, self.code) +"""Import hook implementation for autogenlib.""" + +import sys +import importlib.abc +import importlib.machinery +import logging +import os +from ._state import description +from ._generator import generate_code +from ._cache import get_cached_code, cache_module +from ._context import get_module_context, set_module_context +from ._caller import get_caller_info + +logger = logging.getLogger(__name__) + + +class AutoLibFinder(importlib.abc.MetaPathFinder): + def __init__(self): + pass + + def find_spec(self, fullname, path, target=None): + # Only handle imports under the 'autogenlib' namespace, excluding autogenlib itself + if not fullname.startswith("autogenlib.") or fullname == "autogenlib": + return None + + if not description: + return None + + # Get caller code context + try: + caller_info = get_caller_info() + if caller_info.get("code"): + logger.debug(f"Got caller context from {caller_info.get('filename')}") + else: + logger.debug("No caller context available") + except Exception as e: + logger.warning(f"Error getting caller info: {e}") + caller_info = {"code": "", "filename": ""} + + # Parse the fullname into components and determine the module structure + parts = fullname.split(".") + + # Handle package structure (e.g., autogenlib.tokens.secure) + is_package = False + package_path = None + module_to_check = fullname + + if len(parts) > 2: + # This might be a nested package or a module within a package + parent_module_name = ".".join(parts[:-1]) # e.g., 'autogenlib.tokens' + + # Check if the parent module exists as a package + if parent_module_name in sys.modules: + parent_module = sys.modules[parent_module_name] + parent_path = getattr(parent_module, "__path__", None) + + if parent_path: + # Parent is a package + is_package = False + package_path = parent_path + + # We need to check if this is requesting a module that doesn't exist yet + # If the parent exists as a package, we'll create a module within it + module_to_check = fullname + + # Check if an attribute in the parent + attr_name = parts[-1] + if hasattr(parent_module, attr_name): + # The attribute exists, no need to generate code + return None + else: + # Parent module doesn't exist yet + # Start by generating the immediate parent package + parent_package_name = ".".join(parts[:2]) # e.g., 'autogenlib.tokens' + + # First ensure the parent package exists + if parent_package_name not in sys.modules: + # Generate the parent package + parent_code = generate_code( + description, parent_package_name, None, caller_info + ) + if parent_code: + # Cache the generated code with the prompt + cache_module(parent_package_name, parent_code, description) + # Update the module context + set_module_context(parent_package_name, parent_code) + + # Create a spec for the parent package + parent_loader = AutoLibLoader(parent_package_name, parent_code) + parent_spec = importlib.machinery.ModuleSpec( + parent_package_name, parent_loader, is_package=True + ) + + # Create and initialize the parent package + parent_module = importlib.util.module_from_spec(parent_spec) + sys.modules[parent_package_name] = parent_module + parent_spec.loader.exec_module(parent_module) + + # Set the __path__ attribute to make it a proper package + # This is crucial for nested imports to work + if not hasattr(parent_module, "__path__"): + parent_module.__path__ = [] + + # Now handle the subpackage or module + if len(parts) == 3: + # This is a direct submodule of the parent (e.g., autogenlib.tokens.secure) + is_package = False + module_to_check = fullname + else: + # This is a nested subpackage (e.g., autogenlib.tokens.secure.module) + # We need to create intermediate packages + current_pkg = ( + parts[0] + "." + parts[1] + ) # Start with autogenlib.tokens + + for i in range(2, len(parts) - 1): + sub_pkg = ( + current_pkg + "." + parts[i] + ) # e.g., autogenlib.tokens.secure + + if sub_pkg not in sys.modules: + # Generate and load this subpackage + sub_code = generate_code( + description, sub_pkg, None, caller_info + ) + if sub_code: + cache_module(sub_pkg, sub_code, description) + set_module_context(sub_pkg, sub_code) + + sub_loader = AutoLibLoader(sub_pkg, sub_code) + sub_spec = importlib.machinery.ModuleSpec( + sub_pkg, sub_loader, is_package=True + ) + + sub_module = importlib.util.module_from_spec(sub_spec) + sys.modules[sub_pkg] = sub_module + sub_spec.loader.exec_module(sub_module) + + if not hasattr(sub_module, "__path__"): + sub_module.__path__ = [] + + current_pkg = sub_pkg + + # Finally, set up for the actual module we want to import + is_package = False + module_to_check = fullname + else: + # Standard case: autogenlib.module + is_package = len(parts) == 2 + module_to_check = fullname + + # Handle attribute import (e.g., autogenlib.tokens.generate_token) + if len(parts) > 2: + module_name = ".".join(parts[:2]) # e.g., 'autogenlib.tokens' + attr_name = parts[-1] # e.g., 'generate_token' + + # Check if the module exists but is missing this attribute + if module_name in sys.modules: + module = sys.modules[module_name] + + # If the attribute doesn't exist, regenerate the module + if not hasattr(module, attr_name): + # Get the current module code + module_context = get_module_context(module_name) + current_code = module_context.get("code", "") + + # Generate updated code including the new function + new_code = generate_code( + description, fullname, current_code, caller_info + ) + if new_code: + # Update the cache and module + cache_module(module_name, new_code, description) + set_module_context(module_name, new_code) + + # Execute the new code in the module's namespace + exec(new_code, module.__dict__) + + # If the attribute exists now, return None to continue normal import + if hasattr(module, attr_name): + return None + + # Check if the module is already cached + code = get_cached_code(module_to_check) + + if code is None: + # Generate code using OpenAI's API with caller context + code = generate_code(description, module_to_check, None, caller_info) + if code is not None: + # Cache the generated code with the prompt + cache_module(module_to_check, code, description) + # Update the module context + set_module_context(module_to_check, code) + + if code is not None: + # Create a spec for the module + loader = AutoLibLoader(module_to_check, code) + spec = importlib.machinery.ModuleSpec( + module_to_check, loader, is_package=is_package + ) + + # Set origin for proper package handling + if is_package: + spec.submodule_search_locations = [] + + return spec + + return None + + +class AutoLibLoader(importlib.abc.Loader): + def __init__(self, fullname, code): + self.fullname = fullname + self.code = code + + def create_module(self, spec): + return None # Use the default module creation + + def exec_module(self, module): + # Set up package attributes if this is a package + if getattr(module.__spec__, "submodule_search_locations", None) is not None: + # This is a package + if not hasattr(module, "__path__"): + module.__path__ = [] + + # Create a virtual __init__.py for packages + if "__init__" not in self.code: + init_code = self.code + else: + init_code = self.code + + # Execute the code + exec(init_code, module.__dict__) + else: + # Regular module + exec(self.code, module.__dict__) + + # Update the module context + set_module_context(self.fullname, self.code) diff --git a/Libraries/autogenlib/autogenlib/_generator.py b/Libraries/autogenlib/_generator.py similarity index 97% rename from Libraries/autogenlib/autogenlib/_generator.py rename to Libraries/autogenlib/_generator.py index 374e303c..7df101cd 100644 --- a/Libraries/autogenlib/autogenlib/_generator.py +++ b/Libraries/autogenlib/_generator.py @@ -1,356 +1,356 @@ -"""Code generation for autogenlib using OpenAI API.""" - -import openai -import os -import ast -import re -from ._cache import get_all_modules, get_cached_prompt -from logging import getLogger - -logger = getLogger(__name__) - - -def validate_code(code): - """Validate the generated code against PEP standards.""" - try: - # Check if the code is syntactically valid - ast.parse(code) - return True - except SyntaxError: - return False - - -def get_codebase_context(): - """Get the full codebase context for all cached modules.""" - modules = get_all_modules() - - if not modules: - return "" - - context = "Here is the existing codebase for reference:\n\n" - - for module_name, data in modules.items(): - if "code" in data: - context += f"# Module: {module_name}\n```python\n{data['code']}\n```\n\n" - - return context - - -def extract_python_code(response): - """ - Extract Python code from LLM response more robustly. - - Handles various ways code might be formatted in the response: - - Code blocks with ```python or ``` markers - - Multiple code blocks - - Indented code blocks - - Code without any markers - - Returns the cleaned Python code. - """ - # Check if response is already clean code (no markdown) - if validate_code(response): - return response - - # Try to extract code from markdown code blocks - code_block_pattern = r"```(?:python)?(.*?)```" - matches = re.findall(code_block_pattern, response, re.DOTALL) - - if matches: - # Join all code blocks and check if valid - extracted_code = "\n\n".join(match.strip() for match in matches) - if validate_code(extracted_code): - return extracted_code - - # If we get here, no valid code blocks were found - # Try to identify the largest Python-like chunk in the text - lines = response.split("\n") - code_lines = [] - current_code_chunk = [] - - for line in lines: - # Skip obvious non-code lines - if re.match( - r"^(#|Here's|I've|This|Note:|Remember:|Explanation:)", line.strip() - ): - # If we were collecting code, save the chunk - if current_code_chunk: - code_lines.extend(current_code_chunk) - current_code_chunk = [] - continue - - # Lines that likely indicate code - if re.match( - r"^(import|from|def|class|if|for|while|return|try|with|@|\s{4}| )", line - ): - current_code_chunk.append(line) - elif line.strip() == "" and current_code_chunk: - # Empty lines within code blocks are kept - current_code_chunk.append(line) - elif current_code_chunk: - # If we have a non-empty line that doesn't look like code but follows code - # we keep it in the current chunk (might be a variable assignment, etc.) - current_code_chunk.append(line) - - # Add any remaining code chunk - if current_code_chunk: - code_lines.extend(current_code_chunk) - - # Join all identified code lines - extracted_code = "\n".join(code_lines) - - # If we couldn't extract anything or it's invalid, return the original - # but the validator will likely reject it - if not extracted_code or not validate_code(extracted_code): - # Last resort: try to use the whole response if it might be valid code - if "def " in response or "class " in response or "import " in response: - if validate_code(response): - return response - - # Log the issue - logger.warning("Could not extract valid Python code from response") - logger.debug("Response: %s", response) - return response - - return extracted_code - - -def generate_code(description, fullname, existing_code=None, caller_info=None): - """Generate code using the OpenAI API.""" - parts = fullname.split(".") - if len(parts) < 2: - return None - - module_name = parts[1] - function_name = parts[2] if len(parts) > 2 else None - - # Get the cached prompt or use the provided description - module_to_check = ".".join(fullname.split(".")[:2]) # e.g., 'autogenlib.totp' - cached_prompt = get_cached_prompt(module_to_check) - current_description = cached_prompt or description - - # Get the full codebase context - codebase_context = get_codebase_context() - - # Add caller code context if available - caller_context = "" - if caller_info and caller_info.get("code"): - code = caller_info.get("code", "") - # Extract the most relevant parts of the code if possible - # Try to focus on the sections that use the requested module/function - relevant_parts = [] - module_parts = fullname.split(".") - - if len(module_parts) >= 2: - # Look for imports of this module - module_prefix = f"from {module_parts[0]}.{module_parts[1]}" - import_lines = [line for line in code.split("\n") if module_prefix in line] - if import_lines: - relevant_parts.extend(import_lines) - - # Look for usages of the imported functions - if len(module_parts) >= 3: - func_name = module_parts[2] - func_usage_lines = [ - line - for line in code.split("\n") - if func_name in line and not line.startswith(("import ", "from ")) - ] - if func_usage_lines: - relevant_parts.extend(func_usage_lines) - - # Include relevant parts if found, otherwise use the whole code - if relevant_parts: - caller_context = f""" - Here is the code that is importing and using this module/function: - ```python - # File: {caller_info.get("filename", "unknown")} - # --- Relevant snippets --- - {"\n".join(relevant_parts)} - ``` - - And here is the full context: - ```python - {code} - ``` - - Pay special attention to how the requested functionality will be used in the code snippets above. - """ - else: - caller_context = f""" - Here is the code that is importing this module/function: - ```python - # File: {caller_info.get("filename", "unknown")} - {code} - ``` - - Pay special attention to how the requested functionality will be used in this code. - """ - - logger.debug(f"Including caller context from {caller_info.get('filename')}") - - # Create a prompt for the OpenAI API - system_message = """ - You are an expert Python developer tasked with generating high-quality, production-ready Python modules. - - Follow these guidelines precisely: - - 1. CODE QUALITY: - - Write clean, efficient, and well-documented code with docstrings - - Follow PEP 8 style guidelines strictly - - Include type hints where appropriate (Python 3.12+ compatible) - - Add comprehensive error handling for edge cases - - Create descriptive variable names that clearly convey their purpose - - 2. UNDERSTANDING CONTEXT: - - Carefully analyze existing code to maintain consistency - - Match the naming conventions and patterns in related modules - - Ensure your implementation will work with the exact data structures shown in caller code - - Make reasonable assumptions when information is missing, but document those assumptions - - 3. RESPONSE FORMAT: - - ONLY provide clean Python code with no explanations outside of code comments - - Do NOT include markdown formatting, explanations, or any text outside the code - - Do NOT include ```python or ``` markers around your code - - Your entire response should be valid Python code that can be executed directly - - 4. IMPORTS: - - Use only Python standard library modules unless explicitly told otherwise - - If you need to import from within the library (autogenlib), do so as if those modules exist - - Format imports according to PEP 8 (stdlib, third-party, local) - - The code you generate will be directly executed by the Python interpreter, so it must be syntactically perfect. - """ - - if function_name and existing_code: - prompt = f""" - TASK: Extend an existing Python module named '{module_name}' with a new function/class. - - LIBRARY PURPOSE: - {current_description} - - EXISTING MODULE CODE: - ```python - {existing_code} - ``` - - CODEBASE CONTEXT: - {codebase_context} - - CALLER CONTEXT: - {caller_context} - - REQUIREMENTS: - Add a new {"class" if function_name[0].isupper() else "function"} named '{function_name}' that implements: - {description} - - IMPORTANT INSTRUCTIONS: - 1. Keep all existing functions and classes intact - 2. Follow the existing coding style for consistency - 3. Add comprehensive docstrings and comments where needed - 4. Include proper type hints and error handling - 5. Return ONLY the complete Python code for the entire module - 6. Do NOT include any explanations or markdown formatting in your response - """ - elif function_name: - prompt = f""" - TASK: Create a new Python module named '{module_name}' with a specific function/class. - - LIBRARY PURPOSE: - {current_description} - - CODEBASE CONTEXT: - {codebase_context} - - CALLER CONTEXT: - {caller_context} - - REQUIREMENTS: - Create a module that contains a {"class" if function_name[0].isupper() else "function"} named '{function_name}' that implements: - {description} - - IMPORTANT INSTRUCTIONS: - 1. Start with an appropriate module docstring summarizing the purpose - 2. Include comprehensive docstrings for all functions/classes - 3. Add proper type hints and error handling - 4. Return ONLY the complete Python code for the module - 5. Do NOT include any explanations or markdown formatting in your response - """ - else: - prompt = f""" - TASK: Create a new Python package module named '{module_name}'. - - LIBRARY PURPOSE: - {current_description} - - CODEBASE CONTEXT: - {codebase_context} - - CALLER CONTEXT: - {caller_context} - - REQUIREMENTS: - Implement functionality for: - {description} - - IMPORTANT INSTRUCTIONS: - 1. Create a well-structured module with appropriate functions and classes - 2. Start with a comprehensive module docstring - 3. Include proper docstrings, type hints, and error handling - 4. Return ONLY the complete Python code without any explanations - 5. Do NOT include file paths or any markdown formatting in your response - """ - - try: - # Set API key from environment variable - api_key = os.environ.get("OPENAI_API_KEY") - if not api_key: - raise ValueError("Please set the OPENAI_API_KEY environment variable.") - - base_url = os.environ.get("OPENAI_API_BASE_URL") - model = os.environ.get("OPENAI_MODEL", "gpt-4.1") - - # Initialize the OpenAI client - client = openai.OpenAI(api_key=api_key, base_url=base_url) - - logger.debug("Prompt: %s", prompt) - - # Call the OpenAI API - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_message}, - {"role": "user", "content": prompt}, - ], - temperature=0.1, - ) - - # Get the generated code - raw_response = response.choices[0].message.content.strip() - - logger.debug("Raw response: %s", raw_response) - - # Extract and clean the Python code from the response - code = extract_python_code(raw_response) - - logger.debug("Extracted code: %s", code) - - # Validate the code - if validate_code(code): - return code - else: - logger.error("Generated code is not valid. Attempting to fix...") - - # Try to clean up common issues - # Remove any additional text before or after code blocks - clean_code = re.sub(r'^.*?(?=(?:"""|\'\'\'))', "", code, flags=re.DOTALL) - - if validate_code(clean_code): - logger.info("Fixed code validation issues") - return clean_code - - logger.error("Generated code is not valid and could not be fixed") - return None - except Exception as e: - logger.error(f"Error generating code: {e}") - return None +"""Code generation for autogenlib using OpenAI API.""" + +import openai +import os +import ast +import re +from ._cache import get_all_modules, get_cached_prompt +from logging import getLogger + +logger = getLogger(__name__) + + +def validate_code(code): + """Validate the generated code against PEP standards.""" + try: + # Check if the code is syntactically valid + ast.parse(code) + return True + except SyntaxError: + return False + + +def get_codebase_context(): + """Get the full codebase context for all cached modules.""" + modules = get_all_modules() + + if not modules: + return "" + + context = "Here is the existing codebase for reference:\n\n" + + for module_name, data in modules.items(): + if "code" in data: + context += f"# Module: {module_name}\n```python\n{data['code']}\n```\n\n" + + return context + + +def extract_python_code(response): + """ + Extract Python code from LLM response more robustly. + + Handles various ways code might be formatted in the response: + - Code blocks with ```python or ``` markers + - Multiple code blocks + - Indented code blocks + - Code without any markers + + Returns the cleaned Python code. + """ + # Check if response is already clean code (no markdown) + if validate_code(response): + return response + + # Try to extract code from markdown code blocks + code_block_pattern = r"```(?:python)?(.*?)```" + matches = re.findall(code_block_pattern, response, re.DOTALL) + + if matches: + # Join all code blocks and check if valid + extracted_code = "\n\n".join(match.strip() for match in matches) + if validate_code(extracted_code): + return extracted_code + + # If we get here, no valid code blocks were found + # Try to identify the largest Python-like chunk in the text + lines = response.split("\n") + code_lines = [] + current_code_chunk = [] + + for line in lines: + # Skip obvious non-code lines + if re.match( + r"^(#|Here's|I've|This|Note:|Remember:|Explanation:)", line.strip() + ): + # If we were collecting code, save the chunk + if current_code_chunk: + code_lines.extend(current_code_chunk) + current_code_chunk = [] + continue + + # Lines that likely indicate code + if re.match( + r"^(import|from|def|class|if|for|while|return|try|with|@|\s{4}| )", line + ): + current_code_chunk.append(line) + elif line.strip() == "" and current_code_chunk: + # Empty lines within code blocks are kept + current_code_chunk.append(line) + elif current_code_chunk: + # If we have a non-empty line that doesn't look like code but follows code + # we keep it in the current chunk (might be a variable assignment, etc.) + current_code_chunk.append(line) + + # Add any remaining code chunk + if current_code_chunk: + code_lines.extend(current_code_chunk) + + # Join all identified code lines + extracted_code = "\n".join(code_lines) + + # If we couldn't extract anything or it's invalid, return the original + # but the validator will likely reject it + if not extracted_code or not validate_code(extracted_code): + # Last resort: try to use the whole response if it might be valid code + if "def " in response or "class " in response or "import " in response: + if validate_code(response): + return response + + # Log the issue + logger.warning("Could not extract valid Python code from response") + logger.debug("Response: %s", response) + return response + + return extracted_code + + +def generate_code(description, fullname, existing_code=None, caller_info=None): + """Generate code using the OpenAI API.""" + parts = fullname.split(".") + if len(parts) < 2: + return None + + module_name = parts[1] + function_name = parts[2] if len(parts) > 2 else None + + # Get the cached prompt or use the provided description + module_to_check = ".".join(fullname.split(".")[:2]) # e.g., 'autogenlib.totp' + cached_prompt = get_cached_prompt(module_to_check) + current_description = cached_prompt or description + + # Get the full codebase context + codebase_context = get_codebase_context() + + # Add caller code context if available + caller_context = "" + if caller_info and caller_info.get("code"): + code = caller_info.get("code", "") + # Extract the most relevant parts of the code if possible + # Try to focus on the sections that use the requested module/function + relevant_parts = [] + module_parts = fullname.split(".") + + if len(module_parts) >= 2: + # Look for imports of this module + module_prefix = f"from {module_parts[0]}.{module_parts[1]}" + import_lines = [line for line in code.split("\n") if module_prefix in line] + if import_lines: + relevant_parts.extend(import_lines) + + # Look for usages of the imported functions + if len(module_parts) >= 3: + func_name = module_parts[2] + func_usage_lines = [ + line + for line in code.split("\n") + if func_name in line and not line.startswith(("import ", "from ")) + ] + if func_usage_lines: + relevant_parts.extend(func_usage_lines) + + # Include relevant parts if found, otherwise use the whole code + if relevant_parts: + caller_context = f""" + Here is the code that is importing and using this module/function: + ```python + # File: {caller_info.get("filename", "unknown")} + # --- Relevant snippets --- + {"\n".join(relevant_parts)} + ``` + + And here is the full context: + ```python + {code} + ``` + + Pay special attention to how the requested functionality will be used in the code snippets above. + """ + else: + caller_context = f""" + Here is the code that is importing this module/function: + ```python + # File: {caller_info.get("filename", "unknown")} + {code} + ``` + + Pay special attention to how the requested functionality will be used in this code. + """ + + logger.debug(f"Including caller context from {caller_info.get('filename')}") + + # Create a prompt for the OpenAI API + system_message = """ + You are an expert Python developer tasked with generating high-quality, production-ready Python modules. + + Follow these guidelines precisely: + + 1. CODE QUALITY: + - Write clean, efficient, and well-documented code with docstrings + - Follow PEP 8 style guidelines strictly + - Include type hints where appropriate (Python 3.12+ compatible) + - Add comprehensive error handling for edge cases + - Create descriptive variable names that clearly convey their purpose + + 2. UNDERSTANDING CONTEXT: + - Carefully analyze existing code to maintain consistency + - Match the naming conventions and patterns in related modules + - Ensure your implementation will work with the exact data structures shown in caller code + - Make reasonable assumptions when information is missing, but document those assumptions + + 3. RESPONSE FORMAT: + - ONLY provide clean Python code with no explanations outside of code comments + - Do NOT include markdown formatting, explanations, or any text outside the code + - Do NOT include ```python or ``` markers around your code + - Your entire response should be valid Python code that can be executed directly + + 4. IMPORTS: + - Use only Python standard library modules unless explicitly told otherwise + - If you need to import from within the library (autogenlib), do so as if those modules exist + - Format imports according to PEP 8 (stdlib, third-party, local) + + The code you generate will be directly executed by the Python interpreter, so it must be syntactically perfect. + """ + + if function_name and existing_code: + prompt = f""" + TASK: Extend an existing Python module named '{module_name}' with a new function/class. + + LIBRARY PURPOSE: + {current_description} + + EXISTING MODULE CODE: + ```python + {existing_code} + ``` + + CODEBASE CONTEXT: + {codebase_context} + + CALLER CONTEXT: + {caller_context} + + REQUIREMENTS: + Add a new {"class" if function_name[0].isupper() else "function"} named '{function_name}' that implements: + {description} + + IMPORTANT INSTRUCTIONS: + 1. Keep all existing functions and classes intact + 2. Follow the existing coding style for consistency + 3. Add comprehensive docstrings and comments where needed + 4. Include proper type hints and error handling + 5. Return ONLY the complete Python code for the entire module + 6. Do NOT include any explanations or markdown formatting in your response + """ + elif function_name: + prompt = f""" + TASK: Create a new Python module named '{module_name}' with a specific function/class. + + LIBRARY PURPOSE: + {current_description} + + CODEBASE CONTEXT: + {codebase_context} + + CALLER CONTEXT: + {caller_context} + + REQUIREMENTS: + Create a module that contains a {"class" if function_name[0].isupper() else "function"} named '{function_name}' that implements: + {description} + + IMPORTANT INSTRUCTIONS: + 1. Start with an appropriate module docstring summarizing the purpose + 2. Include comprehensive docstrings for all functions/classes + 3. Add proper type hints and error handling + 4. Return ONLY the complete Python code for the module + 5. Do NOT include any explanations or markdown formatting in your response + """ + else: + prompt = f""" + TASK: Create a new Python package module named '{module_name}'. + + LIBRARY PURPOSE: + {current_description} + + CODEBASE CONTEXT: + {codebase_context} + + CALLER CONTEXT: + {caller_context} + + REQUIREMENTS: + Implement functionality for: + {description} + + IMPORTANT INSTRUCTIONS: + 1. Create a well-structured module with appropriate functions and classes + 2. Start with a comprehensive module docstring + 3. Include proper docstrings, type hints, and error handling + 4. Return ONLY the complete Python code without any explanations + 5. Do NOT include file paths or any markdown formatting in your response + """ + + try: + # Set API key from environment variable + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError("Please set the OPENAI_API_KEY environment variable.") + + base_url = os.environ.get("OPENAI_API_BASE_URL") + model = os.environ.get("OPENAI_MODEL", "gpt-4.1") + + # Initialize the OpenAI client + client = openai.OpenAI(api_key=api_key, base_url=base_url) + + logger.debug("Prompt: %s", prompt) + + # Call the OpenAI API + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + ) + + # Get the generated code + raw_response = response.choices[0].message.content.strip() + + logger.debug("Raw response: %s", raw_response) + + # Extract and clean the Python code from the response + code = extract_python_code(raw_response) + + logger.debug("Extracted code: %s", code) + + # Validate the code + if validate_code(code): + return code + else: + logger.error("Generated code is not valid. Attempting to fix...") + + # Try to clean up common issues + # Remove any additional text before or after code blocks + clean_code = re.sub(r'^.*?(?=(?:"""|\'\'\'))', "", code, flags=re.DOTALL) + + if validate_code(clean_code): + logger.info("Fixed code validation issues") + return clean_code + + logger.error("Generated code is not valid and could not be fixed") + return None + except Exception as e: + logger.error(f"Error generating code: {e}") + return None diff --git a/Libraries/autogenlib/autogenlib/_state.py b/Libraries/autogenlib/_state.py similarity index 96% rename from Libraries/autogenlib/autogenlib/_state.py rename to Libraries/autogenlib/_state.py index 7c1d18a8..8c906057 100644 --- a/Libraries/autogenlib/autogenlib/_state.py +++ b/Libraries/autogenlib/_state.py @@ -1,10 +1,10 @@ -"""Shared state for the autogenlib package.""" - -# The global description provided by the user -description = "A useful library." - -# Flag to enable/disable the exception handler -exception_handler_enabled = True - -# Flag to enable/disable caching -caching_enabled = False +"""Shared state for the autogenlib package.""" + +# The global description provided by the user +description = "A useful library." + +# Flag to enable/disable the exception handler +exception_handler_enabled = True + +# Flag to enable/disable caching +caching_enabled = False diff --git a/Libraries/autogenlib/examples/error_handling.py b/Libraries/autogenlib/examples/error_handling.py deleted file mode 100644 index a787733f..00000000 --- a/Libraries/autogenlib/examples/error_handling.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Example demonstrating automatic error handling in autogenlib.""" - -from autogenlib import setup_exception_handler - -# Initialize error handling -setup_exception_handler() - -# Throw an error -1 / 0 diff --git a/Libraries/autogenlib/examples/logger.py b/Libraries/autogenlib/examples/logger.py deleted file mode 100644 index e9e8d4fc..00000000 --- a/Libraries/autogenlib/examples/logger.py +++ /dev/null @@ -1,12 +0,0 @@ -from logging import DEBUG, basicConfig - -basicConfig(level=DEBUG) - -from autogenlib import init - -init("Library for easy and beatiful logging in CLI") - -from autogenlib.easylog import warn, info - -warn("Ahtung!") -info("Some useful (no) message") diff --git a/Libraries/autogenlib/examples/totp.py b/Libraries/autogenlib/examples/totp.py deleted file mode 100644 index 2bab2ad8..00000000 --- a/Libraries/autogenlib/examples/totp.py +++ /dev/null @@ -1,17 +0,0 @@ -from autogenlib import init - -# Initialize with our general description -init("Library for cryptographic operations and secure communications") - -# First import - generates the totp module with totp_generator and current_totp_code functions -from autogenlib.totp import generate_totp_secret, current_totp_code - -secret = generate_totp_secret() -print(f"Secret: {secret}") - -code = current_totp_code(secret) -print(f"Code: {code}") - -from autogenlib.totp import validate_code_against_secret - -print(f"Is valid?: {validate_code_against_secret(code, secret)}") diff --git a/Libraries/autogenlib/pyproject.toml b/Libraries/autogenlib/pyproject.toml deleted file mode 100644 index 92657a5e..00000000 --- a/Libraries/autogenlib/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[project] -name = "autogenlib" -version = "0.1.3" -description = "Import wisdom, export code." -readme = "README.md" -requires-python = ">=3.12" -dependencies = ["openai>=1.78.0"] -authors = [{ name = "Egor Ternovoi", email = "i.am@cfb.wtf" }] -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", -] - -[project.urls] -Homepage = "https://github.com/cofob/autogenlib" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" diff --git a/Libraries/autogenlib/uv.lock b/Libraries/autogenlib/uv.lock deleted file mode 100644 index 59e6919b..00000000 --- a/Libraries/autogenlib/uv.lock +++ /dev/null @@ -1,262 +0,0 @@ -version = 1 -requires-python = ">=3.12" - -[[package]] -name = "annotated-types" -version = "0.7.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, -] - -[[package]] -name = "anyio" -version = "4.9.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "idna" }, - { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/95/7d/4c1bd541d4dffa1b52bd83fb8527089e097a106fc90b467a7313b105f840/anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028", size = 190949 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916 }, -] - -[[package]] -name = "autogenlib" -version = "0.1.0" -source = { editable = "." } -dependencies = [ - { name = "openai" }, -] - -[package.metadata] -requires-dist = [{ name = "openai", specifier = ">=1.78.0" }] - -[[package]] -name = "certifi" -version = "2025.4.26" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/9e/c05b3920a3b7d20d3d3310465f50348e5b3694f4f88c6daf736eef3024c4/certifi-2025.4.26.tar.gz", hash = "sha256:0a816057ea3cdefcef70270d2c515e4506bbc954f417fa5ade2021213bb8f0c6", size = 160705 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4a/7e/3db2bd1b1f9e95f7cddca6d6e75e2f2bd9f51b1246e546d88addca0106bd/certifi-2025.4.26-py3-none-any.whl", hash = "sha256:30350364dfe371162649852c63336a15c70c6510c2ad5015b21c2345311805f3", size = 159618 }, -] - -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, -] - -[[package]] -name = "distro" -version = "1.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 }, -] - -[[package]] -name = "h11" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515 }, -] - -[[package]] -name = "httpcore" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "h11" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784 }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "certifi" }, - { name = "httpcore" }, - { name = "idna" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, -] - -[[package]] -name = "idna" -version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, -] - -[[package]] -name = "jiter" -version = "0.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1e/c2/e4562507f52f0af7036da125bb699602ead37a2332af0788f8e0a3417f36/jiter-0.9.0.tar.gz", hash = "sha256:aadba0964deb424daa24492abc3d229c60c4a31bfee205aedbf1acc7639d7893", size = 162604 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/af/d7/c55086103d6f29b694ec79156242304adf521577530d9031317ce5338c59/jiter-0.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7b46249cfd6c48da28f89eb0be3f52d6fdb40ab88e2c66804f546674e539ec11", size = 309203 }, - { url = "https://files.pythonhosted.org/packages/b0/01/f775dfee50beb420adfd6baf58d1c4d437de41c9b666ddf127c065e5a488/jiter-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:609cf3c78852f1189894383cf0b0b977665f54cb38788e3e6b941fa6d982c00e", size = 319678 }, - { url = "https://files.pythonhosted.org/packages/ab/b8/09b73a793714726893e5d46d5c534a63709261af3d24444ad07885ce87cb/jiter-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d726a3890a54561e55a9c5faea1f7655eda7f105bd165067575ace6e65f80bb2", size = 341816 }, - { url = "https://files.pythonhosted.org/packages/35/6f/b8f89ec5398b2b0d344257138182cc090302854ed63ed9c9051e9c673441/jiter-0.9.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2e89dc075c1fef8fa9be219e249f14040270dbc507df4215c324a1839522ea75", size = 364152 }, - { url = "https://files.pythonhosted.org/packages/9b/ca/978cc3183113b8e4484cc7e210a9ad3c6614396e7abd5407ea8aa1458eef/jiter-0.9.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04e8ffa3c353b1bc4134f96f167a2082494351e42888dfcf06e944f2729cbe1d", size = 406991 }, - { url = "https://files.pythonhosted.org/packages/13/3a/72861883e11a36d6aa314b4922125f6ae90bdccc225cd96d24cc78a66385/jiter-0.9.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:203f28a72a05ae0e129b3ed1f75f56bc419d5f91dfacd057519a8bd137b00c42", size = 395824 }, - { url = "https://files.pythonhosted.org/packages/87/67/22728a86ef53589c3720225778f7c5fdb617080e3deaed58b04789418212/jiter-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fca1a02ad60ec30bb230f65bc01f611c8608b02d269f998bc29cca8619a919dc", size = 351318 }, - { url = "https://files.pythonhosted.org/packages/69/b9/f39728e2e2007276806d7a6609cda7fac44ffa28ca0d02c49a4f397cc0d9/jiter-0.9.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:237e5cee4d5d2659aaf91bbf8ec45052cc217d9446070699441a91b386ae27dc", size = 384591 }, - { url = "https://files.pythonhosted.org/packages/eb/8f/8a708bc7fd87b8a5d861f1c118a995eccbe6d672fe10c9753e67362d0dd0/jiter-0.9.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:528b6b71745e7326eed73c53d4aa57e2a522242320b6f7d65b9c5af83cf49b6e", size = 520746 }, - { url = "https://files.pythonhosted.org/packages/95/1e/65680c7488bd2365dbd2980adaf63c562d3d41d3faac192ebc7ef5b4ae25/jiter-0.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9f48e86b57bc711eb5acdfd12b6cb580a59cc9a993f6e7dcb6d8b50522dcd50d", size = 512754 }, - { url = "https://files.pythonhosted.org/packages/78/f3/fdc43547a9ee6e93c837685da704fb6da7dba311fc022e2766d5277dfde5/jiter-0.9.0-cp312-cp312-win32.whl", hash = "sha256:699edfde481e191d81f9cf6d2211debbfe4bd92f06410e7637dffb8dd5dfde06", size = 207075 }, - { url = "https://files.pythonhosted.org/packages/cd/9d/742b289016d155f49028fe1bfbeb935c9bf0ffeefdf77daf4a63a42bb72b/jiter-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:099500d07b43f61d8bd780466d429c45a7b25411b334c60ca875fa775f68ccb0", size = 207999 }, - { url = "https://files.pythonhosted.org/packages/e7/1b/4cd165c362e8f2f520fdb43245e2b414f42a255921248b4f8b9c8d871ff1/jiter-0.9.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:2764891d3f3e8b18dce2cff24949153ee30c9239da7c00f032511091ba688ff7", size = 308197 }, - { url = "https://files.pythonhosted.org/packages/13/aa/7a890dfe29c84c9a82064a9fe36079c7c0309c91b70c380dc138f9bea44a/jiter-0.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:387b22fbfd7a62418d5212b4638026d01723761c75c1c8232a8b8c37c2f1003b", size = 318160 }, - { url = "https://files.pythonhosted.org/packages/6a/38/5888b43fc01102f733f085673c4f0be5a298f69808ec63de55051754e390/jiter-0.9.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d8da8629ccae3606c61d9184970423655fb4e33d03330bcdfe52d234d32f69", size = 341259 }, - { url = "https://files.pythonhosted.org/packages/3d/5e/bbdbb63305bcc01006de683b6228cd061458b9b7bb9b8d9bc348a58e5dc2/jiter-0.9.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a1be73d8982bdc278b7b9377426a4b44ceb5c7952073dd7488e4ae96b88e1103", size = 363730 }, - { url = "https://files.pythonhosted.org/packages/75/85/53a3edc616992fe4af6814c25f91ee3b1e22f7678e979b6ea82d3bc0667e/jiter-0.9.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2228eaaaa111ec54b9e89f7481bffb3972e9059301a878d085b2b449fbbde635", size = 405126 }, - { url = "https://files.pythonhosted.org/packages/ae/b3/1ee26b12b2693bd3f0b71d3188e4e5d817b12e3c630a09e099e0a89e28fa/jiter-0.9.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:11509bfecbc319459647d4ac3fd391d26fdf530dad00c13c4dadabf5b81f01a4", size = 393668 }, - { url = "https://files.pythonhosted.org/packages/11/87/e084ce261950c1861773ab534d49127d1517b629478304d328493f980791/jiter-0.9.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f22238da568be8bbd8e0650e12feeb2cfea15eda4f9fc271d3b362a4fa0604d", size = 352350 }, - { url = "https://files.pythonhosted.org/packages/f0/06/7dca84b04987e9df563610aa0bc154ea176e50358af532ab40ffb87434df/jiter-0.9.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:17f5d55eb856597607562257c8e36c42bc87f16bef52ef7129b7da11afc779f3", size = 384204 }, - { url = "https://files.pythonhosted.org/packages/16/2f/82e1c6020db72f397dd070eec0c85ebc4df7c88967bc86d3ce9864148f28/jiter-0.9.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:6a99bed9fbb02f5bed416d137944419a69aa4c423e44189bc49718859ea83bc5", size = 520322 }, - { url = "https://files.pythonhosted.org/packages/36/fd/4f0cd3abe83ce208991ca61e7e5df915aa35b67f1c0633eb7cf2f2e88ec7/jiter-0.9.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e057adb0cd1bd39606100be0eafe742de2de88c79df632955b9ab53a086b3c8d", size = 512184 }, - { url = "https://files.pythonhosted.org/packages/a0/3c/8a56f6d547731a0b4410a2d9d16bf39c861046f91f57c98f7cab3d2aa9ce/jiter-0.9.0-cp313-cp313-win32.whl", hash = "sha256:f7e6850991f3940f62d387ccfa54d1a92bd4bb9f89690b53aea36b4364bcab53", size = 206504 }, - { url = "https://files.pythonhosted.org/packages/f4/1c/0c996fd90639acda75ed7fa698ee5fd7d80243057185dc2f63d4c1c9f6b9/jiter-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:c8ae3bf27cd1ac5e6e8b7a27487bf3ab5f82318211ec2e1346a5b058756361f7", size = 204943 }, - { url = "https://files.pythonhosted.org/packages/78/0f/77a63ca7aa5fed9a1b9135af57e190d905bcd3702b36aca46a01090d39ad/jiter-0.9.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f0b2827fb88dda2cbecbbc3e596ef08d69bda06c6f57930aec8e79505dc17001", size = 317281 }, - { url = "https://files.pythonhosted.org/packages/f9/39/a3a1571712c2bf6ec4c657f0d66da114a63a2e32b7e4eb8e0b83295ee034/jiter-0.9.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:062b756ceb1d40b0b28f326cba26cfd575a4918415b036464a52f08632731e5a", size = 350273 }, - { url = "https://files.pythonhosted.org/packages/ee/47/3729f00f35a696e68da15d64eb9283c330e776f3b5789bac7f2c0c4df209/jiter-0.9.0-cp313-cp313t-win_amd64.whl", hash = "sha256:6f7838bc467ab7e8ef9f387bd6de195c43bad82a569c1699cb822f6609dd4cdf", size = 206867 }, -] - -[[package]] -name = "openai" -version = "1.78.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "distro" }, - { name = "httpx" }, - { name = "jiter" }, - { name = "pydantic" }, - { name = "sniffio" }, - { name = "tqdm" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d1/7c/7c48bac9be52680e41e99ae7649d5da3a0184cd94081e028897f9005aa03/openai-1.78.0.tar.gz", hash = "sha256:254aef4980688468e96cbddb1f348ed01d274d02c64c6c69b0334bf001fb62b3", size = 442652 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/41/d64a6c56d0ec886b834caff7a07fc4d43e1987895594b144757e7a6b90d7/openai-1.78.0-py3-none-any.whl", hash = "sha256:1ade6a48cd323ad8a7715e7e1669bb97a17e1a5b8a916644261aaef4bf284778", size = 680407 }, -] - -[[package]] -name = "pydantic" -version = "2.11.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "annotated-types" }, - { name = "pydantic-core" }, - { name = "typing-extensions" }, - { name = "typing-inspection" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/77/ab/5250d56ad03884ab5efd07f734203943c8a8ab40d551e208af81d0257bf2/pydantic-2.11.4.tar.gz", hash = "sha256:32738d19d63a226a52eed76645a98ee07c1f410ee41d93b4afbfa85ed8111c2d", size = 786540 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/12/46b65f3534d099349e38ef6ec98b1a5a81f42536d17e0ba382c28c67ba67/pydantic-2.11.4-py3-none-any.whl", hash = "sha256:d9615eaa9ac5a063471da949c8fc16376a84afb5024688b3ff885693506764eb", size = 443900 }, -] - -[[package]] -name = "pydantic-core" -version = "2.33.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000 }, - { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996 }, - { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957 }, - { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199 }, - { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296 }, - { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109 }, - { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028 }, - { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044 }, - { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881 }, - { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034 }, - { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187 }, - { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628 }, - { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866 }, - { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894 }, - { url = "https://files.pythonhosted.org/packages/46/8c/99040727b41f56616573a28771b1bfa08a3d3fe74d3d513f01251f79f172/pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f", size = 2015688 }, - { url = "https://files.pythonhosted.org/packages/3a/cc/5999d1eb705a6cefc31f0b4a90e9f7fc400539b1a1030529700cc1b51838/pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6", size = 1844808 }, - { url = "https://files.pythonhosted.org/packages/6f/5e/a0a7b8885c98889a18b6e376f344da1ef323d270b44edf8174d6bce4d622/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef", size = 1885580 }, - { url = "https://files.pythonhosted.org/packages/3b/2a/953581f343c7d11a304581156618c3f592435523dd9d79865903272c256a/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a", size = 1973859 }, - { url = "https://files.pythonhosted.org/packages/e6/55/f1a813904771c03a3f97f676c62cca0c0a4138654107c1b61f19c644868b/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916", size = 2120810 }, - { url = "https://files.pythonhosted.org/packages/aa/c3/053389835a996e18853ba107a63caae0b9deb4a276c6b472931ea9ae6e48/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a", size = 2676498 }, - { url = "https://files.pythonhosted.org/packages/eb/3c/f4abd740877a35abade05e437245b192f9d0ffb48bbbbd708df33d3cda37/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d", size = 2000611 }, - { url = "https://files.pythonhosted.org/packages/59/a7/63ef2fed1837d1121a894d0ce88439fe3e3b3e48c7543b2a4479eb99c2bd/pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56", size = 2107924 }, - { url = "https://files.pythonhosted.org/packages/04/8f/2551964ef045669801675f1cfc3b0d74147f4901c3ffa42be2ddb1f0efc4/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5", size = 2063196 }, - { url = "https://files.pythonhosted.org/packages/26/bd/d9602777e77fc6dbb0c7db9ad356e9a985825547dce5ad1d30ee04903918/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e", size = 2236389 }, - { url = "https://files.pythonhosted.org/packages/42/db/0e950daa7e2230423ab342ae918a794964b053bec24ba8af013fc7c94846/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162", size = 2239223 }, - { url = "https://files.pythonhosted.org/packages/58/4d/4f937099c545a8a17eb52cb67fe0447fd9a373b348ccfa9a87f141eeb00f/pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849", size = 1900473 }, - { url = "https://files.pythonhosted.org/packages/a0/75/4a0a9bac998d78d889def5e4ef2b065acba8cae8c93696906c3a91f310ca/pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9", size = 1955269 }, - { url = "https://files.pythonhosted.org/packages/f9/86/1beda0576969592f1497b4ce8e7bc8cbdf614c352426271b1b10d5f0aa64/pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9", size = 1893921 }, - { url = "https://files.pythonhosted.org/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162 }, - { url = "https://files.pythonhosted.org/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560 }, - { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777 }, -] - -[[package]] -name = "sniffio" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, -] - -[[package]] -name = "tqdm" -version = "4.67.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 }, -] - -[[package]] -name = "typing-extensions" -version = "4.13.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f6/37/23083fcd6e35492953e8d2aaaa68b860eb422b34627b13f2ce3eb6106061/typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef", size = 106967 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/54/b1ae86c0973cc6f0210b53d508ca3641fb6d0c56823f288d108bc7ab3cc8/typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c", size = 45806 }, -] - -[[package]] -name = "typing-inspection" -version = "0.4.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/82/5c/e6082df02e215b846b4b8c0b887a64d7d08ffaba30605502639d44c06b82/typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122", size = 76222 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, -] diff --git a/Libraries/graph_sitter_lib/__init__.py b/Libraries/graph_sitter_lib/__init__.py new file mode 100644 index 00000000..0c39c92a --- /dev/null +++ b/Libraries/graph_sitter_lib/__init__.py @@ -0,0 +1 @@ +"""Synced from https://github.com/Zeeeepa/graph-sitter.git""" diff --git a/Libraries/graph_sitter_lib/analysis.py b/Libraries/graph_sitter_lib/analysis.py new file mode 100644 index 00000000..322d6fcf --- /dev/null +++ b/Libraries/graph_sitter_lib/analysis.py @@ -0,0 +1,2097 @@ +#!/usr/bin/env python3 +"""Comprehensive Python Code Analysis Backend with Graph-Sitter Integration + +This module provides a complete code analysis system that combines: +- Graph-sitter for structural codebase analysis +- SolidLSP for real-time language server diagnostics +- Multiple static analysis tools (ruff, mypy, pylint, etc.) +- AutoGenLib integration for AI-powered error fixing +- Advanced error categorization and reporting + +Usage: + python analysis.py --target /path/to/project --comprehensive + python analysis.py --target /path/to/file.py --fix-errors + python analysis.py --target . --interactive +""" + +import argparse +import json +import logging +import os +import sqlite3 +import subprocess +import sys +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from typing import Any + +import yaml + +# Third-party imports +try: + import openai + from rich.console import Console + from rich.panel import Panel + from rich.progress import Progress, SpinnerColumn, TextColumn + from rich.syntax import Syntax + from rich.table import Table + from rich.tree import Tree + + RICH_AVAILABLE = True +except ImportError: + RICH_AVAILABLE = False + Console = None + +# Graph-sitter integration +try: + from graph_sitter import Codebase + from graph_sitter.codebase.codebase_analysis import ( + get_class_summary, + get_codebase_summary, + get_file_summary, + get_function_summary, + get_symbol_summary, + ) + from graph_sitter.configs.models.codebase import CodebaseConfig + + GRAPH_SITTER_AVAILABLE = True +except ImportError: + GRAPH_SITTER_AVAILABLE = False + Codebase = None + +# SolidLSP integration +try: + from graph_sitter.extensions.lsp.solidlsp.ls import SolidLanguageServer + from graph_sitter.extensions.lsp.solidlsp.ls_config import Language, LanguageServerConfig + from graph_sitter.extensions.lsp.solidlsp.ls_logger import LanguageServerLogger + from graph_sitter.extensions.lsp.solidlsp.settings import SolidLSPSettings + + SOLIDLSP_AVAILABLE = True +except ImportError as e: + logging.debug(f"SolidLSP not available: {e}") + SOLIDLSP_AVAILABLE = False + +# AutoGenLib integration +try: + from graph_sitter.extensions import autogenlib + from graph_sitter.extensions.autogenlib._cache import cache_module + from graph_sitter.extensions.autogenlib._exception_handler import generate_fix + + AUTOGENLIB_AVAILABLE = True +except ImportError as e: + AUTOGENLIB_AVAILABLE = False + logging.debug(f"AutoGenLib not available: {e}") + + +@dataclass +class AnalysisError: + """Structured representation of a code analysis error.""" + + file_path: str + line: int + column: int + error_type: str + severity: str + message: str + tool_source: str + category: str = "general" + fix_suggestion: str | None = None + confidence: float = 1.0 + context: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "file_path": self.file_path, + "line": self.line, + "column": self.column, + "error_type": self.error_type, + "severity": self.severity, + "message": self.message, + "tool_source": self.tool_source, + "category": self.category, + "fix_suggestion": self.fix_suggestion, + "confidence": self.confidence, + "context": self.context, + } + + +@dataclass +class ToolConfig: + """Configuration for a code analysis tool.""" + + name: str + command: str + enabled: bool = True + args: list[str] = field(default_factory=list) + config_file: str | None = None + timeout: int = 300 + priority: int = 2 # 1=critical, 2=important, 3=optional + requires_network: bool = False + + +class GraphSitterAnalysis: + """Graph-sitter based code analysis wrapper.""" + + def __init__(self, target_path: str): + """Initialize graph-sitter analysis.""" + if not GRAPH_SITTER_AVAILABLE: + msg = "graph-sitter not available. Install with: pip install graph-sitter" + raise ImportError(msg) + + self.target_path = target_path + self.codebase = None + self._initialize_codebase() + + def _initialize_codebase(self): + """Initialize the graph-sitter codebase.""" + try: + config = CodebaseConfig( + method_usages=True, + generics=True, + sync_enabled=True, + full_range_index=True, + py_resolve_syspath=True, + exp_lazy_graph=False, + ) + self.codebase = Codebase(self.target_path, config=config) + except Exception as e: + logging.warning(f"Failed to initialize graph-sitter codebase: {e}") + self.codebase = None + + @property + def functions(self): + """All functions in codebase.""" + if not self.codebase: + return [] + return getattr(self.codebase, "functions", []) + + @property + def classes(self): + """All classes in codebase.""" + if not self.codebase: + return [] + return getattr(self.codebase, "classes", []) + + @property + def imports(self): + """All imports in codebase.""" + if not self.codebase: + return [] + return getattr(self.codebase, "imports", []) + + @property + def files(self): + """All files in codebase.""" + if not self.codebase: + return [] + return getattr(self.codebase, "files", []) + + @property + def symbols(self): + """All symbols in codebase.""" + if not self.codebase: + return [] + return getattr(self.codebase, "symbols", []) + + @property + def external_modules(self): + """External dependencies.""" + if not self.codebase: + return [] + return getattr(self.codebase, "external_modules", []) + + def get_codebase_summary(self) -> dict[str, Any]: + """Get comprehensive codebase summary.""" + if not self.codebase: + return {} + + try: + return get_codebase_summary(self.codebase) + except Exception as e: + logging.warning(f"Failed to get codebase summary: {e}") + return { + "files": len(self.files), + "functions": len(self.functions), + "classes": len(self.classes), + "imports": len(self.imports), + "external_modules": len(self.external_modules), + } + + def get_function_analysis(self, function_name: str) -> dict[str, Any]: + """Get detailed analysis for a specific function.""" + functions_attr = getattr(self.codebase, "functions", []) + functions = list(functions_attr) if hasattr(functions_attr, "__iter__") else [] + + for func in functions: + if getattr(func, "name", "") == function_name: + try: + return get_function_summary(func) + except Exception: + # Fallback to basic analysis + return { + "name": func.name, + "parameters": [p.name for p in getattr(func, "parameters", [])], + "return_type": getattr(func, "return_type", None), + "decorators": [d.name for d in getattr(func, "decorators", [])], + "is_async": getattr(func, "is_async", False), + "complexity": getattr(func, "complexity", 0), + "usages": len(getattr(func, "usages", [])), + } + return {} + + def get_class_analysis(self, class_name: str) -> dict[str, Any]: + """Get detailed analysis for a specific class.""" + classes_attr = getattr(self.codebase, "classes", []) + classes = list(classes_attr) if hasattr(classes_attr, "__iter__") else [] + + for cls in classes: + if getattr(cls, "name", "") == class_name: + try: + return get_class_summary(cls) + except Exception: + # Fallback to basic analysis + return { + "name": cls.name, + "methods": len(getattr(cls, "methods", [])), + "attributes": len(getattr(cls, "attributes", [])), + "superclasses": [sc.name for sc in getattr(cls, "superclasses", [])], + "subclasses": [sc.name for sc in getattr(cls, "subclasses", [])], + "is_abstract": getattr(cls, "is_abstract", False), + } + return {} + + +class RuffIntegration: + """Enhanced Ruff integration for comprehensive error detection.""" + + def __init__(self, target_path: str): + self.target_path = target_path + + def run_comprehensive_analysis(self) -> list[AnalysisError]: + """Run comprehensive Ruff analysis with all rule categories.""" + errors = [] + + # Ruff rule categories for comprehensive analysis + rule_categories = [ + ("E", "pycodestyle errors"), + ("W", "pycodestyle warnings"), + ("F", "pyflakes"), + ("C", "mccabe complexity"), + ("I", "isort"), + ("N", "pep8-naming"), + ("D", "pydocstyle"), + ("UP", "pyupgrade"), + ("YTT", "flake8-2020"), + ("ANN", "flake8-annotations"), + ("ASYNC", "flake8-async"), + ("S", "flake8-bandit"), + ("BLE", "flake8-blind-except"), + ("FBT", "flake8-boolean-trap"), + ("B", "flake8-bugbear"), + ("A", "flake8-builtins"), + ("COM", "flake8-commas"), + ("CPY", "flake8-copyright"), + ("C4", "flake8-comprehensions"), + ("DTZ", "flake8-datetimez"), + ("T10", "flake8-debugger"), + ("DJ", "flake8-django"), + ("EM", "flake8-errmsg"), + ("EXE", "flake8-executable"), + ("FA", "flake8-future-annotations"), + ("ISC", "flake8-implicit-str-concat"), + ("ICN", "flake8-import-conventions"), + ("G", "flake8-logging-format"), + ("INP", "flake8-no-pep420"), + ("PIE", "flake8-pie"), + ("T20", "flake8-print"), + ("PYI", "flake8-pyi"), + ("PT", "flake8-pytest-style"), + ("Q", "flake8-quotes"), + ("RSE", "flake8-raise"), + ("RET", "flake8-return"), + ("SLF", "flake8-self"), + ("SLOT", "flake8-slots"), + ("SIM", "flake8-simplify"), + ("TID", "flake8-tidy-imports"), + ("TCH", "flake8-type-checking"), + ("INT", "flake8-gettext"), + ("ARG", "flake8-unused-arguments"), + ("PTH", "flake8-use-pathlib"), + ("TD", "flake8-todos"), + ("FIX", "flake8-fixme"), + ("ERA", "eradicate"), + ("PD", "pandas-vet"), + ("PGH", "pygrep-hooks"), + ("PL", "pylint"), + ("TRY", "tryceratops"), + ("FLY", "flynt"), + ("NPY", "numpy"), + ("AIR", "airflow"), + ("PERF", "perflint"), + ("FURB", "refurb"), + ("LOG", "flake8-logging"), + ("RUF", "ruff-specific"), + ] + + for category, description in rule_categories: + try: + cmd = [ + "ruff", + "check", + "--select", + category, + "--output-format", + "json", + "--no-cache", + self.target_path, + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) + + if result.stdout: + try: + ruff_errors = json.loads(result.stdout) + for error in ruff_errors: + errors.append( + AnalysisError( + file_path=error.get("filename", ""), + line=error.get("location", {}).get("row", 0), + column=error.get("location", {}).get("column", 0), + error_type=error.get("code", ""), + severity=self._map_ruff_severity(error.get("code", "")), + message=error.get("message", ""), + tool_source="ruff", + category=self._categorize_ruff_error(error.get("code", "")), + confidence=0.9, + ) + ) + except json.JSONDecodeError: + pass + + except subprocess.TimeoutExpired: + logging.warning(f"Ruff analysis timed out for category {category}") + except Exception as e: + logging.warning(f"Ruff analysis failed for category {category}: {e}") + + return errors + + def _map_ruff_severity(self, code: str) -> str: + """Map Ruff error codes to severity levels.""" + if code.startswith(("E", "F")): + return "ERROR" + elif code.startswith(("W", "C", "N")): + return "WARNING" + elif code.startswith(("S", "B")): + return "SECURITY" + else: + return "INFO" + + def _categorize_ruff_error(self, code: str) -> str: + """Categorize Ruff errors by type.""" + category_map = { + "E": "syntax_style", + "W": "style_warning", + "F": "logic_error", + "C": "complexity", + "I": "import_style", + "N": "naming", + "D": "documentation", + "S": "security", + "B": "bug_risk", + "A": "builtin_shadow", + "T": "debug_code", + "UP": "modernization", + "ANN": "type_annotation", + "ASYNC": "async_issues", + "PL": "pylint_issues", + "RUF": "ruff_specific", + } + + prefix = code.split("0")[0] if "0" in code else code[:1] + return category_map.get(prefix, "general") + + +class LSPDiagnosticsCollector: + """Collects diagnostics from Language Server Protocol servers.""" + + def __init__(self, target_path: str): + self.target_path = target_path + self.diagnostics = [] + self.logger = LanguageServerLogger() if SOLIDLSP_AVAILABLE else None + + def collect_python_diagnostics(self) -> list[AnalysisError]: + """Collect diagnostics from Python language servers.""" + if not SOLIDLSP_AVAILABLE: + logging.warning("SolidLSP not available, skipping LSP diagnostics") + return [] + + errors = [] + + try: + # Configure Pyright for comprehensive analysis + config = LanguageServerConfig(code_language=Language.PYTHON, trace_lsp_communication=False) + + settings = SolidLSPSettings() + + # Initialize Pyright language server + from graph_sitter.extensions.lsp.solidlsp.language_servers.pyright_server import PyrightServer + + with PyrightServer(config, self.logger, self.target_path, settings) as lsp: + lsp.start_server() + + # Find Python files to analyze + python_files = [] + if os.path.isfile(self.target_path) and self.target_path.endswith(".py"): + python_files = [self.target_path] + elif os.path.isdir(self.target_path): + for root, dirs, files in os.walk(self.target_path): + # Skip common ignore directories + dirs[:] = [ + d + for d in dirs + if d + not in { + "__pycache__", + ".git", + ".venv", + "venv", + "node_modules", + } + ] + for file in files: + if file.endswith(".py"): + python_files.append(os.path.join(root, file)) + + # Open files and collect diagnostics + for file_path in python_files[:10]: # Limit for performance + try: + with open(file_path, encoding="utf-8") as f: + content = f.read() + + # Open document in LSP + lsp.open_document(file_path, content) + + # Wait for diagnostics + time.sleep(0.5) + + # Retrieve diagnostics + diagnostics = lsp.get_diagnostics(file_path) + + for diag in diagnostics: + errors.append( + AnalysisError( + file_path=file_path, + line=diag.get("range", {}).get("start", {}).get("line", 0), + column=diag.get("range", {}).get("start", {}).get("character", 0), + error_type=diag.get("code", "LSP_ERROR"), + severity=self._map_lsp_severity(diag.get("severity", 1)), + message=diag.get("message", ""), + tool_source="pyright", + category="type_checking", + confidence=0.95, + ) + ) + + lsp.close_document(file_path) + + except Exception as e: + logging.warning(f"Failed to analyze {file_path} with LSP: {e}") + + except Exception as e: + logging.exception(f"LSP diagnostics collection failed: {e}") + + return errors + + def _map_lsp_severity(self, severity: int) -> str: + """Map LSP severity to our severity levels.""" + severity_map = {1: "ERROR", 2: "WARNING", 3: "INFO", 4: "HINT"} + return severity_map.get(severity, "INFO") + + +class ErrorDatabase: + """SQLite database for storing and querying analysis errors.""" + + def __init__(self, db_path: str = "analysis_errors.db"): + self.db_path = db_path + self._init_database() + + def _init_database(self): + """Initialize the SQLite database schema.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS analysis_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + target_path TEXT NOT NULL, + timestamp TEXT NOT NULL, + tools_used TEXT NOT NULL, + total_errors INTEGER DEFAULT 0, + config_hash TEXT, + completed BOOLEAN DEFAULT FALSE + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS errors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER, + file_path TEXT NOT NULL, + line INTEGER, + column INTEGER, + error_type TEXT, + severity TEXT, + message TEXT, + tool_source TEXT, + category TEXT, + fix_suggestion TEXT, + confidence REAL, + context TEXT, + FOREIGN KEY (session_id) REFERENCES analysis_sessions (id) + ) + """) + + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_errors_session + ON errors (session_id) + """) + + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_errors_category + ON errors (category) + """) + + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_errors_severity + ON errors (severity) + """) + + def create_session(self, target_path: str, tools_used: list[str], config: dict[str, Any]) -> int: + """Create a new analysis session.""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute( + """ + INSERT INTO analysis_sessions + (target_path, timestamp, tools_used, config_hash) + VALUES (?, ?, ?, ?) + """, + ( + target_path, + time.strftime("%Y-%m-%d %H:%M:%S"), + json.dumps(tools_used), + str(hash(json.dumps(config, sort_keys=True))), + ), + ) + return cursor.lastrowid + + def store_errors(self, errors: list[AnalysisError], session_id: int): + """Store errors in the database.""" + with sqlite3.connect(self.db_path) as conn: + for error in errors: + conn.execute( + """ + INSERT INTO errors + (session_id, file_path, line, column, error_type, severity, + message, tool_source, category, fix_suggestion, confidence, context) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + session_id, + error.file_path, + error.line, + error.column, + error.error_type, + error.severity, + error.message, + error.tool_source, + error.category, + error.fix_suggestion, + error.confidence, + error.context, + ), + ) + + def update_session(self, session_id: int, total_errors: int): + """Update session with final error count.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + UPDATE analysis_sessions + SET total_errors = ?, completed = TRUE + WHERE id = ? + """, + (total_errors, session_id), + ) + + def query_errors(self, filters: dict[str, Any]) -> list[dict[str, Any]]: + """Query errors with filters.""" + query = "SELECT * FROM errors WHERE 1=1" + params = [] + + for key, value in filters.items(): + if key in ["severity", "category", "tool_source", "error_type"]: + query += f" AND {key} = ?" + params.append(value) + + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute(query, params) + return [dict(row) for row in cursor.fetchall()] + + +class AutoGenLibFixer: + """Integration with AutoGenLib for AI-powered error fixing.""" + + def __init__(self): + if not AUTOGENLIB_AVAILABLE: + msg = "AutoGenLib not available" + raise ImportError(msg) + + # Initialize AutoGenLib for code fixing + autogenlib.init( + "Advanced Python code analysis and error fixing system", + enable_exception_handler=True, + enable_caching=True, + ) + + def generate_fix_for_error(self, error: AnalysisError, source_code: str) -> dict[str, Any] | None: + """Generate a fix for a specific error using AutoGenLib's LLM integration.""" + try: + # Create a mock exception for the error + mock_exception_type = type(error.error_type, (Exception,), {}) + mock_exception_value = Exception(error.message) + + # Create a simplified traceback string + mock_traceback = f""" +File "{error.file_path}", line {error.line}, in + {error.context or "# Error context not available"} +{error.error_type}: {error.message} +""" + + # Use AutoGenLib's fix generation + fix_info = generate_fix( + module_name=os.path.basename(error.file_path).replace(".py", ""), + current_code=source_code, + exc_type=mock_exception_type, + exc_value=mock_exception_value, + traceback_str=mock_traceback, + is_autogenlib=False, + source_file=error.file_path, + ) + + return fix_info + + except Exception as e: + logging.exception(f"Failed to generate fix for error: {e}") + return None + + def apply_fix_to_file(self, file_path: str, fixed_code: str) -> bool: + """Apply a fix to a file (with backup).""" + try: + # Create backup + backup_path = f"{file_path}.backup_{int(time.time())}" + with open(file_path) as original: + with open(backup_path, "w") as backup: + backup.write(original.read()) + + # Apply fix + with open(file_path, "w") as f: + f.write(fixed_code) + + logging.info(f"Applied fix to {file_path} (backup: {backup_path})") + return True + + except Exception as e: + logging.exception(f"Failed to apply fix to {file_path}: {e}") + return False + + +class ComprehensiveAnalyzer: + """Main analyzer that orchestrates all analysis tools.""" + + # Enhanced tool configuration with comprehensive coverage + DEFAULT_TOOLS = { + "ruff": ToolConfig( + "ruff", + "ruff", + args=[ + "check", + "--output-format=json", + "--select=ALL", + "--ignore=COM812,ISC001", + ], + timeout=180, + priority=1, + ), + "mypy": ToolConfig( + "mypy", + "mypy", + args=[ + "--strict", + "--show-error-codes", + "--show-column-numbers", + "--json-report=/tmp/mypy.json", + ], + timeout=300, + priority=1, + ), + "pyright": ToolConfig("pyright", "pyright", args=["--outputjson"], timeout=600, priority=1), + "pylint": ToolConfig( + "pylint", + "pylint", + args=["--output-format=json", "--reports=y", "--score=y"], + timeout=300, + priority=2, + ), + "bandit": ToolConfig( + "bandit", + "bandit", + args=["-r", "-f", "json", "--severity-level=low", "--confidence-level=low"], + timeout=180, + priority=1, + ), + "safety": ToolConfig( + "safety", + "safety", + args=["check", "--json", "--full-report"], + timeout=120, + priority=1, + requires_network=True, + ), + "semgrep": ToolConfig( + "semgrep", + "semgrep", + args=["--config=p/python", "--json", "--severity=WARNING"], + timeout=300, + priority=2, + requires_network=True, + ), + "vulture": ToolConfig( + "vulture", + "vulture", + args=["--min-confidence=60", "--sort-by-size"], + timeout=120, + priority=2, + ), + "radon": ToolConfig( + "radon", + "radon", + args=["cc", "-j", "--total-average", "--show-complexity"], + timeout=120, + priority=2, + ), + "xenon": ToolConfig( + "xenon", + "xenon", + args=["--max-absolute=B", "--max-modules=B", "--max-average=C"], + timeout=120, + priority=2, + ), + "black": ToolConfig("black", "black", args=["--check", "--diff"], timeout=60, priority=2), + "isort": ToolConfig( + "isort", + "isort", + args=["--check-only", "--diff", "--profile=black"], + timeout=60, + priority=2, + ), + "pydocstyle": ToolConfig( + "pydocstyle", + "pydocstyle", + args=["--convention=google"], + timeout=120, + priority=2, + ), + "pyflakes": ToolConfig("pyflakes", "pyflakes", timeout=60, priority=2), + "pycodestyle": ToolConfig( + "pycodestyle", + "pycodestyle", + args=["--statistics", "--count"], + timeout=120, + priority=3, + ), + "mccabe": ToolConfig("mccabe", "python -m mccabe", args=["--min", "5"], timeout=60, priority=3), + } + + def __init__( + self, + target_path: str, + config: dict[str, Any] | None = None, + verbose: bool = False, + ): + self.target_path = target_path + self.config = config or {} + self.verbose = verbose + self.tools_config = self.DEFAULT_TOOLS.copy() + self.console = Console() if RICH_AVAILABLE else None + + # Initialize components + self.graph_sitter = None + self.lsp_collector = None + self.ruff_integration = None + self.autogenlib_fixer = None + self.error_db = ErrorDatabase() + + # Results storage + self.last_results = None + + self._initialize_components() + self._load_config() + + def _initialize_components(self): + """Initialize analysis components.""" + try: + if GRAPH_SITTER_AVAILABLE: + self.graph_sitter = GraphSitterAnalysis(self.target_path) + if self.verbose: + print("✓ Graph-sitter initialized") + except Exception as e: + logging.warning(f"Graph-sitter initialization failed: {e}") + + try: + if SOLIDLSP_AVAILABLE: + self.lsp_collector = LSPDiagnosticsCollector(self.target_path) + if self.verbose: + print("✓ LSP diagnostics collector initialized") + except Exception as e: + logging.warning(f"LSP collector initialization failed: {e}") + + try: + self.ruff_integration = RuffIntegration(self.target_path) + if self.verbose: + print("✓ Ruff integration initialized") + except Exception as e: + logging.warning(f"Ruff integration failed: {e}") + + try: + if AUTOGENLIB_AVAILABLE: + self.autogenlib_fixer = AutoGenLibFixer() + if self.verbose: + print("✓ AutoGenLib fixer initialized") + except Exception as e: + logging.warning(f"AutoGenLib fixer initialization failed: {e}") + + def _load_config(self): + """Load configuration from file or use defaults.""" + config_file = self.config.get("config_file") + if config_file and os.path.exists(config_file): + try: + with open(config_file) as f: + if config_file.endswith(".yaml") or config_file.endswith(".yml"): + file_config = yaml.safe_load(f) + else: + file_config = json.load(f) + + # Update tool configurations + tools_config = file_config.get("analysis", {}).get("tools", {}) + for tool_name, tool_config in tools_config.items(): + if tool_name in self.tools_config: + if "enabled" in tool_config: + self.tools_config[tool_name].enabled = tool_config["enabled"] + if "args" in tool_config: + self.tools_config[tool_name].args = tool_config["args"] + if "timeout" in tool_config: + self.tools_config[tool_name].timeout = tool_config["timeout"] + if "priority" in tool_config: + self.tools_config[tool_name].priority = tool_config["priority"] + + if self.verbose: + print(f"✓ Configuration loaded from {config_file}") + + except Exception as e: + logging.warning(f"Failed to load config from {config_file}: {e}") + + def run_comprehensive_analysis(self) -> dict[str, Any]: + """Run comprehensive analysis using all available tools and methods.""" + start_time = time.time() + all_errors = [] + + if self.console: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=self.console, + ) as progress: + task = progress.add_task("Running comprehensive analysis...", total=None) + + # Graph-sitter analysis + progress.update(task, description="Analyzing codebase structure...") + graph_sitter_results = self._run_graph_sitter_analysis() + + # LSP diagnostics + progress.update(task, description="Collecting LSP diagnostics...") + lsp_errors = self._collect_lsp_diagnostics() + all_errors.extend(lsp_errors) + + # Ruff comprehensive analysis + progress.update(task, description="Running Ruff analysis...") + ruff_errors = self._run_ruff_analysis() + all_errors.extend(ruff_errors) + + # Traditional tools + progress.update(task, description="Running traditional analysis tools...") + tool_errors = self._run_traditional_tools() + all_errors.extend(tool_errors) + + progress.update(task, description="Categorizing and processing errors...") + else: + print("Running comprehensive analysis...") + graph_sitter_results = self._run_graph_sitter_analysis() + lsp_errors = self._collect_lsp_diagnostics() + all_errors.extend(lsp_errors) + ruff_errors = self._run_ruff_analysis() + all_errors.extend(ruff_errors) + tool_errors = self._run_traditional_tools() + all_errors.extend(tool_errors) + + # Categorize errors + categorized_errors = self._categorize_errors(all_errors) + + # Calculate metrics + metrics = self._calculate_metrics(all_errors, graph_sitter_results) + + # Detect dead code + dead_code = self._detect_dead_code(graph_sitter_results) + + # Generate summary + summary = self._generate_summary(all_errors, categorized_errors, metrics) + + end_time = time.time() + + results = { + "metadata": { + "target_path": self.target_path, + "analysis_time": round(end_time - start_time, 2), + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "tools_used": [name for name, config in self.tools_config.items() if config.enabled], + "graph_sitter_available": GRAPH_SITTER_AVAILABLE, + "lsp_available": SOLIDLSP_AVAILABLE, + "autogenlib_available": AUTOGENLIB_AVAILABLE, + }, + "summary": summary, + "errors": [error.to_dict() for error in all_errors], + "categorized_errors": categorized_errors, + "graph_sitter_results": graph_sitter_results, + "dead_code": dead_code, + "metrics": metrics, + "quality_score": self._calculate_quality_score(all_errors, metrics), + } + + # Store in database + session_id = self.error_db.create_session(self.target_path, results["metadata"]["tools_used"], self.config) + self.error_db.store_errors(all_errors, session_id) + self.error_db.update_session(session_id, len(all_errors)) + + self.last_results = results + return results + + def _run_graph_sitter_analysis(self) -> dict[str, Any]: + """Run graph-sitter analysis.""" + if not self.graph_sitter: + return {} + + try: + summary = self.graph_sitter.get_codebase_summary() + + # Add detailed analysis + functions_analysis = [] + for func in self.graph_sitter.functions[:20]: # Limit for performance + func_name = getattr(func, "name", "") if hasattr(func, "name") else "" + if func_name: + func_analysis = self.graph_sitter.get_function_analysis(func_name) + if func_analysis: + functions_analysis.append(func_analysis) + + classes_analysis = [] + for cls in self.graph_sitter.classes[:20]: # Limit for performance + cls_name = getattr(cls, "name", "") if hasattr(cls, "name") else "" + if cls_name: + cls_analysis = self.graph_sitter.get_class_analysis(cls_name) + if cls_analysis: + classes_analysis.append(cls_analysis) + + return { + "summary": summary, + "functions": functions_analysis, + "classes": classes_analysis, + "external_modules": [getattr(mod, "name", "") for mod in self.graph_sitter.external_modules[:50]], + } + + except Exception as e: + logging.exception(f"Graph-sitter analysis failed: {e}") + return {} + + def _collect_lsp_diagnostics(self) -> list[AnalysisError]: + """Collect diagnostics from LSP servers.""" + if not self.lsp_collector: + return [] + + try: + return self.lsp_collector.collect_python_diagnostics() + except Exception as e: + logging.exception(f"LSP diagnostics collection failed: {e}") + return [] + + def _run_ruff_analysis(self) -> list[AnalysisError]: + """Run comprehensive Ruff analysis.""" + if not self.ruff_integration: + return [] + + try: + return self.ruff_integration.run_comprehensive_analysis() + except Exception as e: + logging.exception(f"Ruff analysis failed: {e}") + return [] + + def _run_traditional_tools(self) -> list[AnalysisError]: + """Run traditional analysis tools.""" + errors = [] + + # Run tools in parallel by priority + priority_groups = {} + for tool_name, tool_config in self.tools_config.items(): + if not tool_config.enabled: + continue + + priority = tool_config.priority + if priority not in priority_groups: + priority_groups[priority] = [] + priority_groups[priority].append((tool_name, tool_config)) + + # Execute by priority (1=critical first, then 2, then 3) + for priority in sorted(priority_groups.keys()): + tools_group = priority_groups[priority] + + with ThreadPoolExecutor(max_workers=4) as executor: + future_to_tool = {executor.submit(self._run_single_tool, tool_name, tool_config): tool_name for tool_name, tool_config in tools_group} + + for future in as_completed(future_to_tool): + tool_name = future_to_tool[future] + try: + tool_errors = future.result() + errors.extend(tool_errors) + except Exception as e: + logging.exception(f"Tool {tool_name} failed: {e}") + + return errors + + def _run_single_tool(self, tool_name: str, tool_config: ToolConfig) -> list[AnalysisError]: + """Run a single analysis tool.""" + errors = [] + + try: + # Build command + cmd = [tool_config.command, *tool_config.args, self.target_path] + + # Run tool + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=tool_config.timeout, + cwd=os.path.dirname(self.target_path) or ".", + ) + + # Parse output based on tool + if tool_name == "mypy" and result.stdout: + errors.extend(self._parse_mypy_output(result.stdout)) + elif tool_name == "pylint" and result.stdout: + errors.extend(self._parse_pylint_output(result.stdout)) + elif tool_name == "bandit" and result.stdout: + errors.extend(self._parse_bandit_output(result.stdout)) + elif tool_name == "safety" and result.stdout: + errors.extend(self._parse_safety_output(result.stdout)) + elif tool_name == "semgrep" and result.stdout: + errors.extend(self._parse_semgrep_output(result.stdout)) + elif result.stdout or result.stderr: + # Generic parsing for other tools + errors.extend(self._parse_generic_output(tool_name, result)) + + except subprocess.TimeoutExpired: + logging.warning(f"Tool {tool_name} timed out") + except Exception as e: + logging.exception(f"Tool {tool_name} failed: {e}") + + return errors + + def _parse_mypy_output(self, output: str) -> list[AnalysisError]: + """Parse MyPy JSON output.""" + errors = [] + try: + # Try to parse as JSON first + if output.strip().startswith("{"): + data = json.loads(output) + for error in data.get("errors", []): + errors.append( + AnalysisError( + file_path=error.get("file", ""), + line=error.get("line", 0), + column=error.get("column", 0), + error_type=error.get("code", "mypy-error"), + severity="ERROR" if error.get("severity") == "error" else "WARNING", + message=error.get("message", ""), + tool_source="mypy", + category="type_checking", + confidence=0.9, + ) + ) + else: + # Parse text format + for line in output.split("\n"): + if ":" in line and ("error:" in line or "warning:" in line): + parts = line.split(":") + if len(parts) >= 4: + errors.append( + AnalysisError( + file_path=parts[0], + line=int(parts[1]) if parts[1].isdigit() else 0, + column=int(parts[2]) if parts[2].isdigit() else 0, + error_type="mypy-error", + severity="ERROR" if "error:" in line else "WARNING", + message=":".join(parts[3:]).strip(), + tool_source="mypy", + category="type_checking", + confidence=0.9, + ) + ) + except Exception as e: + logging.warning(f"Failed to parse MyPy output: {e}") + + return errors + + def _parse_pylint_output(self, output: str) -> list[AnalysisError]: + """Parse Pylint JSON output.""" + errors = [] + try: + data = json.loads(output) + for error in data: + errors.append( + AnalysisError( + file_path=error.get("path", ""), + line=error.get("line", 0), + column=error.get("column", 0), + error_type=error.get("message-id", ""), + severity=error.get("type", "INFO").upper(), + message=error.get("message", ""), + tool_source="pylint", + category=self._categorize_pylint_error(error.get("message-id", "")), + confidence=0.8, + ) + ) + except Exception as e: + logging.warning(f"Failed to parse Pylint output: {e}") + + return errors + + def _parse_bandit_output(self, output: str) -> list[AnalysisError]: + """Parse Bandit JSON output.""" + errors = [] + try: + data = json.loads(output) + for result in data.get("results", []): + errors.append( + AnalysisError( + file_path=result.get("filename", ""), + line=result.get("line_number", 0), + column=0, + error_type=result.get("test_id", ""), + severity="SECURITY", + message=result.get("issue_text", ""), + tool_source="bandit", + category="security_critical", + confidence=result.get("confidence", 0.5), + ) + ) + except Exception as e: + logging.warning(f"Failed to parse Bandit output: {e}") + + return errors + + def _parse_safety_output(self, output: str) -> list[AnalysisError]: + """Parse Safety JSON output.""" + errors = [] + try: + data = json.loads(output) + for vuln in data.get("vulnerabilities", []): + errors.append( + AnalysisError( + file_path=self.target_path, + line=0, + column=0, + error_type=vuln.get("id", ""), + severity="SECURITY", + message=f"Vulnerable dependency: {vuln.get('package_name', '')} {vuln.get('installed_version', '')}", + tool_source="safety", + category="dependency_security", + confidence=0.95, + ) + ) + except Exception as e: + logging.warning(f"Failed to parse Safety output: {e}") + + return errors + + def _parse_semgrep_output(self, output: str) -> list[AnalysisError]: + """Parse Semgrep JSON output.""" + errors = [] + try: + data = json.loads(output) + for result in data.get("results", []): + errors.append( + AnalysisError( + file_path=result.get("path", ""), + line=result.get("start", {}).get("line", 0), + column=result.get("start", {}).get("col", 0), + error_type=result.get("check_id", ""), + severity="WARNING", + message=result.get("extra", {}).get("message", ""), + tool_source="semgrep", + category="security_pattern", + confidence=0.85, + ) + ) + except Exception as e: + logging.warning(f"Failed to parse Semgrep output: {e}") + + return errors + + def _parse_generic_output(self, tool_name: str, result: subprocess.CompletedProcess) -> list[AnalysisError]: + """Parse generic tool output.""" + errors = [] + + # If tool failed or has output, create a generic error + if result.returncode != 0 or result.stdout.strip() or result.stderr.strip(): + output = result.stdout or result.stderr + + # Try to extract file:line information + for line in output.split("\n"): + if ":" in line and any(keyword in line.lower() for keyword in ["error", "warning", "issue"]): + parts = line.split(":") + if len(parts) >= 2: + errors.append( + AnalysisError( + file_path=parts[0] if os.path.exists(parts[0]) else self.target_path, + line=int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0, + column=0, + error_type=f"{tool_name}-issue", + severity="WARNING", + message=line.strip(), + tool_source=tool_name, + category="general", + confidence=0.7, + ) + ) + + return errors + + def _categorize_errors(self, errors: list[AnalysisError]) -> dict[str, list[AnalysisError]]: + """Categorize errors into comprehensive categories.""" + categories = { + "syntax_critical": [], + "type_critical": [], + "security_critical": [], + "logic_critical": [], + "import_critical": [], + "performance_major": [], + "complexity_major": [], + "style_major": [], + "documentation_major": [], + "naming_major": [], + "dependency_major": [], + "async_major": [], + "testing_minor": [], + "formatting_minor": [], + "optimization_minor": [], + "general": [], + } + + for error in errors: + category = error.category + + # Enhanced categorization based on error type and severity + if error.severity == "ERROR": + if "syntax" in error.error_type.lower() or "F" in error.error_type: + categories["syntax_critical"].append(error) + elif "type" in error.error_type.lower() or error.tool_source in [ + "mypy", + "pyright", + ]: + categories["type_critical"].append(error) + elif "import" in error.error_type.lower(): + categories["import_critical"].append(error) + else: + categories["logic_critical"].append(error) + elif error.severity == "SECURITY": + categories["security_critical"].append(error) + elif "performance" in category or "complexity" in category: + categories["performance_major"].append(error) + elif "style" in category or "format" in category: + categories["style_major"].append(error) + elif "doc" in category: + categories["documentation_major"].append(error) + elif "naming" in category: + categories["naming_major"].append(error) + elif "dependency" in category: + categories["dependency_major"].append(error) + elif "async" in category: + categories["async_major"].append(error) + else: + categories["general"].append(error) + + return categories + + def _detect_dead_code(self, graph_sitter_results: dict[str, Any]) -> list[dict[str, Any]]: + """Detect dead code using graph-sitter analysis.""" + dead_code = [] + + if not self.graph_sitter: + return dead_code + + try: + # Find unused functions + for func in self.graph_sitter.functions: + func_name = getattr(func, "name", "") + usages = getattr(func, "usages", []) + if func_name and len(usages) == 0: + dead_code.append( + { + "type": "function", + "name": func_name, + "file_path": getattr(func, "file_path", ""), + "line": getattr(func, "line", 0), + "reason": "Function is defined but never called", + } + ) + + # Find unused classes + for cls in self.graph_sitter.classes: + cls_name = getattr(cls, "name", "") + usages = getattr(cls, "usages", []) + if cls_name and len(usages) == 0: + dead_code.append( + { + "type": "class", + "name": cls_name, + "file_path": getattr(cls, "file_path", ""), + "line": getattr(cls, "line", 0), + "reason": "Class is defined but never instantiated or referenced", + } + ) + + # Find unused imports + for imp in self.graph_sitter.imports: + imp_name = getattr(imp, "name", "") + usages = getattr(imp, "usages", []) + if imp_name and len(usages) == 0: + dead_code.append( + { + "type": "import", + "name": imp_name, + "file_path": getattr(imp, "file_path", ""), + "line": getattr(imp, "line", 0), + "reason": "Import is unused", + } + ) + + except Exception as e: + logging.exception(f"Dead code detection failed: {e}") + + return dead_code + + def _calculate_metrics(self, errors: list[AnalysisError], graph_sitter_results: dict[str, Any]) -> dict[str, Any]: + """Calculate comprehensive code metrics.""" + metrics = { + "total_errors": len(errors), + "error_density": 0, + "complexity_metrics": {}, + "dependency_metrics": {}, + "performance_metrics": {}, + } + + # Error density calculation + if graph_sitter_results.get("summary", {}).get("files", 0) > 0: + metrics["error_density"] = len(errors) / graph_sitter_results["summary"]["files"] + + # Complexity metrics + complexity_errors = [e for e in errors if "complexity" in e.category] + metrics["complexity_metrics"] = { + "high_complexity_count": len(complexity_errors), + "average_complexity": sum(e.confidence for e in complexity_errors) / len(complexity_errors) if complexity_errors else 0, + } + + # Dependency metrics + dependency_errors = [e for e in errors if "dependency" in e.category] + metrics["dependency_metrics"] = { + "vulnerable_dependencies": len(dependency_errors), + "external_dependencies": len(graph_sitter_results.get("external_modules", [])), + } + + # Performance metrics + performance_errors = [e for e in errors if "performance" in e.category] + metrics["performance_metrics"] = { + "performance_issues": len(performance_errors), + "performance_warnings": [e.message for e in performance_errors[:5]], + } + + return metrics + + def _generate_summary( + self, + errors: list[AnalysisError], + categorized_errors: dict[str, list[AnalysisError]], + metrics: dict[str, Any], + ) -> dict[str, Any]: + """Generate comprehensive analysis summary.""" + return { + "overview": { + "total_errors": len(errors), + "critical_errors": len(categorized_errors.get("syntax_critical", [])) + len(categorized_errors.get("type_critical", [])) + len(categorized_errors.get("security_critical", [])), + "major_issues": sum(len(categorized_errors.get(cat, [])) for cat in categorized_errors if "major" in cat), + "minor_issues": sum(len(categorized_errors.get(cat, [])) for cat in categorized_errors if "minor" in cat), + }, + "by_severity": { + "ERROR": len([e for e in errors if e.severity == "ERROR"]), + "WARNING": len([e for e in errors if e.severity == "WARNING"]), + "SECURITY": len([e for e in errors if e.severity == "SECURITY"]), + "INFO": len([e for e in errors if e.severity == "INFO"]), + }, + "by_tool": {tool: len([e for e in errors if e.tool_source == tool]) for tool in set(e.tool_source for e in errors)}, + "quality_metrics": metrics, + } + + def _calculate_quality_score(self, errors: list[AnalysisError], metrics: dict[str, Any]) -> float: + """Calculate overall code quality score (0-100).""" + if not errors: + return 100.0 + + # Base score + score = 100.0 + + # Deduct points for errors by severity + for error in errors: + if error.severity == "ERROR": + score -= 5.0 + elif error.severity == "SECURITY": + score -= 8.0 + elif error.severity == "WARNING": + score -= 2.0 + else: + score -= 0.5 + + # Deduct for complexity + complexity_count = metrics.get("complexity_metrics", {}).get("high_complexity_count", 0) + score -= complexity_count * 3.0 + + # Deduct for dependency issues + vuln_deps = metrics.get("dependency_metrics", {}).get("vulnerable_dependencies", 0) + score -= vuln_deps * 10.0 + + return max(0.0, min(100.0, score)) + + def _categorize_pylint_error(self, message_id: str) -> str: + """Categorize Pylint errors.""" + if message_id.startswith("C"): + return "style_major" + elif message_id.startswith("R"): + return "complexity_major" + elif message_id.startswith("W"): + return "warning_major" + elif message_id.startswith("E"): + return "logic_critical" + else: + return "general" + + def fix_errors_with_autogenlib(self, max_fixes: int = 5) -> dict[str, Any]: + """Use AutoGenLib to generate fixes for errors.""" + if not self.autogenlib_fixer or not self.last_results: + return {"error": "AutoGenLib not available or no analysis results"} + + fixes_applied = [] + errors_to_fix = [] + + # Select errors to fix (prioritize critical errors) + categorized = self.last_results.get("categorized_errors", {}) + + # Get critical errors first + for category in ["syntax_critical", "type_critical", "logic_critical"]: + errors_to_fix.extend(categorized.get(category, [])[:2]) # Max 2 per category + + # Limit total fixes + errors_to_fix = errors_to_fix[:max_fixes] + + for error_dict in errors_to_fix: + error = AnalysisError(**error_dict) + + try: + # Read source file + with open(error.file_path) as f: + source_code = f.read() + + # Generate fix + fix_info = self.autogenlib_fixer.generate_fix_for_error(error, source_code) + + if fix_info and fix_info.get("fixed_code"): + # Apply fix + success = self.autogenlib_fixer.apply_fix_to_file(error.file_path, fix_info["fixed_code"]) + + fixes_applied.append( + { + "error": error.to_dict(), + "fix_applied": success, + "fix_explanation": fix_info.get("explanation", ""), + "backup_created": success, + } + ) + + except Exception as e: + logging.exception(f"Failed to fix error in {error.file_path}: {e}") + fixes_applied.append( + { + "error": error.to_dict(), + "fix_applied": False, + "error_message": str(e), + } + ) + + return { + "fixes_attempted": len(errors_to_fix), + "fixes_applied": len([f for f in fixes_applied if f.get("fix_applied")]), + "fixes_details": fixes_applied, + } + + +class InteractiveAnalyzer: + """Interactive analysis session for code exploration.""" + + def __init__(self, analyzer: ComprehensiveAnalyzer): + self.analyzer = analyzer + self.console = Console() if RICH_AVAILABLE else None + + def start_interactive_session(self): + """Start an interactive analysis session.""" + if self.console: + self.console.print( + Panel.fit( + "🔍 Interactive Code Analysis Session\nCommands: summary, errors [category], function [name], class [name], fix, export [format], quit", + title="Analysis Shell", + ) + ) + else: + print("=== Interactive Code Analysis Session ===") + print("Commands: summary, errors [category], function [name], class [name], fix, export [format], quit") + + while True: + try: + command = input("\nanalysis> ").strip().lower() + + if command == "quit" or command == "exit": + break + elif command == "summary": + self._show_summary() + elif command.startswith("errors"): + category = command.split()[1] if len(command.split()) > 1 else None + self._show_errors(category) + elif command.startswith("function"): + func_name = command.split()[1] if len(command.split()) > 1 else None + self._show_function_analysis(func_name) + elif command.startswith("class"): + class_name = command.split()[1] if len(command.split()) > 1 else None + self._show_class_analysis(class_name) + elif command == "fix": + self._apply_fixes() + elif command.startswith("export"): + format_type = command.split()[1] if len(command.split()) > 1 else "json" + self._export_results(format_type) + else: + print("Unknown command. Available: summary, errors, function, class, fix, export, quit") + + except KeyboardInterrupt: + break + except Exception as e: + print(f"Error: {e}") + + print("Interactive session ended.") + + def _show_summary(self): + """Show analysis summary.""" + if not self.analyzer.last_results: + print("No analysis results available. Run analysis first.") + return + + summary = self.analyzer.last_results.get("summary", {}) + + if self.console: + table = Table(title="Analysis Summary") + table.add_column("Metric", style="cyan") + table.add_column("Value", style="magenta") + + overview = summary.get("overview", {}) + for key, value in overview.items(): + table.add_row(key.replace("_", " ").title(), str(value)) + + self.console.print(table) + else: + print("\n=== Analysis Summary ===") + overview = summary.get("overview", {}) + for key, value in overview.items(): + print(f"{key.replace('_', ' ').title()}: {value}") + + def _show_errors(self, category: str | None = None): + """Show errors, optionally filtered by category.""" + if not self.analyzer.last_results: + print("No analysis results available.") + return + + categorized = self.analyzer.last_results.get("categorized_errors", {}) + + if category: + errors = categorized.get(category, []) + print(f"\n=== {category.replace('_', ' ').title()} Errors ===") + else: + errors = self.analyzer.last_results.get("errors", []) + print(f"\n=== All Errors ({len(errors)}) ===") + + for i, error in enumerate(errors[:20]): # Show first 20 + print(f"{i + 1}. {error['file_path']}:{error['line']} - {error['message']} [{error['tool_source']}]") + + def _show_function_analysis(self, func_name: str | None): + """Show function analysis.""" + if not func_name: + print("Please specify a function name: function ") + return + + if not self.analyzer.graph_sitter: + print("Graph-sitter not available for function analysis.") + return + + analysis = self.analyzer.graph_sitter.get_function_analysis(func_name) + if analysis: + print(f"\n=== Function Analysis: {func_name} ===") + for key, value in analysis.items(): + print(f"{key}: {value}") + else: + print(f"Function '{func_name}' not found.") + + def _show_class_analysis(self, class_name: str | None): + """Show class analysis.""" + if not class_name: + print("Please specify a class name: class ") + return + + if not self.analyzer.graph_sitter: + print("Graph-sitter not available for class analysis.") + return + + analysis = self.analyzer.graph_sitter.get_class_analysis(class_name) + if analysis: + print(f"\n=== Class Analysis: {class_name} ===") + for key, value in analysis.items(): + print(f"{key}: {value}") + else: + print(f"Class '{class_name}' not found.") + + def _apply_fixes(self): + """Apply AutoGenLib fixes.""" + if not self.analyzer.autogenlib_fixer: + print("AutoGenLib not available for fixing.") + return + + print("Applying AI-powered fixes...") + fix_results = self.analyzer.fix_errors_with_autogenlib() + + print(f"Fixes attempted: {fix_results.get('fixes_attempted', 0)}") + print(f"Fixes applied: {fix_results.get('fixes_applied', 0)}") + + for fix in fix_results.get("fixes_details", []): + if fix.get("fix_applied"): + print(f"✓ Fixed: {fix['error']['message']}") + else: + print(f"✗ Failed: {fix['error']['message']}") + + def _export_results(self, format_type: str): + """Export results in specified format.""" + if not self.analyzer.last_results: + print("No results to export.") + return + + timestamp = time.strftime("%Y%m%d_%H%M%S") + filename = f"analysis_export_{timestamp}.{format_type}" + + try: + if format_type == "json": + with open(filename, "w") as f: + json.dump(self.analyzer.last_results, f, indent=2) + elif format_type == "html": + html_content = ReportGenerator(self.analyzer.last_results).generate_html_report() + with open(filename, "w") as f: + f.write(html_content) + else: + print(f"Unsupported format: {format_type}") + return + + print(f"Results exported to: {filename}") + + except Exception as e: + print(f"Export failed: {e}") + + +class ReportGenerator: + """Generate comprehensive analysis reports.""" + + def __init__(self, results: dict[str, Any]): + self.results = results + + def generate_terminal_report(self) -> str: + """Generate a comprehensive terminal report.""" + lines = [] + + # Header + lines.append("=" * 100) + lines.append("COMPREHENSIVE CODE ANALYSIS REPORT") + lines.append("=" * 100) + + # Metadata + metadata = self.results.get("metadata", {}) + lines.append(f"Target: {metadata.get('target_path', 'Unknown')}") + lines.append(f"Analysis Time: {metadata.get('analysis_time', 0)}s") + lines.append(f"Timestamp: {metadata.get('timestamp', 'Unknown')}") + lines.append(f"Tools Used: {', '.join(metadata.get('tools_used', []))}") + lines.append("") + + # Quality Score + quality_score = self.results.get("quality_score", 0) + lines.append(f"🎯 QUALITY SCORE: {quality_score:.1f}/100") + lines.append("") + + # Summary + summary = self.results.get("summary", {}) + overview = summary.get("overview", {}) + + lines.append("📊 SUMMARY") + lines.append("-" * 50) + lines.append(f"Total Errors: {overview.get('total_errors', 0)}") + lines.append(f"Critical Errors: {overview.get('critical_errors', 0)}") + lines.append(f"Major Issues: {overview.get('major_issues', 0)}") + lines.append(f"Minor Issues: {overview.get('minor_issues', 0)}") + lines.append("") + + # Graph-sitter results + gs_results = self.results.get("graph_sitter_results", {}) + if gs_results: + lines.append("🏗️ CODEBASE STRUCTURE") + lines.append("-" * 50) + gs_summary = gs_results.get("summary", {}) + for key, value in gs_summary.items(): + lines.append(f"{key.replace('_', ' ').title()}: {value}") + lines.append("") + + # Dead code + dead_code = self.results.get("dead_code", []) + if dead_code: + lines.append(f"💀 DEAD CODE ({len(dead_code)} items)") + lines.append("-" * 50) + for item in dead_code[:10]: # Show first 10 + lines.append(f"{item['type'].title()}: {item['name']} - {item['reason']}") + if len(dead_code) > 10: + lines.append(f"... and {len(dead_code) - 10} more items") + lines.append("") + + # Categorized errors + categorized = self.results.get("categorized_errors", {}) + lines.append("🔍 ERRORS BY CATEGORY") + lines.append("-" * 50) + + for category, errors in categorized.items(): + if errors: + lines.append(f"\n{category.replace('_', ' ').title()} ({len(errors)} errors):") + for error in errors[:5]: # Show first 5 per category + lines.append(f" • {error['file_path']}:{error['line']} - {error['message']}") + if len(errors) > 5: + lines.append(f" ... and {len(errors) - 5} more") + + lines.append("") + lines.append("=" * 100) + + return "\n".join(lines) + + def generate_html_report(self) -> str: + """Generate comprehensive HTML report.""" + html = f""" + + + + + + Comprehensive Code Analysis Report + + + +
+

🔍 Comprehensive Code Analysis Report

+ +
+
+
{self.results.get("quality_score", 0):.1f}
+
Quality Score
+
+
+
{len(self.results.get("errors", []))}
+
Total Issues
+
+
+
{len(self.results.get("dead_code", []))}
+
Dead Code Items
+
+
+
{len(self.results.get("metadata", {}).get("tools_used", []))}
+
Tools Used
+
+
+ +

📈 Error Categories

+ {self._generate_error_categories_html()} + +

💀 Dead Code Analysis

+ {self._generate_dead_code_html()} + +

🏗️ Codebase Structure

+ {self._generate_structure_html()} + +

📊 Detailed Metrics

+ {self._generate_metrics_html()} +
+ + +""" + return html + + def _generate_error_categories_html(self) -> str: + """Generate HTML for error categories.""" + categorized = self.results.get("categorized_errors", {}) + html_parts = [] + + for category, errors in categorized.items(): + if not errors: + continue + + html_parts.append('
') + html_parts.append(f"

{category.replace('_', ' ').title()} ({len(errors)} errors)

") + + for error in errors[:10]: # Show first 10 per category + html_parts.append(f""" +
+ {error["file_path"]}:{error["line"]} + {error["tool_source"]} +
+ {error["message"]} +
+ """) + + if len(errors) > 10: + html_parts.append(f"

... and {len(errors) - 10} more errors

") + + html_parts.append("
") + + return "".join(html_parts) + + def _generate_dead_code_html(self) -> str: + """Generate HTML for dead code analysis.""" + dead_code = self.results.get("dead_code", []) + + if not dead_code: + return "

No dead code detected.

" + + html_parts = ['
'] + + for item in dead_code[:20]: # Show first 20 + html_parts.append(f""" +
+ {item["type"].title()}: {item["name"]}
+ {item["file_path"]}:{item["line"]}
+ {item["reason"]} +
+ """) + + if len(dead_code) > 20: + html_parts.append(f"

... and {len(dead_code) - 20} more items

") + + html_parts.append("
") + return "".join(html_parts) + + def _generate_structure_html(self) -> str: + """Generate HTML for codebase structure.""" + gs_results = self.results.get("graph_sitter_results", {}) + + if not gs_results: + return "

Graph-sitter analysis not available.

" + + summary = gs_results.get("summary", {}) + html_parts = ['
'] + + for key, value in summary.items(): + html_parts.append(f""" +
+
{value}
+
{key.replace("_", " ").title()}
+
+ """) + + html_parts.append("
") + return "".join(html_parts) + + def _generate_metrics_html(self) -> str: + """Generate HTML for detailed metrics.""" + metrics = self.results.get("metrics", {}) + + html_parts = [] + for category, data in metrics.items(): + html_parts.append(f"

{category.replace('_', ' ').title()}

") + html_parts.append("
    ") + + if isinstance(data, dict): + for key, value in data.items(): + html_parts.append(f"
  • {key}: {value}
  • ") + else: + html_parts.append(f"
  • {data}
  • ") + + html_parts.append("
") + + return "".join(html_parts) + + +def main(): + """Main entry point for the comprehensive analysis system.""" + parser = argparse.ArgumentParser(description="Comprehensive Python Code Analysis with Graph-Sitter, LSP, and AI-powered fixing") + parser.add_argument("--target", required=True, help="Target file or directory to analyze") + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument("--config", help="Configuration file path") + parser.add_argument("--comprehensive", action="store_true", help="Run comprehensive analysis") + parser.add_argument("--fix-errors", action="store_true", help="Apply AI-powered fixes") + parser.add_argument("--interactive", action="store_true", help="Start interactive session") + parser.add_argument( + "--format", + choices=["terminal", "json", "html"], + default="terminal", + help="Output format", + ) + parser.add_argument("--output", help="Output file path") + parser.add_argument("--max-fixes", type=int, default=5, help="Maximum number of fixes to apply") + + args = parser.parse_args() + + # Initialize analyzer + config = {"config_file": args.config} if args.config else {} + analyzer = ComprehensiveAnalyzer(args.target, config, args.verbose) + + try: + if args.comprehensive or not (args.fix_errors or args.interactive): + # Run comprehensive analysis + print("🚀 Starting comprehensive analysis...") + results = analyzer.run_comprehensive_analysis() + + # Generate report + generator = ReportGenerator(results) + + if args.format == "terminal": + report = generator.generate_terminal_report() + print(report) + + if args.output: + with open(args.output, "w") as f: + f.write(report) + print(f"\nReport saved to: {args.output}") + + elif args.format == "json": + output_file = args.output or f"analysis_report_{int(time.time())}.json" + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + print(f"JSON report saved to: {output_file}") + + elif args.format == "html": + output_file = args.output or f"analysis_report_{int(time.time())}.html" + html_report = generator.generate_html_report() + with open(output_file, "w") as f: + f.write(html_report) + print(f"HTML report saved to: {output_file}") + + # Try to open in browser + try: + import webbrowser + + webbrowser.open(output_file) + except Exception: + pass + + if args.fix_errors: + # Apply fixes + if not analyzer.last_results: + print("Running analysis first...") + analyzer.run_comprehensive_analysis() + + print("🔧 Applying AI-powered fixes...") + fix_results = analyzer.fix_errors_with_autogenlib(args.max_fixes) + + print(f"Fixes attempted: {fix_results.get('fixes_attempted', 0)}") + print(f"Fixes applied: {fix_results.get('fixes_applied', 0)}") + + for fix in fix_results.get("fixes_details", []): + if fix.get("fix_applied"): + print(f"✓ Fixed: {fix['error']['message']}") + else: + print(f"✗ Failed: {fix['error']['message']}") + + if args.interactive: + # Start interactive session + if not analyzer.last_results: + print("Running analysis first...") + analyzer.run_comprehensive_analysis() + + interactive = InteractiveAnalyzer(analyzer) + interactive.start_interactive_session() + + except KeyboardInterrupt: + print("\nAnalysis interrupted by user.") + sys.exit(1) + except Exception as e: + print(f"Analysis failed: {e}") + if args.verbose: + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + # Configure logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + main() diff --git a/Libraries/graph_sitter_lib/analysisbig.py b/Libraries/graph_sitter_lib/analysisbig.py new file mode 100644 index 00000000..593b392c --- /dev/null +++ b/Libraries/graph_sitter_lib/analysisbig.py @@ -0,0 +1,4461 @@ +#!/usr/bin/env python3 +""" +Comprehensive Python Code Analysis Backend with Graph-Sitter Integration + +This advanced analysis tool provides deep codebase insights using graph-sitter for +structural analysis, LSP integration for real-time error detection, and comprehensive +static analysis tools for maximum error comprehension. + +Features: +- Graph-sitter based codebase analysis +- LSP error detection and reporting +- Comprehensive static analysis tool integration +- Advanced error categorization and presentation +- Interactive analysis capabilities +- Performance profiling and metrics +- Security vulnerability detection +- Code quality assessment +- Dependency analysis +- Documentation coverage analysis + +Usage: + python analysis.py --target /path/to/codebase + python analysis.py --target /path/to/file.py --interactive + python analysis.py --target /path/to/project --graph-sitter +""" + +import argparse +import ast +import dataclasses +import json +import logging +import os +import re +import subprocess +import sys +import time +import traceback +import threading +import queue +import hashlib +import sqlite3 +from abc import ABC, abstractmethod +from collections import Counter, defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Union, Callable +from urllib.parse import urlparse +import yaml +import pickle + +# Graph-sitter imports (simulated - would be actual imports in real implementation) +try: + from graph_sitter import Codebase + from graph_sitter.configs.models.codebase import CodebaseConfig + from graph_sitter.codebase.codebase_analysis import ( + get_codebase_summary, + get_file_summary, + get_class_summary, + get_function_summary, + get_symbol_summary + ) + from graph_sitter.external_module import ExternalModule + from graph_sitter.import_resolution import Import + from graph_sitter.symbol import Symbol + from graph_sitter.file import File + from graph_sitter.function import Function + from graph_sitter.class_def import ClassDef + from graph_sitter.import_stmt import ImportStmt + from graph_sitter.directory import Directory + GRAPH_SITTER_AVAILABLE = True +except ImportError: + # Fallback implementations for demonstration + GRAPH_SITTER_AVAILABLE = False + print("Warning: Graph-sitter not available. Using fallback implementations.") + +# LSP types and error handling (based on solidlsp) +class ErrorCodes(Enum): + """LSP error codes for categorization.""" + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 + SERVER_ERROR_START = -32099 + SERVER_ERROR_END = -32000 + SERVER_NOT_INITIALIZED = -32002 + UNKNOWN_ERROR_CODE = -32001 + REQUEST_FAILED = -32803 + SERVER_CANCELLED = -32802 + CONTENT_MODIFIED = -32801 + REQUEST_CANCELLED = -32800 + + +class MessageType(Enum): + """LSP message types for severity classification.""" + ERROR = 1 + WARNING = 2 + INFO = 3 + LOG = 4 + + +class DiagnosticSeverity(Enum): + """Diagnostic severity levels.""" + ERROR = 1 + WARNING = 2 + INFORMATION = 3 + HINT = 4 + + +@dataclass +class Position: + """Represents a position in a text document.""" + line: int + character: int + + +@dataclass +class Range: + """Represents a range in a text document.""" + start: Position + end: Position + + +@dataclass +class Diagnostic: + """Represents a diagnostic message.""" + range: Range + severity: DiagnosticSeverity + code: Optional[str] + source: Optional[str] + message: str + related_information: Optional[List[Dict]] = None + tags: Optional[List[int]] = None + + +@dataclass +class LSPError: + """LSP error representation.""" + code: ErrorCodes + message: str + data: Optional[Dict] = None + + def to_dict(self) -> Dict[str, Any]: + result = {"code": self.code.value, "message": self.message} + if self.data: + result["data"] = self.data + return result + + +@dataclass +class AnalysisError: + """Represents an analysis error with comprehensive context.""" + file_path: str + line_number: Optional[int] + column_number: Optional[int] + error_type: str + severity: DiagnosticSeverity + message: str + tool_source: str + code_context: Optional[str] = None + fix_suggestion: Optional[str] = None + category: Optional[str] = None + subcategory: Optional[str] = None + error_code: Optional[str] = None + related_errors: List[str] = field(default_factory=list) + confidence: float = 1.0 + tags: List[str] = field(default_factory=list) + + +@dataclass +class ToolConfig: + """Configuration for a code analysis tool.""" + name: str + command: str + enabled: bool = True + args: List[str] = field(default_factory=list) + config_file: Optional[str] = None + timeout: int = 300 + category: str = "general" + priority: int = 1 + requires_network: bool = False + output_format: str = "text" + + +class ErrorCategory(Enum): + """Comprehensive error categorization system.""" + SYNTAX = "syntax" + TYPE = "type" + IMPORT = "import" + LOGIC = "logic" + STYLE = "style" + SECURITY = "security" + PERFORMANCE = "performance" + COMPATIBILITY = "compatibility" + DOCUMENTATION = "documentation" + TESTING = "testing" + MAINTAINABILITY = "maintainability" + DESIGN = "design" + DEPENDENCIES = "dependencies" + CONFIGURATION = "configuration" + DEPLOYMENT = "deployment" + + +class AnalysisSeverity(Enum): + """Analysis severity levels with priority.""" + CRITICAL = 1 + ERROR = 2 + WARNING = 3 + INFO = 4 + HINT = 5 + + +# Fallback Graph-sitter Analysis implementation for demonstration +class MockAnalysis: + """Mock Analysis class when graph-sitter is not available.""" + + def __init__(self, codebase_path: str): + self.codebase_path = codebase_path + self._files = [] + self._functions = [] + self._classes = [] + self._imports = [] + self._symbols = [] + self._external_modules = [] + self._scan_codebase() + + def _scan_codebase(self): + """Scan the codebase and extract basic information.""" + for root, dirs, files in os.walk(self.codebase_path): + # Skip common ignored directories + dirs[:] = [d for d in dirs if d not in {'.git', '__pycache__', '.venv', 'venv', 'node_modules'}] + + for file in files: + if file.endswith('.py'): + file_path = os.path.join(root, file) + try: + self._analyze_file(file_path) + except Exception as e: + logging.warning(f"Failed to analyze {file_path}: {e}") + + def _analyze_file(self, file_path: str): + """Analyze a single file using AST parsing.""" + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + tree = ast.parse(content) + + # Extract functions, classes, and imports + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + self._functions.append({ + 'name': node.name, + 'file_path': file_path, + 'line_number': node.lineno, + 'is_async': isinstance(node, ast.AsyncFunctionDef), + 'decorators': [d.id if isinstance(d, ast.Name) else str(d) for d in node.decorator_list], + 'parameters': [arg.arg for arg in node.args.args] + }) + self._symbols.append(self._functions[-1]) + + elif isinstance(node, ast.ClassDef): + self._classes.append({ + 'name': node.name, + 'file_path': file_path, + 'line_number': node.lineno, + 'bases': [base.id if isinstance(base, ast.Name) else str(base) for base in node.bases], + 'decorators': [d.id if isinstance(d, ast.Name) else str(d) for d in node.decorator_list] + }) + self._symbols.append(self._classes[-1]) + + elif isinstance(node, (ast.Import, ast.ImportFrom)): + if isinstance(node, ast.Import): + for alias in node.names: + self._imports.append({ + 'module': alias.name, + 'name': alias.asname or alias.name, + 'file_path': file_path, + 'line_number': node.lineno, + 'is_external': not alias.name.startswith('.') + }) + else: # ImportFrom + module = node.module or '' + for alias in node.names: + self._imports.append({ + 'module': module, + 'name': alias.name, + 'file_path': file_path, + 'line_number': node.lineno, + 'is_external': not module.startswith('.') + }) + + self._files.append({ + 'path': file_path, + 'content': content, + 'line_count': len(content.splitlines()) + }) + + except Exception as e: + logging.warning(f"Error parsing {file_path}: {e}") + + @property + def functions(self): + return self._functions + + @property + def classes(self): + return self._classes + + @property + def imports(self): + return self._imports + + @property + def files(self): + return self._files + + @property + def symbols(self): + return self._symbols + + @property + def external_modules(self): + return self._external_modules + + +class GraphSitterAnalysis: + """ + Enhanced Analysis class providing pre-computed graph element access and advanced analysis. + + Integrates with graph-sitter for comprehensive codebase understanding and provides + the exact API shown in the documentation with additional error detection capabilities. + """ + + def __init__(self, codebase_path: str, config: Optional[Dict] = None): + """Initialize Analysis with a Codebase instance.""" + self.codebase_path = codebase_path + self.config = config or {} + + if GRAPH_SITTER_AVAILABLE: + # Initialize with performance optimizations + gs_config = CodebaseConfig( + method_usages=True, # Enable method usage resolution + generics=True, # Enable generic type resolution + sync_enabled=True, # Enable graph sync during commits + full_range_index=True, # Full range-to-node mapping + py_resolve_syspath=True, # Resolve sys.path imports + exp_lazy_graph=False, # Lazy graph construction + ) + + self.codebase = Codebase(codebase_path, config=gs_config) + self.analysis = self.codebase # Direct codebase access + else: + # Use fallback implementation + self.analysis = MockAnalysis(codebase_path) + self.codebase = None + + @property + def functions(self): + """ + All functions in codebase with enhanced analysis. + + Each function provides: + - function.usages # All usage sites + - function.call_sites # All call locations + - function.dependencies # Function dependencies + - function.function_calls # Functions this function calls + - function.parameters # Function parameters + - function.return_statements # Return statements + - function.decorators # Function decorators + - function.is_async # Async function detection + - function.is_generator # Generator function detection + """ + return self.analysis.functions + + @property + def classes(self): + """ + All classes in codebase with comprehensive analysis. + + Each class provides: + - cls.superclasses # Parent classes + - cls.subclasses # Child classes + - cls.methods # Class methods + - cls.attributes # Class attributes + - cls.decorators # Class decorators + - cls.usages # Class usage sites + - cls.dependencies # Class dependencies + - cls.is_abstract # Abstract class detection + """ + return self.analysis.classes + + @property + def imports(self): + """All import statements in the codebase.""" + return self.analysis.imports + + @property + def files(self): + """ + All files in the codebase with import analysis. + + Each file provides: + - file.imports # Outbound imports + - file.inbound_imports # Files that import this file + - file.symbols # Symbols defined in file + - file.external_modules # External dependencies + """ + return self.analysis.files + + @property + def symbols(self): + """All symbols (functions, classes, variables) in the codebase.""" + return self.analysis.symbols + + @property + def external_modules(self): + """External dependencies imported by the codebase.""" + return self.analysis.external_modules + + def get_codebase_summary(self) -> Dict[str, Any]: + """Get comprehensive codebase summary.""" + if GRAPH_SITTER_AVAILABLE and self.codebase: + return get_codebase_summary(self.codebase) + + # Fallback implementation + return { + "total_files": len(self.files), + "total_functions": len(self.functions), + "total_classes": len(self.classes), + "total_imports": len(self.imports), + "total_symbols": len(self.symbols), + "external_dependencies": len(self.external_modules), + "lines_of_code": sum(f.get('line_count', 0) for f in self.files), + "complexity_score": self._calculate_complexity_score() + } + + def get_file_summary(self, file_path: str) -> Dict[str, Any]: + """Get detailed file summary.""" + if GRAPH_SITTER_AVAILABLE and self.codebase: + file_obj = self.codebase.get_file(file_path) + if file_obj: + return get_file_summary(file_obj) + + # Fallback implementation + file_info = next((f for f in self.files if f['path'] == file_path), None) + if not file_info: + return {"error": f"File not found: {file_path}"} + + file_functions = [f for f in self.functions if f['file_path'] == file_path] + file_classes = [c for c in self.classes if c['file_path'] == file_path] + file_imports = [i for i in self.imports if i['file_path'] == file_path] + + return { + "path": file_path, + "lines_of_code": file_info['line_count'], + "functions": len(file_functions), + "classes": len(file_classes), + "imports": len(file_imports), + "complexity": self._calculate_file_complexity(file_path) + } + + def get_function_analysis(self, function_name: str) -> Dict[str, Any]: + """Get detailed analysis for a specific function.""" + if GRAPH_SITTER_AVAILABLE and self.codebase: + func = self.codebase.get_function(function_name) + if func: + return get_function_summary(func) + + # Fallback implementation + func_info = next((f for f in self.functions if f['name'] == function_name), None) + if not func_info: + return {"error": f"Function not found: {function_name}"} + + return { + 'name': func_info['name'], + 'file_path': func_info['file_path'], + 'line_number': func_info['line_number'], + 'parameters': func_info.get('parameters', []), + 'decorators': func_info.get('decorators', []), + 'is_async': func_info.get('is_async', False), + 'complexity': self._calculate_function_complexity(func_info), + 'usages': self._find_function_usages(function_name) + } + + def get_class_analysis(self, class_name: str) -> Dict[str, Any]: + """Get detailed analysis for a specific class.""" + if GRAPH_SITTER_AVAILABLE and self.codebase: + cls = self.codebase.get_class(class_name) + if cls: + return get_class_summary(cls) + + # Fallback implementation + class_info = next((c for c in self.classes if c['name'] == class_name), None) + if not class_info: + return {"error": f"Class not found: {class_name}"} + + class_methods = [f for f in self.functions + if f['file_path'] == class_info['file_path'] + and f['line_number'] > class_info['line_number']] + + return { + 'name': class_info['name'], + 'file_path': class_info['file_path'], + 'line_number': class_info['line_number'], + 'methods': len(class_methods), + 'bases': class_info.get('bases', []), + 'decorators': class_info.get('decorators', []), + 'complexity': self._calculate_class_complexity(class_info) + } + + def get_symbol_analysis(self, symbol_name: str) -> Dict[str, Any]: + """Get detailed analysis for a specific symbol.""" + if GRAPH_SITTER_AVAILABLE and self.codebase: + symbol = self.codebase.get_symbol(symbol_name) + if symbol: + return get_symbol_summary(symbol) + + # Fallback implementation + symbol_info = next((s for s in self.symbols if s['name'] == symbol_name), None) + if not symbol_info: + return {"error": f"Symbol not found: {symbol_name}"} + + return { + 'name': symbol_info['name'], + 'type': 'function' if 'parameters' in symbol_info else 'class', + 'file_path': symbol_info['file_path'], + 'line_number': symbol_info['line_number'], + 'usage_count': len(self._find_symbol_usages(symbol_name)) + } + + def _calculate_complexity_score(self) -> float: + """Calculate overall codebase complexity score.""" + if not self.functions: + return 0.0 + + total_complexity = sum(self._calculate_function_complexity(f) for f in self.functions) + return total_complexity / len(self.functions) + + def _calculate_file_complexity(self, file_path: str) -> float: + """Calculate complexity score for a specific file.""" + file_functions = [f for f in self.functions if f['file_path'] == file_path] + if not file_functions: + return 0.0 + + total_complexity = sum(self._calculate_function_complexity(f) for f in file_functions) + return total_complexity / len(file_functions) + + def _calculate_function_complexity(self, func_info: Dict) -> float: + """Calculate complexity score for a function.""" + # Basic complexity calculation based on parameters and decorators + base_complexity = 1.0 + param_complexity = len(func_info.get('parameters', [])) * 0.2 + decorator_complexity = len(func_info.get('decorators', [])) * 0.1 + async_complexity = 0.3 if func_info.get('is_async', False) else 0.0 + + return base_complexity + param_complexity + decorator_complexity + async_complexity + + def _calculate_class_complexity(self, class_info: Dict) -> float: + """Calculate complexity score for a class.""" + base_complexity = 2.0 + base_complexity += len(class_info.get('bases', [])) * 0.5 + base_complexity += len(class_info.get('decorators', [])) * 0.2 + return base_complexity + + def _find_function_usages(self, function_name: str) -> List[Dict]: + """Find usages of a function across the codebase.""" + usages = [] + for file_info in self.files: + try: + content = file_info['content'] + lines = content.splitlines() + for i, line in enumerate(lines): + if function_name in line and 'def ' not in line: + usages.append({ + 'file_path': file_info['path'], + 'line_number': i + 1, + 'context': line.strip() + }) + except Exception: + continue + return usages + + def _find_symbol_usages(self, symbol_name: str) -> List[Dict]: + """Find usages of a symbol across the codebase.""" + return self._find_function_usages(symbol_name) + + +class LSPClient: + """JSON-RPC client for Language Server Protocol communication.""" + + def __init__(self, server_command: List[str], cwd: Optional[str] = None): + self.server_command = server_command + self.cwd = cwd or os.getcwd() + self.process = None + self.request_id = 0 + self.responses = {} + self.notifications = [] + self.diagnostics = {} + self._running = False + + def start(self): + """Start the language server process.""" + try: + self.process = subprocess.Popen( + self.server_command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=self.cwd, + text=False + ) + self._running = True + + # Start reader thread + self.reader_thread = threading.Thread(target=self._read_messages) + self.reader_thread.daemon = True + self.reader_thread.start() + + # Initialize the server + self._initialize() + + except Exception as e: + logging.error(f"Failed to start LSP server: {e}") + return False + + return True + + def stop(self): + """Stop the language server process.""" + self._running = False + if self.process: + self.process.terminate() + self.process.wait() + + def _initialize(self): + """Initialize the language server.""" + init_params = { + "processId": os.getpid(), + "clientInfo": {"name": "AnalysisBackend", "version": "1.0.0"}, + "rootUri": f"file://{self.cwd}", + "capabilities": { + "textDocument": { + "publishDiagnostics": {"relatedInformation": True} + } + } + } + + response = self.send_request("initialize", init_params) + if response: + self.send_notification("initialized", {}) + + def send_request(self, method: str, params: Any) -> Optional[Dict]: + """Send a request to the language server.""" + if not self.process: + return None + + self.request_id += 1 + request = { + "jsonrpc": "2.0", + "id": self.request_id, + "method": method, + "params": params + } + + try: + message = self._create_message(request) + self.process.stdin.write(message) + self.process.stdin.flush() + + # Wait for response (simplified - in real implementation would be async) + time.sleep(0.1) + return self.responses.get(self.request_id) + + except Exception as e: + logging.error(f"Error sending request: {e}") + return None + + def send_notification(self, method: str, params: Any): + """Send a notification to the language server.""" + if not self.process: + return + + notification = { + "jsonrpc": "2.0", + "method": method, + "params": params + } + + try: + message = self._create_message(notification) + self.process.stdin.write(message) + self.process.stdin.flush() + except Exception as e: + logging.error(f"Error sending notification: {e}") + + def _create_message(self, payload: Dict) -> bytes: + """Create a properly formatted LSP message.""" + body = json.dumps(payload, separators=(',', ':')).encode('utf-8') + header = f"Content-Length: {len(body)}\r\n\r\n".encode('utf-8') + return header + body + + def _read_messages(self): + """Read messages from the language server.""" + while self._running and self.process: + try: + # Read header + header = b"" + while not header.endswith(b"\r\n\r\n"): + chunk = self.process.stdout.read(1) + if not chunk: + break + header += chunk + + if not header: + break + + # Parse content length + content_length = 0 + for line in header.decode('utf-8').split('\r\n'): + if line.startswith('Content-Length:'): + content_length = int(line.split(':')[1].strip()) + break + + if content_length == 0: + continue + + # Read body + body = self.process.stdout.read(content_length) + if not body: + break + + # Parse message + try: + message = json.loads(body.decode('utf-8')) + self._handle_message(message) + except json.JSONDecodeError as e: + logging.error(f"Failed to parse LSP message: {e}") + + except Exception as e: + logging.error(f"Error reading LSP messages: {e}") + break + + def _handle_message(self, message: Dict): + """Handle incoming message from language server.""" + if "id" in message: + # Response to our request + self.responses[message["id"]] = message + elif message.get("method") == "textDocument/publishDiagnostics": + # Diagnostic notification + params = message.get("params", {}) + uri = params.get("uri", "") + diagnostics = params.get("diagnostics", []) + self.diagnostics[uri] = diagnostics + else: + # Other notification + self.notifications.append(message) + + def get_diagnostics(self, file_path: str) -> List[Diagnostic]: + """Get diagnostics for a specific file.""" + file_uri = f"file://{os.path.abspath(file_path)}" + raw_diagnostics = self.diagnostics.get(file_uri, []) + + diagnostics = [] + for diag in raw_diagnostics: + try: + range_info = diag.get("range", {}) + start_pos = Position( + line=range_info.get("start", {}).get("line", 0), + character=range_info.get("start", {}).get("character", 0) + ) + end_pos = Position( + line=range_info.get("end", {}).get("line", 0), + character=range_info.get("end", {}).get("character", 0) + ) + + diagnostic = Diagnostic( + range=Range(start=start_pos, end=end_pos), + severity=DiagnosticSeverity(diag.get("severity", 1)), + code=diag.get("code"), + source=diag.get("source"), + message=diag.get("message", ""), + related_information=diag.get("relatedInformation"), + tags=diag.get("tags") + ) + diagnostics.append(diagnostic) + except Exception as e: + logging.warning(f"Failed to parse diagnostic: {e}") + + return diagnostics + + +class RuffIntegration: + """Integration with Ruff for comprehensive Python linting and type checking.""" + + def __init__(self, target_path: str): + self.target_path = target_path + self.config_file = self._find_ruff_config() + + def _find_ruff_config(self) -> Optional[str]: + """Find Ruff configuration file.""" + search_paths = [ + "ruff.toml", + "pyproject.toml", + ".ruff.toml" + ] + + current_dir = Path(self.target_path if os.path.isdir(self.target_path) else os.path.dirname(self.target_path)) + + while current_dir != current_dir.parent: + for config_file in search_paths: + config_path = current_dir / config_file + if config_path.exists(): + return str(config_path) + current_dir = current_dir.parent + + return None + + def run_ruff_check(self) -> List[AnalysisError]: + """Run Ruff check and parse results.""" + cmd = ["ruff", "check", "--output-format=json", self.target_path] + + if self.config_file: + cmd.extend(["--config", self.config_file]) + + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=300 + ) + + errors = [] + if result.stdout: + try: + ruff_results = json.loads(result.stdout) + for issue in ruff_results: + error = AnalysisError( + file_path=issue.get("filename", ""), + line_number=issue.get("location", {}).get("row"), + column_number=issue.get("location", {}).get("column"), + error_type="ruff", + severity=self._map_ruff_severity(issue.get("severity", "error")), + message=issue.get("message", ""), + tool_source="ruff", + error_code=issue.get("code"), + category=self._categorize_ruff_error(issue.get("code", "")), + fix_suggestion=issue.get("fix", {}).get("message") + ) + errors.append(error) + except json.JSONDecodeError: + # Fallback to text parsing + errors.extend(self._parse_ruff_text_output(result.stdout)) + + return errors + + except subprocess.TimeoutExpired: + return [AnalysisError( + file_path=self.target_path, + line_number=None, + column_number=None, + error_type="timeout", + severity=DiagnosticSeverity.ERROR, + message="Ruff analysis timed out", + tool_source="ruff" + )] + except Exception as e: + return [AnalysisError( + file_path=self.target_path, + line_number=None, + column_number=None, + error_type="execution_error", + severity=DiagnosticSeverity.ERROR, + message=f"Ruff execution failed: {str(e)}", + tool_source="ruff" + )] + + def run_ruff_format_check(self) -> List[AnalysisError]: + """Run Ruff format check.""" + cmd = ["ruff", "format", "--check", "--diff", self.target_path] + + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=120 + ) + + errors = [] + if result.returncode != 0 and result.stdout: + # Parse diff output for formatting issues + errors.append(AnalysisError( + file_path=self.target_path, + line_number=None, + column_number=None, + error_type="formatting", + severity=DiagnosticSeverity.WARNING, + message="Code formatting issues detected", + tool_source="ruff-format", + category=ErrorCategory.STYLE.value, + code_context=result.stdout[:500] # First 500 chars of diff + )) + + return errors + + except Exception as e: + return [] + + def _map_ruff_severity(self, severity: str) -> DiagnosticSeverity: + """Map Ruff severity to diagnostic severity.""" + mapping = { + "error": DiagnosticSeverity.ERROR, + "warning": DiagnosticSeverity.WARNING, + "info": DiagnosticSeverity.INFORMATION, + "hint": DiagnosticSeverity.HINT + } + return mapping.get(severity.lower(), DiagnosticSeverity.WARNING) + + def _categorize_ruff_error(self, code: str) -> str: + """Categorize Ruff error based on error code.""" + if not code: + return ErrorCategory.STYLE.value + + code_prefix = code.split()[0] if ' ' in code else code[:3] + + category_mapping = { + "E": ErrorCategory.STYLE.value, + "W": ErrorCategory.STYLE.value, + "F": ErrorCategory.LOGIC.value, + "C": ErrorCategory.DESIGN.value, + "N": ErrorCategory.STYLE.value, + "B": ErrorCategory.LOGIC.value, + "A": ErrorCategory.COMPATIBILITY.value, + "COM": ErrorCategory.STYLE.value, + "CPY": ErrorCategory.DOCUMENTATION.value, + "DJ": ErrorCategory.DESIGN.value, + "EM": ErrorCategory.MAINTAINABILITY.value, + "EXE": ErrorCategory.SECURITY.value, + "FA": ErrorCategory.COMPATIBILITY.value, + "FBT": ErrorCategory.DESIGN.value, + "FLY": ErrorCategory.PERFORMANCE.value, + "FURB": ErrorCategory.MAINTAINABILITY.value, + "G": ErrorCategory.STYLE.value, + "I": ErrorCategory.IMPORT.value, + "ICN": ErrorCategory.STYLE.value, + "INP": ErrorCategory.IMPORT.value, + "INT": ErrorCategory.TYPE.value, + "ISC": ErrorCategory.STYLE.value, + "LOG": ErrorCategory.MAINTAINABILITY.value, + "NPY": ErrorCategory.PERFORMANCE.value, + "PD": ErrorCategory.PERFORMANCE.value, + "PERF": ErrorCategory.PERFORMANCE.value, + "PGH": ErrorCategory.MAINTAINABILITY.value, + "PIE": ErrorCategory.LOGIC.value, + "PL": ErrorCategory.LOGIC.value, + "PT": ErrorCategory.TESTING.value, + "PTH": ErrorCategory.COMPATIBILITY.value, + "PYI": ErrorCategory.TYPE.value, + "Q": ErrorCategory.STYLE.value, + "RET": ErrorCategory.LOGIC.value, + "RSE": ErrorCategory.LOGIC.value, + "RUF": ErrorCategory.MAINTAINABILITY.value, + "S": ErrorCategory.SECURITY.value, + "SIM": ErrorCategory.MAINTAINABILITY.value, + "SLF": ErrorCategory.DESIGN.value, + "SLOT": ErrorCategory.PERFORMANCE.value, + "T": ErrorCategory.STYLE.value, + "TCH": ErrorCategory.TYPE.value, + "TD": ErrorCategory.DOCUMENTATION.value, + "TID": ErrorCategory.IMPORT.value, + "TRY": ErrorCategory.LOGIC.value, + "UP": ErrorCategory.COMPATIBILITY.value, + "YTT": ErrorCategory.COMPATIBILITY.value + } + + for prefix, category in category_mapping.items(): + if code.startswith(prefix): + return category + + return ErrorCategory.STYLE.value + + def _parse_ruff_text_output(self, output: str) -> List[AnalysisError]: + """Parse Ruff text output when JSON is not available.""" + errors = [] + for line in output.splitlines(): + if ':' in line and '.py:' in line: + try: + parts = line.split(':') + if len(parts) >= 4: + file_path = parts[0] + line_num = int(parts[1]) + col_num = int(parts[2]) + message = ':'.join(parts[3:]).strip() + + # Extract error code if present + code_match = re.search(r'\[([A-Z0-9]+)\]', message) + error_code = code_match.group(1) if code_match else None + + error = AnalysisError( + file_path=file_path, + line_number=line_num, + column_number=col_num, + error_type="ruff", + severity=DiagnosticSeverity.WARNING, + message=message, + tool_source="ruff", + error_code=error_code, + category=self._categorize_ruff_error(error_code or "") + ) + errors.append(error) + except (ValueError, IndexError): + continue + + return errors + + +class ComprehensiveAnalyzer: + """Main analyzer class that orchestrates all analysis tools and provides comprehensive error reporting.""" + + # Extended tool configuration with categorization + COMPREHENSIVE_TOOLS = { + # Core Python linting and type checking + "ruff": ToolConfig( + "ruff", "ruff", + args=["check", "--output-format=json", "--select=ALL"], + category=ErrorCategory.STYLE.value, + priority=1 + ), + "mypy": ToolConfig( + "mypy", "mypy", + args=["--strict", "--show-error-codes", "--show-column-numbers", "--pretty"], + category=ErrorCategory.TYPE.value, + priority=1 + ), + "pyright": ToolConfig( + "pyright", "pyright", + args=["--outputjson"], + category=ErrorCategory.TYPE.value, + priority=1, + timeout=600 + ), + "pylint": ToolConfig( + "pylint", "pylint", + args=["--output-format=json", "--reports=y", "--score=y"], + category=ErrorCategory.LOGIC.value, + priority=2 + ), + + # Security analysis + "bandit": ToolConfig( + "bandit", "bandit", + args=["-r", "-f", "json", "--severity-level=low", "--confidence-level=low"], + category=ErrorCategory.SECURITY.value, + priority=1 + ), + "safety": ToolConfig( + "safety", "safety", + args=["check", "--json", "--full-report"], + category=ErrorCategory.SECURITY.value, + priority=1, + requires_network=True + ), + "semgrep": ToolConfig( + "semgrep", "semgrep", + args=["--config=p/python", "--json", "--severity=WARNING"], + category=ErrorCategory.SECURITY.value, + priority=2, + requires_network=True + ), + + # Code quality and complexity + "radon": ToolConfig( + "radon", "radon", + args=["cc", "-j", "--total-average"], + category=ErrorCategory.MAINTAINABILITY.value, + priority=2 + ), + "xenon": ToolConfig( + "xenon", "xenon", + args=["--max-absolute=B", "--max-modules=B", "--max-average=C"], + category=ErrorCategory.MAINTAINABILITY.value, + priority=2 + ), + "cohesion": ToolConfig( + "cohesion", "cohesion", + args=["--below", "80", "--format", "json"], + category=ErrorCategory.DESIGN.value, + priority=3 + ), + + # Import and dependency analysis + "isort": ToolConfig( + "isort", "isort", + args=["--check-only", "--diff", "--profile=black"], + category=ErrorCategory.IMPORT.value, + priority=2 + ), + "vulture": ToolConfig( + "vulture", "vulture", + args=["--min-confidence=60", "--sort-by-size"], + category=ErrorCategory.LOGIC.value, + priority=2 + ), + "pydeps": ToolConfig( + "pydeps", "pydeps", + args=["--max-bacon=3", "--show-cycles", "--format", "json"], + category=ErrorCategory.DEPENDENCIES.value, + priority=3 + ), + + # Style and formatting + "black": ToolConfig( + "black", "black", + args=["--check", "--diff"], + category=ErrorCategory.STYLE.value, + priority=2 + ), + "pycodestyle": ToolConfig( + "pycodestyle", "pycodestyle", + args=["--statistics", "--count"], + category=ErrorCategory.STYLE.value, + priority=3 + ), + "pydocstyle": ToolConfig( + "pydocstyle", "pydocstyle", + args=["--convention=google"], + category=ErrorCategory.DOCUMENTATION.value, + priority=2 + ), + + # Performance analysis + "py-spy": ToolConfig( + "py-spy", "py-spy", + args=["record", "-o", "/tmp/profile.svg", "--", "python"], + category=ErrorCategory.PERFORMANCE.value, + priority=3, + enabled=False # Requires special setup + ), + + # Testing analysis + "pytest": ToolConfig( + "pytest", "pytest", + args=["--collect-only", "--quiet"], + category=ErrorCategory.TESTING.value, + priority=3 + ), + "coverage": ToolConfig( + "coverage", "coverage", + args=["report", "--format=json"], + category=ErrorCategory.TESTING.value, + priority=3 + ), + + # Compatibility analysis + "pyupgrade": ToolConfig( + "pyupgrade", "pyupgrade", + args=["--py312-plus"], + category=ErrorCategory.COMPATIBILITY.value, + priority=3 + ), + "modernize": ToolConfig( + "modernize", "python-modernize", + args=["--print", "--no-diffs"], + category=ErrorCategory.COMPATIBILITY.value, + priority=3 + ), + + # Additional analysis tools + "pyflakes": ToolConfig( + "pyflakes", "pyflakes", + category=ErrorCategory.LOGIC.value, + priority=2 + ), + "mccabe": ToolConfig( + "mccabe", "python", + args=["-m", "mccabe", "--min", "5"], + category=ErrorCategory.MAINTAINABILITY.value, + priority=3 + ), + "dodgy": ToolConfig( + "dodgy", "dodgy", + args=["--ignore-paths=venv,.venv,env,.env,__pycache__,.git"], + category=ErrorCategory.SECURITY.value, + priority=3 + ), + "dlint": ToolConfig( + "dlint", "dlint", + category=ErrorCategory.SECURITY.value, + priority=3 + ), + "codespell": ToolConfig( + "codespell", "codespell", + args=["--quiet-level=2"], + category=ErrorCategory.DOCUMENTATION.value, + priority=3 + ), + + # Advanced analysis + "prospector": ToolConfig( + "prospector", "prospector", + args=["--output-format=json", "--full-pep8"], + category=ErrorCategory.LOGIC.value, + priority=2 + ), + "pyanalyze": ToolConfig( + "pyanalyze", "python", + args=["-m", "pyanalyze"], + category=ErrorCategory.TYPE.value, + priority=3 + ) + } + + def __init__(self, target_path: str, config: Optional[Dict] = None, verbose: bool = False): + """Initialize the comprehensive analyzer.""" + self.target_path = os.path.abspath(target_path) + self.config = config or {} + self.verbose = verbose + self.tools_config = self.COMPREHENSIVE_TOOLS.copy() + self.graph_analysis = None + self.lsp_client = None + self.ruff_integration = None + self.analysis_cache = {} + self.error_database = ErrorDatabase() + + # Initialize components + self._initialize_graph_analysis() + self._initialize_lsp_client() + self._initialize_ruff_integration() + + # Apply configuration + self._apply_config() + + def _initialize_graph_analysis(self): + """Initialize graph-sitter analysis.""" + try: + self.graph_analysis = GraphSitterAnalysis(self.target_path, self.config) + if self.verbose: + print("Graph-sitter analysis initialized successfully") + except Exception as e: + logging.error(f"Failed to initialize graph-sitter analysis: {e}") + self.graph_analysis = None + + def _initialize_lsp_client(self): + """Initialize LSP client for real-time error detection.""" + # Try to find and start a Python language server + lsp_commands = [ + ["pylsp"], # Python LSP Server + ["pyright-langserver", "--stdio"], # Pyright + ["jedi-language-server"] # Jedi + ] + + for cmd in lsp_commands: + try: + if self._command_exists(cmd[0]): + self.lsp_client = LSPClient(cmd, self.target_path) + if self.lsp_client.start(): + if self.verbose: + print(f"LSP client started with {cmd[0]}") + break + else: + self.lsp_client = None + except Exception as e: + if self.verbose: + print(f"Failed to start LSP with {cmd[0]}: {e}") + continue + + def _initialize_ruff_integration(self): + """Initialize Ruff integration.""" + try: + self.ruff_integration = RuffIntegration(self.target_path) + if self.verbose: + print("Ruff integration initialized") + except Exception as e: + logging.error(f"Failed to initialize Ruff integration: {e}") + + def _command_exists(self, command: str) -> bool: + """Check if a command exists in the system.""" + try: + subprocess.run( + ["which", command], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True + ) + return True + except subprocess.CalledProcessError: + return False + + def _apply_config(self): + """Apply configuration settings.""" + if "tools" in self.config: + for tool_name, tool_config in self.config["tools"].items(): + if tool_name in self.tools_config: + self.tools_config[tool_name].enabled = tool_config.get("enabled", True) + if "args" in tool_config: + self.tools_config[tool_name].args = tool_config["args"] + if "timeout" in tool_config: + self.tools_config[tool_name].timeout = tool_config["timeout"] + + def run_comprehensive_analysis(self) -> Dict[str, Any]: + """Run comprehensive analysis using all available tools and methods.""" + start_time = time.time() + + analysis_results = { + "metadata": { + "target_path": self.target_path, + "analysis_start_time": time.strftime("%Y-%m-%d %H:%M:%S"), + "python_version": sys.version, + "tools_used": [], + "analysis_duration": 0 + }, + "graph_sitter_analysis": {}, + "lsp_diagnostics": {}, + "static_analysis": {}, + "errors": [], + "summary": {}, + "entrypoints": [], + "dead_code": [], + "categorized_errors": {}, + "performance_metrics": {}, + "security_issues": [], + "dependency_analysis": {}, + "quality_metrics": {} + } + + try: + # 1. Graph-sitter Analysis + if self.graph_analysis: + analysis_results["graph_sitter_analysis"] = self._run_graph_sitter_analysis() + analysis_results["metadata"]["tools_used"].append("graph-sitter") + + # 2. LSP Diagnostics + if self.lsp_client: + analysis_results["lsp_diagnostics"] = self._run_lsp_analysis() + analysis_results["metadata"]["tools_used"].append("lsp") + + # 3. Ruff Integration + if self.ruff_integration: + ruff_errors = self._run_ruff_analysis() + analysis_results["errors"].extend(ruff_errors) + analysis_results["metadata"]["tools_used"].append("ruff") + + # 4. Static Analysis Tools + static_analysis_results = self._run_static_analysis() + analysis_results["static_analysis"] = static_analysis_results + analysis_results["errors"].extend(self._extract_errors_from_static_analysis(static_analysis_results)) + + # 5. Advanced Analysis + analysis_results["entrypoints"] = self._find_entrypoints() + analysis_results["dead_code"] = self._find_dead_code() + analysis_results["dependency_analysis"] = self._analyze_dependencies() + analysis_results["performance_metrics"] = self._calculate_performance_metrics() + analysis_results["quality_metrics"] = self._calculate_quality_metrics() + + # 6. Error Categorization and Prioritization + analysis_results["categorized_errors"] = self._categorize_errors(analysis_results["errors"]) + + # 7. Generate Summary + analysis_results["summary"] = self._generate_comprehensive_summary(analysis_results) + + except Exception as e: + logging.error(f"Error during comprehensive analysis: {e}") + analysis_results["fatal_error"] = str(e) + + finally: + # Cleanup + if self.lsp_client: + self.lsp_client.stop() + + # Calculate duration + end_time = time.time() + analysis_results["metadata"]["analysis_duration"] = round(end_time - start_time, 2) + analysis_results["metadata"]["analysis_end_time"] = time.strftime("%Y-%m-%d %H:%M:%S") + + return analysis_results + + def _run_graph_sitter_analysis(self) -> Dict[str, Any]: + """Run comprehensive graph-sitter analysis.""" + if not self.graph_analysis: + return {"error": "Graph-sitter analysis not available"} + + try: + results = { + "codebase_summary": self.graph_analysis.get_codebase_summary(), + "files_analysis": {}, + "symbols_analysis": {}, + "dependency_graph": {}, + "inheritance_hierarchy": {}, + "call_graph": {}, + "import_graph": {} + } + + # Analyze each file + for file_info in self.graph_analysis.files[:50]: # Limit for performance + file_path = file_info.get('path', '') if isinstance(file_info, dict) else getattr(file_info, 'path', '') + if file_path: + results["files_analysis"][file_path] = self.graph_analysis.get_file_summary(file_path) + + # Analyze symbols + for symbol_info in self.graph_analysis.symbols[:100]: # Limit for performance + symbol_name = symbol_info.get('name', '') if isinstance(symbol_info, dict) else getattr(symbol_info, 'name', '') + if symbol_name: + symbol_analysis = self.graph_analysis.get_symbol_analysis(symbol_name) + results["symbols_analysis"][symbol_name] = symbol_analysis + + # Additional analysis based on symbol type + if symbol_analysis.get('type') == 'function': + func_analysis = self.graph_analysis.get_function_analysis(symbol_name) + results["symbols_analysis"][symbol_name].update(func_analysis) + elif symbol_analysis.get('type') == 'class': + class_analysis = self.graph_analysis.get_class_analysis(symbol_name) + results["symbols_analysis"][symbol_name].update(class_analysis) + + # Build dependency graph + results["dependency_graph"] = self._build_dependency_graph() + + # Build inheritance hierarchy + results["inheritance_hierarchy"] = self._build_inheritance_hierarchy() + + # Build call graph + results["call_graph"] = self._build_call_graph() + + # Build import graph + results["import_graph"] = self._build_import_graph() + + return results + + except Exception as e: + logging.error(f"Error in graph-sitter analysis: {e}") + return {"error": str(e)} + + def _run_lsp_analysis(self) -> Dict[str, Any]: + """Run LSP analysis for real-time error detection.""" + if not self.lsp_client: + return {"error": "LSP client not available"} + + try: + results = { + "diagnostics_by_file": {}, + "total_diagnostics": 0, + "error_summary": {} + } + + # Get diagnostics for each Python file + if os.path.isfile(self.target_path): + files_to_analyze = [self.target_path] + else: + files_to_analyze = [] + for root, dirs, files in os.walk(self.target_path): + for file in files: + if file.endswith('.py'): + files_to_analyze.append(os.path.join(root, file)) + + for file_path in files_to_analyze[:20]: # Limit for performance + # Send textDocument/didOpen notification + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + self.lsp_client.send_notification("textDocument/didOpen", { + "textDocument": { + "uri": f"file://{file_path}", + "languageId": "python", + "version": 1, + "text": content + } + }) + + # Wait a bit for diagnostics + time.sleep(0.5) + + # Get diagnostics + diagnostics = self.lsp_client.get_diagnostics(file_path) + if diagnostics: + results["diagnostics_by_file"][file_path] = [ + { + "range": { + "start": {"line": d.range.start.line, "character": d.range.start.character}, + "end": {"line": d.range.end.line, "character": d.range.end.character} + }, + "severity": d.severity.value, + "message": d.message, + "source": d.source, + "code": d.code + } + for d in diagnostics + ] + results["total_diagnostics"] += len(diagnostics) + + # Generate error summary + results["error_summary"] = self._summarize_lsp_diagnostics(results["diagnostics_by_file"]) + + return results + + except Exception as e: + logging.error(f"Error in LSP analysis: {e}") + return {"error": str(e)} + + def _run_ruff_analysis(self) -> List[AnalysisError]: + """Run comprehensive Ruff analysis.""" + if not self.ruff_integration: + return [] + + errors = [] + + try: + # Run standard Ruff check + ruff_errors = self.ruff_integration.run_ruff_check() + errors.extend(ruff_errors) + + # Run Ruff format check + format_errors = self.ruff_integration.run_ruff_format_check() + errors.extend(format_errors) + + # Additional Ruff configurations for comprehensive analysis + additional_checks = [ + ["--select=ALL", "--ignore=COM812,ISC001"], # All rules except conflicting ones + ["--select=E,W,F", "--statistics"], # Core errors and warnings + ["--select=I", "--show-fixes"], # Import sorting + ["--select=N", "--show-source"], # Naming conventions + ["--select=S", "--show-source"], # Security + ["--select=B", "--show-source"], # Bugbear + ["--select=C90", "--show-source"], # Complexity + ["--select=PL", "--show-source"], # Pylint + ["--select=RUF", "--show-source"] # Ruff-specific + ] + + for check_args in additional_checks: + try: + cmd = ["ruff", "check"] + check_args + [self.target_path] + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=120 + ) + + if result.stdout: + additional_errors = self._parse_ruff_output(result.stdout, check_args[0]) + errors.extend(additional_errors) + + except Exception as e: + if self.verbose: + print(f"Additional Ruff check failed: {e}") + continue + + except Exception as e: + logging.error(f"Error in Ruff analysis: {e}") + + return errors + + def _run_static_analysis(self) -> Dict[str, Any]: + """Run all configured static analysis tools.""" + results = {} + + # Group tools by category for organized execution + tools_by_category = defaultdict(list) + for tool_name, tool_config in self.tools_config.items(): + if tool_config.enabled: + tools_by_category[tool_config.category].append((tool_name, tool_config)) + + # Run tools in parallel within each category + for category, tools in tools_by_category.items(): + if self.verbose: + print(f"Running {category} analysis tools...") + + category_results = {} + + with ThreadPoolExecutor(max_workers=3) as executor: + future_to_tool = { + executor.submit(self._run_single_tool, tool_name, tool_config): tool_name + for tool_name, tool_config in tools + } + + for future in as_completed(future_to_tool): + tool_name = future_to_tool[future] + try: + tool_result = future.result() + category_results[tool_name] = tool_result + except Exception as e: + category_results[tool_name] = { + "error": str(e), + "success": False + } + + results[category] = category_results + + return results + + def _run_single_tool(self, tool_name: str, tool_config: ToolConfig) -> Dict[str, Any]: + """Run a single analysis tool.""" + if not self._command_exists(tool_config.command.split()[0]): + return { + "error": f"Tool {tool_name} not found", + "success": False, + "skipped": True + } + + # Build command + cmd = [tool_config.command] + tool_config.args + [self.target_path] + if tool_config.command == "python": + cmd = tool_config.args + [self.target_path] + + # Find configuration file + config_file = self._find_tool_config(tool_name, tool_config) + if config_file: + cmd = self._add_config_to_command(cmd, tool_name, config_file) + + try: + start_time = time.time() + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=tool_config.timeout, + cwd=os.path.dirname(self.target_path) if os.path.isfile(self.target_path) else self.target_path + ) + + execution_time = time.time() - start_time + + return { + "command": " ".join(cmd), + "returncode": result.returncode, + "stdout": result.stdout, + "stderr": result.stderr, + "success": result.returncode == 0, + "execution_time": execution_time, + "tool_category": tool_config.category, + "tool_priority": tool_config.priority + } + + except subprocess.TimeoutExpired: + return { + "error": f"Tool {tool_name} timed out after {tool_config.timeout}s", + "success": False, + "timeout": True, + "tool_category": tool_config.category + } + except Exception as e: + return { + "error": str(e), + "success": False, + "tool_category": tool_config.category + } + + def _find_tool_config(self, tool_name: str, tool_config: ToolConfig) -> Optional[str]: + """Find configuration file for a tool.""" + if not tool_config.config_file: + return None + + search_dir = self.target_path if os.path.isdir(self.target_path) else os.path.dirname(self.target_path) + current_dir = Path(search_dir) + + while current_dir != current_dir.parent: + config_path = current_dir / tool_config.config_file + if config_path.exists(): + return str(config_path) + current_dir = current_dir.parent + + return None + + def _add_config_to_command(self, cmd: List[str], tool_name: str, config_file: str) -> List[str]: + """Add configuration file to command.""" + config_mappings = { + "pylint": f"--rcfile={config_file}", + "mypy": f"--config-file={config_file}", + "flake8": f"--config={config_file}", + "bandit": f"--configfile={config_file}" + } + + if tool_name in config_mappings: + cmd.insert(-1, config_mappings[tool_name]) + + return cmd + + def _extract_errors_from_static_analysis(self, static_results: Dict) -> List[AnalysisError]: + """Extract and convert static analysis results to AnalysisError objects.""" + errors = [] + + for category, tools in static_results.items(): + for tool_name, tool_result in tools.items(): + if not tool_result.get("success", False) or tool_result.get("stdout", "").strip(): + errors.extend(self._parse_tool_output(tool_name, tool_result, category)) + + return errors + + def _parse_tool_output(self, tool_name: str, tool_result: Dict, category: str) -> List[AnalysisError]: + """Parse tool output and convert to AnalysisError objects.""" + errors = [] + + if tool_result.get("timeout"): + errors.append(AnalysisError( + file_path=self.target_path, + line_number=None, + column_number=None, + error_type="timeout", + severity=DiagnosticSeverity.ERROR, + message=f"{tool_name} analysis timed out", + tool_source=tool_name, + category=category + )) + return errors + + output = tool_result.get("stdout", "") + tool_result.get("stderr", "") + + # Tool-specific parsing + if tool_name == "pylint": + errors.extend(self._parse_pylint_output(output, tool_name, category)) + elif tool_name == "mypy": + errors.extend(self._parse_mypy_output(output, tool_name, category)) + elif tool_name == "pyright": + errors.extend(self._parse_pyright_output(output, tool_name, category)) + elif tool_name == "bandit": + errors.extend(self._parse_bandit_output(output, tool_name, category)) + elif tool_name == "safety": + errors.extend(self._parse_safety_output(output, tool_name, category)) + elif tool_name == "vulture": + errors.extend(self._parse_vulture_output(output, tool_name, category)) + elif tool_name == "radon": + errors.extend(self._parse_radon_output(output, tool_name, category)) + else: + # Generic parsing + errors.extend(self._parse_generic_output(output, tool_name, category)) + + return errors + + def _parse_pylint_output(self, output: str, tool_name: str, category: str) -> List[AnalysisError]: + """Parse Pylint output.""" + errors = [] + + # Try JSON format first + try: + pylint_data = json.loads(output) + for item in pylint_data: + if isinstance(item, dict) and "path" in item: + error = AnalysisError( + file_path=item["path"], + line_number=item.get("line"), + column_number=item.get("column"), + error_type=item.get("type", "unknown"), + severity=self._map_pylint_severity(item.get("type", "")), + message=item.get("message", ""), + tool_source=tool_name, + error_code=item.get("message-id"), + category=self._categorize_pylint_error(item.get("message-id", "")), + confidence=item.get("confidence", 1.0) / 10.0 # Convert to 0-1 scale + ) + errors.append(error) + except json.JSONDecodeError: + # Fallback to text parsing + for line in output.splitlines(): + if ".py:" in line and ": " in line: + match = re.match(r"(.+?):(\d+):(\d+): (.+?): (.+)", line) + if match: + file_path, line_num, col_num, msg_type, message = match.groups() + error = AnalysisError( + file_path=file_path, + line_number=int(line_num), + column_number=int(col_num), + error_type=msg_type, + severity=self._map_pylint_severity(msg_type), + message=message, + tool_source=tool_name, + category=category + ) + errors.append(error) + + return errors + + def _parse_mypy_output(self, output: str, tool_name: str, category: str) -> List[AnalysisError]: + """Parse MyPy output.""" + errors = [] + + for line in output.splitlines(): + if ".py:" in line and ": " in line: + # MyPy format: file.py:line:column: error: message [error-code] + match = re.match(r"(.+?):(\d+):(?:(\d+):)?\s*(\w+):\s*(.+?)(?:\s*\[(.+?)\])?$", line) + if match: + file_path, line_num, col_num, severity, message, error_code = match.groups() + + error = AnalysisError( + file_path=file_path, + line_number=int(line_num), + column_number=int(col_num) if col_num else None, + error_type=severity, + severity=self._map_mypy_severity(severity), + message=message, + tool_source=tool_name, + error_code=error_code, + category=ErrorCategory.TYPE.value, + subcategory=self._categorize_mypy_error(error_code or message) + ) + errors.append(error) + + return errors + + def _parse_pyright_output(self, output: str, tool_name: str, category: str) -> List[AnalysisError]: + """Parse Pyright JSON output.""" + errors = [] + + try: + pyright_data = json.loads(output) + + for diagnostic in pyright_data.get("generalDiagnostics", []): + error = AnalysisError( + file_path=diagnostic.get("file", ""), + line_number=diagnostic.get("range", {}).get("start", {}).get("line"), + column_number=diagnostic.get("range", {}).get("start", {}).get("character"), + error_type=diagnostic.get("severity", "error"), + severity=self._map_pyright_severity(diagnostic.get("severity", "error")), + message=diagnostic.get("message", ""), + tool_source=tool_name, + error_code=diagnostic.get("rule"), + category=ErrorCategory.TYPE.value + ) + errors.append(error) + + except json.JSONDecodeError: + # Fallback to text parsing + for line in output.splitlines(): + if " - error:" in line or " - warning:" in line: + parts = line.split(" - ") + if len(parts) >= 2: + location = parts[0].strip() + severity_message = parts[1] + + # Extract file path and line number + location_match = re.match(r"(.+?):(\d+):(\d+)", location) + if location_match: + file_path, line_num, col_num = location_match.groups() + + severity = "error" if "error:" in severity_message else "warning" + message = severity_message.split(":", 1)[1].strip() if ":" in severity_message else severity_message + + error = AnalysisError( + file_path=file_path, + line_number=int(line_num), + column_number=int(col_num), + error_type=severity, + severity=DiagnosticSeverity.ERROR if severity == "error" else DiagnosticSeverity.WARNING, + message=message, + tool_source=tool_name, + category=ErrorCategory.TYPE.value + ) + errors.append(error) + + return errors + + def _parse_bandit_output(self, output: str, tool_name: str, category: str) -> List[AnalysisError]: + """Parse Bandit security analysis output.""" + errors = [] + + try: + bandit_data = json.loads(output) + + for result in bandit_data.get("results", []): + error = AnalysisError( + file_path=result.get("filename", ""), + line_number=result.get("line_number"), + column_number=result.get("col_offset"), + error_type="security", + severity=self._map_bandit_severity(result.get("issue_severity", "MEDIUM")), + message=result.get("issue_text", ""), + tool_source=tool_name, + error_code=result.get("test_id"), + category=ErrorCategory.SECURITY.value, + subcategory=result.get("issue_type", "unknown"), + confidence=self._map_bandit_confidence(result.get("issue_confidence", "MEDIUM")), + code_context=result.get("code") + ) + errors.append(error) + + except json.JSONDecodeError: + # Fallback to text parsing + current_file = None + for line in output.splitlines(): + if "Test results:" in line: + break + + if line.startswith(">>"): + # File indicator + current_file = line.replace(">>", "").strip() + elif "Issue:" in line and current_file: + # Parse issue line + issue_match = re.search(r"Issue: \[(.+?)\] (.+)", line) + if issue_match: + severity, message = issue_match.groups() + + error = AnalysisError( + file_path=current_file, + line_number=None, + column_number=None, + error_type="security", + severity=self._map_bandit_severity(severity), + message=message, + tool_source=tool_name, + category=ErrorCategory.SECURITY.value + ) + errors.append(error) + + return errors + + def _parse_safety_output(self, output: str, tool_name: str, category: str) -> List[AnalysisError]: + """Parse Safety vulnerability output.""" + errors = [] + + try: + # Safety can output JSON or text + if output.strip().startswith('[') or output.strip().startswith('{'): + safety_data = json.loads(output) + + vulnerabilities = safety_data if isinstance(safety_data, list) else safety_data.get("vulnerabilities", []) + + for vuln in vulnerabilities: + error = AnalysisError( + file_path=self.target_path, + line_number=None, + column_number=None, + error_type="vulnerability", + severity=DiagnosticSeverity.ERROR, + message=f"Vulnerability in {vuln.get('package_name', 'unknown')}: {vuln.get('advisory', '')}", + tool_source=tool_name, + error_code=vuln.get("vulnerability_id"), + category=ErrorCategory.SECURITY.value, + subcategory="dependency_vulnerability", + fix_suggestion=f"Upgrade to version {vuln.get('fixed_in', 'latest')}" + ) + errors.append(error) + else: + # Text parsing + for line in output.splitlines(): + if "vulnerability" in line.lower() or "insecure" in line.lower(): + error = AnalysisError( + file_path=self.target_path, + line_number=None, + column_number=None, + error_type="vulnerability", + severity=DiagnosticSeverity.ERROR, + message=line.strip(), + tool_source=tool_name, + category=ErrorCategory.SECURITY.value + ) + errors.append(error) + + except json.JSONDecodeError: + pass + + return errors + + def _parse_vulture_output(self, output: str, tool_name: str, category: str) -> List[AnalysisError]: + """Parse Vulture dead code output.""" + errors = [] + + for line in output.splitlines(): + if ".py:" in line: + # Vulture format: file.py:line: unused function/class/variable 'name' + match = re.match(r"(.+?):(\d+):\s*(.+)", line) + if match: + file_path, line_num, message = match.groups() + + error = AnalysisError( + file_path=file_path, + line_number=int(line_num), + column_number=None, + error_type="dead_code", + severity=DiagnosticSeverity.WARNING, + message=message, + tool_source=tool_name, + category=ErrorCategory.LOGIC.value, + subcategory="unused_code", + tags=["dead_code", "optimization"] + ) + errors.append(error) + + return errors + + def _parse_radon_output(self, output: str, tool_name: str, category: str) -> List[AnalysisError]: + """Parse Radon complexity output.""" + errors = [] + + try: + # Try JSON format + if output.strip().startswith('{'): + radon_data = json.loads(output) + + for file_path, metrics in radon_data.items(): + for item in metrics: + complexity = item.get("complexity", 0) + if complexity > 10: # High complexity threshold + error = AnalysisError( + file_path=file_path, + line_number=item.get("lineno"), + column_number=item.get("col_offset"), + error_type="complexity", + severity=DiagnosticSeverity.WARNING if complexity < 20 else DiagnosticSeverity.ERROR, + message=f"High complexity ({complexity}) in {item.get('name', 'unknown')}", + tool_source=tool_name, + category=ErrorCategory.MAINTAINABILITY.value, + subcategory="complexity" + ) + errors.append(error) + else: + # Text parsing + for line in output.splitlines(): + if " - " in line and ("A " in line or "B " in line or "C " in line): + complexity_match = re.search(r"([A-F])\s*\((\d+)\)", line) + if complexity_match: + grade, complexity = complexity_match.groups() + complexity_value = int(complexity) + + if complexity_value > 10: + error = AnalysisError( + file_path=self.target_path, + line_number=None, + column_number=None, + error_type="complexity", + severity=DiagnosticSeverity.WARNING if complexity_value < 20 else DiagnosticSeverity.ERROR, + message=f"High complexity ({complexity_value}, grade {grade}): {line.split(' - ')[0]}", + tool_source=tool_name, + category=ErrorCategory.MAINTAINABILITY.value + ) + errors.append(error) + + except json.JSONDecodeError: + pass + + return errors + + def _parse_generic_output(self, output: str, tool_name: str, category: str) -> List[AnalysisError]: + """Generic parser for tool output.""" + errors = [] + + # Look for common error patterns + error_patterns = [ + r"(.+?):(\d+):(\d+):\s*(.+)", # file:line:col: message + r"(.+?):(\d+):\s*(.+)", # file:line: message + r"(.+?)\s*(.+?)\s*(.+)" # Generic three-part pattern + ] + + for line in output.splitlines(): + if not line.strip() or line.startswith('#'): + continue + + for pattern in error_patterns: + match = re.match(pattern, line) + if match: + groups = match.groups() + + if len(groups) >= 3 and '.py' in groups[0]: + try: + file_path = groups[0] + line_num = int(groups[1]) if groups[1].isdigit() else None + col_num = int(groups[2]) if len(groups) > 3 and groups[2].isdigit() else None + message = groups[-1] + + error = AnalysisError( + file_path=file_path, + line_number=line_num, + column_number=col_num, + error_type="unknown", + severity=DiagnosticSeverity.WARNING, + message=message, + tool_source=tool_name, + category=category + ) + errors.append(error) + break + except (ValueError, IndexError): + continue + + return errors + + def _parse_ruff_output(self, output: str, select_arg: str) -> List[AnalysisError]: + """Parse additional Ruff output.""" + errors = [] + + try: + if output.strip().startswith('['): + ruff_data = json.loads(output) + for issue in ruff_data: + error = AnalysisError( + file_path=issue.get("filename", ""), + line_number=issue.get("location", {}).get("row"), + column_number=issue.get("location", {}).get("column"), + error_type="ruff_extended", + severity=self._map_ruff_severity(issue.get("severity", "warning")), + message=issue.get("message", ""), + tool_source="ruff", + error_code=issue.get("code"), + category=self._categorize_ruff_error(issue.get("code", "")), + fix_suggestion=issue.get("fix", {}).get("message"), + tags=[select_arg.replace("--select=", "")] + ) + errors.append(error) + except json.JSONDecodeError: + # Text parsing for additional ruff output + for line in output.splitlines(): + if '.py:' in line and ':' in line: + parts = line.split(':') + if len(parts) >= 4: + try: + file_path = parts[0] + line_num = int(parts[1]) + col_num = int(parts[2]) + message = ':'.join(parts[3:]).strip() + + error = AnalysisError( + file_path=file_path, + line_number=line_num, + column_number=col_num, + error_type="ruff_extended", + severity=DiagnosticSeverity.WARNING, + message=message, + tool_source="ruff", + category=ErrorCategory.STYLE.value, + tags=[select_arg.replace("--select=", "")] + ) + errors.append(error) + except (ValueError, IndexError): + continue + + return errors + + def _map_pylint_severity(self, msg_type: str) -> DiagnosticSeverity: + """Map Pylint message type to diagnostic severity.""" + mapping = { + "error": DiagnosticSeverity.ERROR, + "warning": DiagnosticSeverity.WARNING, + "refactor": DiagnosticSeverity.INFORMATION, + "convention": DiagnosticSeverity.HINT, + "info": DiagnosticSeverity.INFORMATION + } + return mapping.get(msg_type.lower(), DiagnosticSeverity.WARNING) + + def _map_mypy_severity(self, severity: str) -> DiagnosticSeverity: + """Map MyPy severity to diagnostic severity.""" + mapping = { + "error": DiagnosticSeverity.ERROR, + "warning": DiagnosticSeverity.WARNING, + "note": DiagnosticSeverity.INFORMATION + } + return mapping.get(severity.lower(), DiagnosticSeverity.ERROR) + + def _map_pyright_severity(self, severity: str) -> DiagnosticSeverity: + """Map Pyright severity to diagnostic severity.""" + mapping = { + "error": DiagnosticSeverity.ERROR, + "warning": DiagnosticSeverity.WARNING, + "information": DiagnosticSeverity.INFORMATION + } + return mapping.get(severity.lower(), DiagnosticSeverity.ERROR) + + def _map_bandit_severity(self, severity: str) -> DiagnosticSeverity: + """Map Bandit severity to diagnostic severity.""" + mapping = { + "HIGH": DiagnosticSeverity.ERROR, + "MEDIUM": DiagnosticSeverity.WARNING, + "LOW": DiagnosticSeverity.INFORMATION + } + return mapping.get(severity.upper(), DiagnosticSeverity.WARNING) + + def _map_bandit_confidence(self, confidence: str) -> float: + """Map Bandit confidence to float value.""" + mapping = { + "HIGH": 0.9, + "MEDIUM": 0.6, + "LOW": 0.3 + } + return mapping.get(confidence.upper(), 0.5) + + def _categorize_pylint_error(self, message_id: str) -> str: + """Categorize Pylint errors.""" + if not message_id: + return ErrorCategory.LOGIC.value + + category_mapping = { + "C": ErrorCategory.STYLE.value, + "R": ErrorCategory.MAINTAINABILITY.value, + "W": ErrorCategory.LOGIC.value, + "E": ErrorCategory.LOGIC.value, + "F": ErrorCategory.SYNTAX.value + } + + prefix = message_id[0] if message_id else "W" + return category_mapping.get(prefix, ErrorCategory.LOGIC.value) + + def _categorize_mypy_error(self, error_info: str) -> str: + """Categorize MyPy errors.""" + if not error_info: + return "type_checking" + + if any(keyword in error_info.lower() for keyword in ["import", "module"]): + return "import_error" + elif any(keyword in error_info.lower() for keyword in ["return", "yield"]): + return "return_type" + elif any(keyword in error_info.lower() for keyword in ["argument", "parameter"]): + return "argument_type" + elif any(keyword in error_info.lower() for keyword in ["attribute", "member"]): + return "attribute_error" + else: + return "type_checking" + + def _categorize_errors(self, errors: List[AnalysisError]) -> Dict[str, Any]: + """Categorize and organize errors for comprehensive presentation.""" + categorized = { + "by_severity": defaultdict(list), + "by_category": defaultdict(list), + "by_file": defaultdict(list), + "by_tool": defaultdict(list), + "by_error_type": defaultdict(list), + "statistics": {}, + "priority_errors": [], + "fixable_errors": [], + "security_critical": [], + "performance_critical": [], + "maintainability_issues": [] + } + + # Organize errors + for error in errors: + categorized["by_severity"][error.severity.name].append(error) + categorized["by_category"][error.category or "uncategorized"].append(error) + categorized["by_file"][error.file_path].append(error) + categorized["by_tool"][error.tool_source].append(error) + categorized["by_error_type"][error.error_type].append(error) + + # Special categorizations + if error.severity in [DiagnosticSeverity.ERROR]: + categorized["priority_errors"].append(error) + + if error.fix_suggestion: + categorized["fixable_errors"].append(error) + + if error.category == ErrorCategory.SECURITY.value: + categorized["security_critical"].append(error) + + if error.category == ErrorCategory.PERFORMANCE.value: + categorized["performance_critical"].append(error) + + if error.category == ErrorCategory.MAINTAINABILITY.value: + categorized["maintainability_issues"].append(error) + + # Calculate statistics + categorized["statistics"] = { + "total_errors": len(errors), + "critical_errors": len([e for e in errors if e.severity == DiagnosticSeverity.ERROR]), + "warnings": len([e for e in errors if e.severity == DiagnosticSeverity.WARNING]), + "info_messages": len([e for e in errors if e.severity == DiagnosticSeverity.INFORMATION]), + "files_with_errors": len(categorized["by_file"]), + "tools_with_findings": len(categorized["by_tool"]), + "categories_affected": len(categorized["by_category"]), + "fixable_count": len(categorized["fixable_errors"]), + "security_issues": len(categorized["security_critical"]), + "performance_issues": len(categorized["performance_critical"]) + } + + return categorized + + def _find_entrypoints(self) -> List[Dict[str, Any]]: + """Find entrypoints in the codebase.""" + entrypoints = [] + + if self.graph_analysis: + # Use graph-sitter to find entrypoints + for func in self.graph_analysis.functions[:20]: # Limit for performance + func_name = func.get('name', '') if isinstance(func, dict) else getattr(func, 'name', '') + if func_name in ['main', '__main__', 'run', 'start', 'execute']: + entrypoint = { + "type": "function", + "name": func_name, + "file_path": func.get('file_path', '') if isinstance(func, dict) else getattr(func, 'file_path', ''), + "line_number": func.get('line_number', 0) if isinstance(func, dict) else getattr(func, 'line_number', 0), + "category": "entrypoint" + } + entrypoints.append(entrypoint) + + # Look for classes that might be entrypoints + for cls in self.graph_analysis.classes[:20]: + cls_name = cls.get('name', '') if isinstance(cls, dict) else getattr(cls, 'name', '') + if any(pattern in cls_name.lower() for pattern in ['app', 'main', 'server', 'client', 'runner']): + entrypoint = { + "type": "class", + "name": cls_name, + "file_path": cls.get('file_path', '') if isinstance(cls, dict) else getattr(cls, 'file_path', ''), + "line_number": cls.get('line_number', 0) if isinstance(cls, dict) else getattr(cls, 'line_number', 0), + "category": "entrypoint" + } + entrypoints.append(entrypoint) + + # Look for if __name__ == "__main__" patterns + for file_info in (self.graph_analysis.files if self.graph_analysis else [])[:30]: + file_path = file_info.get('path', '') if isinstance(file_info, dict) else getattr(file_info, 'path', '') + if file_path and file_path.endswith('.py'): + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + if 'if __name__ == "__main__"' in content: + entrypoint = { + "type": "script", + "name": os.path.basename(file_path), + "file_path": file_path, + "line_number": content.count('\n', 0, content.find('if __name__ == "__main__"')) + 1, + "category": "entrypoint" + } + entrypoints.append(entrypoint) + + except Exception: + continue + + return entrypoints + + def _find_dead_code(self) -> List[Dict[str, Any]]: + """Find dead code in the codebase.""" + dead_code = [] + + if not self.graph_analysis: + return dead_code + + # Find unused functions + for func in self.graph_analysis.functions: + func_name = func.get('name', '') if isinstance(func, dict) else getattr(func, 'name', '') + if func_name: + usages = self.graph_analysis._find_function_usages(func_name) + if not usages and not func_name.startswith('_'): # Ignore private functions + dead_code.append({ + "type": "function", + "name": func_name, + "file_path": func.get('file_path', '') if isinstance(func, dict) else getattr(func, 'file_path', ''), + "line_number": func.get('line_number', 0) if isinstance(func, dict) else getattr(func, 'line_number', 0), + "category": "unused_function", + "context": "Not used by any other code" + }) + + # Find unused classes + for cls in self.graph_analysis.classes: + cls_name = cls.get('name', '') if isinstance(cls, dict) else getattr(cls, 'name', '') + if cls_name: + usages = self.graph_analysis._find_symbol_usages(cls_name) + if not usages and not cls_name.startswith('_'): + dead_code.append({ + "type": "class", + "name": cls_name, + "file_path": cls.get('file_path', '') if isinstance(cls, dict) else getattr(cls, 'file_path', ''), + "line_number": cls.get('line_number', 0) if isinstance(cls, dict) else getattr(cls, 'line_number', 0), + "category": "unused_class", + "context": "Not used by any other code" + }) + + # Find unused imports + for imp in self.graph_analysis.imports: + imp_name = imp.get('name', '') if isinstance(imp, dict) else getattr(imp, 'name', '') + imp_file = imp.get('file_path', '') if isinstance(imp, dict) else getattr(imp, 'file_path', '') + + if imp_name and imp_file: + # Check if the imported name is used in the file + try: + with open(imp_file, 'r', encoding='utf-8') as f: + content = f.read() + + # Simple check - look for the name in the file content + import_line_num = imp.get('line_number', 0) if isinstance(imp, dict) else getattr(imp, 'line_number', 0) + lines_after_import = content.splitlines()[import_line_num:] + + if not any(imp_name in line for line in lines_after_import): + dead_code.append({ + "type": "import", + "name": imp_name, + "file_path": imp_file, + "line_number": import_line_num, + "category": "unused_import", + "context": f"Imported but never used in {os.path.basename(imp_file)}" + }) + + except Exception: + continue + + return dead_code + + def _analyze_dependencies(self) -> Dict[str, Any]: + """Analyze project dependencies.""" + dependency_analysis = { + "external_dependencies": [], + "internal_dependencies": [], + "circular_dependencies": [], + "dependency_graph": {}, + "security_issues": [], + "outdated_packages": [], + "dependency_statistics": {} + } + + try: + # Analyze external dependencies + if self.graph_analysis: + for ext_mod in self.graph_analysis.external_modules: + mod_name = ext_mod.get('name', '') if isinstance(ext_mod, dict) else getattr(ext_mod, 'name', '') + if mod_name: + dependency_analysis["external_dependencies"].append({ + "name": mod_name, + "usage_count": len(self.graph_analysis._find_symbol_usages(mod_name)), + "import_locations": [ + imp.get('file_path', '') if isinstance(imp, dict) else getattr(imp, 'file_path', '') + for imp in self.graph_analysis.imports + if (imp.get('module', '') if isinstance(imp, dict) else getattr(imp, 'module', '')) == mod_name + ] + }) + + # Check for circular dependencies using imports + dependency_analysis["circular_dependencies"] = self._find_circular_dependencies() + + # Analyze requirements.txt or pyproject.toml + dependency_analysis["dependency_statistics"] = self._analyze_dependency_files() + + except Exception as e: + logging.error(f"Error in dependency analysis: {e}") + dependency_analysis["error"] = str(e) + + return dependency_analysis + + def _find_circular_dependencies(self) -> List[Dict[str, Any]]: + """Find circular dependencies in the codebase.""" + circular_deps = [] + + if not self.graph_analysis: + return circular_deps + + # Build a simple dependency graph + deps_graph = defaultdict(set) + + for imp in self.graph_analysis.imports: + imp_file = imp.get('file_path', '') if isinstance(imp, dict) else getattr(imp, 'file_path', '') + imp_module = imp.get('module', '') if isinstance(imp, dict) else getattr(imp, 'module', '') + + if imp_file and imp_module and not imp.get('is_external', True): + # Convert module name to file path (simplified) + target_file = imp_module.replace('.', '/') + '.py' + deps_graph[imp_file].add(target_file) + + # Simple cycle detection + visited = set() + rec_stack = set() + + def has_cycle(node, path): + if node in rec_stack: + # Found a cycle + cycle_start = path.index(node) + cycle = path[cycle_start:] + [node] + return cycle + + if node in visited: + return None + + visited.add(node) + rec_stack.add(node) + + for neighbor in deps_graph.get(node, []): + cycle = has_cycle(neighbor, path + [node]) + if cycle: + return cycle + + rec_stack.remove(node) + return None + + for node in deps_graph: + if node not in visited: + cycle = has_cycle(node, []) + if cycle: + circular_deps.append({ + "cycle": cycle, + "type": "import_cycle", + "severity": "high", + "description": f"Circular import detected: {' -> '.join(cycle)}" + }) + + return circular_deps + + def _analyze_dependency_files(self) -> Dict[str, Any]: + """Analyze dependency configuration files.""" + stats = { + "requirements_files": [], + "pyproject_config": {}, + "total_dependencies": 0, + "dev_dependencies": 0, + "optional_dependencies": 0 + } + + # Check for requirements files + req_files = ["requirements.txt", "requirements-dev.txt", "requirements-test.txt"] + base_dir = self.target_path if os.path.isdir(self.target_path) else os.path.dirname(self.target_path) + + for req_file in req_files: + req_path = os.path.join(base_dir, req_file) + if os.path.exists(req_path): + try: + with open(req_path, 'r') as f: + lines = [line.strip() for line in f.readlines() if line.strip() and not line.startswith('#')] + stats["requirements_files"].append({ + "file": req_file, + "dependencies": len(lines), + "content": lines[:10] # First 10 for sample + }) + stats["total_dependencies"] += len(lines) + except Exception: + pass + + # Check pyproject.toml + pyproject_path = os.path.join(base_dir, "pyproject.toml") + if os.path.exists(pyproject_path): + try: + import tomllib + with open(pyproject_path, 'rb') as f: + pyproject_data = tomllib.load(f) + + project = pyproject_data.get("project", {}) + dependencies = project.get("dependencies", []) + optional_deps = project.get("optional-dependencies", {}) + + stats["pyproject_config"] = { + "dependencies": len(dependencies), + "optional_dependency_groups": len(optional_deps), + "total_optional": sum(len(deps) for deps in optional_deps.values()) + } + stats["total_dependencies"] += len(dependencies) + stats["optional_dependencies"] = sum(len(deps) for deps in optional_deps.values()) + + except Exception as e: + stats["pyproject_config"] = {"error": str(e)} + + return stats + + def _calculate_performance_metrics(self) -> Dict[str, Any]: + """Calculate performance-related metrics.""" + metrics = { + "complexity_analysis": {}, + "function_metrics": {}, + "class_metrics": {}, + "file_metrics": {}, + "performance_warnings": [] + } + + if not self.graph_analysis: + return metrics + + try: + # Function complexity metrics + function_complexities = [] + for func in self.graph_analysis.functions: + complexity = self.graph_analysis._calculate_function_complexity(func) + function_complexities.append(complexity) + + if complexity > 15: # High complexity threshold + func_name = func.get('name', '') if isinstance(func, dict) else getattr(func, 'name', '') + func_file = func.get('file_path', '') if isinstance(func, dict) else getattr(func, 'file_path', '') + + metrics["performance_warnings"].append({ + "type": "high_complexity", + "function": func_name, + "file": func_file, + "complexity": complexity, + "recommendation": "Consider refactoring into smaller functions" + }) + + metrics["function_metrics"] = { + "total_functions": len(function_complexities), + "average_complexity": sum(function_complexities) / len(function_complexities) if function_complexities else 0, + "max_complexity": max(function_complexities) if function_complexities else 0, + "high_complexity_count": len([c for c in function_complexities if c > 10]) + } + + # Class metrics + class_complexities = [] + for cls in self.graph_analysis.classes: + complexity = self.graph_analysis._calculate_class_complexity(cls) + class_complexities.append(complexity) + + metrics["class_metrics"] = { + "total_classes": len(class_complexities), + "average_complexity": sum(class_complexities) / len(class_complexities) if class_complexities else 0, + "max_complexity": max(class_complexities) if class_complexities else 0 + } + + # File metrics + file_sizes = [] + for file_info in self.graph_analysis.files: + size = file_info.get('line_count', 0) if isinstance(file_info, dict) else getattr(file_info, 'line_count', 0) + file_sizes.append(size) + + if size > 500: # Large file threshold + file_path = file_info.get('path', '') if isinstance(file_info, dict) else getattr(file_info, 'path', '') + metrics["performance_warnings"].append({ + "type": "large_file", + "file": file_path, + "lines": size, + "recommendation": "Consider splitting into smaller modules" + }) + + metrics["file_metrics"] = { + "total_files": len(file_sizes), + "average_file_size": sum(file_sizes) / len(file_sizes) if file_sizes else 0, + "largest_file": max(file_sizes) if file_sizes else 0, + "large_files_count": len([s for s in file_sizes if s > 300]) + } + + except Exception as e: + logging.error(f"Error calculating performance metrics: {e}") + metrics["error"] = str(e) + + return metrics + + def _calculate_quality_metrics(self) -> Dict[str, Any]: + """Calculate code quality metrics.""" + metrics = { + "documentation_coverage": 0.0, + "test_coverage": 0.0, + "maintainability_index": 0.0, + "technical_debt_ratio": 0.0, + "code_duplication": 0.0, + "quality_gates": {}, + "recommendations": [] + } + + try: + if self.graph_analysis: + # Documentation coverage + documented_functions = 0 + total_functions = len(self.graph_analysis.functions) + + for func in self.graph_analysis.functions: + func_file = func.get('file_path', '') if isinstance(func, dict) else getattr(func, 'file_path', '') + func_line = func.get('line_number', 0) if isinstance(func, dict) else getattr(func, 'line_number', 0) + + if func_file and func_line: + try: + with open(func_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + + # Check for docstring after function definition + if func_line < len(lines): + for i in range(func_line, min(func_line + 5, len(lines))): + if '"""' in lines[i] or "'''" in lines[i]: + documented_functions += 1 + break + except Exception: + continue + + metrics["documentation_coverage"] = documented_functions / total_functions if total_functions > 0 else 0.0 + + # Calculate maintainability index (simplified) + avg_complexity = self.graph_analysis._calculate_complexity_score() + total_loc = sum(f.get('line_count', 0) for f in self.graph_analysis.files) + + # Simplified maintainability index calculation + metrics["maintainability_index"] = max(0, 100 - (avg_complexity * 10) - (total_loc / 1000)) + + # Quality gates + metrics["quality_gates"] = { + "documentation_gate": metrics["documentation_coverage"] >= 0.7, + "complexity_gate": avg_complexity <= 10, + "file_size_gate": all(f.get('line_count', 0) <= 500 for f in self.graph_analysis.files), + "maintainability_gate": metrics["maintainability_index"] >= 60 + } + + # Generate recommendations + if metrics["documentation_coverage"] < 0.5: + metrics["recommendations"].append("Improve documentation coverage") + if avg_complexity > 15: + metrics["recommendations"].append("Reduce code complexity") + if metrics["maintainability_index"] < 50: + metrics["recommendations"].append("Focus on code maintainability") + + except Exception as e: + logging.error(f"Error calculating quality metrics: {e}") + metrics["error"] = str(e) + + return metrics + + def _build_dependency_graph(self) -> Dict[str, Any]: + """Build comprehensive dependency graph.""" + graph = { + "nodes": [], + "edges": [], + "clusters": [], + "metrics": {} + } + + if not self.graph_analysis: + return graph + + try: + # Add file nodes + for file_info in self.graph_analysis.files: + file_path = file_info.get('path', '') if isinstance(file_info, dict) else getattr(file_info, 'path', '') + if file_path: + graph["nodes"].append({ + "id": file_path, + "type": "file", + "label": os.path.basename(file_path), + "size": file_info.get('line_count', 0) if isinstance(file_info, dict) else getattr(file_info, 'line_count', 0) + }) + + # Add dependency edges + for imp in self.graph_analysis.imports: + source_file = imp.get('file_path', '') if isinstance(imp, dict) else getattr(imp, 'file_path', '') + target_module = imp.get('module', '') if isinstance(imp, dict) else getattr(imp, 'module', '') + + if source_file and target_module: + # Convert module to file path (simplified) + if not imp.get('is_external', True): + target_file = target_module.replace('.', '/') + '.py' + graph["edges"].append({ + "source": source_file, + "target": target_file, + "type": "import", + "weight": 1 + }) + + # Calculate metrics + graph["metrics"] = { + "total_nodes": len(graph["nodes"]), + "total_edges": len(graph["edges"]), + "average_dependencies": len(graph["edges"]) / len(graph["nodes"]) if graph["nodes"] else 0 + } + + except Exception as e: + logging.error(f"Error building dependency graph: {e}") + graph["error"] = str(e) + + return graph + + def _build_inheritance_hierarchy(self) -> Dict[str, Any]: + """Build inheritance hierarchy.""" + hierarchy = { + "inheritance_trees": [], + "abstract_classes": [], + "leaf_classes": [], + "metrics": {} + } + + if not self.graph_analysis: + return hierarchy + + try: + # Build inheritance relationships + class_inheritance = {} + for cls in self.graph_analysis.classes: + cls_name = cls.get('name', '') if isinstance(cls, dict) else getattr(cls, 'name', '') + bases = cls.get('bases', []) if isinstance(cls, dict) else getattr(cls, 'bases', []) + + if cls_name: + class_inheritance[cls_name] = { + "bases": bases, + "file_path": cls.get('file_path', '') if isinstance(cls, dict) else getattr(cls, 'file_path', ''), + "line_number": cls.get('line_number', 0) if isinstance(cls, dict) else getattr(cls, 'line_number', 0) + } + + # Find inheritance chains + for cls_name, cls_info in class_inheritance.items(): + if cls_info["bases"]: + chain = [cls_name] + current = cls_name + + while current in class_inheritance and class_inheritance[current]["bases"]: + bases = class_inheritance[current]["bases"] + if bases and bases[0] in class_inheritance: + current = bases[0] + chain.append(current) + else: + break + + if len(chain) > 1: + hierarchy["inheritance_trees"].append({ + "chain": chain, + "depth": len(chain), + "root_class": chain[-1], + "leaf_class": chain[0] + }) + + # Calculate metrics + hierarchy["metrics"] = { + "total_classes": len(class_inheritance), + "classes_with_inheritance": len([c for c in class_inheritance.values() if c["bases"]]), + "max_inheritance_depth": max([len(tree["chain"]) for tree in hierarchy["inheritance_trees"]], default=0), + "average_inheritance_depth": sum([len(tree["chain"]) for tree in hierarchy["inheritance_trees"]]) / len(hierarchy["inheritance_trees"]) if hierarchy["inheritance_trees"] else 0 + } + + except Exception as e: + logging.error(f"Error building inheritance hierarchy: {e}") + hierarchy["error"] = str(e) + + return hierarchy + + def _build_call_graph(self) -> Dict[str, Any]: + """Build function call graph.""" + call_graph = { + "nodes": [], + "edges": [], + "metrics": {}, + "hotspots": [], + "call_chains": [] + } + + if not self.graph_analysis: + return call_graph + + try: + # Add function nodes + for func in self.graph_analysis.functions: + func_name = func.get('name', '') if isinstance(func, dict) else getattr(func, 'name', '') + if func_name: + call_graph["nodes"].append({ + "id": func_name, + "type": "function", + "file": func.get('file_path', '') if isinstance(func, dict) else getattr(func, 'file_path', ''), + "complexity": self.graph_analysis._calculate_function_complexity(func), + "parameters": len(func.get('parameters', []) if isinstance(func, dict) else getattr(func, 'parameters', [])) + }) + + # Simple call relationship detection (would be more sophisticated with graph-sitter) + for func in self.graph_analysis.functions: + func_name = func.get('name', '') if isinstance(func, dict) else getattr(func, 'name', '') + func_file = func.get('file_path', '') if isinstance(func, dict) else getattr(func, 'file_path', '') + + if func_file and func_name: + try: + with open(func_file, 'r', encoding='utf-8') as f: + content = f.read() + + # Look for function calls (simplified) + for other_func in self.graph_analysis.functions: + other_name = other_func.get('name', '') if isinstance(other_func, dict) else getattr(other_func, 'name', '') + if other_name and other_name != func_name and f"{other_name}(" in content: + call_graph["edges"].append({ + "source": func_name, + "target": other_name, + "type": "function_call" + }) + except Exception: + continue + + # Find hotspots (most called functions) + call_counts = Counter(edge["target"] for edge in call_graph["edges"]) + call_graph["hotspots"] = [ + {"function": func, "call_count": count} + for func, count in call_counts.most_common(10) + ] + + # Calculate metrics + call_graph["metrics"] = { + "total_functions": len(call_graph["nodes"]), + "total_calls": len(call_graph["edges"]), + "average_calls_per_function": len(call_graph["edges"]) / len(call_graph["nodes"]) if call_graph["nodes"] else 0, + "most_called_function": call_graph["hotspots"][0]["function"] if call_graph["hotspots"] else None + } + + except Exception as e: + logging.error(f"Error building call graph: {e}") + call_graph["error"] = str(e) + + return call_graph + + def _build_import_graph(self) -> Dict[str, Any]: + """Build import dependency graph.""" + import_graph = { + "internal_imports": [], + "external_imports": [], + "import_clusters": [], + "unused_imports": [], + "metrics": {} + } + + if not self.graph_analysis: + return import_graph + + try: + # Categorize imports + for imp in self.graph_analysis.imports: + imp_data = { + "module": imp.get('module', '') if isinstance(imp, dict) else getattr(imp, 'module', ''), + "name": imp.get('name', '') if isinstance(imp, dict) else getattr(imp, 'name', ''), + "file_path": imp.get('file_path', '') if isinstance(imp, dict) else getattr(imp, 'file_path', ''), + "line_number": imp.get('line_number', 0) if isinstance(imp, dict) else getattr(imp, 'line_number', 0) + } + + if imp.get('is_external', True): + import_graph["external_imports"].append(imp_data) + else: + import_graph["internal_imports"].append(imp_data) + + # Find import clusters (files that import similar modules) + external_by_file = defaultdict(set) + for imp in import_graph["external_imports"]: + external_by_file[imp["file_path"]].add(imp["module"]) + + # Group files with similar import patterns + import_patterns = defaultdict(list) + for file_path, modules in external_by_file.items(): + pattern_key = tuple(sorted(modules)) + import_patterns[pattern_key].append(file_path) + + for pattern, files in import_patterns.items(): + if len(files) > 1: + import_graph["import_clusters"].append({ + "pattern": list(pattern), + "files": files, + "cluster_size": len(files) + }) + + # Calculate metrics + import_graph["metrics"] = { + "total_imports": len(self.graph_analysis.imports), + "external_imports": len(import_graph["external_imports"]), + "internal_imports": len(import_graph["internal_imports"]), + "unique_external_modules": len(set(imp["module"] for imp in import_graph["external_imports"])), + "import_clusters": len(import_graph["import_clusters"]), + "average_imports_per_file": len(self.graph_analysis.imports) / len(self.graph_analysis.files) if self.graph_analysis.files else 0 + } + + except Exception as e: + logging.error(f"Error building import graph: {e}") + import_graph["error"] = str(e) + + return import_graph + + def _summarize_lsp_diagnostics(self, diagnostics_by_file: Dict) -> Dict[str, Any]: + """Summarize LSP diagnostics.""" + summary = { + "total_files_with_errors": len(diagnostics_by_file), + "severity_breakdown": defaultdict(int), + "error_sources": defaultdict(int), + "most_problematic_files": [] + } + + file_error_counts = [] + + for file_path, diagnostics in diagnostics_by_file.items(): + error_count = len(diagnostics) + file_error_counts.append((file_path, error_count)) + + for diagnostic in diagnostics: + severity = diagnostic.get("severity", 1) + source = diagnostic.get("source", "unknown") + + summary["severity_breakdown"][severity] += 1 + summary["error_sources"][source] += 1 + + # Sort files by error count + file_error_counts.sort(key=lambda x: x[1], reverse=True) + summary["most_problematic_files"] = file_error_counts[:10] + + return summary + + def _generate_comprehensive_summary(self, analysis_results: Dict[str, Any]) -> Dict[str, Any]: + """Generate comprehensive analysis summary.""" + summary = { + "overview": {}, + "critical_findings": [], + "recommendations": [], + "quality_score": 0.0, + "risk_assessment": {}, + "action_items": [] + } + + try: + # Overview statistics + total_errors = len(analysis_results.get("errors", [])) + categorized = analysis_results.get("categorized_errors", {}) + + summary["overview"] = { + "total_errors": total_errors, + "critical_errors": categorized.get("statistics", {}).get("critical_errors", 0), + "warnings": categorized.get("statistics", {}).get("warnings", 0), + "files_analyzed": len(analysis_results.get("graph_sitter_analysis", {}).get("files_analysis", {})), + "tools_used": len(analysis_results.get("metadata", {}).get("tools_used", [])), + "analysis_duration": analysis_results.get("metadata", {}).get("analysis_duration", 0) + } + + # Critical findings + priority_errors = categorized.get("priority_errors", []) + security_issues = categorized.get("security_critical", []) + + summary["critical_findings"] = [ + f"Found {len(priority_errors)} critical errors", + f"Identified {len(security_issues)} security issues", + f"Detected {len(analysis_results.get('dead_code', []))} dead code instances" + ] + + # Risk assessment + summary["risk_assessment"] = { + "security_risk": "high" if len(security_issues) > 5 else "medium" if len(security_issues) > 0 else "low", + "maintenance_risk": "high" if total_errors > 100 else "medium" if total_errors > 20 else "low", + "quality_risk": "high" if categorized.get("statistics", {}).get("critical_errors", 0) > 10 else "low" + } + + # Calculate quality score + quality_metrics = analysis_results.get("quality_metrics", {}) + performance_metrics = analysis_results.get("performance_metrics", {}) + + doc_coverage = quality_metrics.get("documentation_coverage", 0.0) + maintainability = quality_metrics.get("maintainability_index", 0.0) + error_penalty = min(total_errors * 0.5, 50) # Cap penalty at 50 points + + summary["quality_score"] = max(0.0, (doc_coverage * 30) + (maintainability * 0.7) - error_penalty) + + # Generate recommendations + if doc_coverage < 0.5: + summary["recommendations"].append("Improve code documentation") + if len(security_issues) > 0: + summary["recommendations"].append("Address security vulnerabilities") + if total_errors > 50: + summary["recommendations"].append("Reduce overall error count") + if len(analysis_results.get("dead_code", [])) > 10: + summary["recommendations"].append("Remove dead code") + + # Action items + summary["action_items"] = [ + { + "priority": "high", + "category": "security", + "action": f"Fix {len(security_issues)} security issues", + "estimated_effort": "medium" + }, + { + "priority": "medium", + "category": "quality", + "action": f"Address {categorized.get('statistics', {}).get('critical_errors', 0)} critical errors", + "estimated_effort": "high" + }, + { + "priority": "low", + "category": "maintenance", + "action": f"Clean up {len(analysis_results.get('dead_code', []))} dead code items", + "estimated_effort": "low" + } + ] + + except Exception as e: + logging.error(f"Error generating summary: {e}") + summary["error"] = str(e) + + return summary + + +class ErrorDatabase: + """Database for storing and querying analysis errors.""" + + def __init__(self, db_path: str = ":memory:"): + self.db_path = db_path + self.connection = sqlite3.connect(db_path, check_same_thread=False) + self._create_tables() + + def _create_tables(self): + """Create database tables for storing errors.""" + cursor = self.connection.cursor() + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS analysis_errors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_path TEXT NOT NULL, + line_number INTEGER, + column_number INTEGER, + error_type TEXT NOT NULL, + severity TEXT NOT NULL, + message TEXT NOT NULL, + tool_source TEXT NOT NULL, + category TEXT, + subcategory TEXT, + error_code TEXT, + confidence REAL DEFAULT 1.0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + fixed BOOLEAN DEFAULT FALSE + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS analysis_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + target_path TEXT NOT NULL, + start_time TIMESTAMP, + end_time TIMESTAMP, + total_errors INTEGER, + tools_used TEXT, + config_hash TEXT + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_file_path ON analysis_errors(file_path); + CREATE INDEX IF NOT EXISTS idx_severity ON analysis_errors(severity); + CREATE INDEX IF NOT EXISTS idx_category ON analysis_errors(category); + CREATE INDEX IF NOT EXISTS idx_tool_source ON analysis_errors(tool_source); + """) + + self.connection.commit() + + def store_errors(self, errors: List[AnalysisError], session_id: int): + """Store errors in the database.""" + cursor = self.connection.cursor() + + for error in errors: + cursor.execute(""" + INSERT INTO analysis_errors + (file_path, line_number, column_number, error_type, severity, message, + tool_source, category, subcategory, error_code, confidence) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + error.file_path, + error.line_number, + error.column_number, + error.error_type, + error.severity.name, + error.message, + error.tool_source, + error.category, + error.subcategory, + error.error_code, + error.confidence + )) + + self.connection.commit() + + def create_session(self, target_path: str, tools_used: List[str], config: Dict) -> int: + """Create a new analysis session.""" + cursor = self.connection.cursor() + + config_hash = hashlib.md5(json.dumps(config, sort_keys=True).encode()).hexdigest() + + cursor.execute(""" + INSERT INTO analysis_sessions (target_path, start_time, tools_used, config_hash) + VALUES (?, ?, ?, ?) + """, (target_path, time.strftime("%Y-%m-%d %H:%M:%S"), json.dumps(tools_used), config_hash)) + + self.connection.commit() + return cursor.lastrowid + + def update_session(self, session_id: int, total_errors: int): + """Update session with final statistics.""" + cursor = self.connection.cursor() + + cursor.execute(""" + UPDATE analysis_sessions + SET end_time = ?, total_errors = ? + WHERE id = ? + """, (time.strftime("%Y-%m-%d %H:%M:%S"), total_errors, session_id)) + + self.connection.commit() + + def query_errors(self, filters: Dict[str, Any]) -> List[Dict]: + """Query errors with filters.""" + cursor = self.connection.cursor() + + query = "SELECT * FROM analysis_errors WHERE 1=1" + params = [] + + if "file_path" in filters: + query += " AND file_path = ?" + params.append(filters["file_path"]) + + if "severity" in filters: + query += " AND severity = ?" + params.append(filters["severity"]) + + if "category" in filters: + query += " AND category = ?" + params.append(filters["category"]) + + if "tool_source" in filters: + query += " AND tool_source = ?" + params.append(filters["tool_source"]) + + cursor.execute(query, params) + columns = [description[0] for description in cursor.description] + + return [dict(zip(columns, row)) for row in cursor.fetchall()] + + +class ReportGenerator: + """Generate comprehensive analysis reports in multiple formats.""" + + def __init__(self, analysis_results: Dict[str, Any]): + self.results = analysis_results + + def generate_terminal_report(self) -> str: + """Generate comprehensive terminal report.""" + lines = [] + + # Header + lines.extend([ + "="*100, + "COMPREHENSIVE PYTHON CODE ANALYSIS REPORT", + "="*100, + "" + ]) + + # Metadata + metadata = self.results.get("metadata", {}) + lines.extend([ + f"Target: {metadata.get('target_path', 'Unknown')}", + f"Analysis Duration: {metadata.get('analysis_duration', 0):.2f}s", + f"Tools Used: {', '.join(metadata.get('tools_used', []))}", + f"Python Version: {metadata.get('python_version', 'Unknown')}", + "" + ]) + + # Graph-sitter analysis summary + gs_analysis = self.results.get("graph_sitter_analysis", {}) + if gs_analysis and "codebase_summary" in gs_analysis: + summary = gs_analysis["codebase_summary"] + lines.extend([ + "CODEBASE STRUCTURE:", + "-" * 50, + f"Files: {summary.get('total_files', 0)}", + f"Functions: {summary.get('total_functions', 0)}", + f"Classes: {summary.get('total_classes', 0)}", + f"Imports: {summary.get('total_imports', 0)}", + f"External Dependencies: {summary.get('external_dependencies', 0)}", + f"Lines of Code: {summary.get('lines_of_code', 0)}", + f"Complexity Score: {summary.get('complexity_score', 0):.2f}", + "" + ]) + + # Entrypoints + entrypoints = self.results.get("entrypoints", []) + if entrypoints: + lines.extend([ + f"ENTRYPOINTS: [{len(entrypoints)}]", + "-" * 50 + ]) + for i, entry in enumerate(entrypoints[:10], 1): + entry_type = entry.get("type", "unknown").title() + entry_name = entry.get("name", "unknown") + entry_file = entry.get("file_path", "") + lines.append(f"{i}. {entry_type}: {entry_name} [{os.path.basename(entry_file)}]") + lines.append("") + + # Dead code + dead_code = self.results.get("dead_code", []) + if dead_code: + dead_by_type = defaultdict(int) + for item in dead_code: + dead_by_type[item.get("category", "unknown")] += 1 + + lines.extend([ + f"DEAD CODE: {len(dead_code)} [{', '.join(f'{cat.title()}: {count}' for cat, count in dead_by_type.items())}]", + "-" * 50 + ]) + for i, item in enumerate(dead_code[:20], 1): + item_type = item.get("type", "unknown") + item_name = item.get("name", "unknown") + item_file = os.path.basename(item.get("file_path", "")) + context = item.get("context", "") + lines.append(f"{i}. {item_type.title()}: '{item_name}' [{item_file}] - {context}") + + if len(dead_code) > 20: + lines.append(f"... and {len(dead_code) - 20} more items") + lines.append("") + + # Error summary + categorized = self.results.get("categorized_errors", {}) + stats = categorized.get("statistics", {}) + + if stats: + lines.extend([ + f"ERRORS: {stats.get('total_errors', 0)} " + f"[Critical: {stats.get('critical_errors', 0)}] " + f"[Warnings: {stats.get('warnings', 0)}] " + f"[Info: {stats.get('info_messages', 0)}]", + "-" * 50 + ]) + + # List errors by severity and tool + by_tool = categorized.get("by_tool", {}) + for i, (tool, tool_errors) in enumerate(by_tool.items(), 1): + if tool_errors: + critical_count = len([e for e in tool_errors if e.severity == DiagnosticSeverity.ERROR]) + warning_count = len([e for e in tool_errors if e.severity == DiagnosticSeverity.WARNING]) + + lines.append(f"{i}. {tool.upper()}: {len(tool_errors)} issues " + f"[Critical: {critical_count}, Warnings: {warning_count}]") + + lines.append("") + + # Detailed error listing by category + by_category = categorized.get("by_category", {}) + for category, category_errors in by_category.items(): + if category_errors: + lines.extend([ + f"{category.upper()} ERRORS: {len(category_errors)}", + "-" * 30 + ]) + + # Group by file for better organization + errors_by_file = defaultdict(list) + for error in category_errors[:50]: # Limit display + errors_by_file[error.file_path].append(error) + + for file_path, file_errors in list(errors_by_file.items())[:10]: # Limit files + rel_path = os.path.relpath(file_path, self.results.get("metadata", {}).get("target_path", "")) + lines.append(f" {rel_path}: {len(file_errors)} issues") + + for error in file_errors[:5]: # Limit errors per file + severity_icon = { + DiagnosticSeverity.ERROR: "⚠️", + DiagnosticSeverity.WARNING: "👉", + DiagnosticSeverity.INFORMATION: "🔍", + DiagnosticSeverity.HINT: "💡" + }.get(error.severity, "•") + + location = f"Line {error.line_number}" if error.line_number else "Unknown location" + tool_info = f"[{error.tool_source}]" + + lines.append(f" {severity_icon} {location}: {error.message[:80]}{'...' if len(error.message) > 80 else ''} {tool_info}") + + if len(file_errors) > 5: + lines.append(f" ... and {len(file_errors) - 5} more issues in this file") + + if len(errors_by_file) > 10: + lines.append(f" ... and {len(errors_by_file) - 10} more files with {category} errors") + + lines.append("") + + # Performance and quality insights + perf_metrics = self.results.get("performance_metrics", {}) + quality_metrics = self.results.get("quality_metrics", {}) + + if perf_metrics.get("function_metrics"): + func_metrics = perf_metrics["function_metrics"] + lines.extend([ + "PERFORMANCE INSIGHTS:", + "-" * 30, + f"Average Function Complexity: {func_metrics.get('average_complexity', 0):.2f}", + f"High Complexity Functions: {func_metrics.get('high_complexity_count', 0)}", + f"Largest File: {perf_metrics.get('file_metrics', {}).get('largest_file', 0)} lines", + "" + ]) + + if quality_metrics: + lines.extend([ + "QUALITY METRICS:", + "-" * 30, + f"Documentation Coverage: {quality_metrics.get('documentation_coverage', 0)*100:.1f}%", + f"Maintainability Index: {quality_metrics.get('maintainability_index', 0):.1f}", + f"Quality Score: {self.results.get('summary', {}).get('quality_score', 0):.1f}/100", + "" + ]) + + # Recommendations + recommendations = self.results.get("summary", {}).get("recommendations", []) + if recommendations: + lines.extend([ + "RECOMMENDATIONS:", + "-" * 30 + ]) + for i, rec in enumerate(recommendations, 1): + lines.append(f"{i}. {rec}") + lines.append("") + + except Exception as e: + lines.extend([ + "ERROR GENERATING SUMMARY:", + str(e), + "" + ]) + + lines.append("="*100) + return "\n".join(lines) + + def generate_json_report(self) -> str: + """Generate JSON report.""" + return json.dumps(self.results, indent=2, default=str) + + def generate_html_report(self) -> str: + """Generate comprehensive HTML report.""" + html_template = """ + + + + + + Comprehensive Code Analysis Report + + + + +
+
+

Code Analysis Report

+

Target: {target_path}

+

Generated on {timestamp}

+
+ +
+ {metrics_cards} +
+ +
+

📊 Analysis Overview

+ {overview_content} +
+ +
+

🎯 Entrypoints

+ {entrypoints_content} +
+ +
+

💀 Dead Code Analysis

+ {dead_code_content} +
+ +
+

🚨 Error Analysis

+ {error_analysis_content} +
+ +
+

📈 Performance Metrics

+ {performance_content} +
+ +
+

🔒 Security Analysis

+ {security_content} +
+ +
+

📋 Recommendations

+ {recommendations_content} +
+
+ + + """ + + # Generate content sections + metadata = self.results.get("metadata", {}) + categorized = self.results.get("categorized_errors", {}) + stats = categorized.get("statistics", {}) + + # Metrics cards + metrics_cards = self._generate_metrics_cards(stats) + + # Overview content + overview_content = self._generate_overview_content() + + # Entrypoints content + entrypoints_content = self._generate_entrypoints_content() + + # Dead code content + dead_code_content = self._generate_dead_code_content() + + # Error analysis content + error_analysis_content = self._generate_error_analysis_content() + + # Performance content + performance_content = self._generate_performance_content() + + # Security content + security_content = self._generate_security_content() + + # Recommendations content + recommendations_content = self._generate_recommendations_content() + + # Fill template + html_report = html_template.format( + target_path=metadata.get("target_path", "Unknown"), + timestamp=metadata.get("analysis_start_time", "Unknown"), + metrics_cards=metrics_cards, + overview_content=overview_content, + entrypoints_content=entrypoints_content, + dead_code_content=dead_code_content, + error_analysis_content=error_analysis_content, + performance_content=performance_content, + security_content=security_content, + recommendations_content=recommendations_content + ) + + return html_report + + def _generate_metrics_cards(self, stats: Dict) -> str: + """Generate HTML for metrics cards.""" + cards = [] + + metrics = [ + ("Total Errors", stats.get("total_errors", 0), "dc3545"), + ("Critical", stats.get("critical_errors", 0), "dc3545"), + ("Warnings", stats.get("warnings", 0), "ffc107"), + ("Files Analyzed", stats.get("files_with_errors", 0), "28a745"), + ("Tools Used", stats.get("tools_with_findings", 0), "17a2b8"), + ("Categories", stats.get("categories_affected", 0), "6f42c1"), + ("Fixable", stats.get("fixable_count", 0), "28a745"), + ("Security Issues", stats.get("security_issues", 0), "dc3545") + ] + + for label, value, color in metrics: + cards.append(f""" +
+
{value}
+
{label}
+
+ """) + + return "".join(cards) + + def _generate_overview_content(self) -> str: + """Generate overview content.""" + gs_analysis = self.results.get("graph_sitter_analysis", {}) + summary = self.results.get("summary", {}) + + content = [] + + if gs_analysis.get("codebase_summary"): + codebase_sum = gs_analysis["codebase_summary"] + content.append(f""" +
+ Codebase Structure: + 📁 Total Files: {codebase_sum.get('total_files', 0)} + 🔧 Functions: {codebase_sum.get('total_functions', 0)} + 🏗️ Classes: {codebase_sum.get('total_classes', 0)} + 📦 Imports: {codebase_sum.get('total_imports', 0)} + 🌐 External Dependencies: {codebase_sum.get('external_dependencies', 0)} + 📏 Lines of Code: {codebase_sum.get('lines_of_code', 0):,} + 📊 Complexity Score: {codebase_sum.get('complexity_score', 0):.2f} +
+ """) + + # Quality score with progress bar + quality_score = summary.get("quality_score", 0) + content.append(f""" +
+

Quality Score: {quality_score:.1f}/100

+
+
+
+
+ """) + + return "".join(content) + + def _generate_entrypoints_content(self) -> str: + """Generate entrypoints content.""" + entrypoints = self.results.get("entrypoints", []) + + if not entrypoints: + return "

No entrypoints detected in the codebase.

" + + content = [f"

Found {len(entrypoints)} entrypoints in the codebase:

"] + content.append("") + content.append("") + + for entry in entrypoints: + content.append(f""" + + + + + + + """) + + content.append("
TypeNameFileLine
{entry.get('type', 'unknown')}{entry.get('name', 'unknown')}{os.path.basename(entry.get('file_path', ''))}{entry.get('line_number', 'N/A')}
") + return "".join(content) + + def _generate_dead_code_content(self) -> str: + """Generate dead code content.""" + dead_code = self.results.get("dead_code", []) + + if not dead_code: + return "

No dead code detected. Great job!

" + + # Group by type + by_type = defaultdict(list) + for item in dead_code: + by_type[item.get("category", "unknown")].append(item) + + content = [f"

Found {len(dead_code)} dead code items:

"] + + for category, items in by_type.items(): + content.append(f""" +
+
{category.replace('_', ' ').title()}: {len(items)} items
+
    + """) + + for item in items[:20]: # Limit display + content.append(f""" +
  • + {item.get('type', 'unknown').title()}: {item.get('name', 'unknown')} +
    + 📁 {os.path.basename(item.get('file_path', ''))} + 📍 Line {item.get('line_number', 'N/A')} + 💭 {item.get('context', '')} +
    +
  • + """) + + if len(items) > 20: + content.append(f"
  • ... and {len(items) - 20} more items
  • ") + + content.extend(["
", "
"]) + + return "".join(content) + + def _generate_error_analysis_content(self) -> str: + """Generate error analysis content.""" + categorized = self.results.get("categorized_errors", {}) + + content = [] + + # Error category filter buttons + content.append(""" +
+ + """) + + by_category = categorized.get("by_category", {}) + for category in by_category.keys(): + content.append(f""" + + """) + + content.append("
") + + # Error sections by category + for category, errors in by_category.items(): + content.append(f""" +
+
{category.upper()}: {len(errors)} issues
+
    + """) + + # Group errors by file for better organization + errors_by_file = defaultdict(list) + for error in errors[:100]: # Limit for performance + errors_by_file[error.file_path].append(error) + + for file_path, file_errors in list(errors_by_file.items())[:15]: # Limit files shown + content.append(f""" + + ") + + content.extend(["
", "
"]) + + return "".join(content) + + def _generate_performance_content(self) -> str: + """Generate performance analysis content.""" + perf_metrics = self.results.get("performance_metrics", {}) + + if not perf_metrics or "error" in perf_metrics: + return "

Performance metrics not available.

" + + content = [] + + # Function metrics + func_metrics = perf_metrics.get("function_metrics", {}) + if func_metrics: + content.append(f""" +
+

Function Analysis

+
    +
  • Total Functions: {func_metrics.get('total_functions', 0)}
  • +
  • Average Complexity: {func_metrics.get('average_complexity', 0):.2f}
  • +
  • Maximum Complexity: {func_metrics.get('max_complexity', 0):.2f}
  • +
  • High Complexity Functions: {func_metrics.get('high_complexity_count', 0)}
  • +
+
+ """) + + # Performance warnings + warnings = perf_metrics.get("performance_warnings", []) + if warnings: + content.append("

Performance Warnings

") + content.append("
    ") + + for warning in warnings: + content.append(f""" +
  • + {warning.get('type', 'unknown').replace('_', ' ').title()} +
    + {warning.get('function', warning.get('file', 'Unknown'))} + - {warning.get('recommendation', '')} +
    +
  • + """) + + content.append("
") + + return "".join(content) + + def _generate_security_content(self) -> str: + """Generate security analysis content.""" + categorized = self.results.get("categorized_errors", {}) + security_issues = categorized.get("security_critical", []) + + if not security_issues: + return "

✅ Security Status

No critical security issues detected.

" + + content = [f"

Found {len(security_issues)} security issues requiring attention:

"] + content.append("
    ") + + for issue in security_issues: + severity_class = "error" if issue.severity == DiagnosticSeverity.ERROR else "warning" + + content.append(f""" +
  • + 🔒 {issue.message} +
    + 📁 {os.path.basename(issue.file_path)} + 📍 Line {issue.line_number or 'N/A'} + 🔧 {issue.tool_source} + {f'🏷️ {issue.error_code}' if issue.error_code else ''} +
    + Security Risk + {f'Confidence: {issue.confidence:.1%}' if hasattr(issue, 'confidence') else ''} +
    + {f'
    Fix: {issue.fix_suggestion}
    ' if issue.fix_suggestion else ''} +
  • + """) + + content.append("
") + return "".join(content) + + def _generate_recommendations_content(self) -> str: + """Generate recommendations content.""" + summary = self.results.get("summary", {}) + recommendations = summary.get("recommendations", []) + action_items = summary.get("action_items", []) + + content = [] + + if recommendations: + content.append("
") + content.append("

🎯 Key Recommendations

") + content.append("
    ") + for rec in recommendations: + content.append(f"
  • {rec}
  • ") + content.append("
") + content.append("
") + + if action_items: + content.append("

📋 Action Items

") + content.append("") + content.append("") + + for item in action_items: + priority_color = { + "high": "#dc3545", + "medium": "#ffc107", + "low": "#28a745" + }.get(item.get("priority", "low"), "#6c757d") + + content.append(f""" + + + + + + + """) + + content.append("
PriorityCategoryActionEffort
{item.get('priority', 'low').title()}{item.get('category', 'unknown').title()}{item.get('action', '')}{item.get('estimated_effort', 'unknown').title()}
") + + return "".join(content) + + +class InteractiveAnalyzer: + """Interactive analyzer for real-time code analysis and exploration.""" + + def __init__(self, analyzer: ComprehensiveAnalyzer): + self.analyzer = analyzer + self.current_context = None + self.analysis_history = [] + + def start_interactive_session(self): + """Start an interactive analysis session.""" + print("\n" + "="*80) + print("INTERACTIVE CODE ANALYSIS SESSION") + print("="*80) + print("Commands:") + print(" analyze - Analyze specific file") + print(" summary - Show codebase summary") + print(" errors - Show errors by category") + print(" function - Analyze specific function") + print(" class - Analyze specific class") + print(" deps - Show dependency analysis") + print(" security - Show security issues") + print(" performance - Show performance metrics") + print(" dead-code - Show dead code") + print(" export - Export report (json/html)") + print(" help - Show this help") + print(" quit - Exit session") + print("\n") + + while True: + try: + command = input("analysis> ").strip().lower() + + if command == "quit" or command == "exit": + break + elif command == "help": + self._show_help() + elif command == "summary": + self._show_summary() + elif command.startswith("analyze "): + file_path = command[8:].strip() + self._analyze_file(file_path) + elif command.startswith("errors "): + category = command[7:].strip() + self._show_errors_by_category(category) + elif command.startswith("function "): + func_name = command[9:].strip() + self._analyze_function(func_name) + elif command.startswith("class "): + class_name = command[6:].strip() + self._analyze_class(class_name) + elif command == "deps": + self._show_dependencies() + elif command == "security": + self._show_security_issues() + elif command == "performance": + self._show_performance_metrics() + elif command == "dead-code": + self._show_dead_code() + elif command.startswith("export "): + format_type = command[7:].strip() + self._export_report(format_type) + else: + print(f"Unknown command: {command}. Type 'help' for available commands.") + + except KeyboardInterrupt: + print("\nSession interrupted. Type 'quit' to exit.") + except Exception as e: + print(f"Error: {e}") + + def _show_help(self): + """Show detailed help information.""" + print(""" +Available Commands: + +📊 ANALYSIS COMMANDS: + summary - Show comprehensive codebase summary + analyze - Deep analysis of specific file + function - Detailed function analysis + class - Detailed class analysis + +🚨 ERROR COMMANDS: + errors - Show errors by category (syntax, type, security, etc.) + security - Show all security-related issues + performance - Show performance bottlenecks + +🔍 EXPLORATION COMMANDS: + deps - Show dependency analysis and graphs + dead-code - Show unused code detection results + +📤 EXPORT COMMANDS: + export json - Export full report as JSON + export html - Export interactive HTML report + +💡 TIPS: + - Use tab completion for file names and function names + - Commands are case-insensitive + - Use 'quit' or Ctrl+C to exit + """) + + def _show_summary(self): + """Show interactive summary.""" + if not hasattr(self.analyzer, 'last_results'): + print("No analysis results available. Run analysis first.") + return + + results = self.analyzer.last_results + summary = results.get("summary", {}) + overview = summary.get("overview", {}) + + print("\n📊 CODEBASE SUMMARY") + print("-" * 50) + print(f"Total Errors: {overview.get('total_errors', 0)}") + print(f"Critical Errors: {overview.get('critical_errors', 0)}") + print(f"Warnings: {overview.get('warnings', 0)}") + print(f"Files Analyzed: {overview.get('files_analyzed', 0)}") + print(f"Analysis Duration: {overview.get('analysis_duration', 0):.2f}s") + + quality_score = summary.get("quality_score", 0) + print(f"\nQuality Score: {quality_score:.1f}/100") + + # Show top recommendations + recommendations = summary.get("recommendations", []) + if recommendations: + print("\n💡 TOP RECOMMENDATIONS:") + for i, rec in enumerate(recommendations[:5], 1): + print(f" {i}. {rec}") + + def _analyze_file(self, file_path: str): + """Analyze specific file interactively.""" + if not self.analyzer.graph_analysis: + print("Graph analysis not available.") + return + + file_summary = self.analyzer.graph_analysis.get_file_summary(file_path) + + if "error" in file_summary: + print(f"Error: {file_summary['error']}") + return + + print(f"\n📄 FILE ANALYSIS: {os.path.basename(file_path)}") + print("-" * 50) + print(f"Path: {file_path}") + print(f"Lines of Code: {file_summary.get('lines_of_code', 0)}") + print(f"Functions: {file_summary.get('functions', 0)}") + print(f"Classes: {file_summary.get('classes', 0)}") + print(f"Imports: {file_summary.get('imports', 0)}") + print(f"Complexity: {file_summary.get('complexity', 0):.2f}") + + # Show file-specific errors + if hasattr(self.analyzer, 'last_results'): + categorized = self.analyzer.last_results.get("categorized_errors", {}) + file_errors = categorized.get("by_file", {}).get(file_path, []) + + if file_errors: + print(f"\n🚨 ERRORS IN THIS FILE: {len(file_errors)}") + for error in file_errors[:10]: + severity_icon = { + DiagnosticSeverity.ERROR: "⚠️", + DiagnosticSeverity.WARNING: "👉", + DiagnosticSeverity.INFORMATION: "🔍" + }.get(error.severity, "•") + + print(f" {severity_icon} Line {error.line_number or 'N/A'}: {error.message} [{error.tool_source}]") + else: + print("\n✅ No errors found in this file!") + + def _show_errors_by_category(self, category: str): + """Show errors filtered by category.""" + if not hasattr(self.analyzer, 'last_results'): + print("No analysis results available.") + return + + categorized = self.analyzer.last_results.get("categorized_errors", {}) + by_category = categorized.get("by_category", {}) + + if category not in by_category: + available = ", ".join(by_category.keys()) + print(f"Category '{category}' not found. Available: {available}") + return + + errors = by_category[category] + print(f"\n🚨 {category.upper()} ERRORS: {len(errors)}") + print("-" * 50) + + # Group by file + by_file = defaultdict(list) + for error in errors: + by_file[error.file_path].append(error) + + for file_path, file_errors in list(by_file.items())[:10]: + print(f"\n📁 {os.path.basename(file_path)} ({len(file_errors)} issues):") + + for error in file_errors[:5]: + severity_icon = { + DiagnosticSeverity.ERROR: "⚠️", + DiagnosticSeverity.WARNING: "👉", + DiagnosticSeverity.INFORMATION: "🔍" + }.get(error.severity, "•") + + print(f" {severity_icon} Line {error.line_number or 'N/A'}: {error.message}") + if error.fix_suggestion: + print(f" 💡 Fix: {error.fix_suggestion}") + + def _analyze_function(self, function_name: str): + """Analyze specific function.""" + if not self.analyzer.graph_analysis: + print("Graph analysis not available.") + return + + func_analysis = self.analyzer.graph_analysis.get_function_analysis(function_name) + + if "error" in func_analysis: + print(f"Error: {func_analysis['error']}") + return + + print(f"\n🔧 FUNCTION ANALYSIS: {function_name}") + print("-" * 50) + print(f"File: {os.path.basename(func_analysis.get('file_path', ''))}") + print(f"Line: {func_analysis.get('line_number', 'N/A')}") + print(f"Parameters: {', '.join(func_analysis.get('parameters', []))}") + print(f"Async: {'Yes' if func_analysis.get('is_async', False) else 'No'}") + print(f"Decorators: {', '.join(func_analysis.get('decorators', [])) or 'None'}") + print(f"Complexity: {func_analysis.get('complexity', 0):.2f}") + print(f"Usage Count: {func_analysis.get('usages', 0)}") + + # Show function-specific errors + if hasattr(self.analyzer, 'last_results'): + all_errors = self.analyzer.last_results.get("errors", []) + func_errors = [e for e in all_errors if function_name in e.message or e.line_number == func_analysis.get('line_number')] + + if func_errors: + print(f"\n🚨 ISSUES IN THIS FUNCTION: {len(func_errors)}") + for error in func_errors: + print(f" • {error.message} [{error.tool_source}]") + + def _export_report(self, format_type: str): + """Export analysis report.""" + if not hasattr(self.analyzer, 'last_results'): + print("No analysis results to export.") + return + + generator = ReportGenerator(self.analyzer.last_results) + timestamp = time.strftime("%Y%m%d_%H%M%S") + + if format_type == "json": + filename = f"analysis_report_{timestamp}.json" + content = generator.generate_json_report() + elif format_type == "html": + filename = f"analysis_report_{timestamp}.html" + content = generator.generate_html_report() + else: + print(f"Unsupported format: {format_type}. Use 'json' or 'html'.") + return + + try: + with open(filename, 'w', encoding='utf-8') as f: + f.write(content) + print(f"Report exported to: {filename}") + except Exception as e: + print(f"Failed to export report: {e}") + + +def main(): + """Main entry point with comprehensive argument parsing.""" + parser = argparse.ArgumentParser( + description="Comprehensive Python Code Analysis with Graph-Sitter Integration", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python analysis.py --target /path/to/project --comprehensive + python analysis.py --target file.py --interactive --verbose + python analysis.py --target /path/to/project --graph-sitter --format html + python analysis.py --target /path/to/project --security-focus --export-db + """ + ) + + # Core arguments + parser.add_argument( + "--target", required=True, + help="Path to file or directory to analyze" + ) + parser.add_argument( + "--verbose", action="store_true", + help="Enable verbose output" + ) + parser.add_argument( + "--config", + help="Path to configuration file (JSON/YAML)" + ) + + # Analysis modes + parser.add_argument( + "--comprehensive", action="store_true", + help="Run comprehensive analysis with all tools" + ) + parser.add_argument( + "--graph-sitter", action="store_true", + help="Enable graph-sitter analysis" + ) + parser.add_argument( + "--interactive", action="store_true", + help="Start interactive analysis session" + ) + parser.add_argument( + "--lsp", action="store_true", + help="Enable LSP diagnostics" + ) + + # Tool selection + parser.add_argument( + "--tools", nargs="+", + choices=list(ComprehensiveAnalyzer.COMPREHENSIVE_TOOLS.keys()), + help="Specific tools to run" + ) + parser.add_argument( + "--exclude", nargs="+", + choices=list(ComprehensiveAnalyzer.COMPREHENSIVE_TOOLS.keys()), + help="Tools to exclude" + ) + parser.add_argument( + "--categories", nargs="+", + choices=[cat.value for cat in ErrorCategory], + help="Specific error categories to focus on" + ) + + # Focus modes + parser.add_argument( + "--security-focus", action="store_true", + help="Focus on security analysis" + ) + parser.add_argument( + "--performance-focus", action="store_true", + help="Focus on performance analysis" + ) + parser.add_argument( + "--type-focus", action="store_true", + help="Focus on type checking" + ) + + # Output options + parser.add_argument( + "--format", choices=["text", "json", "html"], + default="html", + help="Output format" + ) + parser.add_argument( + "--output", + help="Output file path" + ) + parser.add_argument( + "--export-db", action="store_true", + help="Export results to SQLite database" + ) + + # Performance options + parser.add_argument( + "--timeout", type=int, default=300, + help="Default timeout for tools" + ) + parser.add_argument( + "--parallel", type=int, default=3, + help="Number of parallel tool executions" + ) + parser.add_argument( + "--cache", action="store_true", + help="Enable analysis caching" + ) + + args = parser.parse_args() + + # Setup logging + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=log_level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # Load configuration + config = {} + if args.config and os.path.exists(args.config): + try: + with open(args.config, 'r') as f: + if args.config.endswith('.json'): + config = json.load(f) + elif args.config.endswith(('.yaml', '.yml')): + config = yaml.safe_load(f) + except Exception as e: + print(f"Warning: Failed to load config file: {e}") + + # Apply focus modes + if args.security_focus: + config["focus"] = "security" + args.tools = ["ruff", "bandit", "safety", "semgrep", "dlint"] + elif args.performance_focus: + config["focus"] = "performance" + args.tools = ["radon", "xenon", "vulture", "py-spy"] + elif args.type_focus: + config["focus"] = "type" + args.tools = ["mypy", "pyright", "ruff"] + + # Apply tool filters + if args.tools or args.exclude: + tool_config = {} + for tool_name in ComprehensiveAnalyzer.COMPREHENSIVE_TOOLS: + enabled = True + if args.tools: + enabled = tool_name in args.tools + if args.exclude: + enabled = enabled and tool_name not in args.exclude + tool_config[tool_name] = {"enabled": enabled} + config["tools"] = tool_config + + # Initialize analyzer + print("Initializing comprehensive analyzer...") + analyzer = ComprehensiveAnalyzer(args.target, config, args.verbose) + + if args.interactive: + # Run initial analysis then start interactive session + print("Running initial analysis...") + results = analyzer.run_comprehensive_analysis() + analyzer.last_results = results + + # Start interactive session + interactive = InteractiveAnalyzer(analyzer) + interactive.start_interactive_session() + + else: + # Run analysis + print("Running comprehensive analysis...") + results = analyzer.run_comprehensive_analysis() + + # Store in database if requested + if args.export_db: + db_path = f"analysis_{int(time.time())}.db" + db = ErrorDatabase(db_path) + session_id = db.create_session( + args.target, + results.get("metadata", {}).get("tools_used", []), + config + ) + db.store_errors(results.get("errors", []), session_id) + db.update_session(session_id, len(results.get("errors", []))) + print(f"Results stored in database: {db_path}") + + # Generate report + generator = ReportGenerator(results) + + if args.format == "json": + output = generator.generate_json_report() + elif args.format == "html": + output = generator.generate_html_report() + else: + output = generator.generate_terminal_report() + + # Determine output path + if args.output: + output_path = args.output + else: + timestamp = time.strftime("%Y%m%d_%H%M%S") + target_name = os.path.basename(args.target) + output_path = f"analysis_{target_name}_{timestamp}.{args.format}" + + # Write output + try: + with open(output_path, 'w', encoding='utf-8') as f: + f.write(output) + print(f"\nAnalysis complete! Report saved to: {output_path}") + + # Show summary in terminal + if args.format != "text": + print(generator.generate_terminal_report()) + + # Open HTML report in browser + if args.format == "html": + try: + import webbrowser + webbrowser.open(f"file://{os.path.abspath(output_path)}") + except Exception: + pass + + except Exception as e: + print(f"Error writing output: {e}") + # Fallback to terminal output + print(generator.generate_terminal_report()) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Libraries/graph_sitter_lib/autogenlib_adapter.py b/Libraries/graph_sitter_lib/autogenlib_adapter.py new file mode 100644 index 00000000..145e8eb8 --- /dev/null +++ b/Libraries/graph_sitter_lib/autogenlib_adapter.py @@ -0,0 +1,1130 @@ +#!/usr/bin/env python3 +"""AutoGenLib Adapter - Consolidated Module + +Provides comprehensive AutoGenLib integration with: +- Context enrichment for AI-driven code analysis +- AI-powered error resolution +- Batch processing capabilities +- Fix validation and strategy generation + +This module consolidates functionality from: +- autogenlib_context.py: Context gathering and enrichment +- autogenlib_ai_resolve.py: AI-driven error resolution +""" + +import json +import logging +import os +import time +from typing import Any + +import openai + +from graph_sitter import Codebase +from graph_sitter.extensions.autogenlib._cache import get_all_modules, get_cached_code, get_cached_prompt +from graph_sitter.extensions.autogenlib._caller import get_caller_info +from graph_sitter.extensions.autogenlib._context import extract_defined_names, get_module_context +from graph_sitter.extensions.autogenlib._generator import ( + extract_python_code, + get_codebase_context as get_autogenlib_codebase_context, + validate_code, +) +from graph_sitter.extensions.lsp.solidlsp.lsp_protocol_handler.lsp_types import Diagnostic +from graph_sitter_analysis import GraphSitterAnalyzer +from lsp_diagnostics import EnhancedDiagnostic + +logger = logging.getLogger(__name__) + + +# ================================================================================ +# CONTEXT ENRICHMENT FUNCTIONS +# ================================================================================ + + + +def get_llm_codebase_overview(codebase: Codebase) -> dict[str, str]: + """Provides a high-level summary of the entire codebase for the LLM.""" + analyzer = GraphSitterAnalyzer(codebase) + overview = analyzer.get_codebase_overview() + return {"codebase_overview": overview.get("summary", "No specific codebase overview available.")} + + +def get_comprehensive_symbol_context(codebase: Codebase, symbol_name: str, filepath: str | None = None) -> dict[str, Any]: + """Get comprehensive context for a symbol using all available Graph-Sitter APIs.""" + analyzer = GraphSitterAnalyzer(codebase) + + # Get symbol details + symbol_details = analyzer.get_symbol_details(symbol_name, filepath) + + # Get extended context using reveal_symbol + reveal_info = analyzer.reveal_symbol_relationships(symbol_name, filepath=filepath, max_depth=3, max_tokens=2000) + + # Get function-specific details if it's a function + function_details = None + if symbol_details.get("error") is None and symbol_details.get("symbol_type") == "Function": + function_details = analyzer.get_function_details(symbol_name, filepath) + + # Get class-specific details if it's a class + class_details = None + if symbol_details.get("error") is None and symbol_details.get("symbol_type") == "Class": + class_details = analyzer.get_class_details(symbol_name, filepath) + + return { + "symbol_details": symbol_details, + "reveal_info": reveal_info, + "function_details": function_details, + "class_details": class_details, + "extended_dependencies": reveal_info.dependencies if reveal_info.dependencies else [], + "extended_usages": reveal_info.usages if reveal_info.usages else [], + } + + +def get_file_context(codebase: Codebase, filepath: str) -> dict[str, Any]: + """Get comprehensive context for a file.""" + analyzer = GraphSitterAnalyzer(codebase) + + # Get file details + file_details = analyzer.get_file_details(filepath) + + # Get import relationships + import_analysis = analyzer.analyze_import_relationships(filepath) + + # Get directory listing for context + directory_path = os.path.dirname(filepath) or "./" + directory_info = analyzer.list_directory_contents(directory_path, depth=1) + + # View file content with line numbers + file_view = analyzer.view_file_content(filepath, line_numbers=True, max_lines=100) + + return { + "file_details": file_details, + "import_analysis": import_analysis, + "directory_context": directory_info, + "file_preview": file_view, + "related_files": [imp["imported_by"] for imp in import_analysis.get("inbound_imports", [])] if import_analysis.get("error") is None else [], + } + + +def get_autogenlib_enhanced_context(enhanced_diagnostic: EnhancedDiagnostic) -> dict[str, Any]: + """Get enhanced context using AutoGenLib's context retrieval capabilities.""" + # Get caller context from AutoGenLib + caller_info = get_caller_info() + + # Get module context if available + module_name = enhanced_diagnostic["relative_file_path"].replace("/", ".").replace(".py", "") + module_context = get_module_context(module_name) + + # Get AutoGenLib's internal codebase context + autogenlib_codebase_context = get_autogenlib_codebase_context() + + # Get all cached modules for broader context + all_cached_modules = get_all_modules() + + # Extract defined names from the file + defined_names = extract_defined_names(enhanced_diagnostic["file_content"]) + + # Get cached code and prompts + cached_code = get_cached_code(module_name) + cached_prompt = get_cached_prompt(module_name) + + return { + "caller_info": { + "filename": caller_info.get("filename", "unknown"), + "code": caller_info.get("code", ""), + "code_length": len(caller_info.get("code", "")), + "relevant_snippets": _extract_relevant_code_snippets(caller_info.get("code", ""), enhanced_diagnostic), + }, + "module_context": { + "module_name": module_name, + "defined_names": list(defined_names), + "cached_code": cached_code or "", + "cached_prompt": cached_prompt or "", + "has_cached_context": bool(module_context), + "module_dependencies": _analyze_module_dependencies(module_name, all_cached_modules), + }, + "autogenlib_codebase_context": autogenlib_codebase_context, + "cached_modules_overview": { + "total_modules": len(all_cached_modules), + "module_names": list(all_cached_modules.keys()), + "related_modules": _find_related_modules(module_name, all_cached_modules), + }, + "file_analysis": { + "defined_names_count": len(defined_names), + "file_size": len(enhanced_diagnostic["file_content"]), + "line_count": len(enhanced_diagnostic["file_content"].splitlines()), + "import_statements": _count_import_statements(enhanced_diagnostic["file_content"]), + "function_definitions": _count_function_definitions(enhanced_diagnostic["file_content"]), + "class_definitions": _count_class_definitions(enhanced_diagnostic["file_content"]), + }, + } + + +def get_ai_fix_context(enhanced_diagnostic: EnhancedDiagnostic, codebase: Codebase) -> EnhancedDiagnostic: + """Aggregates all relevant context for the AI to resolve a diagnostic. + This is the central context aggregation function. + """ + # 1. Get Graph-Sitter context + diag = enhanced_diagnostic["diagnostic"] + + # Find symbol at diagnostic location + symbol_at_error = None + try: + file_obj = codebase.get_file(enhanced_diagnostic["relative_file_path"]) + + # Try to find function containing the error + for func in file_obj.functions: + if hasattr(func, "start_point") and hasattr(func, "end_point") and func.start_point.line <= diag.range.line <= func.end_point.line: + symbol_at_error = func + break + + # Try to find class containing the error if no function found + if not symbol_at_error: + for cls in file_obj.classes: + if hasattr(cls, "start_point") and hasattr(cls, "end_point") and cls.start_point.line <= diag.range.line <= cls.end_point.line: + symbol_at_error = cls + break + + except Exception as e: + logger.warning(f"Could not find symbol at error location: {e}") + + # Get comprehensive symbol context if found + symbol_context = {} + if symbol_at_error: + symbol_context = get_comprehensive_symbol_context(codebase, symbol_at_error.name, enhanced_diagnostic["relative_file_path"]) + + # Get file context + file_context = get_file_context(codebase, enhanced_diagnostic["relative_file_path"]) + + # Get codebase overview + codebase_overview = get_llm_codebase_overview(codebase) + + # 2. Get AutoGenLib enhanced context + autogenlib_context = get_autogenlib_enhanced_context(enhanced_diagnostic) + + # 3. Analyze related patterns using Graph-Sitter + analyzer = GraphSitterAnalyzer(codebase) + + # Find similar errors in the codebase + similar_patterns = [] + if diag.code: + # Look for other diagnostics with the same code + for other_file in codebase.files: + if other_file.filepath != enhanced_diagnostic["relative_file_path"]: + # This is a simplified pattern matching - in practice, you'd want more sophisticated analysis + if diag.code.lower() in other_file.source.lower(): + similar_patterns.append({"file": other_file.filepath, "pattern": diag.code, "confidence": 0.6, "line_count": len(other_file.source.splitlines())}) + + # 4. Get architectural context + architectural_context = { + "file_role": _determine_file_role(enhanced_diagnostic["relative_file_path"]), + "module_dependencies": len(file_context.get("import_analysis", {}).get("imports_analysis", [])), + "is_test_file": "test" in enhanced_diagnostic["relative_file_path"].lower(), + "is_main_file": enhanced_diagnostic["relative_file_path"].endswith("main.py") or enhanced_diagnostic["relative_file_path"].endswith("__main__.py"), + "directory_depth": len(enhanced_diagnostic["relative_file_path"].split(os.sep)) - 1, + "related_symbols": _find_related_symbols_in_file(codebase, enhanced_diagnostic["relative_file_path"], diag.range.line), + } + + # 5. Get error resolution context + resolution_context = { + "error_category": _categorize_error(diag), + "common_fixes": _get_common_fixes_for_error(diag), + "resolution_confidence": _estimate_resolution_confidence(diag, symbol_context), + "requires_manual_review": _requires_manual_review(diag), + "automated_fix_available": _has_automated_fix(diag), + } + + # 6. Aggregate all context + enhanced_diagnostic["graph_sitter_context"] = { + "symbol_context": symbol_context, + "file_context": file_context, + "codebase_overview": codebase_overview, + "similar_patterns": similar_patterns, + "architectural_context": architectural_context, + "resolution_context": resolution_context, + "visualization_data": _get_visualization_context(analyzer, symbol_at_error) if symbol_at_error else {}, + } + + enhanced_diagnostic["autogenlib_context"] = autogenlib_context + + return enhanced_diagnostic + + +def _extract_relevant_code_snippets(caller_code: str, enhanced_diagnostic: EnhancedDiagnostic) -> list[str]: + """Extract relevant code snippets from caller code.""" + if not caller_code: + return [] + + snippets = [] + lines = caller_code.split("\n") + + # Look for imports related to the diagnostic file + file_name = os.path.basename(enhanced_diagnostic["relative_file_path"]).replace(".py", "") + for i, line in enumerate(lines): + if "import" in line and file_name in line: + # Include surrounding context + start = max(0, i - 2) + end = min(len(lines), i + 3) + snippets.append("\n".join(lines[start:end])) + + # Look for function calls that might be related to the error + diag_message = enhanced_diagnostic["diagnostic"].message.lower() + for i, line in enumerate(lines): + if any(word in line.lower() for word in diag_message.split() if len(word) > 3): + start = max(0, i - 1) + end = min(len(lines), i + 2) + snippets.append("\n".join(lines[start:end])) + + return snippets[:5] # Limit to 5 most relevant snippets + + +def _analyze_module_dependencies(module_name: str, all_cached_modules: dict[str, Any]) -> dict[str, Any]: + """Analyze dependencies between cached modules.""" + dependencies = {"direct_dependencies": [], "dependent_modules": [], "circular_dependencies": []} + + if module_name not in all_cached_modules: + return dependencies + + module_code = all_cached_modules[module_name].get("code", "") + + # Find direct dependencies + for other_module, other_data in all_cached_modules.items(): + if other_module != module_name: + if f"from {other_module}" in module_code or f"import {other_module}" in module_code: + dependencies["direct_dependencies"].append(other_module) + + other_code = other_data.get("code", "") + if f"from {module_name}" in other_code or f"import {module_name}" in other_code: + dependencies["dependent_modules"].append(other_module) + + # Check for circular dependencies + for dep in dependencies["direct_dependencies"]: + if module_name in dependencies["dependent_modules"] and dep in dependencies["dependent_modules"]: + dependencies["circular_dependencies"].append(dep) + + return dependencies + + +def _find_related_modules(module_name: str, all_cached_modules: dict[str, Any]) -> list[str]: + """Find modules related to the given module.""" + related = [] + + # Find modules with similar names + base_name = module_name.split(".")[-1] + for other_module in all_cached_modules.keys(): + other_base = other_module.split(".")[-1] + if base_name in other_base or other_base in base_name: + if other_module != module_name: + related.append(other_module) + + return related[:10] # Limit to 10 most related + + +def _count_import_statements(file_content: str) -> int: + """Count import statements in file content.""" + lines = file_content.split("\n") + return sum(1 for line in lines if line.strip().startswith(("import ", "from "))) + + +def _count_function_definitions(file_content: str) -> int: + """Count function definitions in file content.""" + return len(re.findall(r"^\s*def\s+\w+", file_content, re.MULTILINE)) + + +def _count_class_definitions(file_content: str) -> int: + """Count class definitions in file content.""" + return len(re.findall(r"^\s*class\s+\w+", file_content, re.MULTILINE)) + + +def _determine_file_role(filepath: str) -> str: + """Determine the role of a file in the codebase architecture.""" + filepath_lower = filepath.lower() + + if "test" in filepath_lower: + return "test" + elif "main" in filepath_lower or "__main__" in filepath_lower: + return "entry_point" + elif "config" in filepath_lower or "settings" in filepath_lower: + return "configuration" + elif "model" in filepath_lower or "schema" in filepath_lower: + return "data_model" + elif "view" in filepath_lower or "template" in filepath_lower: + return "presentation" + elif "controller" in filepath_lower or "handler" in filepath_lower: + return "controller" + elif "service" in filepath_lower or "business" in filepath_lower: + return "business_logic" + elif "util" in filepath_lower or "helper" in filepath_lower: + return "utility" + elif "api" in filepath_lower or "endpoint" in filepath_lower: + return "api" + elif "__init__" in filepath_lower: + return "module_init" + else: + return "general" + + +def _find_related_symbols_in_file(codebase: Codebase, filepath: str, error_line: int) -> list[dict[str, Any]]: + """Find symbols related to the error location.""" + try: + file_obj = codebase.get_file(filepath) + related_symbols = [] + + # Find symbols near the error line + for func in file_obj.functions: + if hasattr(func, "start_point") and hasattr(func, "end_point"): + if func.start_point.line <= error_line <= func.end_point.line: + related_symbols.append( + { + "name": func.name, + "type": "function", + "distance": 0, # Contains the error + "complexity": _calculate_simple_complexity(func), + } + ) + elif abs(func.start_point.line - error_line) <= 10: + related_symbols.append({"name": func.name, "type": "function", "distance": abs(func.start_point.line - error_line), "complexity": _calculate_simple_complexity(func)}) + + # Find classes near the error line + for cls in file_obj.classes: + if hasattr(cls, "start_point") and hasattr(cls, "end_point"): + if cls.start_point.line <= error_line <= cls.end_point.line: + related_symbols.append({"name": cls.name, "type": "class", "distance": 0, "methods_count": len(cls.methods)}) + + return sorted(related_symbols, key=lambda x: x["distance"])[:5] + + except Exception as e: + logger.warning(f"Error finding related symbols: {e}") + return [] + + +def _calculate_simple_complexity(func) -> int: + """Calculate simple complexity metric.""" + if hasattr(func, "source") and func.source: + return func.source.count("if ") + func.source.count("for ") + func.source.count("while ") + 1 + return 1 + + +def _categorize_error(diagnostic: Diagnostic) -> str: + """Categorize error based on diagnostic information.""" + message = diagnostic.message.lower() + code = str(diagnostic.code).lower() if diagnostic.code else "" + + if any(keyword in message for keyword in ["import", "module", "not found"]): + return "import_error" + elif any(keyword in message for keyword in ["type", "annotation", "expected"]): + return "type_error" + elif any(keyword in message for keyword in ["syntax", "invalid", "unexpected"]): + return "syntax_error" + elif any(keyword in message for keyword in ["unused", "defined", "never used"]): + return "unused_code" + elif any(keyword in message for keyword in ["missing", "required", "undefined"]): + return "missing_definition" + elif "circular" in message or "cycle" in message: + return "circular_dependency" + else: + return "general_error" + + +def _get_common_fixes_for_error(diagnostic: Diagnostic) -> list[str]: + """Get common fixes for an error category.""" + category = _categorize_error(diagnostic) + + fixes_map = { + "import_error": ["Add missing import statement", "Fix import path", "Install missing package", "Check module availability"], + "type_error": ["Add type annotations", "Fix type mismatch", "Import missing types", "Update function signature"], + "syntax_error": ["Fix syntax issues", "Check parentheses/brackets", "Fix indentation", "Remove invalid characters"], + "unused_code": ["Remove unused imports", "Remove unused variables", "Add underscore prefix for intentionally unused", "Use the variable or remove it"], + "missing_definition": ["Define missing variable/function", "Add missing import", "Check spelling", "Add default value"], + "circular_dependency": ["Refactor to break circular imports", "Move shared code to separate module", "Use dependency injection", "Reorganize module structure"], + } + + return fixes_map.get(category, ["Manual review required"]) + + +def _estimate_resolution_confidence(diagnostic: Diagnostic, symbol_context: dict[str, Any]) -> float: + """Estimate confidence in automated resolution.""" + confidence = 0.5 # Base confidence + + # Higher confidence for well-understood error types + category = _categorize_error(diagnostic) + category_confidence = {"import_error": 0.8, "unused_code": 0.9, "type_error": 0.7, "syntax_error": 0.6, "missing_definition": 0.5, "circular_dependency": 0.3} + + confidence = category_confidence.get(category, 0.5) + + # Adjust based on symbol context availability + if symbol_context and symbol_context.get("symbol_details", {}).get("error") is None: + confidence += 0.1 + + # Adjust based on error message clarity + if len(diagnostic.message) > 50: # Detailed error messages + confidence += 0.1 + + return min(1.0, confidence) + + +def _requires_manual_review(diagnostic: Diagnostic) -> bool: + """Check if error requires manual review.""" + category = _categorize_error(diagnostic) + manual_review_categories = ["circular_dependency", "missing_definition"] + + return ( + category in manual_review_categories + or "todo" in diagnostic.message.lower() + or "fixme" in diagnostic.message.lower() + or (diagnostic.severity and diagnostic.severity.value == 1) # Critical errors + ) + + +def _has_automated_fix(diagnostic: Diagnostic) -> bool: + """Check if error has available automated fix.""" + category = _categorize_error(diagnostic) + automated_categories = ["unused_code", "import_error", "type_error"] + + return category in automated_categories + + +def _get_visualization_context(analyzer: GraphSitterAnalyzer, symbol) -> dict[str, Any]: + """Get visualization context for a symbol.""" + if not symbol: + return {} + + try: + # Create blast radius visualization + blast_radius = analyzer.create_blast_radius_visualization(symbol.name) + + # Create dependency trace if it's a function + dependency_trace = {} + if hasattr(symbol, "function_calls"): # It's a function + dependency_trace = analyzer.create_dependency_trace_visualization(symbol.name) + + return { + "blast_radius": blast_radius, + "dependency_trace": dependency_trace, + "symbol_relationships": { + "usages_count": len(symbol.usages), + "dependencies_count": len(symbol.dependencies), + "complexity": analyzer._calculate_cyclomatic_complexity(symbol) if hasattr(symbol, "source") else 0, + }, + } + except Exception as e: + logger.warning(f"Error creating visualization context: {e}") + return {} + + +def get_error_pattern_context(codebase: Codebase, error_category: str, max_examples: int = 5) -> dict[str, Any]: + """Get context about similar error patterns in the codebase.""" + analyzer = GraphSitterAnalyzer(codebase) + + pattern_context = { + "category": error_category, + "common_causes": _get_common_causes_for_error_category(error_category), + "resolution_strategies": _get_resolution_strategies_for_error_category(error_category), + "related_files": [], + "similar_errors_count": 0, + "pattern_analysis": {}, + } + + # Search for similar patterns in the codebase + search_terms = _get_search_terms_for_error_category(error_category) + for term in search_terms: + for file_obj in codebase.files: + if hasattr(file_obj, "source") and term.lower() in file_obj.source.lower(): + pattern_context["related_files"].append({"filepath": file_obj.filepath, "matches": file_obj.source.lower().count(term.lower()), "file_role": _determine_file_role(file_obj.filepath)}) + pattern_context["similar_errors_count"] += 1 + + if len(pattern_context["related_files"]) >= max_examples: + break + + # Analyze patterns + if pattern_context["related_files"]: + file_roles = [f["file_role"] for f in pattern_context["related_files"]] + pattern_context["pattern_analysis"] = { + "most_affected_role": max(set(file_roles), key=file_roles.count), + "role_distribution": {role: file_roles.count(role) for role in set(file_roles)}, + "average_matches_per_file": sum(f["matches"] for f in pattern_context["related_files"]) / len(pattern_context["related_files"]), + } + + return pattern_context + + +def _get_common_causes_for_error_category(category: str) -> list[str]: + """Get common causes for an error category.""" + causes_map = { + "import_error": ["Missing package installation", "Incorrect import path", "Module not in PYTHONPATH", "Circular import dependencies"], + "type_error": ["Missing type annotations", "Incorrect type usage", "Type mismatch in function calls", "Generic type parameter issues"], + "syntax_error": ["Missing parentheses or brackets", "Incorrect indentation", "Invalid character usage", "Incomplete statements"], + "unused_code": ["Imports added but never used", "Variables defined but not referenced", "Functions created but not called", "Refactoring artifacts"], + "missing_definition": ["Variable used before definition", "Function called but not defined", "Missing import for used symbol", "Typo in variable/function name"], + "circular_dependency": ["Mutual dependencies between modules", "Poor module organization", "Shared state between modules", "Tight coupling between components"], + } + return causes_map.get(category, ["Unknown causes"]) + + +def _get_resolution_strategies_for_error_category(category: str) -> list[str]: + """Get resolution strategies for an error category.""" + strategies_map = { + "import_error": ["Fix import paths and module names", "Install missing dependencies", "Add modules to PYTHONPATH", "Reorganize module structure"], + "type_error": ["Add explicit type annotations", "Fix type mismatches", "Import missing type definitions", "Update function signatures"], + "syntax_error": ["Fix syntax issues automatically", "Use code formatter", "Check language syntax rules", "Validate with linter"], + "unused_code": ["Remove unused imports and variables", "Use import optimization tools", "Add underscore prefix for intentional unused", "Refactor to eliminate dead code"], + "missing_definition": ["Define missing variables and functions", "Add missing imports", "Fix typos in names", "Add default values where appropriate"], + "circular_dependency": ["Refactor shared code to separate module", "Use dependency injection patterns", "Reorganize module hierarchy", "Break tight coupling between modules"], + } + return strategies_map.get(category, ["Manual review and correction required"]) + + +def _get_search_terms_for_error_category(category: str) -> list[str]: + """Get search terms to find similar patterns for an error category.""" + terms_map = { + "import_error": ["import ", "from ", "ImportError", "ModuleNotFoundError"], + "type_error": ["TypeError", "def ", "class ", "->", ":"], + "syntax_error": ["SyntaxError", "def ", "class ", "if ", "for "], + "unused_code": ["import ", "from ", "def ", "="], + "missing_definition": ["NameError", "UnboundLocalError", "def ", "="], + "circular_dependency": ["import ", "from "], + } + return terms_map.get(category, []) + + + +# ================================================================================ +# AI RESOLUTION FUNCTIONS +# ================================================================================ + + + +def resolve_diagnostic_with_ai(enhanced_diagnostic: EnhancedDiagnostic, codebase: Codebase) -> dict[str, Any]: + """Generates a fix for a given LSP diagnostic using an AI model, with comprehensive context.""" + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + logger.error("OPENAI_API_KEY environment variable not set.") + return {"status": "error", "message": "OpenAI API key not configured."} + + base_url = os.environ.get("OPENAI_API_BASE_URL") + model = os.environ.get("OPENAI_MODEL", "gpt-4o") # Using gpt-4o for better code generation + + client = openai.OpenAI(api_key=api_key, base_url=base_url) + + # Prepare comprehensive context for the LLM + diag = enhanced_diagnostic["diagnostic"] + + # Construct the system message with comprehensive instructions + system_message = """ + You are an expert software engineer and code fixer with deep knowledge of software architecture, + design patterns, and best practices. Your task is to analyze code diagnostics and provide + precise, contextually-aware fixes. + + You have access to: + 1. LSP diagnostic information (static analysis) + 2. Runtime error context (if available) + 3. UI interaction error context (if available) + 4. Graph-Sitter codebase analysis (symbol relationships, dependencies, usages) + 5. AutoGenLib context (caller information, module context) + 6. Architectural context (file role, module structure) + 7. Visualization data (blast radius, dependency traces) + 8. Error pattern analysis (similar errors, resolution strategies) + + Follow these guidelines: + 1. Understand the diagnostic: Analyze the message, severity, and exact location + 2. Consider the full context: Use all provided context to understand the broader implications + 3. Identify root causes: Look beyond symptoms to find underlying issues + 4. Propose comprehensive fixes: Address not just the immediate error but related issues + 5. Maintain code quality: Ensure fixes follow best practices and coding standards + 6. Consider side effects: Think about how changes might affect other parts of the codebase + + Output format: Return a JSON object with: + - 'fixed_code': The corrected code (can be a snippet, function, or entire file) + - 'explanation': Detailed explanation of the fix and why it's necessary + - 'confidence': Confidence level (0.0-1.0) in the fix + - 'side_effects': Potential side effects or additional changes needed + - 'testing_suggestions': Suggestions for testing the fix + - 'related_changes': Other files or symbols that might need updates + """ + + # Construct comprehensive user prompt + user_prompt = f""" + DIAGNOSTIC INFORMATION: + ====================== + Severity: {diag.severity.name if diag.severity else "Unknown"} + Code: {diag.code} + Source: {diag.source} + Message: {diag.message} + File: {enhanced_diagnostic["relative_file_path"]} + Line: {diag.range.line + 1}, Character: {diag.range.character} + End Line: {diag.range.end.line + 1}, End Character: {diag.range.end.character} + + RELEVANT CODE SNIPPET (with '>>>' markers for the diagnostic range): + ================================================================ + ```python + {enhanced_diagnostic["relevant_code_snippet"]} + ``` + + FULL FILE CONTENT: + ================== + ```python + {enhanced_diagnostic["file_content"]} + ``` + + GRAPH-SITTER CONTEXT: + ===================== + Codebase Overview: {enhanced_diagnostic["graph_sitter_context"].get("codebase_overview", {}).get("codebase_overview", "N/A")} + + Symbol Context: {json.dumps(enhanced_diagnostic["graph_sitter_context"].get("symbol_context", {}), indent=2)} + + File Context: {json.dumps(enhanced_diagnostic["graph_sitter_context"].get("file_context", {}), indent=2)} + + Architectural Context: {json.dumps(enhanced_diagnostic["graph_sitter_context"].get("architectural_context", {}), indent=2)} + + Resolution Context: {json.dumps(enhanced_diagnostic["graph_sitter_context"].get("resolution_context", {}), indent=2)} + + Visualization Data: {json.dumps(enhanced_diagnostic["graph_sitter_context"].get("visualization_data", {}), indent=2)} + + AUTOGENLIB CONTEXT: + =================== + {json.dumps(enhanced_diagnostic["autogenlib_context"], indent=2)} + + RUNTIME CONTEXT: + ================ + Runtime Errors: {json.dumps(enhanced_diagnostic["runtime_context"], indent=2)} + + UI Interaction Context: {json.dumps(enhanced_diagnostic["ui_interaction_context"], indent=2)} + + ADDITIONAL CONTEXT: + =================== + Similar Patterns: {json.dumps(enhanced_diagnostic["graph_sitter_context"].get("similar_patterns", []), indent=2)} + + Your task is to provide a comprehensive fix for this diagnostic, considering all the context provided. + Return a JSON object with the required fields: fixed_code, explanation, confidence, side_effects, testing_suggestions, related_changes. + """ + + try: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + temperature=0.1, # Keep it low for deterministic fixes + max_tokens=4000, # Increased for comprehensive responses + ) + + content = response.choices[0].message.content.strip() + fix_info = {} + try: + fix_info = json.loads(content) + except json.JSONDecodeError: + logger.exception(f"AI response was not valid JSON: {content}") + return { + "status": "error", + "message": "AI returned invalid JSON.", + "raw_response": content, + } + + fixed_code = fix_info.get("fixed_code", "") + explanation = fix_info.get("explanation", "No explanation provided.") + confidence = fix_info.get("confidence", 0.5) + side_effects = fix_info.get("side_effects", []) + testing_suggestions = fix_info.get("testing_suggestions", []) + related_changes = fix_info.get("related_changes", []) + + if not fixed_code: + return { + "status": "error", + "message": "AI did not provide fixed code.", + "explanation": explanation, + } + + # Basic validation of the fixed code + if not validate_code(fixed_code): + logger.warning("AI generated code that is not syntactically valid.") + # Attempt to extract valid code if it's wrapped in markdown + extracted_code = extract_python_code(fixed_code) + if validate_code(extracted_code): + fixed_code = extracted_code + else: + return { + "status": "warning", + "message": "AI generated code with syntax errors.", + "fixed_code": fixed_code, + "explanation": explanation, + "confidence": confidence * 0.5, # Reduce confidence for invalid code + } + + return { + "status": "success", + "fixed_code": fixed_code, + "explanation": explanation, + "confidence": confidence, + "side_effects": side_effects, + "testing_suggestions": testing_suggestions, + "related_changes": related_changes, + } + + except openai.APIError as e: + logger.exception(f"OpenAI API error: {e}") + return {"status": "error", "message": f"OpenAI API error: {e}"} + except Exception as e: + logger.exception(f"Error resolving diagnostic with AI: {e}") + return {"status": "error", "message": f"An unexpected error occurred: {e}"} + + +def resolve_runtime_error_with_ai(runtime_error: dict[str, Any], codebase: Codebase) -> dict[str, Any]: + """Resolve runtime errors using AI with full context.""" + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + return {"status": "error", "message": "OpenAI API key not configured."} + + client = openai.OpenAI(api_key=api_key, base_url=os.environ.get("OPENAI_API_BASE_URL")) + + system_message = """ + You are an expert Python developer specializing in runtime error resolution. + You have access to the full traceback, codebase context, and related information. + + Provide comprehensive fixes that: + 1. Address the immediate runtime error + 2. Add proper error handling + 3. Include defensive programming practices + 4. Consider the broader codebase impact + + Return JSON with: fixed_code, explanation, confidence, prevention_measures + """ + + user_prompt = f""" + RUNTIME ERROR: + ============== + Error Type: {runtime_error["error_type"]} + Message: {runtime_error["message"]} + File: {runtime_error["file_path"]} + Line: {runtime_error["line"]} + Function: {runtime_error["function"]} + + FULL TRACEBACK: + =============== + {runtime_error["traceback"]} + + Please provide a comprehensive fix for this runtime error. + """ + + try: + response = client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + temperature=0.1, + max_tokens=2000, + ) + + content = response.choices[0].message.content.strip() + return json.loads(content) + + except Exception as e: + logger.exception(f"Error resolving runtime error with AI: {e}") + return {"status": "error", "message": f"Failed to resolve runtime error: {e}"} + + +def resolve_ui_error_with_ai(ui_error: dict[str, Any], codebase: Codebase) -> dict[str, Any]: + """Resolve UI interaction errors using AI with full context.""" + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + return {"status": "error", "message": "OpenAI API key not configured."} + + client = openai.OpenAI(api_key=api_key, base_url=os.environ.get("OPENAI_API_BASE_URL")) + + system_message = """ + You are an expert frontend developer specializing in React/JavaScript error resolution. + You understand component lifecycles, state management, and user interaction patterns. + + Provide fixes that: + 1. Resolve the immediate UI error + 2. Improve user experience + 3. Add proper error boundaries + 4. Follow React best practices + + Return JSON with: fixed_code, explanation, confidence, user_impact + """ + + user_prompt = f""" + UI INTERACTION ERROR: + ==================== + Error Type: {ui_error["error_type"]} + Message: {ui_error["message"]} + File: {ui_error["file_path"]} + Line: {ui_error["line"]} + Component: {ui_error.get("component", "Unknown")} + + Please provide a comprehensive fix for this UI error. + """ + + try: + response = client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + temperature=0.1, + max_tokens=2000, + ) + + content = response.choices[0].message.content.strip() + return json.loads(content) + + except Exception as e: + logger.exception(f"Error resolving UI error with AI: {e}") + return {"status": "error", "message": f"Failed to resolve UI error: {e}"} + + +def resolve_multiple_errors_with_ai( + enhanced_diagnostics: List[EnhancedDiagnostic], + codebase: Codebase, + max_fixes: int = 10, +) -> dict[str, Any]: + """Resolve multiple errors in batch using AI with pattern recognition.""" + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + return {"status": "error", "message": "OpenAI API key not configured."} + + client = openai.OpenAI(api_key=api_key, base_url=os.environ.get("OPENAI_API_BASE_URL")) + + # Group errors by category and file + error_groups = {} + for enhanced_diag in enhanced_diagnostics[:max_fixes]: + diag = enhanced_diag["diagnostic"] + file_path = enhanced_diag["relative_file_path"] + error_category = enhanced_diag["graph_sitter_context"].get("resolution_context", {}).get("error_category", "general") + + key = f"{error_category}:{file_path}" + if key not in error_groups: + error_groups[key] = [] + error_groups[key].append(enhanced_diag) + + batch_results = [] + + for group_key, group_diagnostics in error_groups.items(): + error_category, file_path = group_key.split(":", 1) + + # Create batch prompt for similar errors + system_message = f""" + You are an expert software engineer specializing in batch error resolution. + You are fixing {len(group_diagnostics)} {error_category} errors in {file_path}. + + Provide a comprehensive fix that addresses all related errors efficiently. + Consider patterns and commonalities between the errors. + + Return JSON with: fixes (array of individual fixes), batch_explanation, overall_confidence + """ + + diagnostics_summary = [] + for enhanced_diag in group_diagnostics: + diag = enhanced_diag["diagnostic"] + diagnostics_summary.append( + { + "line": diag.range.line + 1, + "message": diag.message, + "code": diag.code, + "snippet": enhanced_diag["relevant_code_snippet"], + } + ) + + user_prompt = f""" + BATCH ERROR RESOLUTION: + ====================== + Error Category: {error_category} + File: {file_path} + Number of Errors: {len(group_diagnostics)} + + ERRORS TO FIX: + ============== + {json.dumps(diagnostics_summary, indent=2)} + + FULL FILE CONTENT: + ================== + ```python + {group_diagnostics[0]["file_content"]} + ``` + + CONTEXT SUMMARY: + ================ + Graph-Sitter Context: {json.dumps(group_diagnostics[0]["graph_sitter_context"], indent=2)} + AutoGenLib Context: {json.dumps(group_diagnostics[0]["autogenlib_context"], indent=2)} + + Please provide a batch fix for all these related errors. + """ + + try: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + temperature=0.1, + max_tokens=5000, + ) + + content = response.choices[0].message.content.strip() + batch_result = json.loads(content) + batch_result["group_key"] = group_key + batch_result["errors_count"] = len(group_diagnostics) + batch_results.append(batch_result) + + except Exception as e: + logger.exception(f"Error in batch resolution for {group_key}: {e}") + batch_results.append( + { + "group_key": group_key, + "status": "error", + "message": f"Batch resolution failed: {e}", + "errors_count": len(group_diagnostics), + } + ) + + return { + "status": "success", + "batch_results": batch_results, + "total_groups": len(error_groups), + "total_errors": sum(len(group) for group in error_groups.values()), + } + + +def generate_comprehensive_fix_strategy(codebase: Codebase, error_analysis: dict[str, Any]) -> dict[str, Any]: + """Generate a comprehensive fix strategy for all errors in the codebase.""" + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + return {"status": "error", "message": "OpenAI API key not configured."} + + client = openai.OpenAI(api_key=api_key, base_url=os.environ.get("OPENAI_API_BASE_URL")) + + system_message = """ + You are a senior software architect and code quality expert. + Analyze the comprehensive error analysis and create a strategic plan for fixing all issues. + + Consider: + 1. Error priorities and dependencies + 2. Optimal fixing order to minimize conflicts + 3. Architectural improvements needed + 4. Preventive measures for future errors + 5. Testing and validation strategies + + Return JSON with: strategy, phases, priorities, estimated_effort, risk_assessment + """ + + user_prompt = f""" + COMPREHENSIVE ERROR ANALYSIS: + ============================ + Total Errors: {error_analysis.get("total", 0)} + Critical: {error_analysis.get("critical", 0)} + Major: {error_analysis.get("major", 0)} + Minor: {error_analysis.get("minor", 0)} + + ERROR CATEGORIES: + ================= + {json.dumps(error_analysis.get("by_category", {}), indent=2)} + + ERROR PATTERNS: + =============== + {json.dumps(error_analysis.get("error_patterns", []), indent=2)} + + RESOLUTION RECOMMENDATIONS: + =========================== + {json.dumps(error_analysis.get("resolution_recommendations", []), indent=2)} + + Please create a comprehensive strategy for resolving all these errors efficiently. + """ + + try: + response = client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + temperature=0.2, + max_tokens=3000, + ) + + content = response.choices[0].message.content.strip() + strategy = json.loads(content) + + return {"status": "success", "strategy": strategy, "generated_at": time.time()} + + except Exception as e: + logger.exception(f"Error generating fix strategy: {e}") + return {"status": "error", "message": f"Failed to generate strategy: {e}"} + + +def validate_fix_with_context(fixed_code: str, enhanced_diagnostic: EnhancedDiagnostic, codebase: Codebase) -> dict[str, Any]: + """Validate a fix using comprehensive context analysis.""" + validation_results = { + "syntax_valid": False, + "context_compatible": False, + "dependencies_satisfied": False, + "style_consistent": False, + "warnings": [], + "suggestions": [], + } + + # 1. Syntax validation + try: + validate_code(fixed_code) + validation_results["syntax_valid"] = True + except Exception as e: + validation_results["warnings"].append(f"Syntax error: {e}") + + # 2. Context compatibility validation + symbol_context = enhanced_diagnostic["graph_sitter_context"].get("symbol_context", {}) + if symbol_context and symbol_context.get("symbol_details", {}).get("error") is None: + # Check if fix maintains expected function signature + if "function_details" in symbol_context: + func_details = symbol_context["function_details"] + if "def " in fixed_code: + validation_results["context_compatible"] = True + else: + validation_results["warnings"].append("Fix doesn't appear to maintain function structure") + + # 3. Dependencies validation + file_context = enhanced_diagnostic["graph_sitter_context"].get("file_context", {}) + if file_context and "import_analysis" in file_context: + import_analysis = file_context["import_analysis"] + # Check if fix introduces new dependencies + for imp in import_analysis.get("imports_analysis", []): + if imp["name"] in fixed_code and not imp["is_external"]: + validation_results["dependencies_satisfied"] = True + break + + # 4. Style consistency validation + original_style = _analyze_code_style(enhanced_diagnostic["file_content"]) + fixed_style = _analyze_code_style(fixed_code) + + if _styles_compatible(original_style, fixed_style): + validation_results["style_consistent"] = True + else: + validation_results["suggestions"].append("Consider adjusting code style to match existing patterns") + + return validation_results + + +def _analyze_code_style(code: str) -> dict[str, Any]: + """Analyze code style patterns.""" + return { + "indentation": "spaces" if " " in code else "tabs", + "quote_style": "double" if code.count('"') > code.count("'") else "single", + "line_length": max(len(line) for line in code.split("\n")) if code else 0, + "has_type_hints": "->" in code or ": " in code, + } + + +def _styles_compatible(style1: dict[str, Any], style2: dict[str, Any]) -> bool: + """Check if two code styles are compatible.""" + return style1.get("indentation") == style2.get("indentation") and style1.get("quote_style") == style2.get("quote_style") + + +import time + diff --git a/Libraries/graph_sitter_lib/codemods/canonical/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/add_function_parameter_type_annotations/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/add_function_parameter_type_annotations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py b/Libraries/graph_sitter_lib/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py new file mode 100644 index 00000000..79dfe8ca --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py @@ -0,0 +1,51 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that adds type annotations for function parameters named 'db' to be of type 'SessionLocal' from 'app.db'. The codemod should +also ensure that the necessary import statement is added if it is not already present. Include examples of the code before and after the +transformation.""", + uid="d62a3590-14ef-4759-853c-39c5cf755ce5", +) +@canonical +class AddFunctionParameterTypeAnnotations(Codemod, Skill): + """Adds type annotation for function parameters that takes in a 'db' parameter, which is a `SessionLocal` from `app.db`. + It also adds the necessary import if not already present. + + Before: + ``` + def some_function(db): + pass + ``` + + After: + ``` + from app.db import SessionLocal + + def some_function(db: SessionLocal): + pass + ``` + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # Iterate over all functions in the codebase + for function in codebase.functions: + # Check each parameter of the function + for param in function.parameters: + # Identify parameters named 'db' + if param.name == "db": + # Change the type annotation to 'SessionLocal' + param.set_type_annotation("SessionLocal") + # Ensure the necessary import is present + file = function.file + if "SessionLocal" not in [imp.name for imp in file.imports]: + file.add_import("from app.db import SessionLocal") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/add_internal_to_non_exported_components/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/add_internal_to_non_exported_components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/add_internal_to_non_exported_components/add_internal_to_non_exported_components.py b/Libraries/graph_sitter_lib/codemods/canonical/add_internal_to_non_exported_components/add_internal_to_non_exported_components.py new file mode 100644 index 00000000..0154df6e --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/add_internal_to_non_exported_components/add_internal_to_non_exported_components.py @@ -0,0 +1,44 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a codemod that iterates through a codebase and renames all non-exported React function components by appending 'Internal' to their names. The +codemod should check each function to determine if it is a JSX component and not exported, then rename it accordingly.""", + uid="302d8f7c-c848-4020-9dea-30e8e622d709", +) +@canonical +class AddInternalToNonExportedComponents(Codemod, Skill): + """This codemod renames all React function components that are not exported from their file to be suffixed with 'Internal'. + + Example: + Before: + ``` + const Inner = () =>
; + const Outer = () =>
; + export default Outer; + ``` + After: + ``` + const InnerInternal = () =>
; + const Outer = () =>
; + export default Outer; + ``` + """ + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files + for file in codebase.files: + for function in file.functions: + # Check if the function is a React component and is not exported + if function.is_jsx and not function.is_exported: + # Rename the function to include 'Internal' + function.rename(f"{function.name}Internal") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/bang_bang_to_boolean/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/bang_bang_to_boolean/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/bang_bang_to_boolean/bang_bang_to_boolean.py b/Libraries/graph_sitter_lib/codemods/canonical/bang_bang_to_boolean/bang_bang_to_boolean.py new file mode 100644 index 00000000..d45ad0c4 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/bang_bang_to_boolean/bang_bang_to_boolean.py @@ -0,0 +1,37 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a TypeScript codemod that transforms instances of '!!(expression)' into 'Boolean(expression)'. The codemod should search through all +TypeScript files in a codebase, using a regular expression to identify the pattern. Upon finding a match, it should replace '!!' with 'Boolean(' and +append a closing parenthesis to complete the transformation.""", + uid="d1ece8d3-7da9-4696-9288-4087737e2952", +) +@canonical +class BangBangToBoolean(Codemod, Skill): + """This codemod converts !!(expression) to Boolean(expression)""" + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Regular expression pattern as a string to find '!!' followed by an identifier or any bracketed expression + pattern = r"!!\s*(\w+|\([^\)]*\))" + + # Iterate over all files in the codebase + for file in codebase.files: + # Check if the file is a TypeScript file + if file.extension == ".ts": + # Search for the pattern in the file's source code using the string pattern + matches = file.search(pattern, include_strings=False, include_comments=False) + for match in matches: + # Replace the '!!' with 'Boolean(' + match.replace("!!", "Boolean(", count=1) + # Wrap the expression in closing parenthesis + match.insert_after(")", newline=False) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/built_in_type_annotation/built_in_type_annotation.py b/Libraries/graph_sitter_lib/codemods/canonical/built_in_type_annotation/built_in_type_annotation.py new file mode 100644 index 00000000..54e26d09 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/built_in_type_annotation/built_in_type_annotation.py @@ -0,0 +1,44 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that replaces type annotations from the typing module with their corresponding built-in types. The codemod should iterate +through all files in a codebase, check for imports from the typing module, remove those imports, and replace any usages of typing.List, typing.Dict, +typing.Set, and typing.Tuple with list, dict, set, and tuple respectively.""", + uid="b2cd98af-d3c5-4e45-b396-e7abf06df924", +) +@canonical +class BuiltInTypeAnnotation(Codemod, Skill): + """Replaces type annotations using typing module with builtin types. + + Examples: + typing.List -> list + typing.Dict -> dict + typing.Set -> set + typing.Tuple -> tuple + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + import_replacements = {"List": "list", "Dict": "dict", "Set": "set", "Tuple": "tuple"} + # Iterate over all files in the codebase + for file in codebase.files: + # Iterate over all imports in the file + for imported in file.imports: + # Check if the import is from the typing module and is a builtin type + if imported.module == "typing" and imported.name in import_replacements: + # Remove the type import + imported.remove() + # Iterate over all symbols that use this imported module + for usage in imported.usages: + # Replace the usage with the builtin type + if usage.match.source == imported.name: + usage.match.edit(import_replacements[imported.name]) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/change_component_tag_names/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/change_component_tag_names/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/change_component_tag_names/change_component_tag_names.py b/Libraries/graph_sitter_lib/codemods/canonical/change_component_tag_names/change_component_tag_names.py new file mode 100644 index 00000000..47c506b0 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/change_component_tag_names/change_component_tag_names.py @@ -0,0 +1,59 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a codemod that updates all instances of the JSX element to within React components in a TypeScript +codebase. Ensure that the new component is imported if it is not already present. The codemod should check for the existence of the + component and raise an error if it is not found.""", + uid="ab5879e3-e3ea-4231-b928-b756473f290d", +) +@canonical +class ChangeJSXElementName(Codemod, Skill): + """This codemod updates specific JSX elements inside of React components + + In particular, this: + <> + test + + + + gets updated to: + <> + test + + + + Inside of all React components in the codebase. + """ + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase): + # Grab the NewName component + PrivateRoutesContainer = codebase.get_symbol("PrivateRoutesContainer", optional=True) + if PrivateRoutesContainer is None or not PrivateRoutesContainer.is_jsx: + msg = "PrivateRoutesContainer component not found in codebase" + raise ValueError(msg) + + # Iterate over all functions in the codebase + for file in codebase.files: + # Iterate over each function in the file + for function in file.functions: + # Check if the function is a React component + if function.is_jsx: + # Iterate over all JSXElements in the React component + for element in function.jsx_elements: + # Check if the element named improperly + if element.name == "PrivateRoute": + # Update the JSXElement's name + element.set_name("PrivateRoutesContainer") + # Add the import if it doesn't exist + if not file.has_import("PrivateRoutesContainer"): + file.add_import(PrivateRoutesContainer) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/classnames_to_backtick.py b/Libraries/graph_sitter_lib/codemods/canonical/classnames_to_backtick.py new file mode 100644 index 00000000..f062981a --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/classnames_to_backtick.py @@ -0,0 +1,50 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a TypeScript codemod that converts all `className='...'` props in JSX elements to use backticks. The codemod should iterate through all files +in a codebase, identify JSX components, and for each JSX element, check its props. If a prop is named `className` and its value is not already wrapped +in curly braces, replace the quotes with backticks, updating the prop value accordingly.""", + uid="bf22f4d7-a93a-458f-be78-470c24487d4c", +) +@canonical +class ClassNamesToBackTick(Codemod, Skill): + """This Codemod converts all `classNames="..."` props in JSX elements to use backticks. + + Example: + Before: +
+ + After: +
+ + """ + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files in the codebase + for file in codebase.files: + # Check if the file is likely to contain JSX elements (commonly in .tsx files) + for function in file.functions: + # Check if the function is a JSX component + if function.is_jsx: + # Iterate over all JSX elements in the function + for element in function.jsx_elements: + # Access the props of the JSXElement + for prop in element.props: + # Check if the prop is named 'className' + if prop.name == "className": + # Get the current value of the prop + if not prop.value.startswith("{"): + # Replace single or double quotes with backticks + new_value = "{`" + prop.value.strip("\"'") + "`}" + # Update the attribute value + prop.set_value(new_value) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/convert_array_type_to_square_bracket/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/convert_array_type_to_square_bracket/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/convert_array_type_to_square_bracket/convert_array_type_to_square_bracket.py b/Libraries/graph_sitter_lib/codemods/canonical/convert_array_type_to_square_bracket/convert_array_type_to_square_bracket.py new file mode 100644 index 00000000..089fd344 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/convert_array_type_to_square_bracket/convert_array_type_to_square_bracket.py @@ -0,0 +1,38 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.core.expressions.generic_type import GenericType +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a TypeScript codemod that converts types from `Array` to `T[]`. The codemod should iterate through all files in a codebase, checking each +function's return type and parameters. If a return type or parameter type is of the form `Array`, it should be transformed to `T[]`. Ensure that the +codemod handles edge cases, such as nested Array types, appropriately.""", + uid="97184a15-5992-405b-be7b-30122556fe8b", +) +@canonical +class ConvertArrayTypeToSquareBracket(Codemod, Skill): + """This codemod converts types of the form `Array` to `T[]`, while avoiding edge cases like nested Array types""" + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files in the codebase + for file in codebase.files: + # Iterate over all functions in the file + for func in file.functions: + # Check if the return type is of the form Array + if (return_type := func.return_type) and isinstance(return_type, GenericType) and return_type.name == "Array": + # Array<..> syntax only allows one type argument + func.set_return_type(f"({return_type.parameters[0].source})[]") + + # Process each parameter in the function + for param in func.parameters: + if (param_type := param.type) and isinstance(param_type, GenericType) and param_type.name == "Array": + # Array<..> syntax only allows one type argument + param_type.edit(f"({param_type.parameters[0].source})[]") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/convert_attribute_to_decorator/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/convert_attribute_to_decorator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py b/Libraries/graph_sitter_lib/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py new file mode 100644 index 00000000..72eb4938 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py @@ -0,0 +1,59 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that transforms class attributes initializing specific Session objects into decorators. The codemod should iterate through +all classes in a codebase, check for attributes with values 'NullSession' or 'SecureCookieSession', import the corresponding decorators, add them to +the class, and remove the original attributes. Ensure the decorators are imported from 'src.flask.sessions'.""", + uid="b200fb43-dad4-4241-a0b2-75a6fbf5aca6", +) +@canonical +class ConvertAttributeToDecorator(Codemod, Skill): + """This converts any class attributes that initializes a set of Session objects to a decorator. + + For example, before: + + class MySession(SessionInterface): + session_class = NullSession + ... + + After: + @null_session + class MySession(SessionInterface): + ... + + That is, it deletes the attribute and adds the appropriate decorator via the `cls.add_decorator` method. + Note that `cls.file.add_import(import_str)` is the method used to add import for the decorator. + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + attr_value_to_decorator = { + "NullSession": "null_session", + "SecureCookieSession": "secure_cookie_session", + } + # Iterate over all classes in the codebase + for cls in codebase.classes: + # Check if the class contains any targeted attributes + for attribute in cls.attributes: + if attribute.right is None: + continue + + if attribute.right.source in attr_value_to_decorator: + decorator_name = attr_value_to_decorator[attribute.right.source] + # Import the necessary decorators + required_import = f"from src.flask.sessions import {decorator_name}" + cls.file.add_import(required_import) + + # Add the appropriate decorator + cls.add_decorator(f"@{decorator_name}") + # Remove the attribute + attribute.remove() diff --git a/Libraries/graph_sitter_lib/codemods/canonical/convert_comments_to_JSDoc_style/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/convert_comments_to_JSDoc_style/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/convert_comments_to_JSDoc_style/convert_comments_to_JSDoc_style.py b/Libraries/graph_sitter_lib/codemods/canonical/convert_comments_to_JSDoc_style/convert_comments_to_JSDoc_style.py new file mode 100644 index 00000000..11fe37fb --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/convert_comments_to_JSDoc_style/convert_comments_to_JSDoc_style.py @@ -0,0 +1,45 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a codemod that converts comments on exported functions and classes in a TypeScript codebase to JSDoc style. The codemod should iterate +through all functions and classes, check if they are exported, and if they lack docstrings. If comments are present and do not contain 'eslint', +escape any occurrences of '*/' in the comments to prevent breaking the JSDoc block, then convert the comments to JSDoc format. Finally, remove the +original comments after conversion.""", + uid="846a3894-b534-4de2-9810-94bc691a5687", +) +@canonical +class ConvertCommentsToJSDocStyle(Codemod, Skill): + """This codemod converts the comments on any exported function or class to JSDoc style if they aren't already in JSDoc style. + + A JSDoc style comment is one that uses /** */ instead of // + + It also accounts for some common edgecases like avoiding eslint comments or comments which include a */ in them that needs to be escaped + """ + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Iterate over all functions and classes in the codebase + for symbol in codebase.functions + codebase.classes: + # Check if the symbol is exported + if symbol.is_exported: + # Check if the symbol is missing docstrings + if not symbol.docstring: + # Check if the symbol has comments + if symbol.comment: + # If eslint comments are present, skip conversion + if "eslint" not in symbol.comment.text: + # Escape any `*/` found in the comment to prevent breaking the JSDoc block + escaped_comment = symbol.comment.text.replace("*/", r"*\/") + # Convert comment to JSdoc docstrings + # symbol.set_docstring(escaped_comment, force_multiline=True) + symbol.set_docstring(escaped_comment, force_multiline=True) + symbol.comment.remove() diff --git a/Libraries/graph_sitter_lib/codemods/canonical/convert_docstring_to_google_style/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/convert_docstring_to_google_style/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/convert_docstring_to_google_style/convert_docstring_to_google_style.py b/Libraries/graph_sitter_lib/codemods/canonical/convert_docstring_to_google_style/convert_docstring_to_google_style.py new file mode 100644 index 00000000..e40cd5c5 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/convert_docstring_to_google_style/convert_docstring_to_google_style.py @@ -0,0 +1,30 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod class named `ConvertDocstringToGoogleStyle` that inherits from `Codemod` and `Skill`. The class should have a docstring +explaining its purpose: converting docstrings of functions and classes to Google style if they aren't already. The `execute` method should iterate +over the functions in a given `codebase`, check if each function has a docstring, and if so, convert it to Google style using a method +`to_google_docstring`.""", + uid="99da3cd9-6ba8-4a4e-8ceb-8c1b2a60562d", +) +@canonical +class ConvertDocstringToGoogleStyle(Codemod, Skill): + """This codemod converts docstrings on any function or class to Google docstring style if they aren't already. + + A Google docstring style is one that specifies the args, return value, and raised exceptions in a structured format. + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + for function in codebase.functions: + if (docstring := function.docstring) is not None: + function.set_docstring(docstring.to_google_docstring(function)) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/delete_unused_functions/delete_unused_functions.py b/Libraries/graph_sitter_lib/codemods/canonical/delete_unused_functions/delete_unused_functions.py new file mode 100644 index 00000000..e6eb6bb2 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/delete_unused_functions/delete_unused_functions.py @@ -0,0 +1,34 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that deletes all unused functions from a codebase. The codemod should iterate through each file in the codebase, check for +top-level functions, and remove any function that has no usages or call-sites. Ensure that the implementation follows best practices for identifying +unused functions.""", + uid="4024ceb5-54de-49de-b8f5-122ca2d3a6ee", +) +@canonical +class DeleteUnusedFunctionsCodemod(Codemod, Skill): + """This Codemod deletes all functions that are not used in the codebase (no usages). + In general, when deleting unused things, it's good practice to check both usages and call-sites, even though + call-sites should be basically a subset of usages (every call-site should correspond to a usage). + This is not always the case, however, so it's good to check both. + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + for file in codebase.files: + # Iterate over top-level functions in the file + for function in file.functions: + # Check conditions: function has no usages/call-sites + if not function.usages: + # Remove the function from the codebase when it has no call sites + function.remove() diff --git a/Libraries/graph_sitter_lib/codemods/canonical/emojify_py_files_codemod/emojify_py_files_codemod.py b/Libraries/graph_sitter_lib/codemods/canonical/emojify_py_files_codemod/emojify_py_files_codemod.py new file mode 100644 index 00000000..9bbb96c6 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/emojify_py_files_codemod/emojify_py_files_codemod.py @@ -0,0 +1,28 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that iterates over all Python files in a codebase and adds a rainbow emoji comment at the beginning of each file. The +codemod should be implemented in the `execute` function of the `EmojifyPyFilesCodemod` class, which inherits from `Codemod` and `Skill`. Ensure that +the new content for each file starts with the comment '#🌈' followed by the original content of the file.""", + uid="5d8f1994-7f74-42e8-aaa8-0c41ced228ef", +) +@canonical +class EmojifyPyFilesCodemod(Codemod, Skill): + """Trivial codemod to add a rainbow emoji in a comment at the beginning of all Python files.""" + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # iterate over files + for file in codebase.files: + # add the rainbow emoji to the top of the file + new_content = "#🌈" + "\n" + file.content + file.edit(new_content) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/enum_mover/enum_mover.py b/Libraries/graph_sitter_lib/codemods/canonical/enum_mover/enum_mover.py new file mode 100644 index 00000000..d444769c --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/enum_mover/enum_mover.py @@ -0,0 +1,49 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import CodebaseType +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that iterates through all classes in a codebase, identifies subclasses of Enum, and moves them to a designated enums.py +file. Ensure that the codemod checks if the class is already in the correct file, flags it for movement if necessary, and creates the enums.py file if +it does not exist.""", + uid="55bc76e5-15d2-4da6-bac1-59b408a59be7", +) +@canonical +class EnumMover(Codemod, Skill): + """This codemod moves all enums (Enum subclasses) to a designated enums.py file within the same directory of the + file they're defined in. It ensures that the enums are moved to the correct file and creates the enums.py file if + it does not exist. Furthermore, it flags the class for movement which is necessary for splitting up the + modifications into separate pull requests. + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: CodebaseType): + # Iterate over all classes in the codebase + for cls in codebase.classes: + # Check if the class is a subclass of Enum + if cls.is_subclass_of("Enum"): + # Determine the target file path for enums.py + target_filepath = "/".join(cls.file.filepath.split("/")[:-1]) + "/enums.py" + + # Check if the current class is already in the correct enums.py file + if cls.file.filepath.endswith("enums.py"): + continue + + # Flag the class for potential movement + flag = codebase.flag_instance(symbol=cls) + if codebase.should_fix(flag): + # Check if the enums.py file exists, if not, create it + if not codebase.has_file(target_filepath): + enums_file = codebase.create_file(target_filepath, "") + else: + enums_file = codebase.get_file(target_filepath) + + # Move the enum class to the enums.py file + cls.move_to_file(enums_file) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/insert_arguments_to_decorator/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/insert_arguments_to_decorator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/insert_arguments_to_decorator/insert_arguments_to_decorator.py b/Libraries/graph_sitter_lib/codemods/canonical/insert_arguments_to_decorator/insert_arguments_to_decorator.py new file mode 100644 index 00000000..581c5552 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/insert_arguments_to_decorator/insert_arguments_to_decorator.py @@ -0,0 +1,45 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that iterates through a codebase, identifying all instances of the `@app.function` decorator. For each decorator, check if +the `cloud` and `region` arguments are present. If they are missing, append `cloud='aws'` and `region='us-east-1'` to the decorator's arguments. +Ensure that the modifications are made only when the arguments are not already included.""", + uid="de868e09-796c-421b-9efd-151f94f08aef", +) +@canonical +class InsertArgumentsToDecorator(Codemod, Skill): + """This codemod inserts the cloud and region arguments to every app.function decorator. + it decides whether to insert the arguments based on whether they are already present in the decorator. + if they are not present, it inserts them. + for example: + + -@app.function(image=runner_image, secrets=[modal.Secret.from_name("aws-secret")]) + +@app.function(image=runner_image, secrets=[modal.Secret.from_name("aws-secret")], cloud="aws", region="us-east-1") + + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files in the codebase + for file in codebase.files: + # Iterate over all functions in each file + for function in file.functions: + # Check each decorator for the function + for decorator in function.decorators: + # Identify decorators that are app.function and modify them + if decorator.source.startswith("@app.function("): + # Parse the existing decorator to add or update the cloud and region parameters + # Check if 'cloud' and 'region' are already in the decorator + if "cloud=" not in decorator.source: + decorator.call.args.append('cloud="aws"') + if "region=" not in decorator.source: + decorator.call.args.append('region="us-east-1"') diff --git a/Libraries/graph_sitter_lib/codemods/canonical/invite_factory_create_params/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/invite_factory_create_params/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/invite_factory_create_params/invite_factory_create_params.py b/Libraries/graph_sitter_lib/codemods/canonical/invite_factory_create_params/invite_factory_create_params.py new file mode 100644 index 00000000..692827d3 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/invite_factory_create_params/invite_factory_create_params.py @@ -0,0 +1,69 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.core.detached_symbols.function_call import FunctionCall +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that updates calls to `InviteFactory.create`, `InviteFactory.build`, and `InviteFactory(...)` to use the `invitee` parameter +instead of `invitee_id`, `invitee['email']`, or `invitee.id`. The codemod should iterate through all files in a codebase, find the relevant function +calls, and modify the arguments accordingly. Specifically, it should replace `invitee_id` with `invitee`, and adjust the value to remove `.id` or +`['email']` as needed.""", + uid="1c43f274-e4bc-49c7-abca-8b273e9cad9a", +) +@canonical +class InviteFactoryCreateParams(Codemod, Skill): + """This codemod updates calls to InviteFactory.create, InviteFactory.build and InviteFactory(...) to use the `invitee` parameter instead of `invitee_id`, `invitee["email"]`, or `invitee.id`. + + For example: + + InviteFactory.create(invitee_id=user_deleted_recently.id) + + Becomes: + + InviteFactory.create(invitee=user_deleted_recently) + + Note that this involves grabbing the function calls by using `file.find` and `file.search` to find the function calls, and then using `FunctionCall.from_usage` to create a `FunctionCall` object from the usage. This is because **the current version of GraphSitter does not support finding method usages** + """ # noqa: E501 + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files + for file in codebase.files: + # Find invocations of InviteFactory.create and InviteFactory.build in the file + usages = file.find("InviteFactory.create", exact=True) # returns an Editable + usages += file.find("InviteFactory.build", exact=True) + usages += file.search(r"\bInviteFactory\(") + + # Iterate over all these function calls + for usage in usages: + # Create a function call from this `usage` + function_call = FunctionCall.from_usage(usage) + if function_call is None: + continue + + # Grab the invitee_id argument + invitee_arg = function_call.get_arg_by_parameter_name("invitee_id") + # If it exists... + if invitee_arg: + # Grab the current value + arg_value = invitee_arg.value + + # Replace the arg value with the correct value + if arg_value.endswith(".id"): + # replace `xyz.id` with `xyz` + invitee_arg.set_value(arg_value.replace(".id", "")) + elif arg_value.endswith('["email"]'): + # replace `xyz["email"]` with `xyz` + invitee_arg.set_value(arg_value.replace('["email"]', "")) + else: + continue + + # Update the arg keyword from `invitee_id` => 'invitee' + invitee_arg.rename("invitee") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/js_to_esm_codemod/js_to_esm_codemod.py b/Libraries/graph_sitter_lib/codemods/canonical/js_to_esm_codemod/js_to_esm_codemod.py new file mode 100644 index 00000000..58dc60e6 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/js_to_esm_codemod/js_to_esm_codemod.py @@ -0,0 +1,33 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python function named `execute` within a class `JsToEsmCodemod` that iterates through all files in a given `codebase`. For each file, check +if its name contains '.router'. If it does, convert the file to ESM format and update its filename to have a '.ts' extension, preserving the original +directory structure.""", + uid="f93122d3-f469-4740-a8bf-f53016de41b2", +) +@canonical +class JsToEsmCodemod(Codemod, Skill): + """This codemod will convert all JS files that have .router in their name to be proper ESM modules""" + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # iterate all files in the codebase + for file in codebase.files: + # Check if the file is not a router file + if ".router" in file.name: + # Convert the file to ESM + file.convert_js_to_esm() + # Update filename + new_file_dir = "/".join(file.filepath.split("/")[:-1]) + new_file_name = ".".join(file.name.split(".")[:3]) + file.update_filepath(f"{new_file_dir}/{new_file_name}.ts") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/mark_as_internal_codemod/mark_as_internal_codemod.py b/Libraries/graph_sitter_lib/codemods/canonical/mark_as_internal_codemod/mark_as_internal_codemod.py new file mode 100644 index 00000000..7b5bd123 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/mark_as_internal_codemod/mark_as_internal_codemod.py @@ -0,0 +1,49 @@ +from pathlib import Path + +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a TypeScript codemod that marks functions as internal by adding the @internal tag to their docstrings. The codemod should check if a function +is only used within the same directory or subdirectory, ensuring it is not exported, re-exported, or overloaded. If the function's docstring does not +already contain the @internal tag, append it appropriately.""", + uid="fe61add3-ab41-49ec-9c26-c2d13e2647d1", +) +@canonical +class MarkAsInternalCodemod(Codemod, Skill): + """Mark all functions that are only used in the same directory or subdirectory as an internal function. + To mark function as internal by adding the @internal tag to the docstring. + """ + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Check if the caller and callee are in the same directory + def check_caller_directory(caller_file: str, callee_file: str) -> bool: + caller_path = Path(caller_file).resolve() + callee_path = Path(callee_file).resolve() + return str(caller_path).startswith(str(callee_path.parent)) + + # Iterate over all the functions in the codebase + for function in codebase.functions: + # Ignore functions that are exported + if function.is_exported: + # Check if all usages of the function are in the same file + if all([check_caller_directory(caller.file.filepath, function.file.filepath) for caller in function.symbol_usages]): + # Check if function is not re-exported + if not function.is_reexported and not function.is_overload: + # Check if function is not already marked as internal + docstring = function.docstring.text if function.docstring else "" + if "@internal" not in docstring: + # Add @internal to the docstring + if function.docstring: + function.set_docstring(f"{function.docstring.text}\n\n@internal") + else: + function.set_docstring("@internal") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/mark_internal_to_module/mark_internal_to_module.py b/Libraries/graph_sitter_lib/codemods/canonical/mark_internal_to_module/mark_internal_to_module.py new file mode 100644 index 00000000..dbe77e33 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/mark_internal_to_module/mark_internal_to_module.py @@ -0,0 +1,32 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that iterates through all functions in the `app` directory of a codebase. For each function that is not private and is not +being imported anywhere, rename it to be internal by prefixing its name with an underscore. Ensure that the function checks the file path to confirm +it belongs to the `app` directory and uses a method to find import usages.""", + uid="cb5c6f1d-0a00-46e3-ac0d-c540ab665041", +) +@canonical +class MarkInternalToModule(Codemod, Skill): + """This codemod looks at all functions in the `app` directory and marks them as internal if they are not being imported anywhere""" + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + for function in codebase.functions: + if "app" in function.file.filepath: + # Check if the function is not internal + if not function.is_private and function.name is not None: + # Check if the function is not being imported anywhere + if not any(usage.kind in (UsageKind.IMPORTED, UsageKind.IMPORTED_WILDCARD) for usage in function.usages): + # Rename the function to be internal + function.rename("_" + function.name) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/mark_is_boolean/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/mark_is_boolean/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/mark_is_boolean/mark_is_boolean.py b/Libraries/graph_sitter_lib/codemods/canonical/mark_is_boolean/mark_is_boolean.py new file mode 100644 index 00000000..40714bcc --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/mark_is_boolean/mark_is_boolean.py @@ -0,0 +1,45 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a TypeScript codemod that renames function parameters of boolean type that do not start with 'is'. The codemod should iterate through all +files in a codebase, check each function's parameters, and if a parameter is boolean and does not start with 'is', it should be renamed to start with +'is' followed by the capitalized parameter name. Additionally, all function calls using the old parameter name should be updated to use the new name.""", + uid="e848b784-c703-4f4f-bfa4-e3876b2468d1", +) +@canonical +class MarkIsBoolean(Codemod, Skill): + """This (TypeScript) Codemod illustrates how to rename function parameters that are boolean types but do not start with 'is'. + + In a real application, you would probably also check for other valid prefixes, like `should` etc. + """ + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files in the codebase + for file in codebase.files: + # Iterate over all functions in the file + for function in file.functions: + # Iterate over all parameters in each function + for param in function.parameters: + # Check if the parameter is a boolean type + if param.type == "boolean" or param.default in ["true", "false"]: + # Check if the parameter name does not start with 'is' + if not param.name.startswith("is"): + # Generate the new parameter name + new_name = "is" + param.name.capitalize() + # Rename the parameter and update all usages + param.rename(new_name) + # Update all function calls with the new parameter name + for call in function.call_sites: + arg = call.get_arg_by_parameter_name(param.name) + if arg: + arg.rename(new_name) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/migrate_class_attributes/migrate_class_attributes.py b/Libraries/graph_sitter_lib/codemods/canonical/migrate_class_attributes/migrate_class_attributes.py new file mode 100644 index 00000000..72c7fcd2 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/migrate_class_attributes/migrate_class_attributes.py @@ -0,0 +1,61 @@ +import logging +import textwrap + +from codemods.codemod import Codemod +from graph_sitter.core.codebase import PyCodebaseType +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + +logger = logging.getLogger(__name__) + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that migrates class attributes from a source class named 'RequestResetPassword' to a destination class named +'UserGroupsSettingsControlPanel'. The migrated attributes should be made private in the source class by renaming them with a leading underscore. +Additionally, create a hybrid property for each migrated attribute in the source class, including getter and setter methods that manage the private +attribute and maintain a copy in the source class.""", + uid="739061ae-4f4f-48eb-a825-7424417ce540", +) +@canonical +class MigrateClassAttributes(Codemod, Skill): + """Migrates class attributes from a source class to another class. + Any migrated attributes are made private in the source class. + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: PyCodebaseType) -> None: + # Get the source and destination classes + source_class = codebase.get_class("RequestResetPassword") + dest_class = codebase.get_class("UserGroupsSettingsControlPanel") + dest_attr_names = [x.name for x in dest_class.attributes] + + # Iterate over all attributes in the source class + for attribute in source_class.attributes(private=False): + # Skip attributes that are already added + if attribute.name in dest_attr_names: + continue + + # Add the attribute to the destination class (and bring its dependencies with it) + dest_class.add_attribute(attribute, include_dependencies=True) + + # Make this attribute private (_name) in the source class + attribute.rename(f"_{attribute.name}") + + # Add a "shadow copy write" to the source class + return_type = attribute.assignment.type.source if attribute.assignment.type else "None" + source_class.add_attribute_from_source(f"""{attribute.name} = hybrid_property(fget=get_{attribute.name}, fset=set_{attribute.name})""") + source_class.methods.append( + textwrap.dedent(f""" + def get_{attribute.name}(self) -> {return_type}: + return self._{attribute.name} + + def set_{attribute.name}(self, value: str) -> None: + self._{attribute.name} = value + self.copy.{attribute.name} = value + """) + ) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/move_enums_codemod/move_enums_codemod.py b/Libraries/graph_sitter_lib/codemods/canonical/move_enums_codemod/move_enums_codemod.py new file mode 100644 index 00000000..fbaffe3b --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/move_enums_codemod/move_enums_codemod.py @@ -0,0 +1,42 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that moves all enum classes from various files in a codebase to a single file named 'enums.py'. The codemod should check if +'enums.py' already exists in the current directory; if not, it should create it. For each enum class found, the codemod should move the class along +with its dependencies to 'enums.py' and add a back edge import to the original file.""", + uid="47e9399c-b8d5-4f39-a5cf-fd40c51620b0", +) +@canonical +class MoveEnumsCodemod(Codemod, Skill): + """Moves all enums to a file called enums.py in current directory if it doesn't already exist""" + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + for file in codebase.files: + if not file.name.endswith("enums.py"): + for cls in file.classes: + # check if the class inherits from the Enum class + if cls.is_subclass_of("Enum"): + # generate the new filename for the enums.py file + new_filename = "/".join(file.filepath.split("/")[:-1]) + "/enums.py" + + # check if the enums.py file exists + if not codebase.has_file(new_filename): + # if it doesn't exist, create a new file + dst_file = codebase.create_file(new_filename, "from enum import Enum\n\n") + else: + # if it exists, get a reference to the existing file + dst_file = codebase.get_file(new_filename) + + # move the enum class and its dependencies to the enums.py file + # add a "back edge" import to the original file + cls.move_to_file(dst_file, include_dependencies=True, strategy="add_back_edge") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py b/Libraries/graph_sitter_lib/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py new file mode 100644 index 00000000..c394b8e4 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING + +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + +if TYPE_CHECKING: + from graph_sitter.core.file import SourceFile + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that moves all functions starting with 'pylsp_' from existing files in a codebase to a new file named 'pylsp_shared.py'. +Ensure that all imports across the codebase are updated to reflect the new location of these functions. The codemod should iterate through each file +in the codebase, create the new file, and move the matching functions while including their dependencies.""", + uid="b29f6b8b-0837-4548-b770-b597bbcd3e02", +) +@canonical +class MoveFunctionsToNewFile(Codemod, Skill): + """This codemod moves functions that starts with "pylsp_" in their names to a new file called pylsp_shared.py + + When it moves them to this file, all imports across the codebase will get updated to reflect the new location. + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase): + # Create a new file for storing the functions that contain pylsp util functions + new_file: SourceFile = codebase.create_file("pylsp/pylsp_shared.py", "") + for file in codebase.files: + # Move function's name contains 'pylsp_' as a prefix + for function in file.functions: + if function.name.startswith("pylsp_"): + # Move each function that matches the criteria to the new file + function.move_to_file(new_file, include_dependencies=True, strategy="update_all_imports") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py b/Libraries/graph_sitter_lib/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py new file mode 100644 index 00000000..d0c995b7 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py @@ -0,0 +1,75 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.core.detached_symbols.decorator import Decorator +from graph_sitter.core.symbol import Symbol +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that adds a `@xys_ns.response(200)` decorator to Flask Resource methods that lack return status codes. The codemod should +check for Flask Resource classes and their HTTP methods (GET, POST, PUT, PATCH, DELETE). If a method does not have any `@response` decorators and has +a valid return statement, the codemod should extract the namespace from the class's `@xys.route` decorator and add the `@xys_ns.response(200)` +decorator to the method.""", + uid="c1596668-8169-44b4-9e0e-b244eb7671d9", +) +@canonical +class OpenAPIAddResponseNone(Codemod, Skill): + """This one adds a `@xys_ns.response(200)` decorator to Flask Resource methods that do not contain any return status codes + + Before: + + @xyz_ns.route("/ping", methods=["GET"]) + class XYZResource(Resource): + + @decorator + def get(self): + return "pong" + + After: + + @xyz_ns.route("/ping", methods=["GET"]) + class XYZResource(Resource): + + @decorator + @xyz_ns.response(200) + def get(self): + return "pong" + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase): + def get_response_decorators(method: Symbol) -> list[Decorator]: + """Returns a list of decorators that contain the string '.response' in the source code""" + return [d for d in method.decorators if ".response" in d.source] + + def get_namespace_decorator(symbol: Symbol) -> Decorator | None: + """Returns the first decorator that contains the string '.route' in the source code""" + matches = [d for d in symbol.decorators if ".route" in d.source] + if len(matches) == 0: + return None + return matches[0] + + for cls in codebase.classes: + # Get Flask Resource classes + if cls.superclasses and any("Resource" in sc.source for sc in cls.superclasses): + for method in cls.methods: + # Filter to HTTP methods + if method.name in ("get", "post", "put", "patch", "delete"): + # Check if it has no `@response` decorators + response_decorators = get_response_decorators(method) + if len(response_decorators) == 0: + # Make sure it has `@xys.route` on the class + ns_decorator = get_namespace_decorator(cls) + if ns_decorator is not None: + # Check if returns a status code + if method.return_statements and not any(ret.value and ret.value.ts_node_type == "expression_list" for ret in method.return_statements): + # Extract the namespace name + ns_name = ns_decorator.source.split("@")[1].split(".")[0] + # Add the decorator + method.add_decorator(f"@{ns_name}.response(200)") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/openapi_no_reference_request/openapi_no_reference_request.py b/Libraries/graph_sitter_lib/codemods/canonical/openapi_no_reference_request/openapi_no_reference_request.py new file mode 100644 index 00000000..a69c7029 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/openapi_no_reference_request/openapi_no_reference_request.py @@ -0,0 +1,49 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.core.detached_symbols.decorator import Decorator +from graph_sitter.core.symbol import Symbol +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that adds `@my_namespace.expect(None)` to all Flask route methods (GET, POST, PUT, PATCH, DELETE) in classes ending with +'Resource' that do not access the request object. Ensure that these methods do not already have an `expect` decorator or similar decorators like +`load_with`, `use_args`, or `use_kwargs`. The codemod should also check for the presence of a namespace decorator in the class to determine the +correct namespace to use.""", + uid="5341d15f-92c7-4a3e-b409-416603dfa7f6", +) +@canonical +class OpenAPINoReferenceRequest(Codemod, Skill): + """As part of the OpenAPI typing initiative for Flask endpoints, this codemod will add `@my_namespace.expect(None)` to all Flask routes that do not interact with the request object.""" + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + request_accesses = ["request_get_json", "request.json", "request.args", "request.form", "request.files", "request", "self.request"] + + def get_namespace_decorator(symbol: Symbol) -> Decorator | None: + matches = [d for d in symbol.decorators if "_ns.route" in d.source] + if len(matches) == 0: + return None + return matches[0] + + for cls in codebase.classes: + if cls.name.endswith("Resource"): + for method in cls.methods: + if method.name in ("get", "post", "put", "patch", "delete"): + # Check if it has any request accesses + if not any([access in method.source for access in request_accesses]): + # Check if it has an existing `expect` + decorators = method.decorators + if not any([x in decorator.source for decorator in decorators for x in ["load_with", "expect", "use_args", "use_kwargs"]]): + # Make sure it has `@xys_ns.route` on the class + ns_decorator = get_namespace_decorator(cls) + if ns_decorator is not None: + ns_name = ns_decorator.source.split("@")[1].split(".")[0] + # Add the decorator + method.add_decorator(f"@{ns_name}.expect(None)") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/pascal_case_symbols/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/pascal_case_symbols/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/pascal_case_symbols/pascal_case_symbols.py b/Libraries/graph_sitter_lib/codemods/canonical/pascal_case_symbols/pascal_case_symbols.py new file mode 100644 index 00000000..e56e685e --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/pascal_case_symbols/pascal_case_symbols.py @@ -0,0 +1,41 @@ +from codemods.codemod import Codemod +from graph_sitter.core.class_definition import Class +from graph_sitter.core.codebase import Codebase +from graph_sitter.core.interface import Interface +from graph_sitter.core.type_alias import TypeAlias +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a TypeScript codemod that converts all Classes, Interfaces, and Types in a codebase to PascalCase. The codemod should iterate through all +symbols in the codebase, check if each symbol is a Class, Interface, or Type using the `isinstance` function, and if the symbol's name is not +capitalized, it should convert the name to PascalCase by capitalizing the first letter of each word and removing underscores. Finally, the codemod +should rename the symbol and update all references accordingly.""", + uid="bbb9e26a-7911-4b94-a4eb-207b9d32d18f", +) +@canonical +class PascalCaseSymbols(Codemod, Skill): + """This (Typescript) codemod converts all Classes, Interfaces and Types to be in PascalCase using simple logic. + + Note the use of the `isinstance(symbol, (Class | Interface | Type))` syntax to check if the symbol is a Class, Interface, or Type. + You should always use the abstract base class to check for the type of a symbol. + """ + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Iterate over all symbols in the codebase + for symbol in codebase.symbols: + # Check if the symbol is a Class, Interface, or Type with `isinstance` syntax + if isinstance(symbol, (Class | Interface | TypeAlias)): + # Check if the name isn't capitalized + if not symbol.name[0].isupper(): + # Generate the PascalCase name + new_name = "".join(word.capitalize() for word in symbol.name.replace("_", " ").split()) + # Rename the symbol and update all references + symbol.rename(new_name) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/pivot_return_types/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/pivot_return_types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/pivot_return_types/pivot_return_types.py b/Libraries/graph_sitter_lib/codemods/canonical/pivot_return_types/pivot_return_types.py new file mode 100644 index 00000000..e7513c14 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/pivot_return_types/pivot_return_types.py @@ -0,0 +1,49 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that transforms all functions returning a string type to return a custom FastStr type instead. The codemod should iterate +through the codebase, check for functions with a return type of 'str', update the return type to 'FastStr', add the necessary import statement for +FastStr, and modify all return statements to wrap the returned value in the FastStr constructor.""", + uid="a357f5c4-2ff0-4fb2-a5c6-be051428604a", +) +@canonical +class PivotReturnTypes(Codemod, Skill): + """This codemod allows us to take all functions that return str and safely convert it to a custom FastStr type. + It does so by wrapping the return statement value in the CustomStr constructor and update the return type annotation. + + def f() -> str: + ... + return content + + Becomes + + def f() -> FastStr: + ... + return FastStr(str=content) + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # Iterate over all functions in the codebase + for function in codebase.functions: + # Check if the function's return type annotation is 'str' + if (return_type := function.return_type) and return_type.source == "str": + # Update the return type to 'FastStr' + function.set_return_type("FastStr") + + # Add import for 'FastStr' if it doesn't exist + function.file.add_import("from app.models.fast_str import FastStr") + + # Modify all return statements within the function + for return_stmt in function.code_block.return_statements: + # Wrap return statements with FastStr constructor + return_stmt.set_value(f"FastStr(str={return_stmt.value})") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/refactor_react_components_into_separate_files/refactor_react_components_into_separate_files.py b/Libraries/graph_sitter_lib/codemods/canonical/refactor_react_components_into_separate_files/refactor_react_components_into_separate_files.py new file mode 100644 index 00000000..d21d0bf7 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/refactor_react_components_into_separate_files/refactor_react_components_into_separate_files.py @@ -0,0 +1,48 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python function that refactors React components in a codebase. The function should iterate through all files, identify React function +components, and separate non-default exported components into new files. Ensure that the new files are named after the components and that all imports +are updated accordingly. Include necessary error handling and commit changes to the codebase after each move.""", + uid="b64406f4-a670-4d65-8356-c6db25c4f4b7", +) +@canonical +class RefactorReactComponentsIntoSeparateFiles(Codemod, Skill): + """This codemod breaks up JSX/TSX files by moving components that aren't exported by default + into separate files. + """ + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files in the codebase + for file in codebase.files: + # Find all React function components in the file + react_components = [func for func in file.functions if func.is_jsx and func.name is not None] + + # Identify the default exported component + default_component = next((comp for comp in react_components if comp.is_exported and comp.export.is_default_export()), None) + if default_component is None: + continue + + # Move non-default components to new files + for component in react_components: + if component != default_component and component in file.symbols: + # Create a new file for the component + new_file_path = "/".join(file.filepath.split("/")[:-1]) + "/" + component.name + ".tsx" + if not codebase.has_file(new_file_path): + new_file = codebase.create_file(new_file_path) + + # Move the component to the new file and update all imports + component.move_to_file(new_file, strategy="update_all_imports") + + # Commit is NECESSARY since subsequent steps depend on current symbol locations + codebase.commit() diff --git a/Libraries/graph_sitter_lib/codemods/canonical/remove_indirect_imports/remove_indirect_imports.py b/Libraries/graph_sitter_lib/codemods/canonical/remove_indirect_imports/remove_indirect_imports.py new file mode 100644 index 00000000..a205de61 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/remove_indirect_imports/remove_indirect_imports.py @@ -0,0 +1,53 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.core.external_module import ExternalModule +from graph_sitter.core.import_resolution import Import +from graph_sitter.core.symbol import Symbol +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python function named `execute` within a class `RemoveIndirectImports` that processes a codebase to remove all indirect imports. The +function should iterate through all files in the codebase, check each import to determine if it points to another import, and replace it with a direct +import. Handle cases where the resolved import is either an external module or a symbol, ensuring that the import is updated accordingly.""", + uid="0648c80e-a569-4aa5-b241-38a2dd320e9a", +) +@canonical +class RemoveIndirectImports(Codemod, Skill): + """This codemod removes all indirect imports from a codebase (i.e. an import that points to another import), + replacing them instead with direct imports + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # iterate over all files -> imports + for file in codebase.files: + for original_import in file.imports: + # Grab the symbol being imported + imported_symbol = original_import.imported_symbol + + # Check if the symbol being imported is itself import + if isinstance(imported_symbol, Import): + # We've found an import that points to another import which means it's an indirect import! + # Get the symbol that the import eventually resolves to + imported_symbol = original_import.resolved_symbol + + # Case: we can't find the final destination symbol + if imported_symbol is None: + continue + + # Case: the resolved import is an external module. + elif isinstance(imported_symbol, ExternalModule): + original_import.edit(imported_symbol.source) + + # Case: the resolved import is Symbol. + elif isinstance(imported_symbol, Symbol): + # Replace the module in the import with the final destination symbol's module + # e.g. `from abc import ABC` -> `from xyz import ABC` or equivalent in your language. + original_import.set_import_module(imported_symbol.file.import_module_name) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/rename_function_parameters/rename_function_parameters.py b/Libraries/graph_sitter_lib/codemods/canonical/rename_function_parameters/rename_function_parameters.py new file mode 100644 index 00000000..03586ec9 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/rename_function_parameters/rename_function_parameters.py @@ -0,0 +1,33 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that iterates through all files in a codebase, identifies function parameters containing the substring 'obj', and renames +them to 'new_obj'. The codemod should be structured as a class that inherits from Codemod and Skill, with an execute method that performs the +renaming operation.""", + uid="1576b2fd-8a00-44e4-9659-eb0f585e015a", +) +@canonical +class RenameFunctionParameters(Codemod, Skill): + """This takes all functions that renames any parameter that contains 'obj' and replaces with 'new_obj'""" + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files + for file in codebase.files: + for function in file.functions: + # Search for parameter names that contain 'obj' + params_to_rename = [p for p in function.parameters if "obj" in p.name] + if params_to_rename: + # Rename the parameters + for param in params_to_rename: + new_param_name = param.name.replace("obj", "new_obj") + param.rename(new_param_name) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/rename_local_variables/rename_local_variables.py b/Libraries/graph_sitter_lib/codemods/canonical/rename_local_variables/rename_local_variables.py new file mode 100644 index 00000000..649a6322 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/rename_local_variables/rename_local_variables.py @@ -0,0 +1,48 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that iterates through a codebase, identifying functions with local variables containing the name 'position'. For each +identified function, rename all occurrences of the local variable 'position' to 'pos', ensuring that the renaming is applied to all relevant usages +within the function.""", + uid="79c10c00-bbce-4bdb-8c39-d91586307a2b", +) +@canonical +class RenameLocalVariables(Codemod, Skill): + """This codemod renames all local variables in functions that contain 'position' to 'pos' + + Example: + Before: + ``` + def some_function(x, y, position): + position_x = x + position + position_y = y + position + return position_x, position_y + ``` + After: + ``` + def some_function(x, y, position): + pos_x = x + position + pos_y = y + position + return pos_x, pos_y + ``` + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # iterate over files + for file in codebase.files: + for function in file.functions: + # Check if any local variable names contain "position" + position_usages = function.code_block.get_variable_usages("position", fuzzy_match=True) + if len(position_usages) > 0: + # Rename + function.rename_local_variable("position", "pos", fuzzy_match=True) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/replace_prop_values/replace_prop_values.py b/Libraries/graph_sitter_lib/codemods/canonical/replace_prop_values/replace_prop_values.py new file mode 100644 index 00000000..5bd84415 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/replace_prop_values/replace_prop_values.py @@ -0,0 +1,36 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a TypeScript codemod that iterates through a codebase, identifies JSX functions, and replaces any occurrences of the prop value 'text-center' +with 'text-left' in all JSX elements.""", + uid="c1914552-556b-4ae0-99f0-33cb7bfb702e", +) +@canonical +class ReplacePropValues(Codemod, Skill): + """Replaces any JSX props with text-center to text-left""" + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files in the codebase + for file in codebase.files: + # Iterate over all functions in the file + for function in file.functions: + # Filter for JSX functions + if function.is_jsx: + # Iterate over all JSX elements in the function + for jsx_element in function.jsx_elements: + # Iterate over all the props of the component + for prop in jsx_element.props: + # Check if prop has a value + if prop.value: + # Replace text-center with text-left + prop.value.replace("text-center", "text-left") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/return_none_type_annotation/return_none_type_annotation.py b/Libraries/graph_sitter_lib/codemods/canonical/return_none_type_annotation/return_none_type_annotation.py new file mode 100644 index 00000000..ea6d28bc --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/return_none_type_annotation/return_none_type_annotation.py @@ -0,0 +1,36 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that iterates through all functions and methods in a codebase. For each function or method that lacks return statements and +a return type annotation, set the return type to 'None'. Ensure the implementation handles both standalone functions and methods within classes.""", + uid="fcac16ed-a915-472a-9dfe-1562452d9ab3", +) +@canonical +class ReturnNoneTypeAnnotation(Codemod, Skill): + """This codemod sets the return type of functions that do not have any return statements""" + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # Iterate over all functions in the codebase + for function in codebase.functions: + # Look at ones that do not have return statements and no return type annotation + if len(function.return_statements) == 0 and not function.return_type: + # Set the return type to None + function.set_return_type("None") + + # Do the same for methods (have to call it `cls`, not `class`, since `class` is a reserved keyword) + for cls in codebase.classes: + for method in cls.methods: + # Look at ones that do not have return statements and no return type annotation + if len(method.return_statements) == 0 and not method.return_type: + # Set the return type to None + method.set_return_type("None") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/split_decorators/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/split_decorators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/split_decorators/split_decorators.py b/Libraries/graph_sitter_lib/codemods/canonical/split_decorators/split_decorators.py new file mode 100644 index 00000000..57d01927 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/split_decorators/split_decorators.py @@ -0,0 +1,52 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that transforms a single decorator call into multiple calls. The codemod should iterate through all classes in a codebase, +identify decorators matching the pattern '@generic_repr', and replace them with separate decorators for each argument passed to the original +decorator. Ensure that the original decorator's ordering is preserved by editing in-place.""", + uid="3f6325b8-02c3-4d90-a726-830f8bccce3a", +) +@canonical +class SplitDecorators(Codemod, Skill): + """This codemod splits a single decorator call into multiple + + For example: + @generic_repr("id", "name", "email") + def f(): + ... + + Becomes: + @generic_repr("id") + @generic_repr("name") + @generic_repr("email") + def f(): + ... + + Note that we edit the original decorator in-place (`decorator.edit(...)`), so as to keep the original decorator's ordering! + + If we instead did `add_decorator` etc., we would have to figure out where to insert the new decorators. + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # Iterate over all classes in the codebase + for cls in codebase.classes: + # Find all decorators of the function that match the pattern for `@allow_update` - this is a list of Decorator instances with '{' in the source + target_decorators = [decorator for decorator in cls.decorators if "@generic_repr" in decorator.source] + for decorator in target_decorators: + new_decorators = [] + for arg in decorator.call.args: + new_decorator_source = f"@generic_repr({arg})" + new_decorators.append(new_decorator_source) + + # Remove the original decorator as it will be replaced + decorator.edit("\n".join(new_decorators), fix_indentation=True) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/split_file/split_file.py b/Libraries/graph_sitter_lib/codemods/canonical/split_file/split_file.py new file mode 100644 index 00000000..bbdbdeb3 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/split_file/split_file.py @@ -0,0 +1,38 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that splits a large file by moving all subclasses of 'Enum' from 'sqlglot/optimizer/scope.py' to a new file named +'sqlglot/optimizer/enums.py'. The codemod should check if the large file exists, raise a FileNotFoundError if it does not, and then create the new +file before iterating through the classes in the large file to move the relevant subclasses.""", + uid="a7c7388d-f473-4a37-b316-e881079fe093", +) +@canonical +class SplitFile(Codemod, Skill): + """This codemod moves symbols from one large to a new file with the goal of breaking up a large file.""" + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase): + # Grab large file to split + file = codebase.get_file("sqlglot/optimizer/scope.py", optional=True) + if file is None: + msg = "The file `sqlglot/optimizer/scope.py` was not found." + raise FileNotFoundError(msg) + + # Create a new file for storing all our 'Enum' classes + new_file = codebase.create_file("sqlglot/optimizer/enums.py") + + # iterate over all classes + for cls in file.classes: + # Check inheritance + if cls.is_subclass_of("Enum"): + # Move symbol + cls.move_to_file(new_file) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py b/Libraries/graph_sitter_lib/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py new file mode 100644 index 00000000..46e7d71b --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py @@ -0,0 +1,59 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import CodebaseType +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that splits a file by moving classes containing 'Configuration' to a new file named 'configuration.py'. After moving, commit +the changes to ensure the new classes are recognized. Then, rename all 'Configuration' classes in the new file to 'Config'. Finally, update the +original file's path from 'types.py' to 'schemas.py'.""", + uid="816415d9-27e8-4228-b284-1b18b3072f0d", +) +@canonical +class SplitFileAndRenameSymbols(Codemod, Skill): + """Split file and rename moved symbols + + This codemod first moves several symbols to new files and then renames them. + + This requires a codebase.commit() call between the move and the rename step. + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: CodebaseType): + # Get file to split up + source_file = codebase.get_file("redash/models/types.py", optional=True) + if source_file is None: + msg = "[1] The file `redash/models/types.py` was not found." + raise FileNotFoundError(msg) + + # Get file symbols will be moved to + configuration_file = codebase.create_file("redash/models/configuration.py") + + # Move all the classes that contain with `Configuration` to the new configuration file + for cls in source_file.classes: + # Move the `_filter` functions + if "Configuration" in cls.name: + # Move the function to the filters file and rename it + # move_to_file should also take care of updating the imports of the functions, and bringing over any imports or local references the function needs + cls.move_to_file(configuration_file, include_dependencies=True, strategy="update_all_imports") + + # Commit is NECESSARY for the codebase graph to be aware of the new classes moved into configuration file + codebase.commit() + + # re-acquire the configuration file with the latest changes + configuration_file = codebase.get_file("redash/models/configuration.py") + + # rename all the `Configuration` classes to `Config` + for cls in configuration_file.classes: + if cls.name == "Configuration": + cls.rename("Config") + + # re-acquire the source file with the latest changes + source_file = codebase.get_file("redash/models/types.py") + source_file.update_filepath("redash/models/schemas.py") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/split_large_files/split_large_files.py b/Libraries/graph_sitter_lib/codemods/canonical/split_large_files/split_large_files.py new file mode 100644 index 00000000..899571a3 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/split_large_files/split_large_files.py @@ -0,0 +1,49 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a TypeScript codemod that processes a codebase to split large files. The codemod should define constants for maximum file length (500 lines) +and maximum symbol length (50 lines). It should iterate through all files in the codebase, checking if a file exceeds the maximum length. If a file +has more than 3 symbols that exceed the maximum symbol length, create a new directory for the file (removing the .ts extension) and move each long +symbol into its own new file within that directory. Ensure to add a back edge to the original file for each moved symbol.""", + uid="b5bbec91-5bfe-4b4b-b62e-0a1ec94089b5", +) +@canonical +class SplitLargeFiles(Codemod, Skill): + """This codemod splits all large files.""" + + language = ProgrammingLanguage.TYPESCRIPT + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) + def execute(self, codebase: Codebase): + # Define constants for maximum lengths + MAX_FILE_LENGTH = 500 + MAX_SYMBOL_LENGTH = 50 + + # Iterate over all files in the codebase + for file in codebase.files: + # Check if the file has more than the maximum file length + if len(file.content.splitlines()) > MAX_FILE_LENGTH: + # Count the number of symbols with more than the maximum symbol length + long_symbols_count = sum(1 for symbol in file.symbols if len(symbol.source.splitlines()) > MAX_SYMBOL_LENGTH) + # Proceed if there are more than 3 long symbols + if long_symbols_count > 3: + # Create a new directory for the file + dir_name = file.filepath.replace(".ts", "") + codebase.create_directory(dir_name, exist_ok=True) + # Iterate over symbols in the file + for symbol in file.symbols: + # Skip any symbol named 'Space' + if len(symbol.source.splitlines()) > MAX_SYMBOL_LENGTH: + # Create a new file for the symbol + new_file = codebase.create_file(f"{dir_name}/{symbol.name}.ts", sync=False) + # Move the symbol to the new file + symbol.move_to_file(new_file) + # Add a back edge to the original file + file.add_import(symbol) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/swap_call_site_imports/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/swap_call_site_imports/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py b/Libraries/graph_sitter_lib/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py new file mode 100644 index 00000000..719e770f --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py @@ -0,0 +1,63 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that replaces all imports of a legacy function with its new replacement. The codemod should find all call sites of the +legacy function, update the import module to the new module, and handle the edge case where the legacy function is called within the same file it is +defined. In this case, the codemod should remove the legacy function and add an import for its replacement. The legacy function is located in +'redash/settings/helpers.py' and is named 'array_from_string'. The new import module is 'redash.settings.collections'. Include comments to explain +each step.""", + uid="8fa00be7-adad-473d-8436-fc5f70e6ac6d", +) +@canonical +class SwapCallSiteImports(Codemod, Skill): + """This codemod replaces all imports of a legacy function with it's new replacement. + + This involves: + - Finding all the call sites of the old function + - Updating the import module of the old function import to the new module + - Edge case: legacy function is called within the same file it's defined in + - There won't be an import to the legacy function in this file (b/c it's where it's defined) + - For this case we have to both remove the legacy function and add an import to it's replacement. + + Example: + Before: + from mod import func + + func() + + After: + from new_mode import func + + func() + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + legacy_func_file = codebase.get_file("redash/settings/helpers.py") + legacy_function = legacy_func_file.get_function("array_from_string") + + # Find all call sites of the legacy function + for call_site in legacy_function.call_sites: + # Get the import of the legacy function in the call site file + legacy_import = next((x for x in call_site.file.imports if x.resolved_symbol == legacy_function), None) + + # Update the import module of the old function import to the new module + if legacy_import: + legacy_import.set_import_module("redash.settings.collections") + + # Edge case: legacy function is called within the same file it's defined in + if call_site.file == legacy_function.file: + # Remove the legacy function + legacy_function.remove() + + # Add import of the new function + call_site.file.add_import(f"from settings.collections import {legacy_function.name}") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py b/Libraries/graph_sitter_lib/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py new file mode 100644 index 00000000..24662ef9 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py @@ -0,0 +1,63 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that transfers attributes from one class to another. The codemod should rename parameters of functions that use the first +class (GraphRagConfig) to use the second class (CacheConfig) instead. It should also handle variable renaming to avoid conflicts, update function +definitions, add necessary imports, and modify function call sites accordingly.""", + uid="4a3569c2-cf58-4bdc-822b-7a5747f476ab", +) +@canonical +class SwapClassAttributeUsages(Codemod, Skill): + """This codemod takes two classes (class A and class B) and transfers one class's attributes to the other. + It does this by: + - Renaming any parameters that are passing the class A and replaces it to take in class B instead + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + class_a_symb = codebase.get_symbol("GraphRagConfig") + class_b_symb = codebase.get_symbol("CacheConfig") + + for function in codebase.functions: + parameters = function.parameters + if any(p.type == class_a_symb for p in parameters): + # Rename existing instances of `cache_config`=> `cache_config_` (prevents mypy issue) + name_conflict_vars = function.code_block.get_local_var_assignments("cache_config") + for name_conflict_var in name_conflict_vars: + name_conflict_var.rename("cache_config_") + + # Get the parameter to update + class_a_param = function.get_parameter_by_type(class_a_symb) + + # Update original function definition + class_a_param.edit("cache_config: CacheConfig") + + # Add import of `CacheConfig` to function definition file + function.file.add_import(class_b_symb) + + # Check if the function body is using `cache_config` + if len(function.code_block.get_variable_usages(class_a_param.name)) > 0: + # Add "wrapper" inside the function + # This creates the `cache_config` variable internally + proxy_var_declaration = f"""{class_a_param.name} = cache_config.settings # added by Codegen""" + function.prepend_statements(proxy_var_declaration) + + # Update all callsites of original function to take in `cache_config` instead of `graph_rag_config` + fcalls = function.call_sites + for fcall in fcalls: + arg = fcall.get_arg_by_parameter_name(class_a_param.name) + if not arg: + continue + if arg.is_named: + arg.edit(f"cache_config={arg.value}.cache_config") + else: + arg.edit(f"{arg.value}.cache_config") diff --git a/Libraries/graph_sitter_lib/codemods/canonical/update_optional_type_annotations/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/update_optional_type_annotations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py b/Libraries/graph_sitter_lib/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py new file mode 100644 index 00000000..02f6ae6f --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py @@ -0,0 +1,55 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.core.expressions import Type +from graph_sitter.core.expressions.generic_type import GenericType +from graph_sitter.core.expressions.union_type import UnionType +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that updates type annotations in a codebase. The codemod should replace instances of 'Optional[X]' with 'X | None' and +handle other generic types and unions appropriately. Ensure that the codemod iterates through all files, processes functions and methods, checks for +typed parameters, and modifies their annotations as needed. Additionally, include an import statement for future annotations if any changes are made.""", + uid="0e2d60db-bff0-4020-bda7-f264ff6c7f46", +) +@canonical +class UpdateOptionalTypeAnnotations(Codemod, Skill): + """Replaces type annotations with builtin ones, e.g.: + def f(x: Optional[int]): + becomes + def f(x: int | None): + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + def update_type_annotation(type: Type) -> str: + if "Optional" in type.source: + if isinstance(type, GenericType): + if type.name == "Optional": + return update_type_annotation(type.parameters[0]) + " | None" + else: + return f"{type.name}[{', '.join(update_type_annotation(param) for param in type.parameters)}]" + if isinstance(type, UnionType): + return " | ".join(update_type_annotation(param) for param in type) + return type.source + + # Iterate over all files in the codebase + for file in codebase.files: + # Process standalone functions and methods within classes + for function in file.functions + [method for cls in file.classes for method in cls.methods]: + # Iterate over all parameters in the function + if function.parameters: + for parameter in function.parameters: + if parameter.is_typed: + # Check if the parameter has a type annotation + new_type = update_type_annotation(parameter.type) + if parameter.type != new_type: + # Add the future annotations import + file.add_import("from __future__ import annotations\n") + parameter.type.edit(new_type) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/update_union_types/__init__.py b/Libraries/graph_sitter_lib/codemods/canonical/update_union_types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/codemods/canonical/update_union_types/update_union_types.py b/Libraries/graph_sitter_lib/codemods/canonical/update_union_types/update_union_types.py new file mode 100644 index 00000000..e0d0a5cf --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/update_union_types/update_union_types.py @@ -0,0 +1,41 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that updates type annotations from the old Union[x, y] syntax to the new x | y syntax for migration from Python 3.9 to +Python 3.10. The codemod should iterate through all files in a codebase, check for imports of Union from typing, and replace occurrences of Union in +both generic type and subscript forms. Ensure that the new syntax is correctly formatted, handling cases with multiple types and removing any empty +strings from trailing commas.""", + uid="7637d11a-b907-4716-a09f-07776f81a359", +) +@canonical +class UpdateUnionTypes(Codemod, Skill): + """This updates the Union [ x , y ] syntax for x | y for migrations for python 3.9 to python 3.10""" + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + for file in codebase.files: + # Check if the file imports Union from typing + if "Union" in [imp.name for imp in file.imports]: + # Search for Union type annotations in the file + for editable in file.find("Union["): + if editable.ts_node_type == "generic_type": + new_type = editable.source.replace("Union[", "").replace("]", "", 1).replace(", ", " | ") + editable.replace(editable.source, new_type) + elif editable.ts_node_type == "subscript": + # Handle subscript case (like TypeAlias = Union[...]) + types = editable.source[6:-1].split(",") + # Remove any empty strings that might result from trailing commas + types = [t.strip() for t in types if t.strip()] + new_type = " | ".join(types) + if len(types) > 1: + new_type = f"({new_type})" + editable.replace(editable.source, new_type) diff --git a/Libraries/graph_sitter_lib/codemods/canonical/use_named_kwargs/use_named_kwargs.py b/Libraries/graph_sitter_lib/codemods/canonical/use_named_kwargs/use_named_kwargs.py new file mode 100644 index 00000000..274ab6f9 --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/use_named_kwargs/use_named_kwargs.py @@ -0,0 +1,57 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.core.external_module import ExternalModule +from graph_sitter.python.class_definition import PyClass +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a Python codemod that converts all function calls in a codebase to use named keyword arguments if they have two or more positional arguments. +The codemod should iterate through all files and functions, checking each function call to determine if it meets the criteria for conversion. Ensure +that the conversion is skipped if all arguments are already named, if there are fewer than two arguments, if the function definition cannot be found, +if the function is a class without a constructor, or if the function is part of an external module.""", + uid="1a4b9e66-1df5-4ad1-adbb-034976add8e0", +) +@canonical +class UseNamedKwargs(Codemod, Skill): + """Converts all functions to use named kwargs if there are more than >= 2 args being used. + + In general you can use FunctionCall.convert_args_to_kwargs() once you have filtered properly + """ + + language = ProgrammingLanguage.PYTHON + + @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) + def execute(self, codebase: Codebase) -> None: + # Iterate over all files + for file in codebase.files: + # TODO: doesn't handle global function calls + # Iterate over all functions + for function in file.functions: + # look at the function calls + for call in function.function_calls: + # Skip if all args are already named + if all(arg.is_named for arg in call.args): + continue + + # Skip if call sites has < 2 args + if len(call.args) < 2: + continue + + # Skip if we can't find the def of the function + function_def = call.function_definition + if not function_def: + continue + + # Skip if function_def is a class and the class has no constructor + if isinstance(function_def, PyClass) and not function_def.constructor: + continue + + if isinstance(function_def, ExternalModule): + continue + + call.convert_args_to_kwargs() diff --git a/Libraries/graph_sitter_lib/codemods/canonical/wrap_with_component/wrap_with_component.py b/Libraries/graph_sitter_lib/codemods/canonical/wrap_with_component/wrap_with_component.py new file mode 100644 index 00000000..7d89321d --- /dev/null +++ b/Libraries/graph_sitter_lib/codemods/canonical/wrap_with_component/wrap_with_component.py @@ -0,0 +1,51 @@ +from codemods.codemod import Codemod +from graph_sitter.core.codebase import Codebase +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.writer_decorators import canonical +from tests.shared.skills.decorators import skill, skill_impl +from tests.shared.skills.skill import Skill + + +@skill( + canonical=True, + prompt="""Generate a codemod in TypeScript that wraps all instances of the JSX element ; + } + } + Button.propTypes = { + text: PropTypes.string.isRequired, + onClick: PropTypes.func.isRequired + }; + + // After + interface ButtonProps { + text: string; + onClick: CallableFunction; + } + + class Button extends React.Component { + render() { + return ; + } + } + ``` + """ + if self.parent_classes and len(self.parent_classes) > 0: + react_parent = self.parent_classes[0] + if "Component" in react_parent.source: + if interface_name := self.convert_to_react_interface(): + if isinstance(react_parent, GenericType): + react_parent.parameters.insert(0, interface_name) + else: + react_parent.insert_after(f"<{interface_name}>", newline=False) + + @writer + def class_component_to_function_component(self) -> None: + """Converts a class component to a function component.""" + return self.ctx.ts_declassify.declassify(self.source, filename=os.path.basename(self.file.file_path)) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/config_parser.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/config_parser.py new file mode 100644 index 00000000..9a47c2a0 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/config_parser.py @@ -0,0 +1,63 @@ +from pathlib import Path +from typing import TYPE_CHECKING + +from graph_sitter.codebase.config_parser import ConfigParser +from graph_sitter.core.file import File +from graph_sitter.enums import NodeType +from graph_sitter.typescript.ts_config import TSConfig + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.typescript.file import TSFile + +import os +from functools import cache + + +class TSConfigParser(ConfigParser): + # Cache of path names to TSConfig objects + config_files: dict[Path, TSConfig] + ctx: "CodebaseContext" + + def __init__(self, codebase_context: "CodebaseContext", default_config_name: str = "tsconfig.json"): + super().__init__() + self.config_files = dict() + self.ctx = codebase_context + self.default_config_name = default_config_name + + def get_config(self, config_path: os.PathLike) -> TSConfig | None: + path = self.ctx.to_absolute(config_path) + if path in self.config_files: + return self.config_files[path] + if path.exists(): + self.config_files[path] = TSConfig(File.from_content(config_path, path.read_text(), self.ctx, sync=False), self) + return self.config_files.get(path) + return None + + def parse_configs(self): + # This only yields a 0.05s speedup, but its funny writing dynamic programming code + @cache + def get_config_for_dir(dir_path: Path) -> TSConfig | None: + # Check if the config file exists in the directory + ts_config_path = dir_path / self.default_config_name + # If it does, return the config + if ts_config_path.exists(): + if ts_config := self.get_config(self.ctx.to_absolute(ts_config_path)): + self.config_files[ts_config_path] = ts_config + return ts_config + # Otherwise, check the parent directory + if dir_path.is_relative_to(self.ctx.repo_path): + return get_config_for_dir(dir_path.parent) + return None + + # Get all the files in the codebase + for file in self.ctx.get_nodes(NodeType.FILE): + file: TSFile # This should be safe because we only call this on TSFiles + # Get the config for the directory the file is in + config = get_config_for_dir(file.path.parent) + # Set the config for the file + file.ts_config = config + + # Loop through all the configs and precompute their import aliases + for config in self.config_files.values(): + config._precompute_import_aliases() diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/code_block.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/code_block.py new file mode 100644 index 00000000..563bebc6 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/code_block.py @@ -0,0 +1,79 @@ +from typing import TYPE_CHECKING, Generic, Self, TypeVar + +from graph_sitter.compiled.utils import find_line_start_and_end_nodes +from graph_sitter.core.autocommit import reader, writer +from graph_sitter.core.detached_symbols.code_block import CodeBlock +from graph_sitter.core.interfaces.editable import Editable +from graph_sitter.core.statements.statement import Statement +from graph_sitter.core.symbol_groups.multi_line_collection import MultiLineCollection +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc + +if TYPE_CHECKING: + from graph_sitter.typescript.interfaces.has_block import TSHasBlock + + +Parent = TypeVar("Parent", bound="TSHasBlock") + + +@ts_apidoc +class TSCodeBlock(CodeBlock[Parent, "TSAssignment"], Generic[Parent]): + """Extends the CodeBlock class to provide TypeScript-specific functionality.""" + + @noapidoc + @reader + def _parse_statements(self) -> MultiLineCollection[Statement, Self]: + statements: list[Statement] = self.ctx.parser.parse_ts_statements(self.ts_node, self.file_node_id, self.ctx, self) + line_nodes = find_line_start_and_end_nodes(self.ts_node) + start_node = line_nodes[1][0] if len(line_nodes) > 1 else line_nodes[0][0] + end_node = line_nodes[-2][1] if len(line_nodes) > 1 else line_nodes[-1][1] + indent_size = start_node.start_point[1] + collection = MultiLineCollection( + children=statements, + file_node_id=self.file_node_id, + ctx=self.ctx, + parent=self, + node=self.ts_node, + indent_size=indent_size, + leading_delimiter="", + start_byte=start_node.start_byte - indent_size, + end_byte=end_node.end_byte + 1, + ) + return collection + + @reader + @noapidoc + def _get_line_starts(self) -> list[Editable]: + """Returns an ordered list of first Editable for each non-empty line within the code block""" + line_start_nodes = super()._get_line_starts() + if len(line_start_nodes) >= 3 and line_start_nodes[0].source == "{" and line_start_nodes[-1].source == "}": + # Remove the first and last line of the code block as they are opening and closing braces. + return line_start_nodes[1:-1] + return line_start_nodes + + @reader + @noapidoc + def _get_line_ends(self) -> list[Editable]: + """Returns an ordered list of last Editable for each non-empty line within the code block""" + line_end_nodes = super()._get_line_ends() + # Remove the first and last line of the code block as they are opening and closing braces. + return line_end_nodes[1:-1] + + @writer + def unwrap(self) -> None: + """Unwraps a code block by removing its opening and closing braces. + + This method removes both the opening and closing braces of a code block, including any trailing whitespace + up to the next sibling node if it exists, or up to the closing brace of the last line if no sibling exists. + This is commonly used to flatten nested code structures like if statements, with statements, and function bodies. + + Returns: + None + """ + super().unwrap() + # Also remove the closing brace of the last line. + next_sibling = self.ts_node.next_sibling + if next_sibling: + self.remove_byte_range(self.ts_node.end_byte - 1, next_sibling.start_byte) + else: + # If there is no next sibling, remove up to the closing brace of the last line + self.remove_byte_range(self._get_line_ends()[-1].end_byte, self.ts_node.end_byte) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/decorator.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/decorator.py new file mode 100644 index 00000000..93fec06e --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/decorator.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graph_sitter.core.autocommit import reader +from graph_sitter.core.detached_symbols.decorator import Decorator +from graph_sitter.core.detached_symbols.function_call import FunctionCall +from graph_sitter.shared.decorators.docs import ts_apidoc + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + +@ts_apidoc +class TSDecorator(Decorator["TSClass", "TSFunction", "TSParameter"]): + """Abstract representation of a Decorator""" + + @reader + def _get_name_node(self) -> TSNode: + """Returns the name of the decorator.""" + for child in self.ts_node.children: + # =====[ Identifier ]===== + # Just `@dataclass` etc. + if child.type == "identifier": + return child + + # =====[ Attribute ]===== + # e.g. `@a.b` + elif child.type == "member_expression": + return child + + # =====[ Call ]===== + # e.g. `@a.b()` + elif child.type == "call_expression": + func = child.child_by_field_name("function") + return func + + msg = f"Could not find decorator name within {self.source}" + raise ValueError(msg) + + @property + @reader + def call(self) -> FunctionCall | None: + """Retrieves the function call expression associated with the decorator. + + This property checks if the decorator has a function call expression (e.g., @decorator()) and returns it as a FunctionCall object. + If the decorator is a simple identifier (e.g., @decorator), returns None. + + Returns: + FunctionCall | None: A FunctionCall object representing the decorator's call expression if present, None otherwise. + """ + if call_node := next((x for x in self.ts_node.named_children if x.type == "call_expression"), None): + return FunctionCall(call_node, self.file_node_id, self.ctx, self.parent) + return None diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/jsx/element.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/jsx/element.py new file mode 100644 index 00000000..1de2a1fb --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/jsx/element.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Generic, TypeVar, override + +from graph_sitter.compiled.autocommit import commiter +from graph_sitter.core.autocommit import reader, writer +from graph_sitter.core.expressions import Expression, Value +from graph_sitter.core.expressions.name import Name +from graph_sitter.core.interfaces.has_name import HasName +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.detached_symbols.jsx.prop import JSXProp +from graph_sitter.utils import find_all_descendants + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.dataclasses.usage import UsageKind + from graph_sitter.core.interfaces.editable import Editable + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.jsx.expression import JSXExpression + +Parent = TypeVar("Parent", bound="Editable") + + +@ts_apidoc +class JSXElement(Expression[Parent], HasName, Generic[Parent]): + """Abstract representation of TSX/JSX elements, e.g. ``. This allows for many React-specific modifications, like adding props, changing the name, etc.""" + + _name_node: Name | None + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent) -> None: + super().__init__(ts_node, file_node_id, ctx, parent) + open_tag = self.ts_node.child_by_field_name("open_tag") or self.ts_node + name_node = open_tag.child_by_field_name("name") + self._name_node = self._parse_expression(name_node, default=Name) + self.children # Force parse children of this JSX element + + @cached_property + @reader + def jsx_elements(self) -> list[JSXElement]: + """Returns a list of JSX elements nested within the current element. + + Gets all JSX elements that are descendants of this element in the syntax tree, excluding the element itself. + This includes both regular JSX elements (`...`) and self-closing elements (``). + + Args: + None + + Returns: + list[JSXElement]: A list of JSXElement objects representing all nested JSX elements. + """ + jsx_elements = [] + for node in self.extended_nodes: + jsx_element_nodes = find_all_descendants(node.ts_node, {"jsx_element", "jsx_self_closing_element"}) + jsx_elements.extend([self._parse_expression(x) for x in jsx_element_nodes if x != self.ts_node]) + return jsx_elements + + @cached_property + @reader + def expressions(self) -> list[JSXExpression]: + """Gets all JSX expressions within the JSX element. + + Retrieves all JSX expressions that are descendant nodes of the current JSX element, including expressions in child elements and attributes. + + Returns: + list[JSXExpression]: A list of JSX expression objects found within this element, excluding the current element itself. + """ + jsx_expressions = [] + for node in self.extended_nodes: + jsx_expressions_nodes = find_all_descendants(node.ts_node, {"jsx_expression"}) + jsx_expressions.extend([self._parse_expression(x) for x in jsx_expressions_nodes if x != self.ts_node]) + return jsx_expressions + + @property + @noapidoc + @reader + def _attribute_nodes(self) -> list[Editable]: + """Returns all attribute nodes of the element""" + open_tag = self.ts_node.child_by_field_name("open_tag") or self.ts_node + attribute_nodes = open_tag.children_by_field_name("attribute") + return [Value(x, self.file_node_id, self.ctx, self) for x in attribute_nodes] + + @property + @reader + def props(self) -> list[JSXProp]: + """Retrieves all JSXProps (attributes) from a JSX element. + + Gets all props (attributes) on the current JSX element. For example, in ``, this would return a list with one JSXProp object representing `prop1="value"`. + + Args: + self: The JSXElement instance. + + Returns: + list[JSXProp]: A list of JSXProp objects representing each attribute on the element. + """ + return [self._parse_expression(x.ts_node, default=JSXProp) for x in self._attribute_nodes] + + @reader + def get_prop(self, name: str) -> JSXProp | None: + """Returns the JSXProp with the given name from the JSXElement. + + Searches through the element's props to find a prop with a matching name. + + Args: + name (str): The name of the prop to find. + + Returns: + JSXProp | None: The matching JSXProp object if found, None if not found. + """ + for prop in self.props: + if prop.name == name: + return prop + return None + + @property + def attributes(self) -> list[JSXProp]: + """Returns all JSXProp on this JSXElement, an alias for JSXElement.props. + + Returns all JSXProp attributes (props) on this JSXElement. For example, for a JSX element like + ``, this would return a list containing one JSXProp object. + + Returns: + list[JSXProp]: A list of JSXProp objects representing each attribute/prop on the JSXElement. + """ + return [self._parse_expression(x.ts_node, default=JSXProp) for x in self._attribute_nodes] + + @writer + def set_name(self, name: str) -> None: + """Sets the name of a JSXElement by modifying both opening and closing tags. + + Updates the name of a JSX element, affecting both self-closing tags (``) and elements with closing tags (``). + + Args: + name (str): The new name to set for the JSX element. + + Returns: + None: The method modifies the JSXElement in place. + """ + # This should correctly set the name of both the opening and closing tags + if open_tag := self.ts_node.child_by_field_name("open_tag"): + name_node = self._parse_expression(open_tag.child_by_field_name("name"), default=Name) + name_node.edit(name) + if close_tag := self.ts_node.child_by_field_name("close_tag"): + name_node = self._parse_expression(close_tag.child_by_field_name("name"), default=Name) + name_node.edit(name) + else: + # If the element is self-closing, we only need to edit the name of the element + super().set_name(name) + + @writer + def add_prop(self, prop_name: str, prop_value: str) -> None: + """Adds a new prop to a JSXElement. + + Adds a prop with the specified name and value to the JSXElement. If the element already has props, + the new prop is added after the last existing prop. If the element has no props, the new prop is + added immediately after the element name. + + Args: + prop_name (str): The name of the prop to add. + prop_value (str): The value of the prop to add. + + Returns: + None + """ + if len(self.props) > 0: + last_prop = self.props[-1] + # Extra padding is handled by the insert_after method on prop + last_prop.insert_after(f"{prop_name}={prop_value}", newline=False) + else: + self._name_node.insert_after(f" {prop_name}={prop_value}", newline=False) + + @property + @reader + @noapidoc + def _source(self): + """Text representation of the Editable instance""" + return self.ts_node.text.decode("utf-8").strip() + + @writer + def wrap(self, opening_tag: str, closing_tag: str) -> None: + """Wraps the current JSXElement with the provided opening and closing tags, properly handling indentation. + + Args: + opening_tag (str): The opening JSX tag to wrap around the current element (e.g. `
`) + closing_tag (str): The closing JSX tag to wrap around the current element (e.g. `
`) + """ + current_source = self.source + indented_source = "\n".join(f" {line.rstrip()}" for line in current_source.split("\n")) + new_source = f"{opening_tag}\n{indented_source}\n{closing_tag}" + self.edit(new_source, fix_indentation=True) + + @commiter + @noapidoc + @override + def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: + for node in self.children: + node._compute_dependencies(usage_type, dest=dest) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/jsx/expression.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/jsx/expression.py new file mode 100644 index 00000000..beaed181 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/jsx/expression.py @@ -0,0 +1,74 @@ +from functools import cached_property +from typing import Self, override + +from graph_sitter.compiled.autocommit import commiter +from graph_sitter.core.autocommit import reader, writer +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.expressions import Expression +from graph_sitter.core.interfaces.editable import Editable +from graph_sitter.core.interfaces.has_name import HasName +from graph_sitter.core.interfaces.unwrappable import Unwrappable +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc + + +@ts_apidoc +class JSXExpression(Unwrappable["Function | JSXElement | JSXProp"]): + """Abstract representation of TSX/JSX expression""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.statement + + @cached_property + @reader + def statement(self) -> Editable[Self] | None: + """Returns the editable component of this JSX expression. + + Retrieves the editable contained within this JSX expression by accessing the second child node. Returns None if the JSX expression doesn't + contain an editable object. + + Returns: + Editable[Self]: A Editable object representing the statement of this JSX expression. None if the object doesn't have an Editable object. + """ + return self._parse_expression(self.ts_node.named_children[0]) if len(self.ts_node.named_children) > 0 else None + + @commiter + @noapidoc + @override + def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: + if self.statement: + self.statement._compute_dependencies(usage_type, dest=dest) + + @writer + def reduce_condition(self, bool_condition: bool, node: Editable) -> None: + """Simplifies a JSX expression by reducing it based on a boolean condition. + + + Args: + bool_condition (bool): The boolean value to reduce the condition to. + + """ + if self.ts_node.parent.type == "jsx_attribute" and not bool_condition: + node.edit(self.ctx.node_classes.bool_conversion[bool_condition]) + else: + self.remove() + + @writer + @override + def unwrap(self, node: Expression | None = None) -> None: + """Removes the brackets from a JSX expression. + + + Returns: + None + """ + from graph_sitter.typescript.detached_symbols.jsx.element import JSXElement + from graph_sitter.typescript.detached_symbols.jsx.prop import JSXProp + + if node is None: + node = self + if isinstance(self.parent, JSXProp): + return + if isinstance(node, JSXExpression | JSXElement | JSXProp): + for child in self._anonymous_children: + child.remove() diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/jsx/prop.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/jsx/prop.py new file mode 100644 index 00000000..dfaff0f4 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/jsx/prop.py @@ -0,0 +1,116 @@ +from typing import TYPE_CHECKING, override + +from tree_sitter import Node as TSNode + +from graph_sitter.codebase.codebase_context import CodebaseContext +from graph_sitter.compiled.autocommit import commiter +from graph_sitter.core.autocommit import reader, writer +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.expressions import Expression +from graph_sitter.core.expressions.name import Name +from graph_sitter.core.interfaces.has_name import HasName +from graph_sitter.core.interfaces.has_value import HasValue +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.detached_symbols.jsx.expression import JSXExpression + +if TYPE_CHECKING: + from graph_sitter.core.function import Function + from graph_sitter.typescript.detached_symbols.jsx.element import JSXElement + + +@ts_apidoc +class JSXProp(Expression["Function | JSXElement | JSXProp"], HasName, HasValue): + """Abstract representation of TSX/JSX prop, e.g .""" + + _name_node: Name | None + _expression_node: JSXExpression | None + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: "Function | JSXElement | JSXProp") -> None: + super().__init__(ts_node, file_node_id, ctx, parent) + self._name_node = self._parse_expression(self.ts_node.children[0], default=Name) + if len(self.ts_node.children) > 2: + self._value_node = self._parse_expression(self.ts_node.children[2]) + if self._value_node.ts_node.type == "jsx_expression": + self._expression_node = self._parse_expression(self._value_node.ts_node) + else: + self._expression_node = None + else: + # If there is no value node, then the prop is a boolean prop + # For example, is equivalent to + self._value_node = None + + @property + @reader + def expression(self) -> JSXExpression | None: + """Retrieves the JSX expression associated with this JSX prop. + + Returns the JSX expression node if this prop has one, e.g., for props like prop={expression}. + For boolean props or string literal props, returns None. + + Returns: + JSXExpression | None: The JSX expression node if present, None otherwise. + """ + return self._expression_node + + @writer + def insert_after( + self, + new_src: str, + fix_indentation: bool = False, + newline: bool = True, + priority: int = 0, + dedupe: bool = True, + ) -> None: + """Inserts source code after a JSX prop in a TypeScript/JSX file. + + Inserts the provided source code after the current JSX prop, adding necessary spacing. + + Args: + new_src (str): The source code to insert after the prop. + fix_indentation (bool, optional): Whether to fix the indentation of the inserted code. Defaults to False. + newline (bool, optional): Whether to add a newline after the inserted code. Defaults to True. + priority (int, optional): The priority of the insertion. Defaults to 0. + dedupe (bool, optional): Whether to prevent duplicate insertions. Defaults to True. + + Returns: + None + """ + # TODO: This may not be transaction save with adds and deletes + # Insert space after the prop name + super().insert_after(" " + new_src, fix_indentation, newline, priority, dedupe) + + @writer + def insert_before( + self, + new_src: str, + fix_indentation: bool = False, + newline: bool = True, + priority: int = 0, + dedupe: bool = True, + ) -> None: + """Insert a new source code string before a JSX prop in a React component. + + Inserts a new string of source code before a JSX prop, maintaining proper spacing. + Automatically adds a trailing space after the inserted code. + + Args: + new_src (str): The source code string to insert before the prop. + fix_indentation (bool, optional): Whether to adjust the indentation of the inserted code. Defaults to False. + newline (bool, optional): Whether to add a newline after the inserted code. Defaults to True. + priority (int, optional): Priority of this insertion relative to others. Defaults to 0. + dedupe (bool, optional): Whether to avoid duplicate insertions. Defaults to True. + + Returns: + None + """ + # TODO: This may not be transaction save with adds and deletes + # Insert space before the prop name + super().insert_before(new_src + " ", fix_indentation, newline, priority, dedupe) + + @commiter + @noapidoc + @override + def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: + for node in self.children: + node._compute_dependencies(usage_type, dest=dest) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/parameter.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/parameter.py new file mode 100644 index 00000000..76de0666 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/parameter.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, override + +from graph_sitter.compiled.autocommit import commiter +from graph_sitter.core.autocommit import reader +from graph_sitter.core.autocommit.decorators import writer +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.detached_symbols.parameter import Parameter +from graph_sitter.core.expressions.union_type import UnionType +from graph_sitter.core.symbol_groups.collection import Collection +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.expressions.object_type import TSObjectType +from graph_sitter.typescript.expressions.type import TSType +from graph_sitter.typescript.symbol_groups.dict import TSPair + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.core.interfaces.has_name import HasName + from graph_sitter.core.placeholder.placeholder import Placeholder + from graph_sitter.typescript.function import TSFunction + + +@ts_apidoc +class TSParameter(Parameter[TSType, Collection["TSParameter", "TSFunction"]]): + """A class representing a TypeScript function parameter with extensive type analysis capabilities. + + This class provides functionality to inspect and manipulate TypeScript function parameters, + including support for destructured parameters, optional parameters, variadic parameters, + default values, and type annotations. + + Attributes: + type (TSType): The TypeScript type annotation of the parameter. + """ + + def __init__(self, ts_node: TSNode, index: int, parent: TSFunction, type: TSType | Placeholder | None = None) -> None: + super().__init__(ts_node, index, parent) + if not self.type and type is not None: + self.type = type # Destructured types + + @property + @reader + def is_destructured(self) -> bool: + """Determines if a parameter is part of an object destructuring pattern. + + Checks the parameter's tree-sitter node type to determine if it represents a destructured parameter. + A parameter is considered destructured if it appears within an object destructuring pattern. + + Returns: + bool: True if the parameter is destructured, False otherwise. + """ + return self.ts_node.type in ("shorthand_property_identifier_pattern", "object_assignment_pattern") + + @property + @reader + def is_optional(self) -> bool: + """Determines if a parameter is marked as optional in TypeScript. + + Checks whether a parameter is marked with the '?' syntax in TypeScript, indicating that it is optional. + If the parameter is part of a destructured pattern, this function returns False as optionality is + handled at the function level for destructured parameters. + + Returns: + bool: True if the parameter is marked as optional, False otherwise. + """ + if self.is_destructured: + # In this case, individual destructured parameters are not marked as optional + # The entire object might be optional, but that's handled at the function level + return False + else: + return self.ts_node.type == "optional_parameter" + + @property + @reader + def is_variadic(self) -> bool: + """Determines if a parameter is variadic (using the rest operator). + + A property that checks if the parameter uses the rest pattern (e.g., ...args in TypeScript), + which allows the parameter to accept an arbitrary number of arguments. + + Returns: + bool: True if the parameter is variadic (uses rest pattern), False otherwise. + """ + pattern = self.ts_node.child_by_field_name("pattern") + return pattern is not None and pattern.type == "rest_pattern" + + @property + @reader + def default(self) -> str | None: + """Returns the default value of a parameter. + + Retrieves the default value of a parameter, handling both destructured and non-destructured parameters. + For destructured parameters, returns the default value if it's an object assignment pattern. + For non-destructured parameters, returns the value specified after the '=' sign. + + Returns: + str | None: The default value of the parameter as a string if it exists, None otherwise. + """ + # =====[ Destructured ]===== + if self.is_destructured: + if self.ts_node.type == "object_assignment_pattern": + return self.ts_node.children[-1].text.decode("utf-8") + else: + return None + + # =====[ Not destructured ]===== + default_node = self.ts_node.child_by_field_name("value") + if default_node is None: + return None + return default_node.text.decode("utf-8") + + @noapidoc + @commiter + @override + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: + if self.type: + if not (self.is_destructured and self.index > 0): + self.type._compute_dependencies(UsageKind.TYPE_ANNOTATION, dest or self.parent.self_dest) + if self.value: + self.value._compute_dependencies(UsageKind.DEFAULT_VALUE, dest or self.parent.self_dest) + + @writer + def convert_to_interface(self) -> None: + """Converts a parameter's inline type definition to an interface. + + For React components, converts inline props type definitions to a separate interface. + Handles both simple types and complex types including generics, extends patterns, and union types. + The interface will be named {ComponentName}Props and inserted before the component. + Supports extracting types from destructured parameters and preserves any type parameters. + + Example: + ```typescript + // Before + function Button(props: { text: string, onClick: () => void }) { + return ; + } + + // After + interface ButtonProps { + text: string; + onClick: () => void; + } + function Button(props: ButtonProps) { + return ; + } + ``` + """ + if not self.type or not self.parent_function.is_jsx or not isinstance(self.type, TSObjectType | UnionType): + return + + # # Get the type definition and component name + # type_def = self.type.source + component_name = self.parent_function.name + + # # Handle extends pattern + extends_clause: str = "" + + type = self.type + if isinstance(type, UnionType): + for subtype in type: + if isinstance(subtype, TSObjectType): + type = subtype + else: + extends_clause += f" extends {subtype.source}" + + # # Extract generic type parameters if present + generic_params = "" + if self.parent_function.type_parameters: + generic_params = self.parent_function.type_parameters.source + interface_name = f"{component_name}Props" + # # Update parameter type to use interface + if generic_params: + interface_name += generic_params + + # # Convert type definition to interface + interface_def = f"interface {interface_name}{extends_clause} {{\n" + + # Strip outer braces and convert to semicolon-separated lines + for value in type.values(): + interface_def += f" {value.parent_of_type(TSPair).source.rstrip(',')};\n" + interface_def += "}" + + # Insert interface before the function + self.parent_function.insert_before(interface_def + "\n") + + self.type.edit(interface_name) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/promise_chain.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/promise_chain.py new file mode 100644 index 00000000..78aa9399 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/detached_symbols/promise_chain.py @@ -0,0 +1,559 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graph_sitter.core.autocommit import reader, writer +from graph_sitter.core.expressions import Name +from graph_sitter.core.statements.statement import StatementType + +if TYPE_CHECKING: + from graph_sitter.core.class_definition import Class + from graph_sitter.core.detached_symbols.function_call import FunctionCall + from graph_sitter.core.statements.statement import Statement + from graph_sitter.core.symbol_groups.multi_line_collection import MultiLineCollection + from graph_sitter.typescript.function import TSFunction + + +class TSPromiseChain: + """A class representing a TypeScript Promise chain. + + This class parses and handles Promise chains in TypeScript code, including .then(), .catch(), and .finally() chains. + It provides functionality to convert Promise chains to async/await syntax. + """ + + base_chain: list[FunctionCall | Name] + then_chain: list[FunctionCall] + catch_call: FunctionCall | None + finally_call: FunctionCall | None + after_promise_chain: list[FunctionCall | Name] + base_attribute: Name + parent_statement: Statement + parent_function: FunctionCall + parent_class: Class + declared_vars: set[str] + base_indent: str + name: str | None + log_statements: list[str] = ["console.error", "console.warn", "console.log"] + + def __init__(self, attribute_chain: list[FunctionCall | Name]) -> None: + """Initialize a TSPromiseChain instance. + + Args: + attribute_chain: A list of function calls or a Name object representing the Promise chain + """ + # Parse the chain and assign all attributes + (self.base_chain, self.then_chain, self.catch_call, self.finally_call, self.after_promise_chain) = self._parse_chain(attribute_chain) + + self.base_attribute = self.base_chain[-1].parent.object + self.parent_statement = self.base_chain[0].parent_statement + self.parent_function = self.parent_statement.parent_function + self.parent_class = self.parent_statement.parent_class + self.declared_vars = set() + self.base_indent = " " * self.parent_statement._get_indent() + self.name = self.base_chain[0].source if isinstance(self.base_chain[0], Name) else self.base_chain[0].name + + @reader + def _parse_chain(self, attribute_chain: list[FunctionCall | Name]) -> tuple[list[FunctionCall], list[FunctionCall], FunctionCall | None, FunctionCall | None, list[FunctionCall | Name]]: + """Parse the Promise chain into its component parts. + + Args: + attribute_chain: The chain of function calls to parse + + Returns: + A tuple containing: + - base_chain: Initial function calls + - then_chain: .then() calls + - catch_call: .catch() call if present + - finally_call: .finally() call if present + - after_promise_chain: Calls after the Promise chain + """ + base_chain: list[FunctionCall | Name] = [] + then_chain: list[FunctionCall] = [] + catch_call: FunctionCall | None = None + finally_call: FunctionCall | None = None + after_promise_chain: list[FunctionCall | Name] = [] + + in_then_chain: bool = False + promise_chain_ended: bool = False + + for attribute in attribute_chain: + if not isinstance(attribute, Name): + if attribute.name == "then": + in_then_chain = True + then_chain.append(attribute) + elif attribute.name == "catch": + catch_call = attribute + in_then_chain = False + elif attribute.name == "finally": + finally_call = attribute + in_then_chain = False + promise_chain_ended = True + else: + if promise_chain_ended: + after_promise_chain.append(attribute) + elif in_then_chain: + then_chain.append(attribute) + else: + base_chain.append(attribute) + else: + if promise_chain_ended: + after_promise_chain.append(attribute) + elif in_then_chain: + then_chain.append(attribute) + else: + base_chain.append(attribute) + + return base_chain, then_chain, catch_call, finally_call, after_promise_chain + + @property + @reader + def is_return_statement(self) -> bool: + """Check if the parent statement is a return statement. + + Returns: + bool: True if the parent statement is a return statement + """ + return self.parent_statement.statement_type == StatementType.RETURN_STATEMENT + + @property + @reader + def assigned_var(self) -> str | None: + """Get the variable being assigned to in an assignment statement. + + Returns: + Optional[str]: The name of the variable being assigned to, or None if not an assignment + """ + if self.parent_statement.statement_type == StatementType.ASSIGNMENT: + return self.parent_statement.left + + @reader + def get_next_call_params(self, call: FunctionCall | None) -> list[str]: + from graph_sitter.typescript.function import TSFunction + + """Get parameters from the next then/catch/finally call. + + Args: + call: The function call to extract parameters from + + Returns: + List[str]: List of parameter names from the call + """ + # handling the .then in parameter function + if call and len(call.args) > 0 and isinstance(call.args[0].value, TSFunction): + return [p.source for p in call.args[0].value.parameters] + + return [] + + @reader + def _needs_anonymous_function(self, arrow_fn: TSFunction) -> bool: + """Determine if we need to use an anonymous function wrapper. + + Returns True if: + 1. There are multiple return statements + 2. The code block has complex control flow (if/else, loops, etc) + + Args: + arrow_fn: The arrow function to analyze + + Returns: + bool: True if an anonymous function wrapper is needed + """ + statements = arrow_fn.code_block.get_statements() + return_count = sum(1 for stmt in statements if stmt.statement_type == StatementType.RETURN_STATEMENT) + return return_count > 1 + + @reader + def format_param_assignment(self, params: list[str], base_expr: str, declare: bool = True) -> str: + """Format parameter assignment with proper let declaration if needed. + + Args: + params: List of parameter names to assign + base_expr: The base expression to assign from + declare: Whether to declare new variables with 'let' + + Returns: + str: Formatted parameter assignment string + """ + if not params: + return base_expr + + if len(params) > 1: + param_str = ", ".join(params) + if declare and not any(p in self.declared_vars for p in params): + self.declared_vars.update(params) + return f"let [{param_str}] = {base_expr}" + return f"[{param_str}] = {base_expr}" + else: + param = params[0] + if declare and param not in self.declared_vars: + self.declared_vars.add(param) + return f"let {param} = {base_expr}" + return f"{param} = {base_expr}" + + @reader + def handle_base_call(self) -> str: + """Format the base promise call. + + Returns: + str: Formatted base call string + """ + new_handle = None + if "await" not in self.base_attribute.extended_source: + new_handle = f"await {self.base_attribute.extended_source};" + else: + new_handle = f"{self.base_attribute.extended_source};" + + next_params = self.get_next_call_params(self.then_chain[0]) + if next_params: + new_handle = self.format_param_assignment(next_params, new_handle) + return new_handle + + @reader + def handle_then_block(self, call: FunctionCall, next_call: FunctionCall | None = None) -> str: + from graph_sitter.typescript.function import TSFunction + + """Format a then block in the promise chain. + + Args: + call: The then call to format + next_call: The next function call in the chain, if any + + Returns: + str: Formatted then block code + """ + # a then block must have a callback handler + if not call or call.name != "then" or len(call.args) != 1: + msg = "Invalid then call provided" + raise Exception(msg) + + arrow_fn = call.args[0].value + if not isinstance(arrow_fn, TSFunction): + msg = "callback function not provided in the argument" + raise Exception(msg) + + statements = arrow_fn.code_block.statements + + formatted_statements = [] + + # adds anonymous function if then block handler has ambiguous returns + if self._needs_anonymous_function(arrow_fn): + anon_block = self._format_anonymous_function(arrow_fn, next_call) + formatted_statements.append(f"{self.base_indent}{anon_block}") + + elif self._is_implicit_return(arrow_fn): + implicit_block = self._handle_last_block_implicit_return(statements, is_catch=False) + formatted_statements.append(f"{self.base_indent}{implicit_block}") + else: + for stmt in statements: + if stmt.statement_type == StatementType.RETURN_STATEMENT: + return_value = stmt.source[7:].strip() + next_params = self.get_next_call_params(next_call) + await_expression = f"await {return_value}" + if next_params: + formatted_statements.append(f"{self.base_indent}{self.format_param_assignment(next_params, await_expression, declare=True)}") + else: + formatted_statements.append(f"{self.base_indent}{await_expression}") + else: + formatted_statements.append(f"{self.base_indent}{stmt.source.strip()}") + + return "\n".join(formatted_statements) + + @reader + def parse_last_then_block(self, call: FunctionCall, assignment_variable_name: str | None = None) -> str: + from graph_sitter.typescript.function import TSFunction + + """Parse the last .then() block in the chain. + + Args: + call: The last .then() call to parse + assignment_variable_name: Optional custom variable name for assignment + + Returns: + str: Formatted code for the last .then() block + """ + arrow_fn = call.args[0].value + + if not isinstance(arrow_fn, TSFunction): + msg = "callback function not provided in the argument" + raise Exception(msg) + + statements = arrow_fn.code_block.statements + + if self._needs_anonymous_function(arrow_fn): + return self._format_anonymous_function(arrow_fn, assignment_variable_name=assignment_variable_name) + + if self._is_implicit_return(arrow_fn): + return self._handle_last_block_implicit_return(statements, assignment_variable_name=assignment_variable_name) + else: + formatted_statements = [] + for stmt in statements: + if stmt.statement_type == StatementType.RETURN_STATEMENT: + return_value = self._handle_last_block_normal_return(stmt, assignment_variable_name=assignment_variable_name) + formatted_statements.append(return_value) + else: + formatted_statements.append(stmt.source.strip()) + return "\n".join(formatted_statements) + + @reader + def _handle_last_block_normal_return(self, stmt: Statement, is_catch: bool = False, assignment_variable_name: str | None = None) -> str: + """Handle a normal return statement in the last block of a Promise chain. + + Args: + stmt: The return statement to handle + is_catch: Whether this is in a catch block + assignment_variable_name: Optional custom variable name for assignment + + Returns: + str: Formatted return statement code + """ + return_value = stmt.source[7:].strip() # Remove 'return ' prefix + + var_name = assignment_variable_name if assignment_variable_name else self.assigned_var + if var_name: + return self.format_param_assignment([var_name], return_value) + elif self.is_return_statement: + if is_catch: + return f"throw {return_value}" + else: + return f"return {return_value}" + else: + if is_catch: + return f"throw {return_value}" + else: + return f"await {return_value}" + + @reader + def _handle_last_block_implicit_return(self, statements: MultiLineCollection[Statement], is_catch: bool = False, assignment_variable_name: str | None = None) -> str: + """Handle an implicit return in the last block of a Promise chain. + + Args: + statements: The statements in the block + is_catch: Whether this is in a catch block + assignment_variable_name: Optional custom variable name for assignment + + Returns: + str: Formatted implicit return code + """ + stmt_source = statements[0].source.strip() + var_name = assignment_variable_name if assignment_variable_name else self.assigned_var + + if any(stmt_source.startswith(console_method) for console_method in self.log_statements): + return stmt_source + ";" + elif is_catch: + return "throw " + stmt_source + ";" + elif var_name: + return self.format_param_assignment([var_name], stmt_source) + elif self.is_return_statement: + return "return " + stmt_source + ";" + else: + return "await " + stmt_source + ";" + + @reader + def handle_catch_block(self, call: FunctionCall, assignment_variable_name: str | None = None) -> str: + """Handle catch block in the promise chain. + + Args: + call: The catch function call to handle + assignment_variable_name: Optional custom variable name for assignment + + Returns: + str: Formatted catch block code + """ + # a catch block must have a callback handler + if not call or call.name != "catch" or len(call.args) != 1: + msg = "Invalid catch call provided" + raise Exception(msg) + + arrow_fn = call.args[0].value + statements = arrow_fn.code_block.statements + if len(arrow_fn.parameters) > 0: + error_param = arrow_fn.parameters[0].source + else: + error_param = "" + + formatted_statements = [f"{self.base_indent}}} catch({error_param}: any) {{"] + + # adds annonymous function if catch block handler has ambiguous returns + if self._needs_anonymous_function(arrow_fn): + anon_block = self._format_anonymous_function(arrow_fn, assignment_variable_name=assignment_variable_name) + formatted_statements.append(f"{self.base_indent}{anon_block}") + + elif self._is_implicit_return(arrow_fn): + implicit_block = self._handle_last_block_implicit_return(statements, is_catch=True, assignment_variable_name=assignment_variable_name) + formatted_statements.append(f"{self.base_indent}{implicit_block}") + else: + for stmt in statements: + if stmt.statement_type == StatementType.RETURN_STATEMENT: + return_block = self._handle_last_block_normal_return(stmt, is_catch=True, assignment_variable_name=assignment_variable_name) + formatted_statements.append(f"{self.base_indent}{return_block}") + else: + formatted_statements.append(f"{self.base_indent}{stmt.source.strip()}") + + return "\n".join(formatted_statements) + + @reader + def handle_finally_block(self, call: FunctionCall) -> str: + """Handle finally block in the promise chain. + + Args: + call: The finally function call to handle + + Returns: + str: Formatted finally block code + """ + if not call or call.name != "finally": + msg = "Invalid finally call provided" + raise Exception(msg) + + arrow_fn = call.args[0].value + statements = arrow_fn.code_block.statements + + formatted_statements = [f"{self.base_indent}}} finally {{"] + + for stmt in statements: + formatted_statements.append(f"{self.base_indent}{stmt.source.strip()}") + + return "\n".join(formatted_statements) + + @writer + def convert_to_async_await(self, assignment_variable_name: str | None = None, inplace_edit: bool = True) -> str | None: + """Convert the promise chain to async/await syntax. + + Args: + assignment_variable_name: Optional custom variable name for assignment + inplace_edit: If set to true, will call statement.edit(); else will return a string of the new code + + Returns: + Optional[str]: The converted async/await code + """ + # check if promise expression needs to be wrapped in a try/catch/finally block + needs_wrapping = self.has_catch_call or self.has_finally_call + formatted_blocks = [] + + if needs_wrapping: + formatted_blocks.append(f"\n{self.base_indent}try {{") + + base_call = self.handle_base_call() + formatted_blocks.append(f"{self.base_indent}{base_call}") + + for idx, then_call in enumerate(self.then_chain): + is_last_then = idx == len(self.then_chain) - 1 + + # if it's the last then block, then parse differently + if is_last_then: + formatted_block = self.parse_last_then_block(then_call, assignment_variable_name=assignment_variable_name) + else: + next_call = self.then_chain[idx + 1] if idx + 1 < len(self.then_chain) else None + formatted_block = self.handle_then_block(then_call, next_call) + formatted_blocks.append(f"{self.base_indent}{formatted_block}") + + if self.catch_call: + catch_block = self.handle_catch_block(self.catch_call, assignment_variable_name=assignment_variable_name) + formatted_blocks.append(catch_block) + + if self.finally_call: + finally_block = self.handle_finally_block(self.finally_call) + formatted_blocks.append(finally_block) + + if needs_wrapping: + formatted_blocks.append(f"{self.base_indent}}}") + + if self.parent_statement.parent_function: + self.parent_statement.parent_function.asyncify() + + diff_changes = "\n".join(formatted_blocks) + if inplace_edit: + self.parent_statement.edit(diff_changes) + else: + return diff_changes + + @reader + def _is_implicit_return(self, arrow_fn: TSFunction) -> bool: + """Check if an arrow function has an implicit return. + + An implicit return occurs when: + 1. The function has exactly one statement + 2. The statement is not a comment + 3. The function body is not wrapped in curly braces + + Args: + arrow_fn: The arrow function to check + + Returns: + bool: True if the function has an implicit return + """ + statements = arrow_fn.code_block.statements + if len(statements) != 1: + return False + + stmt = statements[0] + return not stmt.statement_type == StatementType.COMMENT and not arrow_fn.code_block.source.strip().startswith("{") + + @reader + def _format_anonymous_function(self, arrow_fn: TSFunction, next_call: FunctionCall | None = None, assignment_variable_name: str | None = None) -> str: + """Format an arrow function as an anonymous async function. + + Args: + arrow_fn: The arrow function to format + next_call: The next function call in the chain, if any + assignment_variable_name: Optional custom variable name for assignment + + Returns: + str: Formatted anonymous function code + """ + params = arrow_fn.parameters + params_str = ", ".join(p.source for p in params) if params else "" + lines = [] + + var_name = assignment_variable_name if assignment_variable_name else self.assigned_var + + if next_call and next_call.name == "then": + next_params = self.get_next_call_params(next_call) + if next_params: + lines.append(f"{self.base_indent}{self.format_param_assignment(next_params, f'await (async ({params_str}) => {{', declare=True)}") + else: + prefix = "" + if self.is_return_statement: + prefix = "return " + elif var_name: + prefix = f"{var_name} = " + lines.append(f"{self.base_indent}{prefix}await (async ({params_str}) => {{") + + code_block = arrow_fn.code_block + block_content = code_block.source.strip() + if block_content.startswith("{"): + block_content = block_content[1:] + if block_content.endswith("}"): + block_content = block_content[:-1] + + block_lines = block_content.split("\n") + for line in block_lines: + if line.strip(): + lines.append(f"{self.base_indent} {line.strip()}") + + if params_str: + lines.append(f"{self.base_indent}}})({params_str});") + else: + lines.append(f"{self.base_indent}}})();") + + return "\n".join(lines) + + @property + @reader + def has_catch_call(self) -> bool: + """Check if the Promise chain has a catch call. + + Returns: + bool: True if there is a catch call + """ + return self.catch_call is not None + + @property + @reader + def has_finally_call(self) -> bool: + """Check if the Promise chain has a finally call. + + Returns: + bool: True if there is a finally call + """ + return self.finally_call is not None diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/enum_definition.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/enum_definition.py new file mode 100644 index 00000000..d59ff09a --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/enum_definition.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Self, TypeVar, override + +from graph_sitter.core.autocommit import commiter, reader +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.interfaces.has_attribute import HasAttribute +from graph_sitter.enums import SymbolType +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.interfaces.has_block import TSHasBlock +from graph_sitter.typescript.statements.attribute import TSAttribute +from graph_sitter.typescript.symbol import TSSymbol + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.detached_symbols.code_block import CodeBlock + from graph_sitter.core.expressions import Expression + from graph_sitter.core.interfaces.has_name import HasName + from graph_sitter.core.interfaces.importable import Importable + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.core.statements.statement import Statement + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + +Parent = TypeVar("Parent", bound="TSHasBlock") + + +@ts_apidoc +class TSEnum(TSHasBlock, TSSymbol, HasAttribute[TSAttribute]): + """Representation of an Enum in TypeScript. + + Attributes: + symbol_type: The type of symbol, set to SymbolType.Enum. + body: The expression representing the body of the enum. + code_block: The code block associated with the enum. + """ + + symbol_type = SymbolType.Enum + body: Expression[Self] + code_block: TSCodeBlock + + def __init__( + self, + ts_node: TSNode, + file_id: NodeId, + ctx: CodebaseContext, + parent: Statement[CodeBlock[Parent, ...]], + ) -> None: + name_node = ts_node.child_by_field_name("name") + super().__init__(ts_node, file_id, ctx, parent, name_node=name_node) + self.body = self._parse_expression(ts_node.child_by_field_name("body")) + + @property + @reader + def attributes(self) -> list[TSAttribute[Self, None]]: + """Property that retrieves the attributes of a TypeScript enum. + + Returns the list of attributes defined within the enum's code block. + + Returns: + list[TSAttribute[Self, None]]: List of TSAttribute objects representing the enum's attributes. + """ + return self.code_block.attributes + + @reader + def get_attribute(self, name: str) -> TSAttribute | None: + """Returns an attribute from the TypeScript enum by its name. + + Args: + name (str): The name of the attribute to retrieve. + + Returns: + TSAttribute | None: The attribute with the given name if it exists, None otherwise. + """ + return next((x for x in self.attributes if x.name == name), None) + + @noapidoc + @commiter + def _compute_dependencies(self, usage_type: UsageKind = UsageKind.BODY, dest: HasName | None = None) -> None: + dest = dest or self.self_dest + self.body._compute_dependencies(usage_type, dest) + + @property + @noapidoc + def descendant_symbols(self) -> list[Importable]: + return super().descendant_symbols + self.body.descendant_symbols + + @noapidoc + @reader + @override + def resolve_attribute(self, name: str) -> TSAttribute | None: + return self.get_attribute(name) + + @staticmethod + @noapidoc + def _get_name_node(ts_node: TSNode) -> TSNode | None: + if ts_node.type == "enum_declaration": + return ts_node.child_by_field_name("name") + return None diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/enums.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/enums.py new file mode 100644 index 00000000..ce101ec6 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/enums.py @@ -0,0 +1,36 @@ +from enum import StrEnum + + +class TSFunctionTypeNames(StrEnum): + # const a = function functionExpression(): void { + # console.log("This is a regular function expression"); + # }; + FunctionExpression = "function_expression" + + # let arrowFunction = (x,y) => { x + y }; + ArrowFunction = "arrow_function" + + # function* generatorFunctionDeclaration(): Generator { + # yield 1; + # } + GeneratorFunctionDeclaration = "generator_function_declaration" + + # const a = function* generatorFunction(): Generator { + # yield 1; + # }; + GeneratorFunction = "generator_function" + + # function functionDeclaration(name: string): string { + # return `Hello, ${name}!`; + # } + FunctionDeclaration = "function_declaration" + + # class Example { + # methodDefinition(): void { + # console.log("This is a method definition"); + # } + # } + MethodDefinition = "method_definition" + + # Decorated methods (assuming decorators are supported in your JavaScript/TypeScript parser) + DecoratedMethodDefinition = "decorated_method_definition" diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/export.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/export.py new file mode 100644 index 00000000..d7e48b02 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/export.py @@ -0,0 +1,705 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar, override + +from graph_sitter.compiled.utils import cached_property +from graph_sitter.core.autocommit import commiter, reader +from graph_sitter.core.autocommit.decorators import writer +from graph_sitter.core.dataclasses.usage import UsageKind, UsageType +from graph_sitter.core.export import Export +from graph_sitter.core.expressions.name import Name +from graph_sitter.core.external_module import ExternalModule +from graph_sitter.core.import_resolution import Import +from graph_sitter.core.interfaces.chainable import Chainable +from graph_sitter.core.interfaces.has_value import HasValue +from graph_sitter.core.interfaces.importable import Importable +from graph_sitter.enums import EdgeType, ImportType, NodeType +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.assignment import TSAssignment +from graph_sitter.typescript.class_definition import TSClass +from graph_sitter.typescript.enum_definition import TSEnum +from graph_sitter.typescript.enums import TSFunctionTypeNames +from graph_sitter.typescript.function import TSFunction +from graph_sitter.typescript.import_resolution import TSImport +from graph_sitter.typescript.interface import TSInterface +from graph_sitter.typescript.namespace import TSNamespace +from graph_sitter.typescript.statements.assignment_statement import TSAssignmentStatement +from graph_sitter.typescript.type_alias import TSTypeAlias +from graph_sitter.utils import find_all_descendants + +if TYPE_CHECKING: + from collections.abc import Generator + + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.codebase.resolution_stack import ResolutionStack + from graph_sitter.core.interfaces.exportable import Exportable + from graph_sitter.core.interfaces.has_name import HasName + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.core.statements.export_statement import ExportStatement + from graph_sitter.core.symbol_groups.collection import Collection + from graph_sitter.typescript.symbol import TSSymbol + + +@ts_apidoc +class TSExport(Export["Collection[TSExport, ExportStatement[TSExport]]"], HasValue, Chainable): + """Represents a single exported symbol. + + There is a 1:M relationship between an ExportStatement and an Export + + Attributes: + node_type: The type of the node, set to NodeType.EXPORT. + """ + + _declared_symbol: TSSymbol | TSImport | None + _exported_symbol: Name | None + _name_node: Name | None + node_type: Literal[NodeType.EXPORT] = NodeType.EXPORT + + def __init__( + self, + ts_node: TSNode, + file_node_id: NodeId, + parent: Collection[TSExport, ExportStatement[TSExport]], + ctx: CodebaseContext, + name_node: TSNode | None = None, + declared_symbol: TSSymbol | TSImport | None = None, + exported_symbol: TSNode | None = None, + value_node: TSNode | None = None, + ) -> None: + """Given an `export_statement` tree sitter node, parses all implicit export symbols.""" + if declared_symbol and exported_symbol and declared_symbol.name != exported_symbol.text.decode("utf-8"): + msg = "The exported symbol name must match the declared symbol name" + raise ValueError(msg) + + super().__init__(ts_node, file_node_id, ctx, parent) + self._name_node = self._parse_expression(name_node, default=Name) + self._declared_symbol = declared_symbol + self._exported_symbol = self._parse_expression(exported_symbol, default=Name) + # if self.is_wildcard_export(): + # self.node_id = NodeIdFactory.export_node_id(name=f"wildcard_export_<{self._declared_symbol.node_id}>", file_id=self.file_node_id, is_default=self.is_default_export()) + # else: + # self.node_id = NodeIdFactory.export_node_id(name=self.name, file_id=self.file_node_id, is_default=self.is_default_export()) + self.parse(ctx) + self._value_node = self._parse_expression(value_node) + + @classmethod + @noapidoc + def from_export_statement_with_declaration( + cls, + export_statement: TSNode, + declaration: TSNode, + file_id: NodeId, + ctx: CodebaseContext, + parent: ExportStatement[TSExport], + pos: int, + ) -> list[TSExport]: + declared_symbols = [] + + # =====[ Symbol Definitions ]===== + if declaration.type in ["function_declaration", "generator_function_declaration"]: + # e.g. export function* namedGenerator() {} + declared_symbols.append(TSFunction(declaration, file_id, ctx, parent)) + elif declaration.type == "class_declaration": + # e.g. export class NamedClass {} + declared_symbols.append(TSClass(declaration, file_id, ctx, parent)) + elif declaration.type in ["variable_declaration", "lexical_declaration"]: + if len(arrow_functions := find_all_descendants(declaration, {"arrow_function"}, max_depth=2)) > 0: + # e.g. export const arrowFunction = () => {}, but not export const a = { func: () => null } + for arrow_func in arrow_functions: + declared_symbols.append(TSFunction.from_function_type(arrow_func, file_id, ctx, parent)) + else: + # e.g. export const a = value; + for child in declaration.named_children: + if child.type in TSAssignmentStatement.assignment_types: + s = TSAssignmentStatement.from_assignment(declaration, file_id, ctx, parent.parent, pos, assignment_node=child) + declared_symbols.extend(s.assignments) + elif declaration.type == "interface_declaration": + # e.g. export interface MyInterface {} + declared_symbols.append(TSInterface(declaration, file_id, ctx, parent)) + elif declaration.type == "type_alias_declaration": + # e.g. export type MyType = {} + declared_symbols.append(TSTypeAlias(declaration, file_id, ctx, parent)) + elif declaration.type == "enum_declaration": + # e.g. export enum MyEnum {} + declared_symbols.append(TSEnum(declaration, file_id, ctx, parent)) + elif declaration.type == "internal_module": + # e.g. export namespace MyNamespace {} + declared_symbols.append(TSNamespace(declaration, file_id, ctx, parent)) + else: + declared_symbols.append(None) + + exports = [] + for declared_symbol in declared_symbols: + name_node = declared_symbol._name_node.ts_node if declared_symbol and declared_symbol._name_node else declaration + export = cls(ts_node=declaration, file_node_id=file_id, ctx=ctx, name_node=name_node, declared_symbol=declared_symbol, parent=parent.exports) + exports.append(export) + return exports + + @classmethod + @noapidoc + def from_export_statement_with_value(cls, export_statement: TSNode, value: TSNode, file_id: NodeId, ctx: CodebaseContext, parent: ExportStatement[TSExport], pos: int) -> list[TSExport]: + declared_symbols = [] + exported_name_and_symbol = [] # tuple of export name node and export symbol name + detached_value_node = None + + # =====[ Symbol Definitions ]===== + if value.type in [function_type.value for function_type in TSFunctionTypeNames]: + # e.g. export default async function() {} + declared_symbols.append(parent._parse_expression(value)) + elif value.type == "class": + # e.g. export default class {} + declared_symbols.append(parent._parse_expression(value, default=TSClass)) + elif value.type == "object": + # e.g. export default { a, b, c }, export = { a, b, c } + # Export symbol usage will get resolved in _compute_dependencies based on identifiers in value + # TODO: parse as TSDict + detached_value_node = value + for child in value.named_children: + if child.type == "pair": + key_value = child.child_by_field_name("key") + pair_value = child.child_by_field_name("value") + if pair_value.type in [function_type.value for function_type in TSFunctionTypeNames]: + declared_symbols.append(TSFunction(pair_value, file_id, ctx, parent)) + elif pair_value.type == "class": + declared_symbols.append(TSClass(pair_value, file_id, ctx, parent)) + else: + exported_name_and_symbol.append((key_value, pair_value)) + elif child.type == "shorthand_property_identifier": + exported_name_and_symbol.append((child, child)) + elif value.type == "assignment_expression": + left = value.child_by_field_name("left") + right = value.child_by_field_name("right") + assignment = TSAssignment(value, file_id, ctx, parent, left, right, left) + declared_symbols.append(assignment) + else: + # Other values are detached symbols: array, number, string, true, null, undefined, new_expression, call_expression + # Export symbol usage will get resolved in _compute_dependencies based on identifiers in value + detached_value_node = value + declared_symbols.append(None) + + exports = [] + for declared_symbol in declared_symbols: + if declared_symbol is None: + name_node = value + else: + name_node = declared_symbol._name_node.ts_node if declared_symbol._name_node else declared_symbol.ts_node + export = cls(ts_node=export_statement, file_node_id=file_id, ctx=ctx, name_node=name_node, declared_symbol=declared_symbol, value_node=detached_value_node, parent=parent.exports) + exports.append(export) + for name_node, symbol_name_node in exported_name_and_symbol: + exports.append(cls(ts_node=export_statement, file_node_id=file_id, ctx=ctx, name_node=name_node, exported_symbol=symbol_name_node, value_node=detached_value_node, parent=parent.exports)) + return exports + + @noapidoc + @commiter + def parse(self, ctx: CodebaseContext) -> None: + pass + + @noapidoc + @commiter + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: + if self.exported_symbol: + for frame in self.resolved_type_frames: + if frame.parent_frame: + frame.parent_frame.add_usage(self._name_node or self, UsageKind.EXPORTED_SYMBOL, self, self.ctx) + elif self._exported_symbol: + if not next(self.resolve_name(self._exported_symbol.source), None): + self._exported_symbol._compute_dependencies(UsageKind.BODY, dest=dest or self) + elif self.value: + self.value._compute_dependencies(UsageKind.EXPORTED_SYMBOL, self) + + @noapidoc + @commiter + def compute_export_dependencies(self) -> None: + """Create Export edges from this export to it's used symbols""" + if self.declared_symbol is not None: + assert self.ctx.has_node(self.declared_symbol.node_id) + self.ctx.add_edge(self.node_id, self.declared_symbol.node_id, type=EdgeType.EXPORT) + elif self._exported_symbol is not None: + symbol_name = self._exported_symbol.source + if (used_node := next(self.resolve_name(symbol_name), None)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): + self.ctx.add_edge(self.node_id, used_node.node_id, type=EdgeType.EXPORT) + elif self.value is not None: + if isinstance(self.value, Chainable): + for resolved in self.value.resolved_types: + if self.ctx.has_node(getattr(resolved, "node_id", None)): + self.ctx.add_edge(self.node_id, resolved.node_id, type=EdgeType.EXPORT) + elif self.name is None: + # This is the export *; case + self.ctx.add_edge(self.node_id, self.file_node_id, type=EdgeType.EXPORT) + if self.is_wildcard_export(): + for file in self.file.importers: + file.__dict__.pop("valid_symbol_names", None) + file.__dict__.pop("valid_import_names", None) + + @reader + def is_named_export(self) -> bool: + """Determines whether this export is a named export. + + Named exports are exports that are not default exports. For example, `export const foo = 'bar'` is a named export, + while `export default foo` is not. + + Returns: + bool: True if this is a named export, False if it is a default export. + """ + return not self.is_default_export() + + @reader + def is_default_export(self) -> bool: + """Determines if an export is the default export for a file. + + This function checks if the export is a default export by examining the export source code and the export's symbol. It handles various cases of default exports including: + - Re-exports as default (`export { foo as default }`) + - Default exports (`export default foo`) + - Module exports (`export = foo`) + + Returns: + bool: True if this is a default export, False otherwise. + """ + exported_symbol = self.exported_symbol + if exported_symbol and isinstance(exported_symbol, TSImport) and exported_symbol.is_default_import(): + return True + + # ==== [ Case: Named re-export as default ] ==== + # e.g. export { foo as default } from './other-module'; + exported_symbol = self.exported_symbol + if exported_symbol is not None and exported_symbol.node_type == NodeType.IMPORT and exported_symbol.source == self.source: + return self.name == "default" + + # ==== [ Case: Default export ] ==== + # e.g. export default foo; export default { foo }; export = foo; export = { foo }; + return self.parent.parent.source.startswith("export default ") or self.parent.parent.source.startswith("export = ") + + @reader + def is_default_symbol_export(self) -> bool: + """Returns True if this is exporting a default symbol, as opposed to a default object export. + + This method checks if an export is a default symbol export (e.g. 'export default foo') rather than a default object export (e.g. 'export default { foo }'). + It handles both direct exports and re-exports. + + Args: + self (TSExport): The export object being checked. + + Returns: + bool: True if this is a default symbol export, False otherwise. + """ + if not self.is_default_export(): + return False + + # ==== [ Case: Default import re-export ] ==== + exported_symbol = self.exported_symbol + if exported_symbol is not None and exported_symbol.node_type == NodeType.IMPORT and exported_symbol.source == self.source: + return self.name == "default" + + # === [ Case: Default symbol export ] ==== + export_object = next((x for x in self.ts_node.children if x.type == "object"), None) + return export_object is None + + @reader + def is_type_export(self) -> bool: + """Determines if this export is exclusively exporting a type. + + Checks if this export starts with "export type" to identify if it's only exporting a type definition. + This method is used to distinguish between value exports and type exports in TypeScript. + + Returns: + bool: True if this is a type-only export, False otherwise. + """ + # TODO: do this more robustly + return self.source.startswith("export type ") + + @reader + def is_reexport(self) -> bool: + """Returns whether the export is re-exporting an import or export. + + Checks if this export node is re-exporting a symbol that was originally imported from another module or exported from another location. This includes wildcard re-exports of entire modules. + + Args: + self (TSExport): The export node being checked. + + Returns: + bool: True if this export re-exports an imported/exported symbol or entire module, False otherwise. + """ + if exported_symbol := self.exported_symbol: + return exported_symbol.node_type == NodeType.IMPORT or exported_symbol.node_type == NodeType.EXPORT or exported_symbol == self.file + return False + + @reader + def is_wildcard_export(self) -> bool: + """Determines if the export is a wildcard export. + + Checks if the export statement contains a wildcard export pattern 'export *' or 'export *;'. A wildcard export exports all symbols from a module. + + Returns: + bool: True if the export is a wildcard export (e.g. 'export * from "./module"'), False otherwise. + """ + return "export * " in self.source or "export *;" in self.source + + @reader + def is_module_export(self) -> bool: + """Determines if the export is exporting a module rather than a symbol. + + Returns True if the export is a wildcard export (e.g. 'export *') or if it is a default export but not of a symbol (e.g. 'export default { foo }'). + + Returns: + bool: True if the export represents a module export, False otherwise. + """ + return self.is_wildcard_export() or (self.is_default_export() and not self.is_default_symbol_export()) + + @property + @reader(cache=False) + def declared_symbol(self) -> TSSymbol | TSImport | None: + """Returns the symbol that was defined in this export. + + Returns the symbol that was directly declared within this export statement. For class, function, + interface, type alias, enum declarations or assignments, returns the declared symbol. + For re-exports or exports without declarations, returns None. + + Returns: + Union[TSSymbol, TSImport, None]: The symbol declared within this export statement, + or None if no symbol was declared. + """ + return self._declared_symbol + + @property + @reader + def exported_symbol(self) -> Exportable | None: + """Returns the symbol, file, or import being exported from this export object. + + Retrieves the symbol or module being exported by this export node by finding the node connected via an EXPORT edge. + This method is the inverse of Import.imported_symbol. + + Args: + None + + Returns: + Exportable | None: The exported symbol, file, or import, or None if no symbol is exported. + """ + return next(iter(self.ctx.successors(self.node_id, edge_type=EdgeType.EXPORT)), None) + + @property + @reader + def resolved_symbol(self) -> Exportable | None: + """Returns the Symbol, SourceFile or External module that this export resolves to. + + Recursively traverses through indirect imports and exports to find the final resolved symbol. + This is useful for determining what symbol an export ultimately points to, particularly in cases of re-exports and import-export chains. + + Returns: + Exportable | None: The final resolved Symbol, SourceFile or External module, or None if the resolution fails. The resolution follows this chain: + - If the symbol is an Import, resolves to its imported symbol + - If the symbol is an Export, resolves to its exported symbol + - Otherwise returns the symbol itself + + Note: + Handles circular references by tracking visited symbols to prevent infinite loops. + """ + ix_seen = set() + resolved_symbol = self.exported_symbol + + while resolved_symbol is not None and (resolved_symbol.node_type == NodeType.IMPORT or resolved_symbol.node_type == NodeType.EXPORT): + if resolved_symbol in ix_seen: + return resolved_symbol + + ix_seen.add(resolved_symbol) + if resolved_symbol.node_type == NodeType.IMPORT: + resolved_symbol = resolved_symbol.resolved_symbol + else: + resolved_symbol = resolved_symbol.exported_symbol + + return resolved_symbol + + @writer + def make_non_default(self) -> None: + """Converts the export to a named export. + + Transforms default exports into named exports by modifying the export syntax and updating any corresponding export/import usages. + For default exports, it removes the 'default' keyword and adjusts all import statements that reference this export. + + Args: + None + + Returns: + None + """ + if self.is_default_export(): + # Default node is: + # export default foo = ... + # ^^^^^^^ + default_node = self.parent.parent._anonymous_children[1] + + if default_node.ts_node.type == "default": + if isinstance(self.declared_symbol, TSAssignment): + # Converts `export default foo` to `export const foo` + default_node.edit("const") + else: + # Converts `export default foo` to `export { foo }` + default_node.remove() + if name_node := self.get_name(): + name_node.insert_before("{ ", newline=False) + name_node.insert_after(" }", newline=False) + + # Update all usages of this export + for usage in self.usages(usage_types=UsageType.DIRECT): + if usage.match is not None and usage.kind == UsageKind.IMPORTED: + # === [ Case: Exported Symbol ] === + # Fixes Exports of the form `export { ... } from ...` + if usage.usage_symbol.source.startswith("export") and usage.match.source == "default": + # Export clause is: + # export { default as foo } from ... + # ^^^^^^^^^^^^^^^^^^ + export_clause = usage.usage_symbol.children[0] + for export_specifier in export_clause.children: + # This is the case where `export { ... as ... }` + if len(export_specifier.children) == 2 and export_specifier.children[0] == usage.match: + if export_specifier.children[1].source == self.name: + # Converts `export { default as foo }` to `export { foo }` + export_specifier.edit(self.name) + else: + # Converts `export { default as renamed_foo }` to `export { foo as renamed_foo }` + usage.match.edit(self.name) + # This is the case where `export { ... } from ...`, (specifically `export { default }`) + elif len(export_specifier.children) == 1 and export_specifier.children[0] == usage.match: + # Converts `export { default }` to `export { foo }` + export_specifier.edit(self.name) + + # === [ Case: Imported Symbol ] === + # Fixes Imports of the form `import { default as foo }` + else: + # Import clause is: + # import A, { B } from ... + # ^^^^^^^^ + import_clause = usage.usage_symbol.children[0] + + # Fixes imports of the form `import foo, { ... } from ...` + if len(import_clause.children) > 1 and import_clause.children[0] == usage.match: + # This is a terrible hack :skull: + + # Named imports are: + # import foo, { ... } + # ^^^^^^^ + named_imports = import_clause.children[1] + + # This converts `import foo, { bar, baz as waz }` to `import { foo, bar, baz as waz }` + import_clause.children[0].remove() # Remove `foo, ` + named_imports.children[0].insert_before(f"{self.name}, ", newline=False) # Add the `foo, ` + # Fixes imports of the form `import foo from ...` + else: + # This converts `import foo` to `import { foo }` + usage.match.insert_before("{ ", newline=False) + usage.match.insert_after(" }", newline=False) + + @cached_property + @noapidoc + @reader + def _wildcards(self) -> dict[str, WildcardExport[Self]]: + if self.is_wildcard_export() and isinstance(self.exported_symbol, Import): + res = {} + for name, symbol in self.exported_symbol._wildcards.items(): + res[name] = WildcardExport(self, symbol) + return res + return {} + + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + aliased = self.is_aliased() + if self.exported_symbol is not None: + yield from self.with_resolution_frame(self.exported_symbol, direct=True, aliased=aliased) + elif self.value is not None: + yield from self.with_resolution_frame(self.value, direct=True, aliased=aliased) + + @property + @noapidoc + def names(self) -> Generator[tuple[str, Self | WildcardExport[Self]], None, None]: + if self.exported_name is None: + if self.is_wildcard_export(): + yield from self._wildcards.items() + else: + yield self.exported_name, self + + @property + def descendant_symbols(self) -> list[Importable]: + """Returns a list of all descendant symbols from this export's declared symbol. + + Returns all child symbols that are contained within the declared symbol of this export. For example, + if the declared symbol is a class, this will return all methods, properties and nested classes. + If the export has no declared symbol, returns an empty list. + + Returns: + list[Importable]: List of descendant symbols. Empty list if no declared symbol exists. + """ + if self.declared_symbol: + return [self, *self.declared_symbol.descendant_symbols] + return [self] + + def __hash__(self): + if self._hash is None: + self._hash = hash((self.filepath, self.range, self.ts_node.kind_id, self.name)) + return self._hash + + @reader + def __eq__(self, other: object): + if isinstance(other, TSExport): + return super().__eq__(other) and self.name == other.name + return super().__eq__(other) + + @property + @reader + def source(self) -> str: + """Returns the source code of the symbol. + + Gets the source code of the symbol from its extended representation, which includes the export statement. + + Returns: + str: The complete source code of the symbol including any extended nodes. + """ + return self.parent.parent.source + + @property + @reader + def is_external_export(self) -> bool: + """Determines if this export is exporting a symbol from an external (non-relative) module. + + An external module is one that comes from outside the project's codebase. + + Returns: + bool: True if the export is from an external module, False otherwise. + """ + if self.is_reexport(): + if isinstance(self.exported_symbol, TSImport): + for resolved in self.exported_symbol.resolved_types: + if isinstance(resolved, ExternalModule): + return True + return False + + @reader + def to_import_string(self) -> str: + """Converts this export into its equivalent import string representation. + + This is primarily used for handling re-exports, converting them into their + equivalent import statements. + + Returns: + str: The import string representation of this export. + + Examples: + - For `export { foo } from './bar'` -> `import { foo } from './bar'` + - For `export * from './bar'` -> `import * as _namespace from './bar'` + - For `export { default as foo } from './bar'` -> `import foo from './bar'` + """ + module_path = self.exported_symbol.module.source.strip("'\"") if self.exported_symbol.module is not None else "" + type_prefix = "type " if self.is_type_export() else "" + + if self.is_wildcard_export(): + namespace = self.name or module_path.split("/")[-1].split(".")[0] + return f"import * as {namespace} from '{module_path}';" + + if self.is_default_export(): + if self.is_type_export() and self.is_aliased(): + original_name = self.exported_symbol.symbol_name.source if self.exported_symbol.symbol_name is not None else self.exported_symbol.name + print(original_name) + if original_name == "default": + return f"import {type_prefix}{{ default as {self.name} }} from '{module_path}';" + else: + return f"import {type_prefix}{{ {original_name} as default }} from '{module_path}';" + + # Handle mixed type and value exports + if "type" in self.source and "," in self.source and "{" in self.source and "}" in self.source: + content = self.source[self.source.index("{") + 1 : self.source.index("}")].strip() + return f"import {{ {content} }} from '{module_path}';" + + original_name = self.exported_symbol.symbol_name.source if self.exported_symbol.symbol_name is not None else self.exported_symbol.name + return f"import {{ {original_name} as {self.name} }} from '{module_path}';" + + @reader + def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: + """Returns the import string for this export. + + Args: + alias (str | None): Optional alias to use when importing the symbol. + module (str | None): Optional module name to import from. + import_type (ImportType): The type of import to generate. + is_type_import (bool): Whether this is a type-only import. + + Returns: + str: The formatted import string. + """ + if self.is_reexport(): + return self.to_import_string() + + module_path = self.file.import_module_name.strip("'\"") + type_prefix = "type " if is_type_import else "" + + if import_type == ImportType.WILDCARD: + namespace = alias or module_path.split("/")[-1].split(".")[0] + return f"import * as {namespace} from '{module_path}';" + + # Handle default exports + if self.is_default_export(): + name = alias or self.name + return f"import {name} from '{module_path}';" + + # Handle named exports + original_name = self.name + if alias and alias != original_name: + return f"import {type_prefix}{{ {original_name} as {alias} }} from '{module_path}';" + return f"import {type_prefix}{{ {original_name} }} from '{module_path}';" + + @reader + def reexport_symbol(self) -> TSImport | None: + """Returns the import object that is re-exporting this symbol. + + For re-exports like: + - `export { foo } from './bar'` # Direct re-export + - `export { default as baz } from './bar'` # Direct default re-export + - `export * from './bar'` # Direct wildcard re-export + - `import { foo } from './bar'; export { foo }` # Local re-export + + This returns the corresponding import object that's being re-exported. + + Returns: + TSImport | None: The import object being re-exported, or None if this + is not a re-export or no import was found. + """ + # Only exports can have re-export sources + if not self.is_reexport(): + return None + + # For direct re-exports (export { x } from './y'), use declared_symbol + if self.declared_symbol is not None: + return self.declared_symbol + + # For local re-exports (import x; export { x }), use exported_symbol + if self.exported_symbol is not None and self.exported_symbol.node_type == NodeType.IMPORT: + return self.exported_symbol + + return None + + +TExport = TypeVar("TExport", bound="Export") + + +class WildcardExport(Chainable, Generic[TExport]): + """Class to represent one of many wildcard exports.""" + + exp: TExport + symbol: Exportable + + def __init__(self, exp: TExport, symbol: Exportable): + self.exp = exp + self.symbol = symbol + + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + """Resolve the types used by this import.""" + yield from self.exp.with_resolution_frame(self.symbol, direct=False) + + @noapidoc + @reader + def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: + pass diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/array_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/array_type.py new file mode 100644 index 00000000..f777403f --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/array_type.py @@ -0,0 +1,19 @@ +from typing import TypeVar + +from tree_sitter import Node as TSNode + +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.typescript.expressions.named_type import TSNamedType + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSArrayType(TSNamedType[Parent]): + """Array type + Examples: + string[] + """ + + def _get_name_node(self) -> TSNode | None: + return self.ts_node.named_children[0] diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/chained_attribute.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/chained_attribute.py new file mode 100644 index 00000000..d56f8f25 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/chained_attribute.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING, Generic, TypeVar + +from graph_sitter.compiled.autocommit import reader +from graph_sitter.core.detached_symbols.function_call import FunctionCall +from graph_sitter.core.expressions import Expression, Name +from graph_sitter.core.expressions.chained_attribute import ChainedAttribute +from graph_sitter.shared.decorators.docs import ts_apidoc + +if TYPE_CHECKING: + from graph_sitter.core.interfaces.editable import Editable + +Parent = TypeVar("Parent", bound="Editable") + + +@ts_apidoc +class TSChainedAttribute(ChainedAttribute[Expression, Name, Parent], Generic[Parent]): + """A TypeScript chained attribute class representing member access expressions. + + This class handles the representation and analysis of chained attribute access expressions in TypeScript, + such as 'object.property' or 'object.method()'. It provides functionality for accessing the object + and property components of the expression, as well as analyzing function calls made on the object. + """ + + def __init__(self, ts_node, file_node_id, ctx, parent: Parent): + super().__init__(ts_node, file_node_id, ctx, parent=parent, object=ts_node.child_by_field_name("object"), attribute=ts_node.child_by_field_name("property")) + + @property + @reader + def function_calls(self) -> list[FunctionCall]: + """Returns a list of function calls associated with this chained attribute's object. + + Retrieves all function calls made on the object component of this chained attribute. + This is useful for analyzing call sites and call patterns in code analysis and refactoring tasks. + + Returns: + list[FunctionCall]: A list of function calls made on this chained attribute's object. + """ + # Move the parent reference to its own parent to skip over an identifier type in parent chain + return self._object.function_calls diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/conditional_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/conditional_type.py new file mode 100644 index 00000000..55da92d9 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/conditional_type.py @@ -0,0 +1,59 @@ +from collections.abc import Generator +from typing import TYPE_CHECKING, Generic, Self, TypeVar, override + +from tree_sitter import Node as TSNode + +from graph_sitter.codebase.resolution_stack import ResolutionStack +from graph_sitter.core.autocommit import reader +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.expressions.type import Type +from graph_sitter.core.interfaces.importable import Importable +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.typescript.expressions.type import TSType + + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSConditionalType(Type[Parent], Generic[Parent]): + """Conditional Type + + Examples: + typeof s + + Attributes: + left: The left-hand side type of the conditional type. + right: The right-hand side type of the conditional type. + consequence: The type if the condition is true. + alternative: The type if the condition is false. + """ + + left: "TSType[Self]" + right: "TSType[Self]" + consequence: "TSType[Self]" + alternative: "TSType[Self]" + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): + super().__init__(ts_node, file_node_id, ctx, parent) + self.left = self.child_by_field_name("left") + self.right = self.child_by_field_name("right") + self.consequence = self.child_by_field_name("consequence") + self.alternative = self.child_by_field_name("alternative") + + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + self.left._compute_dependencies(usage_type, dest) + self.right._compute_dependencies(usage_type, dest) + self.consequence._compute_dependencies(usage_type, dest) + self.alternative._compute_dependencies(usage_type, dest) + + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + yield from self.with_resolution_frame(self.consequence) + yield from self.with_resolution_frame(self.alternative) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/expression_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/expression_type.py new file mode 100644 index 00000000..57d88fab --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/expression_type.py @@ -0,0 +1,29 @@ +from typing import TYPE_CHECKING, Generic, TypeVar + +from tree_sitter import Node as TSNode + +from graph_sitter.core.expressions import Expression +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.typescript.expressions.named_type import TSNamedType + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.interfaces.editable import Editable + +Parent = TypeVar("Parent", bound="Editable") + + +@ts_apidoc +class TSExpressionType(TSNamedType, Generic[Parent]): + """Type defined by evaluation of an expression + + Attributes: + expression: The expression to evaluate that yields the type + """ + + expression: Expression["TSExpressionType[Parent]"] + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): + super().__init__(ts_node, file_node_id, ctx, parent) + self.expression = self._parse_expression(ts_node) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/function_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/function_type.py new file mode 100644 index 00000000..dd913b86 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/function_type.py @@ -0,0 +1,94 @@ +from collections.abc import Generator +from typing import TYPE_CHECKING, Generic, Self, TypeVar, override + +from tree_sitter import Node as TSNode + +from graph_sitter.codebase.resolution_stack import ResolutionStack +from graph_sitter.core.autocommit import reader, writer +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.expressions.type import Type +from graph_sitter.core.interfaces.importable import Importable +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.core.symbol_groups.collection import Collection +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.detached_symbols.parameter import TSParameter +from graph_sitter.typescript.placeholder.placeholder_return_type import TSReturnTypePlaceholder + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.typescript.expressions.type import TSType + + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSFunctionType(Type[Parent], Generic[Parent]): + """Function type definition. + + Example: + a: (a: number) => number + + Attributes: + return_type: Return type of the function. + name: This lets parameters generate their node_id properly. + """ + + return_type: "TSType[Self] | TSReturnTypePlaceholder[Self]" + _parameters: Collection[TSParameter, Self] + name: None = None + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): + super().__init__(ts_node, file_node_id, ctx, parent) + self.return_type = self.child_by_field_name("return_type", placeholder=TSReturnTypePlaceholder) + params_node = self.ts_node.child_by_field_name("parameters") + params = [TSParameter(child, idx, self) for idx, child in enumerate(params_node.named_children) if child.type != "comment"] + self._parameters = Collection(params_node, file_node_id, ctx, self, children=params) + + @property + @reader + def parameters(self) -> Collection[TSParameter, Self]: + """Retrieves the parameters of a function type. + + Returns the collection of parameters associated with this function type. These parameters represent the arguments that can be passed to the function. + + Returns: + Collection[TSParameter, Self]: A collection of TSParameter objects representing the function's parameters. + """ + return self._parameters + + @writer + def asyncify(self) -> None: + """Modifies the function type to be asynchronous by wrapping its return type in a Promise. + + This method transforms a synchronous function type into an asynchronous one by modifying + its return type. It wraps the existing return type in a Promise, effectively changing + 'T' to 'Promise'. + + Args: + self: The TSFunctionType instance to modify. + + Returns: + None + """ + if self.return_type: + self.return_type.insert_before("Promise<", newline=False) + self.return_type.insert_after(">", newline=False) + + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: Importable | None = None): + if self.return_type: + self.return_type._compute_dependencies(UsageKind.GENERIC, dest) + + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + yield from self.with_resolution_frame(self.return_type) + + @property + @noapidoc + def descendant_symbols(self) -> list[Importable]: + symbols = [] + for param in self.parameters: + symbols.extend(param.descendant_symbols) + return symbols diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/generic_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/generic_type.py new file mode 100644 index 00000000..0d3131a9 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/generic_type.py @@ -0,0 +1,27 @@ +from typing import Self, TypeVar + +from tree_sitter import Node as TSNode + +from graph_sitter.core.expressions.generic_type import GenericType +from graph_sitter.core.symbol_groups.collection import Collection +from graph_sitter.core.symbol_groups.dict import Dict +from graph_sitter.shared.decorators.docs import ts_apidoc + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSGenericType(GenericType["TSType", Parent]): + """Generic type + + Examples: + `Array` + """ + + def _get_name_node(self) -> TSNode: + return self.child_by_field_name("name").ts_node + + def _get_parameters(self) -> Collection[Self, Self] | Dict[Self, Self] | None: + type_parameter = self.child_by_field_types("type_arguments").ts_node + types = [self._parse_type(child) for child in type_parameter.named_children] + return Collection(node=type_parameter, file_node_id=self.file_node_id, ctx=self.ctx, parent=self, children=types) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/lookup_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/lookup_type.py new file mode 100644 index 00000000..48ab87c0 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/lookup_type.py @@ -0,0 +1,65 @@ +from collections.abc import Generator +from typing import TYPE_CHECKING, Generic, Self, TypeVar, override + +from tree_sitter import Node as TSNode + +from graph_sitter.codebase.resolution_stack import ResolutionStack +from graph_sitter.core.autocommit import reader +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.expressions import Expression +from graph_sitter.core.expressions.type import Type +from graph_sitter.core.interfaces.importable import Importable +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.typescript.expressions.type import TSType + + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSLookupType(Type[Parent], Generic[Parent]): + """Type lookup + + Examples: + a["key"] + + Attributes: + type: The type of the TypeScript object being looked up. + lookup: The expression used for the lookup operation. + """ + + type: "TSType[Self]" + lookup: Expression + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): + super().__init__(ts_node, file_node_id, ctx, parent) + self.type = self._parse_type(ts_node.named_children[0]) + if literal_type := self.child_by_field_types("literal_type"): + self.lookup = self._parse_expression(literal_type.ts_node.named_children[0]) + + @property + @reader + def name(self) -> str | None: + """Retrieves the name of the type object. + + Gets the name property of the underlying type object. This property is commonly used to access type names in TypeScript-style type lookups. + + Returns: + str | None: The name of the type object if it exists, None otherwise. + """ + return self.type.name + + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + # TODO: not implemented properly. Needs to look at the actual lookup + self._log_parse("Cannot resolve lookup type properly") + yield from self.with_resolution_frame(self.type) + + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + self.type._compute_dependencies(usage_type, dest) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/named_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/named_type.py new file mode 100644 index 00000000..1fba0ca4 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/named_type.py @@ -0,0 +1,19 @@ +from typing import TypeVar + +from tree_sitter import Node as TSNode + +from graph_sitter.core.expressions.named_type import NamedType +from graph_sitter.shared.decorators.docs import ts_apidoc + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSNamedType(NamedType[Parent]): + """Named type + Examples: + string + """ + + def _get_name_node(self) -> TSNode | None: + return self.ts_node diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/object_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/object_type.py new file mode 100644 index 00000000..61cd7230 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/object_type.py @@ -0,0 +1,81 @@ +from typing import TYPE_CHECKING, Generic, Self, TypeVar + +from tree_sitter import Node as TSNode + +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.expressions.expression import Expression +from graph_sitter.core.expressions.type import Type +from graph_sitter.core.expressions.value import Value +from graph_sitter.core.interfaces.importable import Importable +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.shared.logging.get_logger import get_logger +from graph_sitter.typescript.symbol_groups.dict import TSDict, TSPair + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + + +logger = get_logger(__name__) + + +Parent = TypeVar("Parent") + + +class TSObjectPair(TSPair, Generic[Parent]): + """Object type + + Examples: + a: {a: int; b?(a: int): c} + """ + + def _get_key_value(self) -> tuple[Expression[Self] | None, Expression[Self] | None]: + from graph_sitter.typescript.expressions.function_type import TSFunctionType + + key, value = None, None + if self.ts_node_type == "property_signature": + type_node = self.ts_node.child_by_field_name("type") + value = self._parse_expression(type_node) + key = self._parse_expression(self.ts_node.child_by_field_name("name")) + elif self.ts_node_type == "call_signature": + value = TSFunctionType(self.ts_node, self.file_node_id, self.ctx, self) + elif self.ts_node_type == "index_signature": + value = self._parse_expression(self.ts_node.child_by_field_name("type")) + key = self._parse_expression(self.ts_node.named_children[0]) + elif self.ts_node_type == "method_signature": + value = TSFunctionType(self.ts_node, self.file_node_id, self.ctx, self) + key = self._parse_expression(self.ts_node.child_by_field_name("name")) + elif self.ts_node_type == "method_definition": + key = self._parse_expression(self.ts_node.child_by_field_name("mapped_clause_type")) + value = self._parse_expression(self.ts_node.child_by_field_name("type")) + else: + key, value = super()._get_key_value() + if isinstance(value, Value): + # HACK: sometimes types are weird + value = self._parse_expression(value.ts_node.named_children[0]) + elif not isinstance(value, Type): + self._log_parse(f"{value} of type {value.__class__.__name__} from {self.ts_node} not a valid type") + + return key, value + + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSObjectType(TSDict, Type[Parent], Generic[Parent]): + """A class representing a TypeScript object type with type annotations and dependencies. + + A specialized class extending `TSDict` and implementing `Type` for handling TypeScript object type annotations. + This class handles object type definitions including nested type structures and manages their dependencies. + It provides functionality for computing dependencies within the type structure and handling type relationships + in TypeScript code. + """ + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, delimiter=";", pair_type=TSObjectPair) + + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + for child in self.values(): + if isinstance(child, Type): + child._compute_dependencies(usage_type, dest) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/query_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/query_type.py new file mode 100644 index 00000000..44aa2e41 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/query_type.py @@ -0,0 +1,59 @@ +from collections.abc import Generator +from typing import TYPE_CHECKING, Generic, Self, TypeVar, override + +from tree_sitter import Node as TSNode + +from graph_sitter.codebase.resolution_stack import ResolutionStack +from graph_sitter.core.autocommit import reader +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.expressions.type import Type +from graph_sitter.core.interfaces.importable import Importable +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.typescript.expressions.type import TSType + + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSQueryType(Type[Parent], Generic[Parent]): + """Type query + + Examples: + typeof s + + Attributes: + query: The TypeScript type associated with the query. + """ + + query: "TSType[Self]" + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): + super().__init__(ts_node, file_node_id, ctx, parent) + self.query = self._parse_type(ts_node.named_children[0]) + + @property + @reader + def name(self) -> str | None: + """Returns the name of the query type. + + A property that retrieves the name of the query type. This property is used to get the name + associated with TypeScript type queries (e.g., 'typeof s'). + + Returns: + str | None: The name of the query type, or None if no name is available. + """ + return self.query.name + + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + self.query._compute_dependencies(usage_type, dest) + + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + yield from self.with_resolution_frame(self.query) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/readonly_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/readonly_type.py new file mode 100644 index 00000000..2b527f89 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/readonly_type.py @@ -0,0 +1,59 @@ +from collections.abc import Generator +from typing import TYPE_CHECKING, Generic, Self, TypeVar, override + +from tree_sitter import Node as TSNode + +from graph_sitter.codebase.resolution_stack import ResolutionStack +from graph_sitter.core.autocommit import reader +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.expressions.type import Type +from graph_sitter.core.interfaces.importable import Importable +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.typescript.expressions.type import TSType + + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSReadonlyType(Type[Parent], Generic[Parent]): + """Readonly type + + Examples: + readonly s + + Attributes: + type: The underlying TypeScript type associated with this readonly type. + """ + + type: "TSType[Self]" + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent): + super().__init__(ts_node, file_node_id, ctx, parent) + self.type = self._parse_type(ts_node.named_children[0]) + + @property + @reader + def name(self) -> str | None: + """Retrieves the name of the type. + + Gets the name from the underlying type object. Since this is a property getter, it is decorated with @reader + to ensure safe concurrent access. + + Returns: + str | None: The name of the type, or None if the type has no name. + """ + return self.type.name + + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + self.type._compute_dependencies(usage_type, dest) + + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + yield from self.with_resolution_frame(self.type) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/string.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/string.py new file mode 100644 index 00000000..042d15b2 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/string.py @@ -0,0 +1,35 @@ +from typing import TYPE_CHECKING, Generic, TypeVar + +from tree_sitter import Node as TSNode + +from graph_sitter.core.expressions import Expression, String +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.shared.decorators.docs import ts_apidoc + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + + +Parent = TypeVar("Parent", bound="Expression") + + +@ts_apidoc +class TSString(String, Generic[Parent]): + """A TypeScript string node representing both literal strings and template strings. + + This class handles both regular string literals and template strings in TypeScript, + providing functionality to parse and manage template string expressions. It extends + the base String class with TypeScript-specific capabilities. + + Attributes: + expressions (list): A list of parsed expressions from template string substitutions. + Empty for regular string literals. + """ + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: + super().__init__(ts_node, file_node_id, ctx, parent=parent) + if ts_node.type == "template_string": + substitutions = [x for x in ts_node.named_children if x.type == "template_substitution"] + self.expressions = [self._parse_expression(x.named_children[0]) for x in substitutions] + else: + self.expressions = [] diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/ternary_expression.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/ternary_expression.py new file mode 100644 index 00000000..ff08f602 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/ternary_expression.py @@ -0,0 +1,20 @@ +from typing import TYPE_CHECKING, TypeVar + +from graph_sitter.core.expressions.ternary_expression import TernaryExpression +from graph_sitter.shared.decorators.docs import ts_apidoc + +if TYPE_CHECKING: + from graph_sitter.core.interfaces.editable import Editable + +Parent = TypeVar("Parent", bound="Editable") + + +@ts_apidoc +class TSTernaryExpression(TernaryExpression[Parent]): + """Any ternary expression in the code where a condition will determine branched execution""" + + def __init__(self, ts_node, file_node_id, ctx, parent: Parent) -> None: + super().__init__(ts_node, file_node_id, ctx, parent=parent) + self.condition = self.child_by_field_name("condition") + self.consequence = self.child_by_field_name("consequence") + self.alternative = self.child_by_field_name("alternative") diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/type.py new file mode 100644 index 00000000..74d4d7ad --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/type.py @@ -0,0 +1,2 @@ +TSType = "TSUnionType[Parent] | TSObjectType[Parent] | TSNamedType[Parent] | TSGenericType[Parent] | TSQueryType[Parent] | TSReadonlyType[Parent] | NoneType[Parent] | TSUndefinedType[Parent]" +__all__ = ["TSType"] diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/undefined_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/undefined_type.py new file mode 100644 index 00000000..5de45e14 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/undefined_type.py @@ -0,0 +1,29 @@ +from collections.abc import Generator +from typing import Generic, Self, TypeVar, override + +from graph_sitter.codebase.resolution_stack import ResolutionStack +from graph_sitter.compiled.autocommit import reader +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.expressions.type import Type +from graph_sitter.core.interfaces.importable import Importable +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSUndefinedType(Type[Parent], Generic[Parent]): + """Undefined type. Represents the undefined keyword + Examples: + undefined + """ + + @noapidoc + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + pass + + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + yield from [] diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/union_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/union_type.py new file mode 100644 index 00000000..92276b06 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/expressions/union_type.py @@ -0,0 +1,17 @@ +from typing import Generic, TypeVar + +from graph_sitter.core.expressions.union_type import UnionType +from graph_sitter.shared.decorators.docs import ts_apidoc + +Parent = TypeVar("Parent") + + +@ts_apidoc +class TSUnionType(UnionType["TSType", Parent], Generic[Parent]): + """Union type + + Examples: + string | number + """ + + pass diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/external/dependency_manager.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/external/dependency_manager.py new file mode 100644 index 00000000..528d9547 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/external/dependency_manager.py @@ -0,0 +1,376 @@ +import concurrent.futures +import json +import os +import shutil +import subprocess +import uuid +from dataclasses import dataclass +from enum import Enum + +import pyjson5 +import requests + +from graph_sitter.core.external.dependency_manager import DependencyManager +from graph_sitter.shared.logging.get_logger import get_logger +from graph_sitter.utils import shadow_files + +logger = get_logger(__name__) + + +class InstallerType(Enum): + NPM = "npm" + YARN = "yarn" + PNPM = "pnpm" + UNKNOWN = "unknown" + + +@dataclass +class PackageJsonData: + dependencies: dict[str, str] + dev_dependencies: dict[str, str] + package_data: dict + + +class TypescriptDependencyManager(DependencyManager): + should_install_dependencies: bool + installer_type: InstallerType + package_json_data: dict[str, PackageJsonData] + base_package_json_data: PackageJsonData | None + + """Handles dependency management for Typescript projects. Uses npm, yarn, or pnpm if applicable.""" + + def __init__(self, repo_path: str, base_path: str | None = None, should_install_dependencies: bool = True, force_installer: str | None = None): + super().__init__(repo_path, base_path) + logger.info(f"Initializing TypescriptDependencyManager with should_install_dependencies={should_install_dependencies}") + # Ensure that node, npm, yarn, and pnpm are installed + if not shutil.which("node"): + msg = "NodeJS is not installed" + raise RuntimeError(msg) + if not shutil.which("corepack"): + msg = "corepack is not installed" + raise RuntimeError(msg) + if not shutil.which("npm"): + msg = "npm is not installed" + raise RuntimeError(msg) + if not shutil.which("yarn"): + msg = "yarn is not installed" + raise RuntimeError(msg) + if not shutil.which("pnpm"): + msg = "pnpm is not installed" + raise RuntimeError(msg) + + self.should_install_dependencies = should_install_dependencies + # Detect the installer type + if force_installer: + self.installer_type = InstallerType(force_installer) + else: + self.installer_type = self._detect_installer_type() + + logger.info(f"Detected installer type: {self.installer_type}") + + # List of package.json files with their parsed data + self.package_json_data: dict[str, PackageJsonData] = {} + self.base_package_json_data: PackageJsonData | None = None + + def _detect_installer_type(self) -> InstallerType: + if os.path.exists(os.path.join(self.full_path, "yarn.lock")): + return InstallerType.YARN + elif os.path.exists(os.path.join(self.full_path, "package-lock.json")): + return InstallerType.NPM + elif os.path.exists(os.path.join(self.full_path, "pnpm-lock.yaml")): + return InstallerType.PNPM + else: + logger.warning("Could not detect installer type. Defaulting to NPM!") + return InstallerType.NPM + # return InstallerType.UNKNOWN + + @staticmethod + def _check_package_exists(package_name: str) -> bool: + """Check if a package exists on the npm registry.""" + url = f"https://registry.npmjs.org/{package_name}" + try: + response = requests.head(url) + return response.status_code == 200 + except requests.RequestException: + return False + + @classmethod + def _validate_dependencies(cls, deps: dict[str, str]) -> tuple[dict[str, str], dict[str, str]]: + """Validate a dictionary of dependencies against npm registry.""" + valid_deps = {} + invalid_deps = {} + + # Use ThreadPoolExecutor for concurrent validation + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + future_to_package = {executor.submit(cls._check_package_exists, package): (package, version) for package, version in deps.items()} + + for future in concurrent.futures.as_completed(future_to_package): + package, version = future_to_package[future] + try: + exists = future.result() + # Hack to fix github packages + if "github" in version: + version = version.split("#")[0] + if exists: + valid_deps[package] = version + else: + invalid_deps[package] = version + except Exception as e: + logger.exception(f"Error checking package {package}: {e}") + + return valid_deps, invalid_deps + + def parse_dependencies(self): + # Clear the package_json_data + self.package_json_data.clear() + + # Walk through directory tree + for current_dir, subdirs, files in os.walk(self.full_path): + # Skip node_modules directories + if "node_modules" in current_dir: + continue + + # Check if package.json exists in current directory + if "package.json" in files: + # Convert to absolute path and append to results + package_json_path = os.path.join(current_dir, "package.json") + + # Parse the package.json file + try: + # Read package.json + with open(package_json_path) as f: + package_data = pyjson5.load(f) + + # Get dependencies and devDependencies + dependencies = package_data.get("dependencies", {}) + dev_dependencies = package_data.get("devDependencies", {}) + + self.package_json_data[package_json_path] = PackageJsonData(dependencies, dev_dependencies, package_data) + + except FileNotFoundError: + logger.exception(f"Could not find package.json at {package_json_path}") + except ValueError: + logger.exception(f"Invalid json in package.json at {package_json_path}") + except Exception as e: + raise e + + # Set the base package.json data + base_package_json_path = os.path.join(self.full_path, "package.json") + self.base_package_json_data = self.package_json_data.get(base_package_json_path, None) + + def _install_dependencies_npm(self): + logger.info("Installing dependencies with NPM") + # Shadow package-lock.json, if it exists + files_to_shadow = [] + # Check if package-lock.json exists. + if os.path.exists(os.path.join(self.full_path, "package-lock.json")): + files_to_shadow.append(os.path.join(self.full_path, "package-lock.json")) + + # Shadow the files + with shadow_files(files_to_shadow): + # Remove the original package-lock.json + for file_path in files_to_shadow: + os.remove(file_path) + + # Print the node version + logger.info(f"Node version: {subprocess.check_output(['node', '--version'], cwd=self.full_path, text=True).strip()}") + + # Print the npm version + logger.info(f"NPM version: {subprocess.check_output(['npm', '--version'], cwd=self.full_path, text=True).strip()}") + + # NPM Install + try: + logger.info(f"Running npm install with cwd {self.full_path}") + subprocess.run(["npm", "install"], cwd=self.full_path, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + logger.exception(f"NPM FAIL: npm install failed with exit code {e.returncode}") + logger.exception(f"NPM FAIL stdout: {e.stdout}") + logger.exception(f"NPM FAIL stderr: {e.stderr}") + raise + + def _install_dependencies_yarn(self): + logger.info("Installing dependencies with Yarn") + # Shadow yarn.lock, yarn.config.cjs, and .yarnrc.yml, if they exist + files_to_shadow = [] + # Check if yarn.lock exists. + if os.path.exists(os.path.join(self.full_path, "yarn.lock")): + files_to_shadow.append(os.path.join(self.full_path, "yarn.lock")) + # Check if yarn.config.cjs exists. This fixes constraints + if os.path.exists(os.path.join(self.full_path, "yarn.config.cjs")): + files_to_shadow.append(os.path.join(self.full_path, "yarn.config.cjs")) + # Check if .yarnrc.yml exists. This fixes pre and post install scripts + if os.path.exists(os.path.join(self.full_path, ".yarnrc.yml")): + files_to_shadow.append(os.path.join(self.full_path, ".yarnrc.yml")) + + # Shadow the files + with shadow_files(files_to_shadow): + # If .yarnrc.yml exists, check if the yarnPath option is set and save it + yarn_path = None + if os.path.exists(os.path.join(self.full_path, ".yarnrc.yml")): + # Grab the line with "yarnPath" + with open(os.path.join(self.full_path, ".yarnrc.yml")) as f: + for line in f: + if "yarnPath" in line: + yarn_path = line.split(":")[1].strip() + break + # Remove all the shadowed files + for file_path in files_to_shadow: + os.remove(file_path) + + try: + # Disable PnP + with open(os.path.join(self.full_path, ".yarnrc.yml"), "w") as f: + f.write("nodeLinker: node-modules\n") + if yarn_path: + f.write(f"yarnPath: {yarn_path}\n") + + # Print the node version + logger.info(f"Node version: {subprocess.check_output(['node', '--version'], cwd=self.full_path, text=True).strip()}") + + # Print the yarn version + logger.info(f"Yarn version: {subprocess.check_output(['yarn', '--version'], cwd=self.full_path, text=True).strip()}") + + # This fixes a bug where swapping yarn versions corrups the metadata and package caches, + # causing all sorts of nasty issues + yarn_temp_global_dir: str = f"/tmp/yarn_tmp_{uuid.uuid4()}" + try: + # Yarn Install + try: + # Create custom flags for yarn + yarn_custom_flags = { + "YARN_ENABLE_IMMUTABLE_INSTALLS": "false", + "YARN_ENABLE_TELEMETRY": "false", + "YARN_ENABLE_GLOBAL_CACHE": "true", + "YARN_GLOBAL_FOLDER": yarn_temp_global_dir, + } + yarn_environ = { + **os.environ, + **yarn_custom_flags, + } + + # Set up yarn + logger.info(f"Running yarn install with cwd {self.full_path} and yarn_custom_flags {yarn_custom_flags}") + subprocess.run(["corepack", "enable"], cwd=self.full_path, check=True, capture_output=True, text=True) + subprocess.run(["corepack", "prepare", "--activate"], cwd=self.full_path, check=True, capture_output=True, text=True) + subprocess.run(["yarn", "install"], cwd=self.full_path, check=True, capture_output=True, text=True, env=yarn_environ) + except subprocess.CalledProcessError as e: + logger.exception(f"Yarn FAIL: yarn install failed with exit code {e.returncode}") + logger.exception(f"Yarn FAIL stdout: {e.stdout}") + logger.exception(f"Yarn FAIL stderr: {e.stderr}") + raise + finally: + # Clean up the temporary global directory + if os.path.exists(yarn_temp_global_dir): + shutil.rmtree(yarn_temp_global_dir) + finally: + # Check if the .yarnrc.yml file exists + if os.path.exists(os.path.join(self.full_path, ".yarnrc.yml")): + # Delete the .yarnrc.yml file + os.remove(os.path.join(self.full_path, ".yarnrc.yml")) + + def _install_dependencies_pnpm(self): + logger.info("Installing dependencies with PNPM") + # Shadow pnpm-lock.yaml, if it exists + files_to_shadow = [] + if os.path.exists(os.path.join(self.full_path, "pnpm-lock.yaml")): + files_to_shadow.append(os.path.join(self.full_path, "pnpm-lock.yaml")) + + # Shadow the files + with shadow_files(files_to_shadow): + # Remove all the shadowed files + for file_path in files_to_shadow: + os.remove(file_path) + + # Print the node version + logger.info(f"Node version: {subprocess.check_output(['node', '--version'], cwd=self.full_path, text=True).strip()}") + + # Print the pnpm version + logger.info(f"PNPM version: {subprocess.check_output(['pnpm', '--version'], cwd=self.full_path, text=True).strip()}") + + # PNPM Install + try: + logger.info(f"Running pnpm install with cwd {self.full_path}") + subprocess.run(["pnpm", "install"], cwd=self.full_path, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + logger.exception(f"PNPM FAIL: pnpm install failed with exit code {e.returncode}") + logger.exception(f"PNPM FAIL stdout: {e.stdout}") + logger.exception(f"PNPM FAIL stderr: {e.stderr}") + raise + + def _clean_package_json(self, package_json_path: str): + # Get the package data + data = self.package_json_data[package_json_path] + + # Get valid dependencies + valid_deps, _ = self._validate_dependencies(data.dependencies) + valid_dev_deps, _ = self._validate_dependencies(data.dev_dependencies) + + # Create a slimmed down package.json with only the valid dependencies + clean_package_data = {} + + # Copy important fields + clean_package_data["name"] = data.package_data.get("name", "unknown") + clean_package_data["version"] = data.package_data.get("version", "v1.0.0") + if "packageManager" in data.package_data: + clean_package_data["packageManager"] = data.package_data["packageManager"] + if "workspaces" in data.package_data: + clean_package_data["workspaces"] = data.package_data["workspaces"] + + # Copy dependencies + clean_package_data["dependencies"] = valid_deps + clean_package_data["devDependencies"] = valid_dev_deps + + # Write the cleaned package.json + with open(package_json_path, "w") as f: + json_str = json.dumps(clean_package_data, indent=2) + f.write(json_str) + + def install_dependencies(self, validate_dependencies: bool = True): + if validate_dependencies: + with shadow_files(list(self.package_json_data.keys())): + logger.info(f"Cleaning package.json files: {list(self.package_json_data.keys())}") + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + executor.map(self._clean_package_json, self.package_json_data.keys()) + + # Install dependencies, now that we have a valid package.json + return self.install_dependencies(validate_dependencies=False) + else: + if self.installer_type == InstallerType.NPM: + return self._install_dependencies_npm() + elif self.installer_type == InstallerType.YARN: + return self._install_dependencies_yarn() + elif self.installer_type == InstallerType.PNPM: + return self._install_dependencies_pnpm() + else: + logger.warning(f"Installer type {self.installer_type} not implemented") + + def remove_dependencies(self): + # Delete node_modules folder if it exists + node_modules_path = os.path.join(self.full_path, "node_modules") + if os.path.exists(node_modules_path): + shutil.rmtree(node_modules_path) + + def _start(self): + try: + logger.info(f"Starting TypescriptDependencyManager with should_install_dependencies={self.should_install_dependencies}") + super()._start() + # Remove dependencies if we are installing them + if self.should_install_dependencies: + logger.info("Removing existing dependencies") + self.remove_dependencies() + + # Parse dependencies + logger.info("Parsing dependencies") + self.parse_dependencies() + + # Install dependencies if we are installing them + if self.should_install_dependencies: + logger.info("Installing dependencies") + self.install_dependencies() + + # We are ready + logger.info("Finalizing TypescriptDependencyManager") + self.is_ready = True + except Exception as e: + self._error = e + logger.error(f"Error starting TypescriptDependencyManager: {e}", exc_info=True) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/external/mega_racer.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/external/mega_racer.py new file mode 100644 index 00000000..ea0a9807 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/external/mega_racer.py @@ -0,0 +1,30 @@ +from py_mini_racer import MiniRacer, init_mini_racer +from py_mini_racer._context import Context +from py_mini_racer._set_timeout import INSTALL_SET_TIMEOUT + + +class MegaRacer(MiniRacer): + """MegaRacer is a patch on MiniRacer that allows for more memory. + + Original MiniRacer: + MiniRacer evaluates JavaScript code using a V8 isolate. + + A MiniRacer instance can be explicitly closed using the close() method, or by using + the MiniRacer as a context manager, i.e,: + + with MiniRacer() as mr: + ... + + The MiniRacer instance will otherwise clean up the underlying V8 resource upon + garbage collection. + + Attributes: + json_impl: JSON module used by helper methods default is + [json](https://docs.python.org/3/library/json.html) + """ + + def __init__(self) -> None: + # Set the max old space size to 64GB + dll = init_mini_racer(ignore_duplicate_init=True, flags=["--max-old-space-size=65536"]) + self._ctx = Context(dll) + self.eval(INSTALL_SET_TIMEOUT) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/external/ts_analyzer_engine.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/external/ts_analyzer_engine.py new file mode 100644 index 00000000..a1202299 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/external/ts_analyzer_engine.py @@ -0,0 +1,250 @@ +import json +import os +import shutil +import subprocess +import uuid +from abc import abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING + +from py_mini_racer import MiniRacer +from py_mini_racer._objects import JSMappedObject +from py_mini_racer._types import JSEvalException + +from graph_sitter.core.external.language_engine import LanguageEngine +from graph_sitter.shared.logging.get_logger import get_logger +from graph_sitter.typescript.external.mega_racer import MegaRacer + +if TYPE_CHECKING: + from graph_sitter.core.external.dependency_manager import DependencyManager + from graph_sitter.core.interfaces.editable import Editable + + +logger = get_logger(__name__) + + +class TypescriptEngine(LanguageEngine): + dependency_manager: "DependencyManager | None" + + def __init__(self, repo_path: str, base_path: str | None = None, dependency_manager: "DependencyManager | None" = None): + super().__init__(repo_path, base_path) + self.dependency_manager = dependency_manager + + @abstractmethod + def _start(self): + # If a dependency manager is provided, make sure it is ready + if self.dependency_manager: + logger.info(f"TypescriptEngine: Waiting for {self.dependency_manager.__class__.__name__} to be ready...") + self.dependency_manager.wait_until_ready(ignore_error=True) + # Start the engine + super()._start() + + +class V8TypescriptEngine(TypescriptEngine): + """Typescript-compiler based language engine using MiniRacer's V8-based JS engine. + + More experimental approach to type inference, but is faster and more flexible. + + Attributes: + hard_memory_limit (int): Maximum memory limit in bytes before V8 will force garbage collection + soft_memory_limit (int): Memory threshold in bytes that triggers garbage collection + """ + + hard_memory_limit: int + soft_memory_limit: int + ctx: MiniRacer | None + mr_type_script_analyzer: JSMappedObject | None + + def __init__( + self, + repo_path: str, + base_path: str | None = None, + dependency_manager: "DependencyManager | None" = None, + hard_memory_limit: int = 1024 * 1024 * 1024 * 16, + soft_memory_limit: int = 1024 * 1024 * 1024 * 8, + ): + super().__init__(repo_path, base_path, dependency_manager) + logger.info(f"Initializing V8TypescriptEngine with hard_memory_limit={hard_memory_limit} and soft_memory_limit={soft_memory_limit}") + self.hard_memory_limit: int = hard_memory_limit + self.soft_memory_limit: int = soft_memory_limit + self.ctx: MiniRacer | None = None + self.mr_type_script_analyzer: JSMappedObject | None = None + # Get the path to the current file + self.current_file_path: str = os.path.abspath(__file__) + # Get the path of the language engine + self.engine_path: str = os.path.join(os.path.dirname(self.current_file_path), "typescript_analyzer", "dist", "index.js") + if not os.path.exists(self.engine_path): + msg = f"Typescript analyzer engine not found at {self.engine_path}" + raise FileNotFoundError(msg) + self.engine_source: str = open(self.engine_path).read() + self._patch_engine_source() + + def _start(self): + try: + logger.info("Starting V8TypescriptEngine") + super()._start() + # Create the MiniRacer/MegaRacer context + self.ctx = MegaRacer() # MegaRacer is a patch on MiniRacer that allows for more memory + # Set to 16GB + self.ctx.set_hard_memory_limit(self.hard_memory_limit) + self.ctx.set_soft_memory_limit(self.soft_memory_limit) + + # Load the engine + logger.info(f"Loading engine source with {len(self.engine_source)} bytes") + self.ctx.eval(self.engine_source) + + # Set up proxy file system + logger.info("Setting up proxy file system") + self.ctx.eval("var interop_fs = new ProxyFileSystem();") + self.ctx.eval("var fs_files = {};") + fs_files = self.ctx.eval("fs_files") + self._populate_fs_files(fs_files) + self.ctx.eval("fs_file_map = new Map(Object.entries(fs_files));") + self.ctx.eval("interop_fs.setFiles(fs_file_map);") + + # Set up the analyzer + logger.info(f"Setting up analyzer with path {self.full_path}") + self.ctx.eval(f"const type_script_analyzer = new TypeScriptAnalyzer('{self.full_path}', interop_fs);") + self.mr_type_script_analyzer = self.ctx.eval("type_script_analyzer") + + # Finalize + logger.info("Finalizing V8TypescriptEngine") + self.is_ready = True + except Exception as e: + self._error = e + logger.error(f"Error starting V8TypescriptEngine: {e}", exc_info=True) + + def _populate_fs_files(self, fs_files: dict): + for root, _, files in os.walk(self.full_path): + for filename in files: + file_path = Path(root) / filename + s_fp = str(file_path) + + # Only process JS/TS related files + if not s_fp.endswith((".ts", ".tsx", ".js", ".jsx", ".json", ".d.ts")): + continue + + try: + with open(file_path, encoding="utf-8") as f: + if "node_modules" in s_fp: + if not s_fp.endswith(".json") and not s_fp.endswith(".d.ts"): + continue + content = f.read() + fs_files[str(file_path)] = content + except (UnicodeDecodeError, OSError): + # Skip files that can't be read as text + continue + + def _patch_engine_source(self): + """MiniRacer does not support require and export, so we need to patch the engine source to remove them.""" + logger.info("Patching engine source to remove require and export") + patch_map = { + "var require$$1 = require('fs');": "", + "var require$$2 = require('path');": "", + "var require$$3 = require('os');": "", + "var require$$6 = require('inspector');": "", + "exports.ProxyFileSystem = ProxyFileSystem;": "", + "exports.TypeScriptAnalyzer = TypeScriptAnalyzer;": "", + } + for old, new in patch_map.items(): + self.engine_source = self.engine_source.replace(old, new) + + def get_return_type(self, node: "Editable") -> str | None: + file_path = os.path.join(self.repo_path, node.filepath) + try: + return self.ctx.eval(f"type_script_analyzer.getFunctionAtPosition('{file_path}', {node.start_byte})") + except JSEvalException as e: + return None + + +class NodeTypescriptEngine(TypescriptEngine): + """Typescript-compiler based language engine using NodeJS and the TypeScript compiler. + + More mature approach to type inference, but is slower and less flexible. + + Attributes: + type_data (dict | None): Type data for the codebase + """ + + type_data: dict | None + + def __init__(self, repo_path: str, base_path: str | None = None, dependency_manager: "DependencyManager | None" = None): + super().__init__(repo_path, base_path, dependency_manager) + logger.info("Initializing NodeTypescriptEngine") + self.type_data: dict | None = None + + # Get the path to the current file + self.current_file_path: str = os.path.abspath(__file__) + # Ensure NodeJS and npm are installed + if not shutil.which("node") or not shutil.which("npm"): + msg = "NodeJS or npm is not installed" + raise RuntimeError(msg) + + # Get the path to the typescript analyzer + self.analyzer_path: str = os.path.join(os.path.dirname(self.current_file_path), "typescript_analyzer") + self.analyzer_entry: str = os.path.join(self.analyzer_path, "src", "run_full.ts") + if not os.path.exists(self.analyzer_path): + msg = f"Typescript analyzer not found at {self.analyzer_path}" + raise FileNotFoundError(msg) + + def _start(self): + try: + logger.info("Starting NodeTypescriptEngine") + super()._start() + # NPM Install + try: + logger.info("Installing typescript analyzer dependencies") + subprocess.run(["npm", "install"], cwd=self.analyzer_path, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + logger.exception(f"NPM FAIL: npm install failed with exit code {e.returncode}") + logger.exception(f"NPM FAIL stdout: {e.stdout}") + logger.exception(f"NPM FAIL stderr: {e.stderr}") + raise + + # Create a temporary output file with a random name + output_file_path: str = f"/tmp/ts_analyzer_output_{uuid.uuid4()}.json" + try: + # Run the analyzer + try: + # Create custom flags for node + node_environ = {**os.environ, "NODE_OPTIONS": "--max_old_space_size=8192"} + + # Run the analyzer + logger.info(f"Running analyzer with project path {self.full_path} and output file {output_file_path}") + subprocess.run( + ["node", "--loader", "ts-node/esm", self.analyzer_entry, "--project", self.full_path, "--output", output_file_path], + cwd=self.analyzer_path, + check=True, + capture_output=True, + text=True, + env=node_environ, + ) + except subprocess.CalledProcessError as e: + logger.exception(f"ANALYZER FAIL: analyzer failed with exit code {e.returncode}") + logger.exception(f"ANALYZER FAIL stdout: {e.stdout}") + logger.exception(f"ANALYZER FAIL stderr: {e.stderr}") + raise + + # Load the type data + self.type_data = json.load(open(output_file_path)) + finally: + # Clean up the output file + if os.path.exists(output_file_path): + os.remove(output_file_path) + + # Finalize + logger.info("Finalizing NodeTypescriptEngine") + self.is_ready = True + except Exception as e: + self._error = e + logger.error(f"Error starting NodeTypescriptEngine: {e}", exc_info=True) + + def get_return_type(self, node: "Editable") -> str | None: + file_path: str = os.path.join(self.repo_path, node.filepath) + if not self.type_data: + return None + codebase_data: dict = self.type_data.get("files", {}) + file_data: dict = codebase_data.get(file_path, {}) + functions_data: dict = file_data.get("functions", {}) + function_data: dict = functions_data.get(node.name, {}) + return function_data.get("returnType", None) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/external/ts_declassify/ts_declassify.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/external/ts_declassify/ts_declassify.py new file mode 100644 index 00000000..2c1d8361 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/external/ts_declassify/ts_declassify.py @@ -0,0 +1,94 @@ +import os +import shutil +import subprocess + +from graph_sitter.core.external.external_process import ExternalProcess +from graph_sitter.shared.logging.get_logger import get_logger + +logger = get_logger(__name__) + + +class TSDeclassify(ExternalProcess): + def __init__(self, repo_path: str, base_path: str, working_dir: str = "/tmp/ts_declassify"): + super().__init__(repo_path, base_path) + self.working_dir = working_dir + + # Ensure NodeJS and npm are installed + if not shutil.which("node") or not shutil.which("npm"): + msg = "NodeJS or npm is not installed" + raise RuntimeError(msg) + + def _start(self): + try: + logger.info("Installing ts-declassify...") + + # Remove existing working directory + if os.path.exists(self.working_dir): + shutil.rmtree(self.working_dir) + + # Creating ts-declassify working directory + os.makedirs(self.working_dir, exist_ok=True) + + # NPM Init + try: + logger.info(f"Running npm init in {self.working_dir}") + subprocess.run(["npm", "init", "-y"], cwd=self.working_dir, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + logger.exception(f"NPM FAIL: npm init failed with exit code {e.returncode}") + logger.exception(f"NPM FAIL stdout: {e.stdout}") + logger.exception(f"NPM FAIL stderr: {e.stderr}") + raise + + # NPM Install + try: + logger.info(f"Running npm install in {self.working_dir}") + subprocess.run(["npm", "install", "-D", "@codemod/cli", "react-declassify"], cwd=self.working_dir, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + logger.exception(f"NPM FAIL: npm install failed with exit code {e.returncode}") + logger.exception(f"NPM FAIL stdout: {e.stdout}") + logger.exception(f"NPM FAIL stderr: {e.stderr}") + raise + + # Finalize + self.is_ready = True + except Exception as e: + self._error = e + logger.exception(f"Error installing ts-declassify: {e}") + raise e + + def reparse(self): + msg = "TSDeclassify does not support reparse" + raise NotImplementedError(msg) + + def declassify(self, source: str, filename: str = "file.tsx", error_on_failure: bool = True): + assert self.ready(), "TSDeclassify is not ready" + + try: + # Remove and recreate file.tsx + source_file = os.path.join(self.working_dir, filename) + with open(source_file, "w") as f: + f.write(source) + + # Run declassify + try: + subprocess.run(["npx", "codemod", "--plugin", "react-declassify", source_file], cwd=self.working_dir, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + logger.exception(f"DECLASSIFY FAIL: declassify failed with exit code {e.returncode}") + logger.exception(f"DECLASSIFY FAIL stdout: {e.stdout}") + logger.exception(f"DECLASSIFY FAIL stderr: {e.stderr}") + raise + + # Get the declassified source + with open(source_file) as f: + declassified_source = f.read() + + # Raise an error if the declassification failed + if error_on_failure and "Cannot perform transformation" in declassified_source: + msg = "Declassification failed!" + raise RuntimeError(msg) + finally: + # Remove file.tsx if it exists + if os.path.exists(source_file): + os.remove(source_file) + + return declassified_source diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/file.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/file.py new file mode 100644 index 00000000..66e67c7f --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/file.py @@ -0,0 +1,450 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +from graph_sitter.compiled.sort import sort_editables +from graph_sitter.compiled.utils import cached_property +from graph_sitter.core.autocommit import mover, reader, writer +from graph_sitter.core.file import SourceFile +from graph_sitter.core.interfaces.exportable import Exportable +from graph_sitter.enums import ImportType, NodeType, SymbolType +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.typescript.assignment import TSAssignment +from graph_sitter.typescript.class_definition import TSClass +from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock +from graph_sitter.typescript.export import TSExport +from graph_sitter.typescript.function import TSFunction +from graph_sitter.typescript.import_resolution import TSImport +from graph_sitter.typescript.interface import TSInterface +from graph_sitter.typescript.interfaces.has_block import TSHasBlock +from graph_sitter.typescript.namespace import TSNamespace +from graph_sitter.utils import calculate_base_path + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.statements.export_statement import ExportStatement + from graph_sitter.core.symbol import Symbol + from graph_sitter.typescript.detached_symbols.promise_chain import TSPromiseChain + from graph_sitter.typescript.symbol import TSSymbol + from graph_sitter.typescript.ts_config import TSConfig + from graph_sitter.typescript.type_alias import TSTypeAlias + + +@ts_apidoc +class TSFile(SourceFile[TSImport, TSFunction, TSClass, TSAssignment, TSInterface, TSCodeBlock], TSHasBlock, Exportable): + """Extends the SourceFile class to provide TypeScript-specific functionality. + + Attributes: + programming_language: The programming language of the file. Set to ProgrammingLanguage.TYPESCRIPT. + ts_config: The ts_config file nearest to this file. + """ + + programming_language = ProgrammingLanguage.TYPESCRIPT + ts_config: TSConfig | None = None + + @cached_property + @reader(cache=False) + def exports(self) -> list[TSExport]: + """Returns all Export symbols in the file. + + Retrieves a list of all top-level export declarations in the current TypeScript file. + Does not include exports inside namespaces. + + Returns: + list[TSExport]: A list of TSExport objects representing all top-level export declarations in the file. + """ + # Filter to only get exports that are direct children of the file's code block + return sort_editables(filter(lambda node: isinstance(node, TSExport) and ((node.parent.parent.parent == self) or (node.parent.parent == self)), self.get_nodes(sort=False)), by_id=True) + + @property + @reader(cache=False) + def export_statements(self) -> list[ExportStatement[TSExport]]: + """Returns a list of all export statements in the file. + + Each export statement in the returned list can contain multiple exports. The export statements + are sorted by their position in the file. + + Args: + None + + Returns: + list[ExportStatement[TSExport]]: A list of ExportStatement objects, where each ExportStatement + contains one or more TSExport objects. + """ + export_statements = [exp.export_statement for exp in self.exports] + return sort_editables(export_statements) + + @property + @reader(cache=False) + def default_exports(self) -> list[TSExport]: + """Returns all default export symbols from the file. + + A property method that retrieves all export objects that are designated as default exports from the file. + + Returns: + list[TSExport]: A list of default export objects. Each object belongs to a single export statement. + """ + return [x for x in self.exports if x.is_default_export()] + + @property + @reader + def named_exports(self) -> list[TSExport]: + """Returns the named exports declared in the file. + + Gets all export statements in the file that are not default exports. These exports are defined + using the `export` keyword rather than `export default`. + + Args: + self (TSFile): The TypeScript file object. + + Returns: + list[TSExport]: A list of TSExport objects representing named exports in the file. + """ + return [x for x in self.exports if not x.is_default_export()] + + @reader + def get_export(self, export_name: str) -> TSExport | None: + """Returns an export object with the specified name from the file. + + This method searches for an export with the given name in the file. + + Args: + export_name (str): The name of the export to find. + + Returns: + TSExport | None: The export object if found, None otherwise. + """ + return next((x for x in self.exports if x.name == export_name), None) + + @property + @reader + def interfaces(self) -> list[TSInterface]: + """Returns all Interfaces in the file. + + Retrieves all symbols in the file that are of type Interface. + + Args: + None + + Returns: + list[TSInterface]: A list of TypeScript interface symbols defined in the file. + """ + return [s for s in self.symbols if s.symbol_type == SymbolType.Interface] + + @reader + def get_interface(self, name: str) -> TSInterface | None: + """Retrieves a specific interface from the file by its name. + + Args: + name (str): The name of the interface to find. + + Returns: + TSInterface | None: The interface with the specified name if found, None otherwise. + """ + return next((x for x in self.interfaces if x.name == name), None) + + @property + @reader + def types(self) -> list[TSTypeAlias]: + """Returns all type aliases in the file. + + Retrieves a list of all type aliases defined in the current TypeScript/JavaScript file. + + Returns: + list[TSTypeAlias]: A list of all type aliases in the file. Empty list if no type aliases are found. + """ + return [s for s in self.symbols if s.symbol_type == SymbolType.Type] + + @reader + def get_type(self, name: str) -> TSTypeAlias | None: + """Returns a specific Type by name from the file's types. + + Retrieves a TypeScript type alias by its name from the file's collection of types. + + Args: + name (str): The name of the type alias to retrieve. + + Returns: + TSTypeAlias | None: The TypeScript type alias with the matching name, or None if not found. + """ + return next((x for x in self.types if x.name == name), None) + + @staticmethod + def get_extensions() -> list[str]: + """Returns a list of file extensions that this class can parse. + + Returns a list of file extensions for TypeScript and JavaScript files that this File class can parse and process. + + Returns: + list[str]: A list of file extensions including '.tsx', '.ts', '.jsx', and '.js'. + """ + return [".tsx", ".ts", ".jsx", ".js"] + + def symbol_can_be_added(self, symbol: TSSymbol) -> bool: + """Determines if a TypeScript symbol can be added to this file based on its type and JSX compatibility. + + This method checks whether a given symbol can be added to the current TypeScript file by validating its compatibility with the file's extension. + In particular, it ensures that JSX functions are only added to appropriate file types (.tsx or .jsx). + + Args: + symbol (TSSymbol): The TypeScript symbol to be checked. + + Returns: + bool: True if the symbol can be added to this file, False otherwise. + """ + if symbol.symbol_type == SymbolType.Function: + if symbol.is_jsx: + if not (self.file_path.endswith("tsx") or self.file_path.endswith("jsx")): + return False + return True + + @reader + def get_config(self) -> TSConfig | None: + """Returns the nearest tsconfig.json applicable to this file. + + Gets the TypeScript configuration for the current file by retrieving the nearest tsconfig.json file in the directory hierarchy. + + Returns: + TSConfig | None: The TypeScript configuration object if found, None otherwise. + """ + return self.ts_config + + @writer + def add_export_to_symbol(self, symbol: TSSymbol) -> None: + """Adds an export keyword to a symbol in a TypeScript file. + + Marks a symbol for export by adding the 'export' keyword. This modifies the symbol's + declaration to make it available for import by other modules. + + Args: + symbol (TSSymbol): The TypeScript symbol (function, class, interface, etc.) to be exported. + + Returns: + None + """ + # TODO: this should be in symbol.py class. Rename as `add_export` + symbol.add_keyword("export") + + @writer + def remove_unused_exports(self) -> None: + """Removes unused exports from the file. + + Analyzes all exports in the file and removes any that are not used. An export is considered unused if it has no direct + symbol usages and no re-exports that are used elsewhere in the codebase. + + When removing unused exports, the method also cleans up any related unused imports. For default exports, it removes + the 'export default' keyword, and for named exports, it removes the 'export' keyword or the entire export statement. + + Args: + None + + Returns: + None + """ + for export in self.exports: + symbol_export_unused = True + symbols_to_remove = [] + + exported_symbol = export.resolved_symbol + for export_usage in export.symbol_usages: + if export_usage.node_type == NodeType.IMPORT or (export_usage.node_type == NodeType.EXPORT and export_usage.resolved_symbol != exported_symbol): + # If the import has no usages then we can add the import to the list of symbols to remove + reexport_usages = export_usage.symbol_usages + if len(reexport_usages) == 0: + symbols_to_remove.append(export_usage) + break + + # If any of the import's usages are valid symbol usages, export is used. + if any(usage.node_type == NodeType.SYMBOL for usage in reexport_usages): + symbol_export_unused = False + break + + symbols_to_remove.append(export_usage) + + elif export_usage.node_type == NodeType.SYMBOL: + symbol_export_unused = False + break + + # export is not used, remove it + if symbol_export_unused: + # remove the unused imports + for imp in symbols_to_remove: + imp.remove() + + if exported_symbol == exported_symbol.export.declared_symbol: + # change this to be more robust + if exported_symbol.source.startswith("export default "): + exported_symbol.replace("export default ", "") + else: + exported_symbol.replace("export ", "") + else: + exported_symbol.export.remove() + if exported_symbol.export != export: + export.remove() + + @noapidoc + def _get_export_data(self, relative_path: str, export_type: str = "EXPORT") -> tuple[tuple[str, str], dict[str, callable]]: + quoted_paths = (f"'{relative_path}'", f'"{relative_path}"') + export_type_conditions = { + "WILDCARD": lambda exp: exp.is_wildcard_export(), + "TYPE": lambda exp: exp.is_type_export(), + # Changed this condition - it was incorrectly handling type exports + "EXPORT": lambda exp: (not exp.is_type_export() and not exp.is_wildcard_export()), + } + return quoted_paths, export_type_conditions + + @reader + def has_export_statement_for_path(self, relative_path: str, export_type: str = "EXPORT") -> bool: + """Checks if the file has exports of specified type that contains the given path in single or double quotes. + + Args: + relative_path (str): The path to check for in export statements + export_type (str): Type of export to check for - "WILDCARD", "TYPE", or "EXPORT" (default) + + Returns: + bool: True if there exists an export of specified type with the exact relative path (quoted) + in its source, False otherwise. + """ + if not self.export_statements: + return False + + quoted_paths, export_type_conditions = self._get_export_data(relative_path, export_type) + condition = export_type_conditions[export_type] + + return any(any(quoted_path in stmt.source for quoted_path in quoted_paths) and any(condition(exp) for exp in stmt.exports) for stmt in self.export_statements) + + #################################################################################################################### + # GETTERS + #################################################################################################################### + + @reader + def get_export_statement_for_path(self, relative_path: str, export_type: str = "EXPORT") -> ExportStatement | None: + """Gets the first export of specified type that contains the given path in single or double quotes. + + Args: + relative_path (str): The path to check for in export statements + export_type (str): Type of export to get - "WILDCARD", "TYPE", or "EXPORT" (default) + + Returns: + TSExport | None: The first matching export if found, None otherwise. + """ + if not self.export_statements: + return None + + quoted_paths, export_type_conditions = self._get_export_data(relative_path, export_type) + condition = export_type_conditions[export_type] + + for stmt in self.export_statements: + if any(quoted_path in stmt.source for quoted_path in quoted_paths): + for exp in stmt.exports: + if condition(exp): + return exp + + return None + + @noapidoc + def get_import_module_name_for_file(self, filepath: str, ctx: CodebaseContext) -> str: + """Returns the module name that this file gets imported as""" + # TODO: support relative and absolute module path + import_path = filepath + + # Apply path import aliases to import_path + if self.ts_config: + import_path = self.ts_config.translate_absolute_path(import_path) + + # Remove file extension + import_path = os.path.splitext(import_path)[0] + return f"'{import_path}'" + + @reader + def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: + """Generates and returns an import statement for the file. + + Constructs an import statement string based on the file's name and module information. + + Args: + alias (str | None): Alternative name for the imported module. Defaults to None. + module (str | None): Module path to import from. If None, uses file's default module name. + import_type (ImportType): The type of import statement. Defaults to ImportType.UNKNOWN. + is_type_import (bool): Whether this is a type-only import. Defaults to False. + + Returns: + str: A formatted import statement string importing all exports from the module. + """ + import_module = module if module is not None else self.import_module_name + file_module = self.name + return f"import * as {file_module} from {import_module}" + + @cached_property + @noapidoc + @reader(cache=True) + def valid_import_names(self) -> dict[str, Symbol | TSImport]: + """Returns a dict mapping name => Symbol (or import) in this file that can be imported from another file""" + valid_export_names = {} + if len(self.default_exports) == 1: + valid_export_names["default"] = self.default_exports[0] + for export in self.exports: + for name, dest in export.names: + valid_export_names[name] = dest + return valid_export_names + + #################################################################################################################### + # MANIPULATIONS + #################################################################################################################### + + @mover + def update_filepath(self, new_filepath: str) -> None: + """Updates the file path of the current file and all associated imports. + + Renames the current file to a new file path and updates all imports that reference this file to point to the new location. + + Args: + new_filepath (str): The new file path to move the file to. + + Returns: + None + """ + # =====[ Add the new filepath as a new file node in the graph ]===== + new_file = self.ctx.node_classes.file_cls.from_content(new_filepath, self.content, self.ctx) + # =====[ Change the file on disk ]===== + self.transaction_manager.add_file_rename_transaction(self, new_filepath) + # =====[ Update all the inbound imports to point to the new module ]===== + for imp in self.inbound_imports: + existing_imp = imp.module.source.strip("'") + new_module_name = new_file.import_module_name.strip("'") + # Web specific hacks + if self.ctx.repo_name == "web": + if existing_imp.startswith("./"): + relpath = calculate_base_path(new_filepath, existing_imp) + new_module_name = new_module_name.replace(relpath, ".") + elif existing_imp.startswith("~/src"): + new_module_name = new_module_name.replace("src/", "~/src/") + imp.set_import_module(f"'{new_module_name}'") + + @reader + def get_namespace(self, name: str) -> TSNamespace | None: + """Returns a specific namespace by name from the file's namespaces. + + Args: + name (str): The name of the namespace to find. + + Returns: + TSNamespace | None: The namespace with the specified name if found, None otherwise. + """ + return next((x for x in self.symbols if isinstance(x, TSNamespace) and x.name == name), None) + + @property + @reader + def promise_chains(self) -> list[TSPromiseChain]: + """Returns all promise chains in the file. + + Returns: + list[TSPromiseChain]: A list of promise chains in the file. + """ + promise_chains = [] + for function in self.functions: + for promise_chain in function.promise_chains: + promise_chains.append(promise_chain) + return promise_chains diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/function.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/function.py new file mode 100644 index 00000000..6fe17ed3 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/function.py @@ -0,0 +1,452 @@ +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING + +from graph_sitter.core.autocommit import commiter, reader, writer +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.function import Function +from graph_sitter.core.symbol_groups.collection import Collection +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.shared.logging.get_logger import get_logger +from graph_sitter.typescript.detached_symbols.decorator import TSDecorator +from graph_sitter.typescript.detached_symbols.parameter import TSParameter +from graph_sitter.typescript.enums import TSFunctionTypeNames +from graph_sitter.typescript.expressions.type import TSType +from graph_sitter.typescript.interfaces.has_block import TSHasBlock +from graph_sitter.typescript.placeholder.placeholder_return_type import TSReturnTypePlaceholder +from graph_sitter.typescript.symbol import TSSymbol +from graph_sitter.utils import find_all_descendants + +if TYPE_CHECKING: + from collections.abc import Generator + + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.import_resolution import Import, WildcardImport + from graph_sitter.core.interfaces.has_name import HasName + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.core.statements.export_statement import ExportStatement + from graph_sitter.core.statements.symbol_statement import SymbolStatement + from graph_sitter.core.symbol import Symbol + from graph_sitter.typescript.detached_symbols.promise_chain import TSPromiseChain +_VALID_TYPE_NAMES = {function_type.value for function_type in TSFunctionTypeNames} +logger = get_logger(__name__) + + +@ts_apidoc +class TSFunction(Function[TSDecorator, "TSCodeBlock", TSParameter, TSType], TSHasBlock, TSSymbol): + """Representation of a Function in JavaScript/TypeScript""" + + @noapidoc + @commiter + def parse(self, ctx: CodebaseContext) -> None: + super().parse(ctx) + + self.return_type = self.child_by_field_name("return_type", placeholder=TSReturnTypePlaceholder) + if parameters_node := self.ts_node.child_by_field_name("parameters"): + self._parameters = Collection(parameters_node, self.file_node_id, self.ctx, self) + params = [x for x in parameters_node.children if x.type in ("required_parameter", "optional_parameter")] + symbols = None + # Deconstructed object parameters + if len(params) == 1: + pattern = params[0].child_by_field_name("pattern") + type_annotation = None + if type_node := params[0].child_by_field_name("type"): + type_annotation = self._parse_type(type_node) + if pattern and pattern.type == "object_pattern": + params = [x for x in pattern.children if x.type in ("shorthand_property_identifier_pattern", "object_assignment_pattern", "pair_pattern")] + symbols = [TSParameter(x, i, self._parameters, type_annotation) for (i, x) in enumerate(params)] + # Default case - regular parameters + if symbols is None: + symbols = [TSParameter(x, i, self._parameters) for (i, x) in enumerate(params)] + self._parameters._init_children(symbols) + elif parameters_node := self.ts_node.child_by_field_name("parameter"): + self._parameters = Collection(parameters_node, self.file_node_id, self.ctx, self) + self._parameters._init_children([TSParameter(parameters_node, 0, self._parameters)]) + else: + logger.warning(f"Couldn't find parameters for {self!r}") + self._parameters = [] + self.type_parameters = self.child_by_field_name("type_parameters") + + @property + @reader + def function_type(self) -> TSFunctionTypeNames: + """Gets the type of function from its TreeSitter node. + + Extracts and returns the type of function (e.g., arrow function, generator function, function expression) + from the node's type information. + + Args: + None: Property method that uses instance's ts_node. + + Returns: + TSFunctionTypeNames: The function type enum value representing the specific type of function. + """ + return TSFunctionTypeNames(self.ts_node.type) + + @noapidoc + @commiter + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: + # If a destination is provided, use it, otherwise use the default destination + # This is used for cases where a non-symbol (eg. argument) value parses as a function + dest = dest or self.self_dest + + # =====[ Typed Parameters ]===== + # Have to grab types from the parameters + if self.parameters is not None: + for param in self.parameters: + assignment_patterns = find_all_descendants(param.ts_node, {"object_pattern", "object_assignment_pattern", "assignment_pattern"}) + if assignment_patterns: + dest.add_all_identifier_usages_for_child_node(UsageKind.GENERIC, assignment_patterns[0]) + if self.type_parameters: + self.type_parameters._compute_dependencies(UsageKind.GENERIC, dest) + # =====[ Return type ]===== + if self.return_type: + # Need to parse all the different types + self.return_type._compute_dependencies(UsageKind.RETURN_TYPE, dest) + + # =====[ Code Block ]===== + self.code_block._compute_dependencies(usage_type, dest) + + @classmethod + @noapidoc + def from_function_type(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: SymbolStatement | ExportStatement) -> TSFunction: + """Creates a TSFunction object from a function declaration.""" + if ts_node.type not in [function_type.value for function_type in TSFunctionTypeNames]: + msg = f"Node type={ts_node.type} is not a function declaration" + raise ValueError(msg) + file = ctx.get_node(file_node_id) + if canonical := file._range_index.get_canonical_for_range(ts_node.range, ts_node.kind_id): + return canonical + return cls(ts_node, file_node_id, ctx, parent=parent) + + @staticmethod + @noapidoc + def _get_name_node(ts_node: TSNode) -> TSNode | None: + if ts_node.type == "function_declaration": + return ts_node.child_by_field_name("name") + elif ts_node.type == "function_expression": + if name := ts_node.child_by_field_name("name"): + return name + return ts_node.parent.child_by_field_name("name") + elif ts_node.type == "arrow_function": + ts_node = ts_node.parent + while ts_node.type in ("parenthesized_expression", "binary_expression"): + ts_node = ts_node.parent + if ts_node.type == "pair": + return ts_node.child_by_field_name("key") + elif ts_node.type == "return_statement": + func_expression = next((x for x in ts_node.children if x.type == ("function_expression")), None) + if func_expression: + return func_expression.child_by_field_name("name") + return ts_node.child_by_field_name("name") + + @property + @reader + def function_signature(self) -> str: + """Returns a string representation of the function's signature. + + Generates a string containing the full function signature including name, parameters, and return type + based on the function's type (arrow function, generator function, function expression, etc.). + + Returns: + str: A string containing the complete function signature. For example: 'function foo(bar: string): number' + + Raises: + NotImplementedError: If the function type is not implemented. + """ + if self.function_type == TSFunctionTypeNames.FunctionDeclaration: + func_def_src = f"function {self.name}" + elif self.function_type == TSFunctionTypeNames.GeneratorFunctionDeclaration: + func_def_src = f"function* {self.name}" + elif self.function_type == TSFunctionTypeNames.ArrowFunction: + func_def_src = f"{self.name} = " + elif self.function_type == TSFunctionTypeNames.FunctionExpression: + func_def_src = f"{self.name} = function" + else: + msg = "function type not implemented" + raise NotImplementedError(msg) + if self.parameters is not None: + func_def_src += self.parameters.source + if self.return_type: + func_def_src += ": " + self.return_type.source + return func_def_src + + @cached_property + @reader + def is_private(self) -> bool: + """Determines if a function is private based on its accessibility modifier. + + This property examines the function's accessibility modifier to determine if it's marked as private. In TypeScript, this means the function has the 'private' keyword. + + Returns: + bool: True if the function has a 'private' accessibility modifier, False otherwise. + """ + modifier = self.ts_node.children[0] + return modifier.type == "accessibility_modifier" and modifier.text == b"private" + + @cached_property + @reader + def is_magic(self) -> bool: + """Returns whether this method is a magic method. + + A magic method is a method whose name starts and ends with double underscores, like __init__ or __str__. + In this implementation, all methods are considered non-magic in TypeScript. + + Returns: + bool: False, as TypeScript does not have magic methods. + """ + return False + + @property + @reader + def is_anonymous(self) -> bool: + """Property indicating whether a function is anonymous. + + Returns True if the function has no name or if its name is an empty string. + + Returns: + bool: True if the function is anonymous, False otherwise. + """ + return not self.name or self.name.strip() == "" + + @property + def is_async(self) -> bool: + """Determines if the function is asynchronous. + + Checks the function's node children to determine if the function is marked as asynchronous. + + Returns: + bool: True if the function is asynchronous (has 'async' keyword), False otherwise. + """ + return any("async" == x.type for x in self.ts_node.children) + + @property + @reader + def is_arrow(self) -> bool: + """Returns True iff the function is an arrow function. + + Identifies whether the current function is an arrow function (lambda function) in TypeScript/JavaScript. + + Returns: + bool: True if the function is an arrow function, False otherwise. + """ + return self.function_type == TSFunctionTypeNames.ArrowFunction + + @property + @reader + def is_property(self) -> bool: + """Determines if the function is a property. + + Checks if any of the function's decorators are '@property' or '@cached_property'. + + Returns: + bool: True if the function has a @property or @cached_property decorator, False otherwise. + """ + return any(dec in ("@property", "@cached_property") for dec in self.decorators) + + @property + @reader + def _named_arrow_function(self) -> TSNode | None: + """Returns the name of the named arrow function, if it exists.""" + if self.is_arrow: + node = self.ts_node + if name := self.get_name(): + node = name.ts_node + parent = node.parent + if parent.type == "variable_declarator": + return parent.parent + return None + + @property + @reader + def is_jsx(self) -> bool: + """Determines if the function is a React component by checking if it returns a JSX element. + + A function is considered a React component if it contains at least one JSX element in its body + and either has no name or has a name that starts with an uppercase letter. + + Returns: + bool: True if the function is a React component, False otherwise. + """ + # Must contain a React component + if len(self.jsx_elements) == 0: + return False + # Must be uppercase name + if not self.name: + return True + return self.name[0].isupper() + + #################################################################################################################### + # MANIPULATIONS + #################################################################################################################### + + @writer + def asyncify(self) -> None: + """Modifies the function to be asynchronous, if it is not already. + + This method converts a synchronous function to be asynchronous by adding the 'async' keyword and wrapping + the return type in a Promise if a return type exists. + + Returns: + None + + Note: + If the function is already asynchronous, this method does nothing. + """ + if self.is_async: + return + self.add_keyword("async") + if self.return_type and self.return_type.name != "Promise": + self.return_type.insert_before("Promise<", newline=False) + self.return_type.insert_after(">", newline=False) + + @writer + def arrow_to_named(self, name: str | None = None) -> None: + """Converts an arrow function to a named function in TypeScript/JavaScript. + + Transforms an arrow function into a named function declaration, preserving type parameters, parameters, + return types, and function body. If the function is already asynchronous, the async modifier is preserved. + + Args: + name (str | None): The name for the converted function. If None, uses the name of the variable + the arrow function is assigned to. + + Returns: + None + + Raises: + ValueError: If name is None and the arrow function is not assigned to a named variable. + """ + if not self.is_arrow or self.name is None: + return + + if name is None and self._name_node is None: + msg = "The `name` argument must be provided when converting an arrow function that is not assigned to any variable." + raise ValueError(msg) + + node = self._named_arrow_function + # Replace variable declaration with function declaration + async_prefix = "async " if self.is_async else "" + edit_start = node.start_byte + type_param_node = self.ts_node.child_by_field_name("type_parameters") + if param_node := self.ts_node.child_by_field_name("parameters"): + edit_end = param_node.start_byte + self._edit_byte_range(f"{async_prefix}function {name or self.name}{type_param_node.text.decode('utf-8') if type_param_node else ''}", edit_start, edit_end) + elif param_node := self.ts_node.child_by_field_name("parameter"): + edit_end = param_node.start_byte + self._edit_byte_range(f"{async_prefix}function {name or self.name}{type_param_node.text.decode('utf-8') if type_param_node else ''}(", edit_start, edit_end) + self.insert_at(param_node.end_byte, ")") + + # Remove the arrow => + if self.return_type: + remove_start = self.return_type.end_byte + 1 + else: + remove_start = param_node.end_byte + 1 + self.remove_byte_range(remove_start, self.code_block.start_byte) + + # Add brackets surrounding the code block if not already present + if not self.code_block.source.startswith("{"): + self.insert_at(self.code_block.start_byte, "{ return ") + self.insert_at(node.end_byte, " }") + + # Move over variable type annotations as parameter type annotations + if (type_node := node.named_children[0].child_by_field_name("type")) and len(param_node.named_children) == 1: + destructured_param = self.parameters.ts_node.named_children[0] + self.insert_at(destructured_param.end_byte, type_node.text.decode("utf-8")) + + @noapidoc + @reader + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: + """Resolves the name of a symbol in the function. + + This method resolves the name of a symbol in the function. If the name is "this", it returns the parent class. + Otherwise, it calls the superclass method to resolve the name. + + Args: + name (str): The name of the symbol to resolve. + start_byte (int | None): The start byte of the symbol to resolve. + strict (bool): If True considers candidates that don't satisfy start byte if none do. + + Returns: + Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import, or None if not found. + """ + if self.is_method: + if name == "this": + yield self.parent_class + return + yield from super().resolve_name(name, start_byte, strict=strict) + + @staticmethod + def is_valid_node(node: TSNode) -> bool: + """Determines if a given tree-sitter node corresponds to a valid function type. + + This method checks if a tree-sitter node's type matches one of the valid function types defined in the _VALID_TYPE_NAMES set. + + Args: + node (TSNode): The tree-sitter node to validate. + + Returns: + bool: True if the node's type is a valid function type, False otherwise. + """ + return node.type in _VALID_TYPE_NAMES + + @writer + def convert_props_to_interface(self) -> None: + """Converts React component props to TypeScript interfaces. + + For React components, converts inline props type definitions and PropTypes declarations + to a separate interface. The interface will be named {ComponentName}Props and inserted + before the component. + + Handles both simple types and complex types including: + - Inline object type definitions + - PropTypes declarations + - Union types and optional props + - Destructured parameters + - Generic type parameters + + Example: + ```typescript + // Before + function Button({ text, onClick }: { text: string, onClick: () => void }) { + return ; + } + + // After + interface ButtonProps { + text: string; + onClick: () => void; + } + function Button({ text, onClick }: ButtonProps) { + return ; + } + ``` + """ + if self.parameters and len(self.parameters) > 0: + if interface_name := self.convert_to_react_interface(): + if not self.parameters[0].is_destructured: + self.parameters[0].edit(interface_name) + else: + self.insert_at(self.parameters.ts_node.end_byte - 1, f": {interface_name}") + + @property + @reader + def promise_chains(self) -> list[TSPromiseChain]: + """Returns a list of promise chains in the function. + + Returns: + list[TSPromiseChain]: A list of promise chains in the function. + """ + promise_chains = [] + visited_base_functions = set() + function_calls = self.function_calls + + for function_call in function_calls: + if function_call.name == "then" and function_call.base not in visited_base_functions: + promise_chains.append(function_call.promise_chain) + visited_base_functions.add(function_call.base) + + return promise_chains diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/import_resolution.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/import_resolution.py new file mode 100644 index 00000000..fdd3575e --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/import_resolution.py @@ -0,0 +1,648 @@ +from __future__ import annotations + +import os +from collections import deque +from typing import TYPE_CHECKING, Self, override + +from graph_sitter.core.autocommit import reader +from graph_sitter.core.expressions import Name +from graph_sitter.core.import_resolution import Import, ImportResolution, WildcardImport +from graph_sitter.core.interfaces.exportable import Exportable +from graph_sitter.enums import ImportType, NodeType, SymbolType +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.utils import find_all_descendants, find_first_ancestor, find_first_descendant + +if TYPE_CHECKING: + from collections.abc import Generator + + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.external_module import ExternalModule + from graph_sitter.core.interfaces.editable import Editable + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.core.statements.import_statement import ImportStatement + from graph_sitter.core.symbol import Symbol + from graph_sitter.typescript.file import TSFile + from graph_sitter.typescript.namespace import TSNamespace + from graph_sitter.typescript.statements.import_statement import TSImportStatement + + +@ts_apidoc +class TSImport(Import["TSFile"], Exportable): + """Extends Import for TypeScript codebases.""" + + @reader + def is_type_import(self) -> bool: + """Checks if an import is a type import. + + Determines whether an import statement is specifically for types. This includes explicit type imports + (e.g., 'import type foo from bar'), exports of types, and dynamic imports followed by property access. + + Returns: + bool: True if the import is a type import, False otherwise. + """ + if self.ts_node.type == "import_statement": + return self.source.startswith("import type ") + elif self.ts_node.type == "export_statement": + return self.source.startswith("export type ") + elif call_node := find_first_descendant(self.ts_node, ["call_expression"]): + # If the import is an import using functions `import` or `require`, + # assume it is a type import if it is followed by a dot notation + while call_node.parent and call_node.parent.type in ["await_expression", "parenthesized_expression"]: + call_node = call_node.parent + sibling = call_node.next_named_sibling + return sibling and sibling.type == "property_identifier" + return False + + @reader + def is_module_import(self) -> bool: + """Determines if an import represents a module-level import. + + Module imports represent imports of an entire file rather than specific symbols from a file. + These imports must traverse through the file to resolve the actual symbol(s) being imported. + + Args: + self (TSImport): The import object to check. + + Returns: + bool: True if the import is a module-level import, False otherwise. + Returns True for: + - Imports of type MODULE, WILDCARD, or DEFAULT_EXPORT + - Side effect imports that are not type imports + """ + if self.import_type in [ImportType.MODULE, ImportType.WILDCARD, ImportType.DEFAULT_EXPORT]: + return True + return self.import_type == ImportType.SIDE_EFFECT and not self.is_type_import() + + @reader + def is_default_import(self) -> bool: + """Determines whether the import is a default export import. + + Checks if the import is importing a default export from a module. The default export + may be a single symbol or an entire module. + + Args: + self (TSImport): The import instance. + + Returns: + bool: True if the import is a default export import, False otherwise. + """ + return self.import_type == ImportType.DEFAULT_EXPORT + + @property + @reader + def namespace(self) -> str | None: + """If import is a module import, returns any namespace prefix that must be used with import reference. + + Returns the namespace prefix for import reference when the import is a module import, specifically when + the import resolves to a file node_type. The namespace is determined by the alias if set, otherwise None. + + Returns: + str | None: The alias name if the import resolves to a file node_type and has an alias, + None otherwise. + """ + resolved_symbol = self.resolved_symbol + if resolved_symbol is not None and resolved_symbol.node_type == NodeType.FILE: + return self.alias.source if self.alias is not None else None + return None + + @property + @reader + def imported_exports(self) -> list[Exportable]: + """Returns the enumerated list of exports imported from a module import. + + Returns a list of exports that this import statement references. The exports can be direct exports + or re-exports from other modules. + + Returns: + list[Exportable]: List of exported symbols. Empty list if this import doesn't reference any exports + or if imported_symbol is None. + """ + if self.imported_symbol is None: + return [] + + if not self.is_module_import(): + return [] if self.imported_symbol.export is None else [self.imported_symbol.export] + + from_file = self.imported_symbol + if from_file.node_type != NodeType.FILE: + return [] + + if self.is_default_import(): + return from_file.default_exports + + return from_file.exports + + @property + @reader + def resolved_symbol(self) -> Symbol | ExternalModule | TSFile | None: + """Returns the resolved symbol that the import is referencing. + + Follows the imported symbol and returns the final symbol it resolves to. For default imports, resolves to the exported symbol. + For module imports with matching symbol names, resolves through module imports to find the matching symbol. + For indirect imports, follows the import chain to find the ultimate symbol. + + Returns: + Union[Symbol, ExternalModule, TSFile, None]: The resolved symbol. Returns None if the import cannot be resolved, + Symbol for resolved import symbols, ExternalModule for external module imports, + or TSFile for module/file imports. + """ + imports_seen = set() + resolved_symbol = self.imported_symbol + + if resolved_symbol is None: + return None + + # If the default import is a single symbol export, resolve to the symbol + if self.is_default_import(): + if resolved_symbol is not None and resolved_symbol.node_type == NodeType.FILE: + file = resolved_symbol + if len(file.default_exports) == 1 and (export_symbol := file.default_exports[0]).is_default_symbol_export(): + while export_symbol and export_symbol.node_type == NodeType.EXPORT: + export_symbol = export_symbol.exported_symbol + resolved_symbol = export_symbol + + # If the imported symbol is a file even though the import is not a module import, + # hop through the file module imports to resolve the symbol that matches the import symbol name + if resolved_symbol and resolved_symbol.node_type == NodeType.FILE and not self.is_module_import(): + # Perform BFS search on the file's module imports to find the resolved symbol + module_imps_seen = set() + module_imports_to_search = deque([imp for imp in resolved_symbol.imports if imp.is_module_import()]) + while module_imports_to_search: + module_imp = module_imports_to_search.popleft() + if module_imp in module_imps_seen: + continue + + module_imps_seen.add(module_imp) + # Search through all the symbols that this module imp is potentially importing! + for export in module_imp.imported_exports: + if export.is_named_export(): + # TODO: Why does this break? When is symbol_name None? + if self.symbol_name is not None and export.name == self.symbol_name.source: + resolved_symbol = export.resolved_symbol + break + else: + exported_symbol = export.exported_symbol + if isinstance(exported_symbol, TSImport) and exported_symbol.is_module_import(): + module_imports_to_search.append(exported_symbol) + + # If the imported symbol is an indirect import, hop through the import resolution edges + while resolved_symbol is not None and resolved_symbol.node_type == NodeType.IMPORT: + if resolved_symbol in imports_seen: + return resolved_symbol + + imports_seen.add(resolved_symbol) + resolved_symbol = resolved_symbol.imported_symbol + + return resolved_symbol + + @reader + def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSFile] | None: + """Resolves an import statement to its target file and symbol. + + This method is used by GraphBuilder to resolve import statements to their target files and symbols. It handles both relative and absolute imports, + and supports various import types including named imports, default imports, and module imports. + + Args: + base_path (str | None): The base path to resolve imports from. If None, uses the codebase's base path + or the tsconfig base URL. + + Returns: + ImportResolution[TSFile] | None: An ImportResolution object containing the resolved file and symbol, + or None if the import could not be resolved (treated as an external module). + The ImportResolution contains: + - from_file: The file being imported from + - symbol: The specific symbol being imported (None for module imports) + - imports_file: True if importing the entire file/module + """ + try: + self.file: TSFile # Type cast ts_file + base_path = base_path or self.ctx.projects[0].base_path or "" + + # Get the import source path + import_source = self.module.source.strip('"').strip("'") if self.module else "" + + # Try to resolve the import using the tsconfig paths + if self.file.ts_config: + import_source = self.file.ts_config.translate_import_path(import_source) + + # Check if need to resolve relative import path to absolute path + relative_import = False + if import_source.startswith("."): + relative_import = True + + # Insert base path + # This has the happen before the relative path resolution + if not import_source.startswith(base_path): + import_source = os.path.join(base_path, import_source) + + # If the import is relative, convert it to an absolute path + if relative_import: + import_source = self._relative_to_absolute_import(import_source) + else: + import_source = os.path.normpath(import_source) + + # covers the case where the import is from a directory ex: "import { postExtract } from './post'" + import_name = import_source.split("/")[-1] + if "." not in import_name: + possible_paths = ["index.ts", "index.js", "index.tsx", "index.jsx"] + for p_path in possible_paths: + if self.ctx.to_absolute(os.path.join(import_source, p_path)).exists(): + import_source = os.path.join(import_source, p_path) + break + + # Loop through all extensions and try to find the file + extensions = ["", ".ts", ".d.ts", ".tsx", ".d.tsx", ".js", ".jsx"] + # Try both filename with and without extension + for import_source_base in (import_source, os.path.splitext(import_source)[0]): + for extension in extensions: + import_source_ext = import_source_base + extension + if file := self.ctx.get_file(import_source_ext): + if self.is_module_import(): + return ImportResolution(from_file=file, symbol=None, imports_file=True) + else: + # If the import is a named import, resolve to the named export in the file + if self.symbol_name is None: + return ImportResolution(from_file=file, symbol=None, imports_file=True) + export_symbol = file.get_export(export_name=self.symbol_name.source) + if export_symbol is None: + # If the named export is not found, it is importing a module re-export. + # In this case, resolve to the file itself and dynamically resolve the symbol later. + return ImportResolution(from_file=file, symbol=None, imports_file=True) + return ImportResolution(from_file=file, symbol=export_symbol) + + # If the imported file is not found, treat it as an external module + return None + except AssertionError: + # Codebase is probably trying to import file from outside repo + return None + + @noapidoc + @reader + def _relative_to_absolute_import(self, relative_import: str) -> str: + """Helper to go from a relative import to an absolute one. + Ex: "./foo/bar" in "src/file.ts" would be -> "src/foo/bar" + Ex: "../foo/bar" in "project/src/file.ts" would be -> "project/foo/bar" + """ + import_file_path = self.to_file.file_path # the filepath the import is in + import_dir = os.path.dirname(import_file_path) # the directory of the file this import is in + absolute_import = os.path.join(import_dir, relative_import) # absolute path of the import + normalized_absolute_import = os.path.normpath(absolute_import) # normalized absolute path of the import. removes redundant separators and './' or '../' segments. + return normalized_absolute_import + + @classmethod + @noapidoc + def from_export_statement(cls, source_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSImportStatement) -> list[TSImport]: + """Constructs import objects defined from an export statement""" + export_statement_node = find_first_ancestor(source_node, ["export_statement"]) + imports = [] + if export_clause := next((child for child in export_statement_node.named_children if child.type == "export_clause"), None): + # === [ Named export import ] === + # e.g. export { default as subtract } from './subtract'; + for export_specifier in export_clause.named_children: + name = export_specifier.child_by_field_name("name") + alias = export_specifier.child_by_field_name("alias") or name + import_type = ImportType.DEFAULT_EXPORT if (name and name.text.decode("utf-8") == "default") else ImportType.NAMED_EXPORT + imp = cls(ts_node=export_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=source_node, name_node=name, alias_node=alias, import_type=import_type) + imports.append(imp) + else: + # ==== [ Wildcard export import ] ==== + # Note: re-exporting using wildcard syntax does NOT include the default export! + if namespace_export := next((child for child in export_statement_node.named_children if child.type == "namespace_export"), None): + # Aliased wildcard export (e.g. export * as myNamespace from './m';) + alias = next(child for child in namespace_export.named_children if child.type == "identifier") or namespace_export + imp = cls( + ts_node=export_statement_node, + file_node_id=file_node_id, + ctx=ctx, + parent=parent, + module_node=source_node, + name_node=namespace_export, + alias_node=alias, + import_type=ImportType.WILDCARD, + ) + imports.append(imp) + else: + # No alias wildcard export (e.g. export * from './m';) + imp = cls(ts_node=export_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=source_node, name_node=None, alias_node=None, import_type=ImportType.WILDCARD) + imports.append(imp) + return imports + + @classmethod + @noapidoc + def from_import_statement(cls, import_statement_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSImportStatement) -> list[TSImport]: + source_node = import_statement_node.child_by_field_name("source") + import_clause = next((x for x in import_statement_node.named_children if x.type == "import_clause"), None) + if import_clause is None: + # === [ Side effect module import ] === + # Will not have any import usages in the file! (e.g. import './module';) + return [cls(ts_node=import_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=source_node, name_node=None, alias_node=None, import_type=ImportType.SIDE_EFFECT)] + + imports = [] + for import_type_node in import_clause.named_children: + if import_type_node.type == "identifier": + # === [ Default export import ] === + # e.g. import a from './module' + imp = cls( + ts_node=import_statement_node, + file_node_id=file_node_id, + ctx=ctx, + parent=parent, + module_node=source_node, + name_node=import_type_node, + alias_node=import_type_node, + import_type=ImportType.DEFAULT_EXPORT, + ) + imports.append(imp) + elif import_type_node.type == "named_imports": + # === [ Named export import ] === + # e.g. import { a, b as c } from './module'; + for import_specifier in import_type_node.named_children: + # Skip comment nodes + if import_specifier.type == "comment": + continue + + name_node = import_specifier.child_by_field_name("name") + alias_node = import_specifier.child_by_field_name("alias") or name_node + imp = cls( + ts_node=import_statement_node, + file_node_id=file_node_id, + ctx=ctx, + parent=parent, + module_node=source_node, + name_node=name_node, + alias_node=alias_node, + import_type=ImportType.NAMED_EXPORT, + ) + imports.append(imp) # MODIFY IMPORT HERE ? + elif import_type_node.type == "namespace_import": + # === [ Wildcard module import ] === + # Imports both default and named exports e.g. import * as someAlias from './module'; + alias_node = next(x for x in import_type_node.named_children if x.type == "identifier") + imp = cls( + ts_node=import_statement_node, + file_node_id=file_node_id, + ctx=ctx, + module_node=source_node, + parent=parent, + name_node=import_type_node, + alias_node=alias_node, + import_type=ImportType.WILDCARD, + ) + imports.append(imp) + return imports + + @classmethod + @noapidoc + def from_dynamic_import_statement(cls, import_call_node: TSNode, module_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: ImportStatement) -> list[TSImport]: + """Parses a dynamic import statement, given a reference to the `import`/`require` node and `module` node. + e.g. + const myModule = await import('./someFile')`; + const { exportedFunction, exportedVariable: aliasedVariable } = await import('./someFile'); + import('./someFile'); + + const myModule = require('./someFile')`; + const { exportedFunction, exportedVariable: aliasedVariable } = require('./someFile'); + require('./someFile'); + Note: imports using `require` will import whatever is defined in `module.exports = ...` or `export = ...` + """ + if module_node is None: + # TODO: fixme + return [] + imports = [] + + # TODO: FIX THIS, is a horrible hack to avoid a crash on the next.js + if len(module_node.named_children) == 0: + return [] + + # Grab the first element of dynamic import call expression argument list + module_node = module_node.named_children[0] + + # Get the top most parent of call expression node that bypasses wrappers that doesn't change the semantics + call_node = find_first_ancestor(import_call_node, ["call_expression"]) + while call_node.parent and call_node.parent.type in ["await_expression", "parenthesized_expression", "binary_expression", "ternary_expression"]: + call_node = call_node.parent + + import_statement_node = call_node.parent + if import_statement_node.type == "expression_statement": + # ==== [ Side effect module import ] ==== + # Will not have any import usages in the file! (e.g. await import('./module');) + imp = cls(ts_node=import_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=module_node, name_node=None, alias_node=None, import_type=ImportType.SIDE_EFFECT) + imports.append(imp) + else: + if import_statement_node.type == "member_expression": + # ==== [ Type import ] ==== + # Imports a type defined in module -- in javascript, type imports are entirely emitted + # e.g. type DynamicType = typeof import('./module').SomeType; + # const MyType = typeof import('./module').SomeType; + # const DefaultType = (await import('./module')).default + # import('./module').SomeType + # function foo(param: import('./module').SomeType) {} + name_node = import_statement_node.child_by_field_name("property") + parent_type_names = ["type_alias_declaration", "variable_declarator", "assignment_expression", "expression_statement"] + import_statement_node = find_first_ancestor(import_statement_node, parent_type_names, max_depth=2) or import_statement_node + else: + name_type_name = "left" if import_statement_node.type == "assignment_expression" else "name" + name_node = import_statement_node.child_by_field_name(name_type_name) + + # TODO: Handle dynamic import name not found (CG-8722) + if name_node is None: + alias_node = import_statement_node.child_by_field_name("name") or import_statement_node.child_by_field_name("left") + imp = cls( + ts_node=import_statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=module_node, name_node=None, alias_node=alias_node, import_type=ImportType.SIDE_EFFECT + ) + imports.append(imp) + return imports + + # If import statement is a variable declaration, capture the variable scoping keyword (const, let, var, etc) + if import_statement_node.type == "lexical_declaration": + statement_node = import_statement_node + else: + statement_node = import_statement_node.parent if import_statement_node.type in ["variable_declarator", "assignment_expression"] else import_statement_node + + # ==== [ Named dynamic import ] ==== + if name_node.type == "property_identifier": + # If the type import is being stored into a variable, get the alias + if import_statement_node.type in ["type_alias_declaration", "variable_declarator"]: + alias_node = import_statement_node.child_by_field_name("name") + elif import_statement_node.type == "assignment_expression": + alias_node = import_statement_node.child_by_field_name("left") + else: + alias_node = name_node + import_type = ImportType.DEFAULT_EXPORT if name_node.text.decode("utf-8") == "default" else ImportType.NAMED_EXPORT + imp = cls(ts_node=statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=module_node, name_node=name_node, alias_node=alias_node, import_type=import_type) + imports.append(imp) + elif name_node.type == "identifier": + # ==== [ Aliased module import ] ==== + # Imports both default and named exports (e.g. const moduleImp = await import('./module');) + imp = cls(ts_node=statement_node, file_node_id=file_node_id, ctx=ctx, parent=parent, module_node=module_node, name_node=name_node, alias_node=name_node, import_type=ImportType.MODULE) + imports.append(imp) + elif name_node.type == "object_pattern": + # ==== [ Deconstructed import ] ==== + for imported_symbol in name_node.named_children: + if imported_symbol.type == "shorthand_property_identifier_pattern": + # ==== [ Named export import ] ==== + # e.g. const { symbol } = await import('./module') + imp = cls( + ts_node=statement_node, + file_node_id=file_node_id, + ctx=ctx, + parent=parent, + module_node=module_node, + name_node=imported_symbol, + alias_node=imported_symbol, + import_type=ImportType.NAMED_EXPORT, + ) + imports.append(imp) + elif imported_symbol.type == "pair_pattern": + # ==== [ Aliased named export import ] ==== + # e.g. const { symbol: aliasedSymbol } = await import('./module') + name_node = imported_symbol.child_by_field_name("key") + alias_node = imported_symbol.child_by_field_name("value") + imp = cls( + ts_node=statement_node, + file_node_id=file_node_id, + ctx=ctx, + parent=parent, + module_node=module_node, + name_node=name_node, + alias_node=alias_node, + import_type=ImportType.NAMED_EXPORT, + ) + imports.append(imp) + else: + continue + # raise ValueError(f"Unexpected alias name node type {imported_symbol.type}") + return imports + + @property + @reader + def import_specifier(self) -> Editable: + """Retrieves the import specifier node for this import. + + Finds and returns the import specifier node containing this import's name and optional alias. + For named imports, this is the import_specifier or export_specifier node. + For other imports, this is the identifier node containing the import name. + + Returns: + Editable: The import specifier node containing this import's name and alias. + For named imports, returns the import_specifier/export_specifier node. + For other imports, returns the identifier node containing the import name. + Returns None if no matching specifier is found. + """ + import_specifiers = find_all_descendants(self.ts_node, {"import_specifier", "export_specifier"}) + for import_specifier in import_specifiers: + alias = import_specifier.child_by_field_name("alias") + if alias is not None: + is_match = self.alias.source == alias.text.decode("utf-8") + else: + name = import_specifier.child_by_field_name("name") + is_match = self.symbol_name.source == name.text.decode("utf-8") + if is_match: + return Name(import_specifier, self.file_node_id, self.ctx, self) + if named := next(iter(find_all_descendants(self.ts_node, {"identifier"})), None): + if named.text.decode("utf-8") == self.symbol_name.source: + return Name(named, self.file_node_id, self.ctx, self) + + @reader + def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: + """Generates an import string for an import statement. + + Generates a string representation of an import statement with optional type and alias information. + + Args: + alias (str | None): Alias name for the imported symbol. Defaults to None. + module (str | None): Module name to import from. Defaults to None. If not provided, uses the file's import module name. + import_type (ImportType): Type of import (e.g. WILDCARD, NAMED_EXPORT). Defaults to ImportType.UNKNOWN. + is_type_import (bool): Whether this is a type import. Defaults to False. + + Returns: + str: A string representation of the import statement. + """ + type_prefix = "type " if is_type_import else "" + import_module = module if module is not None else self.file.import_module_name + + if import_type == ImportType.WILDCARD: + file_as_module = self.file.name + return f"import {type_prefix}* as {file_as_module} from {import_module};" + elif alias is not None and alias != self.name: + return f"import {type_prefix}{{ {self.name} as {alias} }} from {import_module};" + else: + return f"import {type_prefix}{{ {self.name} }} from {import_module};" + + @property + @noapidoc + @override + def names(self) -> Generator[tuple[str, Self | WildcardImport[Self]], None, None]: + if self.import_type == ImportType.SIDE_EFFECT: + return + yield from super().names + + @property + def namespace_imports(self) -> list[TSNamespace]: + """Returns any namespace objects imported by this import statement. + + For example: + import * as MyNS from './mymodule'; + + Returns: + List of namespace objects imported + """ + if not self.is_namespace_import(): + return [] + + from graph_sitter.typescript.namespace import TSNamespace + + resolved = self.resolved_symbol + if resolved is None or not isinstance(resolved, TSNamespace): + return [] + + return [resolved] + + @property + def is_namespace_import(self) -> bool: + """Returns True if this import is importing a namespace. + + Examples: + import { MathUtils } from './file1'; # True if MathUtils is a namespace + import * as AllUtils from './utils'; # True + """ + # For wildcard imports with namespace alias + if self.import_type == ImportType.WILDCARD and self.namespace: + return True + + # For named imports, check if any imported symbol is a namespace + if self.import_type == ImportType.NAMED_EXPORT: + for name, _ in self.names: + symbol = self.resolved_symbol + if symbol and symbol.symbol_type == SymbolType.Namespace: + return True + + return False + + @override + def set_import_module(self, new_module: str) -> None: + """Sets the module of an import. + + Updates the module of an import statement while maintaining the import symbol. + Uses single quotes by default (TypeScript standard), falling back to double quotes + only if the path contains single quotes. + + Args: + new_module (str): The new module path to import from. + + Returns: + None + """ + if self.module is None: + return + + # If already quoted, use as is + if (new_module.startswith('"') and new_module.endswith('"')) or (new_module.startswith("'") and new_module.endswith("'")): + self.module.source = new_module + return + + # Use double quotes if path contains single quotes, otherwise use single quotes (TypeScript standard) + quote = '"' if "'" in new_module else "'" + self.module.source = f"{quote}{new_module}{quote}" diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/interface.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/interface.py new file mode 100644 index 00000000..38452a4c --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/interface.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + +from graph_sitter.core.autocommit import commiter, reader +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.interface import Interface +from graph_sitter.core.symbol_groups.parents import Parents +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock +from graph_sitter.typescript.expressions.type import TSType +from graph_sitter.typescript.function import TSFunction +from graph_sitter.typescript.interfaces.has_block import TSHasBlock +from graph_sitter.typescript.statements.attribute import TSAttribute +from graph_sitter.typescript.symbol import TSSymbol + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.detached_symbols.code_block import CodeBlock + from graph_sitter.core.interfaces.has_name import HasName + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.core.statements.statement import Statement + +Parent = TypeVar("Parent", bound="TSHasBlock") + + +@ts_apidoc +class TSInterface(Interface[TSCodeBlock, TSAttribute, TSFunction, TSType], TSSymbol, TSHasBlock): + """Representation of an Interface in TypeScript + + Attributes: + parent_interfaces: All the interfaces that this interface extends. + code_block: The code block that contains the interface's body. + """ + + def __init__( + self, + ts_node: TSNode, + file_id: NodeId, + ctx: CodebaseContext, + parent: Statement[CodeBlock[Parent, ...]], + ) -> None: + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + + super().__init__(ts_node, file_id, ctx, parent) + body_node = ts_node.child_by_field_name("body") + + # Find the nearest parent with a code_block + current_parent = parent + while not hasattr(current_parent, "code_block"): + current_parent = current_parent.parent + + self.code_block = TSCodeBlock(body_node, current_parent.code_block.level + 1, current_parent.code_block, self) + self.code_block.parse() + + @commiter + @noapidoc + def parse(self, ctx: CodebaseContext) -> None: + # =====[ Extends ]===== + # Look for parent interfaces in the "extends" clause + if extends_clause := self.child_by_field_types("extends_type_clause"): + self.parent_interfaces = Parents(extends_clause.ts_node, self.file_node_id, self.ctx, self) + super().parse(ctx) + + @noapidoc + @commiter + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: + dest = dest or self.self_dest + + # =====[ Extends ]===== + if self.parent_interfaces is not None: + self.parent_interfaces._compute_dependencies(UsageKind.SUBCLASS, dest) + + # =====[ Body ]===== + # Look for type references in the interface body + self.code_block._compute_dependencies(usage_type, dest) + + @staticmethod + @noapidoc + def _get_name_node(ts_node: TSNode) -> TSNode | None: + if ts_node.type == "interface_declaration": + return ts_node.child_by_field_name("name") + return None + + @property + @reader + def attributes(self) -> list[TSAttribute]: + """Retrieves the list of attributes defined in the TypeScript interface. + + Args: + None + + Returns: + list[TSAttribute]: A list of the interface's attributes stored in the code block. + """ + return self.code_block.attributes diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/interfaces/has_block.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/interfaces/has_block.py new file mode 100644 index 00000000..5ff22e4a --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/interfaces/has_block.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Self + +from graph_sitter.compiled.utils import find_all_descendants +from graph_sitter.core.autocommit import reader, writer +from graph_sitter.core.detached_symbols.code_block import CodeBlock +from graph_sitter.core.interfaces.has_block import HasBlock +from graph_sitter.core.statements.statement import StatementType +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.typescript.detached_symbols.decorator import TSDecorator +from graph_sitter.typescript.statements.comment import TSComment, TSCommentType +from graph_sitter.typescript.symbol_groups.comment_group import TSCommentGroup +from graph_sitter.utils import find_index + +if TYPE_CHECKING: + from graph_sitter.typescript.detached_symbols.jsx.element import JSXElement + + +@ts_apidoc +class TSHasBlock(HasBlock["TSCodeBlock", TSDecorator]): + """A TypeScript base class that provides block-level code organization and decorator handling capabilities. + + This class extends the concept of block scoping for TypeScript code elements like classes and functions. + It provides functionality for managing code blocks, decorators, JSX elements, and documentation within + those blocks. The class supports operations such as retrieving and manipulating docstrings, + handling JSX components, and managing TypeScript decorators. + """ + + @property + @reader + def is_decorated(self) -> bool: + """Checks if the current symbol has a decorator. + + Determines if the symbol has a preceding decorator node. + + Returns: + bool: True if the symbol has a decorator node as its previous named sibling, + False otherwise. + """ + previous_sibling = self.ts_node.prev_named_sibling + # is decorated if it has a previous named sibling (i.e. the text above the function) and it is type=decorator + return previous_sibling and previous_sibling.type == "decorator" + + @property + @reader + def decorators(self) -> list[TSDecorator]: + """Returns a list of decorators associated with this symbol. + + Retrieves all decorators applied to this symbol by looking at both previous named siblings and decorator fields. + This includes both inline decorators and standalone decorator statements. + + Returns: + list[TSDecorator]: A list of TSDecorator objects representing all decorators applied to this symbol. + Returns an empty list if no decorators are found. + """ + decorators = [] + # Get all previous named siblings that are decorators, break once we hit a non decorator + prev_named_sibling = self.ts_node.prev_named_sibling + while prev_named_sibling and prev_named_sibling.type == "decorator": + decorators.append(TSDecorator(prev_named_sibling, self)) + prev_named_sibling = prev_named_sibling.prev_named_sibling + for child in self.ts_node.children_by_field_name("decorator"): + decorators.append(TSDecorator(child, self)) + return decorators + + @property + @reader + def jsx_elements(self) -> list[JSXElement[Self]]: + """Returns a list of all JSX elements contained within this symbol. + + Searches through the extended nodes of the symbol for any JSX elements or self-closing JSX elements + and returns them as a list of JSXElement objects. + + Args: + None + + Returns: + list[JSXElement[Self]]: A list of JSXElement objects contained within this symbol. + """ + jsx_elements = [] + for node in self.extended_nodes: + jsx_element_nodes = find_all_descendants(node.ts_node, {"jsx_element", "jsx_self_closing_element"}) + jsx_elements.extend([self._parse_expression(x) for x in jsx_element_nodes]) + return jsx_elements + + @reader + def get_component(self, component_name: str) -> JSXElement[Self] | None: + """Returns a specific JSX element from within this symbol's JSX elements. + + Searches through all JSX elements in this symbol's code block and returns the first one that matches + the given component name. + + Args: + component_name (str): The name of the JSX component to find. + + Returns: + JSXElement[Self] | None: The matching JSX element if found, None otherwise. + """ + for component in self.jsx_elements: + if component.name == component_name: + return component + return None + + @cached_property + @reader + def docstring(self) -> TSCommentGroup | None: + """Retrieves the docstring of a function or class. + + Returns any comments immediately preceding this node as a docstring. For nodes that are children of a HasBlock, it returns consecutive comments that end on the line before the node starts. + For other nodes, it returns formatted docstring comments. + + Returns: + TSCommentGroup | None: A CommentGroup representing the docstring if one exists, None otherwise. + """ + if self.parent.parent.parent and isinstance(self.parent.parent, CodeBlock): + comments = [] + sibling_statements = self.parent.parent.statements + index = find_index(self.ts_node, [x.ts_node for x in sibling_statements]) + if index == -1: + return None + + row = self.start_point[0] + for statement in reversed(sibling_statements[:index]): + if statement.end_point[0] != row - 1: + break + row = statement.start_point[0] + if statement.statement_type == StatementType.COMMENT: + comments.append(statement) + + return TSCommentGroup.from_comment_nodes(list(reversed(comments)), self) + + return TSCommentGroup.from_docstring(self) + + @writer + def set_docstring(self, docstring: str, auto_format: bool = True, clean_format: bool = True, leading_star: bool = True, force_multiline: bool = False) -> None: + """Sets or updates a docstring for a code element. + + Adds a new docstring if none exists, or updates the existing docstring. Handles formatting and placement + of the docstring according to the specified parameters. + + Args: + docstring (str): The docstring text to be added or updated. + auto_format (bool, optional): Whether to automatically format the text into a docstring format. Defaults to True. + clean_format (bool, optional): Whether to clean existing formatting from the docstring before inserting. Defaults to True. + leading_star (bool, optional): Whether to add leading "*" to each line of the comment block. Defaults to True. + force_multiline (bool, optional): Whether to force single line comments to be multi-line. Defaults to False. + + Returns: + None + """ + # Clean existing formatting off docstring + if clean_format: + docstring = TSComment.clean_comment(docstring) + + # If the docstring exists, edit it + if self.docstring: + if auto_format: + self.docstring.edit_text(docstring) + else: + self.docstring.edit(docstring) + else: + if auto_format: + docstring = TSComment.generate_comment(docstring, TSCommentType.SLASH_STAR, leading_star=leading_star, force_multiline=force_multiline) + # If a comment exists, insert the docstring after it + if self.comment: + self.comment.insert_after(docstring) + # If no comment exists, insert the docstring before the function + else: + self.extended.insert_before(docstring, fix_indentation=True) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/namespace.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/namespace.py new file mode 100644 index 00000000..e1bf9483 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/namespace.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Self, override + +from graph_sitter.compiled.autocommit import reader +from graph_sitter.compiled.resolution import ResolutionStack +from graph_sitter.compiled.sort import sort_editables +from graph_sitter.compiled.utils import cached_property +from graph_sitter.core.autocommit import commiter +from graph_sitter.core.autocommit.decorators import writer +from graph_sitter.core.export import Export +from graph_sitter.core.interfaces.chainable import Chainable +from graph_sitter.core.interfaces.has_attribute import HasAttribute +from graph_sitter.core.interfaces.has_name import HasName +from graph_sitter.enums import SymbolType +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.shared.logging.get_logger import get_logger +from graph_sitter.typescript.class_definition import TSClass +from graph_sitter.typescript.enum_definition import TSEnum +from graph_sitter.typescript.function import TSFunction +from graph_sitter.typescript.interface import TSInterface +from graph_sitter.typescript.interfaces.has_block import TSHasBlock +from graph_sitter.typescript.symbol import TSSymbol +from graph_sitter.typescript.type_alias import TSTypeAlias + +if TYPE_CHECKING: + from collections.abc import Generator, Sequence + + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.dataclasses.usage import UsageKind + from graph_sitter.core.interfaces.importable import Importable + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.core.statements.statement import Statement + from graph_sitter.core.symbol import Symbol + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + from graph_sitter.typescript.export import TSExport + from graph_sitter.typescript.import_resolution import TSImport + + +logger = get_logger(__name__) + + +@ts_apidoc +class TSNamespace( + TSSymbol, + TSHasBlock, + Chainable, + HasName, + HasAttribute, +): + """Representation of a namespace module in TypeScript. + + Attributes: + symbol_type: The type of the symbol, set to SymbolType.Namespace. + code_block: The code block associated with this namespace. + """ + + symbol_type = SymbolType.Namespace + code_block: TSCodeBlock + + def __init__(self, ts_node: TSNode, file_id: NodeId, ctx: CodebaseContext, parent: Statement, namespace_node: TSNode | None = None) -> None: + ts_node = namespace_node or ts_node + name_node = ts_node.child_by_field_name("name") + super().__init__(ts_node, file_id, ctx, parent, name_node=name_node) + + @noapidoc + @commiter + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: + """Computes dependencies for the namespace by analyzing its code block. + + Args: + usage_type: Optional UsageKind specifying how the dependencies are used + dest: Optional HasName destination for the dependencies + """ + # Use self as destination if none provided + dest = dest or self.self_dest + # Compute dependencies from namespace's code block + self.code_block._compute_dependencies(usage_type, dest) + + @cached_property + def symbols(self) -> list[Symbol]: + """Returns all symbols defined within this namespace, including nested ones.""" + all_symbols = [] + for stmt in self.code_block.statements: + if stmt.ts_node_type == "export_statement": + for export in stmt.exports: + all_symbols.append(export.declared_symbol) + elif hasattr(stmt, "assignments"): + all_symbols.extend(stmt.assignments) + else: + all_symbols.append(stmt) + return all_symbols + + def get_symbol(self, name: str, recursive: bool = True, get_private: bool = False) -> Symbol | None: + """Get an exported or private symbol by name from this namespace. Returns only exported symbols by default. + + Args: + name: Name of the symbol to find + recursive: If True, also search in nested namespaces + get_private: If True, also search in private symbols + + Returns: + Symbol | None: The found symbol, or None if not found + """ + # First check direct symbols in this namespace + for symbol in self.symbols: + # Handle TSAssignmentStatement case + if hasattr(symbol, "assignments"): + for assignment in symbol.assignments: + if assignment.name == name: + # If we are looking for private symbols then return it, else only return exported symbols + if get_private: + return assignment + elif assignment.is_exported: + return assignment + + # Handle regular symbol case + if hasattr(symbol, "name") and symbol.name == name: + if get_private: + return symbol + elif symbol.is_exported: + return symbol + + # If recursive and this is a namespace, check its symbols + if recursive and isinstance(symbol, TSNamespace): + nested_symbol = symbol.get_symbol(name, recursive=True, get_private=get_private) + return nested_symbol + + return None + + @reader(cache=False) + @noapidoc + def get_nodes(self, *, sort_by_id: bool = False, sort: bool = True) -> Sequence[Importable]: + """Returns all nodes in the namespace, sorted by position in the namespace.""" + file_nodes = self.file.get_nodes(sort_by_id=sort_by_id, sort=sort) + start_limit = self.start_byte + end_limit = self.end_byte + namespace_nodes = [] + for file_node in file_nodes: + if file_node.start_byte > start_limit: + if file_node.end_byte < end_limit: + namespace_nodes.append(file_node) + else: + break + return namespace_nodes + + @cached_property + @reader(cache=False) + def exports(self) -> list[TSExport]: + """Returns all Export symbols in the namespace. + + Retrieves a list of all top-level export declarations in the current TypeScript namespace. + + Returns: + list[TSExport]: A list of TSExport objects representing all top-level export declarations in the namespace. + """ + # Filter to only get exports that are direct children of the namespace's code block + return sort_editables(filter(lambda node: isinstance(node, Export), self.get_nodes(sort=False)), by_id=True) + + @cached_property + def functions(self) -> list[TSFunction]: + """Get all functions defined in this namespace. + + Returns: + List of Function objects in this namespace + """ + return [symbol for symbol in self.symbols if isinstance(symbol, TSFunction)] + + def get_function(self, name: str, recursive: bool = True) -> TSFunction | None: + """Get a function by name from this namespace. + + Args: + name: Name of the function to find + recursive: If True, also search in nested namespaces + """ + symbol = self.get_symbol(name, recursive=recursive) + return symbol if isinstance(symbol, TSFunction) else None + + @cached_property + def classes(self) -> list[TSClass]: + """Get all classes defined in this namespace. + + Returns: + List of Class objects in this namespace + """ + return [symbol for symbol in self.symbols if isinstance(symbol, TSClass)] + + def get_class(self, name: str, recursive: bool = True) -> TSClass | None: + """Get a class by name from this namespace. + + Args: + name: Name of the class to find + recursive: If True, also search in nested namespaces + """ + symbol = self.get_symbol(name, recursive=recursive) + return symbol if isinstance(symbol, TSClass) else None + + def get_interface(self, name: str, recursive: bool = True) -> TSInterface | None: + """Get an interface by name from this namespace. + + Args: + name: Name of the interface to find + recursive: If True, also search in nested namespaces + """ + symbol = self.get_symbol(name, recursive=recursive) + return symbol if isinstance(symbol, TSInterface) else None + + def get_type(self, name: str, recursive: bool = True) -> TSTypeAlias | None: + """Get a type alias by name from this namespace. + + Args: + name: Name of the type to find + recursive: If True, also search in nested namespaces + """ + symbol = self.get_symbol(name, recursive=recursive) + return symbol if isinstance(symbol, TSTypeAlias) else None + + def get_enum(self, name: str, recursive: bool = True) -> TSEnum | None: + """Get an enum by name from this namespace. + + Args: + name: Name of the enum to find + recursive: If True, also search in nested namespaces + """ + symbol = self.get_symbol(name, recursive=recursive) + return symbol if isinstance(symbol, TSEnum) else None + + def get_namespace(self, name: str, recursive: bool = True) -> TSNamespace | None: + """Get a namespace by name from this namespace. + + Args: + name: Name of the namespace to find + recursive: If True, also search in nested namespaces + + Returns: + TSNamespace | None: The found namespace, or None if not found + """ + # First check direct symbols in this namespace + for symbol in self.symbols: + if isinstance(symbol, TSNamespace) and symbol.name == name: + return symbol + + # If recursive and this is a namespace, check its symbols + if recursive and isinstance(symbol, TSNamespace): + nested_namespace = symbol.get_namespace(name, recursive=True) + return nested_namespace + + return None + + def get_nested_namespaces(self) -> list[TSNamespace]: + """Get all nested namespaces within this namespace. + + Returns: + list[TSNamespace]: List of all nested namespace objects + """ + nested = [] + for symbol in self.symbols: + if isinstance(symbol, TSNamespace): + nested.append(symbol) + nested.extend(symbol.get_nested_namespaces()) + return nested + + @writer + def add_symbol_from_source(self, source: str) -> None: + """Adds a symbol to a namespace from a string representation. + + This method adds a new symbol definition to the namespace by appending its source code string. The symbol will be added + after existing symbols if present, otherwise at the beginning of the namespace. + + Args: + source (str): String representation of the symbol to be added. This should be valid source code for + the file's programming language. + + Returns: + None: The symbol is added directly to the namespace's content. + """ + symbols = self.symbols + if len(symbols) > 0: + symbols[-1].insert_after("\n" + source, fix_indentation=True) + else: + self.insert_after("\n" + source) + + @commiter + def add_symbol(self, symbol: TSSymbol, should_export: bool = True) -> TSSymbol | None: + """Adds a new symbol to the namespace, optionally exporting it if applicable. If the symbol already exists in the namespace, returns the existing symbol. + + Args: + symbol: The symbol to add to the namespace (either a TSSymbol instance or source code string) + export: Whether to export the symbol. Defaults to True. + + Returns: + TSSymbol | None: The existing symbol if it already exists in the file or None if it was added. + """ + existing_symbol = self.get_symbol(symbol.name) + if existing_symbol is not None: + return existing_symbol + + if not self.file.symbol_can_be_added(symbol): + msg = f"Symbol {symbol.name} cannot be added to this file type." + raise ValueError(msg) + + source = symbol.source + if isinstance(symbol, TSFunction) and symbol.is_arrow: + raw_source = symbol._named_arrow_function.text.decode("utf-8") + else: + raw_source = symbol.ts_node.text.decode("utf-8") + if should_export and hasattr(symbol, "export") and (not symbol.is_exported or raw_source not in symbol.export.source): + source = source.replace(source, f"export {source}") + self.add_symbol_from_source(source) + + @commiter + def remove_symbol(self, symbol_name: str) -> TSSymbol | None: + """Removes a symbol from the namespace by name. + + Args: + symbol_name: Name of the symbol to remove + + Returns: + The removed symbol if found, None otherwise + """ + symbol = self.get_symbol(symbol_name) + if symbol: + # Remove from code block statements + for i, stmt in enumerate(self.code_block.statements): + if symbol.source == stmt.source: + logger.debug(f"stmt to be removed: {stmt}") + self.code_block.statements.pop(i) + return symbol + return None + + @commiter + def rename_symbol(self, old_name: str, new_name: str) -> None: + """Renames a symbol within the namespace. + + Args: + old_name: Current symbol name + new_name: New symbol name + """ + symbol = self.get_symbol(old_name) + if symbol: + symbol.rename(new_name) + + @commiter + @noapidoc + def export_symbol(self, name: str) -> None: + """Marks a symbol as exported in the namespace. + + Args: + name: Name of symbol to export + """ + symbol = self.get_symbol(name, get_private=True) + if not symbol or symbol.is_exported: + return + + export_source = f"export {symbol.source}" + symbol.parent.edit(export_source) + + @cached_property + @noapidoc + @reader(cache=True) + def valid_import_names(self) -> dict[str, TSSymbol | TSImport]: + """Returns set of valid import names for this namespace. + + This includes all exported symbols plus the namespace name itself + for namespace imports. + """ + valid_export_names = {} + valid_export_names[self.name] = self + for export in self.exports: + for name, dest in export.names: + valid_export_names[name] = dest + return valid_export_names + + def resolve_import(self, import_name: str) -> Symbol | None: + """Resolves an import name to a symbol within this namespace. + + Args: + import_name: Name to resolve + + Returns: + Resolved symbol or None if not found + """ + # First check direct symbols + for symbol in self.symbols: + if symbol.is_exported and symbol.name == import_name: + return symbol + + # Then check nested namespaces + for nested in self.get_nested_namespaces(): + resolved = nested.resolve_import(import_name) + if resolved is not None: + return resolved + + return None + + @override + def resolve_attribute(self, name: str) -> Symbol | None: + """Resolves an attribute access on the namespace. + + Args: + name: Name of the attribute to resolve + + Returns: + The resolved symbol or None if not found + """ + return self.valid_import_names.get(name, None) + + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + """Returns the resolved types for this namespace. + + This includes all exports and the namespace itself. + """ + yield ResolutionStack(self) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/placeholder/placeholder_return_type.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/placeholder/placeholder_return_type.py new file mode 100644 index 00000000..1e380e8d --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/placeholder/placeholder_return_type.py @@ -0,0 +1,44 @@ +from typing import TYPE_CHECKING, Generic, TypeVar + +from graph_sitter.core.placeholder.placeholder import Placeholder +from graph_sitter.shared.decorators.docs import ts_apidoc + +if TYPE_CHECKING: + from graph_sitter.core.interfaces.editable import Editable + +Parent = TypeVar("Parent", bound="Editable") + + +@ts_apidoc +class TSReturnTypePlaceholder(Placeholder[Parent], Generic[Parent]): + """A placeholder class for function return type annotations in TypeScript. + + This class represents a placeholder for function return type annotations, allowing for modification + and addition of return type annotations after the parameter list. It provides functionality to + add or modify return type annotations with proper formatting. + """ + + def edit(self, new_src: str, fix_indentation: bool = False, priority: int = 0, dedupe: bool = True) -> None: + """Modifies the return type annotation of a function. + + Adds or modifies the return type annotation of a function after its parameter list. + + Args: + new_src (str): The return type annotation to add. If it doesn't start with ':', a ':' will be prepended. + fix_indentation (bool, optional): Whether to fix the indentation of the added code. Defaults to False. + priority (int, optional): The priority of this edit. Defaults to 0. + dedupe (bool, optional): Whether to remove duplicate edits. Defaults to True. + + Returns: + None + + Note: + If new_src is empty or None, the method returns without making any changes. + """ + if new_src == "" or new_src is None: + return + if not new_src.startswith(": "): + new_src = ": " + new_src + + param_node = self._parent_node.child_by_field_name("parameters") + param_node.insert_after(new_src, newline=False) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/__init__.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/assignment_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/assignment_statement.py new file mode 100644 index 00000000..6ddc02b3 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/assignment_statement.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from collections import deque +from typing import TYPE_CHECKING, Self + +from graph_sitter.compiled.autocommit import reader +from graph_sitter.core.expressions.multi_expression import MultiExpression +from graph_sitter.core.statements.assignment_statement import AssignmentStatement +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.shared.logging.get_logger import get_logger +from graph_sitter.typescript.assignment import TSAssignment + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + from graph_sitter.typescript.interfaces.has_block import TSHasBlock + + +logger = get_logger(__name__) + + +@ts_apidoc +class TSAssignmentStatement(AssignmentStatement["TSCodeBlock", TSAssignment]): + """A class that represents a TypeScript assignment statement in a codebase, such as `const x = 1` or `const { a: b } = myFunc()`.""" + + assignment_types = {"assignment_expression", "augmented_assignment_expression", "variable_declarator", "public_field_definition", "property_signature"} + + @classmethod + @reader + @noapidoc + def from_assignment(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int, assignment_node: TSNode) -> TSAssignmentStatement: + """Creates an assignment statement node from a TreeSitter assignment node. + + This class method constructs a TSAssignmentStatement from a TreeSitter node representing an assignment. The method validates that the assignment node type is + one of the supported types: assignment_expression, augmented_assignment_expression, variable_declarator, public_field_definition, or property_signature. + + Args: + ts_node (TSNode): The TreeSitter node representing the entire statement. + file_node_id (NodeId): The identifier for the file containing this node. + ctx (CodebaseContext): The codebase context being constructed. + parent (TSHasBlock): The parent block containing this statement. + code_block (TSCodeBlock): The code block containing this statement. + pos (int): The position of this statement within its code block. + assignment_node (TSNode): The TreeSitter node representing the assignment. + + Returns: + TSAssignmentStatement: A new assignment statement node. + + Raises: + ValueError: If the assignment_node.type is not one of the supported assignment types. + """ + if assignment_node.type not in cls.assignment_types: + msg = f"Invalid assignment node type: {assignment_node.type}" + raise ValueError(msg) + + return cls(ts_node, file_node_id, ctx, parent, pos, assignment_node=assignment_node) + + def _parse_assignments(self, assignment_node: TSNode) -> MultiExpression[Self, TSAssignment]: + if assignment_node.type in ["assignment_expression", "augmented_assignment_expression"]: + return TSAssignment.from_assignment(assignment_node, self.file_node_id, self.ctx, self) + elif assignment_node.type in ["variable_declarator", "public_field_definition", "property_signature"]: + return TSAssignment.from_named_expression(assignment_node, self.file_node_id, self.ctx, self) + + logger.info(f"Unknown assignment type: {assignment_node.type}") + return MultiExpression(assignment_node, self.file_node_id, self.ctx, self.parent, [self.parent._parse_expression(assignment_node)]) + + def _DEPRECATED_parse_assignments(self) -> MultiExpression[TSHasBlock, TSAssignment]: + if self.ts_node.type in ["lexical_declaration", "variable_declaration"]: + return MultiExpression(self.ts_node, self.file_node_id, self.ctx, self.parent, self._DEPRECATED_parse_assignment_declarations()) + elif self.ts_node.type in ["expression_statement"]: + return MultiExpression(self.ts_node, self.file_node_id, self.ctx, self.parent, self._DEPRECATED_parse_assignment_expression()) + elif self.ts_node.type in ["public_field_definition", "property_signature", "enum_assignment"]: + return MultiExpression(self.ts_node, self.file_node_id, self.ctx, self.parent, self._DEPRECATED_parse_attribute_assignments()) + else: + msg = f"Unknown assignment type: {self.ts_node.type}" + raise ValueError(msg) + + def _DEPRECATED_parse_attribute_assignments(self) -> list[TSAssignment]: + left = self.ts_node.child_by_field_name("name") + right = self.ts_node.child_by_field_name("value") + return [TSAssignment(self.ts_node, self.file_node_id, self.ctx, self, left, right, left)] + + def _DEPRECATED_parse_assignment_declarations(self) -> list[TSAssignment]: + assignments = [] + for variable_declarator in self.ts_node.named_children: + if variable_declarator.type != "variable_declarator": + continue + left = variable_declarator.child_by_field_name("name") + type_node = variable_declarator.child_by_field_name("type") + right = variable_declarator.child_by_field_name("value") + if len(left.named_children) > 0: + to_parse: deque[tuple[TSNode, TSNode | None]] = deque([(left, type_node)]) + while to_parse: + child, _type = to_parse.popleft() + for identifier in child.named_children: + if identifier.type == "pair_pattern": + value = identifier.child_by_field_name("value") + to_parse.append((value, _type)) # TODO:CG-10064 + if value.type == "identifier": + # TODO: Support type resolution for aliased object unpacks + assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, value)) + else: + key = identifier.child_by_field_name("key") + assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, key)) + else: + assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, identifier)) + + else: + assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, left)) + while right and right.type == "assignment_expression": + left = right.child_by_field_name("left") + right = right.child_by_field_name("right") + assignments.append(TSAssignment(variable_declarator, self.file_node_id, self.ctx, self, left, right, left)) + + return assignments + + def _DEPRECATED_parse_assignment_expression(self) -> list[TSAssignment]: + assignments = [] + for child in self.ts_node.named_children: + if child.type not in ["assignment_expression", "augmented_assignment_expression"]: + continue + left = child.child_by_field_name("left") + right = child.child_by_field_name("right") + assignments.append(TSAssignment(child, self.file_node_id, self.ctx, self, left, right, left)) + + return assignments diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/attribute.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/attribute.py new file mode 100644 index 00000000..4eff6d4f --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/attribute.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graph_sitter._proxy import proxy_property +from graph_sitter.core.autocommit import reader +from graph_sitter.core.statements.attribute import Attribute +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.typescript.assignment import TSAssignment +from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock +from graph_sitter.typescript.statements.assignment_statement import TSAssignmentStatement + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.interfaces.editable import Editable + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.interfaces.has_block import TSHasBlock + + +@ts_apidoc +class TSAttribute(Attribute[TSCodeBlock, TSAssignment], TSAssignmentStatement): + """Typescript implementation of Attribute detached symbol.""" + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos=pos, assignment_node=ts_node) + self.type = self.assignments[0].type + + @reader + def _get_name_node(self) -> TSNode: + """Returns the ID node from the root node of the symbol""" + return self.ts_node.child_by_field_name("name") + + @proxy_property + @reader + def local_usages(self: TSAttribute[TSHasBlock, TSCodeBlock]) -> list[Editable]: + """Returns local usages of a TypeScript attribute within its code block. + + Searches through all statements in the attribute's parent code block and finds instances where the attribute is referenced with 'this.' prefix. Excludes the attribute's own + declaration/assignment. + + Args: + self (TSAttribute[TSHasBlock, TSCodeBlock]): The TypeScript attribute instance. + + Returns: + list[Editable]: A sorted list of unique Editable instances representing local usages of the attribute, ordered by their position in the source code. + + Note: + This method can be called as both a property or a method. If used as a property, it is equivalent to invoking it without arguments. + """ + usages = [] + for statement in self.parent.statements: + var_references = statement.find(f"this.{self.name}", exact=True) + for var_reference in var_references: + # Exclude the variable usage in the assignment itself + if self.ts_node.byte_range[0] <= var_reference.ts_node.start_byte and self.ts_node.byte_range[1] >= var_reference.ts_node.end_byte: + continue + usages.append(var_reference) + return sorted(dict.fromkeys(usages), key=lambda x: x.ts_node.start_byte) + + @property + def is_private(self) -> bool: + """Determines if this attribute has a private accessibility modifier. + + Args: + self: The TypeScript attribute instance. + + Returns: + bool: True if the attribute has a 'private' accessibility modifier, False otherwise. + """ + modifier = self.ts_node.children[0] + return modifier.type == "accessibility_modifier" and modifier.text == b"private" + + @property + def is_optional(self) -> bool: + """Returns True if this attribute is marked as optional in TypeScript. + + Checks if the attribute has a question mark (`?`) symbol after its name, indicating it's an optional field. + + Returns: + bool: True if the attribute is optional, False otherwise. + """ + if sibling := self.get_name().next_sibling: + return sibling.ts_node.type == "?" + return False diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/block_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/block_statement.py new file mode 100644 index 00000000..366a6c73 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/block_statement.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +from graph_sitter.core.statements.block_statement import BlockStatement +from graph_sitter.shared.decorators.docs import apidoc +from graph_sitter.typescript.interfaces.has_block import TSHasBlock + +if TYPE_CHECKING: + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + +Parent = TypeVar("Parent", bound="TSCodeBlock") + + +@apidoc +class TSBlockStatement(BlockStatement[Parent], TSHasBlock, Generic[Parent]): + """Statement which contains a block.""" diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/catch_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/catch_statement.py new file mode 100644 index 00000000..177d8851 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/catch_statement.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +from graph_sitter.core.statements.catch_statement import CatchStatement +from graph_sitter.shared.decorators.docs import apidoc, noapidoc +from graph_sitter.typescript.statements.block_statement import TSBlockStatement + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.interfaces.conditional_block import ConditionalBlock + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + +Parent = TypeVar("Parent", bound="TSCodeBlock") + + +@apidoc +class TSCatchStatement(CatchStatement[Parent], TSBlockStatement, Generic[Parent]): + """Typescript catch clause. + + Attributes: + code_block: The code block that may trigger an exception + condition: The condition which triggers this clause + """ + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int | None = None) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos) + self.condition = self.child_by_field_name("parameter") + + @property + @noapidoc + def other_possible_blocks(self) -> list[ConditionalBlock]: + return [self.parent] diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/comment.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/comment.py new file mode 100644 index 00000000..5e259f10 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/comment.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from enum import StrEnum + +from graph_sitter.core.autocommit import commiter, reader +from graph_sitter.core.statements.comment import Comment, lowest_indentation +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc + + +@ts_apidoc +class TSCommentType(StrEnum): + """An enumeration representing different types of comments in TypeScript. + + Represents the possible types of comments that can be used in TypeScript code, + including double slash comments (//), slash star comments (/* */), and unknown + comment types. + + Attributes: + DOUBLE_SLASH (str): Represents a single-line comment starting with //. + SLASH_STAR (str): Represents a multi-line comment enclosed in /* */. + UNKNOWN (str): Represents an unknown or unrecognized comment type. + """ + + DOUBLE_SLASH = "DOUBLE_SLASH" + SLASH_STAR = "SLASH_STAR" + UNKNOWN = "UNKNOWN" + + +@ts_apidoc +class TSComment(Comment): + """Abstract representation of typescript comments""" + + @property + @reader + def comment_type(self) -> TSCommentType: + """Determines the type of comment in a TypeScript source code. + + Parses the comment markers to determine if it's a single-line comment (//) or a multi-line comment (/* */). If no known comment markers are found, returns UNKNOWN. + + Args: + self: The TSComment instance. + + Returns: + TSCommentType: The type of the comment. Can be DOUBLE_SLASH for single-line comments, + SLASH_STAR for multi-line comments, or UNKNOWN if no known comment markers are found. + """ + if self.source.startswith("//"): + return TSCommentType.DOUBLE_SLASH + elif self.source.startswith("/*"): + return TSCommentType.SLASH_STAR + return TSCommentType.UNKNOWN + + @noapidoc + @commiter + def _parse_comment(self) -> str: + """Parse out the comment into its text content""" + # Remove comment markers + if self.comment_type == TSCommentType.DOUBLE_SLASH: + if self.source.startswith("// "): + return self.source[3:] + elif self.source.startswith("//"): + return self.source[2:] + else: + return self.source + elif self.comment_type == TSCommentType.SLASH_STAR: + formatted_text = self.source + # Remove comment markers + if self.source.startswith("/** "): + formatted_text = self.source[4:] + elif self.source.startswith("/**"): + formatted_text = self.source[3:] + elif self.source.startswith("/* "): + formatted_text = self.source[3:] + elif self.source.startswith("/*"): + formatted_text = self.source[2:] + if formatted_text.endswith(" */"): + formatted_text = formatted_text[:-3] + elif formatted_text.endswith("*/"): + formatted_text = formatted_text[:-2] + formatted_text = formatted_text.strip("\n") + formatted_split = formatted_text.split("\n") + # Get indentation level + padding = lowest_indentation(formatted_split) + # Remove indentation + formatted_text = "\n".join([line[padding:] for line in formatted_split]) + # Remove leading "* " from each line + text_lines = [] + for line in formatted_text.split("\n"): + if line.lstrip().startswith("* "): + text_lines.append(line.lstrip()[2:]) + elif line.lstrip().startswith("*"): + text_lines.append(line.lstrip()[1:]) + else: + text_lines.append(line) + return "\n".join(text_lines).rstrip() + else: + # Return the source if the comment type is unknown + return self.source + + @noapidoc + @reader + def _unparse_comment(self, new_src: str): + """Unparses cleaned text content into a comment block""" + should_add_leading_star = any([line.lstrip().startswith("*") for line in self.source.split("\n")[:-1]]) if len(self.source.split("\n")) > 1 else True + return self.generate_comment(new_src, self.comment_type, leading_star=should_add_leading_star) + + @staticmethod + def generate_comment(new_src: str, comment_type: TSCommentType, leading_star: bool = True, force_multiline: bool = False) -> str: + """Generates a TypeScript comment block from the given text content. + + Creates a comment block in either single-line (//) or multi-line (/* */) format based on the specified comment type. + + Args: + new_src (str): The text content to be converted into a comment. + comment_type (TSCommentType): The type of comment to generate (DOUBLE_SLASH or SLASH_STAR). + leading_star (bool, optional): Whether to add leading "*" to each line in multi-line comments. Defaults to True. + force_multiline (bool, optional): Whether to force multi-line format for single-line content. Defaults to False. + + Returns: + str: The formatted comment block as a string. + """ + # Generate the comment block based on the comment type + if comment_type == TSCommentType.DOUBLE_SLASH: + # Add the comment character to each line + new_src = "\n".join([f"// {line}" for line in new_src.split("\n")]) + elif comment_type == TSCommentType.SLASH_STAR: + # Add triple quotes to the text + if "\n" in new_src or force_multiline: + # Check if we should add leading "* " to each line + if leading_star: + new_src = "\n".join([(" * " + x).rstrip() for x in new_src.split("\n")]) + new_src = "/**\n" + new_src + "\n */" + else: + new_src = "/*\n" + new_src + "\n*/" + else: + new_src = "/* " + new_src + " */" + return new_src + + @staticmethod + def clean_comment(comment: str) -> str: + """Cleans comment markers and whitespace from a comment string. + + Removes various types of comment markers ('/', '/*', '/**', '*/') and trims whitespace + from the beginning and end of the comment text. + + Args: + comment (str): The raw comment string to be cleaned. + + Returns: + str: The cleaned comment text with comment markers and excess whitespace removed. + """ + comment = comment.lstrip() + if comment.startswith("//"): + comment = comment[2:] + if comment.startswith("/**"): + comment = comment[3:] + if comment.startswith("/*"): + comment = comment[2:] + if comment.endswith("*/"): + comment = comment[:-2] + return comment.strip() diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/for_loop_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/for_loop_statement.py new file mode 100644 index 00000000..24fb6e09 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/for_loop_statement.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graph_sitter.compiled.autocommit import commiter, reader +from graph_sitter.core.statements.for_loop_statement import ForLoopStatement +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.statements.block_statement import TSBlockStatement + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.dataclasses.usage import UsageKind + from graph_sitter.core.detached_symbols.function_call import FunctionCall + from graph_sitter.core.expressions import Expression + from graph_sitter.core.interfaces.has_name import HasName + from graph_sitter.core.interfaces.importable import Importable + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + + +@ts_apidoc +class TSForLoopStatement(ForLoopStatement["TSCodeBlock"], TSBlockStatement["TSCodeBlock"]): + """Abstract representation of the for loop in TypeScript. + + Attributes: + item: An item in the iterable object. Only applicable for `for...of` loops. + iterable: The iterable that is being iterated over. Only applicable for `for...of` loops. + + initializer: The counter variable. Applicable for traditional for loops. + condition: The condition for the loop. Applicable for traditional for loops. + increment: The increment expression. Applicable for traditional for loops. + """ + + # TODO: parse as statement + item: Expression[TSForLoopStatement] | None = None + # TODO: parse as statement + iterable: Expression[TSForLoopStatement] | None = None + + initializer: Expression[TSForLoopStatement] | None = None + condition: Expression[TSForLoopStatement] | None = None + increment: Expression[TSForLoopStatement] | None = None + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int | None = None) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos) + if ts_node.type == "for_statement": + self.initializer = self.child_by_field_name("initializer") + self.condition = self.child_by_field_name("condition") + self.increment = self.child_by_field_name("increment") + elif ts_node.type == "for_in_statement": + self.item = self.child_by_field_name("left") + self.iterable = self.child_by_field_name("right") + else: + msg = f"Invalid for loop type: {ts_node.type}" + raise ValueError(msg) + + @property + @reader + def is_for_in_loop(self) -> bool: + """Determines whether the current for loop is a `for...in` loop. + + A property that identifies if the current for loop is a 'for...in' loop by checking its tree-sitter node type. + + Returns: + bool: True if the for loop is a 'for...in' loop, False otherwise. + """ + return self.ts_node.type == "for_in_statement" + + @property + @reader + def function_calls(self) -> list[FunctionCall]: + """Retrieves all function calls within a for loop statement. + + For a for...in loop, collects function calls from the iterable expression. + For a traditional for loop, collects function calls from the initializer, + condition, and increment expressions. Also includes function calls from + the superclass implementation. + + Returns: + list[FunctionCall]: A list of all FunctionCall objects found within the for loop statement. + """ + fcalls = [] + if self.is_for_in_loop: + fcalls.extend(self.iterable.function_calls) + else: + fcalls.extend(self.initializer.function_calls) + fcalls.extend(self.condition.function_calls) + if self.increment: + fcalls.extend(self.increment.function_calls) + fcalls.extend(super().function_calls) + return fcalls + + @noapidoc + @commiter + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: + if self.is_for_in_loop: + self.item._compute_dependencies(usage_type, dest) + self.iterable._compute_dependencies(usage_type, dest) + else: + self.initializer._compute_dependencies(usage_type, dest) + self.condition._compute_dependencies(usage_type, dest) + if self.increment: + self.increment._compute_dependencies(usage_type, dest) + super()._compute_dependencies(usage_type, dest) + + @property + @noapidoc + def descendant_symbols(self) -> list[Importable]: + symbols = [] + if self.is_for_in_loop: + symbols.extend(self.item.descendant_symbols) + symbols.extend(self.iterable.descendant_symbols) + else: + symbols.extend(self.initializer.descendant_symbols) + symbols.extend(self.condition.descendant_symbols) + if self.increment: + symbols.extend(self.increment.descendant_symbols) + symbols.extend(super().descendant_symbols) + return symbols diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/if_block_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/if_block_statement.py new file mode 100644 index 00000000..bad8692a --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/if_block_statement.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +from graph_sitter.core.autocommit import reader, writer +from graph_sitter.core.statements.if_block_statement import IfBlockStatement +from graph_sitter.core.statements.statement import StatementType +from graph_sitter.shared.decorators.docs import apidoc +from graph_sitter.shared.logging.get_logger import get_logger + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + + +logger = get_logger(__name__) + + +Parent = TypeVar("Parent", bound="TSCodeBlock") + + +@apidoc +class TSIfBlockStatement(IfBlockStatement[Parent, "TSIfBlockStatement"], Generic[Parent]): + """Typescript implementation of the if/elif/else statement block. + For example, if there is a code block like: + if (condition1) { + block1 + } else if (condition2) { + block2 + } else { + block3 + } + This class represents the entire block, including the conditions and nested code blocks. + """ + + statement_type = StatementType.IF_BLOCK_STATEMENT + _else_clause_node: TSNode | None = None + + def __init__( + self, + ts_node: TSNode, + file_node_id: NodeId, + ctx: CodebaseContext, + parent: Parent, + pos: int, + else_clause_node: TSNode | None = None, + main_if_block: TSIfBlockStatement | None = None, + ) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos) + self._else_clause_node = else_clause_node + self._main_if_block = main_if_block + # Call .value to unwrap the parenthesis + condition = self.child_by_field_name("condition") + self.condition = condition.value if condition else None + self.consequence_block = self._parse_consequence_block() + self._alternative_blocks = self._parse_alternative_blocks() if self.is_if_statement else None + self.consequence_block.parse() + + @reader + def _parse_consequence_block(self) -> TSCodeBlock: + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + + if self.is_if_statement or self.is_elif_statement: + consequence_node = self.ts_node.child_by_field_name("consequence") + else: + consequence_node = self.ts_node.named_children[0] + return TSCodeBlock(consequence_node, self.parent.level + 1, self.parent, self) + + @reader + def _parse_alternative_blocks(self) -> list[TSIfBlockStatement]: + if self.is_else_statement or self.is_elif_statement: + return [] + + if_blocks = [] + alt_block = self + while alt_node := alt_block.ts_node.child_by_field_name("alternative"): + if (if_node := alt_node.named_children[0]).type == "if_statement": + # Elif statements are represented as if statements with an else clause as the parent node + alt_block = TSIfBlockStatement(if_node, self.file_node_id, self.ctx, self.parent, self.index, else_clause_node=alt_node, main_if_block=self._main_if_block or self) + else: + # Else clause + alt_block = TSIfBlockStatement(alt_node, self.file_node_id, self.ctx, self.parent, self.index, main_if_block=self._main_if_block or self) + if_blocks.append(alt_block) + return if_blocks + + @property + @reader + def is_if_statement(self) -> bool: + """Determines if the current block is a standalone 'if' statement. + + Args: + None + + Returns: + bool: True if the current block is a standalone 'if' statement, False otherwise. + """ + return self.ts_node.type == "if_statement" and self._else_clause_node is None + + @property + @reader + def is_else_statement(self) -> bool: + """Determines if the current block is an else block. + + A property that checks if the current TreeSitter node represents an else clause in an if/elif/else statement structure. + + Returns: + bool: True if the current block is an else block, False otherwise. + """ + return self.ts_node.type == "else_clause" + + @property + @reader + def is_elif_statement(self) -> bool: + """Determines if the current block is an elif block. + + This method checks if the current block is an elif block by verifying that it is both an if_statement and has an else clause node associated with it. + + Returns: + bool: True if the current block is an elif block, False otherwise. + """ + return self.ts_node.type == "if_statement" and self._else_clause_node is not None + + @writer + def _else_if_to_if(self) -> None: + """Converts an elif block to an if block. + + Args: + None + + Returns: + None + """ + if not self.is_elif_statement: + return + + self.remove_byte_range(self.ts_node.start_byte - len("else "), self.ts_node.start_byte) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/import_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/import_statement.py new file mode 100644 index 00000000..e1d5420e --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/import_statement.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graph_sitter.core.expressions.builtin import Builtin +from graph_sitter.core.statements.import_statement import ImportStatement +from graph_sitter.core.symbol_groups.collection import Collection +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.typescript.import_resolution import TSImport + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + + +@ts_apidoc +class TSImportStatement(ImportStatement["TSFile", TSImport, "TSCodeBlock"], Builtin): + """A class representing an import statement in TypeScript, managing both static and dynamic imports. + + This class handles various types of TypeScript imports including regular import statements, + dynamic imports, and export statements. It provides functionality to manage and track imports + within a TypeScript file, enabling operations like analyzing dependencies, moving imports, + and modifying import statements. + + Attributes: + imports (Collection): A collection of TypeScript imports contained within the statement. + """ + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int, *, source_node: TSNode | None = None) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos) + imports = [] + if ts_node.type == "import_statement": + imports.extend(TSImport.from_import_statement(ts_node, file_node_id, ctx, self)) + elif ts_node.type in ["call_expression", "lexical_declaration", "expression_statement", "type_alias_declaration"]: + import_call_node = source_node.child_by_field_name("function") + arguments = source_node.child_by_field_name("arguments") + imports.extend(TSImport.from_dynamic_import_statement(import_call_node, arguments, file_node_id, ctx, self)) + elif ts_node.type == "export_statement": + imports.extend(TSImport.from_export_statement(source_node, file_node_id, ctx, self)) + self.imports = Collection(ts_node, file_node_id, ctx, self, delimiter="\n", children=imports) + for imp in self.imports: + imp.import_statement = self diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/labeled_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/labeled_statement.py new file mode 100644 index 00000000..6cc6c0ef --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/labeled_statement.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +from graph_sitter.core.expressions import Expression, Name +from graph_sitter.core.interfaces.has_name import HasName +from graph_sitter.core.statements.statement import Statement, StatementType +from graph_sitter.shared.decorators.docs import ts_apidoc + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + + +Parent = TypeVar("Parent", bound="TSCodeBlock") + + +@ts_apidoc +class TSLabeledStatement(Statement[Parent], HasName, Generic[Parent]): + """Statement with a named label. It resolves to various types of statements like loops, switch cases, etc. + + Examples: + ``` + outerLoop: for (let i = 0; i < 5; i++) { + innerLoop: for (let j = 0; j < 5; j++) { + if (i === 2 && j === 2) { + break outerLoop; // This will break out of the outer loop + } + console.log(`i: ${i}, j: ${j}`); + } + } + ``` + ``` + emptyStatement: { pass } + ``` + + Attributes: + body: The body of the labeled statement, which can be an Expression or None. + """ + + statement_type = StatementType.LABELED_STATEMENT + body: Expression | None + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos) + self._name_node = Name(ts_node.child_by_field_name("label"), file_node_id, ctx, self) + body_node = self.ts_node.child_by_field_name("body") + self.body = self._parse_expression(body_node) if body_node else None + + @property + def label(self) -> str: + """Returns the label of the labeled statement. + + Acts as a property getter that returns the name of the labeled statement. For example, in code like + 'outerLoop: for...', this would return 'outerLoop'. + + Returns: + str: The label name of the statement. + """ + return self.name diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/switch_case.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/switch_case.py new file mode 100644 index 00000000..5bf9e424 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/switch_case.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +from tree_sitter import Node as TSNode + +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.core.statements.switch_case import SwitchCase +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock +from graph_sitter.typescript.statements.block_statement import TSBlockStatement + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + + +@ts_apidoc +class TSSwitchCase(SwitchCase[TSCodeBlock["TSSwitchStatement"]], TSBlockStatement): + """Typescript switch case. + + Attributes: + default: is this a default case? + """ + + default: bool + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: TSCodeBlock, pos: int | None = None) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos) + self.condition = self.child_by_field_name("value") + self.default = self.ts_node.type == "switch_default" diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/switch_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/switch_statement.py new file mode 100644 index 00000000..4b235463 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/switch_statement.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graph_sitter.core.statements.switch_statement import SwitchStatement +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.typescript.statements.switch_case import TSSwitchCase + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + + +@ts_apidoc +class TSSwitchStatement(SwitchStatement["TSCodeBlock[Self]", "TSCodeBlock", TSSwitchCase]): + """Typescript switch statement""" + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int | None = None) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos) + self.value = self.child_by_field_name("value") + code_block = self.ts_node.child_by_field_name("body") + self.cases = [] + for node in code_block.named_children: + self.cases.append(TSSwitchCase(node, file_node_id, ctx, self)) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/try_catch_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/try_catch_statement.py new file mode 100644 index 00000000..b24ca310 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/try_catch_statement.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Self, override + +from graph_sitter.compiled.autocommit import commiter, reader +from graph_sitter.core.statements.try_catch_statement import TryCatchStatement +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.statements.block_statement import TSBlockStatement +from graph_sitter.typescript.statements.catch_statement import TSCatchStatement + +if TYPE_CHECKING: + from collections.abc import Sequence + + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.dataclasses.usage import UsageKind + from graph_sitter.core.detached_symbols.function_call import FunctionCall + from graph_sitter.core.interfaces.conditional_block import ConditionalBlock + from graph_sitter.core.interfaces.has_name import HasName + from graph_sitter.core.interfaces.importable import Importable + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + + +@ts_apidoc +class TSTryCatchStatement(TryCatchStatement["TSCodeBlock"], TSBlockStatement): + """Abstract representation of the try/catch/finally block in TypeScript. + + Attributes: + catch: The catch block. + """ + + catch: TSCatchStatement[Self] | None = None + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int | None = None) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos) + if handler_node := self.ts_node.child_by_field_name("handler"): + self.catch = TSCatchStatement(handler_node, file_node_id, ctx, self) + if finalizer_node := self.ts_node.child_by_field_name("finalizer"): + self.finalizer = TSBlockStatement(finalizer_node, file_node_id, ctx, self.code_block) + + @property + @reader + def function_calls(self) -> list[FunctionCall]: + """Gets all function calls within a try-catch-finally statement. + + This property retrieves all function calls from the try block, catch block (if present), and finally block (if present). + + Returns: + list[FunctionCall]: A list of function calls found within the try-catch-finally statement, including those from + the try block, catch block (if it exists), and finally block (if it exists). + """ + fcalls = super().function_calls + if self.catch: + fcalls.extend(self.catch.function_calls) + if self.finalizer: + fcalls.extend(self.finalizer.function_calls) + return fcalls + + @noapidoc + @commiter + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: + super()._compute_dependencies(usage_type, dest) + if self.catch: + self.catch._compute_dependencies(usage_type, dest) + if self.finalizer: + self.finalizer._compute_dependencies(usage_type, dest) + + @property + @noapidoc + def descendant_symbols(self) -> list[Importable]: + symbols = super().descendant_symbols + if self.catch: + symbols.extend(self.catch.descendant_symbols) + if self.finalizer: + symbols.extend(self.finalizer.descendant_symbols) + return symbols + + @property + @reader + @override + def nested_code_blocks(self) -> list[TSCodeBlock]: + """Returns all nested CodeBlocks within the statement. + + Retrieves a list of all the code blocks nested within this try/catch/finally statement, including the catch and finally blocks if they exist. + + Returns: + list[TSCodeBlock]: A list of nested code blocks, including the catch and finally blocks. + """ + nested_blocks = super().nested_code_blocks + if self.catch: + nested_blocks.append(self.catch.code_block) + if self.finalizer: + nested_blocks.append(self.finalizer.code_block) + return nested_blocks + + @property + @noapidoc + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: + if self.catch: + return [self.catch] + else: + return [] diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/while_statement.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/while_statement.py new file mode 100644 index 00000000..cb83d352 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/statements/while_statement.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graph_sitter.core.statements.while_statement import WhileStatement +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.typescript.interfaces.has_block import TSHasBlock + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.codebase_context import CodebaseContext + from graph_sitter.core.node_id_factory import NodeId + from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock + + +@ts_apidoc +class TSWhileStatement(WhileStatement["TSCodeBlock"], TSHasBlock): + """A TypeScript while statement class that represents while loops and manages their condition and code block. + + This class provides functionality for handling while statements in TypeScript code, + including managing the loop's condition and associated code block. It extends the base + WhileStatement class with TypeScript-specific behavior. + + Attributes: + condition (str | None): The condition expression of the while loop. + """ + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: TSCodeBlock, pos: int | None = None) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, pos) + condition = self.child_by_field_name("condition") + self.condition = condition.value if condition else None diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/symbol.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/symbol.py new file mode 100644 index 00000000..2752c857 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/symbol.py @@ -0,0 +1,512 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Self, Unpack + +from graph_sitter.core.assignment import Assignment +from graph_sitter.core.autocommit import reader, writer +from graph_sitter.core.dataclasses.usage import UsageKind, UsageType +from graph_sitter.core.detached_symbols.function_call import FunctionCall +from graph_sitter.core.expressions import Value +from graph_sitter.core.expressions.chained_attribute import ChainedAttribute +from graph_sitter.core.expressions.type import Type +from graph_sitter.core.interfaces.exportable import Exportable +from graph_sitter.core.symbol import Symbol +from graph_sitter.core.type_alias import TypeAlias +from graph_sitter.enums import ImportType, NodeType +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.import_resolution import TSImport +from graph_sitter.typescript.statements.comment import TSComment, TSCommentType +from graph_sitter.typescript.symbol_groups.comment_group import TSCommentGroup + +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from graph_sitter.codebase.flagging.code_flag import CodeFlag + from graph_sitter.codebase.flagging.enums import FlagKwargs + from graph_sitter.core.detached_symbols.parameter import Parameter + from graph_sitter.core.file import SourceFile + from graph_sitter.core.import_resolution import Import + from graph_sitter.core.interfaces.editable import Editable + from graph_sitter.core.node_id_factory import NodeId + + +@ts_apidoc +class TSSymbol(Symbol["TSHasBlock", "TSCodeBlock"], Exportable): + """A TypeScript symbol representing a code element with advanced manipulation capabilities. + + This class extends Symbol and Exportable to provide TypeScript-specific functionality for managing + code symbols. It offers methods for handling imports, comments, code refactoring, and file operations + like moving symbols between files while maintaining their dependencies and references. + + The class provides functionality for managing both inline and block comments, setting and retrieving + import strings, and maintaining semicolon presence. It includes capabilities for moving symbols between + files with options to handle dependencies and import strategy selection. + """ + + @reader + def get_import_string(self, alias: str | None = None, module: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> str: + """Generates the appropriate import string for a symbol. + + Constructs and returns an import statement string based on the provided parameters, formatting it according + to TypeScript import syntax rules. + + Args: + alias (str | None, optional): The alias to use for the imported symbol. Defaults to None. + module (str | None, optional): The module to import from. If None, uses the file's import module name. + Defaults to None. + import_type (ImportType, optional): The type of import to generate (e.g., WILDCARD). Defaults to + ImportType.UNKNOWN. + is_type_import (bool, optional): Whether this is a type-only import. Defaults to False. + + Returns: + str: A formatted import statement string. + """ + type_prefix = "type " if is_type_import else "" + import_module = module if module is not None else self.file.import_module_name + + if import_type == ImportType.WILDCARD: + file_as_module = self.file.name + return f"import {type_prefix}* as {file_as_module} from {import_module};" + elif alias is not None and alias != self.name: + return f"import {type_prefix}{{ {self.name} as {alias} }} from {import_module};" + else: + return f"import {type_prefix}{{ {self.name} }} from {import_module};" + + @property + @reader(cache=False) + def extended_nodes(self) -> list[Editable]: + """Returns the list of nodes associated with this symbol including extended nodes. + + This property returns a list of Editable nodes that includes any wrapping or extended symbols like `export`, `public`, or decorators. + For example, if the symbol is within an `export_statement` or `lexical_declaration`, those nodes will be included in the list. + + Args: + No arguments. + + Returns: + list[Editable]: A list of Editable nodes including the symbol's extended nodes like export statements and decorators. + """ + nodes = super().extended_nodes + + # Check if the symbol is wrapped by another node like 'export_statement' + new_ts_node = self.ts_node + while (parent := new_ts_node.parent).type in ("export_statement", "lexical_declaration", "variable_declarator"): + new_ts_node = parent + + return [Value(new_ts_node, self.file_node_id, self.ctx, self.parent) if node.ts_node == self.ts_node else node for node in nodes] + + @property + @reader + def comment(self) -> TSCommentGroup | None: + """Retrieves the comment group associated with the symbol. + + Returns the TSCommentGroup object that contains any comments associated with the symbol. + A comment group represents one or more related comments that precede the symbol in the code. + + Returns: + TSCommentGroup | None: The comment group for the symbol if one exists, None otherwise. + """ + return TSCommentGroup.from_symbol_comments(self) + + @property + @reader + def inline_comment(self) -> TSCommentGroup | None: + """Property that retrieves the inline comment group associated with the symbol. + + Args: + None + + Returns: + TSCommentGroup | None: The inline comment group associated with the symbol if it exists, + otherwise None. + """ + return TSCommentGroup.from_symbol_inline_comments(self) + + @writer + def set_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True, comment_type: TSCommentType = TSCommentType.DOUBLE_SLASH) -> None: + """Sets a comment to the symbol. + + Adds or updates a comment for a code symbol. If a comment already exists, it will be edited. If no + comment exists, a new comment group will be created. + + Args: + comment (str): The comment text to be added. + auto_format (bool, optional): Whether to automatically format the text into a comment syntax. + Defaults to True. + clean_format (bool, optional): Whether to clean the format of the comment before inserting. + Defaults to True. + comment_type (TSCommentType, optional): The style of comment to add. + Defaults to TSCommentType.DOUBLE_SLASH. + + Returns: + None + + Raises: + None + """ + if clean_format: + comment = TSComment.clean_comment(comment) + + # If comment already exists, add the comment to the existing comment group + if self.comment: + if auto_format: + self.comment.edit_text(comment) + else: + self.comment.edit(comment, fix_indentation=True) + else: + if auto_format: + comment = TSComment.generate_comment(comment, comment_type) + self.insert_before(comment, fix_indentation=True) + + @writer + def add_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True, comment_type: TSCommentType = TSCommentType.DOUBLE_SLASH) -> None: + """Adds a new comment to the symbol. + + Appends a comment to an existing comment group or creates a new comment group if none exists. + + Args: + comment (str): The comment text to be added. + auto_format (bool): Whether to automatically format the text into a comment style. Defaults to True. + clean_format (bool): Whether to clean the format of the comment before inserting. Defaults to True. + comment_type (TSCommentType): Type of comment to add. Defaults to TSCommentType.DOUBLE_SLASH. + + Returns: + None + + Raises: + None + """ + if clean_format: + comment = TSComment.clean_comment(comment) + if auto_format: + comment = TSComment.generate_comment(comment, comment_type) + + # If comment already exists, add the comment to the existing comment group + if self.comment: + self.comment.insert_after(comment, fix_indentation=True) + else: + self.insert_before(comment, fix_indentation=True) + + @writer + def set_inline_comment(self, comment: str, auto_format: bool = True, clean_format: bool = True, node: TSNode | None = None) -> None: + """Sets an inline comment to the symbol. + + Sets or replaces an inline comment for a symbol at its current position. If an inline comment + already exists, it is replaced with the new comment. If no inline comment exists, a new one + will be created adjacent to the symbol. + + Args: + comment (str): The inline comment text to be added. + auto_format (bool, optional): Whether to automatically format the text as a comment. + Defaults to True. + clean_format (bool, optional): Whether to clean the comment format before inserting. + Defaults to True. + node (TSNode | None, optional): The specific node to attach the comment to. + Defaults to None. + + Returns: + None + + Raises: + None + """ + if clean_format: + comment = TSComment.clean_comment(comment) + + if self.inline_comment: + if auto_format: + self.inline_comment.edit_text(comment) + else: + self.inline_comment.edit(comment) + else: + if auto_format: + comment = " " + TSComment.generate_comment(comment, TSCommentType.DOUBLE_SLASH) + node = node or self.ts_node + Value(node, self.file_node_id, self.ctx, self).insert_after(comment, fix_indentation=False, newline=False) + + @property + @reader + def semicolon_node(self) -> Editable | None: + """Retrieves the semicolon node associated with a TypeScript symbol. + + A semicolon node is a TreeSitter node of type ';' that appears immediately after the symbol node. + + Returns: + Editable | None: The semicolon node wrapped as an Editable if it exists, None otherwise. + """ + sibbling = self.ts_node.next_sibling + if sibbling and sibbling.type == ";": + return Value(sibbling, self.file_node_id, self.ctx, self) + return None + + @property + @reader + def has_semicolon(self) -> bool: + """Checks whether the current symbol has a semicolon at the end. + + This property determines if a semicolon is present at the end of the symbol by checking + if the semicolon_node property exists. + + Returns: + bool: True if the symbol has a semicolon at the end, False otherwise. + """ + return self.semicolon_node is not None + + @noapidoc + def _move_to_file( + self, + file: SourceFile, + encountered_symbols: set[Symbol | Import], + include_dependencies: bool = True, + strategy: Literal["add_back_edge", "update_all_imports", "duplicate_dependencies"] = "update_all_imports", + ) -> tuple[NodeId, NodeId]: + # TODO: Prevent creation of import loops (!) - raise a ValueError and make the agent fix it + # =====[ Arg checking ]===== + if file == self.file: + return file.file_node_id, self.node_id + + # =====[ Move over dependencies recursively ]===== + if include_dependencies: + try: + for dep in self.dependencies: + if dep in encountered_symbols: + continue + + # =====[ Symbols - move over ]===== + elif isinstance(dep, TSSymbol): + if dep.is_top_level: + encountered_symbols.add(dep) + dep._move_to_file(file, encountered_symbols=encountered_symbols, include_dependencies=True, strategy=strategy) + + # =====[ Imports - copy over ]===== + elif isinstance(dep, TSImport): + if dep.imported_symbol: + file.add_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type) + else: + file.add_import(dep.source) + + else: + msg = f"Unknown dependency type {type(dep)}" + raise ValueError(msg) + except Exception as e: + print(f"Failed to move dependencies of {self.name}: {e}") + else: + try: + for dep in self.dependencies: + if isinstance(dep, Assignment): + msg = "Assignment not implemented yet" + raise NotImplementedError(msg) + + # =====[ Symbols - move over ]===== + elif isinstance(dep, Symbol) and dep.is_top_level: + file.add_import(imp=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=isinstance(dep, TypeAlias)) + + if not dep.is_exported: + dep.file.add_export_to_symbol(dep) + pass + + # =====[ Imports - copy over ]===== + elif isinstance(dep, TSImport): + if dep.imported_symbol: + file.add_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type, is_type_import=dep.is_type_import()) + else: + file.add_import(dep.source) + + except Exception as e: + print(f"Failed to move dependencies of {self.name}: {e}") + + # =====[ Make a new symbol in the new file ]===== + # This will update all edges etc. + file.add_symbol(self) + import_line = self.get_import_string(module=file.import_module_name) + + # =====[ Checks if symbol is used in original file ]===== + # Takes into account that it's dependencies will be moved + is_used_in_file = any(usage.file == self.file and usage.node_type == NodeType.SYMBOL and usage not in encountered_symbols for usage in self.symbol_usages) + + # ======[ Strategy: Duplicate Dependencies ]===== + if strategy == "duplicate_dependencies": + # If not used in the original file. or if not imported from elsewhere, we can just remove the original symbol + if not is_used_in_file and not any(usage.kind is UsageKind.IMPORTED and usage.usage_symbol not in encountered_symbols for usage in self.usages): + self.remove() + + # ======[ Strategy: Add Back Edge ]===== + # Here, we will add a "back edge" to the old file importing the self + elif strategy == "add_back_edge": + if is_used_in_file: + self.file.add_import(import_line) + if self.is_exported: + self.file.add_import(f"export {{ {self.name} }}") + elif self.is_exported: + module_name = file.name + self.file.add_import(f"export {{ {self.name} }} from '{module_name}'") + # Delete the original symbol + self.remove() + + # ======[ Strategy: Update All Imports ]===== + # Update the imports in all the files which use this symbol to get it from the new file now + elif strategy == "update_all_imports": + for usage in self.usages: + if isinstance(usage.usage_symbol, TSImport): + # Add updated import + if usage.usage_symbol.resolved_symbol is not None and usage.usage_symbol.resolved_symbol.node_type == NodeType.SYMBOL and usage.usage_symbol.resolved_symbol == self: + usage.usage_symbol.file.add_import(import_line) + usage.usage_symbol.remove() + elif usage.usage_type == UsageType.CHAINED: + # Update all previous usages of import * to the new import name + if usage.match and "." + self.name in usage.match: + if isinstance(usage.match, FunctionCall): + usage.match.get_name().edit(self.name) + if isinstance(usage.match, ChainedAttribute): + usage.match.edit(self.name) + usage.usage_symbol.file.add_import(import_line) + if is_used_in_file: + self.file.add_import(import_line) + # Delete the original symbol + self.remove() + + def _convert_proptype_to_typescript(self, prop_type: Editable, param: Parameter | None, level: int) -> str: + """Converts a PropType definition to its TypeScript equivalent.""" + # Handle basic types + type_map = {"string": "string", "number": "number", "bool": "boolean", "object": "object", "array": "any[]", "func": "CallableFunction"} + if prop_type.source in type_map: + return type_map[prop_type.source] + if isinstance(prop_type, ChainedAttribute): + if prop_type.attribute.source == "node": + return "T" + if prop_type.attribute.source == "element": + self.file.add_import("import React from 'react';\n") + return "React.ReactElement" + if prop_type.attribute.source in type_map: + return type_map[prop_type.attribute.source] + # if prop_type.attribute.source == "func": + # params = [] + # if param: + # for usage in param.usages: + # call = None + # if isinstance(usage.match, FunctionCall): + # call = usage.match + # elif isinstance(usage.match.parent, FunctionCall): + # call = usage.match.parent + # if call: + # for arg in call.args: + # resolved_value = arg.value.resolved_value + # if resolved_value.rstrip("[]") not in ("number", "string", "boolean", "any", "object"): + # resolved_value = "any" + # params.append(f"{arg.name or arg.source}: {resolved_value}") + # return f"({",".join(params)}) => void" + return "Function" + if prop_type.attribute.source == "isRequired": + return self._convert_proptype_to_typescript(prop_type.object, param, level) + if isinstance(prop_type, FunctionCall): + if prop_type.name == "isRequired": + return self._convert_proptype_to_typescript(prop_type.args[0].value, param, level) + # Handle arrays + if prop_type.name == "arrayOf": + item = self._convert_proptype_to_typescript(prop_type.args[0].value, param, level) + # needs_parens = isinstance(prop_type.args[0].value, FunctionCall) + needs_parens = False + return f"({item})[]" if needs_parens else f"{item}[]" + + # Handle oneOf + if prop_type.name == "oneOf": + values = [arg.source for arg in prop_type.args[0].value] + # Add parentheses if one of the values is a function + return " | ".join(f"({t})" if "() => void" == t else t for t in values) + # Handle anyOf (alias for oneOf) + if prop_type.name == "anyOf": + values = [arg.source for arg in prop_type.args[0].value] + # Add parentheses if one of the values is a function + return " | ".join(f"({t})" if "() => void" == t else t for t in values) + + # Handle oneOfType + if prop_type.name == "oneOfType": + types = [self._convert_proptype_to_typescript(arg, param, level) for arg in prop_type.args[0].value] + # Only add parentheses if one of the types is a function + return " | ".join(f"({t})" if "() => void" == t else t for t in types) + + # Handle shape + if prop_type.name == "shape": + return self._convert_dict(prop_type.args[0].value, level) + if prop_type.name == "objectOf": + return self._convert_object_of(prop_type.args[0].value, level) + return "any" + + def _convert_dict(self, value: Type, level: int) -> str: + """Converts a dictionary of PropTypes to a TypeScript interface string.""" + result = "{\n" + for key, value in value.items(): + is_required = isinstance(value, ChainedAttribute) and value.attribute.source == "isRequired" + optional = "" if is_required else "?" + indent = " " * level + param = next((p for p in self.parameters if p.name == key), None) if self.parameters else None + result += f"{indent}{key}{optional}: {self._convert_proptype_to_typescript(value, param, level + 1)};\n" + indent = " " * (level - 1) + + result += f"{indent}}}" + return result + + def _convert_object_of(self, value: Type, level: int) -> str: + """Converts a dictionary of PropTypes to a TypeScript interface string.""" + indent = " " * level + prev_indent = " " * (level - 1) + type_value = self._convert_proptype_to_typescript(value, None, level + 1) + return f"{{\n{indent}[key: string]: {type_value};\n{prev_indent}}}" + + def _get_static_prop_types(self) -> Type | None: + """Returns a dictionary of prop types for a React component.""" + for usage in self.usages: + if isinstance(usage.usage_symbol, Assignment) and usage.usage_symbol.name == "propTypes": + assert isinstance(usage.usage_symbol.value, Type), usage.usage_symbol.value.__class__ + return usage.usage_symbol.value + return None + + @noapidoc + def convert_to_react_interface(self) -> str | None: + if not self.is_jsx: + return None + + component_name = self.name + # Handle class components with static propTypes + if proptypes := self._get_static_prop_types(): + generics = "" + generic_name = "" + if "PropTypes.node" in proptypes.source: + generics = "" + generic_name = "" + self.file.add_import("import React from 'react';\n") + interface_name = f"{component_name}Props" + # Create interface definition + interface_def = f"interface {interface_name}{generics} {self._convert_dict(proptypes, 1)}" + + # Insert interface and update component + self.insert_before(interface_def + "\n") + + proptypes.parent_statement.remove() + for imp in self.file.imports: + if imp.module.source.strip("'").strip('"') in ("react", "prop-types"): + imp.remove_if_unused() + return interface_name + generic_name + + @writer + def flag(self, **kwargs: Unpack[FlagKwargs]) -> CodeFlag[Self]: + """Flags a TypeScript symbol by adding a flag comment and returning a CodeFlag. + + This implementation first creates the CodeFlag through the standard flagging system, + then adds a TypeScript-specific comment to visually mark the flagged code. + + Args: + **kwargs: Flag keyword arguments including optional 'message' + + Returns: + CodeFlag[Self]: The code flag object for tracking purposes + """ + # First create the standard CodeFlag through the base implementation + code_flag = super().flag(**kwargs) + + # Add a TypeScript comment to visually mark the flag + message = kwargs.get("message", "") + if message: + self.set_inline_comment(f"🚩 {message}") + + return code_flag diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/symbol_groups/comment_group.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/symbol_groups/comment_group.py new file mode 100644 index 00000000..c9947424 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/symbol_groups/comment_group.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graph_sitter.core.symbol_groups.comment_group import CommentGroup +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.statements.comment import TSComment, TSCommentType + +if TYPE_CHECKING: + from graph_sitter.typescript.symbol import TSSymbol + + +@ts_apidoc +class TSCommentGroup(CommentGroup): + """A group of related symbols that represent a comment or docstring in TypeScript + + For example: + ``` + // Comment 1 + // Comment 2 + // Comment 3 + ``` + would be 3 individual comments (accessible via `symbols`), but together they form a `CommentGroup` (accessible via `self). + """ + + @staticmethod + @noapidoc + def _get_sibbling_comments(symbol: TSSymbol) -> list[TSComment]: + # Locate the body that contains the comment nodes + current_node = symbol.ts_node + parent_node = symbol.ts_node.parent + while parent_node and parent_node.type not in ["program", "class_body", "block", "statement_block"]: + current_node = parent_node + parent_node = parent_node.parent + + if not parent_node: + return None + + # Find the correct index of function_node in parent_node's children + function_index = parent_node.children.index(current_node) + + if function_index is None: + return None # function_node is not a child of parent_node + + if function_index == 0: + return None # No nodes before this function, hence no comments + + comment_nodes = [] + # Iterate backwards from the function node to collect all preceding comment nodes + for i in range(function_index - 1, -1, -1): + if parent_node.children[i].type == "comment": + # Check if the comment is directly above each other + if parent_node.children[i].end_point[0] == parent_node.children[i + 1].start_point[0] - 1: + comment = TSComment.from_code_block(parent_node.children[i], symbol) + comment_nodes.insert(0, comment) + else: + break # Stop if there is a break in the comments + else: + break # Stop if a non-comment node is encountered + + return comment_nodes + + @classmethod + @noapidoc + def from_symbol_comments(cls, symbol: TSSymbol): + comment_nodes = cls._get_sibbling_comments(symbol) + if not comment_nodes: + return None + return cls(comment_nodes, symbol.file_node_id, symbol.ctx, symbol) + + @classmethod + @noapidoc + def from_symbol_inline_comments(cls, symbol: TSSymbol): + # Locate the body that contains the comment nodes + current_node = symbol.ts_node + parent_node = symbol.ts_node.parent + while parent_node and parent_node.type not in ["program", "class_body", "block", "statement_block"]: + current_node = parent_node + parent_node = parent_node.parent + + if not parent_node: + return None + + # Find the correct index of function_node in parent_node's children + function_index = parent_node.children.index(current_node) + + if function_index is None: + return None # function_node is not a child of parent_node + + comment_nodes = [] + # Check if there are any comments after the function node + if function_index + 1 < len(parent_node.children): + if parent_node.children[function_index + 1].type == "comment": + # Check if the comment is on the same line + if parent_node.children[function_index].end_point[0] == parent_node.children[function_index + 1].start_point[0]: + comment = TSComment.from_code_block(parent_node.children[function_index + 1], symbol) + comment_nodes.append(comment) + + if not comment_nodes: + return None + + return cls(comment_nodes, symbol.file_node_id, symbol.ctx, symbol) + + @classmethod + @noapidoc + def from_docstring(cls, symbol: TSSymbol) -> TSCommentGroup | None: + """Returns the docstring of the function""" + comment_nodes = cls._get_sibbling_comments(symbol) + if not comment_nodes: + return None + # Docstring comments are filtered by SLASH_STAR comments + docstring_nodes = [comment for comment in comment_nodes if comment.comment_type == TSCommentType.SLASH_STAR] + if not docstring_nodes: + return None + return cls(docstring_nodes, symbol.file_node_id, symbol.ctx, symbol) + + @classmethod + @noapidoc + def from_comment_nodes(cls, comment_nodes: list[TSComment], symbol: TSSymbol): + if not comment_nodes: + return None + + # Docstring comments are filtered by SLASH_STAR comments + docstring_nodes = [comment for comment in comment_nodes if comment.comment_type == TSCommentType.SLASH_STAR] + if not docstring_nodes: + return None + return cls(docstring_nodes, symbol.file_node_id, symbol.ctx, symbol) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/symbol_groups/dict.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/symbol_groups/dict.py new file mode 100644 index 00000000..dec042d6 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/symbol_groups/dict.py @@ -0,0 +1,144 @@ +from typing import TYPE_CHECKING, Self, TypeVar, override + +from tree_sitter import Node as TSNode + +from graph_sitter.compiled.autocommit import reader +from graph_sitter.core.autocommit import writer +from graph_sitter.core.expressions import Expression +from graph_sitter.core.expressions.string import String +from graph_sitter.core.interfaces.editable import Editable +from graph_sitter.core.interfaces.has_attribute import HasAttribute +from graph_sitter.core.node_id_factory import NodeId +from graph_sitter.core.symbol_groups.dict import Dict, Pair +from graph_sitter.shared.decorators.docs import apidoc, noapidoc, ts_apidoc +from graph_sitter.shared.logging.get_logger import get_logger + +if TYPE_CHECKING: + from graph_sitter.codebase.codebase_context import CodebaseContext + +Parent = TypeVar("Parent", bound="Editable") +TExpression = TypeVar("TExpression", bound=Expression) + +logger = get_logger(__name__) + + +@ts_apidoc +class TSPair(Pair): + """A TypeScript pair node that represents key-value pairs in object literals. + + A specialized class extending `Pair` for handling TypeScript key-value pairs, + particularly in object literals. It provides functionality for handling both + regular key-value pairs and shorthand property identifiers, with support for + reducing boolean conditions. + + Attributes: + shorthand (bool): Indicates whether this pair uses shorthand property syntax. + """ + + shorthand: bool + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent) -> None: + super().__init__(ts_node, file_node_id, ctx, parent) + self.shorthand = ts_node.type == "shorthand_property_identifier" + + def _get_key_value(self) -> tuple[Expression[Self] | None, Expression[Self] | None]: + from graph_sitter.typescript.function import TSFunction + + key, value = None, None + + if self.ts_node.type == "pair": + key = self.child_by_field_name("key") + value = self.child_by_field_name("value") + if TSFunction.is_valid_node(value.ts_node): + value = self._parse_expression(value.ts_node) + elif self.ts_node.type == "shorthand_property_identifier": + key = value = self._parse_expression(self.ts_node) + elif TSFunction.is_valid_node(self.ts_node): + value = self._parse_expression(self.ts_node) + key = value.get_name() + else: + return super()._get_key_value() + return key, value + + @writer + def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -> None: + """Reduces an editable to the following condition""" + if self.shorthand and node == self.value: + # Object shorthand + self.parent[self.key.source] = self.ctx.node_classes.bool_conversion[bool_condition] + else: + super().reduce_condition(bool_condition, node) + + +@apidoc +class TSDict(Dict, HasAttribute): + """A typescript dict object. You can use standard operations to operate on this dict (IE len, del, set, get, etc)""" + + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent, delimiter: str = ",", pair_type: type[Pair] = TSPair) -> None: + super().__init__(ts_node, file_node_id, ctx, parent, delimiter=delimiter, pair_type=pair_type) + + def __getitem__(self, __key: str) -> TExpression: + for pair in self._underlying: + pair_match = None + + if isinstance(pair, Pair): + if isinstance(pair.key, String): + if pair.key.content == str(__key): + pair_match = pair + elif pair.key is not None: + if pair.key.source == str(__key): + pair_match = pair + + if pair_match: + if pair_match.value is not None: + return pair_match.value + else: + return pair_match.key + msg = f"Key {__key} not found in {list(self.keys())} {self._underlying!r}" + raise KeyError(msg) + + def __setitem__(self, __key: str, __value: TExpression) -> None: + new_value = __value.source if isinstance(__value, Editable) else str(__value) + for pair in self._underlying: + pair_match = None + + if isinstance(pair, Pair): + if isinstance(pair.key, String): + if pair.key.content == str(__key): + pair_match = pair + elif pair.key is not None: + if pair.key.source == str(__key): + pair_match = pair + + if pair_match: + # CASE: {a: b} + if not pair_match.shorthand: + if __key == new_value: + pair_match.edit(f"{__key}") + else: + pair.value.edit(f"{new_value}") + # CASE: {a} + else: + if __key == new_value: + pair_match.edit(f"{__key}") + else: + pair_match.edit(f"{__key}: {new_value}") + break + # CASE: {} + else: + if not self.ctx.node_classes.int_dict_key: + try: + int(__key) + __key = f"'{__key}'" + except ValueError: + pass + if __key == new_value: + self._underlying.append(f"{__key}") + else: + self._underlying.append(f"{__key}: {new_value}") + + @reader + @noapidoc + @override + def resolve_attribute(self, name: str) -> "Expression | None": + return self.get(name, None) diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/ts_config.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/ts_config.py new file mode 100644 index 00000000..fe05a3c5 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/ts_config.py @@ -0,0 +1,485 @@ +import os +from functools import cache +from pathlib import Path +from typing import TYPE_CHECKING + +import pyjson5 + +from graph_sitter.core.directory import Directory +from graph_sitter.core.file import File +from graph_sitter.shared.decorators.docs import ts_apidoc +from graph_sitter.shared.logging.get_logger import get_logger + +if TYPE_CHECKING: + from graph_sitter.typescript.config_parser import TSConfigParser + from graph_sitter.typescript.file import TSFile + +logger = get_logger(__name__) + + +@ts_apidoc +class TSConfig: + """TypeScript configuration file specified in tsconfig.json, used for import resolution and computing dependencies. + + Attributes: + config_file: The configuration file object representing the tsconfig.json file. + config_parser: The parser used to interpret the TypeScript configuration. + config: A dictionary containing the parsed configuration settings. + """ + + config_file: File + config_parser: "TSConfigParser" + config: dict + + # Base config values + _base_config: "TSConfig | None" = None + _base_url: str | None = None + _out_dir: str | None = None + _root_dir: str | None = None + _root_dirs: list[str] = [] + _paths: dict[str, list[str]] = {} + _references: list[tuple[str, Directory | File]] = [] + + # Self config values + _self_base_url: str | None = None + _self_out_dir: str | None = None + _self_root_dir: str | None = None + _self_root_dirs: list[str] = [] + _self_paths: dict[str, list[str]] = {} + _self_references: list[Directory | File] = [] + + # Precomputed import aliases + _computed_path_import_aliases: bool = False + _path_import_aliases: dict[str, list[str]] = {} + _reference_import_aliases: dict[str, list[str]] = {} + # Optimization hack. If all the path alises start with `@` or `~`, then we can skip any path that doesn't start with `@` or `~` + # when computing the import resolution. + _import_optimization_enabled: bool = False + + def __init__(self, config_file: File, config_parser: "TSConfigParser"): + self.config_file = config_file + self.config_parser = config_parser + # Try to parse the config file as JSON5. Fallback to empty dict if it fails. + # We use json5 because it supports comments in the config file. + try: + self.config = pyjson5.loads(config_file.content) + except pyjson5.Json5Exception: + logger.exception(f"Failed to parse tsconfig.json file: {config_file.filepath}") + self.config = {} + + # Precompute the base config, base url, paths, and references + self._precompute_config_values() + + def __repr__(self): + return f"TSConfig({self.config_file.filepath})" + + def _precompute_config_values(self): + """Precomputes the base config, base url, paths, and references.""" + # Precompute the base config + self._base_config = None + extends = self.config.get("extends", None) + if isinstance(extends, list): + # TODO: Support multiple extends + extends = extends[0] # Grab the first config in the list + base_config_path = self._parse_parent_config_path(extends) + + if base_config_path and base_config_path.exists(): + self._base_config = self.config_parser.get_config(base_config_path) + + # Precompute the base url + self._base_url = None + self._self_base_url = None + if base_url := self.config.get("compilerOptions", {}).get("baseUrl", None): + self._base_url = base_url + self._self_base_url = base_url + elif base_url := {} if self.base_config is None else self.base_config.base_url: + self._base_url = base_url + + # Precompute the outDir + self._out_dir = None + self._self_out_dir = None + if out_dir := self.config.get("compilerOptions", {}).get("outDir", None): + self._out_dir = out_dir + self._self_out_dir = out_dir + elif out_dir := {} if self.base_config is None else self.base_config.out_dir: + self._out_dir = out_dir + + # Precompute the rootDir + self._root_dir = None + self._self_root_dir = None + if root_dir := self.config.get("compilerOptions", {}).get("rootDir", None): + self._root_dir = root_dir + self._self_root_dir = root_dir + elif root_dir := {} if self.base_config is None else self.base_config.root_dir: + self._root_dir = root_dir + + # Precompute the rootDirs + self._root_dirs = [] + self._self_root_dirs = [] + if root_dirs := self.config.get("compilerOptions", {}).get("rootDirs", None): + self._root_dirs = root_dirs + self._self_root_dirs = root_dirs + elif root_dirs := [] if self.base_config is None else self.base_config.root_dirs: + self._root_dirs = root_dirs + + # Precompute the paths + base_paths = {} if self.base_config is None else self.base_config.paths + self_paths = self.config.get("compilerOptions", {}).get("paths", {}) + self._paths = {**base_paths, **self_paths} + self._self_paths = self_paths + + # Precompute the references + self_references = [] + references = self.config.get("references", None) + if references is not None: + for reference in references: + if ref_path := reference.get("path", None): + abs_ref_path = str(self.config_file.ctx.to_relative(self._relative_to_absolute_directory_path(ref_path))) + if directory := self.config_file.ctx.get_directory(self.config_file.ctx.to_absolute(abs_ref_path)): + self_references.append((ref_path, directory)) + elif ts_config := self.config_parser.get_config(abs_ref_path): + self_references.append((ref_path, ts_config.config_file)) + elif file := self.config_file.ctx.get_file(abs_ref_path): + self_references.append((ref_path, file)) + self._references = [*self_references] # MAYBE add base references here? This breaks the reference chain though. + self._self_references = self_references + + def _precompute_import_aliases(self): + """Precomputes the import aliases.""" + if self._computed_path_import_aliases: + return + + # Force compute alias of the base config + if self.base_config is not None: + self.base_config._precompute_import_aliases() + + # Precompute the formatted paths based on compilerOptions/paths + base_path_import_aliases = {} if self.base_config is None else self.base_config.path_import_aliases + self_path_import_aliases = {} + for pattern, relative_paths in self._self_paths.items(): + formatted_pattern = pattern.replace("*", "").rstrip("/").replace("//", "/") + formatted_relative_paths = [] + for relative_path in relative_paths: + cleaned_relative_path = relative_path.replace("*", "").rstrip("/").replace("//", "/") + if self._self_base_url: + cleaned_relative_path = os.path.join(self._self_base_url, cleaned_relative_path) + formatted_absolute_path = self._relative_to_absolute_directory_path(cleaned_relative_path) + formatted_relative_path = str(self.config_file.ctx.to_relative(formatted_absolute_path)) + # Fix absolute path if its base + if formatted_relative_path == ".": + formatted_relative_path = "" + formatted_relative_paths.append(formatted_relative_path) + self_path_import_aliases[formatted_pattern] = formatted_relative_paths + self._path_import_aliases = {**base_path_import_aliases, **self_path_import_aliases} + + # Precompute the formatted paths based on references + base_reference_import_aliases = {} if self.base_config is None else self.base_config.reference_import_aliases + self_reference_import_aliases = {} + # For each reference, try to grab its tsconfig. + for ref_path, reference in self._self_references: + # TODO: THIS ENTIRE PROCESS IS KINDA HACKY. + # If the reference is a file, get its directory. + if isinstance(reference, File): + reference_dir = self.config_file.ctx.get_directory(os.path.dirname(reference.filepath)) + elif isinstance(reference, Directory): + reference_dir = reference + else: + logger.warning(f"Unknown reference type during self_reference_import_aliases computation in _precompute_import_aliases: {type(reference)}") + continue + + # With the directory, try to grab the next available file and get its tsconfig. + if reference_dir and reference_dir.files(recursive=True): + next_file: TSFile = reference_dir.files(recursive=True)[0] + else: + logger.warning(f"No next file found for reference during self_reference_import_aliases computation in _precompute_import_aliases: {reference.dirpath}") + continue + target_ts_config = next_file.ts_config + if target_ts_config is None: + logger.warning(f"No tsconfig found for reference during self_reference_import_aliases computation in _precompute_import_aliases: {reference.dirpath}") + continue + + # With the tsconfig, grab its rootDirs and outDir + target_root_dirs = target_ts_config.root_dirs if target_ts_config.root_dirs else ["."] + target_out_dir = target_ts_config.out_dir + + # Calculate the formatted pattern and formatted relative paths + formatted_relative_paths = [os.path.normpath(os.path.join(reference_dir.path, root_dir)) for root_dir in target_root_dirs] + + # Loop through each possible path part of the reference + # For example, if the reference is "../../a/b/c" and the out dir is "dist" + # then the possible reference aliases are: + # - "a/b/c/dist" + # - "b/c/dist" + # - "c/dist" + # (ignoring any .. segments) + path_parts = [p for p in ref_path.split(os.path.sep) if p and not p.startswith("..")] + for i in range(len(path_parts)): + target_path = os.path.sep.join(path_parts[i:]) + if target_path: + formatted_target_path = os.path.normpath(os.path.join(target_path, target_out_dir) if target_out_dir else target_path) + self_reference_import_aliases[formatted_target_path] = formatted_relative_paths + + self._reference_import_aliases = {**base_reference_import_aliases, **self_reference_import_aliases} + + # Precompute _import_optimization_enabled + self._import_optimization_enabled = all(k.startswith("@") or k.startswith("~") for k in list(self.path_import_aliases.keys()) + list(self.reference_import_aliases.keys())) + + # Mark that we've precomputed the import aliases + self._computed_path_import_aliases = True + + def _parse_parent_config_path(self, config_filepath: str | None) -> Path | None: + """Returns a TSConfig object from a file path.""" + if config_filepath is None: + return None + + path = self._relative_to_absolute_directory_path(config_filepath) + return Path(path if path.suffix == ".json" else f"{path}.json") + + def _relative_to_absolute_directory_path(self, relative_path: str) -> Path: + """Helper to go from a relative module to an absolute one. + Ex: "../pkg-common/" would be -> "src/dir/pkg-common/" + """ + # TODO: This could also use its parent config to resolve the path + relative = self.config_file.path.parent / relative_path.strip('"') + return self.config_file.ctx.to_absolute(relative) + + def translate_import_path(self, import_path: str) -> str: + """Translates an import path to an absolute path using the tsconfig paths. + + Takes an import path and translates it to an absolute path using the configured paths in the tsconfig file. If the import + path matches a path alias, it will be resolved according to the tsconfig paths mapping. + + For example, converts `@abc/my/pkg/src` to `a/b/c/my/pkg/src` or however it's defined in the tsconfig. + + Args: + import_path (str): The import path to translate. + + Returns: + str: The translated absolute path. If no matching path alias is found, returns the original import path unchanged. + """ + # Break out early if we can + if self._import_optimization_enabled and not import_path.startswith("@") and not import_path.startswith("~"): + return import_path + + # Step 1: Try to resolve with import_resolution_overrides + if self.config_file.ctx.config.import_resolution_overrides: + if path_check := TSConfig._find_matching_path(frozenset(self.config_file.ctx.config.import_resolution_overrides.keys()), import_path): + to_base = self.config_file.ctx.config.import_resolution_overrides[path_check] + + # Get the remaining path after the matching prefix + remaining_path = import_path[len(path_check) :].lstrip("/") + + # Join the path together + import_path = os.path.join(to_base, remaining_path) + + return import_path + + # Step 2: Keep traveling down the parent config paths until we find a match a reference_import_aliases + if path_check := TSConfig._find_matching_path(frozenset(self.reference_import_aliases.keys()), import_path): + # TODO: This assumes that there is only one to_base path for the given from_base path + to_base = self.reference_import_aliases[path_check][0] + + # Get the remaining path after the matching prefix + remaining_path = import_path[len(path_check) :].lstrip("/") + + # Join the path together + import_path = os.path.join(to_base, remaining_path) + + return import_path + + # Step 3: Keep traveling down the parent config paths until we find a match a path_import_aliases + if path_check := TSConfig._find_matching_path(frozenset(self.path_import_aliases.keys()), import_path): + # TODO: This assumes that there is only one to_base path for the given from_base path + to_base = self.path_import_aliases[path_check][0] + + # Get the remaining path after the matching prefix + remaining_path = import_path[len(path_check) :].lstrip("/") + + # Join the path together + import_path = os.path.join(to_base, remaining_path) + + return import_path + + # Step 4: Try to resolve with base path for non-relative imports + return self.resolve_base_url(import_path) + + def translate_absolute_path(self, absolute_path: str) -> str: + """Translates an absolute path to an import path using the tsconfig paths. + + Takes an absolute path and translates it to an import path using the configured paths in the tsconfig file. + + For example, converts `a/b/c/my/pkg/src` to `@abc/my/pkg/src` or however it's defined in the tsconfig. + + Args: + import_path (str): The absolute path to translate. + + Returns: + str: The translated import path. + """ + path_aliases = self._path_import_aliases + for alias, paths in path_aliases.items(): + for path in paths: + if absolute_path.startswith(path): + # Pick the first alias that matches + return absolute_path.replace(path, alias, 1) + + return absolute_path + + def resolve_base_url(self, import_path: str) -> str: + """Resolves an import path with the base url. + + If a base url is not defined, try to resolve it with its base config. + """ + # Do nothing if the import path is relative + if import_path.startswith("."): + return import_path + + # If the current config has a base url, use itq + if self._self_base_url: + if not import_path.startswith(self._self_base_url): + import_path = os.path.join(self._self_base_url, import_path) + import_path = str(self._relative_to_absolute_directory_path(import_path)) + return import_path + # If there is a base config, try to resolve it with its base url + elif self.base_config: + return self.base_config.resolve_base_url(import_path) + # Otherwise, do nothing + else: + return import_path + + @staticmethod + @cache + def _find_matching_path(path_import_aliases: set[str], path_check: str): + """Recursively find the longest matching path in path_import_aliases.""" + # Base case + if not path_check or path_check == "/": + return None + + # Recursive case + if path_check in path_import_aliases: + return path_check + elif f"{path_check}/" in path_import_aliases: + return f"{path_check}/" + else: + return TSConfig._find_matching_path(path_import_aliases, os.path.dirname(path_check)) + + @property + def base_config(self) -> "TSConfig | None": + """Returns the base TSConfig that this config inherits from. + + Gets the base configuration file that this TSConfig extends. The base configuration is used for inheriting settings like paths, baseUrl,and other compiler options. + + Returns: + TSConfig | None: The parent TSConfig object if this config extends another config file, None otherwise. + """ + return self._base_config + + @property + def base_url(self) -> str | None: + """Returns the base URL defined in the TypeScript configuration. + + This property retrieves the baseUrl from the project's TypeScript configuration file. + The baseUrl is used for resolving non-relative module names. + + Returns: + str | None: The base URL if defined in the config file or inherited from a base config, + None if not specified. + """ + return self._base_url + + @property + def out_dir(self) -> str | None: + """Returns the outDir defined in the TypeScript configuration. + + The outDir specifies the output directory for all emitted files. When specified, .js (as well as .d.ts, .js.map, etc.) + files will be emitted into this directory. The directory structure of the source files is preserved. + + Returns: + str | None: The output directory path if specified in the config file or inherited from a base config, + None if not specified. + """ + return self._out_dir + + @property + def root_dir(self) -> str | None: + """Returns the rootDir defined in the TypeScript configuration. + + The rootDir specifies the root directory of input files. This is used to control the output directory structure + with outDir. When TypeScript compiles files, it maintains the directory structure of the source files relative + to rootDir when generating output. + + Returns: + str | None: The root directory path if specified in the config file or inherited from a base config, + None if not specified. + """ + return self._root_dir + + @property + def root_dirs(self) -> list[str]: + """Returns the rootDirs defined in the TypeScript configuration. + + The rootDirs allows a list of root directories to be specified that are merged and treated as one virtual directory. + This can be used when your project structure doesn't match your runtime expectations. For example, when you have + both generated and hand-written source files that need to appear to be in the same directory at runtime. + + Returns: + list[str]: A list of root directory paths specified in the config file or inherited from a base config. + Returns an empty list if not specified. + """ + if self._root_dirs is not None: + return self._root_dirs + elif self.root_dir is not None: + return [self.root_dir] + return [] + + @property + def paths(self) -> dict[str, list[str]]: + """Returns all custom module path mappings defined in the tsconfig file. + + Retrieves path mappings from both the current tsconfig file and any inherited base config file, + translating all relative paths to absolute paths. + + Returns: + dict[str, list[str]]: A dictionary mapping path patterns to lists of absolute path destinations. + Each key is a path pattern (e.g., '@/*') and each value is a list of corresponding + absolute path destinations. + """ + return self._paths + + @property + def references(self) -> list[Directory | File]: + """Returns a list of directories that this TypeScript configuration file depends on. + + The references are defined in the 'references' field of the tsconfig.json file. These directories + are used to resolve import conflicts and narrow the search space for import resolution. + + Returns: + list[Directory | File | TSConfig]: A list of Directory, File, or TSConfig objects representing the dependent directories. + """ + return self._references + + @property + def path_import_aliases(self) -> dict[str, list[str]]: + """Returns a formatted version of the paths property from a TypeScript configuration file. + + Processes the paths dictionary by formatting path patterns and their corresponding target paths. All wildcards (*), trailing slashes, and double + slashes are removed from both the path patterns and their target paths. Target paths are also converted from relative to absolute paths. + + Returns: + dict[str, list[str]]: A dictionary where keys are formatted path patterns and values are lists of formatted absolute target paths. + """ + return self._path_import_aliases + + @property + def reference_import_aliases(self) -> dict[str, list[str]]: + """Returns a formatted version of the references property from a TypeScript configuration file. + + Processes the references dictionary by formatting reference paths and their corresponding target paths. For each + reference, retrieves its tsconfig file and path mappings. Also includes any path mappings inherited from base + configs. + + Returns: + dict[str, list[str]]: A dictionary where keys are formatted reference paths (e.g. 'module/dist') and values + are lists of absolute target paths derived from the referenced tsconfig's rootDirs and outDir settings. + """ + return {k: [str(self.config_file.ctx.to_relative(v)) for v in vs] for k, vs in self._reference_import_aliases.items()} diff --git a/Libraries/graph_sitter_lib/graph_sitter/typescript/type_alias.py b/Libraries/graph_sitter_lib/graph_sitter/typescript/type_alias.py new file mode 100644 index 00000000..43de5d66 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/typescript/type_alias.py @@ -0,0 +1,73 @@ +from graph_sitter.core.autocommit import commiter, reader +from graph_sitter.core.dataclasses.usage import UsageKind +from graph_sitter.core.interfaces.has_name import HasName +from graph_sitter.core.type_alias import TypeAlias +from graph_sitter.enums import SymbolType +from graph_sitter.shared.decorators.docs import noapidoc, ts_apidoc +from graph_sitter.typescript.detached_symbols.code_block import TSCodeBlock +from graph_sitter.typescript.interfaces.has_block import TSHasBlock +from graph_sitter.typescript.statements.attribute import TSAttribute +from graph_sitter.typescript.symbol import TSSymbol + + +@ts_apidoc +class TSTypeAlias(TypeAlias[TSCodeBlock, TSAttribute], TSSymbol, TSHasBlock): + """Representation of an Interface in TypeScript. + + Attributes: + symbol_type: The type of symbol, set to SymbolType.Type. + """ + + symbol_type = SymbolType.Type + + @noapidoc + @commiter + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: + dest = dest or self.self_dest + # =====[ Type Identifiers ]===== + # Look for type references in the interface body + self.value._compute_dependencies(UsageKind.TYPE_DEFINITION, dest) + self.code_block._compute_dependencies(UsageKind.TYPE_DEFINITION, dest) + # body = self.ts_node.child_by_field_name("value") + # if body: + # # Handle type queries (typeof) + # type_queries = find_all_descendants(body, ["type_query"]) + # for type_query in type_queries: + # query_identifiers = find_all_descendants(type_query, ["identifier"]) + # self._add_symbol_usages(query_identifiers, SymbolUsageType.TYPE) + # + # type_identifiers = find_all_descendants(body, ["type_identifier"]) + # self._add_symbol_usages(type_identifiers, SymbolUsageType.TYPE) + if self.type_parameters: + self.type_parameters._compute_dependencies(UsageKind.GENERIC, dest) + + @reader + def _parse_code_block(self) -> TSCodeBlock: + """Returns the code block of the function""" + value_node = self.ts_node.child_by_field_name("value") + return super()._parse_code_block(value_node) + + @property + @reader + def attributes(self) -> list[TSAttribute]: + """Retrieves all attributes belonging to this type alias. + + Returns a list of attributes that are defined within the type alias's code block. + These attributes represent named values or properties associated with the type alias. + + Returns: + list[TSAttribute[TSTypeAlias, None]]: A list of TSAttribute objects representing the type alias's attributes. + """ + return self.code_block.attributes + + @reader + def get_attribute(self, name: str) -> TSAttribute | None: + """Retrieves a specific attribute from a TypeScript type alias by its name. + + Args: + name (str): The name of the attribute to retrieve. + + Returns: + TSAttribute[TSTypeAlias, None] | None: The attribute with the specified name if found, None otherwise. + """ + return next((x for x in self.attributes if x.name == name), None) diff --git a/Libraries/graph_sitter_lib/graph_sitter/utils.py b/Libraries/graph_sitter_lib/graph_sitter/utils.py new file mode 100644 index 00000000..de614ced --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/utils.py @@ -0,0 +1,341 @@ +import os +import re +import shutil +import statistics +from collections.abc import Iterable +from contextlib import contextmanager +from xml.dom.minidom import parseString + +import dicttoxml +import xmltodict +from tree_sitter import Node as TSNode + +from graph_sitter.compiled.utils import find_all_descendants, find_first_descendant, get_all_identifiers +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage +from graph_sitter.typescript.enums import TSFunctionTypeNames + +""" +Utility functions for traversing the tree sitter structure. +Do not include language specific traversals, or string manipulations here. +""" + + +class XMLUtils: + @staticmethod + def dict_to_xml(data: dict | list, format: bool = False, **kwargs) -> str: + result = dicttoxml.dicttoxml(data, return_bytes=False, **kwargs) + if not isinstance(result, str): + msg = "Failed to convert dict to XML" + raise ValueError(msg) + if format: + result = parseString(result).toprettyxml() + return result + + @staticmethod + def add_cdata_to_function_body(xml_string): + pattern = r"()(.*?)()" + replacement = r"\1\3" + updated_xml_string = re.sub(pattern, replacement, xml_string, flags=re.DOTALL) + return updated_xml_string + + @staticmethod + def add_cdata_to_tags(xml_string: str, tags: Iterable[str]) -> str: + patterns = [rf"(<{tag}>)(.*?)()" for tag in tags] + updated_xml_string = xml_string + + for pattern in patterns: + replacement = r"\1\3" + updated_xml_string = re.sub(pattern, replacement, updated_xml_string, flags=re.DOTALL) + + return updated_xml_string + + @staticmethod + def xml_to_dict(xml_string: str, **kwargs) -> dict: + return xmltodict.parse(XMLUtils.add_cdata_to_tags(xml_string, ["function_body", "reasoning"]), **kwargs) + + @staticmethod + def strip_after_tag(xml_string, tag): + pattern = re.compile(f"<{tag}.*?>.*", re.DOTALL) + match = pattern.search(xml_string) + if match: + return xml_string[: match.start()] + else: + return xml_string + + @staticmethod + def strip_tag(xml_string: str, tag: str): + pattern = re.compile(f"<{tag}>.*?", re.DOTALL) + return pattern.sub("", xml_string).strip() + + @staticmethod + def strip_all_tags(xml_string: str): + pattern = re.compile(r"<[^>]*>") + return pattern.sub("", xml_string).strip() + + @staticmethod + def extract_elements(xml_string: str, tag: str, keep_tag: bool = False) -> list[str]: + pattern = re.compile(f"<{tag}.*?", re.DOTALL) + matches = pattern.findall(xml_string) + if keep_tag: + return matches + else: + return [match.strip(f"<{tag}>").strip(f"") for match in matches] + + +def find_first_function_descendant(node: TSNode) -> TSNode: + type_names = [function_type.value for function_type in TSFunctionTypeNames] + return find_first_descendant(node=node, type_names=type_names, max_depth=2) + + +def find_import_node(node: TSNode) -> TSNode | None: + """Get the import node from a node that may contain an import. + Returns None if the node does not contain an import. + + Returns: + TSNode | None: The import_statement or call_expression node if it's an import, None otherwise + """ + # Static imports + if node.type == "import_statement": + return node + + # Dynamic imports and requires can be either: + # 1. Inside expression_statement -> call_expression + # 2. Direct call_expression + + # we only parse imports inside expressions and variable declarations + + if member_expression := find_first_descendant(node, ["member_expression"]): + # there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module) + descendants = find_all_descendants(member_expression, ["call_expression"], stop_at_first="statement_block") + if descendants: + import_node = descendants[-1] + else: + # this means this is NOT a dynamic import() + return None + else: + import_node = find_first_descendant(node, ["call_expression"]) + + # thus we only consider the deepest one + if import_node: + function = import_node.child_by_field_name("function") + if function and (function.type == "import" or (function.type == "identifier" and function.text.decode("utf-8") == "require")): + return import_node + + return None + + +def find_index(target: TSNode, siblings: list[TSNode]) -> int: + """Returns the index of the target node in the list of siblings, or -1 if not found. Recursive implementation.""" + if target in siblings: + return siblings.index(target) + + for i, sibling in enumerate(siblings): + index = find_index(target, sibling.named_children if target.is_named else sibling.children) + if index != -1: + return i + return -1 + + +def find_first_ancestor(node: TSNode, type_names: list[str], max_depth: int | None = None) -> TSNode | None: + depth = 0 + while node is not None and (max_depth is None or depth <= max_depth): + if node.type in type_names: + return node + node = node.parent + depth += 1 + return None + + +def find_first_child_by_field_name(node: TSNode, field_name: str) -> TSNode | None: + child = node.child_by_field_name(field_name) + if child is not None: + return child + for child in node.children: + first_descendant = find_first_child_by_field_name(child, field_name) + if first_descendant is not None: + return first_descendant + return None + + +def has_descendant(node: TSNode, type_name: str) -> bool: + def traverse(current_node: TSNode, depth: int = 0) -> bool: + if current_node.type == type_name: + return True + return any(traverse(child, depth + 1) for child in current_node.children) + + return traverse(node) + + +def get_first_identifier(node: TSNode) -> TSNode | None: + """Get the text of the first identifier child of a tree-sitter node. Recursive implementation""" + if node.type in ("identifier", "shorthand_property_identifier_pattern"): + return node + for child in node.children: + output = get_first_identifier(child) + if output is not None: + return output + return None + + +def descendant_for_byte_range(node: TSNode, start_byte: int, end_byte: int, allow_comment_boundaries: bool = True) -> TSNode | None: + """Proper implementation of descendant_for_byte_range, which returns the lowest node that contains the byte range.""" + ts_match = node.descendant_for_byte_range(start_byte, end_byte) + + # We don't care if the match overlaps with comments + if allow_comment_boundaries: + return ts_match + + # Want to prevent it from matching with part of the match within a comment + else: + if not ts_match.children: + return ts_match + comments = find_all_descendants(ts_match, "comment") + # see if any of these comments partially overlaps with the match + if any(comment.start_byte < start_byte < comment.end_byte or comment.start_byte < end_byte < comment.end_byte for comment in comments): + return None + return ts_match + + +@contextmanager +def shadow_files(files: str | list[str]): + """Creates shadow copies of the given files. Restores the original files after the context manager is exited. + + Returns list of filenames of shadowed files. + """ + if isinstance(files, str): + files = [files] + shadowed_files = {} + # Generate shadow file names + for file_name in files: + shadow_file_name = file_name + ".gs_internal.bak" + shadowed_files[file_name] = shadow_file_name + # Shadow files + try: + # Backup the original files + for file_name, shadow_file_name in shadowed_files.items(): + shutil.copy(file_name, shadow_file_name) + yield shadowed_files.values() + finally: + # Restore the original files + for file_name, shadow_file_name in shadowed_files.items(): + # If shadow file was created, restore the original file and delete the shadow file + if os.path.exists(shadow_file_name): + # Delete the original file if it exists + if os.path.exists(file_name): + os.remove(file_name) + # Copy the shadow file to the original file path + shutil.copy(shadow_file_name, file_name) + # Delete the shadow file + os.remove(shadow_file_name) + + +def calculate_base_path(full_path, relative_path): + """Calculate the base path represented by './' in a relative path. + + :param full_path: The full path to a file or directory + :param relative_path: A relative path starting with './' + :return: The base path represented by './' in the relative path + """ + # Normalize paths to handle different path separators + full_path = os.path.normpath(full_path) + relative_path = os.path.normpath(relative_path) + + # Split paths into components + full_components = full_path.split(os.sep) + relative_components = relative_path.split(os.sep) + + # Remove './' from the start of relative path if present + if relative_components[0] == ".": + relative_components = relative_components[1:] + + # Calculate the number of components to keep from the full path + keep_components = len(full_components) - len(relative_components) + + # Join the components to form the base path + base_path = os.sep.join(full_components[:keep_components]) + + return base_path + + +__all__ = [ + "find_all_descendants", + "find_first_ancestor", + "find_first_child_by_field_name", + "find_first_descendant", + "get_all_identifiers", + "has_descendant", +] + + +def get_language_file_extensions(language: ProgrammingLanguage): + """Returns the file extensions for the given language.""" + from graph_sitter.python import PyFile + from graph_sitter.typescript.file import TSFile + + if language == ProgrammingLanguage.PYTHON: + return set(PyFile.get_extensions()) + elif language == ProgrammingLanguage.TYPESCRIPT: + return set(TSFile.get_extensions()) + + +def truncate_line(input: str, max_chars: int) -> str: + input = str(input) + if len(input) > max_chars: + return input[:max_chars] + f"...(truncated from {len(input)} characters)." + return input + + +def is_minified_js(content): + """Analyzes a string to determine if it contains minified JavaScript code. + + Args: + content: String containing JavaScript code to analyze + + Returns: + bool: True if the content appears to be minified JavaScript, False otherwise + """ + try: + # Skip empty content + if not content.strip(): + return False + + # Characteristics of minified JS files + lines = content.split("\n") + + # 1. Check for average line length (minified files have very long lines) + line_lengths = [len(line) for line in lines if line.strip()] + if not line_lengths: # Handle empty content case + return False + + avg_line_length = statistics.mean(line_lengths) + + # 2. Check for semicolon-to-newline ratio (minified often has ; instead of newlines) + semicolons = content.count(";") + newlines = len(lines) - 1 + semicolon_ratio = semicolons / max(newlines, 1) # Avoid division by zero + + # 3. Check whitespace ratio (minified has low whitespace) + whitespace_chars = len(re.findall(r"[\s]", content)) + total_chars = len(content) + whitespace_ratio = whitespace_chars / total_chars if total_chars else 0 + + # 4. Check for common minification patterns + has_common_patterns = bool(re.search(r"[\w\)]\{[\w:]+\}", content)) # Condensed object notation + + # 5. Check for short variable names (common in minified code) + variable_names = re.findall(r"var\s+(\w+)", content) + avg_var_length = statistics.mean([len(name) for name in variable_names]) if variable_names else 0 + + # Decision logic - tuned threshold values + is_minified = ( + (avg_line_length > 250) # Very long average line length + and (semicolon_ratio > 0.8 or has_common_patterns) # High semicolon ratio or minification patterns + and (whitespace_ratio < 0.08) # Very low whitespace ratio + and (avg_var_length < 3 or not variable_names) # Extremely short variable names or no vars + ) + + return is_minified + + except Exception as e: + print(f"Error analyzing content: {e}") + return False diff --git a/Libraries/graph_sitter_lib/graph_sitter/visualizations/blast_radius.py b/Libraries/graph_sitter_lib/graph_sitter/visualizations/blast_radius.py new file mode 100644 index 00000000..98aedf40 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/visualizations/blast_radius.py @@ -0,0 +1,118 @@ +import graph_sitter +import networkx as nx +from graph_sitter import Codebase +from graph_sitter.core.dataclasses.usage import Usage +from graph_sitter.python.function import PyFunction +from graph_sitter.python.symbol import PySymbol + +# Create a directed graph for visualizing relationships between code elements +G = nx.DiGraph() + +# Maximum depth to traverse in the call graph to prevent infinite recursion +MAX_DEPTH = 5 + +# Define colors for different types of nodes in the visualization +COLOR_PALETTE = { + "StartFunction": "#9cdcfe", # Starting function (light blue) + "PyFunction": "#a277ff", # Python functions (purple) + "PyClass": "#ffca85", # Python classes (orange) + "ExternalModule": "#f694ff", # External module imports (pink) + "HTTP_METHOD": "#ffca85", # HTTP method handlers (orange) +} + +# List of common HTTP method names to identify route handlers +HTTP_METHODS = ["get", "put", "patch", "post", "head", "delete"] + + +def generate_edge_meta(usage: Usage) -> dict: + """ + Generate metadata for graph edges based on a usage relationship. + + Args: + usage: A Usage object representing how a symbol is used + + Returns: + dict: Edge metadata including source location and symbol info + """ + return {"name": usage.match.source, "file_path": usage.match.filepath, "start_point": usage.match.start_point, "end_point": usage.match.end_point, "symbol_name": usage.match.__class__.__name__} + + +def is_http_method(symbol: PySymbol) -> bool: + """ + Check if a symbol represents an HTTP method handler. + + Args: + symbol: A Python symbol to check + + Returns: + bool: True if symbol is an HTTP method handler + """ + if isinstance(symbol, PyFunction) and symbol.is_method: + return symbol.name in HTTP_METHODS + return False + + +def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): + """ + Recursively build a graph visualization showing how a symbol is used. + Shows the "blast radius" - everything that would be affected by changes. + + Args: + symbol: Starting symbol to analyze + depth: Current recursion depth + """ + # Stop recursion if we hit max depth + if depth >= MAX_DEPTH: + return + + # Process each usage of the symbol + for usage in symbol.usages: + usage_symbol = usage.usage_symbol + + # Determine node color based on symbol type + if is_http_method(usage_symbol): + color = COLOR_PALETTE.get("HTTP_METHOD") + else: + color = COLOR_PALETTE.get(usage_symbol.__class__.__name__, "#f694ff") + + # Add node and edge to graph + G.add_node(usage_symbol, color=color) + G.add_edge(symbol, usage_symbol, **generate_edge_meta(usage)) + + # Recurse to process usages of this symbol + create_blast_radius_visualization(usage_symbol, depth + 1) + + +@graph_sitter.function("visualize-function-blast-radius") +def run(codebase: Codebase): + """ + Generate a visualization showing the blast radius of changes to a function. + + This codemod: + 1. Identifies all usages of a target function + 2. Creates a graph showing how the function is used throughout the codebase + 3. Highlights HTTP method handlers and different types of code elements + """ + global G + G = nx.DiGraph() + + # Get the target function to analyze + target_func = codebase.get_function("export_asset") + + # Add starting function to graph with special color + G.add_node(target_func, color=COLOR_PALETTE.get("StartFunction")) + + # Build the visualization starting from target function + create_blast_radius_visualization(target_func) + + print(G) + print("Use codegen.sh to visualize the graph!") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase.from_repo("codegen-oss/posthog", commit="b174f2221ea4ae50e715eb6a7e70e9a2b0760800", language="python") + print(f"Codebase with {len(codebase.files)} files and {len(codebase.functions)} functions.") + print("Creating graph...") + + run(codebase) diff --git a/Libraries/graph_sitter_lib/graph_sitter/visualizations/call_trace.py b/Libraries/graph_sitter_lib/graph_sitter/visualizations/call_trace.py new file mode 100644 index 00000000..3a0a7439 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/visualizations/call_trace.py @@ -0,0 +1,120 @@ +import graph_sitter +import networkx as nx +from graph_sitter import Codebase +from graph_sitter.core.class_definition import Class +from graph_sitter.core.detached_symbols.function_call import FunctionCall +from graph_sitter.core.external_module import ExternalModule +from graph_sitter.core.function import Function + +G = nx.DiGraph() + +IGNORE_EXTERNAL_MODULE_CALLS = True +IGNORE_CLASS_CALLS = False +MAX_DEPTH = 10 + +# Color scheme for different types of nodes in the visualization +# Each node type has a distinct color for better visual differentiation +COLOR_PALETTE = { + "StartFunction": "#9cdcfe", # Base purple - draws attention to the root node + "PyFunction": "#a277ff", # Mint green - complementary to purple + "PyClass": "#ffca85", # Warm peach - provides contrast + "ExternalModule": "#f694ff", # Light pink - analogous to base purple +} + + +def generate_edge_meta(call: FunctionCall) -> dict: + """Generate metadata for graph edges representing function calls + + Args: + call (FunctionCall): Object containing information about the function call + + Returns: + dict: Metadata including name, file path, and location information + """ + return {"name": call.name, "file_path": call.filepath, "start_point": call.start_point, "end_point": call.end_point, "symbol_name": "FunctionCall"} + + +def create_downstream_call_trace(src_func: Function, depth: int = 0): + """Creates call graph for parent function by recursively traversing all function calls + + This function builds a directed graph showing all downstream function calls, + up to MAX_DEPTH levels deep. Each node represents a function and edges + represent calls between functions. + + Args: + src_func (Function): The function for which a call graph will be created + depth (int): Current depth in the recursive traversal + """ + # Stop recursion if max depth reached + if MAX_DEPTH <= depth: + return + # Stop if the source is an external module + if isinstance(src_func, ExternalModule): + return + + # Examine each function call made by the source function + for call in src_func.function_calls: + # Skip recursive calls + if call.name == src_func.name: + continue + + # Get the function definition being called + func = call.function_definition + + # Skip if function definition not found + if not func: + continue + # Apply filtering based on configuration flags + if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS: + continue + if isinstance(func, Class) and IGNORE_CLASS_CALLS: + continue + + # Generate the display name for the function + # For methods, include the class name + if isinstance(func, (Class, ExternalModule)): + func_name = func.name + elif isinstance(func, Function): + func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name + + # Add node and edge to the graph with appropriate metadata + G.add_node(func, name=func_name, color=COLOR_PALETTE.get(func.__class__.__name__)) + G.add_edge(src_func, func, **generate_edge_meta(call)) + + # Recursively process called function if it's a regular function + if isinstance(func, Function): + create_downstream_call_trace(func, depth + 1) + + +@graph_sitter.function("visualize-function-call-relationships") +def run(codebase: Codebase): + """Generate a visualization of function call relationships in a codebase. + + This codemod: + 1. Creates a directed graph of function calls starting from a target method + 2. Tracks relationships between functions, classes, and external modules + 3. Generates a visual representation of the call hierarchy + """ + global G + G = nx.DiGraph() + + target_class = codebase.get_class("SharingConfigurationViewSet") + target_method = target_class.get_method("patch") + + # Generate the call graph starting from the target method + create_downstream_call_trace(target_method) + + # Add the root node (target method) to the graph + G.add_node(target_method, name=f"{target_class.name}.{target_method.name}", color=COLOR_PALETTE.get("StartFunction")) + + print(G) + print("Use codegen.sh to visualize the graph!") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase.from_repo("codegen-oss/posthog", commit="b174f2221ea4ae50e715eb6a7e70e9a2b0760800", language="python") + print(f"Codebase with {len(codebase.files)} files and {len(codebase.functions)} functions.") + print("Creating graph...") + + run(codebase) diff --git a/Libraries/graph_sitter_lib/graph_sitter/visualizations/dependency_trace.py b/Libraries/graph_sitter_lib/graph_sitter/visualizations/dependency_trace.py new file mode 100644 index 00000000..b145098d --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/visualizations/dependency_trace.py @@ -0,0 +1,82 @@ +import graph_sitter +import networkx as nx +from graph_sitter import Codebase +from graph_sitter.core.class_definition import Class +from graph_sitter.core.import_resolution import Import +from graph_sitter.core.symbol import Symbol + +G = nx.DiGraph() + +IGNORE_EXTERNAL_MODULE_CALLS = True +IGNORE_CLASS_CALLS = False +MAX_DEPTH = 10 + +COLOR_PALETTE = { + "StartFunction": "#9cdcfe", # Light blue for the starting function + "PyFunction": "#a277ff", # Purple for Python functions + "PyClass": "#ffca85", # Orange for Python classes + "ExternalModule": "#f694ff", # Pink for external module references +} + +# Dictionary to track visited nodes and prevent cycles +visited = {} + + +def create_dependencies_visualization(symbol: Symbol, depth: int = 0): + """Creates a visualization of symbol dependencies in the codebase + + Recursively traverses the dependency tree of a symbol (function, class, etc.) + and creates a directed graph representation. Dependencies can be either direct + symbol references or imports. + + Args: + symbol (Symbol): The starting symbol whose dependencies will be mapped + depth (int): Current depth in the recursive traversal + """ + if depth >= MAX_DEPTH: + return + + for dep in symbol.dependencies: + dep_symbol = None + + if isinstance(dep, Symbol): + dep_symbol = dep + elif isinstance(dep, Import): + dep_symbol = dep.resolved_symbol if dep.resolved_symbol else None + + if dep_symbol: + G.add_node(dep_symbol, color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, "#f694ff")) + G.add_edge(symbol, dep_symbol) + + if not isinstance(dep_symbol, Class): + create_dependencies_visualization(dep_symbol, depth + 1) + + +@graph_sitter.function("visualize-symbol-dependencies") +def run(codebase: Codebase): + """Generate a visualization of symbol dependencies in a codebase. + + This codemod: + 1. Creates a directed graph of symbol dependencies starting from a target function + 2. Tracks relationships between functions, classes, and imports + 3. Generates a visual representation of the dependency hierarchy + """ + global G + G = nx.DiGraph() + + target_func = codebase.get_function("get_query_runner") + G.add_node(target_func, color=COLOR_PALETTE.get("StartFunction")) + + create_dependencies_visualization(target_func) + + print(G) + print("Use codegen.sh to visualize the graph!") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase.from_repo("codegen-oss/posthog", commit="b174f2221ea4ae50e715eb6a7e70e9a2b0760800", language="python") + print(f"Codebase with {len(codebase.files)} files and {len(codebase.functions)} functions.") + print("Creating graph...") + + run(codebase) diff --git a/Libraries/graph_sitter_lib/graph_sitter/visualizations/enums.py b/Libraries/graph_sitter_lib/graph_sitter/visualizations/enums.py new file mode 100644 index 00000000..fda63303 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/visualizations/enums.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from enum import StrEnum + + +@dataclass(frozen=True) +class VizNode: + name: str | None = None + text: str | None = None + code: str | None = None + color: str | None = None + shape: str | None = None + start_point: tuple | None = None + emoji: str | None = None + end_point: tuple | None = None + file_path: str | None = None + symbol_name: str | None = None + + +@dataclass(frozen=True) +class GraphJson: + type: str + data: dict + + +class GraphType(StrEnum): + TREE = "tree" + GRAPH = "graph" diff --git a/Libraries/graph_sitter_lib/graph_sitter/visualizations/method_relationships.py b/Libraries/graph_sitter_lib/graph_sitter/visualizations/method_relationships.py new file mode 100644 index 00000000..76df2439 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/visualizations/method_relationships.py @@ -0,0 +1,106 @@ +import graph_sitter +import networkx as nx +from graph_sitter import Codebase +from graph_sitter.core.class_definition import Class +from graph_sitter.core.detached_symbols.function_call import FunctionCall +from graph_sitter.core.external_module import ExternalModule +from graph_sitter.core.function import Function + +G = nx.DiGraph() + +# Configuration Settings +IGNORE_EXTERNAL_MODULE_CALLS = False +IGNORE_CLASS_CALLS = True +MAX_DEPTH = 100 + +# Track visited nodes to prevent duplicate processing +visited = set() + +COLOR_PALETTE = { + "StartMethod": "#9cdcfe", # Light blue for root/entry point methods + "PyFunction": "#a277ff", # Purple for regular Python functions + "PyClass": "#ffca85", # Warm peach for class definitions + "ExternalModule": "#f694ff", # Pink for external module calls + "StartClass": "#FFE082", # Yellow for the starting class +} + + +def graph_class_methods(target_class: Class): + """Creates a graph visualization of all methods in a class and their call relationships""" + G.add_node(target_class, color=COLOR_PALETTE["StartClass"]) + + for method in target_class.methods: + method_name = f"{target_class.name}.{method.name}" + G.add_node(method, name=method_name, color=COLOR_PALETTE["StartMethod"]) + visited.add(method) + G.add_edge(target_class, method) + + for method in target_class.methods: + create_downstream_call_trace(method) + + +def generate_edge_meta(call: FunctionCall) -> dict: + """Generate metadata for graph edges representing function calls""" + return {"name": call.name, "file_path": call.filepath, "start_point": call.start_point, "end_point": call.end_point, "symbol_name": "FunctionCall"} + + +def create_downstream_call_trace(src_func: Function, depth: int = 0): + """Creates call graph for parent function by recursively traversing all function calls""" + if MAX_DEPTH <= depth or isinstance(src_func, ExternalModule): + return + + for call in src_func.function_calls: + if call.name == src_func.name: + continue + + func = call.function_definition + if not func: + continue + + if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS: + continue + if isinstance(func, Class) and IGNORE_CLASS_CALLS: + continue + + if isinstance(func, (Class, ExternalModule)): + func_name = func.name + elif isinstance(func, Function): + func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name + + if func not in visited: + G.add_node(func, name=func_name, color=COLOR_PALETTE.get(func.__class__.__name__, None)) + visited.add(func) + + G.add_edge(src_func, func, **generate_edge_meta(call)) + + if isinstance(func, Function): + create_downstream_call_trace(func, depth + 1) + + +@graph_sitter.function("visualize-class-method-relationships") +def run(codebase: Codebase): + """Generate a visualization of method call relationships within a class. + + This codemod: + 1. Creates a directed graph with the target class as the root node + 2. Adds all class methods and their downstream function calls + 3. Generates a visual representation of the call hierarchy + """ + global G, visited + G = nx.DiGraph() + visited = set() + + target_class = codebase.get_class("_Client") + graph_class_methods(target_class) + + print(G) + print("Use codegen.sh to visualize the graph!") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase.from_repo("codegen-oss/modal-client", commit="00bf226a1526f9d775d2d70fc7711406aaf42958", language="python") + print(f"Codebase with {len(codebase.files)} files and {len(codebase.functions)} functions.") + print("Creating graph...") + + run(codebase) diff --git a/Libraries/graph_sitter_lib/graph_sitter/visualizations/py.typed b/Libraries/graph_sitter_lib/graph_sitter/visualizations/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/graph_sitter_lib/graph_sitter/visualizations/visualization_manager.py b/Libraries/graph_sitter_lib/graph_sitter/visualizations/visualization_manager.py new file mode 100644 index 00000000..2b9fb55f --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/visualizations/visualization_manager.py @@ -0,0 +1,63 @@ +import os + +import plotly.graph_objects as go +from networkx import Graph + +from graph_sitter.core.interfaces.editable import Editable +from graph_sitter.git.repo_operator.repo_operator import RepoOperator +from graph_sitter.shared.logging.get_logger import get_logger +from graph_sitter.visualizations.viz_utils import graph_to_json + +logger = get_logger(__name__) + + +class VisualizationManager: + op: RepoOperator + + def __init__( + self, + op: RepoOperator, + ) -> None: + self.op = op + + @property + def viz_path(self) -> str: + return os.path.join(self.op.base_dir, "codegen-graphviz") + + @property + def viz_file_path(self) -> str: + return os.path.join(self.viz_path, "graph.json") + + def clear_graphviz_data(self) -> None: + if self.op.folder_exists(self.viz_path): + self.op.emptydir(self.viz_path) + + def write_graphviz_data(self, G: Graph | go.Figure, root: Editable | str | int | None = None) -> None: + """Writes the graph data to a file. + + Args: + ---- + G (Graph | go.Figure): A NetworkX Graph object representing the graph to be visualized. + root (str | None): The root node to visualize. Defaults to None. + + Returns: + ------ + None + """ + # Convert the graph to a JSON-serializable format + if isinstance(G, Graph): + graph_json = graph_to_json(G, root) + elif isinstance(G, go.Figure): + graph_json = G.to_json() + + # Check if the visualization path exists, if so, empty it + if self.op.folder_exists(self.viz_path): + self.op.emptydir(self.viz_path) + else: + # If the path doesn't exist, create it + self.op.mkdir(self.viz_path) + + # Write the graph data to a file + with open(self.viz_file_path, "w") as f: + f.write(graph_json) + f.flush() # Ensure data is written to disk diff --git a/Libraries/graph_sitter_lib/graph_sitter/visualizations/viz_utils.py b/Libraries/graph_sitter_lib/graph_sitter/visualizations/viz_utils.py new file mode 100644 index 00000000..52af6f30 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/visualizations/viz_utils.py @@ -0,0 +1,70 @@ +import json +import os +from dataclasses import asdict +from typing import TYPE_CHECKING + +import networkx as nx +from networkx import DiGraph, Graph + +from graph_sitter.core.interfaces.editable import Editable +from graph_sitter.core.interfaces.importable import Importable +from graph_sitter.output.utils import DeterministicJSONEncoder +from graph_sitter.visualizations.enums import GraphJson, GraphType + +if TYPE_CHECKING: + from graph_sitter.git.repo_operator.repo_operator import RepoOperator + +#################################################################################################################### +# READING GRAPH VISUALIZATION DATA +#################################################################################################################### + + +def get_graph_json(op: "RepoOperator"): + if os.path.exists(op.viz_file_path): + with open(op.viz_file_path) as f: + graph_json = json.load(f) + return graph_json + else: + return None + + +#################################################################################################################### +# NETWORKX GRAPH TO JSON +#################################################################################################################### + + +def get_node_options(node: Editable | str | int): + if isinstance(node, Editable): + return asdict(node.viz) + return {} + + +def get_node_id(node: Editable | str | int): + if isinstance(node, Importable): + return node.node_id + elif isinstance(node, Editable): + return str(node.span) + elif isinstance(node, str) or isinstance(node, int): + return node + + +def graph_to_json(G1: Graph, root: Editable | str | int | None = None): + G2 = DiGraph() + for node_tuple in G1.nodes(data=True): + options = get_node_options(node_tuple[0]) + options.update(node_tuple[1]) + G2.add_node(get_node_id(node_tuple[0]), **options) + + for edge_tuple in G1.edges(data=True): + options = edge_tuple[2] + if "symbol" in options: + print(get_node_options(options["symbol"])) + options.update(get_node_options(options["symbol"])) + del options["symbol"] + G2.add_edge(get_node_id(edge_tuple[0]), get_node_id(edge_tuple[1]), **options) + + if root: + root = get_node_id(root) + return json.dumps(asdict(GraphJson(type=GraphType.TREE.value, data=nx.tree_data(G2, root))), cls=DeterministicJSONEncoder, indent=2) + else: + return json.dumps(asdict(GraphJson(type=GraphType.GRAPH.value, data=nx.node_link_data(G2))), cls=DeterministicJSONEncoder, indent=2) diff --git a/Libraries/graph_sitter_lib/graph_sitter/writer_decorators.py b/Libraries/graph_sitter_lib/graph_sitter/writer_decorators.py new file mode 100644 index 00000000..7c10de0d --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter/writer_decorators.py @@ -0,0 +1,10 @@ +from graph_sitter.shared.enums.programming_language import ProgrammingLanguage + + +def canonical(codemod): + """Decorator for canonical Codemods that will be used for AI-agent prompts.""" + codemod._canonical = True + if not hasattr(codemod, "language") or codemod.language not in (ProgrammingLanguage.PYTHON, ProgrammingLanguage.TYPESCRIPT): + msg = "Canonical codemods must have a `language` attribute (PYTHON or TYPESCRIPT)." + raise AttributeError(msg) + return codemod diff --git a/Libraries/graph_sitter_lib/graph_sitter_analysis.py b/Libraries/graph_sitter_lib/graph_sitter_analysis.py new file mode 100644 index 00000000..1b0746f8 --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter_analysis.py @@ -0,0 +1,1676 @@ +#!/usr/bin/env python3 +""" +Comprehensive Graph-Sitter Analysis Module +Integrates all graph-sitter folder functionalities for complete codebase analysis +""" + +import os +import logging +import math +import re +from typing import Dict, List, Optional, Any, Tuple, Union +from pathlib import Path +from collections import defaultdict, Counter + +# Import all graph-sitter modules +from graph_sitter import Codebase +from graph_sitter.core.symbol import Symbol +from graph_sitter.core.function import Function +from graph_sitter.core.class_definition import Class +from graph_sitter.core.file import SourceFile +from graph_sitter.core.import_resolution import Import +from graph_sitter.core.external_module import ExternalModule + +# Import all analysis functions from graph-sitter modules +from graph_sitter.extensions.tools.codebase_analysis import ( + get_codebase_summary, + get_file_summary, + get_class_summary, + get_function_summary, + get_symbol_summary +) + +# Import visualization and analysis tools +from graph_sitter.extensions.tools.view_file import ( + ViewFileObservation, + add_line_numbers, + view_file +) + +from graph_sitter.extensions.tools.reveal_symbol import ( + SymbolInfo, + RevealSymbolObservation, + get_symbol_info, + truncate_source, + get_extended_context, + reveal_symbol, + hop_through_imports +) + +from graph_sitter.extensions.tools.list_directory import ( + DirectoryInfo, + ListDirectoryObservation, + list_directory +) + +from graph_sitter.extensions.tools.bash import ( + RunBashCommandObservation, + validate_command, + run_bash_command +) + +from graph_sitter.extensions.tools.reflection import ( + ReflectionSection, + ReflectionObservation, + parse_reflection_response, + perform_reflection +) + +from graph_sitter.extensions.tools.observation import Observation + +from graph_sitter.extensions.tools.tools import get_workspace_tools + +from graph_sitter.extensions.tools.tool_output_types import ( + EditFileArtifacts, + ViewFileArtifacts, + ListDirectoryArtifacts, + SearchMatch, + SearchArtifacts, + SemanticEditArtifacts, + RelaceEditArtifacts +) + +# Import documentation generation +from graph_sitter.extensions.tools.generate_docs_json import generate_docs_json +from graph_sitter.extensions.tools.mdx_docs_generation import ( + render_mdx_page_for_class, + render_mdx_page_title, + render_mdx_inheritence_section, + render_mdx_attributes_section, + render_mdx_methods_section, + render_mdx_for_attribute, + format_parameter_for_mdx, + format_parameters_for_mdx, + format_return_for_mdx, + render_mdx_for_method, + get_mdx_route_for_class, + format_type_string, + resolve_type_string, + format_builtin_type_string, + span_type_string_by_pipe, + parse_link +) + +# Import codebase utilities +from graph_sitter.extensions.tools.current_code_codebase import ( + get_current_code_codebase, + get_codegen_codebase_base_path, + get_graphsitter_repo_path, + import_all_codegen_sdk_modules, + get_documented_objects +) + +from graph_sitter.extensions.tools.codegen_sdk_codebase import ( + get_codegen_sdk_codebase, + get_codegen_sdk_subdirectories +) + +# Import document functions +from graph_sitter.extensions.tools.document_functions import ( + run as document_functions_run, + get_extended_context as doc_get_extended_context, + hop_through_imports as doc_hop_through_imports +) + +# Import visualization modules +from graph_sitter.extensions.tools.blast_radius import ( + generate_edge_meta as blast_generate_edge_meta, + is_http_method, + create_blast_radius_visualization, + run as blast_radius_run +) + +from graph_sitter.extensions.tools.call_trace import ( + generate_edge_meta as call_generate_edge_meta, + create_downstream_call_trace, + run as call_trace_run +) + +from graph_sitter.extensions.tools.dependency_trace import ( + create_dependencies_visualization, + run as dependency_trace_run +) + +from graph_sitter.extensions.tools.method_relationships import ( + generate_edge_meta as method_generate_edge_meta, + graph_class_methods, + create_downstream_call_trace as method_create_downstream_call_trace, + run as method_relationships_run +) + +import networkx as nx + +logger = logging.getLogger(__name__) + + +class GraphSitterAnalyzer: + """ + Comprehensive analysis engine using all Graph-Sitter capabilities. + Provides unified access to all graph-sitter folder functionalities. + """ + + def __init__(self, codebase: Codebase): + self.codebase = codebase + self.analysis_cache = {} + self.visualization_cache = {} + + # ============================================================================ + # CORE ANALYSIS FUNCTIONS (from codebase_analysis.py) + # ============================================================================ + + def get_codebase_overview(self) -> Dict[str, Any]: + """Provides a high-level overview of the codebase structure.""" + if "codebase_overview" in self.analysis_cache: + return self.analysis_cache["codebase_overview"] + + summary_str = get_codebase_summary(self.codebase) + + # Parse summary into structured data + overview = { + "summary": summary_str, + "files_count": len(list(self.codebase.files)), + "functions_count": len(list(self.codebase.functions)), + "classes_count": len(list(self.codebase.classes)), + "symbols_count": len(list(self.codebase.symbols)), + "imports_count": len(list(self.codebase.imports)), + "external_modules_count": len(list(self.codebase.external_modules)), + "entrypoints": self._identify_entrypoints(), + "dead_code_summary": self._get_dead_code_summary(), + "complexity_overview": self._get_complexity_overview() + } + + self.analysis_cache["codebase_overview"] = overview + return overview + + def get_file_details(self, filepath: str) -> Dict[str, Any]: + """Retrieves detailed information about a specific file.""" + cache_key = f"file_details_{filepath}" + if cache_key in self.analysis_cache: + return self.analysis_cache[cache_key] + + try: + file_obj = self.codebase.get_file(filepath) + summary_str = get_file_summary(file_obj) + + details = { + "filepath": filepath, + "summary": summary_str, + "functions": [self._get_function_summary(f) for f in file_obj.functions], + "classes": [self._get_class_summary(c) for c in file_obj.classes], + "imports": [self._get_import_summary(i) for i in file_obj.imports], + "symbols": [self._get_symbol_summary(s) for s in getattr(file_obj, "symbols", [])], + "metrics": { + "lines_of_code": len(file_obj.source.splitlines()) if hasattr(file_obj, "source") else 0, + "complexity_score": self._calculate_file_complexity(file_obj), + "maintainability_index": self._calculate_maintainability_index(file_obj), + "documentation_coverage": self._calculate_file_doc_coverage(file_obj) + } + } + + self.analysis_cache[cache_key] = details + return details + except ValueError: + return {"filepath": filepath, "error": "File not found in codebase."} + + def get_function_details(self, function_name: str, filepath: Optional[str] = None) -> Dict[str, Any]: + """Retrieves comprehensive information about a function.""" + cache_key = f"function_details_{function_name}_{filepath}" + if cache_key in self.analysis_cache: + return self.analysis_cache[cache_key] + + try: + symbols = self.codebase.get_symbols(symbol_name=function_name) + if not symbols: + return {"function_name": function_name, "error": "Function not found."} + + target_symbol = self._resolve_symbol_with_filepath(symbols, filepath, Function) + if not target_symbol: + return {"function_name": function_name, "filepath": filepath, "error": "Function not found at specified path or is not a function."} + + summary_str = get_function_summary(target_symbol) + reveal_info = reveal_symbol( + codebase=self.codebase, + symbol_name=function_name, + filepath=filepath, + max_depth=3, + max_tokens=5000 + ) + + details = { + "function_name": function_name, + "filepath": target_symbol.file.filepath if target_symbol.file else "N/A", + "summary": summary_str, + "parameters": self._get_function_parameters_details(target_symbol), + "return_type": self._get_function_return_type_details(target_symbol), + "local_variables": self._get_function_local_variables_details(target_symbol), + "dependencies": [self._symbol_to_dict(s) for s in reveal_info.dependencies] if reveal_info.dependencies else [], + "usages": [self._symbol_to_dict(s) for s in reveal_info.usages] if reveal_info.usages else [], + "call_sites": [self._call_site_to_dict(cs) for cs in target_symbol.call_sites], + "function_calls": [self._function_call_to_dict(fc) for fc in target_symbol.function_calls], + "complexity_metrics": self._calculate_function_complexity_metrics(target_symbol), + "quality_metrics": self._calculate_function_quality_metrics(target_symbol) + } + + self.analysis_cache[cache_key] = details + return details + except Exception as e: + return {"function_name": function_name, "error": f"Error retrieving function details: {e}"} + + def get_class_details(self, class_name: str, filepath: Optional[str] = None) -> Dict[str, Any]: + """Retrieves comprehensive information about a class.""" + cache_key = f"class_details_{class_name}_{filepath}" + if cache_key in self.analysis_cache: + return self.analysis_cache[cache_key] + + try: + symbols = self.codebase.get_symbols(symbol_name=class_name) + if not symbols: + return {"class_name": class_name, "error": "Class not found."} + + target_symbol = self._resolve_symbol_with_filepath(symbols, filepath, Class) + if not target_symbol: + return {"class_name": class_name, "filepath": filepath, "error": "Class not found at specified path or is not a class."} + + summary_str = get_class_summary(target_symbol) + reveal_info = reveal_symbol( + codebase=self.codebase, + symbol_name=class_name, + filepath=filepath, + max_depth=3, + max_tokens=5000 + ) + + details = { + "class_name": class_name, + "filepath": target_symbol.file.filepath if target_symbol.file else "N/A", + "summary": summary_str, + "methods": [self._get_method_summary(m) for m in target_symbol.methods], + "attributes": [self._get_attribute_summary(a) for a in target_symbol.attributes], + "superclasses": [self._class_to_dict(sc) for sc in target_symbol.superclasses], + "subclasses": [self._class_to_dict(sc) for sc in target_symbol.subclasses], + "dependencies": [self._symbol_to_dict(s) for s in reveal_info.dependencies] if reveal_info.dependencies else [], + "usages": [self._symbol_to_dict(s) for s in reveal_info.usages] if reveal_info.usages else [], + "inheritance_metrics": self._calculate_inheritance_metrics(target_symbol), + "complexity_metrics": self._calculate_class_complexity_metrics(target_symbol) + } + + self.analysis_cache[cache_key] = details + return details + except Exception as e: + return {"class_name": class_name, "error": f"Error retrieving class details: {e}"} + + def get_symbol_details(self, symbol_name: str, filepath: Optional[str] = None) -> Dict[str, Any]: + """Retrieves comprehensive information about any symbol.""" + cache_key = f"symbol_details_{symbol_name}_{filepath}" + if cache_key in self.analysis_cache: + return self.analysis_cache[cache_key] + + try: + symbols = self.codebase.get_symbols(symbol_name=symbol_name) + if not symbols: + return {"symbol_name": symbol_name, "error": "Symbol not found."} + + target_symbol = self._resolve_symbol_with_filepath(symbols, filepath, Symbol) + if not target_symbol: + return {"symbol_name": symbol_name, "filepath": filepath, "error": "Symbol not found at specified path."} + + summary_str = get_symbol_summary(target_symbol) + reveal_info = reveal_symbol( + codebase=self.codebase, + symbol_name=symbol_name, + filepath=filepath, + max_depth=3, + max_tokens=5000 + ) + + details = { + "symbol_name": symbol_name, + "symbol_type": type(target_symbol).__name__, + "filepath": target_symbol.file.filepath if target_symbol.file else "N/A", + "summary": summary_str, + "dependencies": [self._symbol_to_dict(s) for s in reveal_info.dependencies] if reveal_info.dependencies else [], + "usages": [self._symbol_to_dict(s) for s in reveal_info.usages] if reveal_info.usages else [], + "context": self._get_symbol_context(target_symbol) + } + + self.analysis_cache[cache_key] = details + return details + except Exception as e: + return {"symbol_name": symbol_name, "error": f"Error retrieving symbol details: {e}"} + + # ============================================================================ + # VISUALIZATION FUNCTIONS (from blast_radius.py, call_trace.py, etc.) + # ============================================================================ + + def create_blast_radius_visualization(self, symbol_name: str, filepath: Optional[str] = None, max_depth: int = 5) -> Dict[str, Any]: + """Creates blast radius visualization showing impact of changes.""" + try: + symbols = self.codebase.get_symbols(symbol_name=symbol_name) + if not symbols: + return {"error": f"Symbol '{symbol_name}' not found"} + + target_symbol = self._resolve_symbol_with_filepath(symbols, filepath, Symbol) + if not target_symbol: + return {"error": f"Symbol '{symbol_name}' not found at specified path"} + + # Create NetworkX graph for blast radius + G = nx.DiGraph() + + # Use the blast radius logic + self._build_blast_radius_graph(G, target_symbol, max_depth) + + return { + "nodes": [{"id": str(node), "label": str(node), "type": type(node).__name__} for node in G.nodes()], + "edges": [{"source": str(source), "target": str(target)} for source, target in G.edges()], + "metadata": { + "target_symbol": symbol_name, + "visualization_type": "blast_radius", + "node_count": len(G.nodes()), + "edge_count": len(G.edges()), + "max_depth": max_depth + } + } + except Exception as e: + return {"error": f"Failed to create blast radius visualization: {e}"} + + def create_call_trace_visualization(self, function_name: str, filepath: Optional[str] = None, max_depth: int = 10) -> Dict[str, Any]: + """Creates call trace visualization showing function call relationships.""" + try: + symbols = self.codebase.get_symbols(symbol_name=function_name) + if not symbols: + return {"error": f"Function '{function_name}' not found"} + + target_function = self._resolve_symbol_with_filepath(symbols, filepath, Function) + if not target_function: + return {"error": f"Function '{function_name}' not found at specified path or is not a function"} + + # Create NetworkX graph for call trace + G = nx.DiGraph() + + # Use the call trace logic + self._build_call_trace_graph(G, target_function, max_depth) + + return { + "nodes": [{"id": str(node), "label": str(node), "type": type(node).__name__} for node in G.nodes()], + "edges": [{"source": str(source), "target": str(target)} for source, target in G.edges()], + "metadata": { + "target_function": function_name, + "visualization_type": "call_trace", + "node_count": len(G.nodes()), + "edge_count": len(G.edges()), + "max_depth": max_depth + } + } + except Exception as e: + return {"error": f"Failed to create call trace visualization: {e}"} + + def create_dependency_trace_visualization(self, symbol_name: str, filepath: Optional[str] = None, max_depth: int = 5) -> Dict[str, Any]: + """Creates dependency trace visualization.""" + try: + symbols = self.codebase.get_symbols(symbol_name=symbol_name) + if not symbols: + return {"error": f"Symbol '{symbol_name}' not found"} + + target_symbol = self._resolve_symbol_with_filepath(symbols, filepath, Symbol) + if not target_symbol: + return {"error": f"Symbol '{symbol_name}' not found at specified path"} + + # Create NetworkX graph for dependency trace + G = nx.DiGraph() + + # Use the dependency trace logic + self._build_dependency_trace_graph(G, target_symbol, max_depth) + + return { + "nodes": [{"id": str(node), "label": str(node), "type": type(node).__name__} for node in G.nodes()], + "edges": [{"source": str(source), "target": str(target)} for source, target in G.edges()], + "metadata": { + "target_symbol": symbol_name, + "visualization_type": "dependency_trace", + "node_count": len(G.nodes()), + "edge_count": len(G.edges()), + "max_depth": max_depth + } + } + except Exception as e: + return {"error": f"Failed to create dependency trace visualization: {e}"} + + def create_method_relationships_visualization(self, class_name: str, filepath: Optional[str] = None) -> Dict[str, Any]: + """Creates method relationships visualization within a class.""" + try: + symbols = self.codebase.get_symbols(symbol_name=class_name) + if not symbols: + return {"error": f"Class '{class_name}' not found"} + + target_class = self._resolve_symbol_with_filepath(symbols, filepath, Class) + if not target_class: + return {"error": f"Class '{class_name}' not found at specified path or is not a class"} + + # Create NetworkX graph for method relationships + G = nx.DiGraph() + + # Use the method relationships logic + self._build_method_relationships_graph(G, target_class) + + return { + "nodes": [{"id": str(node), "label": str(node), "type": type(node).__name__} for node in G.nodes()], + "edges": [{"source": str(source), "target": str(target)} for source, target in G.edges()], + "metadata": { + "target_class": class_name, + "visualization_type": "method_relationships", + "node_count": len(G.nodes()), + "edge_count": len(G.edges()) + } + } + except Exception as e: + return {"error": f"Failed to create method relationships visualization: {e}"} + + # ============================================================================ + # FILE AND DIRECTORY OPERATIONS (from view_file.py, list_directory.py) + # ============================================================================ + + def view_file_content(self, filepath: str, line_numbers: bool = True, start_line: Optional[int] = None, + end_line: Optional[int] = None, max_lines: int = 500) -> ViewFileObservation: + """Views file content with optional line numbers and pagination.""" + return view_file( + codebase=self.codebase, + filepath=filepath, + line_numbers=line_numbers, + start_line=start_line, + end_line=end_line, + max_lines=max_lines + ) + + def list_directory_contents(self, path: str = "./", depth: int = 2) -> ListDirectoryObservation: + """Lists directory contents with specified depth.""" + return list_directory(codebase=self.codebase, path=path, depth=depth) + + def add_line_numbers_to_content(self, content: str) -> str: + """Adds line numbers to content.""" + return add_line_numbers(content) + + # ============================================================================ + # SYMBOL REVELATION AND ANALYSIS (from reveal_symbol.py) + # ============================================================================ + + def reveal_symbol_relationships(self, symbol_name: str, filepath: Optional[str] = None, + max_depth: int = 2, max_tokens: Optional[int] = None, + collect_dependencies: bool = True, collect_usages: bool = True) -> RevealSymbolObservation: + """Reveals comprehensive symbol relationships.""" + return reveal_symbol( + codebase=self.codebase, + symbol_name=symbol_name, + filepath=filepath, + max_depth=max_depth, + max_tokens=max_tokens, + collect_dependencies=collect_dependencies, + collect_usages=collect_usages + ) + + def get_symbol_info_detailed(self, symbol: Symbol, max_tokens: Optional[int] = None) -> SymbolInfo: + """Gets detailed information about a symbol.""" + return get_symbol_info(symbol, max_tokens) + + def get_extended_symbol_context(self, symbol: Symbol, degree: int, max_tokens: Optional[int] = None, + collect_dependencies: bool = True, collect_usages: bool = True) -> Tuple[List[SymbolInfo], List[SymbolInfo], int]: + """Gets extended context for a symbol.""" + return get_extended_context( + symbol=symbol, + degree=degree, + max_tokens=max_tokens, + collect_dependencies=collect_dependencies, + collect_usages=collect_usages + ) + + # ============================================================================ + # DOCUMENTATION GENERATION (from document_functions.py, generate_docs_json.py, mdx_docs_generation.py) + # ============================================================================ + + def generate_docstrings_for_undocumented(self) -> Dict[str, Any]: + """Generates docstrings for undocumented functions using AI.""" + try: + logger.info("Generating docstrings for undocumented functions...") + + # Count undocumented functions before + undocumented_before = len([f for f in self.codebase.functions if not getattr(f, 'docstring', None)]) + + # Run the document functions process + document_functions_run(self.codebase) + + # Count undocumented functions after + undocumented_after = len([f for f in self.codebase.functions if not getattr(f, 'docstring', None)]) + + return { + "status": "success", + "undocumented_before": undocumented_before, + "undocumented_after": undocumented_after, + "docstrings_generated": undocumented_before - undocumented_after, + "message": "Docstring generation complete. Remember to commit changes." + } + except Exception as e: + return {"status": "error", "error": f"Failed to generate docstrings: {e}"} + + def generate_structured_docs(self, head_commit: str = "latest", raise_on_missing_docstring: bool = False) -> Any: + """Generates structured JSON documentation.""" + return generate_docs_json(self.codebase, head_commit, raise_on_missing_docstring) + + def generate_mdx_documentation(self, target_classes: Optional[List[str]] = None, output_dir: str = "docs") -> Dict[str, Any]: + """Generates MDX documentation pages.""" + try: + # Generate structured docs first + structured_docs = self.generate_structured_docs() + + # Filter classes if specified + if target_classes: + classes_to_document = [cls_doc for cls_doc in structured_docs.classes if cls_doc.title in target_classes] + else: + classes_to_document = structured_docs.classes + + generated_files = [] + output_path = Path(output_dir) + output_path.mkdir(exist_ok=True) + + for cls_doc in classes_to_document: + try: + mdx_content = render_mdx_page_for_class(cls_doc) + mdx_route = get_mdx_route_for_class(cls_doc) + + # Create file path + file_path = output_path / f"{mdx_route}.mdx" + file_path.parent.mkdir(parents=True, exist_ok=True) + + # Write MDX content + file_path.write_text(mdx_content) + generated_files.append(str(file_path)) + + except Exception as e: + logger.warning(f"Failed to generate MDX for {cls_doc.title}: {e}") + + return { + "status": "success", + "generated_files": generated_files, + "classes_documented": len(generated_files), + "output_directory": output_dir + } + except Exception as e: + return {"status": "error", "error": f"Failed to generate MDX documentation: {e}"} + + # ============================================================================ + # DEAD CODE ANALYSIS + # ============================================================================ + + def find_dead_code(self) -> Dict[str, Any]: + """Identifies dead code using graph traversal from entry points.""" + if "dead_code" in self.analysis_cache: + return self.analysis_cache["dead_code"] + + dead_code = { + "total": 0, + "functions": 0, + "classes": 0, + "imports": 0, + "variables": 0, + "detailed_items": [], + "recommendations": [], + "entry_point_analysis": {} + } + + # Find entry points + entry_points = self._identify_entrypoints() + + # Perform graph traversal from entry points + visited = set() + + def traverse_from_entry_point(symbol): + if symbol in visited or isinstance(symbol, ExternalModule): + return + visited.add(symbol) + + # Traverse function calls + if hasattr(symbol, "function_calls"): + for call in symbol.function_calls: + if hasattr(call, "function_definition") and call.function_definition: + traverse_from_entry_point(call.function_definition) + + # Traverse usages + if hasattr(symbol, "usages"): + for usage in symbol.usages: + if hasattr(usage, "usage_symbol"): + traverse_from_entry_point(usage.usage_symbol) + + # Traverse dependencies + if hasattr(symbol, "dependencies"): + for dep in symbol.dependencies: + if hasattr(dep, "resolved_symbol") and dep.resolved_symbol: + traverse_from_entry_point(dep.resolved_symbol) + elif isinstance(dep, Symbol): + traverse_from_entry_point(dep) + + # Start traversal from all entry points + for entry_point in entry_points["functions"] + entry_points["classes"]: + if "symbol_obj" in entry_point: + traverse_from_entry_point(entry_point["symbol_obj"]) + + # Find dead functions + for func in self.codebase.functions: + if (func not in visited and + not self._is_test_function(func) and + not self._is_special_function(func)): + + dead_code["functions"] += 1 + dead_code["detailed_items"].append({ + "type": "function", + "name": func.name, + "file": func.filepath, + "line": func.start_point.line + 1 if hasattr(func, 'start_point') else 0, + "reason": "No usage found from entry points", + "confidence": 0.8, + "complexity": self._calculate_function_complexity(func), + "loc": len(func.source.splitlines()) if hasattr(func, 'source') else 0 + }) + + # Find dead classes + for cls in self.codebase.classes: + if cls not in visited and not self._is_test_class(cls): + dead_code["classes"] += 1 + dead_code["detailed_items"].append({ + "type": "class", + "name": cls.name, + "file": cls.filepath, + "line": cls.start_point.line + 1 if hasattr(cls, 'start_point') else 0, + "reason": "No usage found from entry points", + "confidence": 0.7, + "methods_count": len(cls.methods), + "inheritance_depth": len(cls.superclasses) + }) + + # Find dead imports + for file_obj in self.codebase.files: + for imp in file_obj.imports: + if not hasattr(imp, 'usages') or len(imp.usages) == 0: + dead_code["imports"] += 1 + dead_code["detailed_items"].append({ + "type": "import", + "name": imp.name, + "file": file_obj.filepath, + "line": imp.start_point.line + 1 if hasattr(imp, 'start_point') else 0, + "reason": "Import not used in file", + "confidence": 0.9, + "module": getattr(imp, 'module', 'unknown') + }) + + dead_code["total"] = ( + dead_code["functions"] + dead_code["classes"] + + dead_code["imports"] + dead_code["variables"] + ) + + # Generate recommendations + dead_code["recommendations"] = [ + "Review dead code items before removal", + "Check if functions are used in tests or configuration", + "Consider if classes are used for inheritance only", + "Verify imports are not used in string literals or dynamic imports", + f"Found {len(entry_points['functions'])} entry point functions", + f"Found {len(entry_points['classes'])} entry point classes" + ] + + dead_code["entry_point_analysis"] = entry_points + + self.analysis_cache["dead_code"] = dead_code + return dead_code + + # ============================================================================ + # BASH COMMAND EXECUTION (from bash.py) + # ============================================================================ + + def validate_bash_command(self, command: str) -> Tuple[bool, str]: + """Validates if a bash command is safe to execute.""" + return validate_command(command) + + def run_bash_command(self, command: str, is_background: bool = False) -> RunBashCommandObservation: + """Runs a bash command and returns the result.""" + return run_bash_command(command, is_background) + + # ============================================================================ + # REFLECTION AND PLANNING (from reflection.py) + # ============================================================================ + + def perform_reflection(self, context_summary: str, findings_so_far: str, + current_challenges: str = "", reflection_focus: Optional[str] = None) -> ReflectionObservation: + """Performs agent reflection for strategic planning.""" + return perform_reflection( + context_summary=context_summary, + findings_so_far=findings_so_far, + current_challenges=current_challenges, + reflection_focus=reflection_focus, + codebase=self.codebase + ) + + def parse_reflection_response(self, response: str) -> List[ReflectionSection]: + """Parses reflection response into structured sections.""" + return parse_reflection_response(response) + + # ============================================================================ + # WORKSPACE TOOLS (from tools.py) + # ============================================================================ + + def get_workspace_tools(self): + """Gets all workspace tools initialized with the codebase.""" + return get_workspace_tools(self.codebase) + + # ============================================================================ + # CODEBASE UTILITIES (from current_code_codebase.py, codegen_sdk_codebase.py) + # ============================================================================ + + def get_current_code_codebase(self, config=None, secrets=None, subdirectories=None): + """Gets codebase for currently running code.""" + return get_current_code_codebase(config, secrets, subdirectories) + + def get_codegen_sdk_codebase(self): + """Gets the codegen SDK codebase.""" + return get_codegen_sdk_codebase() + + def import_all_sdk_modules(self): + """Imports all codegen SDK modules.""" + import_all_codegen_sdk_modules() + + def get_documented_objects(self): + """Gets all documented objects.""" + return get_documented_objects() + + # ============================================================================ + # ADVANCED ANALYSIS FUNCTIONS + # ============================================================================ + + def analyze_function_complexity(self, function_name: str, filepath: Optional[str] = None) -> Dict[str, Any]: + """Analyzes complexity metrics for a specific function.""" + try: + symbols = self.codebase.get_symbols(symbol_name=function_name) + if not symbols: + return {"function_name": function_name, "error": "Function not found."} + + target_function = self._resolve_symbol_with_filepath(symbols, filepath, Function) + if not target_function: + return {"function_name": function_name, "error": "Function not found at specified path or is not a function."} + + # Calculate various complexity metrics + cyclomatic_complexity = self._calculate_cyclomatic_complexity(target_function) + halstead_metrics = self._calculate_halstead_metrics(target_function) + + return { + "function_name": function_name, + "filepath": target_function.file.filepath if target_function.file else "N/A", + "cyclomatic_complexity": cyclomatic_complexity, + "complexity_rank": self._get_complexity_rank(cyclomatic_complexity), + "halstead_metrics": halstead_metrics, + "parameters_count": len(target_function.parameters), + "return_statements_count": len(getattr(target_function, "return_statements", [])), + "function_calls_count": len(target_function.function_calls), + "lines_of_code": len(target_function.source.splitlines()) if hasattr(target_function, "source") else 0, + "maintainability_score": self._calculate_function_maintainability(target_function) + } + except Exception as e: + return {"function_name": function_name, "error": f"Error analyzing function complexity: {e}"} + + def analyze_class_structure(self, class_name: str, filepath: Optional[str] = None) -> Dict[str, Any]: + """Analyzes structural metrics for a specific class.""" + try: + symbols = self.codebase.get_symbols(symbol_name=class_name) + if not symbols: + return {"class_name": class_name, "error": "Class not found."} + + target_class = self._resolve_symbol_with_filepath(symbols, filepath, Class) + if not target_class: + return {"class_name": class_name, "error": "Class not found at specified path or is not a class."} + + # Analyze class structure + methods_analysis = [] + for method in target_class.methods: + methods_analysis.append({ + "name": method.name, + "complexity": self._calculate_cyclomatic_complexity(method), + "parameters_count": len(method.parameters), + "is_public": not method.name.startswith("_"), + "is_property": self._is_property_method(method), + "lines_of_code": len(method.source.splitlines()) if hasattr(method, "source") else 0, + "has_docstring": bool(getattr(method, "docstring", None)) + }) + + attributes_analysis = [] + for attr in target_class.attributes: + attributes_analysis.append({ + "name": attr.name, + "is_public": not attr.name.startswith("_"), + "has_type_annotation": hasattr(attr, "type") and attr.type is not None, + "usages_count": len(attr.usages) if hasattr(attr, "usages") else 0 + }) + + return { + "class_name": class_name, + "filepath": target_class.file.filepath if target_class.file else "N/A", + "methods_count": len(target_class.methods), + "attributes_count": len(target_class.attributes), + "inheritance_depth": len(target_class.superclasses), + "subclasses_count": len(target_class.subclasses), + "methods_analysis": methods_analysis, + "attributes_analysis": attributes_analysis, + "complexity_score": sum(m["complexity"] for m in methods_analysis) / max(1, len(methods_analysis)), + "public_methods_ratio": len([m for m in methods_analysis if m["is_public"]]) / max(1, len(methods_analysis)), + "documented_methods_ratio": len([m for m in methods_analysis if m["has_docstring"]]) / max(1, len(methods_analysis)), + "cohesion_metrics": self._calculate_class_cohesion(target_class) + } + except Exception as e: + return {"class_name": class_name, "error": f"Error analyzing class structure: {e}"} + + def analyze_import_relationships(self, filepath: str) -> Dict[str, Any]: + """Analyzes import relationships for a specific file.""" + try: + file_obj = self.codebase.get_file(filepath) + + imports_analysis = [] + for imp in file_obj.imports: + imports_analysis.append({ + "name": imp.name, + "module": getattr(imp, "module", "unknown"), + "is_external": isinstance(getattr(imp, "resolved_symbol", None), ExternalModule), + "usages_count": len(imp.usages) if hasattr(imp, "usages") else 0, + "is_unused": len(imp.usages) == 0 if hasattr(imp, "usages") else True, + "line": imp.start_point.line + 1 if hasattr(imp, "start_point") else 0, + "resolved": bool(getattr(imp, "resolved_symbol", None)) + }) + + inbound_imports = [] + for symbol in getattr(file_obj, "symbols", []): + for usage in symbol.usages: + if hasattr(usage, "file") and usage.file != file_obj: + inbound_imports.append({ + "symbol": symbol.name, + "imported_by": usage.file.filepath, + "usage_type": type(usage).__name__ + }) + + return { + "filepath": filepath, + "imports_count": len(file_obj.imports), + "external_imports_count": len([i for i in imports_analysis if i["is_external"]]), + "unused_imports_count": len([i for i in imports_analysis if i["is_unused"]]), + "unresolved_imports_count": len([i for i in imports_analysis if not i["resolved"]]), + "imports_analysis": imports_analysis, + "inbound_imports": inbound_imports, + "inbound_imports_count": len(inbound_imports), + "circular_dependencies": self._detect_circular_imports(file_obj) + } + except ValueError: + return {"filepath": filepath, "error": "File not found in codebase."} + except Exception as e: + return {"filepath": filepath, "error": f"Error analyzing import relationships: {e}"} + + # ============================================================================ + # HELPER FUNCTIONS + # ============================================================================ + + def _resolve_symbol_with_filepath(self, symbols: List[Symbol], filepath: Optional[str], expected_type: type) -> Optional[Symbol]: + """Helper to resolve an ambiguous symbol by filepath and type.""" + if filepath: + for s in symbols: + if s.file and s.file.filepath == filepath and isinstance(s, expected_type): + return s + elif symbols and isinstance(symbols[0], expected_type): + return symbols[0] + return None + + def _identify_entrypoints(self) -> Dict[str, List[Dict[str, Any]]]: + """Identifies entry points in the codebase.""" + entrypoints = { + "functions": [], + "classes": [], + "files": [] + } + + # Function entry points + for func in self.codebase.functions: + if self._is_entrypoint_function(func): + entrypoints["functions"].append({ + "name": func.name, + "file": func.filepath, + "type": "function", + "score": self._calculate_entrypoint_score(func), + "symbol_obj": func + }) + + # Class entry points + for cls in self.codebase.classes: + if self._is_entrypoint_class(cls): + entrypoints["classes"].append({ + "name": cls.name, + "file": cls.filepath, + "type": "class", + "score": self._calculate_class_entrypoint_score(cls), + "symbol_obj": cls + }) + + # File entry points + for file_obj in self.codebase.files: + if self._is_entrypoint_file(file_obj): + entrypoints["files"].append({ + "name": file_obj.name, + "path": file_obj.filepath, + "type": "file" + }) + + return entrypoints + + def _get_dead_code_summary(self) -> Dict[str, Any]: + """Gets a summary of dead code analysis.""" + dead_code = self.find_dead_code() + return { + "total_dead_items": dead_code["total"], + "dead_functions": dead_code["functions"], + "dead_classes": dead_code["classes"], + "dead_imports": dead_code["imports"] + } + + def _get_complexity_overview(self) -> Dict[str, Any]: + """Gets complexity overview of the codebase.""" + complexities = [self._calculate_function_complexity(f) for f in self.codebase.functions] + if not complexities: + return {"average": 0, "max": 0, "distribution": {}} + + return { + "average_complexity": sum(complexities) / len(complexities), + "max_complexity": max(complexities), + "min_complexity": min(complexities), + "distribution": { + "low": len([c for c in complexities if c <= 5]), + "medium": len([c for c in complexities if 5 < c <= 10]), + "high": len([c for c in complexities if c > 10]) + } + } + + def _get_function_parameters_details(self, func: Function) -> List[Dict[str, Any]]: + """Extracts detailed information about function parameters.""" + params_details = [] + for param in func.parameters: + param_type_source = getattr(param.type, 'source', 'Any') if hasattr(param, 'type') and param.type else 'Any' + resolved_types = [] + + if hasattr(param, 'type') and param.type and hasattr(param.type, 'resolved_value') and param.type.resolved_value: + resolved_symbols = param.type.resolved_value + if not isinstance(resolved_symbols, list): + resolved_symbols = [resolved_symbols] + + for res_sym in resolved_symbols: + if isinstance(res_sym, Symbol): + resolved_types.append({ + "name": res_sym.name, + "type": type(res_sym).__name__, + "filepath": res_sym.filepath if hasattr(res_sym, 'filepath') else None + }) + elif isinstance(res_sym, ExternalModule): + resolved_types.append({ + "name": res_sym.name, + "type": "ExternalModule" + }) + else: + resolved_types.append({"name": str(res_sym), "type": "Unknown"}) + + params_details.append({ + "name": param.name, + "type_annotation": param_type_source, + "resolved_types": resolved_types, + "has_default": getattr(param, "has_default", False), + "is_keyword_only": getattr(param, "is_keyword_only", False), + "is_positional_only": getattr(param, "is_positional_only", False), + "is_var_arg": getattr(param, "is_var_arg", False), + "is_var_kw": getattr(param, "is_var_kw", False) + }) + return params_details + + def _get_function_return_type_details(self, func: Function) -> Dict[str, Any]: + """Extracts detailed information about function return type.""" + return_type_source = getattr(func.return_type, 'source', 'Any') if hasattr(func, 'return_type') and func.return_type else 'Any' + resolved_types = [] + + if hasattr(func, 'return_type') and func.return_type and hasattr(func.return_type, 'resolved_value') and func.return_type.resolved_value: + resolved_symbols = func.return_type.resolved_value + if not isinstance(resolved_symbols, list): + resolved_symbols = [resolved_symbols] + + for res_sym in resolved_symbols: + if isinstance(res_sym, Symbol): + resolved_types.append({ + "name": res_sym.name, + "type": type(res_sym).__name__, + "filepath": res_sym.filepath if hasattr(res_sym, 'filepath') else None + }) + elif isinstance(res_sym, ExternalModule): + resolved_types.append({ + "name": res_sym.name, + "type": "ExternalModule" + }) + else: + resolved_types.append({"name": str(res_sym), "type": "Unknown"}) + + return { + "type_annotation": return_type_source, + "resolved_types": resolved_types + } + + def _get_function_local_variables_details(self, func: Function) -> List[Dict[str, Any]]: + """Extracts details about local variables defined within a function.""" + local_vars = [] + if hasattr(func, 'code_block') and hasattr(func.code_block, 'local_var_assignments'): + for assignment in func.code_block.local_var_assignments: + var_type_source = getattr(assignment.type, 'source', 'Any') if hasattr(assignment, 'type') and assignment.type else 'Any' + local_vars.append({ + "name": assignment.name, + "type_annotation": var_type_source, + "line": assignment.start_point.line + 1 if hasattr(assignment, 'start_point') else None, + "value_snippet": assignment.source if hasattr(assignment, 'source') else None + }) + return local_vars + + def _symbol_to_dict(self, symbol_info: SymbolInfo) -> Dict[str, Any]: + """Converts SymbolInfo to dictionary.""" + return { + "name": symbol_info.name, + "filepath": symbol_info.filepath, + "source": symbol_info.source + } + + def _call_site_to_dict(self, call_site) -> Dict[str, Any]: + """Converts call site to dictionary.""" + return { + "name": getattr(call_site, "name", "unknown"), + "file": getattr(call_site, "file", {}).get("filepath", "unknown") if hasattr(call_site, "file") else "unknown", + "line": getattr(call_site, "start_point", {}).get("line", 0) if hasattr(call_site, "start_point") else 0 + } + + def _function_call_to_dict(self, function_call) -> Dict[str, Any]: + """Converts function call to dictionary.""" + return { + "name": getattr(function_call, "name", "unknown"), + "args_count": len(getattr(function_call, "args", [])), + "line": getattr(function_call, "start_point", {}).get("line", 0) if hasattr(function_call, "start_point") else 0 + } + + def _class_to_dict(self, cls: Class) -> Dict[str, Any]: + """Converts class to dictionary.""" + return { + "name": cls.name, + "filepath": cls.filepath, + "methods_count": len(cls.methods), + "attributes_count": len(cls.attributes) + } + + # ============================================================================ + # COMPLEXITY AND QUALITY METRICS + # ============================================================================ + + def _calculate_function_complexity(self, func: Function) -> int: + """Calculates cyclomatic complexity for a function.""" + return self._calculate_cyclomatic_complexity(func) + + def _calculate_cyclomatic_complexity(self, func: Function) -> int: + """Calculate cyclomatic complexity for a function.""" + try: + complexity = 1 # Base complexity + + if hasattr(func, "source") and func.source: + source = func.source.lower() + # Count decision points + complexity += source.count("if ") + complexity += source.count("elif ") + complexity += source.count("for ") + complexity += source.count("while ") + complexity += source.count("except ") + complexity += source.count("and ") + complexity += source.count("or ") + complexity += source.count("try:") + complexity += source.count("with ") + + return complexity + except Exception: + return 1 + + def _calculate_halstead_metrics(self, func: Function) -> Dict[str, float]: + """Calculate Halstead metrics for a function.""" + try: + if not hasattr(func, "source") or not func.source: + return {"volume": 0.0, "difficulty": 0.0, "effort": 0.0} + + operators, operands = get_operators_and_operands(func) + volume, N1, N2, n1, n2 = calculate_halstead_volume(operators, operands) + + N = N1 + N2 # Program length + n = n1 + n2 # Program vocabulary + + if n > 0 and n2 > 0: + difficulty = (n1 / 2) * (N2 / n2) + effort = difficulty * volume + else: + difficulty = effort = 0 + + return { + "volume": volume, + "difficulty": difficulty, + "effort": effort, + "length": N, + "vocabulary": n, + "unique_operators": n1, + "unique_operands": n2 + } + except Exception: + return {"volume": 0.0, "difficulty": 0.0, "effort": 0.0} + + def _get_complexity_rank(self, complexity: int) -> str: + """Get complexity rank based on cyclomatic complexity.""" + if complexity <= 5: + return "A" + elif complexity <= 10: + return "B" + elif complexity <= 20: + return "C" + elif complexity <= 30: + return "D" + elif complexity <= 40: + return "E" + else: + return "F" + + def _calculate_function_complexity_metrics(self, func: Function) -> Dict[str, Any]: + """Calculate comprehensive complexity metrics for a function.""" + return { + "cyclomatic_complexity": self._calculate_cyclomatic_complexity(func), + "halstead_metrics": self._calculate_halstead_metrics(func), + "lines_of_code": len(func.source.splitlines()) if hasattr(func, "source") else 0, + "parameters_count": len(func.parameters), + "nesting_depth": self._calculate_nesting_depth(func) + } + + def _calculate_function_quality_metrics(self, func: Function) -> Dict[str, Any]: + """Calculate quality metrics for a function.""" + return { + "has_docstring": bool(getattr(func, "docstring", None)), + "has_return_type": bool(getattr(func, "return_type", None)), + "typed_parameters_ratio": self._calculate_typed_parameters_ratio(func), + "maintainability_score": self._calculate_function_maintainability(func) + } + + def _calculate_class_complexity_metrics(self, cls: Class) -> Dict[str, Any]: + """Calculate complexity metrics for a class.""" + method_complexities = [self._calculate_cyclomatic_complexity(m) for m in cls.methods] + return { + "average_method_complexity": sum(method_complexities) / max(1, len(method_complexities)), + "max_method_complexity": max(method_complexities) if method_complexities else 0, + "methods_count": len(cls.methods), + "attributes_count": len(cls.attributes), + "inheritance_depth": len(cls.superclasses), + "weighted_methods_per_class": sum(method_complexities) + } + + def _calculate_inheritance_metrics(self, cls: Class) -> Dict[str, Any]: + """Calculate inheritance-related metrics.""" + return { + "depth_of_inheritance": len(cls.superclasses), + "number_of_children": len(cls.subclasses), + "coupling_between_objects": len([dep for dep in cls.dependencies if isinstance(dep, Class)]), + "response_for_class": len(cls.methods) + len([m for sc in cls.superclasses for m in sc.methods]) + } + + def _calculate_class_cohesion(self, cls: Class) -> Dict[str, Any]: + """Calculate class cohesion metrics.""" + # Simplified LCOM (Lack of Cohesion of Methods) calculation + methods = list(cls.methods) + attributes = list(cls.attributes) + + if not methods or not attributes: + return {"lcom": 0, "cohesion_score": 1.0} + + # Count method pairs that don't share attributes + non_cohesive_pairs = 0 + total_pairs = 0 + + for i, method1 in enumerate(methods): + for method2 in methods[i+1:]: + total_pairs += 1 + # Check if methods share any attributes + method1_attrs = set(getattr(method1, "variable_usages", [])) + method2_attrs = set(getattr(method2, "variable_usages", [])) + + if not method1_attrs.intersection(method2_attrs): + non_cohesive_pairs += 1 + + lcom = non_cohesive_pairs / max(1, total_pairs) + cohesion_score = 1.0 - lcom + + return { + "lcom": lcom, + "cohesion_score": cohesion_score, + "methods_count": len(methods), + "attributes_count": len(attributes) + } + + # ============================================================================ + # GRAPH BUILDING FUNCTIONS + # ============================================================================ + + def _build_blast_radius_graph(self, graph: nx.DiGraph, symbol: Symbol, max_depth: int, depth: int = 0): + """Build blast radius graph recursively.""" + if depth >= max_depth or symbol in graph.nodes: + return + + graph.add_node(symbol, name=symbol.name, type=type(symbol).__name__, depth=depth) + + # Add all usages (things that would be affected by changes) + for usage in symbol.usages: + if hasattr(usage, "usage_symbol"): + affected_symbol = usage.usage_symbol + if affected_symbol not in graph.nodes: + graph.add_node(affected_symbol, name=affected_symbol.name, + type=type(affected_symbol).__name__, depth=depth + 1) + graph.add_edge(symbol, affected_symbol, relationship="impacts") + + if depth + 1 < max_depth: + self._build_blast_radius_graph(graph, affected_symbol, max_depth, depth + 1) + + def _build_call_trace_graph(self, graph: nx.DiGraph, func: Function, max_depth: int, depth: int = 0): + """Build call trace graph recursively.""" + if depth >= max_depth or func in graph.nodes: + return + + graph.add_node(func, name=func.name, type="function", depth=depth) + + # Add function calls + for call in func.function_calls: + if hasattr(call, "function_definition") and call.function_definition: + called_func = call.function_definition + if not isinstance(called_func, ExternalModule) and called_func not in graph.nodes: + graph.add_node(called_func, name=called_func.name, type="function", depth=depth + 1) + graph.add_edge(func, called_func, relationship="calls") + + if depth + 1 < max_depth: + self._build_call_trace_graph(graph, called_func, max_depth, depth + 1) + + def _build_dependency_trace_graph(self, graph: nx.DiGraph, symbol: Symbol, max_depth: int, depth: int = 0): + """Build dependency trace graph recursively.""" + if depth >= max_depth or symbol in graph.nodes: + return + + graph.add_node(symbol, name=symbol.name, type=type(symbol).__name__, depth=depth) + + # Add dependencies + for dep in symbol.dependencies: + if isinstance(dep, Import): + dep = hop_through_imports(dep) + + if isinstance(dep, Symbol) and not isinstance(dep, ExternalModule) and dep not in graph.nodes: + graph.add_node(dep, name=dep.name, type=type(dep).__name__, depth=depth + 1) + graph.add_edge(symbol, dep, relationship="depends_on") + + if depth + 1 < max_depth: + self._build_dependency_trace_graph(graph, dep, max_depth, depth + 1) + + def _build_method_relationships_graph(self, graph: nx.DiGraph, cls: Class): + """Build method relationships graph for a class.""" + graph.add_node(cls, name=cls.name, type="class") + + # Add all methods + for method in cls.methods: + graph.add_node(method, name=f"{cls.name}.{method.name}", type="method") + graph.add_edge(cls, method, relationship="contains") + + # Add method call relationships + for call in method.function_calls: + if hasattr(call, "function_definition") and call.function_definition: + called_func = call.function_definition + if called_func in cls.methods: # Internal method call + graph.add_edge(method, called_func, relationship="calls") + + # ============================================================================ + # UTILITY FUNCTIONS + # ============================================================================ + + def _is_entrypoint_function(self, func: Function) -> bool: + """Check if a function is an entrypoint.""" + entrypoint_patterns = ["main", "run", "start", "execute", "cli", "app", "serve"] + return ( + any(pattern in func.name.lower() for pattern in entrypoint_patterns) or + func.name == "__main__" or + self._has_entrypoint_decorators(func) or + self._is_called_from_main_block(func) + ) + + def _is_entrypoint_class(self, cls: Class) -> bool: + """Check if a class is an entrypoint.""" + entrypoint_patterns = ["app", "application", "server", "client", "main", "runner", "service"] + return ( + any(pattern in cls.name.lower() for pattern in entrypoint_patterns) or + self._has_framework_inheritance(cls) or + self._has_singleton_pattern(cls) + ) + + def _is_entrypoint_file(self, file_obj: SourceFile) -> bool: + """Check if a file is an entrypoint.""" + entrypoint_patterns = ["main.py", "__main__.py", "app.py", "server.py", "run.py", "cli.py"] + return any(pattern in file_obj.filepath for pattern in entrypoint_patterns) + + def _is_test_function(self, func: Function) -> bool: + """Check if a function is a test function.""" + return ( + func.name.startswith("test_") or + "test" in func.filepath or + self._has_test_decorators(func) + ) + + def _is_test_class(self, cls: Class) -> bool: + """Check if a class is a test class.""" + return cls.name.startswith("Test") or "test" in cls.filepath + + def _is_special_function(self, func: Function) -> bool: + """Check if a function is special (shouldn't be considered dead code).""" + special_patterns = ["__init__", "__str__", "__repr__", "__call__", "setUp", "tearDown"] + return any(pattern in func.name for pattern in special_patterns) + + def _calculate_entrypoint_score(self, func: Function) -> float: + """Calculate entrypoint score for a function.""" + score = 1.0 # Base score + + # Name-based scoring + entrypoint_names = ["main", "run", "start", "execute", "app", "serve", "launch"] + if any(name in func.name.lower() for name in entrypoint_names): + score += 2.0 + + # Usage-based scoring + if len(func.usages) == 0: + score += 1.0 + elif len(func.usages) < 3: + score += 0.5 + + # Complexity-based scoring + complexity = self._calculate_cyclomatic_complexity(func) + if 5 <= complexity <= 15: + score += 1.0 + elif complexity > 15: + score += 0.5 + + # Call-based scoring + if len(func.function_calls) > 5: + score += 1.0 + elif len(func.function_calls) > 2: + score += 0.5 + + return score + + def _calculate_class_entrypoint_score(self, cls: Class) -> float: + """Calculate entrypoint score for a class.""" + score = 1.0 # Base score + + # Name-based scoring + entrypoint_names = ["app", "application", "server", "client", "main", "service"] + if any(name in cls.name.lower() for name in entrypoint_names): + score += 2.0 + + # Size-based scoring + if len(cls.methods) > 10: + score += 1.0 + + # Inheritance-based scoring + if len(cls.superclasses) > 0: + for superclass in cls.superclasses: + if any(pattern in superclass.name.lower() for pattern in ["application", "service", "handler"]): + score += 1.5 + + return score + + def _has_entrypoint_decorators(self, func: Function) -> bool: + """Check if function has entrypoint decorators.""" + if not hasattr(func, "decorators"): + return False + + for decorator in func.decorators: + decorator_source = getattr(decorator, "source", "") + if any(pattern in decorator_source.lower() for pattern in ["@app.", "@click.", "@typer.", "@fastapi."]): + return True + return False + + def _has_framework_inheritance(self, cls: Class) -> bool: + """Check if class inherits from framework classes.""" + for superclass in cls.superclasses: + if any(pattern in superclass.name.lower() for pattern in ["application", "app", "service", "handler"]): + return True + return False + + def _has_singleton_pattern(self, cls: Class) -> bool: + """Check if class implements singleton pattern.""" + return any("instance" in method.name.lower() or "singleton" in method.name.lower() for method in cls.methods) + + def _has_test_decorators(self, func: Function) -> bool: + """Check if function has test decorators.""" + if not hasattr(func, "decorators"): + return False + + for decorator in func.decorators: + decorator_source = getattr(decorator, "source", "") + if any(pattern in decorator_source.lower() for pattern in ["@pytest.", "@unittest.", "@test"]): + return True + return False + + def _is_called_from_main_block(self, func: Function) -> bool: + """Check if function is called from if __name__ == '__main__' block.""" + for usage in func.usages: + if hasattr(usage, "parent_statement"): + parent = usage.parent_statement + if hasattr(parent, "condition") and "__name__" in getattr(parent.condition, "source", ""): + return True + return False + + def _is_property_method(self, method: Function) -> bool: + """Check if method is a property.""" + if not hasattr(method, "decorators"): + return False + + for decorator in method.decorators: + if "@property" in getattr(decorator, "source", ""): + return True + return False + + def _calculate_nesting_depth(self, func: Function) -> int: + """Calculate maximum nesting depth in a function.""" + if not hasattr(func, "source") or not func.source: + return 0 + + lines = func.source.split('\n') + max_depth = 0 + current_depth = 0 + + for line in lines: + stripped = line.strip() + if any(keyword in stripped for keyword in ['if ', 'for ', 'while ', 'try:', 'with ']): + current_depth += 1 + max_depth = max(max_depth, current_depth) + elif stripped in ['else:', 'elif ', 'except:', 'finally:']: + continue + elif stripped == '' or stripped.startswith('#'): + continue + else: + # Check for dedentation + indent_level = len(line) - len(line.lstrip()) + if indent_level == 0: + current_depth = 0 + + return max_depth + + def _calculate_typed_parameters_ratio(self, func: Function) -> float: + """Calculate ratio of typed parameters.""" + if not func.parameters: + return 1.0 + + typed_count = sum(1 for p in func.parameters if hasattr(p, 'type') and p.type) + return typed_count / len(func.parameters) + + def _calculate_function_maintainability(self, func: Function) -> float: + """Calculate function maintainability score.""" + complexity = self._calculate_cyclomatic_complexity(func) + loc = len(func.source.splitlines()) if hasattr(func, "source") else 0 + has_docstring = bool(getattr(func, "docstring", None)) + + # Simplified maintainability calculation + base_score = 100 + complexity_penalty = complexity * 2 + loc_penalty = max(0, (loc - 20) * 0.5) + doc_bonus = 10 if has_docstring else 0 + + return max(0, base_score - complexity_penalty - loc_penalty + doc_bonus) + + def _calculate_file_complexity(self, file_obj: SourceFile) -> float: + """Calculate complexity score for a file.""" + try: + if not hasattr(file_obj, "functions"): + return 0.0 + + function_complexities = [self._calculate_cyclomatic_complexity(func) for func in file_obj.functions] + return sum(function_complexities) / max(1, len(function_complexities)) + except Exception: + return 0.0 + + def _calculate_maintainability_index(self, file_obj: SourceFile) -> float: + """Calculate maintainability index for a file.""" + try: + if not hasattr(file_obj, "source"): + return 0.0 + + loc = len(file_obj.source.splitlines()) + complexity = self._calculate_file_complexity(file_obj) + + if loc > 0: + return max(0, (171 - 5.2 * math.log(loc) - 0.23 * complexity - 16.2 * math.log(loc)) * 100 / 171) + return 0.0 + except Exception: + return 0.0 + + def _calculate_file_doc_coverage(self, file_obj: SourceFile) -> float: + """Calculate documentation coverage for a file.""" + total_symbols = len(file_obj.functions) + len(file_obj.classes) + if total_symbols == 0: + return 1.0 + + documented = sum(1 for f in file_obj.functions if getattr(f, 'docstring', None)) + documented += sum(1 for c in file_obj.classes if getattr(c, 'docstring', None)) + + return documented / total_symbols + + def _detect_circular_imports(self, file_obj: SourceFile) -> List[List[str]]: + """Detect circular import dependencies for a file.""" + import_graph = nx.DiGraph() + + # Build import graph starting from this file + def add_file_imports(current_file, visited=None): + if visited is None: + visited = set() + if current_file.filepath in visited: + return + visited.add(current_file.filepath) + + import_graph.add_node(current_file.filepath) + for imp in current_file.imports: + if hasattr(imp, "from_file") and imp.from_file: + import_graph.add_edge(current_file.filepath, imp.from_file.filepath) + add_file_imports(imp.from_file, visited) + + add_file_imports(file_obj) + + # Find cycles involving this file + cycles = list(nx.simple_cycles(import_graph)) + return [cycle for cycle in cycles if file_obj.filepath in cycle] + + # ============================================================================ + # SUMMARY HELPER FUNCTIONS + # ============================================================================ + + def _get_function_summary(self, func: Function) -> Dict[str, Any]: + """Get summary information for a function.""" + return { + "name": func.name, + "line": func.start_point.line + 1 if hasattr(func, 'start_point') else 0, + "complexity": self._calculate_cyclomatic_complexity(func), + "parameters_count": len(func.parameters), + "has_docstring": bool(getattr(func, "docstring", None)), + "is_async": getattr(func, "is_async", False), + "usages_count": len(func.usages) + } + + def _get_class_summary(self, cls: Class) -> Dict[str, Any]: + """Get summary information for a class.""" + return { + "name": cls.name, + "line": cls.start_point.line + 1 if hasattr(cls, 'start_point') else 0, + "methods_count": len(cls.methods), + "attributes_count": len(cls.attributes), + "inheritance_depth": len(cls.superclasses), + "has_docstring": bool(getattr(cls, "docstring", None)), + "usages_count": len(cls.usages) + } + + def _get_import_summary(self, imp: Import) -> Dict[str, Any]: + """Get summary information for an import.""" + return { + "name": imp.name, + "module": getattr(imp, "module", "unknown"), + "line": imp.start_point.line + 1 if hasattr(imp, 'start_point') else 0, + "is_external": isinstance(getattr(imp, "resolved_symbol", None), ExternalModule), + "usages_count": len(imp.usages) if hasattr(imp, "usages") else 0, + "is_resolved": bool(getattr(imp, "resolved_symbol", None)) + } + + def _get_symbol_summary(self, symbol: Symbol) -> Dict[str, Any]: + """Get summary information for a symbol.""" + return { + "name": symbol.name, + "type": type(symbol).__name__, + "line": symbol.start_point.line + 1 if hasattr(symbol, 'start_point') else 0, + "usages_count": len(symbol.usages), + "dependencies_count": len(symbol.dependencies) + } + + def _get_method_summary(self, method: Function) -> Dict[str, Any]: + """Get summary information for a method.""" + return { + "name": method.name, + "line": method.start_point.line + 1 if hasattr(method, 'start_point') else 0, + "complexity": self._calculate_cyclomatic_complexity(method), + "parameters_count": len(method.parameters), + "is_public": not method.name.startswith("_"), + "is_property": self._is_property_method(method), + "has_docstring": bool(getattr(method, "docstring", None)) + } + + def _get_attribute_summary(self, attr) -> Dict[str, Any]: + """Get summary information for an attribute.""" + return { + "name": attr.name, + "line": attr.start_point.line + 1 if hasattr(attr, 'start_point') else 0, + "is_public": not attr.name.startswith("_"), + "has_type_annotation": hasattr(attr, "type") and attr.type is not None, + "usages_count": len(attr.usages) if hasattr(attr, "usages") else 0 + } + + def _get_symbol_context(self, symbol: Symbol) -> Dict[str, Any]: + """Get contextual information about a symbol.""" + return { + "parent_class": symbol.parent_class.name if getattr(symbol, "parent_class", None) else None, + "parent_function": symbol.parent_function.name if getattr(symbol, "parent_function", None) else None, + "file": symbol.file.filepath if symbol.file else None, + "is_public": not symbol.name.startswith("_"), + "symbol_type": getattr(symbol, "symbol_type", type(symbol).__name__) + } \ No newline at end of file diff --git a/Libraries/graph_sitter_lib/graph_sitter_backend.py b/Libraries/graph_sitter_lib/graph_sitter_backend.py new file mode 100644 index 00000000..488b8eda --- /dev/null +++ b/Libraries/graph_sitter_lib/graph_sitter_backend.py @@ -0,0 +1,3954 @@ +#!/usr/bin/env python3 +""" +Production Graph-Sitter Backend API +Provides comprehensive codebase analysis, visualization, and transformation capabilities +using actual graph-sitter library implementation +""" + +import os +import tempfile +import shutil +import subprocess +import traceback +import uuid +import math +import ast +import re # Added for Halstead metrics +from typing import Dict, List, Any, Optional, Union +from collections import defaultdict, Counter +from datetime import datetime +import asyncio +import logging +import networkx as nx + +# FastAPI and web framework imports +from fastapi import FastAPI, HTTPException, BackgroundTasks +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +import uvicorn + +# Graph-sitter imports (actual implementation) +try: + from graph_sitter.core.codebase import Codebase + from graph_sitter.configs.models.codebase import CodebaseConfig + from graph_sitter.core.external_module import ExternalModule + from graph_sitter.core.symbol import Symbol + from graph_sitter.core.file import SourceFile # Changed from File + from graph_sitter.core.function import Function + from graph_sitter.core.class_definition import Class # Changed from ClassDef + from graph_sitter.core.statement import ( + Statement, + IfStatement, + WhileStatement, + TryStatement, + ) + from graph_sitter.core.import_statement import ( + Import, + ) # Changed from ImportStatement + from graph_sitter.core.assignment import Assignment + from graph_sitter.core.parameter import Parameter + from graph_sitter.core.function_call import FunctionCall + from graph_sitter.core.usage import Usage + + # Import analysis functions from graph_sitter_analysis.py + from graph_sitter_analysis import GraphSitterAnalyzer + + # Import LSP diagnostics manager + from lsp_diagnostics import LSPDiagnosticsManager + + # Import autogenlib AI resolution + from autogenlib_adapter import resolve_diagnostic_with_ai + from solidlsp.lsp_protocol_handler.lsp_types import ( + Diagnostic, + DocumentUri, + Range, + ) # For LSP types + from solidlsp.ls_config import Language # For LSP language enum + + # Import documentation generation + from graph_sitter.generate_docs_json import generate_docs_json + from graph_sitter.mdx_docs_generation import ( + render_mdx_page_for_class, + ) # Import specific functions + + GRAPH_SITTER_AVAILABLE = True +except ImportError as e: + print(f"Warning: graph-sitter or related modules not available: {e}") + print("Install with: pip install graph-sitter") + GRAPH_SITTER_AVAILABLE = False + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Initialize FastAPI app +app = FastAPI( + title="Graph-Sitter Comprehensive Analysis API", + description="Complete codebase analysis, visualization, and transformation using graph-sitter", + version="3.0.0", +) + +# Configure CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Global storage for analysis sessions +analysis_sessions: Dict[str, Dict] = {} + +# ============================================================================ +# PYDANTIC MODELS +# ============================================================================ + + +class AnalyzeRequest(BaseModel): + repo_url: str = Field(..., description="GitHub repository URL") + branch: str = Field(default="main", description="Branch to analyze") + config: Optional[Dict] = Field(default=None, description="Analysis configuration") + include_deep_analysis: bool = Field( + default=True, description="Include comprehensive analysis" + ) + language: str = Field( + default="python", + description="Programming language of the codebase (e.g., python, csharp).", + ) + + +class ErrorAnalysisResponse(BaseModel): + total_errors: int + critical_errors: int + major_errors: int + minor_errors: int + errors_by_category: Dict[str, int] + detailed_errors: List[Dict[str, Any]] + error_patterns: List[Dict[str, Any]] + suggestions: List[Dict[str, Any]] + + +class EntrypointAnalysisResponse(BaseModel): + total_entrypoints: int + main_entrypoints: List[Dict[str, Any]] + secondary_entrypoints: List[Dict[str, Any]] + test_entrypoints: List[Dict[str, Any]] + api_entrypoints: List[Dict[str, Any]] # Added + cli_entrypoints: List[Dict[str, Any]] # Added + entrypoint_graph: Dict[str, Any] + complexity_metrics: Dict[str, Any] + dependency_analysis: Dict[str, Any] # Added + call_flow_analysis: Dict[str, Any] # Added + + +class TransformationRequest(BaseModel): + analysis_id: str + transformation_type: str + target_path: str + parameters: Dict[str, Any] = Field(default_factory=dict) + dry_run: bool = Field(default=True, description="Preview changes without applying") + + +class VisualizationRequest(BaseModel): + analysis_id: str + viz_type: str = Field(..., description="Type of visualization") + entry_point: Optional[str] = Field( + default=None, description="Entry point for visualization" + ) + max_depth: int = Field(default=10, description="Maximum depth for traversal") + include_external: bool = Field( + default=False, description="Include external modules" + ) + filter_patterns: List[str] = Field( + default_factory=list, description="Filter patterns" + ) + + +class DeadCodeAnalysisResponse(BaseModel): + total_dead_items: int + dead_functions: List[Dict[str, Any]] + dead_classes: List[Dict[str, Any]] + dead_imports: List[Dict[str, Any]] + dead_variables: List[Dict[str, Any]] + potential_dead_code: List[Dict[str, Any]] + recommendations: List[str] + + +class CodeQualityMetrics(BaseModel): + complexity_score: float + maintainability_index: float + technical_debt_ratio: float + test_coverage_estimate: float + documentation_coverage: float + code_duplication_score: float + type_coverage: float # Added + function_metrics: Dict[str, Any] # Added + class_metrics: Dict[str, Any] # Added + file_metrics: Dict[str, Any] # Added + + +# ============================================================================ +# CONTEXTUAL ERROR ANALYSIS (Moved to GraphSitterAnalyzer) +# ============================================================================ + +# ============================================================================ +# UTILITY FUNCTIONS +# ============================================================================ + + +def clone_repository(repo_url: str, branch: str = "main") -> str: + """Clone repository to temporary directory""" + try: + temp_dir = tempfile.mkdtemp(prefix="graph_sitter_") + repo_name = repo_url.split("/")[-1].replace(".git", "") + repo_path = os.path.join(temp_dir, repo_name) + + clone_cmd = [ + "git", + "clone", + "--depth", + "1", + "--branch", + branch, + repo_url, + repo_path, + ] + + result = subprocess.run(clone_cmd, capture_output=True, text=True, timeout=300) + + if result.returncode != 0: + raise Exception(f"Git clone failed: {result.stderr}") + + logger.info(f"Successfully cloned {repo_url} to {repo_path}") + return repo_path + + except subprocess.TimeoutExpired: + raise Exception("Repository clone timed out") + except Exception as e: + logger.error(f"Error cloning repository: {e}") + raise Exception(f"Failed to clone repository: {str(e)}") + + +def calculate_doi(cls: Class): + """Calculate the depth of inheritance for a given class.""" + return len(cls.superclasses) if hasattr(cls, "superclasses") else 0 + + +def get_operators_and_operands(function: Function): + """Extract operators and operands from function for Halstead metrics.""" + operators = [] + operands = [] + + try: + if hasattr(function, "source") and function.source: + source_code = function.source + # Simple keyword-based operator detection + operators.extend( + re.findall( + r"\b(if|for|while|return|def|class|import|from|and|or|not|in|is)\b", + source_code, + ) + ) + operators.extend( + re.findall( + r"(\+|\-|\*|\/|%|\*\*|//|=|==|!=|<|>|<=|>=|\+=|\-=|\*=|/=|%=|\*\*=|//=)", + source_code, + ) + ) + + # Simple variable/literal detection for operands + # This is very basic and would need proper AST traversal for accuracy + operands.extend( + re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", source_code) + ) # Identifiers + operands.extend(re.findall(r"\b\d+(\.\d+)?\b", source_code)) # Numbers + operands.extend(re.findall(r'(".*?"|\'.*?\')', source_code)) # Strings + + except Exception as e: + logger.warning(f"Error extracting operators/operands from {function.name}: {e}") + + return operators, operands + + +def calculate_halstead_volume(operators: List[str], operands: List[str]): + """Calculate Halstead volume metrics.""" + try: + n1 = len(set(operators)) # Unique operators + n2 = len(set(operands)) # Unique operands + N1 = len(operators) # Total operators + N2 = len(operands) # Total operands + + N = N1 + N2 # Program length + n = n1 + n2 # Program vocabulary + + if n > 0: + volume = N * math.log2(n) + return volume, N1, N2, n1, n2 + return 0, N1, N2, n1, n2 + except Exception: + return 0, 0, 0, 0, 0 + + +def cc_rank(complexity: int): + """Calculate cyclomatic complexity rank.""" + if complexity < 0: + raise ValueError("Complexity must be a non-negative value") + + ranks = [ + (1, 5, "A"), + (6, 10, "B"), + (11, 20, "C"), + (21, 30, "D"), + (31, 40, "E"), + (41, float("inf"), "F"), + ] + for low, high, rank in ranks: + if low <= complexity <= high: + return rank + return "F" + + +# ============================================================================ +# COMPREHENSIVE ANALYSIS ENGINE +# ============================================================================ + + +class AnalysisEngine: + """ + Comprehensive analysis engine for deep code analysis, integrating Graph-Sitter and LSP. + """ + + def __init__(self, codebase: Codebase, language: str): + self.codebase = codebase + self.language = language + self.analyzer = GraphSitterAnalyzer(codebase) + self.lsp_manager = LSPDiagnosticsManager( + codebase, Language(language) + ) # Pass codebase object + self.context_cache = {} + self.insight_cache = {} + + async def perform_full_analysis(self) -> Dict[str, Any]: + """Perform comprehensive codebase analysis using Graph-Sitter and LSP.""" + try: + # 1. Open files in LSP server for diagnostic collection + logger.info("Opening files in LSP server for diagnostic collection...") + self.lsp_manager.start_server() # Start LSP server + for file_obj in self.codebase.files: + try: + self.lsp_manager.open_file(file_obj.filepath, file_obj.source) + except Exception as e: + logger.warning( + f"Could not open file {file_obj.filepath} with LSP: {e}" + ) + + # Give LSP server some time to process files and publish diagnostics + logger.info( + "Waiting for LSP server to process files and publish diagnostics (5 seconds)..." + ) + await asyncio.sleep(5) # Adjust as needed for larger codebases + + # 2. Retrieve Enhanced Diagnostics + logger.info("Retrieving enhanced diagnostics from LSP server...") + all_lsp_diagnostics = self.lsp_manager.get_all_enhanced_diagnostics() + + # 3. Perform Graph-Sitter Analysis + logger.info("Performing comprehensive Graph-Sitter analysis...") + codebase_summary = self.analyzer.get_codebase_overview() + tree_structure = self._build_tree_structure_from_graph_sitter( + self.codebase, all_lsp_diagnostics + ) + error_analysis = self._analyze_errors_with_graph_sitter_enhanced( + self.codebase, all_lsp_diagnostics + ) + dead_code_analysis = ( + self.analyzer.find_dead_code() + ) # Using GraphSitterAnalyzer + entrypoint_analysis = self._analyze_entrypoints_with_graph_sitter_enhanced( + self.codebase + ) + dependency_graph = self._build_dependency_graph_from_graph_sitter( + self.codebase + ) + code_quality_metrics = self._calculate_code_quality_metrics(self.codebase) + architectural_insights = self._analyze_architectural_patterns(self.codebase) + security_analysis = self._analyze_security_patterns(self.codebase) + performance_analysis = self._analyze_performance_patterns(self.codebase) + + analysis = { + "codebase_summary": codebase_summary, + "tree_structure": tree_structure, + "error_analysis": error_analysis, + "dead_code_analysis": dead_code_analysis, + "entrypoint_analysis": entrypoint_analysis, + "dependency_graph": dependency_graph, + "code_quality_metrics": code_quality_metrics, + "architectural_insights": architectural_insights, + "security_analysis": security_analysis, + "performance_analysis": performance_analysis, + "metrics": { + "files": len(list(self.codebase.files)), + "functions": len(list(self.codebase.functions)), + "classes": len(list(self.codebase.classes)), + "symbols": len(list(self.codebase.symbols)), + "imports": len(list(self.codebase.imports)), + "external_modules": len(list(self.codebase.external_modules)), + }, + } + + return analysis + + except Exception as e: + logger.error(f"Error analyzing codebase with graph-sitter: {e}") + logger.error(traceback.format_exc()) + raise Exception(f"Graph-sitter analysis failed: {str(e)}") + finally: + self.lsp_manager.shutdown_server() # Ensure LSP server is shut down + + def _analyze_errors_with_graph_sitter_enhanced( + self, codebase: Codebase, lsp_diagnostics: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Enhanced comprehensive error analysis using Graph-sitter APIs and LSP diagnostics.""" + errors = { + "total": 0, + "critical": 0, + "major": 0, + "minor": 0, + "by_category": defaultdict(int), # Use defaultdict for easier counting + "detailed_errors": [], + "error_patterns": [], + "suggestions": [], + "resolution_recommendations": [], + } + + # Integrate LSP diagnostics (which are already enhanced by autogenlib_context) + for enhanced_diag in lsp_diagnostics: + diag = enhanced_diag["diagnostic"] + error_entry = { + "severity": diag.severity.name.lower() if diag.severity else "unknown", + "category": diag.code if diag.code else "lsp_diagnostic", + "file": enhanced_diag["relative_file_path"], # Use relative path + "symbol": diag.source, # LSP source + "line": diag.range.line + 1, + "message": diag.message, + "context": enhanced_diag, # Store the full enhanced diagnostic object + "suggestion": "Review LSP diagnostic message and apply fix.", + "resolution_method": "ai_resolution_lsp", + } + errors["detailed_errors"].append(error_entry) + errors["by_category"][error_entry["category"]] += 1 + if error_entry["severity"] == "error": # LSP error severity is 1 (Error) + errors["critical"] += 1 + elif ( + error_entry["severity"] == "warning" + ): # LSP warning severity is 2 (Warning) + errors["major"] += 1 + else: # Info, Hint, Unknown + errors["minor"] += 1 + errors["total"] += 1 + + # Add Graph-Sitter specific analysis (e.g., missing docstrings, unused imports, circular imports) + # These are examples; actual implementation would involve traversing codebase objects + # and checking properties like `usages`, `docstring`, `return_type`, etc. + + # Example: Missing docstrings + for func in codebase.functions: + if not hasattr(func, "docstring") or not func.docstring: + error_entry = { + "severity": "minor", + "category": "missing_docstrings", + "file": func.filepath, + "symbol": func.name, + "line": func.start_point.line + 1 + if hasattr(func, "start_point") + else 0, + "message": "Missing docstring", + "context": f"Function '{func.name}' has no documentation", + "suggestion": f'Add docstring: """Brief description of {func.name}."""', + "resolution_method": "generate_docstring", + } + errors["detailed_errors"].append(error_entry) + errors["by_category"]["missing_docstrings"] += 1 + errors["minor"] += 1 + errors["total"] += 1 + + # Example: Unused imports + for file_obj in codebase.files: + for imp in file_obj.imports: + if not hasattr(imp, "usages") or len(imp.usages) == 0: + error_entry = { + "severity": "minor", + "category": "unused_imports", + "file": file_obj.filepath, + "symbol": imp.name, + "line": imp.start_point.line + 1 + if hasattr(imp, "start_point") + else 0, + "message": "Unused import", + "context": f"Import '{imp.name}' is not used", + "suggestion": "Remove unused import", + "resolution_method": "remove_unused_imports", + } + errors["detailed_errors"].append(error_entry) + errors["by_category"]["unused_imports"] += 1 + errors["minor"] += 1 + errors["total"] += 1 + + # Example: Circular imports (using NetworkX) + import_graph = nx.DiGraph() + for file_obj in codebase.files: + import_graph.add_node(file_obj.filepath) + for imp in file_obj.imports: + if hasattr(imp, "from_file") and imp.from_file: + import_graph.add_edge(file_obj.filepath, imp.from_file.filepath) + + cycles = list(nx.simple_cycles(import_graph)) + for cycle in cycles: + for file_path in cycle: + error_entry = { + "severity": "critical", + "category": "circular_imports", + "file": file_path, + "symbol": "imports", + "line": 1, + "message": "Circular import detected", + "context": f"File is part of circular import: {' -> '.join(cycle)}", + "suggestion": "Refactor to remove circular dependency", + "resolution_method": "refactor_circular_imports", + } + errors["detailed_errors"].append(error_entry) + errors["by_category"]["circular_imports"] += 1 + errors["critical"] += 1 + errors["total"] += 1 + + # Generate enhanced error patterns + errors["error_patterns"] = self._analyze_error_patterns( + errors["detailed_errors"] + ) + + # Generate resolution recommendations + errors["resolution_recommendations"] = ( + self._generate_resolution_recommendations(errors) + ) + + return errors + + def _analyze_error_patterns( + self, detailed_errors: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Enhanced error pattern analysis with resolution suggestions""" + patterns = [] + + # Group errors by category and file + error_groups = defaultdict(lambda: defaultdict(list)) + for error in detailed_errors: + error_groups[error["category"]][error["file"]].append(error) + + # Analyze patterns within each category + for category, file_errors in error_groups.items(): + if len(file_errors) > 1: + total_errors = sum(len(errors) for errors in file_errors.values()) + most_affected_files = sorted( + file_errors.items(), key=lambda x: len(x[1]), reverse=True + )[:3] + + patterns.append( + { + "category": category, + "total_count": total_errors, + "affected_files": len(file_errors), + "most_affected_files": [ + {"file": file, "count": len(errors)} + for file, errors in most_affected_files + ], + "pattern_description": f"Widespread {category} errors across {len(file_errors)} files", + "severity": most_affected_files[0][1][0]["severity"] + if most_affected_files + else "minor", + "resolution_strategy": self._get_resolution_strategy(category), + } + ) + + return patterns + + def _get_resolution_strategy(self, category: str) -> str: + """Get resolution strategy for error category""" + strategies = { + "missing_types": "Use type inference tools or add explicit type annotations", + "unused_parameters": "Remove unused parameters or prefix with underscore", + "unused_imports": "Use import optimization tools to remove unused imports", + "wrong_call_sites": "Update function signatures or fix call sites", + "circular_imports": "Refactor code to break circular dependencies", + "unresolved_imports": "Fix import paths or install missing dependencies", + "missing_arguments": "Add required arguments to function calls", + "incorrect_types": "Fix type annotations or import missing types", + "unimplemented_methods": "Implement abstract methods or remove inheritance", + "missing_attributes": "Add missing class attributes or fix attribute access", + "parameter_mismatches": "Fix argument types to match parameter expectations", + "assignment_errors": "Define variables before use or fix variable references", + "lsp_diagnostic": "Consult LSP server documentation for specific diagnostic code, or use AI resolution.", + } + return strategies.get(category, "Manual review and correction required") + + def _generate_resolution_recommendations( + self, errors: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """Generate comprehensive resolution recommendations""" + recommendations = [] + + # Type-related recommendations + type_errors_count = errors["by_category"].get("missing_types", 0) + errors[ + "by_category" + ].get("incorrect_types", 0) + if type_errors_count > 0: + recommendations.append( + { + "type": "type_resolution", + "priority": "high", + "description": f"Resolve {type_errors_count} type-related issues", + "actions": [ + "Run mypy --install-types to install missing type stubs", + "Add explicit type annotations to functions and variables", + "Use type inference tools to suggest appropriate types", + "Import missing types from appropriate modules", + ], + "automated_fix": "resolve_all_types", + } + ) + + # Import-related recommendations + import_errors_count = ( + errors["by_category"].get("unresolved_imports", 0) + + errors["by_category"].get("unused_imports", 0) + + errors["by_category"].get("circular_imports", 0) + ) + if import_errors_count > 0: + recommendations.append( + { + "type": "import_resolution", + "priority": "medium", + "description": f"Resolve {import_errors_count} import-related issues", + "actions": [ + "Fix unresolved import paths", + "Remove unused imports", + "Organize imports by type (stdlib, third-party, local)", + "Add missing imports for used symbols", + "Refactor code to break circular dependencies", + ], + "automated_fix": "resolve_all_imports", + } + ) + + # Function call recommendations + call_errors_count = ( + errors["by_category"].get("wrong_call_sites", 0) + + errors["by_category"].get("missing_arguments", 0) + + errors["by_category"].get("parameter_mismatches", 0) + ) + if call_errors_count > 0: + recommendations.append( + { + "type": "function_call_resolution", + "priority": "high", + "description": f"Resolve {call_errors_count} function call issues", + "actions": [ + "Add missing arguments to function calls", + "Fix argument types to match parameter expectations", + "Update function signatures if needed", + "Convert positional arguments to keyword arguments", + ], + "automated_fix": "resolve_all_function_calls", + } + ) + + # LSP specific recommendations + lsp_diag_count = errors["by_category"].get("lsp_diagnostic", 0) + if lsp_diag_count > 0: + recommendations.append( + { + "type": "lsp_diagnostic_resolution", + "priority": "high", + "description": f"Address {lsp_diag_count} LSP reported diagnostics", + "actions": [ + "Review detailed LSP messages for specific guidance", + "Utilize AI-driven resolution for automated fixes", + "Consult language-specific documentation for best practices", + ], + "automated_fix": "ai_resolution_lsp", + } + ) + + return recommendations + + def _analyze_entrypoints_with_graph_sitter_enhanced( + self, codebase: Codebase + ) -> Dict[str, Any]: + """Enhanced entrypoint analysis using comprehensive Graph-sitter APIs""" + entrypoints = { + "total_entrypoints": 0, + "main_entrypoints": [], + "secondary_entrypoints": [], + "test_entrypoints": [], + "api_entrypoints": [], + "cli_entrypoints": [], + "entrypoint_graph": {}, + "complexity_metrics": {}, + "dependency_analysis": {}, + "call_flow_analysis": {}, + } + + # Enhanced main entrypoint detection + for func in codebase.functions: + if self._is_entrypoint_function_enhanced(func): + entrypoint_data = { + "name": func.name, + "file": func.filepath, + "type": "function", + "complexity": self._calculate_function_complexity(func), + "dependencies": len(list(func.dependencies)), + "usages": len(list(func.usages)), + "call_sites": len( + list(func.function_calls) + ), # Changed from call_sites + "is_async": getattr(func, "is_async", False), + "parameters": self._get_function_parameters_details( + func + ), # Detailed parameters + "return_type": self._get_function_return_type_details( + func + ), # Detailed return type + "local_variables": self._get_function_local_variables_details( + func + ), # Local variables + "docstring": bool(getattr(func, "docstring", None)), + "function_calls_count": len(list(func.function_calls)), + "return_statements_count": len( + list(getattr(func, "return_statements", [])) + ), + "variable_usages_count": len( + list(getattr(func, "variable_usages", [])) + ), + "symbol_usages_count": len( + list(getattr(func, "symbol_usages", [])) + ), + "entrypoint_score": self._calculate_entrypoint_score(func), + } + + # Categorize entrypoints more precisely + if func.name in ["main", "__main__"] or func.name.startswith("main_"): + entrypoints["main_entrypoints"].append(entrypoint_data) + elif self._is_api_entrypoint(func): + entrypoints["api_entrypoints"].append(entrypoint_data) + elif self._is_cli_entrypoint(func): + entrypoints["cli_entrypoints"].append(entrypoint_data) + else: + entrypoints["secondary_entrypoints"].append(entrypoint_data) + + # Enhanced class entrypoint detection + for cls in codebase.classes: + if self._is_entrypoint_class_enhanced(cls): + entrypoint_data = { + "name": cls.name, + "file": cls.filepath, + "type": "class", + "complexity": self._calculate_class_complexity(cls), + "methods_count": len(list(cls.methods)), + "attributes_count": len(list(cls.attributes)), + "inheritance_depth": calculate_doi(cls), + "usages_count": len(list(cls.usages)), + "subclasses_count": len(list(cls.subclasses)), + "dependencies_count": len(list(cls.dependencies)), + "entrypoint_score": self._calculate_class_entrypoint_score(cls), + "methods_details": [ + { + "name": m.name, + "parameters": self._get_function_parameters_details(m), + "return_type": self._get_function_return_type_details(m), + "complexity": self._calculate_function_complexity(m), + } + for m in cls.methods + ], + } + + if self._is_api_entrypoint_class(cls): + entrypoints["api_entrypoints"].append(entrypoint_data) + else: + entrypoints["secondary_entrypoints"].append(entrypoint_data) + + # Enhanced test entrypoint detection + for func in codebase.functions: + if self._is_test_function_enhanced(func): + entrypoint_data = { + "name": func.name, + "file": func.filepath, + "type": "test_function", + "complexity": self._calculate_function_complexity(func), + "dependencies_count": len(list(func.dependencies)), + "test_type": self._classify_test_type(func), + "test_coverage_estimate": self._estimate_test_coverage(func), + } + entrypoints["test_entrypoints"].append(entrypoint_data) + + # Calculate totals + entrypoints["total_entrypoints"] = ( + len(entrypoints["main_entrypoints"]) + + len(entrypoints["secondary_entrypoints"]) + + len(entrypoints["test_entrypoints"]) + + len(entrypoints["api_entrypoints"]) + + len(entrypoints["cli_entrypoints"]) + ) + + # Enhanced entrypoint graph analysis + entrypoint_graph = nx.DiGraph() + all_entrypoints_list = ( + entrypoints["main_entrypoints"] + + entrypoints["secondary_entrypoints"] + + entrypoints["api_entrypoints"] + + entrypoints["cli_entrypoints"] + ) + + for entrypoint in all_entrypoints_list: + func = next( + (f for f in codebase.functions if f.name == entrypoint["name"]), None + ) + if func: + entrypoint_graph.add_node(func.name, **entrypoint) + + # Add call relationships + for call in func.function_calls: + if ( + hasattr(call, "function_definition") + and call.function_definition + ): + called_func = call.function_definition + if not isinstance(called_func, ExternalModule): + entrypoint_graph.add_edge( + func.name, + called_func.name, + relationship="calls", + call_count=1, + ) + + # Add dependency relationships + for dep in func.dependencies: + if ( + hasattr(dep, "name") + and dep.name != func.name + and not isinstance(dep, ExternalModule) + ): + entrypoint_graph.add_edge( + func.name, dep.name, relationship="depends_on" + ) + + entrypoints["entrypoint_graph"] = { + "nodes": len(entrypoint_graph.nodes), + "edges": len(entrypoint_graph.edges), + "connected_components": len( + list(nx.weakly_connected_components(entrypoint_graph)) + ), + "strongly_connected_components": len( + list(nx.strongly_connected_components(entrypoint_graph)) + ), + "cycles": len(list(nx.simple_cycles(entrypoint_graph))), + "max_depth": len(nx.dag_longest_path(entrypoint_graph)) + if nx.is_directed_acyclic_graph(entrypoint_graph) + else 0, + } + + # Enhanced complexity metrics + complexities = [ep["complexity"] for ep in all_entrypoints_list] + if complexities: + entrypoints["complexity_metrics"] = { + "average_complexity": sum(complexities) / len(complexities), + "max_complexity": max(complexities), + "min_complexity": min(complexities), + "high_complexity_count": len([c for c in complexities if c > 10]), + "complexity_distribution": { + "low": len([c for c in complexities if c <= 5]), + "medium": len([c for c in complexities if 5 < c <= 10]), + "high": len([c for c in complexities if c > 10]), + }, + } + + # Dependency analysis + entrypoints["dependency_analysis"] = self._analyze_entrypoint_dependencies( + all_entrypoints_list, codebase + ) + + # Call flow analysis + entrypoints["call_flow_analysis"] = self._analyze_entrypoint_call_flows( + all_entrypoints_list, codebase + ) + + return entrypoints + + def _get_function_parameters_details(self, func: Function) -> List[Dict[str, Any]]: + """Extracts detailed information about function parameters.""" + params_details = [] + for param in func.parameters: + param_type_source = ( + getattr(param.type, "source", "Any") + if hasattr(param, "type") and param.type + else "Any" + ) + resolved_types = [] + if ( + hasattr(param, "type") + and param.type + and hasattr(param.type, "resolved_value") + and param.type.resolved_value + ): + # resolved_value can be a single Symbol or a list of Symbols + resolved_symbols = param.type.resolved_value + if not isinstance(resolved_symbols, list): + resolved_symbols = [resolved_symbols] + for res_sym in resolved_symbols: + if isinstance(res_sym, Symbol): + resolved_types.append( + { + "name": res_sym.name, + "type": type(res_sym).__name__, + "filepath": res_sym.filepath + if hasattr(res_sym, "filepath") + else None, + } + ) + elif isinstance(res_sym, ExternalModule): + resolved_types.append( + {"name": res_sym.name, "type": "ExternalModule"} + ) + else: + resolved_types.append({"name": str(res_sym), "type": "Unknown"}) + + params_details.append( + { + "name": param.name, + "type_annotation": param_type_source, + "resolved_types": resolved_types, + "has_default": param.has_default, + "is_keyword_only": param.is_keyword_only, + "is_positional_only": param.is_positional_only, + "is_var_arg": param.is_var_arg, + "is_var_kw": param.is_var_kw, + } + ) + return params_details + + def _get_function_return_type_details(self, func: Function) -> Dict[str, Any]: + """Extracts detailed information about function return type.""" + return_type_source = ( + getattr(func.return_type, "source", "Any") + if hasattr(func, "return_type") and func.return_type + else "Any" + ) + resolved_types = [] + if ( + hasattr(func, "return_type") + and func.return_type + and hasattr(func.return_type, "resolved_value") + and func.return_type.resolved_value + ): + resolved_symbols = func.return_type.resolved_value + if not isinstance(resolved_symbols, list): + resolved_symbols = [resolved_symbols] + for res_sym in resolved_symbols: + if isinstance(res_sym, Symbol): + resolved_types.append( + { + "name": res_sym.name, + "type": type(res_sym).__name__, + "filepath": res_sym.filepath + if hasattr(res_sym, "filepath") + else None, + } + ) + elif isinstance(res_sym, ExternalModule): + resolved_types.append( + {"name": res_sym.name, "type": "ExternalModule"} + ) + else: + resolved_types.append({"name": str(res_sym), "type": "Unknown"}) + return {"type_annotation": return_type_source, "resolved_types": resolved_types} + + def _get_function_local_variables_details( + self, func: Function + ) -> List[Dict[str, Any]]: + """Extracts details about local variables defined within a function.""" + local_vars = [] + if hasattr(func, "code_block") and hasattr( + func.code_block, "local_var_assignments" + ): + for assignment in func.code_block.local_var_assignments: + var_type_source = ( + getattr(assignment.type, "source", "Any") + if hasattr(assignment, "type") and assignment.type + else "Any" + ) + local_vars.append( + { + "name": assignment.name, + "type_annotation": var_type_source, + "line": assignment.start_point.line + 1 + if hasattr(assignment, "start_point") + else None, + "value_snippet": assignment.source + if hasattr(assignment, "source") + else None, + } + ) + return local_vars + + def _is_test_function_enhanced(self, func: Function) -> bool: + """Enhanced test function detection""" + # Standard test patterns + if func.name.startswith("test_") or func.name.endswith("_test"): + return True + + # Check for test decorators + if hasattr(func, "decorators"): + for decorator in func.decorators: + decorator_source = getattr(decorator, "source", "") + if any( + pattern in decorator_source.lower() + for pattern in ["@pytest.", "@unittest.", "@test"] + ): + return True + + # Check if in test file + test_file_patterns = ["test_", "_test.", "tests/", "/test/", "spec_", "_spec."] + if any(pattern in func.filepath for pattern in test_file_patterns): + return True + + # Check for assertion patterns in function body + if hasattr(func, "source") and func.source: + assertion_patterns = ["assert ", "self.assert", "expect(", "should."] + if any(pattern in func.source for pattern in assertion_patterns): + return True + + return False + + def _is_entrypoint_class_enhanced(self, cls: Class) -> bool: + """Enhanced entrypoint class detection""" + # Standard entrypoint patterns + entrypoint_patterns = [ + "app", + "application", + "server", + "client", + "main", + "runner", + "service", + "controller", + ] + if any(pattern in cls.name.lower() for pattern in entrypoint_patterns): + return True + + # Check for framework-specific patterns + framework_patterns = ["fastapi", "flask", "django", "tornado", "aiohttp"] + for method in cls.methods: + if any(pattern in method.name.lower() for pattern in framework_patterns): + return True + + # Check for inheritance from framework classes + for superclass in cls.superclasses: + if any( + pattern in superclass.name.lower() + for pattern in ["application", "app", "service", "handler"] + ): + return True + + # Check for singleton patterns (often used for main application classes) + if any( + "instance" in method.name.lower() or "singleton" in method.name.lower() + for method in cls.methods + ): + return True + + return False + + def _is_api_entrypoint_class(self, cls: Class) -> bool: + """Check if class is an API entrypoint""" + api_patterns = ["api", "router", "controller", "handler", "endpoint", "view"] + if any(pattern in cls.name.lower() for pattern in api_patterns): + return True + + # Check for API-related decorators on methods + for method in cls.methods: + if hasattr(method, "decorators"): + for decorator in method.decorators: + decorator_source = getattr(decorator, "source", "") + if any( + pattern in decorator_source.lower() + for pattern in ["@route", "@get", "@post", "@put", "@delete"] + ): + return True + + # Check for inheritance from API frameworks + for superclass in cls.superclasses: + if any( + pattern in superclass.name.lower() + for pattern in ["resource", "view", "handler", "controller"] + ): + return True + + return False + + def _calculate_entrypoint_score(self, func: Function) -> float: + """Calculate entrypoint score based on various factors""" + score = 0.0 + + # Base score for being a function + score += 1.0 + + # Score based on name patterns + entrypoint_names = ["main", "run", "start", "execute", "app", "serve", "launch"] + if any(name in func.name.lower() for name in entrypoint_names): + score += 2.0 + + # Score based on usage patterns + if len(func.usages) == 0: # Not called by other functions + score += 1.0 + elif len(func.usages) < 3: # Called by few functions + score += 0.5 + + # Score based on complexity (entrypoints often coordinate other functions) + complexity = self._calculate_function_complexity(func) + if 5 <= complexity <= 15: # Sweet spot for entrypoints + score += 1.0 + elif complexity > 15: # Very complex, likely important + score += 0.5 + + # Score based on function calls (entrypoints often call many other functions) + if len(func.function_calls) > 5: + score += 1.0 + elif len(func.function_calls) > 2: + score += 0.5 + + # Score based on decorators + if hasattr(func, "decorators"): + for decorator in func.decorators: + decorator_source = getattr(decorator, "source", "") + if any( + pattern in decorator_source.lower() + for pattern in ["@app.", "@click.", "@typer."] + ): + score += 2.0 + + # Score based on file location + if any( + pattern in func.filepath + for pattern in ["main.py", "app.py", "server.py", "cli.py"] + ): + score += 1.0 + + return score + + def _estimate_test_coverage(self, func: Function) -> float: + """Estimate test coverage based on function characteristics""" + coverage = 0.0 + + # Base coverage for being a test + coverage += 0.3 + + # Coverage based on assertions + if hasattr(func, "source") and func.source: + assertion_count = func.source.count("assert ") + func.source.count( + "self.assert" + ) + coverage += min(0.4, assertion_count * 0.1) + + # Coverage based on function calls (tests that call many functions likely test more) + call_count = len(func.function_calls) + coverage += min(0.3, call_count * 0.05) + + return min(1.0, coverage) + + def _analyze_entrypoint_dependencies( + self, all_entrypoints: List[Dict[str, Any]], codebase: Codebase + ) -> Dict[str, Any]: + """Analyze dependencies between entrypoints""" + dependency_analysis = { + "shared_dependencies": [], + "isolated_entrypoints": [], + "dependency_clusters": [], + "external_dependencies": [], + } + + # Find shared dependencies + entrypoint_deps = {} + for entrypoint in all_entrypoints: + func = next( + (f for f in codebase.functions if f.name == entrypoint["name"]), None + ) + if func: + entrypoint_deps[entrypoint["name"]] = set( + dep.name + for dep in func.dependencies + if not isinstance(dep, ExternalModule) + ) + # Track external dependencies + for dep in func.dependencies: + if isinstance(dep, ExternalModule): + dependency_analysis["external_dependencies"].append( + { + "entrypoint": entrypoint["name"], + "dependency": dep.name, + "module": dep.module, + } + ) + + # Find dependencies shared by multiple entrypoints + all_deps = set() + for deps in entrypoint_deps.values(): + all_deps.update(deps) + + for dep in all_deps: + sharing_entrypoints = [ + name for name, deps in entrypoint_deps.items() if dep in deps + ] + if len(sharing_entrypoints) > 1: + dependency_analysis["shared_dependencies"].append( + { + "dependency": dep, + "shared_by": sharing_entrypoints, + "share_count": len(sharing_entrypoints), + } + ) + + # Find isolated entrypoints + for name, deps in entrypoint_deps.items(): + shared_deps = [ + dep + for dep in deps + if any( + dep in other_deps + for other_name, other_deps in entrypoint_deps.items() + if other_name != name + ) + ] + if len(shared_deps) == 0: + dependency_analysis["isolated_entrypoints"].append(name) + + return dependency_analysis + + def _analyze_entrypoint_call_flows( + self, all_entrypoints: List[Dict[str, Any]], codebase: Codebase + ) -> Dict[str, Any]: + """Analyze call flows from entrypoints""" + call_flow_analysis = { + "max_call_depth": 0, + "average_call_depth": 0.0, + "call_patterns": [], + "recursive_calls": [], + } + + call_depths = [] + + for entrypoint in all_entrypoints: + func = next( + (f for f in codebase.functions if f.name == entrypoint["name"]), None + ) + if func: + # Calculate call depth using BFS + depth = self._calculate_call_depth(func, codebase) + call_depths.append(depth) + + # Check for recursive calls + if self._has_recursive_calls(func): + call_flow_analysis["recursive_calls"].append( + {"entrypoint": entrypoint["name"], "file": entrypoint["file"]} + ) + + if call_depths: + call_flow_analysis["max_call_depth"] = max(call_depths) + call_flow_analysis["average_call_depth"] = sum(call_depths) / len( + call_depths + ) + + return call_flow_analysis + + def _calculate_call_depth( + self, func: Function, codebase: Codebase, visited=None, depth=0 + ) -> int: + """Calculate maximum call depth from a function""" + if visited is None: + visited = set() + + if func in visited or depth > 20: # Prevent infinite recursion + return depth + + visited.add(func) + max_depth = depth + + for call in func.function_calls: + if hasattr(call, "function_definition") and call.function_definition: + called_func = call.function_definition + if not isinstance(called_func, ExternalModule): + call_depth = self._calculate_call_depth( + called_func, codebase, visited.copy(), depth + 1 + ) + max_depth = max(max_depth, call_depth) + + return max_depth + + def _has_recursive_calls(self, func: Function) -> bool: + """Check if function has recursive calls""" + for call in func.function_calls: + if hasattr(call, "function_definition") and call.function_definition: + if call.function_definition.name == func.name: + return True + return False + + def _calculate_class_entrypoint_score(self, cls: Class) -> float: + """Calculate entrypoint score for classes""" + score = 0.0 + + # Base score for being a class + score += 1.0 + + # Score based on name patterns + entrypoint_names = ["app", "application", "server", "client", "main", "service"] + if any(name in cls.name.lower() for name in entrypoint_names): + score += 2.0 + + # Score based on methods + if len(cls.methods) > 10: # Large classes often coordinate functionality + score += 1.0 + + # Score based on inheritance + if len(cls.superclasses) > 0: + for superclass in cls.superclasses: + if any( + pattern in superclass.name.lower() + for pattern in ["application", "service", "handler"] + ): + score += 1.5 + + # Score based on singleton patterns + if any("instance" in method.name.lower() for method in cls.methods): + score += 1.0 + + return score + + def _classify_test_type(self, func: Function) -> str: + """Classify the type of test function""" + if "unit" in func.filepath or "unit" in func.name.lower(): + return "unit" + elif "integration" in func.filepath or "integration" in func.name.lower(): + return "integration" + elif "e2e" in func.filepath or "end_to_end" in func.name.lower(): + return "end_to_end" + elif "performance" in func.filepath or "perf" in func.name.lower(): + return "performance" + else: + return "unknown" + + def _build_tree_structure_from_graph_sitter( + self, codebase: Codebase, all_lsp_diagnostics: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Build hierarchical tree structure from graph-sitter codebase with integrated LSP errors""" + root = { + "name": "root", + "type": "directory", + "path": "", + "children": [], + "errors": {"critical": 0, "major": 0, "minor": 0}, + "isEntrypoint": False, + "metrics": { + "complexity_score": 0, + "maintainability_index": 0, + "lines_of_code": 0, + }, + } + + # Group files by directory + dir_structure = defaultdict(lambda: {"files": [], "subdirs": {}}) + + for file_obj in ( + codebase.files + ): # Changed from file to file_obj to avoid conflict with file.filepath + try: + # Get LSP diagnostics for this file + file_lsp_diagnostics = [ + d + for d in all_lsp_diagnostics + if d["relative_file_path"] == file_obj.filepath + ] + + file_node = { + "name": file_obj.name, + "type": "file", + "path": file_obj.filepath, + "children": [], + "errors": self._detect_file_errors_graph_sitter( + file_obj, file_lsp_diagnostics + ), # Pass LSP diagnostics + "isEntrypoint": self._is_entrypoint_file(file_obj), + "metrics": { + "lines": len(file_obj.source.splitlines()) + if hasattr(file_obj, "source") + else 0, + "functions": len(list(file_obj.functions)), + "classes": len(list(file_obj.classes)), + "imports": len(list(file_obj.imports)), + "symbols": len(list(getattr(file_obj, "symbols", []))), + "variable_usages": len( + list(getattr(file_obj, "variable_usages", [])) + ), + "symbol_usages": len( + list(getattr(file_obj, "symbol_usages", [])) + ), + "complexity_score": self._calculate_file_complexity(file_obj), + "maintainability_index": self._calculate_maintainability_index( + file_obj + ), + }, + } + + # Add functions as children with comprehensive metrics + for func in file_obj.functions: + try: + func_node = { + "name": func.name, + "type": "function", + "path": f"{file_obj.filepath}::{func.name}", + "children": [], + "errors": self._detect_function_errors_graph_sitter( + func, file_lsp_diagnostics + ), # Pass LSP diagnostics + "isEntrypoint": self._is_entrypoint_function(func), + "metrics": { + "parameters": self._get_function_parameters_details( + func + ), # Detailed parameters + "return_type": self._get_function_return_type_details( + func + ), # Detailed return type + "local_variables": self._get_function_local_variables_details( + func + ), # Local variables + "usages": len(list(func.usages)), + "call_sites": len( + list(func.function_calls) + ), # Changed from call_sites + "dependencies": len(list(func.dependencies)), + "return_statements_count": len( + list(getattr(func, "return_statements", [])) + ), + "variable_usages_count": len( + list(getattr(func, "variable_usages", [])) + ), + "symbol_usages_count": len( + list(getattr(func, "symbol_usages", [])) + ), + "parent_class": getattr(func, "parent_class", None) + is not None, + "parent_function": getattr( + func, "parent_function", None + ) + is not None, + "type_parameters": len( + list(getattr(func, "type_parameters", [])) + ), + "complexity_score": self._calculate_function_complexity( + func + ), + "halstead_volume": calculate_halstead_volume( + *get_operators_and_operands(func) + )[0], # Only volume + }, + } + file_node["children"].append(func_node) + except Exception as e: + logger.warning(f"Error processing function {func.name}: {e}") + + # Add classes as children with comprehensive metrics + for cls in file_obj.classes: + try: + class_node = { + "name": cls.name, + "type": "class", + "path": f"{file_obj.filepath}::{cls.name}", + "children": [], + "errors": self._detect_class_errors_graph_sitter( + cls, file_lsp_diagnostics + ), # Pass LSP diagnostics + "isEntrypoint": self._is_entrypoint_class(cls), + "metrics": { + "methods": len(list(cls.methods)), + "attributes": len(list(cls.attributes)), + "usages": len(list(cls.usages)), + "superclasses": len(list(cls.superclasses)), + "subclasses": len(list(cls.subclasses)), + "dependencies": len(list(cls.dependencies)), + "symbol_type": getattr(cls, "symbol_type", "class"), + "parent": getattr(cls, "parent", None) is not None, + "resolved_value": getattr(cls, "resolved_value", None) + is not None, + "inheritance_depth": calculate_doi(cls), + "complexity_score": self._calculate_class_complexity( + cls + ), + }, + } + + # Add methods as children with enhanced metrics + for method in cls.methods: + try: + method_node = { + "name": method.name, + "type": "method", + "path": f"{file_obj.filepath}::{cls.name}::{method.name}", + "children": [], + "errors": self._detect_function_errors_graph_sitter( + method, file_lsp_diagnostics + ), + "isEntrypoint": False, + "metrics": { + "parameters": self._get_function_parameters_details( + method + ), + "return_type": self._get_function_return_type_details( + method + ), + "local_variables": self._get_function_local_variables_details( + method + ), + "usages": len(list(method.usages)), + "parent_class": cls.name, + "parent_function": getattr( + method, "parent_function", None + ) + is not None, + "variable_usages_count": len( + list(getattr(method, "variable_usages", [])) + ), + "symbol_usages_count": len( + list(getattr(method, "symbol_usages", [])) + ), + "parent_statement": getattr( + method, "parent_statement", None + ) + is not None, + "complexity_score": self._calculate_function_complexity( + method + ), + }, + } + class_node["children"].append(method_node) + except Exception as e: + logger.warning( + f"Error processing method {method.name}: {e}" + ) + + file_node["children"].append(class_node) + except Exception as e: + logger.warning(f"Error processing class {cls.name}: {e}") + + # Add to directory structure + path_parts = file_obj.filepath.split(os.sep) + current_dir_level = dir_structure + for part in path_parts[:-1]: + current_dir_level = current_dir_level[part]["subdirs"] + current_dir_level[path_parts[-1]] = { + "files": [file_node], + "subdirs": {}, + } # Store file node under its name + + except Exception as e: + logger.warning(f"Error processing file {file_obj.filepath}: {e}") + + # Convert to hierarchical structure + root["children"] = self._build_directory_nodes_recursive(dir_structure, "") + return root + + def _build_directory_nodes_recursive( + self, dir_structure: Dict, current_path: str + ) -> List[Dict]: + """Recursively build directory nodes from the processed dir_structure.""" + nodes = [] + + for name, content in dir_structure.items(): + if "files" in content and content["files"]: + # This is a file node already processed + nodes.extend(content["files"]) + else: + # This is a directory node + dir_node = { + "name": name, + "type": "directory", + "path": os.path.join(current_path, name).replace( + "\\", "/" + ), # Ensure Unix-like paths + "children": [], + "errors": {"critical": 0, "major": 0, "minor": 0}, + "isEntrypoint": False, + "metrics": { + "total_files": 0, + "total_functions": 0, + "total_classes": 0, + "total_lines": 0, + }, + } + + # Add subdirectories and files + dir_node["children"].extend( + self._build_directory_nodes_recursive( + content["subdirs"], dir_node["path"] + ) + ) + + # Aggregate errors and metrics from children + for child in dir_node["children"]: + for severity in ["critical", "major", "minor"]: + dir_node["errors"][severity] += child["errors"][severity] + + if child["type"] == "file": + dir_node["metrics"]["total_files"] += 1 + dir_node["metrics"]["total_functions"] += child["metrics"].get( + "functions", 0 + ) + dir_node["metrics"]["total_classes"] += child["metrics"].get( + "classes", 0 + ) + dir_node["metrics"]["total_lines"] += child["metrics"].get( + "lines", 0 + ) + elif ( + child["type"] == "directory" + ): # Aggregate from sub-directories too + dir_node["metrics"]["total_files"] += child["metrics"].get( + "total_files", 0 + ) + dir_node["metrics"]["total_functions"] += child["metrics"].get( + "total_functions", 0 + ) + dir_node["metrics"]["total_classes"] += child["metrics"].get( + "total_classes", 0 + ) + dir_node["metrics"]["total_lines"] += child["metrics"].get( + "total_lines", 0 + ) + + nodes.append(dir_node) + + return nodes + + def _build_dependency_graph_from_graph_sitter( + self, codebase: Codebase + ) -> Dict[str, Any]: + """Build dependency graph from graph-sitter codebase""" + dependency_graph = { + "nodes": 0, + "edges": 0, + "cycles": 0, + "strongly_connected_components": 0, + "max_depth": 0, + "file_dependencies": {}, + "symbol_dependencies": {}, + "import_graph": {}, + } + + # Build file dependency graph + file_graph = nx.DiGraph() + for file_obj in codebase.files: + file_graph.add_node(file_obj.filepath) + for imp in file_obj.imports: + if hasattr(imp, "from_file") and imp.from_file: + file_graph.add_edge(file_obj.filepath, imp.from_file.filepath) + + dependency_graph["file_dependencies"] = { + "nodes": len(file_graph.nodes), + "edges": len(file_graph.edges), + "cycles": len(list(nx.simple_cycles(file_graph))), + "strongly_connected_components": len( + list(nx.strongly_connected_components(file_graph)) + ), + } + + # Build symbol dependency graph + symbol_graph = nx.DiGraph() + for symbol in codebase.symbols: + symbol_graph.add_node(symbol.name) + for dep in symbol.dependencies: + if hasattr(dep, "name") and not isinstance( + dep, ExternalModule + ): # Exclude external modules from internal symbol graph + symbol_graph.add_edge(symbol.name, dep.name) + + dependency_graph["symbol_dependencies"] = { + "nodes": len(symbol_graph.nodes), + "edges": len(symbol_graph.edges), + "cycles": len(list(nx.simple_cycles(symbol_graph))), + "max_depth": len(nx.dag_longest_path(symbol_graph)) + if nx.is_directed_acyclic_graph(symbol_graph) + else 0, + } + + return dependency_graph + + def _calculate_code_quality_metrics(self, codebase: Codebase) -> Dict[str, Any]: + """Calculate comprehensive code quality metrics""" + metrics = { + "complexity_score": 0.0, + "maintainability_index": 0.0, + "technical_debt_ratio": 0.0, + "test_coverage_estimate": 0.0, + "documentation_coverage": 0.0, + "code_duplication_score": 0.0, + "type_coverage": 0.0, + "function_metrics": {}, + "class_metrics": {}, + "file_metrics": {}, + } + + total_functions = len(list(codebase.functions)) + total_classes = len(list(codebase.classes)) + total_files = len(list(codebase.files)) + + if total_functions == 0: + return metrics + + # Calculate function metrics + function_complexities = [] + documented_functions = 0 + typed_functions = 0 + + for func in codebase.functions: + complexity = self._calculate_function_complexity(func) + function_complexities.append(complexity) + + if hasattr(func, "docstring") and func.docstring: + documented_functions += 1 + + if hasattr(func, "return_type") and func.return_type: + typed_functions += 1 + + metrics["complexity_score"] = sum(function_complexities) / len( + function_complexities + ) + metrics["documentation_coverage"] = documented_functions / total_functions + metrics["type_coverage"] = typed_functions / total_functions + + # Calculate maintainability index + avg_complexity = metrics["complexity_score"] + avg_loc = ( + sum( + len(f.source.splitlines()) + for f in codebase.functions + if hasattr(f, "source") + ) + / total_functions + ) + + # Simplified maintainability index calculation + metrics["maintainability_index"] = max( + 0, + ( + 171 + - 5.2 * math.log(avg_loc) + - 0.23 * avg_complexity + - 16.2 * math.log(avg_loc) + ) + * 100 + / 171, + ) + + # Estimate test coverage + test_functions = len( + [f for f in codebase.functions if self._is_test_function_enhanced(f)] + ) + metrics["test_coverage_estimate"] = min( + 1.0, test_functions / max(1, total_functions - test_functions) + ) + + # Calculate technical debt ratio (simplified) + high_complexity_functions = len([c for c in function_complexities if c > 10]) + undocumented_functions = total_functions - documented_functions + untyped_functions = total_functions - typed_functions + + debt_score = ( + high_complexity_functions + undocumented_functions + untyped_functions + ) / (total_functions * 3) + metrics["technical_debt_ratio"] = debt_score + + return metrics + + def _analyze_architectural_patterns(self, codebase: Codebase) -> Dict[str, Any]: + """Analyze architectural patterns in the codebase""" + patterns = { + "mvc_pattern": False, + "repository_pattern": False, + "factory_pattern": False, + "singleton_pattern": False, + "observer_pattern": False, + "decorator_pattern": False, + "strategy_pattern": False, + "layered_architecture": False, + "microservices_indicators": [], + "design_pattern_usage": {}, + } + + # Detect MVC pattern + mvc_indicators = ["controller", "model", "view"] + mvc_count = sum( + 1 + for cls in codebase.classes + if any(indicator in cls.name.lower() for indicator in mvc_indicators) + ) + patterns["mvc_pattern"] = mvc_count >= 3 + + # Detect Repository pattern + repo_indicators = ["repository", "repo"] + patterns["repository_pattern"] = any( + indicator in cls.name.lower() + for cls in codebase.classes + for indicator in repo_indicators + ) + + # Detect Factory pattern + factory_indicators = ["factory", "builder", "creator"] + patterns["factory_pattern"] = any( + indicator in cls.name.lower() + for cls in codebase.classes + for indicator in factory_indicators + ) + + # Detect Singleton pattern + for cls in codebase.classes: + if any("singleton" in method.name.lower() for method in cls.methods): + patterns["singleton_pattern"] = True + break + + # Detect microservices indicators + microservice_indicators = ["service", "api", "endpoint", "handler"] + for indicator in microservice_indicators: + count = sum(1 for cls in codebase.classes if indicator in cls.name.lower()) + if count > 0: + patterns["microservices_indicators"].append( + {"pattern": indicator, "count": count} + ) + + return patterns + + def _analyze_security_patterns(self, codebase: Codebase) -> Dict[str, Any]: + """Analyze security patterns and potential issues""" + security = { + "potential_vulnerabilities": [], + "security_patterns": [], + "authentication_usage": False, + "encryption_usage": False, + "input_validation": False, + "sql_injection_risks": [], + "xss_risks": [], + "hardcoded_secrets": [], + } + + # Check for authentication patterns + auth_patterns = ["authenticate", "login", "auth", "token", "jwt"] + for func in codebase.functions: + if any(pattern in func.name.lower() for pattern in auth_patterns): + security["authentication_usage"] = True + break + + # Check for encryption usage + crypto_patterns = ["encrypt", "decrypt", "hash", "crypto", "ssl", "tls"] + for func in codebase.functions: + if any(pattern in func.name.lower() for pattern in crypto_patterns): + security["encryption_usage"] = True + break + + # Check for potential SQL injection risks + for func in codebase.functions: + if hasattr(func, "source") and func.source: + if "execute(" in func.source and "%" in func.source: + security["sql_injection_risks"].append( + { + "function": func.name, + "file": func.filepath, + "risk": "Potential SQL injection via string formatting", + } + ) + + # Check for hardcoded secrets (simplified) + secret_patterns = ["password", "secret", "key", "token"] + for file_obj in codebase.files: + if hasattr(file_obj, "source") and file_obj.source: + for pattern in secret_patterns: + if ( + f'{pattern} = "' in file_obj.source.lower() + or f'{pattern}="' in file_obj.source.lower() + ): + security["hardcoded_secrets"].append( + { + "file": file_obj.filepath, + "pattern": pattern, + "risk": "Potential hardcoded secret", + } + ) + + return security + + def _analyze_performance_patterns(self, codebase: Codebase) -> Dict[str, Any]: + """Analyze performance patterns and potential issues""" + performance = { + "potential_bottlenecks": [], + "async_usage": 0, + "database_queries": 0, + "loop_complexity": [], + "memory_usage_patterns": [], + "caching_usage": False, + "optimization_opportunities": [], + } + + # Count async functions + performance["async_usage"] = sum( + 1 for func in codebase.functions if getattr(func, "is_async", False) + ) + + # Check for database query patterns + db_patterns = ["query", "select", "insert", "update", "delete", "execute"] + for func in codebase.functions: + if hasattr(func, "source") and func.source: + if any(pattern in func.source.lower() for pattern in db_patterns): + performance["database_queries"] += 1 + + # Check for caching usage + cache_patterns = ["cache", "memoize", "redis", "memcached"] + for func in codebase.functions: + if any(pattern in func.name.lower() for pattern in cache_patterns): + performance["caching_usage"] = True + break + + # Identify potential optimization opportunities + for func in codebase.functions: + complexity = self._calculate_function_complexity(func) + if complexity > 15: + performance["optimization_opportunities"].append( + { + "function": func.name, + "file": func.filepath, + "complexity": complexity, + "suggestion": "Consider breaking down complex function", + } + ) + + return performance + + # ============================================================================ + # HELPER FUNCTIONS FOR ANALYSIS + # ============================================================================ + + def _detect_file_errors_graph_sitter( + self, file_obj: SourceFile, lsp_diagnostics: List[Dict[str, Any]] + ) -> Dict[str, int]: + """Detect errors in a file using graph-sitter and LSP diagnostics""" + errors = {"critical": 0, "major": 0, "minor": 0} + + # Add LSP diagnostics counts + for enhanced_diag in lsp_diagnostics: + diag = enhanced_diag["diagnostic"] + if diag.severity: + if diag.severity.name.lower() == "error": + errors["critical"] += 1 + elif diag.severity.name.lower() == "warning": + errors["major"] += 1 + else: # Info, Hint, Unknown + errors["minor"] += 1 + + try: + # Check for syntax errors (simplified) + if hasattr(file_obj, "source") and file_obj.source: + try: + ast.parse(file_obj.source) + except SyntaxError: + errors["critical"] += 1 + + # Check for import issues + for imp in file_obj.imports: + if not hasattr(imp, "resolved_symbol") or not imp.resolved_symbol: + errors["minor"] += 1 + + # Check for long files + if hasattr(file_obj, "source") and len(file_obj.source.splitlines()) > 1000: + errors["major"] += 1 + + except Exception as e: + logger.warning(f"Error detecting file errors for {file_obj.filepath}: {e}") + + return errors + + def _detect_function_errors_graph_sitter( + self, func: Function, lsp_diagnostics: List[Dict[str, Any]] + ) -> Dict[str, int]: + """Detect errors in a function using graph-sitter and LSP diagnostics""" + errors = {"critical": 0, "major": 0, "minor": 0} + + # Add LSP diagnostics counts relevant to this function's range + func_start_line = func.start_point.line if hasattr(func, "start_point") else -1 + func_end_line = func.end_point.line if hasattr(func, "end_point") else -1 + + for enhanced_diag in lsp_diagnostics: + diag = enhanced_diag["diagnostic"] + if diag.range.line >= func_start_line and diag.range.line <= func_end_line: + if diag.severity: + if diag.severity.name.lower() == "error": + errors["critical"] += 1 + elif diag.severity.name.lower() == "warning": + errors["major"] += 1 + else: # Info, Hint, Unknown + errors["minor"] += 1 + + try: + # Check complexity (not required to be presented as retrievable output) + # complexity = self._calculate_function_complexity(func) + # if complexity > 20: + # errors["critical"] += 1 + # elif complexity > 10: + # errors["major"] += 1 + + # Check for missing docstring + if not hasattr(func, "docstring") or not func.docstring: + errors["minor"] += 1 + + # Check for missing type annotations + if not hasattr(func, "return_type") or not func.return_type: + errors["minor"] += 1 + + except Exception as e: + logger.warning(f"Error detecting function errors for {func.name}: {e}") + + return errors + + def _detect_class_errors_graph_sitter( + self, cls: Class, lsp_diagnostics: List[Dict[str, Any]] + ) -> Dict[str, int]: + """Detect errors in a class using graph-sitter and LSP diagnostics""" + errors = {"critical": 0, "major": 0, "minor": 0} + + # Add LSP diagnostics counts relevant to this class's range + cls_start_line = cls.start_point.line if hasattr(cls, "start_point") else -1 + cls_end_line = cls.end_point.line if hasattr(cls, "end_point") else -1 + + for enhanced_diag in lsp_diagnostics: + diag = enhanced_diag["diagnostic"] + if diag.range.line >= cls_start_line and diag.range.line <= cls_end_line: + if diag.severity: + if diag.severity.name.lower() == "error": + errors["critical"] += 1 + elif diag.severity.name.lower() == "warning": + errors["major"] += 1 + else: # Info, Hint, Unknown + errors["minor"] += 1 + + try: + # Check for too many methods + if len(list(cls.methods)) > 20: + errors["major"] += 1 + + # Check inheritance depth + if calculate_doi(cls) > 5: + errors["major"] += 1 + + # Check for missing docstring + if not hasattr(cls, "docstring") or not cls.docstring: + errors["minor"] += 1 + + except Exception as e: + logger.warning(f"Error detecting class errors for {cls.name}: {e}") + + return errors + + def _is_entrypoint_file(self, file_obj: SourceFile) -> bool: + """Check if a file is an entrypoint""" + entrypoint_patterns = [ + "main.py", + "__main__.py", + "app.py", + "server.py", + "run.py", + "cli.py", + ] + return any(pattern in file_obj.filepath for pattern in entrypoint_patterns) + + def _is_entrypoint_function(self, func: Function) -> bool: + """Check if a function is an entrypoint""" + entrypoint_patterns = ["main", "run", "start", "execute", "cli", "app"] + return ( + any(pattern in func.name.lower() for pattern in entrypoint_patterns) + or func.name == "__main__" + ) + + def _is_entrypoint_class(self, cls: Class) -> bool: + """Check if a class is an entrypoint""" + entrypoint_patterns = [ + "app", + "application", + "server", + "client", + "main", + "runner", + ] + return any(pattern in cls.name.lower() for pattern in entrypoint_patterns) + + def _is_test_function(self, func: Function) -> bool: + """Check if a function is a test function""" + return func.name.startswith("test_") or "test" in func.filepath + + def _is_test_class(self, cls: Class) -> bool: + """Check if a class is a test class""" + return cls.name.startswith("Test") or "test" in cls.filepath + + def _is_special_function(self, func: Function) -> bool: + """Check if a function is a special function that shouldn't be considered dead code""" + special_patterns = [ + "__init__", + "__str__", + "__repr__", + "__call__", + "setUp", + "tearDown", + ] + return any(pattern in func.name for pattern in special_patterns) + + def _calculate_file_complexity(self, file_obj: SourceFile) -> float: + """Calculate complexity score for a file""" + try: + if not hasattr(file_obj, "functions"): + return 0.0 + + function_complexities = [ + self._calculate_function_complexity(func) for func in file_obj.functions + ] + return sum(function_complexities) / max(1, len(function_complexities)) + except Exception: + return 0.0 + + def _calculate_function_complexity(self, func: Function) -> int: + """Calculate cyclomatic complexity for a function""" + try: + complexity = 1 # Base complexity + + if hasattr(func, "source") and func.source: + source = func.source.lower() + # Count decision points + complexity += source.count("if ") + complexity += source.count("elif ") + complexity += source.count("for ") + complexity += source.count("while ") + complexity += source.count("except ") + complexity += source.count("and ") + complexity += source.count("or ") + complexity += source.count("try:") + + return complexity + except Exception: + return 1 + + def _calculate_class_complexity(self, cls: Class) -> float: + """Calculate complexity score for a class""" + try: + if not hasattr(cls, "methods"): + return 0.0 + + method_complexities = [ + self._calculate_function_complexity(method) for method in cls.methods + ] + return sum(method_complexities) / max(1, len(method_complexities)) + except Exception: + return 0.0 + + def _calculate_maintainability_index(self, file_obj: SourceFile) -> float: + """Calculate maintainability index for a file""" + try: + if not hasattr(file_obj, "source"): + return 0.0 + + loc = len(file_obj.source.splitlines()) + complexity = self._calculate_file_complexity(file_obj) + + # Simplified maintainability index + if loc > 0: + return max( + 0, + ( + 171 + - 5.2 * math.log(loc) + - 0.23 * complexity + - 16.2 * math.log(loc) + ) + * 100 + / 171, + ) + return 0.0 + except Exception: + return 0.0 + + +# ============================================================================ +# VISUALIZATION ENGINE +# ============================================================================ + + +class EnhancedVisualizationEngine: # Changed from VisualizationEngine + """Enhanced visualization engine with dynamic target selection and scope control""" + + def __init__(self, codebase: Codebase): + self.codebase = codebase + + def create_dynamic_dependency_graph( + self, + target_type: str, + target_name: str, + scope: str = "codebase", + max_depth: int = 10, + include_external: bool = False, + ) -> Dict[str, Any]: + """Create dependency graph with dynamic target and scope selection""" + graph = nx.DiGraph() + + # Get the target symbol + target_symbol = self._get_target_symbol(target_type, target_name) + if not target_symbol: + raise ValueError(f"{target_type} '{target_name}' not found") + + # Get scope symbols + scope_symbols = self._get_scope_symbols(scope, target_symbol) + + # Build dependency graph within scope + self._build_scoped_dependency_graph( + graph, target_symbol, scope_symbols, max_depth, include_external + ) + + return self._serialize_enhanced_graph(graph, target_symbol, scope) + + def create_class_hierarchy_graph( + self, target_class: Optional[str] = None, include_methods: bool = True + ) -> Dict[str, Any]: + """Create class hierarchy visualization""" + graph = nx.DiGraph() + + if target_class: + start_class = self.codebase.get_class(target_class) + if not start_class: + raise ValueError(f"Class '{target_class}' not found") + self._build_class_hierarchy_subgraph(graph, start_class, include_methods) + else: + # Build full class hierarchy + for cls in self.codebase.classes: + self._build_class_hierarchy_subgraph( + graph, cls, include_methods + ) # Changed from _add_class_hierarchy_node + + return self._serialize_enhanced_graph(graph, target_class, "class_hierarchy") + + def create_module_dependency_graph( + self, target_module: str, max_depth: int = 10 + ) -> Dict[str, Any]: + """Create module-level dependency visualization""" + graph = nx.DiGraph() + + # Get all files in the target module + module_files = [f for f in self.codebase.files if target_module in f.filepath] + if not module_files: + raise ValueError(f"Module '{target_module}' not found") + + # Build module dependency graph + self._build_module_dependency_subgraph(graph, module_files, max_depth) + + return self._serialize_enhanced_graph(graph, target_module, "module") + + def create_function_call_trace( + self, entry_function: str, target_function: str, max_depth: int = 15 + ) -> Dict[str, Any]: + """Create call trace from entry function to target function""" + graph = nx.DiGraph() + + start_func = self.codebase.get_function(entry_function) + end_func = self.codebase.get_function(target_function) + + if not start_func: + raise ValueError(f"Entry function '{entry_function}' not found") + if not end_func: + raise ValueError(f"Target function '{target_function}' not found") + + # Build call trace + self._build_call_trace_subgraph(graph, start_func, end_func, max_depth) + + return self._serialize_enhanced_graph( + graph, f"{entry_function} -> {target_function}", "call_trace" + ) + + def create_data_flow_graph( + self, entry_point: str, max_depth: int = 10 + ) -> Dict[str, Any]: + """Create data flow visualization""" + graph = nx.DiGraph() + start_symbol = self.codebase.get_symbol(entry_point) + + if not start_symbol: + raise ValueError(f"Symbol '{entry_point}' not found") + + self._build_data_flow_subgraph(graph, start_symbol, max_depth) + return self._serialize_enhanced_graph(graph, entry_point, "data_flow") + + def create_blast_radius_graph( + self, entry_point: str, max_depth: int = 10 + ) -> Dict[str, Any]: + """Create blast radius visualization showing impact of changes""" + graph = nx.DiGraph() + start_symbol = self.codebase.get_symbol(entry_point) + + if not start_symbol: + raise ValueError(f"Symbol '{entry_point}' not found") + + self._build_blast_radius_subgraph(graph, start_symbol, max_depth) + return self._serialize_enhanced_graph(graph, entry_point, "blast_radius") + + def _get_target_symbol(self, target_type: str, target_name: str): + """Get target symbol based on type""" + if target_type == "function": + return self.codebase.get_function(target_name) + elif target_type == "class": + return self.codebase.get_class(target_name) + elif target_type == "file": + return self.codebase.get_file(target_name) + elif target_type == "symbol": + return self.codebase.get_symbol(target_name) + else: + raise ValueError(f"Unknown target type: {target_type}") + + def _get_scope_symbols(self, scope: str, target_symbol: Symbol) -> List[Symbol]: + """Get all symbols within the specified scope.""" + if scope == "codebase": + return list(self.codebase.symbols) + elif scope == "module": + if hasattr(target_symbol, "file"): + module_path = os.path.dirname(target_symbol.file.filepath) + return [ + s + for s in self.codebase.symbols + if os.path.dirname(s.filepath) == module_path + ] + return [] + elif scope == "file": + if hasattr(target_symbol, "file"): + return list(target_symbol.file.symbols) + return [] + elif scope == "class": + if hasattr(target_symbol, "parent_class") and target_symbol.parent_class: + return list(target_symbol.parent_class.methods) + list( + target_symbol.parent_class.attributes + ) + return [] + elif scope == "function": + if ( + hasattr(target_symbol, "parent_function") + and target_symbol.parent_function + ): + return list( + target_symbol.parent_function.code_block.local_var_assignments + ) + return [] + else: + raise ValueError(f"Unknown scope type: {scope}") + + def _build_scoped_dependency_graph( + self, + graph: nx.DiGraph, + symbol: Symbol, + scope_symbols: List[Symbol], + max_depth: int, + include_external: bool, + depth: int = 0, + ): + """Build dependency graph within specified scope recursively""" + if depth >= max_depth or symbol in graph: # Avoid cycles and depth limit + return + + graph.add_node( + symbol.name, + type=type(symbol).__name__, + file=symbol.filepath if hasattr(symbol, "filepath") else None, + is_target=True if depth == 0 else False, + ) + + for dep in symbol.dependencies: + if not include_external and isinstance(dep, ExternalModule): + continue + + # Only include dependencies within scope or if external is allowed + if dep in scope_symbols or include_external: + graph.add_node( + dep.name, + type=type(dep).__name__, + file=dep.filepath if hasattr(dep, "filepath") else None, + in_scope=dep in scope_symbols, + ) + graph.add_edge(symbol.name, dep.name, relationship="depends_on") + + if dep in scope_symbols: # Only recurse if dependency is within scope + self._build_scoped_dependency_graph( + graph, + dep, + scope_symbols, + max_depth, + include_external, + depth + 1, + ) + + def _build_class_hierarchy_subgraph( + self, graph: nx.DiGraph, cls: Class, include_methods: bool, depth: int = 0 + ): + """Build class hierarchy subgraph recursively""" + if depth > 10 or cls in graph: # Prevent infinite recursion and depth limit + return + + # Add class node + graph.add_node( + cls.name, + type="class", + file=cls.filepath, + methods=len(cls.methods), + attributes=len(cls.attributes), + inheritance_depth=calculate_doi(cls), + ) + + # Add superclass relationships + for superclass in cls.superclasses: + if not isinstance(superclass, ExternalModule): + graph.add_node( + superclass.name, + type="class", + file=superclass.filepath, + methods=len(superclass.methods), + attributes=len(superclass.attributes), + ) + graph.add_edge(cls.name, superclass.name, relationship="inherits_from") + self._build_class_hierarchy_subgraph( + graph, superclass, include_methods, depth + 1 + ) + + # Add subclass relationships + for subclass in cls.subclasses: + if not isinstance(subclass, ExternalModule): + graph.add_node( + subclass.name, + type="class", + file=subclass.filepath, + methods=len(subclass.methods), + attributes=len(subclass.attributes), + ) + graph.add_edge(subclass.name, cls.name, relationship="inherits_from") + self._build_class_hierarchy_subgraph( + graph, subclass, include_methods, depth + 1 + ) + + # Add methods if requested + if include_methods: + for method in cls.methods: + graph.add_node( + f"{cls.name}.{method.name}", + type="method", + file=cls.filepath, + complexity=self._calculate_function_complexity(method), + parameters=len(method.parameters), + ) + graph.add_edge( + cls.name, f"{cls.name}.{method.name}", relationship="contains" + ) + + def _build_module_dependency_subgraph( + self, graph: nx.DiGraph, module_files: List[SourceFile], max_depth: int + ): + """Build module dependency subgraph recursively""" + for file_obj in module_files: + if file_obj in graph: # Avoid cycles + continue + graph.add_node( + file_obj.filepath, + type="file", + functions=len(file_obj.functions), + classes=len(file_obj.classes), + lines=len(file_obj.source.splitlines()) + if hasattr(file_obj, "source") + else 0, + ) + + # Add import relationships + for imp in file_obj.imports: + if hasattr(imp, "from_file") and imp.from_file: + target_file = imp.from_file + if target_file not in graph: # Add target file if not already added + graph.add_node( + target_file.filepath, + type="file", + functions=len(target_file.functions), + classes=len(target_file.classes), + lines=len(target_file.source.splitlines()) + if hasattr(target_file, "source") + else 0, + ) + graph.add_edge( + file_obj.filepath, + target_file.filepath, + relationship="imports_from", + import_count=1, + ) + # Recurse into imported modules if within depth + if max_depth > 1: + self._build_module_dependency_subgraph( + graph, [target_file], max_depth - 1 + ) + + def _build_call_trace_subgraph( + self, + graph: nx.DiGraph, + start_func: Function, + end_func: Function, + max_depth: int, + ): + """Build call trace from start to end function recursively""" + visited = set() + + def trace_calls(func: Function, target: Function, depth=0): + if depth >= max_depth or func in visited: + return False + + visited.add(func) + graph.add_node( + func.name, + type="function", + file=func.filepath, + complexity=self._calculate_function_complexity(func), + depth=depth, + ) + + if func == target: + return True + + for call in func.function_calls: + if hasattr(call, "function_definition") and call.function_definition: + called_func = call.function_definition + if not isinstance(called_func, ExternalModule): + graph.add_edge( + func.name, called_func.name, relationship="calls" + ) + if trace_calls(called_func, target, depth + 1): + return True + + return False + + trace_calls(start_func, end_func) + + def _build_data_flow_subgraph( + self, graph: nx.DiGraph, symbol: Symbol, max_depth: int, depth: int = 0 + ): + """Build data flow subgraph recursively""" + if depth >= max_depth or symbol in graph: # Avoid cycles and depth limit + return + + graph.add_node( + symbol.name, + type=type(symbol).__name__, + file=symbol.filepath if hasattr(symbol, "filepath") else None, + ) + + # Track variable usages and assignments + if hasattr(symbol, "variable_usages"): + for usage in symbol.variable_usages: + if hasattr(usage, "name"): + graph.add_node( + usage.name, + type="variable", + file=usage.file.filepath if hasattr(usage, "file") else None, + ) + graph.add_edge(symbol.name, usage.name, relationship="uses_data") + # Recurse into variable definition if available + if ( + hasattr(usage, "resolved_symbol") + and usage.resolved_symbol + and usage.resolved_symbol not in graph + ): + self._build_data_flow_subgraph( + graph, usage.resolved_symbol, max_depth, depth + 1 + ) + + if hasattr(symbol, "assignments"): # For symbols that are assigned values + for assignment in symbol.assignments: + if hasattr(assignment, "name"): + graph.add_node( + assignment.name, + type="assignment", + file=assignment.file.filepath + if hasattr(assignment, "file") + else None, + ) + graph.add_edge( + assignment.name, symbol.name, relationship="assigns_to" + ) + # Recurse into assigned value's dependencies + if hasattr(assignment, "value") and hasattr( + assignment.value, "dependencies" + ): + for dep in assignment.value.dependencies: + if dep not in graph: + self._build_data_flow_subgraph( + graph, dep, max_depth, depth + 1 + ) + + def _build_blast_radius_subgraph( + self, graph: nx.DiGraph, symbol: Symbol, max_depth: int, depth: int = 0 + ): + """Build blast radius subgraph showing impact of changes recursively""" + if depth >= max_depth or symbol in graph: # Avoid cycles and depth limit + return + + graph.add_node( + symbol.name, + type=type(symbol).__name__, + file=symbol.filepath if hasattr(symbol, "filepath") else None, + impact_level=depth, + ) + + # Add all usages (things that would be affected by changes) + for usage in symbol.usages: + if hasattr(usage, "usage_symbol"): + affected_symbol = usage.usage_symbol + if affected_symbol not in graph: # Avoid re-adding nodes + graph.add_node( + affected_symbol.name, + type=type(affected_symbol).__name__, + file=affected_symbol.filepath + if hasattr(affected_symbol, "filepath") + else None, + impact_level=depth + 1, + ) + graph.add_edge( + symbol.name, affected_symbol.name, relationship="impacts" + ) + + self._build_blast_radius_subgraph( + graph, affected_symbol, max_depth, depth + 1 + ) + + def _serialize_enhanced_graph( + self, graph: nx.DiGraph, target_info: Union[str, Symbol], graph_type: str + ) -> Dict[str, Any]: + """Enhanced graph serialization with additional metadata""" + base_result = self._serialize_graph(graph) + + # Convert target_info to string if it's a Symbol object + if isinstance(target_info, Symbol): + target_info_str = target_info.name + else: + target_info_str = str(target_info) + + # Add enhanced metadata + base_result["metadata"] = { + "target": target_info_str, + "graph_type": graph_type, + "created_at": datetime.now().isoformat(), + "node_types": Counter( + data.get("type", "unknown") for _, data in graph.nodes(data=True) + ), + "relationship_types": Counter( + data.get("relationship", "unknown") + for _, _, data in graph.edges(data=True) + ), + } + + # Add graph analysis + if len(graph.nodes) > 0: + base_result["analysis"] = { + "centrality": dict(nx.degree_centrality(graph)), + "clustering": dict(nx.clustering(graph.to_undirected())), + "shortest_paths": dict(nx.shortest_path_length(graph)) + if nx.is_connected(graph.to_undirected()) + else {}, + } + + return base_result + + def _serialize_graph(self, graph: nx.DiGraph) -> Dict[str, Any]: + """Serialize NetworkX graph to JSON-serializable format""" + return { + "nodes": [ + {"id": node, "label": node, **data} + for node, data in graph.nodes(data=True) + ], + "edges": [ + {"source": source, "target": target, **data} + for source, target, data in graph.edges(data=True) + ], + "metrics": { + "node_count": len(graph.nodes), + "edge_count": len(graph.edges), + "density": nx.density(graph), + "is_connected": nx.is_weakly_connected(graph), + }, + } + + # Re-implement helper functions from AnalysisEngine that are used by VisualizationEngine + # These are simplified versions, assuming they would be part of the main AnalysisEngine + def _calculate_function_complexity(self, func: Function) -> int: + """Calculate cyclomatic complexity for a function (simplified)""" + if hasattr(func, "complexity"): # If graph-sitter provides it directly + return func.complexity + if hasattr(func, "source") and func.source: + return ( + func.source.count("if ") + + func.source.count("for ") + + func.source.count("while ") + + 1 + ) + return 1 + + +# ============================================================================ +# TRANSFORMATION ENGINE +# ============================================================================ + + +class TransformationEngine: + """Advanced transformation engine for code modifications.""" + + def __init__(self, codebase: Codebase): + self.codebase = codebase + self.transformation_log = [] + + def move_symbol( + self, + symbol_name: str, + target_file: str, + include_dependencies: bool = True, + strategy: str = "update_all_imports", + ) -> Dict[str, Any]: + """Move a symbol to a different file""" + try: + symbol = self.codebase.get_symbol(symbol_name) + if not symbol: + raise ValueError(f"Symbol '{symbol_name}' not found") + + # Get or create target file + if not self.codebase.has_file(target_file): + target_file_obj = self.codebase.create_file(target_file) + else: + target_file_obj = self.codebase.get_file(target_file) + + # Record original location + original_file = symbol.filepath + + # Perform the move + symbol.move_to_file( + target_file_obj, + include_dependencies=include_dependencies, + strategy=strategy, + ) + + result = { + "success": True, + "symbol": symbol_name, + "from_file": original_file, + "to_file": target_file, + "strategy": strategy, + "include_dependencies": include_dependencies, + } + + self.transformation_log.append(result) + return result + + except Exception as e: + error_result = { + "success": False, + "symbol": symbol_name, + "error": str(e), + "error_type": type(e).__name__, + } + self.transformation_log.append(error_result) + return error_result + + def remove_symbol(self, symbol_name: str, safe_mode: bool = True) -> Dict[str, Any]: + """Remove a symbol from the codebase""" + try: + symbol = self.codebase.get_symbol(symbol_name) + if not symbol: + raise ValueError(f"Symbol '{symbol_name}' not found") + + # Check if symbol is used elsewhere + if safe_mode and len(symbol.usages) > 0: + return { + "success": False, + "symbol": symbol_name, + "error": "Symbol is still in use", + "usages": [usage.file.filepath for usage in symbol.usages[:5]], + } + + # Remove the symbol + original_file = symbol.filepath + symbol.remove() + + result = { + "success": True, + "symbol": symbol_name, + "file": original_file, + "safe_mode": safe_mode, + } + + self.transformation_log.append(result) + return result + + except Exception as e: + error_result = { + "success": False, + "symbol": symbol_name, + "error": str(e), + "error_type": type(e).__name__, + } + self.transformation_log.append(error_result) + return error_result + + def rename_symbol(self, old_name: str, new_name: str) -> Dict[str, Any]: + """Rename a symbol and update all references""" + try: + symbol = self.codebase.get_symbol(old_name) + if not symbol: + raise ValueError(f"Symbol '{old_name}' not found") + + # Count usages before rename + usage_count = len(symbol.usages) + + # Perform the rename + symbol.rename(new_name) + + result = { + "success": True, + "old_name": old_name, + "new_name": new_name, + "file": symbol.filepath, + "usages_updated": usage_count, + } + + self.transformation_log.append(result) + return result + + except Exception as e: + error_result = { + "success": False, + "old_name": old_name, + "new_name": new_name, + "error": str(e), + "error_type": type(e).__name__, + } + self.transformation_log.append(error_result) + return error_result + + def resolve_imports(self, file_path: str) -> Dict[str, Any]: + """Resolve and fix import issues in a file""" + try: + file_obj = self.codebase.get_file(file_path) + if not file_obj: + raise ValueError(f"File '{file_path}' not found") + + resolved_imports = [] + unresolved_imports = [] + + for imp in file_obj.imports: + if hasattr(imp, "resolved_symbol") and imp.resolved_symbol: + resolved_imports.append(imp.name) + else: + unresolved_imports.append(imp.name) + + result = { + "success": True, + "file": file_path, + "resolved_imports": resolved_imports, + "unresolved_imports": unresolved_imports, + "total_imports": len(file_obj.imports), + } + + self.transformation_log.append(result) + return result + + except Exception as e: + error_result = { + "success": False, + "file": file_path, + "error": str(e), + "error_type": type(e).__name__, + } + self.transformation_log.append(error_result) + return error_result + + def add_type_annotations( + self, + symbol_name: str, + return_type: Optional[str] = None, + parameter_types: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """Add type annotations to a function""" + try: + symbol = self.codebase.get_function(symbol_name) + if not symbol: + raise ValueError(f"Function '{symbol_name}' not found") + + changes = [] + + # Add return type annotation + if return_type: + symbol.set_return_type(return_type) + changes.append(f"Added return type: {return_type}") + + # Add parameter type annotations + if parameter_types: + for param_name, param_type in parameter_types.items(): + param = symbol.get_parameter(param_name) + if param: + param.set_type(param_type) # Changed from set_type_annotation + changes.append(f"Added type for {param_name}: {param_type}") + + result = { + "success": True, + "function": symbol_name, + "file": symbol.filepath, + "changes": changes, + } + + self.transformation_log.append(result) + return result + + except Exception as e: + error_result = { + "success": False, + "function": symbol_name, + "error": str(e), + "error_type": type(e).__name__, + } + self.transformation_log.append(error_result) + return error_result + + def extract_function( + self, + source_function: str, + new_function_name: str, + start_line: int, + end_line: int, + ) -> Dict[str, Any]: + """Extract code into a new function""" + try: + func = self.codebase.get_function(source_function) + if not func: + raise ValueError(f"Function '{source_function}' not found") + + # Get source lines + source_lines = func.source.splitlines() + if start_line < 1 or end_line > len(source_lines): + raise ValueError("Invalid line range") + + # Extract the code block + extracted_code = "\n".join(source_lines[start_line - 1 : end_line]) + + # Create new function + new_function_code = f""" +def {new_function_name}(): + {extracted_code} +""" + + # Add new function after the original + # This is a placeholder; graph-sitter's API for code modification is more advanced + # func.insert_after(new_function_code) + + # Replace extracted code with function call + replacement = f"{new_function_name}()" + # This is a simplified replacement - in practice, you'd need more sophisticated logic + + result = { + "success": True, + "source_function": source_function, + "new_function": new_function_name, + "extracted_lines": f"{start_line}-{end_line}", + "file": func.filepath, + } + + self.transformation_log.append(result) + return result + + except Exception as e: + error_result = { + "success": False, + "source_function": source_function, + "new_function": new_function_name, + "error": str(e), + "error_type": type(e).__name__, + } + self.transformation_log.append(error_result) + return error_result + + def get_transformation_log(self) -> List[Dict[str, Any]]: + """Get the log of all transformations performed""" + return self.transformation_log.copy() + + def clear_transformation_log(self): + """Clear the transformation log""" + self.transformation_log.clear() + + +class EnhancedTransformationEngine(TransformationEngine): + """Enhanced transformation engine with comprehensive resolution methods""" + + def __init__(self, codebase: Codebase): + super().__init__(codebase) + self.resolution_methods = { + # These methods would need to be implemented to use graph-sitter's transformation APIs + "resolve_all_types": self._resolve_all_types, + "resolve_all_imports": self._resolve_all_imports, + "resolve_all_function_calls": self._resolve_all_function_calls, + "resolve_method_implementations": self._resolve_method_implementations, + "resolve_class_attributes": self._resolve_class_attributes, + "resolve_variable_definitions": self._resolve_variable_definitions, + "resolve_parameter_types": self._resolve_parameter_types, + "resolve_argument_types": self._resolve_argument_types, + } + + def _resolve_all_types(self) -> Dict[str, Any]: + """Automated resolution for all missing/incorrect types.""" + # Placeholder for actual graph-sitter type resolution logic + return {"status": "success", "message": "Attempted to resolve all types."} + + def _resolve_all_imports(self) -> Dict[str, Any]: + """Automated resolution for all import issues (unused, unresolved, circular).""" + # Placeholder for actual graph-sitter import resolution logic + return {"status": "success", "message": "Attempted to resolve all imports."} + + def _resolve_all_function_calls(self) -> Dict[str, Any]: + """Automated resolution for all function call issues (missing args, type mismatches).""" + # Placeholder for actual graph-sitter function call resolution logic + return { + "status": "success", + "message": "Attempted to resolve all function calls.", + } + + def _resolve_method_implementations(self) -> Dict[str, Any]: + """Automated resolution for unimplemented methods.""" + # Placeholder for actual graph-sitter method implementation logic + return { + "status": "success", + "message": "Attempted to resolve method implementations.", + } + + def _resolve_class_attributes(self) -> Dict[str, Any]: + """Automated resolution for missing class attributes.""" + # Placeholder for actual graph-sitter class attribute logic + return { + "status": "success", + "message": "Attempted to resolve class attributes.", + } + + def _resolve_variable_definitions(self) -> Dict[str, Any]: + """Automated resolution for undefined variable usages.""" + # Placeholder for actual graph-sitter variable definition logic + return { + "status": "success", + "message": "Attempted to resolve variable definitions.", + } + + def _resolve_parameter_types(self) -> Dict[str, Any]: + """Automated resolution for missing parameter types.""" + # Placeholder for actual graph-sitter parameter type logic + return {"status": "success", "message": "Attempted to resolve parameter types."} + + def _resolve_argument_types(self) -> Dict[str, Any]: + """Automated resolution for argument type mismatches.""" + # Placeholder for actual graph-sitter argument type logic + return {"status": "success", "message": "Attempted to resolve argument types."} + + +# ============================================================================ +# API ENDPOINTS +# ============================================================================ + + +@app.post("/analyze", response_model=Dict[str, Any]) +async def analyze_codebase(request: AnalyzeRequest, background_tasks: BackgroundTasks): + """Analyze a codebase comprehensively""" + try: + analysis_id = str(uuid.uuid4()) + + # Clone repository + repo_path = clone_repository(request.repo_url, request.branch) + + # Initialize AnalysisEngine + codebase = Codebase( + repo_path, + config=CodebaseConfig( + method_usages=True, + generics=True, + sync_enabled=True, + full_range_index=True, + py_resolve_syspath=True, + exp_lazy_graph=False, + ), + ) + analysis_engine = AnalysisEngine(codebase, request.language) + + # Perform analysis + analysis_result = await analysis_engine.perform_full_analysis() + + # Store analysis session + analysis_sessions[analysis_id] = { + "id": analysis_id, + "repo_url": request.repo_url, + "branch": request.branch, + "repo_path": repo_path, + "codebase_obj": codebase, # Store codebase object for later use + "analysis": analysis_result, + "created_at": datetime.now().isoformat(), + "config": request.config or {}, + } + + # Schedule cleanup + background_tasks.add_task(cleanup_temp_directory, repo_path) + + return { + "analysis_id": analysis_id, + "status": "completed", + "summary": { + "files": analysis_result["metrics"]["files"], + "functions": analysis_result["metrics"]["functions"], + "classes": analysis_result["metrics"]["classes"], + "errors": analysis_result["error_analysis"]["total"], + "dead_code_items": analysis_result["dead_code_analysis"]["total"], + }, + "analysis": analysis_result, + } + + except Exception as e: + logger.error(f"Analysis failed: {e}") + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") + + +@app.get("/analysis/{analysis_id}/errors", response_model=ErrorAnalysisResponse) +async def get_error_analysis(analysis_id: str): + """Get detailed error analysis for a codebase""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + error_analysis = session["analysis"]["error_analysis"] + + return ErrorAnalysisResponse( + total_errors=error_analysis["total"], + critical_errors=error_analysis["critical"], + major_errors=error_analysis["major"], + minor_errors=error_analysis["minor"], + errors_by_category=error_analysis["by_category"], + detailed_errors=error_analysis["detailed_errors"], + error_patterns=error_analysis.get("error_patterns", []), + suggestions=error_analysis.get("suggestions", []), + ) + + +@app.post("/analysis/{analysis_id}/fix-errors") +async def fix_errors_with_ai(analysis_id: str, max_fixes: int = 1): + """ + Attempts to fix errors in the codebase using AI. + Applies fixes directly to the cloned repository. + """ + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + codebase: Codebase = session["codebase_obj"] # Retrieve codebase object + repo_path = session["repo_path"] + error_analysis = session["analysis"]["error_analysis"] + + if not error_analysis["detailed_errors"]: + return {"message": "No errors found to fix."} + + fixes_applied = 0 + fix_results = [] + + # Sort diagnostics by severity (critical > major > minor), then by file path, then by line number + # Assuming severity is 'critical', 'major', 'minor' for sorting + severity_order = {"critical": 0, "major": 1, "minor": 2, "unknown": 3} + sorted_errors = sorted( + error_analysis["detailed_errors"], + key=lambda ed: (severity_order.get(ed["severity"], 3), ed["file"], ed["line"]), + ) + + for error_data in sorted_errors: + if fixes_applied >= max_fixes: + break + + # Only attempt to fix LSP diagnostics for now, as they have the full Diagnostic object + # The 'context' field now holds the full EnhancedDiagnostic object + if "context" not in error_data or "diagnostic" not in error_data["context"]: + fix_results.append( + { + "error": error_data, + "status": "skipped", + "message": "Only LSP diagnostics with full context are currently supported for AI fixing.", + } + ) + continue + + enhanced_diag = error_data[ + "context" + ] # This is the full EnhancedDiagnostic object + diag: Diagnostic = enhanced_diag["diagnostic"] + file_path_abs = os.path.join(repo_path, enhanced_diag["relative_file_path"]) + + logger.info( + f"Attempting AI fix for: {diag.message} at {enhanced_diag['relative_file_path']}:{diag.range.line + 1}" + ) + + # Call AI resolution with the full enhanced diagnostic + ai_fix_response = resolve_diagnostic_with_ai(enhanced_diag, codebase) + + if ai_fix_response["status"] == "success": + original_content = Path(file_path_abs).read_text() + fixed_content = apply_fix_to_file( + file_path_abs, + original_content, + ai_fix_response["fixed_code"], + diag.range, + ) + + try: + Path(file_path_abs).write_text(fixed_content) + fixes_applied += 1 + fix_results.append( + { + "error": error_data, + "status": "applied", + "message": ai_fix_response["explanation"], + "fixed_code": ai_fix_response["fixed_code"], + } + ) + logger.info( + f"Successfully applied fix to {enhanced_diag['relative_file_path']}. Explanation: {ai_fix_response['explanation']}" + ) + + # Re-analyze the codebase after applying fixes to update diagnostics + # This is crucial for iterative fixing + # For simplicity, we'll just update the in-memory codebase object + codebase.reload() # Reload the codebase to reflect changes + # Re-run LSP diagnostics for the modified file + # This requires re-initializing LSP manager for the current codebase state + lsp_manager_temp = LSPDiagnosticsManager( + codebase, Language(session["language"]) + ) + lsp_manager_temp.start_server() + lsp_manager_temp.open_file( + enhanced_diag["relative_file_path"], fixed_content + ) + await asyncio.sleep(1) # Give LSP time to re-diagnose + updated_diags_for_file = lsp_manager_temp.get_diagnostics( + enhanced_diag["relative_file_path"] + ) + lsp_manager_temp.shutdown_server() + + # Update the session's analysis with new diagnostics + # This is a simplified update; a full re-analysis might be needed for accuracy + # Find the original enhanced diagnostic in the session's analysis and update it + for i, err in enumerate( + session["analysis"]["error_analysis"]["detailed_errors"] + ): + if ( + err["file"] == enhanced_diag["relative_file_path"] + and err["line"] == diag.range.line + 1 + ): + # For simplicity, we'll just mark it as fixed or remove it. + # A more robust solution would re-run the full error analysis. + session["analysis"]["error_analysis"]["detailed_errors"][i][ + "status" + ] = "fixed" + session["analysis"]["error_analysis"]["detailed_errors"][i][ + "fixed_code" + ] = ai_fix_response["fixed_code"] + break + + except Exception as e: + fix_results.append( + { + "error": error_data, + "status": "failed_to_apply", + "message": f"Failed to write fixed content: {e}", + } + ) + logger.error(f"Failed to write fixed content to {file_path_abs}: {e}") + else: + fix_results.append( + { + "error": error_data, + "status": "ai_failed", + "message": ai_fix_response.get( + "message", "AI could not generate a fix." + ), + } + ) + logger.warning( + f"AI failed to fix error: {ai_fix_response.get('message', 'Unknown error')}" + ) + + return { + "message": f"Attempted to fix {len(sorted_errors)} errors. Applied {fixes_applied} fixes.", + "fix_results": fix_results, + } + + +@app.get( + "/analysis/{analysis_id}/entrypoints", response_model=EntrypointAnalysisResponse +) +async def get_entrypoint_analysis(analysis_id: str): + """Get entrypoint analysis for a codebase""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + entrypoint_analysis = session["analysis"]["entrypoint_analysis"] + + return EntrypointAnalysisResponse( + total_entrypoints=entrypoint_analysis["total_entrypoints"], + main_entrypoints=entrypoint_analysis["main_entrypoints"], + secondary_entrypoints=entrypoint_analysis["secondary_entrypoints"], + test_entrypoints=entrypoint_analysis["test_entrypoints"], + api_entrypoints=entrypoint_analysis["api_entrypoints"], # Added + cli_entrypoints=entrypoint_analysis["cli_entrypoints"], # Added + entrypoint_graph=entrypoint_analysis["entrypoint_graph"], + complexity_metrics=entrypoint_analysis.get("complexity_metrics", {}), + dependency_analysis=entrypoint_analysis.get("dependency_analysis", {}), # Added + call_flow_analysis=entrypoint_analysis.get("call_flow_analysis", {}), # Added + ) + + +@app.get("/analysis/{analysis_id}/dead-code", response_model=DeadCodeAnalysisResponse) +async def get_dead_code_analysis(analysis_id: str): + """Get dead code analysis for a codebase""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + dead_code_analysis = session["analysis"]["dead_code_analysis"] + + return DeadCodeAnalysisResponse( + total_dead_items=dead_code_analysis["total"], + dead_functions=[ + item + for item in dead_code_analysis["detailed_items"] + if item["type"] == "function" + ], + dead_classes=[ + item + for item in dead_code_analysis["detailed_items"] + if item["type"] == "class" + ], + dead_imports=[ + item + for item in dead_code_analysis["detailed_items"] + if item["type"] == "import" + ], + dead_variables=[ + item + for item in dead_code_analysis["detailed_items"] + if item["type"] == "variable" + ], + potential_dead_code=dead_code_analysis["detailed_items"], + recommendations=dead_code_analysis["recommendations"], + ) + + +@app.get("/analysis/{analysis_id}/quality", response_model=CodeQualityMetrics) +async def get_code_quality_metrics(analysis_id: str): + """Get code quality metrics for a codebase""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + quality_metrics = session["analysis"]["code_quality_metrics"] + + return CodeQualityMetrics( + complexity_score=quality_metrics["complexity_score"], + maintainability_index=quality_metrics["maintainability_index"], + technical_debt_ratio=quality_metrics["technical_debt_ratio"], + test_coverage_estimate=quality_metrics["test_coverage_estimate"], + documentation_coverage=quality_metrics["documentation_coverage"], + code_duplication_score=quality_metrics.get("code_duplication_score", 0.0), + type_coverage=quality_metrics.get("type_coverage", 0.0), # Added + function_metrics=quality_metrics.get("function_metrics", {}), # Added + class_metrics=quality_metrics.get("class_metrics", {}), # Added + file_metrics=quality_metrics.get("file_metrics", {}), # Added + ) + + +@app.post("/analysis/{analysis_id}/visualize") +async def create_visualization(analysis_id: str, request: VisualizationRequest): + """Create a visualization of the codebase""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + try: + session = analysis_sessions[analysis_id] + # Re-initialize codebase for visualization (or retrieve from session if stored) + codebase: Codebase = session["codebase_obj"] # Retrieve codebase object + viz_engine = EnhancedVisualizationEngine( + codebase + ) # Use EnhancedVisualizationEngine + + # Create visualization based on type + if request.viz_type == "dependency_graph": + result = viz_engine.create_dynamic_dependency_graph( # Changed to dynamic + target_type="codebase", # Default to codebase-wide + target_name="", # No specific target for codebase-wide + scope="codebase", + max_depth=request.max_depth, + include_external=request.include_external, + ) + elif request.viz_type == "call_flow": + if not request.entry_point: + raise HTTPException( + status_code=400, + detail="Entry point required for call flow visualization", + ) + result = viz_engine.create_function_call_trace( # Changed to function_call_trace + entry_function=request.entry_point, + target_function=request.entry_point, # For full call flow from entry point + max_depth=request.max_depth, + ) + elif request.viz_type == "data_flow": + if not request.entry_point: + raise HTTPException( + status_code=400, + detail="Entry point required for data flow visualization", + ) + result = viz_engine.create_data_flow_graph( + entry_point=request.entry_point, max_depth=request.max_depth + ) + elif request.viz_type == "blast_radius": + if not request.entry_point: + raise HTTPException( + status_code=400, + detail="Entry point required for blast radius visualization", + ) + result = viz_engine.create_blast_radius_graph( + entry_point=request.entry_point, max_depth=request.max_depth + ) + elif request.viz_type == "class_hierarchy": # Added + result = viz_engine.create_class_hierarchy_graph( + target_class=request.entry_point, # entry_point can be a class name + include_methods=True, + ) + elif request.viz_type == "module_dependency": # Added + if not request.entry_point: + raise HTTPException( + status_code=400, + detail="Entry point (module name) required for module dependency visualization", + ) + result = viz_engine.create_module_dependency_graph( + target_module=request.entry_point, max_depth=request.max_depth + ) + else: + raise HTTPException( + status_code=400, + detail=f"Unknown visualization type: {request.viz_type}", + ) + + return { + "visualization_id": str(uuid.uuid4()), + "type": request.viz_type, + "entry_point": request.entry_point, + "graph": result, + } + + except Exception as e: + logger.error(f"Visualization failed: {e}") + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=f"Visualization failed: {str(e)}") + + +@app.post("/analysis/{analysis_id}/transform") +async def apply_transformation(analysis_id: str, request: TransformationRequest): + """Apply a transformation to the codebase""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + try: + session = analysis_sessions[analysis_id] + # Re-initialize codebase for transformation (or retrieve from session if stored) + codebase: Codebase = session["codebase_obj"] # Retrieve codebase object + + transform_engine = EnhancedTransformationEngine( + codebase + ) # Use EnhancedTransformationEngine + + # Apply transformation based on type + if request.transformation_type == "move_symbol": + result = transform_engine.move_symbol( + symbol_name=request.parameters.get("symbol_name"), + target_file=request.target_path, + include_dependencies=request.parameters.get( + "include_dependencies", True + ), + strategy=request.parameters.get("strategy", "update_all_imports"), + ) + elif request.transformation_type == "remove_symbol": + result = transform_engine.remove_symbol( + symbol_name=request.parameters.get("symbol_name"), + safe_mode=request.parameters.get("safe_mode", True), + ) + elif request.transformation_type == "rename_symbol": + result = transform_engine.rename_symbol( + old_name=request.parameters.get("old_name"), + new_name=request.parameters.get("new_name"), + ) + elif request.transformation_type == "resolve_imports": + result = transform_engine.resolve_imports(request.target_path) + elif request.transformation_type == "add_type_annotations": + result = transform_engine.add_type_annotations( + symbol_name=request.parameters.get("symbol_name"), + return_type=request.parameters.get("return_type"), + parameter_types=request.parameters.get("parameter_types"), + ) + elif request.transformation_type == "extract_function": + result = transform_engine.extract_function( + source_function=request.parameters.get("source_function"), + new_function_name=request.parameters.get("new_function_name"), + start_line=request.parameters.get("start_line"), + end_line=request.parameters.get("end_line"), + ) + else: + raise HTTPException( + status_code=400, + detail=f"Unknown transformation type: {request.transformation_type}", + ) + + # Commit changes if not dry run + if not request.dry_run and result.get("success"): + codebase.commit() + + return { + "transformation_id": str(uuid.uuid4()), + "type": request.transformation_type, + "dry_run": request.dry_run, + "result": result, + "log": transform_engine.get_transformation_log(), + } + + except Exception as e: + logger.error(f"Transformation failed: {e}") + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=f"Transformation failed: {str(e)}") + + +@app.post("/analysis/{analysis_id}/generate-docs") +async def generate_documentation( + analysis_id: str, target_type: str = "codebase", target_name: Optional[str] = None +): + """ + Generates MDX documentation pages for the codebase, a specific class, or a function. + """ + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + codebase: Codebase = session["codebase_obj"] + repo_path = session["repo_path"] + + docs_output_dir = Path(repo_path) / "generated_docs" + docs_output_dir.mkdir(exist_ok=True) + + generated_files = [] + + try: + # Generate structured JSON documentation + # For simplicity, we'll generate for the whole codebase and then filter for MDX + structured_docs = generate_docs_json( + codebase, head_commit="latest" + ) # Assuming 'latest' or a real commit hash + + if target_type == "codebase": + classes_to_document = structured_docs.classes + elif target_type == "class" and target_name: + classes_to_document = [ + cls_doc + for cls_doc in structured_docs.classes + if cls_doc.title == target_name + ] + if not classes_to_document: + raise HTTPException( + status_code=404, + detail=f"Class '{target_name}' not found for documentation.", + ) + elif target_type == "function" and target_name: + # For functions, we need to find the class they belong to, or handle standalone functions + # This is a simplified approach; a full implementation would need to find the function's parent + # and then generate docs for that function. For now, we'll just generate for classes. + raise HTTPException( + status_code=400, + detail="Direct function documentation generation not yet supported. Please specify a class.", + ) + else: + raise HTTPException( + status_code=400, detail="Invalid target_type or missing target_name." + ) + + for cls_doc in classes_to_document: + mdx_content = render_mdx_page_for_class(cls_doc) + mdx_route = ( + Path(cls_doc.path).with_suffix(".mdx").as_posix() + ) # Use class path for MDX route + + # Create subdirectories based on the route + output_path = docs_output_dir / mdx_route + output_path.parent.mkdir(parents=True, exist_ok=True) + + output_path.write_text(mdx_content) + generated_files.append(str(output_path.relative_to(repo_path))) + logger.info(f"Generated MDX for {cls_doc.title} at {output_path}") + + return { + "message": f"Documentation generated successfully for {target_type}.", + "generated_files": generated_files, + "output_directory": str(docs_output_dir.relative_to(repo_path)), + } + + except Exception as e: + logger.error(f"Documentation generation failed: {e}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Documentation generation failed: {str(e)}" + ) + + +@app.get("/analysis/{analysis_id}/tree") +async def get_tree_structure(analysis_id: str): + """Get the hierarchical tree structure of the codebase""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + return session["analysis"]["tree_structure"] + + +@app.get("/analysis/{analysis_id}/dependencies") +async def get_dependency_graph(analysis_id: str): + """Get the dependency graph of the codebase""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + return session["analysis"]["dependency_graph"] + + +@app.get("/analysis/{analysis_id}/architecture") +async def get_architectural_insights(analysis_id: str): + """Get architectural insights about the codebase""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + return { + "architectural_patterns": session["analysis"]["architectural_insights"], + "security_analysis": session["analysis"]["security_analysis"], + "performance_analysis": session["analysis"]["performance_analysis"], + } + + +@app.get("/analysis/{analysis_id}/summary") +async def get_analysis_summary(analysis_id: str): + """Get a summary of the analysis""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions[analysis_id] + analysis = session["analysis"] + + return { + "analysis_id": analysis_id, + "repo_url": session["repo_url"], + "branch": session["branch"], + "created_at": session["created_at"], + "metrics": analysis["metrics"], + "summary": { + "total_errors": analysis["error_analysis"]["total"], + "critical_errors": analysis["error_analysis"]["critical"], + "dead_code_items": analysis["dead_code_analysis"]["total"], + "entrypoints": analysis["entrypoint_analysis"]["total_entrypoints"], + "complexity_score": analysis["code_quality_metrics"]["complexity_score"], + "maintainability_index": analysis["code_quality_metrics"][ + "maintainability_index" + ], + "technical_debt_ratio": analysis["code_quality_metrics"][ + "technical_debt_ratio" + ], + }, + } + + +@app.delete("/analysis/{analysis_id}") +async def delete_analysis(analysis_id: str): + """Delete an analysis session""" + if analysis_id not in analysis_sessions: + raise HTTPException(status_code=404, detail="Analysis not found") + + session = analysis_sessions.pop(analysis_id) + + # Clean up temporary directory + try: + if os.path.exists(session["repo_path"]): + shutil.rmtree(session["repo_path"]) + except Exception as e: + logger.warning(f"Failed to clean up temp directory: {e}") + + return {"message": "Analysis deleted successfully"} + + +@app.get("/analyses") +async def list_analyses(): + """List all analysis sessions""" + return { + "analyses": [ + { + "id": session_id, + "repo_url": session["repo_url"], + "branch": session["branch"], + "created_at": session["created_at"], + "metrics": session["analysis"]["metrics"], + } + for session_id, session in analysis_sessions.items() + ] + } + + +# ============================================================================ +# UTILITY ENDPOINTS +# ============================================================================ + + +@app.get("/health") +async def health_check(): + """Health check endpoint""" + return { + "status": "healthy", + "graph_sitter_available": GRAPH_SITTER_AVAILABLE, + "active_sessions": len(analysis_sessions), + "timestamp": datetime.now().isoformat(), + } + + +@app.get("/capabilities") +async def get_capabilities(): + """Get API capabilities""" + return { + "graph_sitter_available": GRAPH_SITTER_AVAILABLE, + "supported_languages": ["python", "typescript", "javascript"], + "analysis_features": [ + "error_analysis", + "dead_code_detection", + "entrypoint_analysis", + "dependency_analysis", + "code_quality_metrics", + "architectural_insights", + "security_analysis", + "performance_analysis", + ], + "visualization_types": [ + "dependency_graph", + "call_flow", + "data_flow", + "blast_radius", + "class_hierarchy", # Added + "module_dependency", # Added + ], + "transformation_types": [ + "move_symbol", + "remove_symbol", + "rename_symbol", + "resolve_imports", + "add_type_annotations", + "extract_function", + ], + "documentation_generation": ["mdx"], # Added + "ai_resolution": ["lsp_diagnostics"], # Added + } + + +# ============================================================================ +# CLEANUP UTILITIES +# ============================================================================ + + +async def cleanup_temp_directory(repo_path: str): + """Clean up temporary directory after delay""" + await asyncio.sleep(3600) # Wait 1 hour before cleanup + try: + if os.path.exists(repo_path): + shutil.rmtree(repo_path) + logger.info(f"Cleaned up temporary directory: {repo_path}") + except Exception as e: + logger.warning(f"Failed to clean up temp directory {repo_path}: {e}") + + +def find_import_cycles(codebase: Codebase) -> List[List[str]]: + """Find import cycles in the codebase""" + import_graph = nx.DiGraph() + + for file_obj in codebase.files: + import_graph.add_node(file_obj.filepath) + for imp in file_obj.imports: + if hasattr(imp, "from_file") and imp.from_file: + import_graph.add_edge(file_obj.filepath, imp.from_file.filepath) + + return list(nx.simple_cycles(import_graph)) + + +def find_problematic_import_loops( + codebase: Codebase, cycles: List[List[str]] +) -> List[Dict[str, Any]]: + """Identify cycles with both static and dynamic imports between files""" + problematic_cycles = [] + + for i, cycle in enumerate(cycles): + mixed_imports = {} + + for from_file_path in cycle: + from_file = next( + (f for f in codebase.files if f.filepath == from_file_path), None + ) + if not from_file: + continue + + for to_file_path in cycle: + if from_file_path == to_file_path: + continue + + # Check imports from from_file to to_file + imports_to_file = [ + imp + for imp in from_file.imports + if hasattr(imp, "from_file") + and imp.from_file + and imp.from_file.filepath == to_file_path + ] + + if imports_to_file: + mixed_imports[(from_file_path, to_file_path)] = { + "imports": len(imports_to_file), + "import_names": [imp.name for imp in imports_to_file], + } + + if mixed_imports: + problematic_cycles.append( + {"files": cycle, "mixed_imports": mixed_imports, "index": i} + ) + + return problematic_cycles + + +def convert_all_calls_to_kwargs(codebase: Codebase): + """Convert all function calls to use keyword arguments""" + converted_count = 0 + + for file_obj in codebase.files: + for function_call in file_obj.function_calls: + try: + # Get function definition + func_def = function_call.function_definition + if not func_def or isinstance(func_def, ExternalModule): + continue + + # Convert positional args to kwargs + for i, arg in enumerate(function_call.args): + if not arg.is_named and i < len(func_def.parameters): + param = func_def.parameters[i] + arg.add_keyword(param.name) + converted_count += 1 + + except Exception as e: + logger.warning(f"Failed to convert call {function_call.name}: {e}") + + logger.info(f"Converted {converted_count} function call arguments to kwargs") + return converted_count + + +def apply_fix_to_file( + filepath: str, + original_content: str, + fixed_code_snippet: str, + diagnostic_range: Range, +) -> str: + """ + Applies a fixed code snippet to the original file content. + This function attempts to replace the lines covered by the diagnostic range. + If the fixed_code_snippet is a full file, it replaces the entire content. + """ + lines = original_content.splitlines( + keepends=True + ) # Keep newlines for accurate replacement + + # Heuristic: If the fixed_code_snippet is very large compared to the original file, + # or if it contains a shebang/imports, assume it's a full file replacement. + # This is a simplification; a more robust solution would involve diffing or AST. + if ( + len(fixed_code_snippet) > len(original_content) * 0.8 + or fixed_code_snippet.startswith("#!") + or "import " in fixed_code_snippet[:100] + or "from " in fixed_code_snippet[:100] + ): + logger.info( + f"Assuming full file replacement for {filepath} based on fixed code size/content." + ) + return fixed_code_snippet + + start_line_idx = diagnostic_range.line + end_line_idx = diagnostic_range.end.line + + # Ensure indices are within bounds + start_line_idx = max(0, min(start_line_idx, len(lines) - 1)) + end_line_idx = max(0, min(end_line_idx, len(lines) - 1)) + + # Replace the block of lines + fixed_lines = fixed_code_snippet.splitlines(keepends=True) + + # Adjust fixed_lines to ensure they end with a newline if the original lines did + for i in range(len(fixed_lines)): + if not fixed_lines[i].endswith("\n") and ( + i + start_line_idx < len(lines) and lines[i + start_line_idx].endswith("\n") + ): + fixed_lines[i] += "\n" + + # If the fixed snippet is just one line, and the diagnostic covers multiple lines, + # it might be a replacement for a block. + # If the fixed snippet is multi-line, and diagnostic is single-line, it's an expansion. + # This simple replacement works for many cases but can be brittle. + # A proper solution would involve `difflib` or `tree-sitter` based patching. + + # Replace the range + new_lines = lines[:start_line_idx] + fixed_lines + lines[end_line_idx + 1 :] + return "".join(new_lines) + + +# ============================================================================ +# MAIN APPLICATION +# ============================================================================ + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Graph-Sitter Backend API") + parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", type=int, default=8000, help="Port to bind to") + parser.add_argument("--reload", action="store_true", help="Enable auto-reload") + parser.add_argument("--log-level", default="info", help="Log level") + + args = parser.parse_args() + + uvicorn.run( + "graph_sitter_backend:app", # Assuming this file is named graph_sitter_backend.py + host=args.host, + port=args.port, + reload=args.reload, + log_level=args.log_level, + ) diff --git a/Libraries/graph_sitter_lib/gsbuild/build.py b/Libraries/graph_sitter_lib/gsbuild/build.py new file mode 100644 index 00000000..fb5ae583 --- /dev/null +++ b/Libraries/graph_sitter_lib/gsbuild/build.py @@ -0,0 +1,24 @@ +import sys +from pathlib import Path +from typing import Any + +from hatchling.builders.hooks.plugin.interface import BuildHookInterface + + +def update_init_file(file: Path) -> None: + path = Path(__file__).parent.parent + sys.path.append(str(path)) + from graph_sitter.gscli.generate.runner_imports import generate_exported_modules, get_runner_imports + + content = file.read_text() + content = get_runner_imports(include_codegen=False) + "\n" + content + "\n" + generate_exported_modules() + file.write_text(content) + + +class SpecialBuildHook(BuildHookInterface): + PLUGIN_NAME = "codegen_build" + + def initialize(self, version: str, build_data: dict[str, Any]) -> None: + file = Path(self.root) / "src" / "graph_sitter" / "__init__.py" + update_init_file(file) + build_data["artifacts"].append(f"/{file}") diff --git a/Libraries/graph_sitter_lib/lsp_diagnostics.py b/Libraries/graph_sitter_lib/lsp_diagnostics.py new file mode 100644 index 00000000..9ae4ad4f --- /dev/null +++ b/Libraries/graph_sitter_lib/lsp_diagnostics.py @@ -0,0 +1,563 @@ +#!/usr/bin/env python3 +"""Enhanced LSP Diagnostics Manager with Runtime Error Collection +Integrates with Graph-Sitter and AutoGenLib for comprehensive error context +""" + +import asyncio +import logging +import os +import re +import time +from typing import Any, TypedDict + +# Import GraphSitterAnalyzer for context enrichment +from graph_sitter import Codebase +from graph_sitter.extensions.lsp.solidlsp.ls import SolidLanguageServer +from graph_sitter.extensions.lsp.solidlsp.ls_config import Language, LanguageServerConfig +from graph_sitter.extensions.lsp.solidlsp.ls_logger import LanguageServerLogger +from graph_sitter.extensions.lsp.solidlsp.ls_utils import PathUtils +from graph_sitter.extensions.lsp.solidlsp.lsp_protocol_handler.lsp_types import Diagnostic, DocumentUri, Range + +logger = logging.getLogger(__name__) + + +class EnhancedDiagnostic(TypedDict): + """A diagnostic with comprehensive context for AI resolution.""" + + diagnostic: Diagnostic + file_content: str + relevant_code_snippet: str + file_path: str # Absolute path to the file + relative_file_path: str # Path relative to codebase root + + # Enhanced context fields + graph_sitter_context: dict[str, Any] + autogenlib_context: dict[str, Any] + runtime_context: dict[str, Any] + ui_interaction_context: dict[str, Any] + + +class RuntimeErrorCollector: + """Collects runtime errors from various sources.""" + + def __init__(self, codebase: Codebase): + self.codebase = codebase + self.runtime_errors = [] + self.ui_errors = [] + self.error_patterns = {} + + def collect_python_runtime_errors(self, log_file_path: str | None = None) -> list[dict[str, Any]]: + """Collect Python runtime errors from logs or exception handlers.""" + runtime_errors = [] + + # If log file is provided, parse it for errors + if log_file_path and os.path.exists(log_file_path): + try: + with open(log_file_path) as f: + log_content = f.read() + + # Parse Python tracebacks + traceback_pattern = r"Traceback \(most recent call last\):(.*?)(?=\n\w|\nTraceback|\Z)" + tracebacks = re.findall(traceback_pattern, log_content, re.DOTALL) + + for traceback in tracebacks: + # Extract file, line, and error info + file_pattern = r'File "([^"]+)", line (\d+), in (\w+)' + error_pattern = r"(\w+Error): (.+)" + + file_matches = re.findall(file_pattern, traceback) + error_matches = re.findall(error_pattern, traceback) + + if file_matches and error_matches: + file_path, line_num, function_name = file_matches[-1] # Last frame + error_type, error_message = error_matches[-1] + + runtime_errors.append( + { + "type": "runtime_error", + "error_type": error_type, + "message": error_message, + "file_path": file_path, + "line": int(line_num), + "function": function_name, + "traceback": traceback.strip(), + "severity": "critical", + "timestamp": time.time(), + } + ) + + except Exception as e: + logger.warning(f"Error parsing log file {log_file_path}: {e}") + + # Collect from in-memory exception handlers if available + # This would require integration with the target application's exception handling + runtime_errors.extend(self._collect_in_memory_errors()) + + return runtime_errors + + def collect_ui_interaction_errors(self, ui_log_path: str | None = None) -> list[dict[str, Any]]: + """Collect UI interaction errors from frontend logs or error boundaries.""" + ui_errors = [] + + # Parse JavaScript/TypeScript errors from UI logs + if ui_log_path and os.path.exists(ui_log_path): + try: + with open(ui_log_path) as f: + log_content = f.read() + + # Parse JavaScript errors + js_error_pattern = r"(TypeError|ReferenceError|SyntaxError): (.+?) at (.+?):(\d+):(\d+)" + js_errors = re.findall(js_error_pattern, log_content) + + for error_type, message, file_path, line, column in js_errors: + ui_errors.append( + { + "type": "ui_error", + "error_type": error_type, + "message": message, + "file_path": file_path, + "line": int(line), + "column": int(column), + "severity": "major", + "timestamp": time.time(), + } + ) + + # Parse React component errors + react_error_pattern = r"Error: (.+?) in (\w+) \(at (.+?):(\d+):(\d+)\)" + react_errors = re.findall(react_error_pattern, log_content) + + for message, component, file_path, line, column in react_errors: + ui_errors.append( + { + "type": "react_error", + "error_type": "ComponentError", + "message": message, + "component": component, + "file_path": file_path, + "line": int(line), + "column": int(column), + "severity": "major", + "timestamp": time.time(), + } + ) + + # Parse console errors + console_error_pattern = r"console\.error: (.+)" + console_errors = re.findall(console_error_pattern, log_content) + + for error_message in console_errors: + ui_errors.append({"type": "console_error", "error_type": "ConsoleError", "message": error_message, "severity": "minor", "timestamp": time.time()}) + + except Exception as e: + logger.warning(f"Error parsing UI log file {ui_log_path}: {e}") + + # Collect from browser console if available + ui_errors.extend(self._collect_browser_console_errors()) + + return ui_errors + + def collect_network_errors(self) -> list[dict[str, Any]]: + """Collect network-related errors.""" + network_errors = [] + + # Look for network error patterns in code + for file_obj in self.codebase.files: + if hasattr(file_obj, "source") and file_obj.source: + # Find fetch/axios/request patterns + network_patterns = [r'fetch\(["\']([^"\']+)["\']', r'axios\.(get|post|put|delete)\(["\']([^"\']+)["\']', r'requests\.(get|post|put|delete)\(["\']([^"\']+)["\']'] + + for pattern in network_patterns: + matches = re.findall(pattern, file_obj.source) + for match in matches: + network_errors.append( + { + "type": "network_call", + "file_path": file_obj.filepath, + "endpoint": match[1] if isinstance(match, tuple) else match, + "method": match[0] if isinstance(match, tuple) else "unknown", + "potential_failure_point": True, + } + ) + + return network_errors + + def _collect_in_memory_errors(self) -> list[dict[str, Any]]: + """Collect runtime errors from in-memory exception handlers.""" + # This would integrate with the application's exception handling system + # For now, return empty list as this requires application-specific integration + return [] + + def _collect_browser_console_errors(self) -> list[dict[str, Any]]: + """Collect errors from browser console.""" + # This would require browser automation or console API integration + # For now, return empty list as this requires browser-specific integration + return [] + + +class LSPDiagnosticsManager: + """Enhanced LSP server lifecycle and diagnostic retrieval with comprehensive context enrichment.""" + + def __init__(self, codebase: Codebase, language: Language, log_level=logging.INFO): + self.codebase = codebase + self.language = language + self.logger = LanguageServerLogger(log_level=log_level) + self.lsp_server: SolidLanguageServer | None = None + self.repository_root_path = codebase.root # Use codebase root + self.runtime_collector = RuntimeErrorCollector(codebase) + + # Enhanced error tracking + self.error_history = [] + self.error_frequency = {} + self.resolution_attempts = {} + + def start_server(self) -> None: + """Starts the LSP server and initializes it.""" + if self.lsp_server is None: + self.lsp_server = SolidLanguageServer.create( + language=self.language, logger=self.logger, repository_root_path=self.repository_root_path, config=LanguageServerConfig(code_language=self.language, trace_lsp_communication=False) + ) + self.logger.log(f"Starting LSP server for {self.language.value} at {self.repository_root_path}", logging.INFO) + self.lsp_server.start() + self.logger.log("LSP server started.", logging.INFO) + + def open_file(self, relative_file_path: str, content: str) -> None: + """Notifies the LSP server that a file has been opened.""" + if self.lsp_server: + self.lsp_server.open_file(relative_file_path, content) + else: + self.logger.log("LSP server not started. Cannot open file.", logging.WARNING) + + def change_file(self, relative_file_path: str, content: str) -> None: + """Notifies the LSP server that a file has been changed.""" + if self.lsp_server: + self.lsp_server.change_file(relative_file_path, content) + else: + self.logger.log("LSP server not started. Cannot change file.", logging.WARNING) + + def get_diagnostics(self, relative_file_path: str) -> list[Diagnostic]: + """Retrieves diagnostics for a specific file.""" + if self.lsp_server: + uri = PathUtils.path_to_uri(os.path.join(self.repository_root_path, relative_file_path)) + return self.lsp_server.get_diagnostics_for_uri(uri) + else: + self.logger.log("LSP server not started. Cannot get diagnostics.", logging.WARNING) + return [] + + def get_all_enhanced_diagnostics(self, runtime_log_path: str | None = None, ui_log_path: str | None = None) -> list[EnhancedDiagnostic]: + """Retrieves all collected diagnostics from the LSP server, enriched with comprehensive context.""" + if not self.lsp_server: + self.logger.log("LSP server not started. No enhanced diagnostics available.", logging.WARNING) + return [] + + all_raw_diagnostics = self.lsp_server.get_all_diagnostics() + enhanced_diagnostics: list[EnhancedDiagnostic] = [] + + # Collect runtime errors + runtime_errors = self.runtime_collector.collect_python_runtime_errors(runtime_log_path) + ui_errors = self.runtime_collector.collect_ui_interaction_errors(ui_log_path) + network_errors = self.runtime_collector.collect_network_errors() + + # Import autogenlib_context here to avoid circular dependency at module level + from autogenlib_adapter import get_ai_fix_context + + for uri, diagnostics_list in all_raw_diagnostics.items(): + file_path = PathUtils.uri_to_path(uri) + relative_file_path = os.path.relpath(file_path, self.repository_root_path) + + try: + file_content = self.codebase.get_file(relative_file_path).content + except ValueError: + logger.warning(f"File {relative_file_path} not found in codebase. Skipping diagnostics for this file.") + continue + + for diag in diagnostics_list: + relevant_code = self._get_relevant_code_for_diagnostic(file_content, diag.range) + + # Find related runtime errors for this file/line + related_runtime_errors = [ + err + for err in runtime_errors + if err["file_path"].endswith(relative_file_path) and abs(err["line"] - (diag.range.line + 1)) <= 2 # Within 2 lines + ] + + # Find related UI errors + related_ui_errors = [ + err + for err in ui_errors + if err["file_path"].endswith(relative_file_path) and abs(err["line"] - (diag.range.line + 1)) <= 2 # Within 2 lines + ] + + # Find related network errors + related_network_errors = [err for err in network_errors if err["file_path"] == relative_file_path] + + # Track error frequency + error_key = f"{diag.code}:{relative_file_path}:{diag.range.line}" + self.error_frequency[error_key] = self.error_frequency.get(error_key, 0) + 1 + + # Create a partial EnhancedDiagnostic + partial_enhanced_diag: EnhancedDiagnostic = { + "diagnostic": diag, + "file_content": file_content, + "relevant_code_snippet": relevant_code, + "file_path": file_path, + "relative_file_path": relative_file_path, + "graph_sitter_context": {}, # Will be filled by get_ai_fix_context + "autogenlib_context": {}, # Will be filled by get_ai_fix_context + "runtime_context": { + "related_runtime_errors": related_runtime_errors, + "error_frequency": self.error_frequency.get(error_key, 0), + "last_runtime_error": related_runtime_errors[-1] if related_runtime_errors else None, + "network_errors": related_network_errors, + "error_history": self._get_error_history(error_key), + }, + "ui_interaction_context": { + "related_ui_errors": related_ui_errors, + "ui_error_frequency": len(related_ui_errors), + "last_ui_error": related_ui_errors[-1] if related_ui_errors else None, + "component_errors": self._extract_component_errors(related_ui_errors), + }, + } + + # Get the full enhanced context using autogenlib_context + full_enhanced_diag = get_ai_fix_context(partial_enhanced_diag, self.codebase) + enhanced_diagnostics.append(full_enhanced_diag) + + # Store in error history + self.error_history.append({"timestamp": time.time(), "diagnostic": diag, "file": relative_file_path, "resolved": False}) + + return enhanced_diagnostics + + def _get_relevant_code_for_diagnostic(self, file_content: str, diagnostic_range: Range, context_lines: int = 5) -> str: + """Extracts the code snippet directly related to the diagnostic, plus surrounding context.""" + lines = file_content.splitlines() + + start_line = max(0, diagnostic_range.line - context_lines) + end_line = min(len(lines), diagnostic_range.end.line + context_lines + 1) # +1 to include the end line + + snippet_lines = lines[start_line:end_line] + + # Simple highlighting: add markers around the problematic line + if diagnostic_range.line >= start_line and diagnostic_range.line < end_line: + line_in_snippet_index = diagnostic_range.line - start_line + original_line = snippet_lines[line_in_snippet_index] + + # Attempt to highlight the exact character range if it's within the same line + if diagnostic_range.line == diagnostic_range.end.line: + char_start = diagnostic_range.character + char_end = diagnostic_range.end.character + highlighted_segment = original_line[char_start:char_end] + + # Avoid empty highlights or out-of-bounds access + if highlighted_segment: + highlighted_line = original_line[:char_start] + "**>>>" + highlighted_segment + "<<<**" + original_line[char_end:] + snippet_lines[line_in_snippet_index] = highlighted_line + else: + snippet_lines[line_in_snippet_index] = ">>> " + original_line + " <<<" + else: + # For multi-line diagnostics, just mark the start line + snippet_lines[line_in_snippet_index] = ">>> " + original_line + " <<<" + + return "\n".join(snippet_lines) + + def _get_error_history(self, error_key: str) -> list[dict[str, Any]]: + """Get historical data for a specific error.""" + return [entry for entry in self.error_history if f"{entry['diagnostic'].code}:{entry['file']}:{entry['diagnostic'].range.line}" == error_key] + + def _extract_component_errors(self, ui_errors: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Extract component-specific error information.""" + component_errors = [] + for error in ui_errors: + if error.get("type") == "react_error": + component_errors.append( + { + "component": error.get("component"), + "error_type": error.get("error_type"), + "message": error.get("message"), + "frequency": 1, # Could be enhanced with actual frequency tracking + } + ) + return component_errors + + def collect_runtime_diagnostics(self, runtime_log_path: str | None = None, ui_log_path: str | None = None) -> list[dict[str, Any]]: + """Collect runtime errors and convert them to diagnostic-like format.""" + runtime_diagnostics = [] + + # Collect Python runtime errors + runtime_errors = self.runtime_collector.collect_python_runtime_errors(runtime_log_path) + for error in runtime_errors: + # Convert runtime error to diagnostic-like format + try: + relative_path = os.path.relpath(error["file_path"], self.repository_root_path) + file_content = self.codebase.get_file(relative_path).content + + # Create a mock Range for the error line + error_range = Range(start={"line": error["line"] - 1, "character": 0}, end={"line": error["line"] - 1, "character": 100}) + + # Create a mock Diagnostic + mock_diagnostic = Diagnostic( + uri=PathUtils.path_to_uri(error["file_path"]), + range=error_range, + severity=1, # Error severity + message=f"Runtime {error['error_type']}: {error['message']}", + code="runtime_error", + source="runtime_collector", + ) + + runtime_diagnostics.append( + { + "diagnostic": mock_diagnostic, + "file_content": file_content, + "relevant_code_snippet": self._get_relevant_code_for_diagnostic(file_content, error_range), + "file_path": error["file_path"], + "relative_file_path": relative_path, + "runtime_error_data": error, + "error_source": "runtime", + } + ) + + except Exception as e: + logger.warning(f"Error processing runtime error: {e}") + + # Collect UI interaction errors + ui_errors = self.runtime_collector.collect_ui_interaction_errors(ui_log_path) + for error in ui_errors: + try: + relative_path = os.path.relpath(error["file_path"], self.repository_root_path) + file_content = self.codebase.get_file(relative_path).content + + # Create a mock Range for the error line + error_range = Range(start={"line": error["line"] - 1, "character": error.get("column", 0)}, end={"line": error["line"] - 1, "character": error.get("column", 0) + 10}) + + # Create a mock Diagnostic + mock_diagnostic = Diagnostic( + uri=PathUtils.path_to_uri(error["file_path"]), + range=error_range, + severity=2, # Warning severity + message=f"UI {error['error_type']}: {error['message']}", + code="ui_error", + source="ui_collector", + ) + + runtime_diagnostics.append( + { + "diagnostic": mock_diagnostic, + "file_content": file_content, + "relevant_code_snippet": self._get_relevant_code_for_diagnostic(file_content, error_range), + "file_path": error["file_path"], + "relative_file_path": relative_path, + "ui_error_data": error, + "error_source": "ui", + } + ) + + except Exception as e: + logger.warning(f"Error processing UI error: {e}") + + return runtime_diagnostics + + def get_error_statistics(self) -> dict[str, Any]: + """Get comprehensive error statistics.""" + if not self.lsp_server: + return {} + + all_diagnostics = self.lsp_server.get_all_diagnostics() + runtime_errors = self.lsp_server.get_runtime_errors() + ui_errors = self.lsp_server.get_ui_errors() + error_patterns = self.lsp_server.get_error_patterns() + + return { + "lsp_diagnostics": { + "total": sum(len(diags) for diags in all_diagnostics.values()), + "files_affected": len(all_diagnostics), + "by_severity": self._categorize_diagnostics_by_severity(all_diagnostics), + }, + "runtime_errors": { + "total": len(runtime_errors), + "by_type": Counter(err.get("type", "unknown") for err in runtime_errors), + "recent_errors": runtime_errors[-10:], # Last 10 errors + }, + "ui_errors": { + "total": len(ui_errors), + "by_type": Counter(err.get("type", "unknown") for err in ui_errors), + "component_errors": len([err for err in ui_errors if err.get("type") == "react_error"]), + }, + "error_patterns": error_patterns, + "error_frequency": self.error_frequency, + "resolution_success_rate": self._calculate_resolution_success_rate(), + } + + def add_runtime_error(self, error_data: dict[str, Any]) -> None: + """Add a runtime error to the LSP server's collection.""" + if self.lsp_server: + self.lsp_server.add_runtime_error(error_data) + + def add_ui_error(self, error_data: dict[str, Any]) -> None: + """Add a UI error to the LSP server's collection.""" + if self.lsp_server: + self.lsp_server.add_ui_error(error_data) + + def clear_diagnostics(self) -> None: + """Clears all stored diagnostics in the LSP server.""" + if self.lsp_server: + self.lsp_server.clear_diagnostics() + self.error_history.clear() + self.error_frequency.clear() + self.resolution_attempts.clear() + + def shutdown_server(self) -> None: + """Shuts down the LSP server.""" + if self.lsp_server: + self.logger.log("Shutting down LSP server.", logging.INFO) + self.lsp_server.stop() + self.lsp_server = None + self.logger.log("LSP server shut down.", logging.INFO) + + async def monitor_runtime_errors(self, callback_func=None, monitor_duration: int = 60): + """Monitor for runtime errors in real-time.""" + logger.info(f"Starting runtime error monitoring for {monitor_duration} seconds...") + + start_time = asyncio.get_event_loop().time() + collected_errors = [] + + while (asyncio.get_event_loop().time() - start_time) < monitor_duration: + # Collect new runtime errors + new_runtime_errors = self.runtime_collector.collect_python_runtime_errors() + new_ui_errors = self.runtime_collector.collect_ui_interaction_errors() + + all_new_errors = new_runtime_errors + new_ui_errors + + if all_new_errors: + collected_errors.extend(all_new_errors) + if callback_func: + await callback_func(all_new_errors) + + await asyncio.sleep(1) # Check every second + + logger.info(f"Runtime error monitoring completed. Collected {len(collected_errors)} errors.") + return collected_errors + + def _categorize_diagnostics_by_severity(self, all_diagnostics: dict[DocumentUri, list[Diagnostic]]) -> dict[str, int]: + """Categorize diagnostics by severity.""" + severity_counts = {"error": 0, "warning": 0, "information": 0, "hint": 0} + + for diagnostics_list in all_diagnostics.values(): + for diag in diagnostics_list: + if diag.severity: + severity_name = diag.severity.name.lower() + if severity_name in severity_counts: + severity_counts[severity_name] += 1 + + return severity_counts + + def _calculate_resolution_success_rate(self) -> float: + """Calculate the success rate of error resolutions.""" + if not self.resolution_attempts: + return 0.0 + + successful = sum(1 for attempt in self.resolution_attempts.values() if attempt.get("success", False)) + return successful / len(self.resolution_attempts) + + def mark_error_resolved(self, error_key: str, success: bool, method: str) -> None: + """Mark an error as resolved or failed.""" + self.resolution_attempts[error_key] = {"success": success, "method": method, "timestamp": time.time()} diff --git a/Libraries/serena/__init__.py b/Libraries/serena/__init__.py new file mode 100644 index 00000000..a2c1cf0a --- /dev/null +++ b/Libraries/serena/__init__.py @@ -0,0 +1,23 @@ +__version__ = "0.1.4" + +import logging + +log = logging.getLogger(__name__) + + +def serena_version() -> str: + """ + :return: the version of the package, including git status if available. + """ + from serena.util.git import get_git_status + + version = __version__ + try: + git_status = get_git_status() + if git_status is not None: + version += f"-{git_status.commit[:8]}" + if not git_status.is_clean: + version += "-dirty" + except: + pass + return version diff --git a/Libraries/serena/agent.py b/Libraries/serena/agent.py new file mode 100644 index 00000000..d8a491dd --- /dev/null +++ b/Libraries/serena/agent.py @@ -0,0 +1,606 @@ +""" +The Serena Model Context Protocol (MCP) Server +""" + +import multiprocessing +import os +import platform +import sys +import threading +import webbrowser +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor +from logging import Logger +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional, TypeVar + +from sensai.util import logging +from sensai.util.logging import LogTime + +from interprompt.jinja_template import JinjaTemplate +from serena import serena_version +from serena.analytics import RegisteredTokenCountEstimator, ToolUsageStats +from serena.config.context_mode import RegisteredContext, SerenaAgentContext, SerenaAgentMode +from serena.config.serena_config import SerenaConfig, ToolInclusionDefinition, ToolSet, get_serena_managed_in_project_dir +from serena.dashboard import SerenaDashboardAPI +from serena.project import Project +from serena.prompt_factory import SerenaPromptFactory +from serena.tools import ActivateProjectTool, GetCurrentConfigTool, Tool, ToolMarker, ToolRegistry +from serena.util.inspection import iter_subclasses +from serena.util.logging import MemoryLogHandler +from solidlsp import SolidLanguageServer + +if TYPE_CHECKING: + from serena.gui_log_viewer import GuiLogViewer + +log = logging.getLogger(__name__) +TTool = TypeVar("TTool", bound="Tool") +T = TypeVar("T") +SUCCESS_RESULT = "OK" + + +class ProjectNotFoundError(Exception): + pass + + +class MemoriesManager: + def __init__(self, project_root: str): + self._memory_dir = Path(get_serena_managed_in_project_dir(project_root)) / "memories" + self._memory_dir.mkdir(parents=True, exist_ok=True) + + def _get_memory_file_path(self, name: str) -> Path: + # strip all .md from the name. Models tend to get confused, sometimes passing the .md extension and sometimes not. + name = name.replace(".md", "") + filename = f"{name}.md" + return self._memory_dir / filename + + def load_memory(self, name: str) -> str: + memory_file_path = self._get_memory_file_path(name) + if not memory_file_path.exists(): + return f"Memory file {name} not found, consider creating it with the `write_memory` tool if you need it." + with open(memory_file_path, encoding="utf-8") as f: + return f.read() + + def save_memory(self, name: str, content: str) -> str: + memory_file_path = self._get_memory_file_path(name) + with open(memory_file_path, "w", encoding="utf-8") as f: + f.write(content) + return f"Memory {name} written." + + def list_memories(self) -> list[str]: + return [f.name.replace(".md", "") for f in self._memory_dir.iterdir() if f.is_file()] + + def delete_memory(self, name: str) -> str: + memory_file_path = self._get_memory_file_path(name) + memory_file_path.unlink() + return f"Memory {name} deleted." + + +class AvailableTools: + def __init__(self, tools: list[Tool]): + """ + :param tools: the list of available tools + """ + self.tools = tools + self.tool_names = [tool.get_name_from_cls() for tool in tools] + self.tool_marker_names = set() + for marker_class in iter_subclasses(ToolMarker): + for tool in tools: + if isinstance(tool, marker_class): + self.tool_marker_names.add(marker_class.__name__) + + def __len__(self) -> int: + return len(self.tools) + + +class SerenaAgent: + def __init__( + self, + project: str | None = None, + project_activation_callback: Callable[[], None] | None = None, + serena_config: SerenaConfig | None = None, + context: SerenaAgentContext | None = None, + modes: list[SerenaAgentMode] | None = None, + memory_log_handler: MemoryLogHandler | None = None, + ): + """ + :param project: the project to load immediately or None to not load any project; may be a path to the project or a name of + an already registered project; + :param project_activation_callback: a callback function to be called when a project is activated. + :param serena_config: the Serena configuration or None to read the configuration from the default location. + :param context: the context in which the agent is operating, None for default context. + The context may adjust prompts, tool availability, and tool descriptions. + :param modes: list of modes in which the agent is operating (they will be combined), None for default modes. + The modes may adjust prompts, tool availability, and tool descriptions. + :param memory_log_handler: a MemoryLogHandler instance from which to read log messages; if None, a new one will be created + if necessary. + """ + # obtain serena configuration using the decoupled factory function + self.serena_config = serena_config or SerenaConfig.from_config_file() + + # project-specific instances, which will be initialized upon project activation + self._active_project: Project | None = None + self.language_server: SolidLanguageServer | None = None + self.memories_manager: MemoriesManager | None = None + + # adjust log level + serena_log_level = self.serena_config.log_level + if Logger.root.level > serena_log_level: + log.info(f"Changing the root logger level to {serena_log_level}") + Logger.root.setLevel(serena_log_level) + + def get_memory_log_handler() -> MemoryLogHandler: + nonlocal memory_log_handler + if memory_log_handler is None: + memory_log_handler = MemoryLogHandler(level=serena_log_level) + Logger.root.addHandler(memory_log_handler) + return memory_log_handler + + # open GUI log window if enabled + self._gui_log_viewer: Optional["GuiLogViewer"] = None + if self.serena_config.gui_log_window_enabled: + if platform.system() == "Darwin": + log.warning("GUI log window is not supported on macOS") + else: + # even importing on macOS may fail if tkinter dependencies are unavailable (depends on Python interpreter installation + # which uv used as a base, unfortunately) + from serena.gui_log_viewer import GuiLogViewer + + self._gui_log_viewer = GuiLogViewer("dashboard", title="Serena Logs", memory_log_handler=get_memory_log_handler()) + self._gui_log_viewer.start() + + # set the agent context + if context is None: + context = SerenaAgentContext.load_default() + self._context = context + + # instantiate all tool classes + self._all_tools: dict[type[Tool], Tool] = {tool_class: tool_class(self) for tool_class in ToolRegistry().get_all_tool_classes()} + tool_names = [tool.get_name_from_cls() for tool in self._all_tools.values()] + + # If GUI log window is enabled, set the tool names for highlighting + if self._gui_log_viewer is not None: + self._gui_log_viewer.set_tool_names(tool_names) + + self._tool_usage_stats: ToolUsageStats | None = None + if self.serena_config.record_tool_usage_stats: + token_count_estimator = RegisteredTokenCountEstimator[self.serena_config.token_count_estimator] + log.info(f"Tool usage statistics recording is enabled with token count estimator: {token_count_estimator.name}.") + self._tool_usage_stats = ToolUsageStats(token_count_estimator) + + # start the dashboard (web frontend), registering its log handler + if self.serena_config.web_dashboard: + self._dashboard_thread, port = SerenaDashboardAPI( + get_memory_log_handler(), tool_names, agent=self, tool_usage_stats=self._tool_usage_stats + ).run_in_thread() + dashboard_url = f"http://127.0.0.1:{port}/dashboard/index.html" + log.info("Serena web dashboard started at %s", dashboard_url) + if self.serena_config.web_dashboard_open_on_launch: + # open the dashboard URL in the default web browser (using a separate process to control + # output redirection) + process = multiprocessing.Process(target=self._open_dashboard, args=(dashboard_url,)) + process.start() + process.join(timeout=1) + + # log fundamental information + log.info(f"Starting Serena server (version={serena_version()}, process id={os.getpid()}, parent process id={os.getppid()})") + log.info("Configuration file: %s", self.serena_config.config_file_path) + log.info("Available projects: {}".format(", ".join(self.serena_config.project_names))) + log.info(f"Loaded tools ({len(self._all_tools)}): {', '.join([tool.get_name_from_cls() for tool in self._all_tools.values()])}") + + self._check_shell_settings() + + # determine the base toolset defining the set of exposed tools (which e.g. the MCP shall see), + # limited by the Serena config, the context (which is fixed for the session) and JetBrains mode + tool_inclusion_definitions: list[ToolInclusionDefinition] = [self.serena_config, self._context] + if self._context.name == RegisteredContext.IDE_ASSISTANT.value: + tool_inclusion_definitions.extend(self._ide_assistant_context_tool_inclusion_definitions(project)) + if self.serena_config.jetbrains: + tool_inclusion_definitions.append(SerenaAgentMode.from_name_internal("jetbrains")) + + self._base_tool_set = ToolSet.default().apply(*tool_inclusion_definitions) + self._exposed_tools = AvailableTools([t for t in self._all_tools.values() if self._base_tool_set.includes_name(t.get_name())]) + log.info(f"Number of exposed tools: {len(self._exposed_tools)}") + + # create executor for starting the language server and running tools in another thread + # This executor is used to achieve linear task execution, so it is important to use a single-threaded executor. + self._task_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="SerenaAgentExecutor") + self._task_executor_lock = threading.Lock() + self._task_executor_task_index = 1 + + # Initialize the prompt factory + self.prompt_factory = SerenaPromptFactory() + self._project_activation_callback = project_activation_callback + + # set the active modes + if modes is None: + modes = SerenaAgentMode.load_default_modes() + self._modes = modes + + self._active_tools: dict[type[Tool], Tool] = {} + self._update_active_tools() + + # activate a project configuration (if provided or if there is only a single project available) + if project is not None: + try: + self.activate_project_from_path_or_name(project) + except Exception as e: + log.error(f"Error activating project '{project}' at startup: {e}", exc_info=e) + + def get_context(self) -> SerenaAgentContext: + return self._context + + def get_tool_description_override(self, tool_name: str) -> str | None: + return self._context.tool_description_overrides.get(tool_name, None) + + def _check_shell_settings(self) -> None: + # On Windows, Claude Code sets COMSPEC to Git-Bash (often even with a path containing spaces), + # which causes all sorts of trouble, preventing language servers from being launched correctly. + # So we make sure that COMSPEC is unset if it has been set to bash specifically. + if platform.system() == "Windows": + comspec = os.environ.get("COMSPEC", "") + if "bash" in comspec: + os.environ["COMSPEC"] = "" # force use of default shell + log.info("Adjusting COMSPEC environment variable to use the default shell instead of '%s'", comspec) + + def _ide_assistant_context_tool_inclusion_definitions(self, project_root_or_name: str | None) -> list[ToolInclusionDefinition]: + """ + In the IDE assistant context, the agent is assumed to work on a single project, and we thus + want to apply that project's tool exclusions/inclusions from the get-go, limiting the set + of tools that will be exposed to the client. + Furthermore, we disable tools that are only relevant for project activation. + So if the project exists, we apply all the aforementioned exclusions. + + :param project_root_or_name: the project root path or project name + :return: + """ + tool_inclusion_definitions = [] + if project_root_or_name is not None: + # Note: Auto-generation is disabled, because the result must be returned instantaneously + # (project generation could take too much time), so as not to delay MCP server startup + # and provide responses to the client immediately. + project = self.load_project_from_path_or_name(project_root_or_name, autogenerate=False) + if project is not None: + tool_inclusion_definitions.append( + ToolInclusionDefinition( + excluded_tools=[ActivateProjectTool.get_name_from_cls(), GetCurrentConfigTool.get_name_from_cls()] + ) + ) + tool_inclusion_definitions.append(project.project_config) + return tool_inclusion_definitions + + def record_tool_usage_if_enabled(self, input_kwargs: dict, tool_result: str | dict, tool: Tool) -> None: + """ + Record the usage of a tool with the given input and output strings if tool usage statistics recording is enabled. + """ + tool_name = tool.get_name() + if self._tool_usage_stats is not None: + input_str = str(input_kwargs) + output_str = str(tool_result) + log.debug(f"Recording tool usage for tool '{tool_name}'") + self._tool_usage_stats.record_tool_usage(tool_name, input_str, output_str) + else: + log.debug(f"Tool usage statistics recording is disabled, not recording usage of '{tool_name}'.") + + @staticmethod + def _open_dashboard(url: str) -> None: + # Redirect stdout and stderr file descriptors to /dev/null, + # making sure that nothing can be written to stdout/stderr, even by subprocesses + null_fd = os.open(os.devnull, os.O_WRONLY) + os.dup2(null_fd, sys.stdout.fileno()) + os.dup2(null_fd, sys.stderr.fileno()) + os.close(null_fd) + + # open the dashboard URL in the default web browser + webbrowser.open(url) + + def get_project_root(self) -> str: + """ + :return: the root directory of the active project (if any); raises a ValueError if there is no active project + """ + project = self.get_active_project() + if project is None: + raise ValueError("Cannot get project root if no project is active.") + return project.project_root + + def get_exposed_tool_instances(self) -> list["Tool"]: + """ + :return: the tool instances which are exposed (e.g. to the MCP client). + Note that the set of exposed tools is fixed for the session, as + clients don't react to changes in the set of tools, so this is the superset + of tools that can be offered during the session. + If a client should attempt to use a tool that is dynamically disabled + (e.g. because a project is activated that disables it), it will receive an error. + """ + return list(self._exposed_tools.tools) + + def get_active_project(self) -> Project | None: + """ + :return: the active project or None if no project is active + """ + return self._active_project + + def get_active_project_or_raise(self) -> Project: + """ + :return: the active project or raises an exception if no project is active + """ + project = self.get_active_project() + if project is None: + raise ValueError("No active project. Please activate a project first.") + return project + + def set_modes(self, modes: list[SerenaAgentMode]) -> None: + """ + Set the current mode configurations. + + :param modes: List of mode names or paths to use + """ + self._modes = modes + self._update_active_tools() + + log.info(f"Set modes to {[mode.name for mode in modes]}") + + def get_active_modes(self) -> list[SerenaAgentMode]: + """ + :return: the list of active modes + """ + return list(self._modes) + + def _format_prompt(self, prompt_template: str) -> str: + template = JinjaTemplate(prompt_template) + return template.render(available_tools=self._exposed_tools.tool_names, available_markers=self._exposed_tools.tool_marker_names) + + def create_system_prompt(self) -> str: + available_markers = self._exposed_tools.tool_marker_names + log.info("Generating system prompt with available_tools=(see exposed tools), available_markers=%s", available_markers) + system_prompt = self.prompt_factory.create_system_prompt( + context_system_prompt=self._format_prompt(self._context.prompt), + mode_system_prompts=[self._format_prompt(mode.prompt) for mode in self._modes], + available_tools=self._exposed_tools.tool_names, + available_markers=available_markers, + ) + log.info("System prompt:\n%s", system_prompt) + return system_prompt + + def _update_active_tools(self) -> None: + """ + Update the active tools based on enabled modes and the active project. + The base tool set already takes the Serena configuration and the context into account + (as well as any internal modes that are not handled dynamically, such as JetBrains mode). + """ + tool_set = self._base_tool_set.apply(*self._modes) + if self._active_project is not None: + tool_set = tool_set.apply(self._active_project.project_config) + if self._active_project.project_config.read_only: + tool_set = tool_set.without_editing_tools() + + self._active_tools = { + tool_class: tool_instance + for tool_class, tool_instance in self._all_tools.items() + if tool_set.includes_name(tool_instance.get_name()) + } + + log.info(f"Active tools ({len(self._active_tools)}): {', '.join(self.get_active_tool_names())}") + + def issue_task(self, task: Callable[[], Any], name: str | None = None) -> Future: + """ + Issue a task to the executor for asynchronous execution. + It is ensured that tasks are executed in the order they are issued, one after another. + + :param task: the task to execute + :param name: the name of the task for logging purposes; if None, use the task function's name + :return: a Future object representing the execution of the task + """ + with self._task_executor_lock: + task_name = f"Task-{self._task_executor_task_index}[{name or task.__name__}]" + self._task_executor_task_index += 1 + + def task_execution_wrapper() -> Any: + with LogTime(task_name, logger=log): + return task() + + log.info(f"Scheduling {task_name}") + return self._task_executor.submit(task_execution_wrapper) + + def execute_task(self, task: Callable[[], T]) -> T: + """ + Executes the given task synchronously via the agent's task executor. + This is useful for tasks that need to be executed immediately and whose results are needed right away. + + :param task: the task to execute + :return: the result of the task execution + """ + future = self.issue_task(task) + return future.result() + + def is_using_language_server(self) -> bool: + """ + :return: whether this agent uses language server-based code analysis + """ + return not self.serena_config.jetbrains + + def _activate_project(self, project: Project) -> None: + log.info(f"Activating {project.project_name} at {project.project_root}") + self._active_project = project + self._update_active_tools() + + # initialize project-specific instances which do not depend on the language server + self.memories_manager = MemoriesManager(project.project_root) + + def init_language_server() -> None: + # start the language server + with LogTime("Language server initialization", logger=log): + self.reset_language_server() + assert self.language_server is not None + + # initialize the language server in the background (if in language server mode) + if self.is_using_language_server(): + self.issue_task(init_language_server) + + if self._project_activation_callback is not None: + self._project_activation_callback() + + def load_project_from_path_or_name(self, project_root_or_name: str, autogenerate: bool) -> Project | None: + """ + Get a project instance from a path or a name. + + :param project_root_or_name: the path to the project root or the name of the project + :param autogenerate: whether to autogenerate the project for the case where first argument is a directory + which does not yet contain a Serena project configuration file + :return: the project instance if it was found/could be created, None otherwise + """ + project_instance: Project | None = self.serena_config.get_project(project_root_or_name) + if project_instance is not None: + log.info(f"Found registered project '{project_instance.project_name}' at path {project_instance.project_root}") + elif autogenerate and os.path.isdir(project_root_or_name): + project_instance = self.serena_config.add_project_from_path(project_root_or_name) + log.info(f"Added new project {project_instance.project_name} for path {project_instance.project_root}") + return project_instance + + def activate_project_from_path_or_name(self, project_root_or_name: str) -> Project: + """ + Activate a project from a path or a name. + If the project was already registered, it will just be activated. + If the argument is a path at which no Serena project previously existed, the project will be created beforehand. + Raises ProjectNotFoundError if the project could neither be found nor created. + + :return: a tuple of the project instance and a Boolean indicating whether the project was newly + created + """ + project_instance: Project | None = self.load_project_from_path_or_name(project_root_or_name, autogenerate=True) + if project_instance is None: + raise ProjectNotFoundError( + f"Project '{project_root_or_name}' not found: Not a valid project name or directory. " + f"Existing project names: {self.serena_config.project_names}" + ) + self._activate_project(project_instance) + return project_instance + + def get_active_tool_classes(self) -> list[type["Tool"]]: + """ + :return: the list of active tool classes for the current project + """ + return list(self._active_tools.keys()) + + def get_active_tool_names(self) -> list[str]: + """ + :return: the list of names of the active tools for the current project + """ + return sorted([tool.get_name_from_cls() for tool in self.get_active_tool_classes()]) + + def tool_is_active(self, tool_class: type["Tool"] | str) -> bool: + """ + :param tool_class: the class or name of the tool to check + :return: True if the tool is active, False otherwise + """ + if isinstance(tool_class, str): + return tool_class in self.get_active_tool_names() + else: + return tool_class in self.get_active_tool_classes() + + def get_current_config_overview(self) -> str: + """ + :return: a string overview of the current configuration, including the active and available configuration options + """ + result_str = "Current configuration:\n" + result_str += f"Serena version: {serena_version()}\n" + result_str += f"Loglevel: {self.serena_config.log_level}, trace_lsp_communication={self.serena_config.trace_lsp_communication}\n" + if self._active_project is not None: + result_str += f"Active project: {self._active_project.project_name}\n" + else: + result_str += "No active project\n" + result_str += "Available projects:\n" + "\n".join(list(self.serena_config.project_names)) + "\n" + result_str += f"Active context: {self._context.name}\n" + + # Active modes + active_mode_names = [mode.name for mode in self.get_active_modes()] + result_str += "Active modes: {}\n".format(", ".join(active_mode_names)) + "\n" + + # Available but not active modes + all_available_modes = SerenaAgentMode.list_registered_mode_names() + inactive_modes = [mode for mode in all_available_modes if mode not in active_mode_names] + if inactive_modes: + result_str += "Available but not active modes: {}\n".format(", ".join(inactive_modes)) + "\n" + + # Active tools + result_str += "Active tools (after all exclusions from the project, context, and modes):\n" + active_tool_names = self.get_active_tool_names() + # print the tool names in chunks + chunk_size = 4 + for i in range(0, len(active_tool_names), chunk_size): + chunk = active_tool_names[i : i + chunk_size] + result_str += " " + ", ".join(chunk) + "\n" + + # Available but not active tools + all_tool_names = sorted([tool.get_name_from_cls() for tool in self._all_tools.values()]) + inactive_tool_names = [tool for tool in all_tool_names if tool not in active_tool_names] + if inactive_tool_names: + result_str += "Available but not active tools:\n" + for i in range(0, len(inactive_tool_names), chunk_size): + chunk = inactive_tool_names[i : i + chunk_size] + result_str += " " + ", ".join(chunk) + "\n" + + return result_str + + def is_language_server_running(self) -> bool: + return self.language_server is not None and self.language_server.is_running() + + def reset_language_server(self) -> None: + """ + Starts/resets the language server for the current project + """ + tool_timeout = self.serena_config.tool_timeout + if tool_timeout is None or tool_timeout < 0: + ls_timeout = None + else: + if tool_timeout < 10: + raise ValueError(f"Tool timeout must be at least 10 seconds, but is {tool_timeout} seconds") + ls_timeout = tool_timeout - 5 # the LS timeout is for a single call, it should be smaller than the tool timeout + + # stop the language server if it is running + if self.is_language_server_running(): + assert self.language_server is not None + log.info(f"Stopping the current language server at {self.language_server.repository_root_path} ...") + self.language_server.stop() + self.language_server = None + + # instantiate and start the language server + assert self._active_project is not None + self.language_server = self._active_project.create_language_server( + log_level=self.serena_config.log_level, + ls_timeout=ls_timeout, + trace_lsp_communication=self.serena_config.trace_lsp_communication, + ls_specific_settings=self.serena_config.ls_specific_settings, + ) + log.info(f"Starting the language server for {self._active_project.project_name}") + self.language_server.start() + if not self.language_server.is_running(): + raise RuntimeError( + f"Failed to start the language server for {self._active_project.project_name} at {self._active_project.project_root}" + ) + + def get_tool(self, tool_class: type[TTool]) -> TTool: + return self._all_tools[tool_class] # type: ignore + + def print_tool_overview(self) -> None: + ToolRegistry().print_tool_overview(self._active_tools.values()) + + def __del__(self) -> None: + """ + Destructor to clean up the language server instance and GUI logger + """ + if not hasattr(self, "_is_initialized"): + return + log.info("SerenaAgent is shutting down ...") + if self.is_language_server_running(): + log.info("Stopping the language server ...") + assert self.language_server is not None + self.language_server.save_cache() + self.language_server.stop() + if self._gui_log_viewer: + log.info("Stopping the GUI log window ...") + self._gui_log_viewer.stop() + + def get_tool_by_name(self, tool_name: str) -> Tool: + tool_class = ToolRegistry().get_tool_class_by_name(tool_name) + return self.get_tool(tool_class) diff --git a/Libraries/serena/agno.py b/Libraries/serena/agno.py new file mode 100644 index 00000000..07ab7371 --- /dev/null +++ b/Libraries/serena/agno.py @@ -0,0 +1,145 @@ +import argparse +import logging +import os +import threading +from pathlib import Path +from typing import Any + +from agno.agent import Agent +from agno.memory import AgentMemory +from agno.models.base import Model +from agno.storage.sqlite import SqliteStorage +from agno.tools.function import Function +from agno.tools.toolkit import Toolkit +from dotenv import load_dotenv +from sensai.util.logging import LogTime + +from serena.agent import SerenaAgent, Tool +from serena.config.context_mode import SerenaAgentContext +from serena.constants import REPO_ROOT +from serena.util.exception import show_fatal_exception_safe + +log = logging.getLogger(__name__) + + +class SerenaAgnoToolkit(Toolkit): + def __init__(self, serena_agent: SerenaAgent): + super().__init__("Serena") + for tool in serena_agent.get_exposed_tool_instances(): + self.functions[tool.get_name_from_cls()] = self._create_agno_function(tool) + log.info("Agno agent functions: %s", list(self.functions.keys())) + + @staticmethod + def _create_agno_function(tool: Tool) -> Function: + def entrypoint(**kwargs: Any) -> str: + if "kwargs" in kwargs: + # Agno sometimes passes a kwargs argument explicitly, so we merge it + kwargs.update(kwargs["kwargs"]) + del kwargs["kwargs"] + log.info(f"Calling tool {tool}") + return tool.apply_ex(log_call=True, catch_exceptions=True, **kwargs) + + function = Function.from_callable(tool.get_apply_fn()) + function.name = tool.get_name_from_cls() + function.entrypoint = entrypoint + function.skip_entrypoint_processing = True + return function + + +class SerenaAgnoAgentProvider: + _agent: Agent | None = None + _lock = threading.Lock() + + @classmethod + def get_agent(cls, model: Model) -> Agent: + """ + Returns the singleton instance of the Serena agent or creates it with the given parameters if it doesn't exist. + + NOTE: This is very ugly with poor separation of concerns, but the way in which the Agno UI works (reloading the + module that defines the `app` variable) essentially forces us to do something like this. + + :param model: the large language model to use for the agent + :return: the agent instance + """ + with cls._lock: + if cls._agent is not None: + return cls._agent + + # change to Serena root + os.chdir(REPO_ROOT) + + load_dotenv() + + parser = argparse.ArgumentParser(description="Serena coding assistant") + + # Create a mutually exclusive group + group = parser.add_mutually_exclusive_group() + + # Add arguments to the group, both pointing to the same destination + group.add_argument( + "--project-file", + required=False, + help="Path to the project (or project.yml file).", + ) + group.add_argument( + "--project", + required=False, + help="Path to the project (or project.yml file).", + ) + args = parser.parse_args() + + args_project_file = args.project or args.project_file + + if args_project_file: + project_file = Path(args_project_file).resolve() + # If project file path is relative, make it absolute by joining with project root + if not project_file.is_absolute(): + # Get the project root directory (parent of scripts directory) + project_root = Path(REPO_ROOT) + project_file = project_root / args_project_file + + # Ensure the path is normalized and absolute + project_file = str(project_file.resolve()) + else: + project_file = None + + with LogTime("Loading Serena agent"): + try: + serena_agent = SerenaAgent(project_file, context=SerenaAgentContext.load("agent")) + except Exception as e: + show_fatal_exception_safe(e) + raise + + # Even though we don't want to keep history between sessions, + # for agno-ui to work as a conversation, we use a persistent storage on disk. + # This storage should be deleted between sessions. + # Note that this might collide with custom options for the agent, like adding vector-search based tools. + # See here for an explanation: https://www.reddit.com/r/agno/comments/1jk6qea/regarding_the_built_in_memory/ + sql_db_path = (Path("temp") / "agno_agent_storage.db").absolute() + sql_db_path.parent.mkdir(exist_ok=True) + # delete the db file if it exists + log.info(f"Deleting DB from PID {os.getpid()}") + if sql_db_path.exists(): + sql_db_path.unlink() + + agno_agent = Agent( + name="Serena", + model=model, + # See explanation above on why storage is needed + storage=SqliteStorage(table_name="serena_agent_sessions", db_file=str(sql_db_path)), + description="A fully-featured coding assistant", + tools=[SerenaAgnoToolkit(serena_agent)], + # The tool calls will be shown in the UI anyway since whether to show them is configurable per tool + # To see detailed logs, you should use the serena logger (configure it in the project file path) + show_tool_calls=False, + markdown=True, + system_message=serena_agent.create_system_prompt(), + telemetry=False, + memory=AgentMemory(), + add_history_to_messages=True, + num_history_responses=100, # you might want to adjust this (expense vs. history awareness) + ) + cls._agent = agno_agent + log.info(f"Agent instantiated: {agno_agent}") + + return agno_agent diff --git a/Libraries/serena/analytics.py b/Libraries/serena/analytics.py new file mode 100644 index 00000000..c773fc08 --- /dev/null +++ b/Libraries/serena/analytics.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import logging +import threading +from abc import ABC, abstractmethod +from collections import defaultdict +from copy import copy +from dataclasses import asdict, dataclass +from enum import Enum + +from anthropic.types import MessageParam, MessageTokensCount +from dotenv import load_dotenv + +log = logging.getLogger(__name__) + + +class TokenCountEstimator(ABC): + @abstractmethod + def estimate_token_count(self, text: str) -> int: + """ + Estimate the number of tokens in the given text. + This is an abstract method that should be implemented by subclasses. + """ + + +class TiktokenCountEstimator(TokenCountEstimator): + """ + Approximate token count using tiktoken. + """ + + def __init__(self, model_name: str = "gpt-4o"): + """ + The tokenizer will be downloaded on the first initialization, which may take some time. + + :param model_name: see `tiktoken.model` to see available models. + """ + import tiktoken + + log.info(f"Loading tiktoken encoding for model {model_name}, this may take a while on the first run.") + self._encoding = tiktoken.encoding_for_model(model_name) + + def estimate_token_count(self, text: str) -> int: + return len(self._encoding.encode(text)) + + +class AnthropicTokenCount(TokenCountEstimator): + """ + The exact count using the Anthropic API. + Counting is free, but has a rate limit and will require an API key, + (typically, set through an env variable). + See https://docs.anthropic.com/en/docs/build-with-claude/token-counting + """ + + def __init__(self, model_name: str = "claude-sonnet-4-20250514", api_key: str | None = None): + import anthropic + + self._model_name = model_name + if api_key is None: + load_dotenv() + self._anthropic_client = anthropic.Anthropic(api_key=api_key) + + def _send_count_tokens_request(self, text: str) -> MessageTokensCount: + return self._anthropic_client.messages.count_tokens( + model=self._model_name, + messages=[MessageParam(role="user", content=text)], + ) + + def estimate_token_count(self, text: str) -> int: + return self._send_count_tokens_request(text).input_tokens + + +_registered_token_estimator_instances_cache: dict[RegisteredTokenCountEstimator, TokenCountEstimator] = {} + + +class RegisteredTokenCountEstimator(Enum): + TIKTOKEN_GPT4O = "TIKTOKEN_GPT4O" + ANTHROPIC_CLAUDE_SONNET_4 = "ANTHROPIC_CLAUDE_SONNET_4" + + @classmethod + def get_valid_names(cls) -> list[str]: + """ + Get a list of all registered token count estimator names. + """ + return [estimator.name for estimator in cls] + + def _create_estimator(self) -> TokenCountEstimator: + match self: + case RegisteredTokenCountEstimator.TIKTOKEN_GPT4O: + return TiktokenCountEstimator(model_name="gpt-4o") + case RegisteredTokenCountEstimator.ANTHROPIC_CLAUDE_SONNET_4: + return AnthropicTokenCount(model_name="claude-sonnet-4-20250514") + case _: + raise ValueError(f"Unknown token count estimator: {self.value}") + + def load_estimator(self) -> TokenCountEstimator: + estimator_instance = _registered_token_estimator_instances_cache.get(self) + if estimator_instance is None: + estimator_instance = self._create_estimator() + _registered_token_estimator_instances_cache[self] = estimator_instance + return estimator_instance + + +class ToolUsageStats: + """ + A class to record and manage tool usage statistics. + """ + + def __init__(self, token_count_estimator: RegisteredTokenCountEstimator = RegisteredTokenCountEstimator.TIKTOKEN_GPT4O): + self._token_count_estimator = token_count_estimator.load_estimator() + self._token_estimator_name = token_count_estimator.value + self._tool_stats: dict[str, ToolUsageStats.Entry] = defaultdict(ToolUsageStats.Entry) + self._tool_stats_lock = threading.Lock() + + @property + def token_estimator_name(self) -> str: + """ + Get the name of the registered token count estimator used. + """ + return self._token_estimator_name + + @dataclass(kw_only=True) + class Entry: + num_times_called: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + + def update_on_call(self, input_tokens: int, output_tokens: int) -> None: + """ + Update the entry with the number of tokens used for a single call. + """ + self.num_times_called += 1 + self.input_tokens += input_tokens + self.output_tokens += output_tokens + + def _estimate_token_count(self, text: str) -> int: + return self._token_count_estimator.estimate_token_count(text) + + def get_stats(self, tool_name: str) -> ToolUsageStats.Entry: + """ + Get (a copy of) the current usage statistics for a specific tool. + """ + with self._tool_stats_lock: + return copy(self._tool_stats[tool_name]) + + def record_tool_usage(self, tool_name: str, input_str: str, output_str: str) -> None: + input_tokens = self._estimate_token_count(input_str) + output_tokens = self._estimate_token_count(output_str) + with self._tool_stats_lock: + entry = self._tool_stats[tool_name] + entry.update_on_call(input_tokens, output_tokens) + + def get_tool_stats_dict(self) -> dict[str, dict[str, int]]: + with self._tool_stats_lock: + return {name: asdict(entry) for name, entry in self._tool_stats.items()} + + def clear(self) -> None: + with self._tool_stats_lock: + self._tool_stats.clear() diff --git a/Libraries/serena/cli.py b/Libraries/serena/cli.py new file mode 100644 index 00000000..0db76046 --- /dev/null +++ b/Libraries/serena/cli.py @@ -0,0 +1,838 @@ +import glob +import json +import os +import shutil +import subprocess +import sys +from logging import Logger +from pathlib import Path +from typing import Any, Literal + +import click +from sensai.util import logging +from sensai.util.logging import FileLoggerContext, datetime_tag +from tqdm import tqdm + +from serena.agent import SerenaAgent +from serena.config.context_mode import SerenaAgentContext, SerenaAgentMode +from serena.config.serena_config import ProjectConfig, SerenaConfig, SerenaPaths +from serena.constants import ( + DEFAULT_CONTEXT, + DEFAULT_MODES, + PROMPT_TEMPLATES_DIR_IN_USER_HOME, + PROMPT_TEMPLATES_DIR_INTERNAL, + SERENA_LOG_FORMAT, + SERENA_MANAGED_DIR_IN_HOME, + SERENAS_OWN_CONTEXT_YAMLS_DIR, + SERENAS_OWN_MODE_YAMLS_DIR, + USER_CONTEXT_YAMLS_DIR, + USER_MODE_YAMLS_DIR, +) +from serena.mcp import SerenaMCPFactory, SerenaMCPFactorySingleProcess +from serena.project import Project +from serena.tools import FindReferencingSymbolsTool, FindSymbolTool, GetSymbolsOverviewTool, SearchForPatternTool, ToolRegistry +from serena.util.logging import MemoryLogHandler +from solidlsp.ls_config import Language +from solidlsp.util.subprocess_util import subprocess_kwargs + +log = logging.getLogger(__name__) + +# --------------------- Utilities ------------------------------------- + + +def _open_in_editor(path: str) -> None: + """Open the given file in the system's default editor or viewer.""" + editor = os.environ.get("EDITOR") + run_kwargs = subprocess_kwargs() + try: + if editor: + subprocess.run([editor, path], check=False, **run_kwargs) + elif sys.platform.startswith("win"): + try: + os.startfile(path) + except OSError: + subprocess.run(["notepad.exe", path], check=False, **run_kwargs) + elif sys.platform == "darwin": + subprocess.run(["open", path], check=False, **run_kwargs) + else: + subprocess.run(["xdg-open", path], check=False, **run_kwargs) + except Exception as e: + print(f"Failed to open {path}: {e}") + + +class ProjectType(click.ParamType): + """ParamType allowing either a project name or a path to a project directory.""" + + name = "[PROJECT_NAME|PROJECT_PATH]" + + def convert(self, value: str, param: Any, ctx: Any) -> str: + path = Path(value).resolve() + if path.exists() and path.is_dir(): + return str(path) + return value + + +PROJECT_TYPE = ProjectType() + + +class AutoRegisteringGroup(click.Group): + """ + A click.Group subclass that automatically registers any click.Command + attributes defined on the class into the group. + + After initialization, it inspects its own class for attributes that are + instances of click.Command (typically created via @click.command) and + calls self.add_command(cmd) on each. This lets you define your commands + as static methods on the subclass for IDE-friendly organization without + manual registration. + """ + + def __init__(self, name: str, help: str): + super().__init__(name=name, help=help) + # Scan class attributes for click.Command instances and register them. + for attr in dir(self.__class__): + cmd = getattr(self.__class__, attr) + if isinstance(cmd, click.Command): + self.add_command(cmd) + + +class TopLevelCommands(AutoRegisteringGroup): + """Root CLI group containing the core Serena commands.""" + + def __init__(self) -> None: + super().__init__(name="serena", help="Serena CLI commands. You can run ` --help` for more info on each command.") + + @staticmethod + @click.command("start-mcp-server", help="Starts the Serena MCP server.") + @click.option("--project", "project", type=PROJECT_TYPE, default=None, help="Path or name of project to activate at startup.") + @click.option("--project-file", "project", type=PROJECT_TYPE, default=None, help="[DEPRECATED] Use --project instead.") + @click.argument("project_file_arg", type=PROJECT_TYPE, required=False, default=None, metavar="") + @click.option( + "--context", type=str, default=DEFAULT_CONTEXT, show_default=True, help="Built-in context name or path to custom context YAML." + ) + @click.option( + "--mode", + "modes", + type=str, + multiple=True, + default=DEFAULT_MODES, + show_default=True, + help="Built-in mode names or paths to custom mode YAMLs.", + ) + @click.option( + "--transport", + type=click.Choice(["stdio", "sse", "streamable-http"]), + default="stdio", + show_default=True, + help="Transport protocol.", + ) + @click.option("--host", type=str, default="0.0.0.0", show_default=True) + @click.option("--port", type=int, default=8000, show_default=True) + @click.option("--enable-web-dashboard", type=bool, is_flag=False, default=None, help="Override dashboard setting in config.") + @click.option("--enable-gui-log-window", type=bool, is_flag=False, default=None, help="Override GUI log window setting in config.") + @click.option( + "--log-level", + type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), + default=None, + help="Override log level in config.", + ) + @click.option("--trace-lsp-communication", type=bool, is_flag=False, default=None, help="Whether to trace LSP communication.") + @click.option("--tool-timeout", type=float, default=None, help="Override tool execution timeout in config.") + def start_mcp_server( + project: str | None, + project_file_arg: str | None, + context: str, + modes: tuple[str, ...], + transport: Literal["stdio", "sse", "streamable-http"], + host: str, + port: int, + enable_web_dashboard: bool | None, + enable_gui_log_window: bool | None, + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | None, + trace_lsp_communication: bool | None, + tool_timeout: float | None, + ) -> None: + # initialize logging, using INFO level initially (will later be adjusted by SerenaAgent according to the config) + # * memory log handler (for use by GUI/Dashboard) + # * stream handler for stderr (for direct console output, which will also be captured by clients like Claude Desktop) + # * file handler + # (Note that stdout must never be used for logging, as it is used by the MCP server to communicate with the client.) + Logger.root.setLevel(logging.INFO) + formatter = logging.Formatter(SERENA_LOG_FORMAT) + memory_log_handler = MemoryLogHandler() + Logger.root.addHandler(memory_log_handler) + stderr_handler = logging.StreamHandler(stream=sys.stderr) + stderr_handler.formatter = formatter + Logger.root.addHandler(stderr_handler) + log_path = SerenaPaths().get_next_log_file_path("mcp") + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.formatter = formatter + Logger.root.addHandler(file_handler) + + log.info("Initializing Serena MCP server") + log.info("Storing logs in %s", log_path) + project_file = project_file_arg or project + factory = SerenaMCPFactorySingleProcess(context=context, project=project_file, memory_log_handler=memory_log_handler) + server = factory.create_mcp_server( + host=host, + port=port, + modes=modes, + enable_web_dashboard=enable_web_dashboard, + enable_gui_log_window=enable_gui_log_window, + log_level=log_level, + trace_lsp_communication=trace_lsp_communication, + tool_timeout=tool_timeout, + ) + if project_file_arg: + log.warning( + "Positional project arg is deprecated; use --project instead. Used: %s", + project_file, + ) + log.info("Starting MCP server …") + server.run(transport=transport) + + @staticmethod + @click.command("print-system-prompt", help="Print the system prompt for a project.") + @click.argument("project", type=click.Path(exists=True), default=os.getcwd(), required=False) + @click.option( + "--log-level", + type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), + default="WARNING", + help="Log level for prompt generation.", + ) + @click.option("--only-instructions", is_flag=True, help="Print only the initial instructions, without prefix/postfix.") + @click.option( + "--context", type=str, default=DEFAULT_CONTEXT, show_default=True, help="Built-in context name or path to custom context YAML." + ) + @click.option( + "--mode", + "modes", + type=str, + multiple=True, + default=DEFAULT_MODES, + show_default=True, + help="Built-in mode names or paths to custom mode YAMLs.", + ) + def print_system_prompt(project: str, log_level: str, only_instructions: bool, context: str, modes: tuple[str, ...]) -> None: + prefix = "You will receive access to Serena's symbolic tools. Below are instructions for using them, take them into account." + postfix = "You begin by acknowledging that you understood the above instructions and are ready to receive tasks." + from serena.tools.workflow_tools import InitialInstructionsTool + + lvl = logging.getLevelNamesMapping()[log_level.upper()] + logging.configure(level=lvl) + context_instance = SerenaAgentContext.load(context) + mode_instances = [SerenaAgentMode.load(mode) for mode in modes] + agent = SerenaAgent( + project=os.path.abspath(project), + serena_config=SerenaConfig(web_dashboard=False, log_level=lvl), + context=context_instance, + modes=mode_instances, + ) + tool = agent.get_tool(InitialInstructionsTool) + instr = tool.apply() + if only_instructions: + print(instr) + else: + print(f"{prefix}\n{instr}\n{postfix}") + + +class ModeCommands(AutoRegisteringGroup): + """Group for 'mode' subcommands.""" + + def __init__(self) -> None: + super().__init__(name="mode", help="Manage Serena modes. You can run `mode --help` for more info on each command.") + + @staticmethod + @click.command("list", help="List available modes.") + def list() -> None: + mode_names = SerenaAgentMode.list_registered_mode_names() + max_len_name = max(len(name) for name in mode_names) if mode_names else 20 + for name in mode_names: + mode_yml_path = SerenaAgentMode.get_path(name) + is_internal = Path(mode_yml_path).is_relative_to(SERENAS_OWN_MODE_YAMLS_DIR) + descriptor = "(internal)" if is_internal else f"(at {mode_yml_path})" + name_descr_string = f"{name:<{max_len_name + 4}}{descriptor}" + click.echo(name_descr_string) + + @staticmethod + @click.command("create", help="Create a new mode or copy an internal one.") + @click.option( + "--name", + "-n", + type=str, + default=None, + help="Name for the new mode. If --from-internal is passed may be left empty to create a mode of the same name, which will then override the internal mode.", + ) + @click.option("--from-internal", "from_internal", type=str, default=None, help="Copy from an internal mode.") + def create(name: str, from_internal: str) -> None: + if not (name or from_internal): + raise click.UsageError("Provide at least one of --name or --from-internal.") + mode_name = name or from_internal + dest = os.path.join(USER_MODE_YAMLS_DIR, f"{mode_name}.yml") + src = ( + os.path.join(SERENAS_OWN_MODE_YAMLS_DIR, f"{from_internal}.yml") + if from_internal + else os.path.join(SERENAS_OWN_MODE_YAMLS_DIR, "mode.template.yml") + ) + if not os.path.exists(src): + raise FileNotFoundError( + f"Internal mode '{from_internal}' not found in {SERENAS_OWN_MODE_YAMLS_DIR}. Available modes: {SerenaAgentMode.list_registered_mode_names()}" + ) + os.makedirs(os.path.dirname(dest), exist_ok=True) + shutil.copyfile(src, dest) + click.echo(f"Created mode '{mode_name}' at {dest}") + _open_in_editor(dest) + + @staticmethod + @click.command("edit", help="Edit a custom mode YAML file.") + @click.argument("mode_name") + def edit(mode_name: str) -> None: + path = os.path.join(USER_MODE_YAMLS_DIR, f"{mode_name}.yml") + if not os.path.exists(path): + if mode_name in SerenaAgentMode.list_registered_mode_names(include_user_modes=False): + click.echo( + f"Mode '{mode_name}' is an internal mode and cannot be edited directly. " + f"Use 'mode create --from-internal {mode_name}' to create a custom mode that overrides it before editing." + ) + else: + click.echo(f"Custom mode '{mode_name}' not found. Create it with: mode create --name {mode_name}.") + return + _open_in_editor(path) + + @staticmethod + @click.command("delete", help="Delete a custom mode file.") + @click.argument("mode_name") + def delete(mode_name: str) -> None: + path = os.path.join(USER_MODE_YAMLS_DIR, f"{mode_name}.yml") + if not os.path.exists(path): + click.echo(f"Custom mode '{mode_name}' not found.") + return + os.remove(path) + click.echo(f"Deleted custom mode '{mode_name}'.") + + +class ContextCommands(AutoRegisteringGroup): + """Group for 'context' subcommands.""" + + def __init__(self) -> None: + super().__init__( + name="context", help="Manage Serena contexts. You can run `context --help` for more info on each command." + ) + + @staticmethod + @click.command("list", help="List available contexts.") + def list() -> None: + context_names = SerenaAgentContext.list_registered_context_names() + max_len_name = max(len(name) for name in context_names) if context_names else 20 + for name in context_names: + context_yml_path = SerenaAgentContext.get_path(name) + is_internal = Path(context_yml_path).is_relative_to(SERENAS_OWN_CONTEXT_YAMLS_DIR) + descriptor = "(internal)" if is_internal else f"(at {context_yml_path})" + name_descr_string = f"{name:<{max_len_name + 4}}{descriptor}" + click.echo(name_descr_string) + + @staticmethod + @click.command("create", help="Create a new context or copy an internal one.") + @click.option( + "--name", + "-n", + type=str, + default=None, + help="Name for the new context. If --from-internal is passed may be left empty to create a context of the same name, which will then override the internal context", + ) + @click.option("--from-internal", "from_internal", type=str, default=None, help="Copy from an internal context.") + def create(name: str, from_internal: str) -> None: + if not (name or from_internal): + raise click.UsageError("Provide at least one of --name or --from-internal.") + ctx_name = name or from_internal + dest = os.path.join(USER_CONTEXT_YAMLS_DIR, f"{ctx_name}.yml") + src = ( + os.path.join(SERENAS_OWN_CONTEXT_YAMLS_DIR, f"{from_internal}.yml") + if from_internal + else os.path.join(SERENAS_OWN_CONTEXT_YAMLS_DIR, "context.template.yml") + ) + if not os.path.exists(src): + raise FileNotFoundError( + f"Internal context '{from_internal}' not found in {SERENAS_OWN_CONTEXT_YAMLS_DIR}. Available contexts: {SerenaAgentContext.list_registered_context_names()}" + ) + os.makedirs(os.path.dirname(dest), exist_ok=True) + shutil.copyfile(src, dest) + click.echo(f"Created context '{ctx_name}' at {dest}") + _open_in_editor(dest) + + @staticmethod + @click.command("edit", help="Edit a custom context YAML file.") + @click.argument("context_name") + def edit(context_name: str) -> None: + path = os.path.join(USER_CONTEXT_YAMLS_DIR, f"{context_name}.yml") + if not os.path.exists(path): + if context_name in SerenaAgentContext.list_registered_context_names(include_user_contexts=False): + click.echo( + f"Context '{context_name}' is an internal context and cannot be edited directly. " + f"Use 'context create --from-internal {context_name}' to create a custom context that overrides it before editing." + ) + else: + click.echo(f"Custom context '{context_name}' not found. Create it with: context create --name {context_name}.") + return + _open_in_editor(path) + + @staticmethod + @click.command("delete", help="Delete a custom context file.") + @click.argument("context_name") + def delete(context_name: str) -> None: + path = os.path.join(USER_CONTEXT_YAMLS_DIR, f"{context_name}.yml") + if not os.path.exists(path): + click.echo(f"Custom context '{context_name}' not found.") + return + os.remove(path) + click.echo(f"Deleted custom context '{context_name}'.") + + +class SerenaConfigCommands(AutoRegisteringGroup): + """Group for 'config' subcommands.""" + + def __init__(self) -> None: + super().__init__(name="config", help="Manage Serena configuration.") + + @staticmethod + @click.command( + "edit", help="Edit serena_config.yml in your default editor. Will create a config file from the template if no config is found." + ) + def edit() -> None: + config_path = os.path.join(SERENA_MANAGED_DIR_IN_HOME, "serena_config.yml") + if not os.path.exists(config_path): + SerenaConfig.generate_config_file(config_path) + _open_in_editor(config_path) + + +class ProjectCommands(AutoRegisteringGroup): + """Group for 'project' subcommands.""" + + def __init__(self) -> None: + super().__init__( + name="project", help="Manage Serena projects. You can run `project --help` for more info on each command." + ) + + @staticmethod + @click.command("generate-yml", help="Generate a project.yml file.") + @click.argument("project_path", type=click.Path(exists=True, file_okay=False), default=os.getcwd()) + @click.option("--language", type=str, default=None, help="Programming language; inferred if not specified.") + def generate_yml(project_path: str, language: str | None = None) -> None: + yml_path = os.path.join(project_path, ProjectConfig.rel_path_to_project_yml()) + if os.path.exists(yml_path): + raise FileExistsError(f"Project file {yml_path} already exists.") + lang_inst = None + if language: + try: + lang_inst = Language[language.upper()] + except KeyError: + all_langs = [l.name.lower() for l in Language.iter_all(include_experimental=True)] + raise ValueError(f"Unknown language '{language}'. Supported: {all_langs}") + generated_conf = ProjectConfig.autogenerate(project_root=project_path, project_language=lang_inst) + print(f"Generated project.yml with language {generated_conf.language.value} at {yml_path}.") + + @staticmethod + @click.command("index", help="Index a project by saving symbols to the LSP cache.") + @click.argument("project", type=click.Path(exists=True), default=os.getcwd(), required=False) + @click.option( + "--log-level", + type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), + default="WARNING", + help="Log level for indexing.", + ) + @click.option("--timeout", type=float, default=10, help="Timeout for indexing a single file.") + def index(project: str, log_level: str, timeout: float) -> None: + ProjectCommands._index_project(project, log_level, timeout=timeout) + + @staticmethod + @click.command("index-deprecated", help="Deprecated alias for 'serena project index'.") + @click.argument("project", type=click.Path(exists=True), default=os.getcwd(), required=False) + @click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="WARNING") + @click.option("--timeout", type=float, default=10, help="Timeout for indexing a single file.") + def index_deprecated(project: str, log_level: str, timeout: float) -> None: + click.echo("Deprecated! Use `serena project index` instead.") + ProjectCommands._index_project(project, log_level, timeout=timeout) + + @staticmethod + def _index_project(project: str, log_level: str, timeout: float) -> None: + lvl = logging.getLevelNamesMapping()[log_level.upper()] + logging.configure(level=lvl) + serena_config = SerenaConfig.from_config_file() + proj = Project.load(os.path.abspath(project)) + click.echo(f"Indexing symbols in project {project}…") + ls = proj.create_language_server(log_level=lvl, ls_timeout=timeout, ls_specific_settings=serena_config.ls_specific_settings) + log_file = os.path.join(project, ".serena", "logs", "indexing.txt") + + collected_exceptions: list[Exception] = [] + files_failed = [] + with ls.start_server(): + files = proj.gather_source_files() + for i, f in enumerate(tqdm(files, desc="Indexing")): + try: + ls.request_document_symbols(f, include_body=False) + ls.request_document_symbols(f, include_body=True) + except Exception as e: + log.error(f"Failed to index {f}, continuing.") + collected_exceptions.append(e) + files_failed.append(f) + if (i + 1) % 10 == 0: + ls.save_cache() + ls.save_cache() + click.echo(f"Symbols saved to {ls.cache_path}") + if len(files_failed) > 0: + os.makedirs(os.path.dirname(log_file), exist_ok=True) + with open(log_file, "w") as f: + for file, exception in zip(files_failed, collected_exceptions, strict=True): + f.write(f"{file}\n") + f.write(f"{exception}\n") + click.echo(f"Failed to index {len(files_failed)} files, see:\n{log_file}") + + @staticmethod + @click.command("is_ignored_path", help="Check if a path is ignored by the project configuration.") + @click.argument("path", type=click.Path(exists=False, file_okay=True, dir_okay=True)) + @click.argument("project", type=click.Path(exists=True, file_okay=False, dir_okay=True), default=os.getcwd()) + def is_ignored_path(path: str, project: str) -> None: + """ + Check if a given path is ignored by the project configuration. + + :param path: The path to check. + :param project: The path to the project directory, defaults to the current working directory. + """ + proj = Project.load(os.path.abspath(project)) + if os.path.isabs(path): + path = os.path.relpath(path, start=proj.project_root) + is_ignored = proj.is_ignored_path(path) + click.echo(f"Path '{path}' IS {'ignored' if is_ignored else 'IS NOT ignored'} by the project configuration.") + + @staticmethod + @click.command("index-file", help="Index a single file by saving its symbols to the LSP cache.") + @click.argument("file", type=click.Path(exists=True, file_okay=True, dir_okay=False)) + @click.argument("project", type=click.Path(exists=True, file_okay=False, dir_okay=True), default=os.getcwd()) + @click.option("--verbose", "-v", is_flag=True, help="Print detailed information about the indexed symbols.") + def index_file(file: str, project: str, verbose: bool) -> None: + """ + Index a single file by saving its symbols to the LSP cache, useful for debugging. + :param file: path to the file to index, must be inside the project directory. + :param project: path to the project directory, defaults to the current working directory. + :param verbose: if set, prints detailed information about the indexed symbols. + """ + proj = Project.load(os.path.abspath(project)) + if os.path.isabs(file): + file = os.path.relpath(file, start=proj.project_root) + if proj.is_ignored_path(file, ignore_non_source_files=True): + click.echo(f"'{file}' is ignored or declared as non-code file by the project configuration, won't index.") + exit(1) + ls = proj.create_language_server() + with ls.start_server(): + symbols, _ = ls.request_document_symbols(file, include_body=False) + ls.request_document_symbols(file, include_body=True) + if verbose: + click.echo(f"Symbols in file '{file}':") + for symbol in symbols: + click.echo(f" - {symbol['name']} at line {symbol['selectionRange']['start']['line']} of kind {symbol['kind']}") + ls.save_cache() + click.echo(f"Successfully indexed file '{file}', {len(symbols)} symbols saved to {ls.cache_path}.") + + @staticmethod + @click.command("health-check", help="Perform a comprehensive health check of the project's tools and language server.") + @click.argument("project", type=click.Path(exists=True, file_okay=False, dir_okay=True), default=os.getcwd()) + def health_check(project: str) -> None: + """ + Perform a comprehensive health check of the project's tools and language server. + + :param project: path to the project directory, defaults to the current working directory. + """ + # NOTE: completely written by Claude Code, only functionality was reviewed, not implementation + logging.configure(level=logging.INFO) + project_path = os.path.abspath(project) + proj = Project.load(project_path) + + # Create log file with timestamp + timestamp = datetime_tag() + log_dir = os.path.join(project_path, ".serena", "logs", "health-checks") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"health_check_{timestamp}.log") + + with FileLoggerContext(log_file, append=False, enabled=True): + log.info("Starting health check for project: %s", project_path) + + try: + # Create SerenaAgent with dashboard disabled + log.info("Creating SerenaAgent with disabled dashboard...") + config = SerenaConfig(gui_log_window_enabled=False, web_dashboard=False) + agent = SerenaAgent(project=project_path, serena_config=config) + log.info("SerenaAgent created successfully") + + # Find first non-empty file that can be analyzed + log.info("Searching for analyzable files...") + files = proj.gather_source_files() + target_file = None + + for file_path in files: + try: + full_path = os.path.join(project_path, file_path) + if os.path.getsize(full_path) > 0: + target_file = file_path + log.info("Found analyzable file: %s", target_file) + break + except (OSError, FileNotFoundError): + continue + + if not target_file: + log.error("No analyzable files found in project") + click.echo("❌ Health check failed: No analyzable files found") + click.echo(f"Log saved to: {log_file}") + return + + # Get tools from agent + overview_tool = agent.get_tool(GetSymbolsOverviewTool) + find_symbol_tool = agent.get_tool(FindSymbolTool) + find_refs_tool = agent.get_tool(FindReferencingSymbolsTool) + search_pattern_tool = agent.get_tool(SearchForPatternTool) + + # Test 1: Get symbols overview + log.info("Testing GetSymbolsOverviewTool on file: %s", target_file) + overview_result = agent.execute_task(lambda: overview_tool.apply(target_file)) + overview_data = json.loads(overview_result) + log.info("GetSymbolsOverviewTool returned %d symbols", len(overview_data)) + + if not overview_data: + log.error("No symbols found in file %s", target_file) + click.echo("❌ Health check failed: No symbols found in target file") + click.echo(f"Log saved to: {log_file}") + return + + # Extract suitable symbol (prefer class or function over variables) + # LSP symbol kinds: 5=class, 12=function, 6=method, 9=constructor + preferred_kinds = [5, 12, 6, 9] # class, function, method, constructor + + selected_symbol = None + for symbol in overview_data: + if symbol.get("kind") in preferred_kinds: + selected_symbol = symbol + break + + # If no preferred symbol found, use first available + if not selected_symbol: + selected_symbol = overview_data[0] + log.info("No class or function found, using first available symbol") + + symbol_name = selected_symbol.get("name_path", "unknown") + symbol_kind = selected_symbol.get("kind", "unknown") + log.info("Using symbol for testing: %s (kind: %d)", symbol_name, symbol_kind) + + # Test 2: FindSymbolTool + log.info("Testing FindSymbolTool for symbol: %s", symbol_name) + find_symbol_result = agent.execute_task( + lambda: find_symbol_tool.apply(symbol_name, relative_path=target_file, include_body=True) + ) + find_symbol_data = json.loads(find_symbol_result) + log.info("FindSymbolTool found %d matches for symbol %s", len(find_symbol_data), symbol_name) + + # Test 3: FindReferencingSymbolsTool + log.info("Testing FindReferencingSymbolsTool for symbol: %s", symbol_name) + try: + find_refs_result = agent.execute_task(lambda: find_refs_tool.apply(symbol_name, relative_path=target_file)) + find_refs_data = json.loads(find_refs_result) + log.info("FindReferencingSymbolsTool found %d references for symbol %s", len(find_refs_data), symbol_name) + except Exception as e: + log.warning("FindReferencingSymbolsTool failed for symbol %s: %s", symbol_name, str(e)) + find_refs_data = [] + + # Test 4: SearchForPatternTool to verify references + log.info("Testing SearchForPatternTool for pattern: %s", symbol_name) + try: + search_result = agent.execute_task( + lambda: search_pattern_tool.apply(substring_pattern=symbol_name, restrict_search_to_code_files=True) + ) + search_data = json.loads(search_result) + pattern_matches = sum(len(matches) for matches in search_data.values()) + log.info("SearchForPatternTool found %d pattern matches for %s", pattern_matches, symbol_name) + except Exception as e: + log.warning("SearchForPatternTool failed for pattern %s: %s", symbol_name, str(e)) + pattern_matches = 0 + + # Verify tools worked as expected + tools_working = True + if not find_symbol_data: + log.error("FindSymbolTool returned no results") + tools_working = False + + if len(find_refs_data) == 0 and pattern_matches == 0: + log.warning("Both FindReferencingSymbolsTool and SearchForPatternTool found no matches - this might indicate an issue") + + log.info("Health check completed successfully") + + if tools_working: + click.echo("✅ Health check passed - All tools working correctly") + else: + click.echo("⚠️ Health check completed with warnings - Check log for details") + + except Exception as e: + log.exception("Health check failed with exception: %s", str(e)) + click.echo(f"❌ Health check failed: {e!s}") + + finally: + click.echo(f"Log saved to: {log_file}") + + +class ToolCommands(AutoRegisteringGroup): + """Group for 'tool' subcommands.""" + + def __init__(self) -> None: + super().__init__( + name="tools", + help="Commands related to Serena's tools. You can run `serena tools --help` for more info on each command.", + ) + + @staticmethod + @click.command( + "list", + help="Prints an overview of the tools that are active by default (not just the active ones for your project). For viewing all tools, pass `--all / -a`", + ) + @click.option("--quiet", "-q", is_flag=True) + @click.option("--all", "-a", "include_optional", is_flag=True, help="List all tools, including those not enabled by default.") + @click.option("--only-optional", is_flag=True, help="List only optional tools (those not enabled by default).") + def list(quiet: bool = False, include_optional: bool = False, only_optional: bool = False) -> None: + tool_registry = ToolRegistry() + if quiet: + if only_optional: + tool_names = tool_registry.get_tool_names_optional() + elif include_optional: + tool_names = tool_registry.get_tool_names() + else: + tool_names = tool_registry.get_tool_names_default_enabled() + for tool_name in tool_names: + click.echo(tool_name) + else: + ToolRegistry().print_tool_overview(include_optional=include_optional, only_optional=only_optional) + + @staticmethod + @click.command( + "description", + help="Print the description of a tool, optionally with a specific context (the latter may modify the default description).", + ) + @click.argument("tool_name", type=str) + @click.option("--context", type=str, default=None, help="Context name or path to context file.") + def description(tool_name: str, context: str | None = None) -> None: + # Load the context + serena_context = None + if context: + serena_context = SerenaAgentContext.load(context) + + agent = SerenaAgent( + project=None, + serena_config=SerenaConfig(web_dashboard=False, log_level=logging.INFO), + context=serena_context, + ) + tool = agent.get_tool_by_name(tool_name) + mcp_tool = SerenaMCPFactory.make_mcp_tool(tool) + click.echo(mcp_tool.description) + + +class PromptCommands(AutoRegisteringGroup): + def __init__(self) -> None: + super().__init__(name="prompts", help="Commands related to Serena's prompts that are outside of contexts and modes.") + + @staticmethod + def _get_user_prompt_yaml_path(prompt_yaml_name: str) -> str: + os.makedirs(PROMPT_TEMPLATES_DIR_IN_USER_HOME, exist_ok=True) + return os.path.join(PROMPT_TEMPLATES_DIR_IN_USER_HOME, prompt_yaml_name) + + @staticmethod + @click.command("list", help="Lists yamls that are used for defining prompts.") + def list() -> None: + serena_prompt_yaml_names = [os.path.basename(f) for f in glob.glob(PROMPT_TEMPLATES_DIR_INTERNAL + "/*.yml")] + for prompt_yaml_name in serena_prompt_yaml_names: + user_prompt_yaml_path = PromptCommands._get_user_prompt_yaml_path(prompt_yaml_name) + if os.path.exists(user_prompt_yaml_path): + click.echo(f"{user_prompt_yaml_path} merged with default prompts in {prompt_yaml_name}") + else: + click.echo(prompt_yaml_name) + + @staticmethod + @click.command("create-override", help="Create an override of an internal prompts yaml for customizing Serena's prompts") + @click.argument("prompt_yaml_name") + def create_override(prompt_yaml_name: str) -> None: + """ + :param prompt_yaml_name: The yaml name of the prompt you want to override. Call the `list` command for discovering valid prompt yaml names. + :return: + """ + # for convenience, we can pass names without .yml + if not prompt_yaml_name.endswith(".yml"): + prompt_yaml_name = prompt_yaml_name + ".yml" + user_prompt_yaml_path = PromptCommands._get_user_prompt_yaml_path(prompt_yaml_name) + if os.path.exists(user_prompt_yaml_path): + raise FileExistsError(f"{user_prompt_yaml_path} already exists.") + serena_prompt_yaml_path = os.path.join(PROMPT_TEMPLATES_DIR_INTERNAL, prompt_yaml_name) + shutil.copyfile(serena_prompt_yaml_path, user_prompt_yaml_path) + _open_in_editor(user_prompt_yaml_path) + + @staticmethod + @click.command("edit-override", help="Edit an existing prompt override file") + @click.argument("prompt_yaml_name") + def edit_override(prompt_yaml_name: str) -> None: + """ + :param prompt_yaml_name: The yaml name of the prompt override to edit. + :return: + """ + # for convenience, we can pass names without .yml + if not prompt_yaml_name.endswith(".yml"): + prompt_yaml_name = prompt_yaml_name + ".yml" + user_prompt_yaml_path = PromptCommands._get_user_prompt_yaml_path(prompt_yaml_name) + if not os.path.exists(user_prompt_yaml_path): + click.echo(f"Override file '{prompt_yaml_name}' not found. Create it with: prompts create-override {prompt_yaml_name}") + return + _open_in_editor(user_prompt_yaml_path) + + @staticmethod + @click.command("list-overrides", help="List existing prompt override files") + def list_overrides() -> None: + os.makedirs(PROMPT_TEMPLATES_DIR_IN_USER_HOME, exist_ok=True) + serena_prompt_yaml_names = [os.path.basename(f) for f in glob.glob(PROMPT_TEMPLATES_DIR_INTERNAL + "/*.yml")] + override_files = glob.glob(os.path.join(PROMPT_TEMPLATES_DIR_IN_USER_HOME, "*.yml")) + for file_path in override_files: + if os.path.basename(file_path) in serena_prompt_yaml_names: + click.echo(file_path) + + @staticmethod + @click.command("delete-override", help="Delete a prompt override file") + @click.argument("prompt_yaml_name") + def delete_override(prompt_yaml_name: str) -> None: + """ + + :param prompt_yaml_name: The yaml name of the prompt override to delete." + :return: + """ + # for convenience, we can pass names without .yml + if not prompt_yaml_name.endswith(".yml"): + prompt_yaml_name = prompt_yaml_name + ".yml" + user_prompt_yaml_path = PromptCommands._get_user_prompt_yaml_path(prompt_yaml_name) + if not os.path.exists(user_prompt_yaml_path): + click.echo(f"Override file '{prompt_yaml_name}' not found.") + return + os.remove(user_prompt_yaml_path) + click.echo(f"Deleted override file '{prompt_yaml_name}'.") + + +# Expose groups so we can reference them in pyproject.toml +mode = ModeCommands() +context = ContextCommands() +project = ProjectCommands() +config = SerenaConfigCommands() +tools = ToolCommands() +prompts = PromptCommands() + +# Expose toplevel commands for the same reason +top_level = TopLevelCommands() +start_mcp_server = top_level.start_mcp_server +index_project = project.index_deprecated + +# needed for the help script to work - register all subcommands to the top-level group +for subgroup in (mode, context, project, config, tools, prompts): + top_level.add_command(subgroup) + + +def get_help() -> str: + """Retrieve the help text for the top-level Serena CLI.""" + return top_level.get_help(click.Context(top_level, info_name="serena")) diff --git a/Libraries/serena/code_editor.py b/Libraries/serena/code_editor.py new file mode 100644 index 00000000..6fe41da1 --- /dev/null +++ b/Libraries/serena/code_editor.py @@ -0,0 +1,294 @@ +import json +import logging +import os +from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator, Reversible +from contextlib import contextmanager +from typing import TYPE_CHECKING, Generic, Optional, TypeVar + +from serena.symbol import JetBrainsSymbol, LanguageServerSymbol, LanguageServerSymbolRetriever, PositionInFile, Symbol +from solidlsp import SolidLanguageServer +from solidlsp.ls import LSPFileBuffer +from solidlsp.ls_utils import TextUtils + +from .project import Project +from .tools.jetbrains_plugin_client import JetBrainsPluginClient + +if TYPE_CHECKING: + from .agent import SerenaAgent + + +log = logging.getLogger(__name__) +TSymbol = TypeVar("TSymbol", bound=Symbol) + + +class CodeEditor(Generic[TSymbol], ABC): + def __init__(self, project_root: str, agent: Optional["SerenaAgent"] = None) -> None: + self.project_root = project_root + self.agent = agent + + class EditedFile(ABC): + @abstractmethod + def get_contents(self) -> str: + """ + :return: the contents of the file. + """ + + @abstractmethod + def delete_text_between_positions(self, start_pos: PositionInFile, end_pos: PositionInFile) -> None: + pass + + @abstractmethod + def insert_text_at_position(self, pos: PositionInFile, text: str) -> None: + pass + + @contextmanager + def _open_file_context(self, relative_path: str) -> Iterator["CodeEditor.EditedFile"]: + """ + Context manager for opening a file + """ + raise NotImplementedError("This method must be overridden for each subclass") + + @contextmanager + def _edited_file_context(self, relative_path: str) -> Iterator["CodeEditor.EditedFile"]: + """ + Context manager for editing a file. + """ + with self._open_file_context(relative_path) as edited_file: + yield edited_file + # save the file + abs_path = os.path.join(self.project_root, relative_path) + with open(abs_path, "w", encoding="utf-8") as f: + f.write(edited_file.get_contents()) + + @abstractmethod + def _find_unique_symbol(self, name_path: str, relative_file_path: str) -> TSymbol: + """ + Finds the unique symbol with the given name in the given file. + If no such symbol exists, raises a ValueError. + + :param name_path: the name path + :param relative_file_path: the relative path of the file in which to search for the symbol. + :return: the unique symbol + """ + + def replace_body(self, name_path: str, relative_file_path: str, body: str) -> None: + """ + Replaces the body of the symbol with the given name_path in the given file. + + :param name_path: the name path of the symbol to replace. + :param relative_file_path: the relative path of the file in which the symbol is defined. + :param body: the new body + """ + symbol = self._find_unique_symbol(name_path, relative_file_path) + start_pos = symbol.get_body_start_position_or_raise() + end_pos = symbol.get_body_end_position_or_raise() + + with self._edited_file_context(relative_file_path) as edited_file: + # make sure the replacement adds no additional newlines (before or after) - all newlines + # and whitespace before/after should remain the same, so we strip it entirely + body = body.strip() + + edited_file.delete_text_between_positions(start_pos, end_pos) + edited_file.insert_text_at_position(start_pos, body) + + @staticmethod + def _count_leading_newlines(text: Iterable) -> int: + cnt = 0 + for c in text: + if c == "\n": + cnt += 1 + elif c == "\r": + continue + else: + break + return cnt + + @classmethod + def _count_trailing_newlines(cls, text: Reversible) -> int: + return cls._count_leading_newlines(reversed(text)) + + def insert_after_symbol(self, name_path: str, relative_file_path: str, body: str) -> None: + """ + Inserts content after the symbol with the given name in the given file. + """ + symbol = self._find_unique_symbol(name_path, relative_file_path) + + # make sure body always ends with at least one newline + if not body.endswith("\n"): + body += "\n" + + pos = symbol.get_body_end_position_or_raise() + + # start at the beginning of the next line + col = 0 + line = pos.line + 1 + + # make sure a suitable number of leading empty lines is used (at least 0/1 depending on the symbol type, + # otherwise as many as the caller wanted to insert) + original_leading_newlines = self._count_leading_newlines(body) + body = body.lstrip("\r\n") + min_empty_lines = 0 + if symbol.is_neighbouring_definition_separated_by_empty_line(): + min_empty_lines = 1 + num_leading_empty_lines = max(min_empty_lines, original_leading_newlines) + if num_leading_empty_lines: + body = ("\n" * num_leading_empty_lines) + body + + # make sure the one line break succeeding the original symbol, which we repurposed as prefix via + # `line += 1`, is replaced + body = body.rstrip("\r\n") + "\n" + + with self._edited_file_context(relative_file_path) as edited_file: + edited_file.insert_text_at_position(PositionInFile(line, col), body) + + def insert_before_symbol(self, name_path: str, relative_file_path: str, body: str) -> None: + """ + Inserts content before the symbol with the given name in the given file. + """ + symbol = self._find_unique_symbol(name_path, relative_file_path) + symbol_start_pos = symbol.get_body_start_position_or_raise() + + # insert position is the start of line where the symbol is defined + line = symbol_start_pos.line + col = 0 + + original_trailing_empty_lines = self._count_trailing_newlines(body) - 1 + + # ensure eol is present at end + body = body.rstrip() + "\n" + + # add suitable number of trailing empty lines after the body (at least 0/1 depending on the symbol type, + # otherwise as many as the caller wanted to insert) + min_trailing_empty_lines = 0 + if symbol.is_neighbouring_definition_separated_by_empty_line(): + min_trailing_empty_lines = 1 + num_trailing_newlines = max(min_trailing_empty_lines, original_trailing_empty_lines) + body += "\n" * num_trailing_newlines + + # apply edit + with self._edited_file_context(relative_file_path) as edited_file: + edited_file.insert_text_at_position(PositionInFile(line=line, col=col), body) + + def insert_at_line(self, relative_path: str, line: int, content: str) -> None: + """ + Inserts content at the given line in the given file. + + :param relative_path: the relative path of the file in which to insert content + :param line: the 0-based index of the line to insert content at + :param content: the content to insert + """ + with self._edited_file_context(relative_path) as edited_file: + edited_file.insert_text_at_position(PositionInFile(line, 0), content) + + def delete_lines(self, relative_path: str, start_line: int, end_line: int) -> None: + """ + Deletes lines in the given file. + + :param relative_path: the relative path of the file in which to delete lines + :param start_line: the 0-based index of the first line to delete (inclusive) + :param end_line: the 0-based index of the last line to delete (inclusive) + """ + start_col = 0 + end_line_for_delete = end_line + 1 + end_col = 0 + with self._edited_file_context(relative_path) as edited_file: + start_pos = PositionInFile(line=start_line, col=start_col) + end_pos = PositionInFile(line=end_line_for_delete, col=end_col) + edited_file.delete_text_between_positions(start_pos, end_pos) + + def delete_symbol(self, name_path: str, relative_file_path: str) -> None: + """ + Deletes the symbol with the given name in the given file. + """ + symbol = self._find_unique_symbol(name_path, relative_file_path) + start_pos = symbol.get_body_start_position_or_raise() + end_pos = symbol.get_body_end_position_or_raise() + with self._edited_file_context(relative_file_path) as edited_file: + edited_file.delete_text_between_positions(start_pos, end_pos) + + +class LanguageServerCodeEditor(CodeEditor[LanguageServerSymbol]): + def __init__(self, symbol_retriever: LanguageServerSymbolRetriever, agent: Optional["SerenaAgent"] = None): + super().__init__(project_root=symbol_retriever.get_language_server().repository_root_path, agent=agent) + self._symbol_retriever = symbol_retriever + + @property + def _lang_server(self) -> SolidLanguageServer: + return self._symbol_retriever.get_language_server() + + class EditedFile(CodeEditor.EditedFile): + def __init__(self, lang_server: SolidLanguageServer, relative_path: str, file_buffer: LSPFileBuffer): + self._lang_server = lang_server + self._relative_path = relative_path + self._file_buffer = file_buffer + + def get_contents(self) -> str: + return self._file_buffer.contents + + def delete_text_between_positions(self, start_pos: PositionInFile, end_pos: PositionInFile) -> None: + self._lang_server.delete_text_between_positions(self._relative_path, start_pos.to_lsp_position(), end_pos.to_lsp_position()) + + def insert_text_at_position(self, pos: PositionInFile, text: str) -> None: + self._lang_server.insert_text_at_position(self._relative_path, pos.line, pos.col, text) + + @contextmanager + def _open_file_context(self, relative_path: str) -> Iterator["CodeEditor.EditedFile"]: + with self._lang_server.open_file(relative_path) as file_buffer: + yield self.EditedFile(self._lang_server, relative_path, file_buffer) + + def _get_code_file_content(self, relative_path: str) -> str: + """Get the content of a file using the language server.""" + return self._lang_server.language_server.retrieve_full_file_content(relative_path) + + def _find_unique_symbol(self, name_path: str, relative_file_path: str) -> LanguageServerSymbol: + symbol_candidates = self._symbol_retriever.find_by_name(name_path, within_relative_path=relative_file_path) + if len(symbol_candidates) == 0: + raise ValueError(f"No symbol with name {name_path} found in file {relative_file_path}") + if len(symbol_candidates) > 1: + raise ValueError( + f"Found multiple {len(symbol_candidates)} symbols with name {name_path} in file {relative_file_path}. " + "Their locations are: \n " + json.dumps([s.location.to_dict() for s in symbol_candidates], indent=2) + ) + return symbol_candidates[0] + + +class JetBrainsCodeEditor(CodeEditor[JetBrainsSymbol]): + def __init__(self, project: Project, agent: Optional["SerenaAgent"] = None) -> None: + self._project = project + super().__init__(project_root=project.project_root, agent=agent) + + class EditedFile(CodeEditor.EditedFile): + def __init__(self, relative_path: str, project: Project): + path = os.path.join(project.project_root, relative_path) + log.info("Editing file: %s", path) + with open(path, encoding=project.project_config.encoding) as f: + self._content = f.read() + + def get_contents(self) -> str: + return self._content + + def delete_text_between_positions(self, start_pos: PositionInFile, end_pos: PositionInFile) -> None: + self._content, _ = TextUtils.delete_text_between_positions( + self._content, start_pos.line, start_pos.col, end_pos.line, end_pos.col + ) + + def insert_text_at_position(self, pos: PositionInFile, text: str) -> None: + self._content, _, _ = TextUtils.insert_text_at_position(self._content, pos.line, pos.col, text) + + @contextmanager + def _open_file_context(self, relative_path: str) -> Iterator["CodeEditor.EditedFile"]: + yield self.EditedFile(relative_path, self._project) + + def _find_unique_symbol(self, name_path: str, relative_file_path: str) -> JetBrainsSymbol: + with JetBrainsPluginClient.from_project(self._project) as client: + result = client.find_symbol(name_path, relative_path=relative_file_path, include_body=False, depth=0, include_location=True) + symbols = result["symbols"] + if not symbols: + raise ValueError(f"No symbol with name {name_path} found in file {relative_file_path}") + if len(symbols) > 1: + raise ValueError( + f"Found multiple {len(symbols)} symbols with name {name_path} in file {relative_file_path}. " + "Their locations are: \n " + json.dumps([s["location"] for s in symbols], indent=2) + ) + return JetBrainsSymbol(symbols[0], self._project) diff --git a/Libraries/serena/config/__init__.py b/Libraries/serena/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Libraries/serena/config/context_mode.py b/Libraries/serena/config/context_mode.py new file mode 100644 index 00000000..44cd42fb --- /dev/null +++ b/Libraries/serena/config/context_mode.py @@ -0,0 +1,232 @@ +""" +Context and Mode configuration loader +""" + +import os +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Self + +import yaml +from sensai.util import logging +from sensai.util.string import ToStringMixin + +from serena.config.serena_config import ToolInclusionDefinition +from serena.constants import ( + DEFAULT_CONTEXT, + DEFAULT_MODES, + INTERNAL_MODE_YAMLS_DIR, + SERENAS_OWN_CONTEXT_YAMLS_DIR, + SERENAS_OWN_MODE_YAMLS_DIR, + USER_CONTEXT_YAMLS_DIR, + USER_MODE_YAMLS_DIR, +) + +if TYPE_CHECKING: + pass + +log = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class SerenaAgentMode(ToolInclusionDefinition, ToStringMixin): + """Represents a mode of operation for the agent, typically read off a YAML file. + An agent can be in multiple modes simultaneously as long as they are not mutually exclusive. + The modes can be adjusted after the agent is running, for example for switching from planning to editing. + """ + + name: str + prompt: str + """ + a Jinja2 template for the generation of the system prompt. + It is formatted by the agent (see SerenaAgent._format_prompt()). + """ + description: str = "" + + def _tostring_includes(self) -> list[str]: + return ["name"] + + def print_overview(self) -> None: + """Print an overview of the mode.""" + print(f"{self.name}:\n {self.description}") + if self.excluded_tools: + print(" excluded tools:\n " + ", ".join(sorted(self.excluded_tools))) + + @classmethod + def from_yaml(cls, yaml_path: str | Path) -> Self: + """Load a mode from a YAML file.""" + with open(yaml_path, encoding="utf-8") as f: + data = yaml.safe_load(f) + name = data.pop("name", Path(yaml_path).stem) + return cls(name=name, **data) + + @classmethod + def get_path(cls, name: str) -> str: + """Get the path to the YAML file for a mode.""" + fname = f"{name}.yml" + custom_mode_path = os.path.join(USER_MODE_YAMLS_DIR, fname) + if os.path.exists(custom_mode_path): + return custom_mode_path + + own_yaml_path = os.path.join(SERENAS_OWN_MODE_YAMLS_DIR, fname) + if not os.path.exists(own_yaml_path): + raise FileNotFoundError( + f"Mode {name} not found in {USER_MODE_YAMLS_DIR} or in {SERENAS_OWN_MODE_YAMLS_DIR}." + f"Available modes:\n{cls.list_registered_mode_names()}" + ) + return own_yaml_path + + @classmethod + def from_name(cls, name: str) -> Self: + """Load a registered Serena mode.""" + mode_path = cls.get_path(name) + return cls.from_yaml(mode_path) + + @classmethod + def from_name_internal(cls, name: str) -> Self: + """Loads an internal Serena mode""" + yaml_path = os.path.join(INTERNAL_MODE_YAMLS_DIR, f"{name}.yml") + if not os.path.exists(yaml_path): + raise FileNotFoundError(f"Internal mode '{name}' not found in {INTERNAL_MODE_YAMLS_DIR}") + return cls.from_yaml(yaml_path) + + @classmethod + def list_registered_mode_names(cls, include_user_modes: bool = True) -> list[str]: + """Names of all registered modes (from the corresponding YAML files in the serena repo).""" + modes = [f.stem for f in Path(SERENAS_OWN_MODE_YAMLS_DIR).glob("*.yml") if f.name != "mode.template.yml"] + if include_user_modes: + modes += cls.list_custom_mode_names() + return sorted(set(modes)) + + @classmethod + def list_custom_mode_names(cls) -> list[str]: + """Names of all custom modes defined by the user.""" + return [f.stem for f in Path(USER_MODE_YAMLS_DIR).glob("*.yml")] + + @classmethod + def load_default_modes(cls) -> list[Self]: + """Load the default modes (interactive and editing).""" + return [cls.from_name(mode) for mode in DEFAULT_MODES] + + @classmethod + def load(cls, name_or_path: str | Path) -> Self: + if str(name_or_path).endswith(".yml"): + return cls.from_yaml(name_or_path) + return cls.from_name(str(name_or_path)) + + +@dataclass(kw_only=True) +class SerenaAgentContext(ToolInclusionDefinition, ToStringMixin): + """Represents a context where the agent is operating (an IDE, a chat, etc.), typically read off a YAML file. + An agent can only be in a single context at a time. + The contexts cannot be changed after the agent is running. + """ + + name: str + prompt: str + """ + a Jinja2 template for the generation of the system prompt. + It is formatted by the agent (see SerenaAgent._format_prompt()). + """ + description: str = "" + tool_description_overrides: dict[str, str] = field(default_factory=dict) + """Maps tool names to custom descriptions, default descriptions are extracted from the tool docstrings.""" + + def _tostring_includes(self) -> list[str]: + return ["name"] + + @classmethod + def from_yaml(cls, yaml_path: str | Path) -> Self: + """Load a context from a YAML file.""" + with open(yaml_path, encoding="utf-8") as f: + data = yaml.safe_load(f) + name = data.pop("name", Path(yaml_path).stem) + # Ensure backwards compatibility for tool_description_overrides + if "tool_description_overrides" not in data: + data["tool_description_overrides"] = {} + return cls(name=name, **data) + + @classmethod + def get_path(cls, name: str) -> str: + """Get the path to the YAML file for a context.""" + fname = f"{name}.yml" + custom_context_path = os.path.join(USER_CONTEXT_YAMLS_DIR, fname) + if os.path.exists(custom_context_path): + return custom_context_path + + own_yaml_path = os.path.join(SERENAS_OWN_CONTEXT_YAMLS_DIR, fname) + if not os.path.exists(own_yaml_path): + raise FileNotFoundError( + f"Context {name} not found in {USER_CONTEXT_YAMLS_DIR} or in {SERENAS_OWN_CONTEXT_YAMLS_DIR}." + f"Available contexts:\n{cls.list_registered_context_names()}" + ) + return own_yaml_path + + @classmethod + def from_name(cls, name: str) -> Self: + """Load a registered Serena context.""" + context_path = cls.get_path(name) + return cls.from_yaml(context_path) + + @classmethod + def load(cls, name_or_path: str | Path) -> Self: + if str(name_or_path).endswith(".yml"): + return cls.from_yaml(name_or_path) + return cls.from_name(str(name_or_path)) + + @classmethod + def list_registered_context_names(cls, include_user_contexts: bool = True) -> list[str]: + """Names of all registered contexts (from the corresponding YAML files in the serena repo).""" + contexts = [f.stem for f in Path(SERENAS_OWN_CONTEXT_YAMLS_DIR).glob("*.yml")] + if include_user_contexts: + contexts += cls.list_custom_context_names() + return sorted(set(contexts)) + + @classmethod + def list_custom_context_names(cls) -> list[str]: + """Names of all custom contexts defined by the user.""" + return [f.stem for f in Path(USER_CONTEXT_YAMLS_DIR).glob("*.yml")] + + @classmethod + def load_default(cls) -> Self: + """Load the default context.""" + return cls.from_name(DEFAULT_CONTEXT) + + def print_overview(self) -> None: + """Print an overview of the mode.""" + print(f"{self.name}:\n {self.description}") + if self.excluded_tools: + print(" excluded tools:\n " + ", ".join(sorted(self.excluded_tools))) + + +class RegisteredContext(Enum): + """A registered context.""" + + IDE_ASSISTANT = "ide-assistant" + """For Serena running within an assistant that already has basic tools, like Claude Code, Cline, Cursor, etc.""" + DESKTOP_APP = "desktop-app" + """For Serena running within Claude Desktop or a similar app which does not have built-in tools for code editing.""" + AGENT = "agent" + """For Serena running as a standalone agent, e.g. through agno.""" + + def load(self) -> SerenaAgentContext: + """Load the context.""" + return SerenaAgentContext.from_name(self.value) + + +class RegisteredMode(Enum): + """A registered mode.""" + + INTERACTIVE = "interactive" + """Interactive mode, for multi-turn interactions.""" + EDITING = "editing" + """Editing tools are activated.""" + PLANNING = "planning" + """Editing tools are deactivated.""" + ONE_SHOT = "one-shot" + """Non-interactive mode, where the goal is to finish a task autonomously.""" + + def load(self) -> SerenaAgentMode: + """Load the mode.""" + return SerenaAgentMode.from_name(self.value) diff --git a/Libraries/serena/config/serena_config.py b/Libraries/serena/config/serena_config.py new file mode 100644 index 00000000..00bbc901 --- /dev/null +++ b/Libraries/serena/config/serena_config.py @@ -0,0 +1,578 @@ +""" +The Serena Model Context Protocol (MCP) Server +""" + +import os +import shutil +from collections.abc import Iterable +from copy import deepcopy +from dataclasses import dataclass, field +from datetime import datetime +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional, Self, TypeVar + +import yaml +from ruamel.yaml.comments import CommentedMap +from sensai.util import logging +from sensai.util.logging import LogTime, datetime_tag +from sensai.util.string import ToStringMixin + +from serena.constants import ( + DEFAULT_ENCODING, + PROJECT_TEMPLATE_FILE, + REPO_ROOT, + SERENA_CONFIG_TEMPLATE_FILE, + SERENA_MANAGED_DIR_IN_HOME, + SERENA_MANAGED_DIR_NAME, +) +from serena.util.general import load_yaml, save_yaml +from serena.util.inspection import determine_programming_language_composition +from solidlsp.ls_config import Language + +from ..analytics import RegisteredTokenCountEstimator +from ..util.class_decorators import singleton + +if TYPE_CHECKING: + from ..project import Project + +log = logging.getLogger(__name__) +T = TypeVar("T") +DEFAULT_TOOL_TIMEOUT: float = 240 + + +@singleton +class SerenaPaths: + """ + Provides paths to various Serena-related directories and files. + """ + + def __init__(self) -> None: + self.user_config_dir: str = SERENA_MANAGED_DIR_IN_HOME + """ + the path to the user's Serena configuration directory, which is typically ~/.serena + """ + + def get_next_log_file_path(self, prefix: str) -> str: + """ + :param prefix: the filename prefix indicating the type of the log file + :return: the full path to the log file to use + """ + log_dir = os.path.join(self.user_config_dir, "logs", datetime.now().strftime("%Y-%m-%d")) + os.makedirs(log_dir, exist_ok=True) + return os.path.join(log_dir, prefix + "_" + datetime_tag() + ".txt") + + # TODO: Paths from constants.py should be moved here + + +class ToolSet: + def __init__(self, tool_names: set[str]) -> None: + self._tool_names = tool_names + + @classmethod + def default(cls) -> "ToolSet": + """ + :return: the default tool set, which contains all tools that are enabled by default + """ + from serena.tools import ToolRegistry + + return cls(set(ToolRegistry().get_tool_names_default_enabled())) + + def apply(self, *tool_inclusion_definitions: "ToolInclusionDefinition") -> "ToolSet": + """ + :param tool_inclusion_definitions: the definitions to apply + :return: a new tool set with the definitions applied + """ + from serena.tools import ToolRegistry + + registry = ToolRegistry() + tool_names = set(self._tool_names) + for definition in tool_inclusion_definitions: + included_tools = [] + excluded_tools = [] + for included_tool in definition.included_optional_tools: + if not registry.is_valid_tool_name(included_tool): + raise ValueError(f"Invalid tool name '{included_tool}' provided for inclusion") + if included_tool not in tool_names: + tool_names.add(included_tool) + included_tools.append(included_tool) + for excluded_tool in definition.excluded_tools: + if not registry.is_valid_tool_name(excluded_tool): + raise ValueError(f"Invalid tool name '{excluded_tool}' provided for exclusion") + if excluded_tool in tool_names: + tool_names.remove(excluded_tool) + excluded_tools.append(excluded_tool) + if included_tools: + log.info(f"{definition} included {len(included_tools)} tools: {', '.join(included_tools)}") + if excluded_tools: + log.info(f"{definition} excluded {len(excluded_tools)} tools: {', '.join(excluded_tools)}") + return ToolSet(tool_names) + + def without_editing_tools(self) -> "ToolSet": + """ + :return: a new tool set that excludes all tools that can edit + """ + from serena.tools import ToolRegistry + + registry = ToolRegistry() + tool_names = set(self._tool_names) + for tool_name in self._tool_names: + if registry.get_tool_class_by_name(tool_name).can_edit(): + tool_names.remove(tool_name) + return ToolSet(tool_names) + + def get_tool_names(self) -> set[str]: + """ + Returns the names of the tools that are currently included in the tool set. + """ + return self._tool_names + + def includes_name(self, tool_name: str) -> bool: + return tool_name in self._tool_names + + +@dataclass +class ToolInclusionDefinition: + excluded_tools: Iterable[str] = () + included_optional_tools: Iterable[str] = () + + +class SerenaConfigError(Exception): + pass + + +def get_serena_managed_in_project_dir(project_root: str | Path) -> str: + return os.path.join(project_root, SERENA_MANAGED_DIR_NAME) + + +def is_running_in_docker() -> bool: + """Check if we're running inside a Docker container.""" + # Check for Docker-specific files + if os.path.exists("/.dockerenv"): + return True + # Check cgroup for docker references + try: + with open("/proc/self/cgroup") as f: + return "docker" in f.read() + except FileNotFoundError: + return False + + +@dataclass(kw_only=True) +class ProjectConfig(ToolInclusionDefinition, ToStringMixin): + project_name: str + language: Language + ignored_paths: list[str] = field(default_factory=list) + read_only: bool = False + ignore_all_files_in_gitignore: bool = True + initial_prompt: str = "" + encoding: str = DEFAULT_ENCODING + + SERENA_DEFAULT_PROJECT_FILE = "project.yml" + + def _tostring_includes(self) -> list[str]: + return ["project_name"] + + @classmethod + def autogenerate( + cls, project_root: str | Path, project_name: str | None = None, project_language: Language | None = None, save_to_disk: bool = True + ) -> Self: + """ + Autogenerate a project configuration for a given project root. + + :param project_root: the path to the project root + :param project_name: the name of the project; if None, the name of the project will be the name of the directory + containing the project + :param project_language: the programming language of the project; if None, it will be determined automatically + :param save_to_disk: whether to save the project configuration to disk + :return: the project configuration + """ + project_root = Path(project_root).resolve() + if not project_root.exists(): + raise FileNotFoundError(f"Project root not found: {project_root}") + with LogTime("Project configuration auto-generation", logger=log): + project_name = project_name or project_root.name + if project_language is None: + language_composition = determine_programming_language_composition(str(project_root)) + if len(language_composition) == 0: + raise ValueError( + f"No source files found in {project_root}\n\n" + f"To use Serena with this project, you need to either:\n" + f"1. Add source files in one of the supported languages (Python, JavaScript/TypeScript, Java, C#, Rust, Go, Ruby, C++, PHP, Swift, Elixir, Terraform, Bash)\n" + f"2. Create a project configuration file manually at:\n" + f" {os.path.join(project_root, cls.rel_path_to_project_yml())}\n\n" + f"Example project.yml:\n" + f" project_name: {project_name}\n" + f" language: python # or typescript, java, csharp, rust, go, ruby, cpp, php, swift, elixir, terraform, bash\n" + ) + # find the language with the highest percentage + dominant_language = max(language_composition.keys(), key=lambda lang: language_composition[lang]) + else: + dominant_language = project_language.value + config_with_comments = load_yaml(PROJECT_TEMPLATE_FILE, preserve_comments=True) + config_with_comments["project_name"] = project_name + config_with_comments["language"] = dominant_language + if save_to_disk: + save_yaml(str(project_root / cls.rel_path_to_project_yml()), config_with_comments, preserve_comments=True) + return cls._from_dict(config_with_comments) + + @classmethod + def rel_path_to_project_yml(cls) -> str: + return os.path.join(SERENA_MANAGED_DIR_NAME, cls.SERENA_DEFAULT_PROJECT_FILE) + + @classmethod + def _from_dict(cls, data: dict[str, Any]) -> Self: + """ + Create a ProjectConfig instance from a configuration dictionary + """ + language_str = data["language"].lower() + project_name = data["project_name"] + # backwards compatibility + if language_str == "javascript": + log.warning(f"Found deprecated project language `javascript` in project {project_name}, please change to `typescript`") + language_str = "typescript" + try: + language = Language(language_str) + except ValueError as e: + raise ValueError(f"Invalid language: {data['language']}.\nValid languages are: {[l.value for l in Language]}") from e + return cls( + project_name=project_name, + language=language, + ignored_paths=data.get("ignored_paths", []), + excluded_tools=data.get("excluded_tools", []), + included_optional_tools=data.get("included_optional_tools", []), + read_only=data.get("read_only", False), + ignore_all_files_in_gitignore=data.get("ignore_all_files_in_gitignore", True), + initial_prompt=data.get("initial_prompt", ""), + encoding=data.get("encoding", DEFAULT_ENCODING), + ) + + @classmethod + def load(cls, project_root: Path | str, autogenerate: bool = False) -> Self: + """ + Load a ProjectConfig instance from the path to the project root. + """ + project_root = Path(project_root) + yaml_path = project_root / cls.rel_path_to_project_yml() + if not yaml_path.exists(): + if autogenerate: + return cls.autogenerate(project_root) + else: + raise FileNotFoundError(f"Project configuration file not found: {yaml_path}") + with open(yaml_path, encoding="utf-8") as f: + yaml_data = yaml.safe_load(f) + if "project_name" not in yaml_data: + yaml_data["project_name"] = project_root.name + return cls._from_dict(yaml_data) + + +class RegisteredProject(ToStringMixin): + def __init__(self, project_root: str, project_config: "ProjectConfig", project_instance: Optional["Project"] = None) -> None: + """ + Represents a registered project in the Serena configuration. + + :param project_root: the root directory of the project + :param project_config: the configuration of the project + """ + self.project_root = Path(project_root).resolve() + self.project_config = project_config + self._project_instance = project_instance + + def _tostring_exclude_private(self) -> bool: + return True + + @property + def project_name(self) -> str: + return self.project_config.project_name + + @classmethod + def from_project_instance(cls, project_instance: "Project") -> "RegisteredProject": + return RegisteredProject( + project_root=project_instance.project_root, + project_config=project_instance.project_config, + project_instance=project_instance, + ) + + def matches_root_path(self, path: str | Path) -> bool: + """ + Check if the given path matches the project root path. + + :param path: the path to check + :return: True if the path matches the project root, False otherwise + """ + return self.project_root == Path(path).resolve() + + def get_project_instance(self) -> "Project": + """ + Returns the project instance for this registered project, loading it if necessary. + """ + if self._project_instance is None: + from ..project import Project + + with LogTime(f"Loading project instance for {self}", logger=log): + self._project_instance = Project(project_root=str(self.project_root), project_config=self.project_config) + return self._project_instance + + +@dataclass(kw_only=True) +class SerenaConfig(ToolInclusionDefinition, ToStringMixin): + """ + Holds the Serena agent configuration, which is typically loaded from a YAML configuration file + (when instantiated via :method:`from_config_file`), which is updated when projects are added or removed. + For testing purposes, it can also be instantiated directly with the desired parameters. + """ + + projects: list[RegisteredProject] = field(default_factory=list) + gui_log_window_enabled: bool = False + log_level: int = logging.INFO + trace_lsp_communication: bool = False + web_dashboard: bool = True + web_dashboard_open_on_launch: bool = True + tool_timeout: float = DEFAULT_TOOL_TIMEOUT + loaded_commented_yaml: CommentedMap | None = None + config_file_path: str | None = None + """ + the path to the configuration file to which updates of the configuration shall be saved; + if None, the configuration is not saved to disk + """ + jetbrains: bool = False + """ + whether to apply JetBrains mode + """ + record_tool_usage_stats: bool = False + """Whether to record tool usage statistics, they will be shown in the web dashboard if recording is active. + """ + token_count_estimator: str = RegisteredTokenCountEstimator.TIKTOKEN_GPT4O.name + """Only relevant if `record_tool_usage` is True; the name of the token count estimator to use for tool usage statistics. + See the `RegisteredTokenCountEstimator` enum for available options. + + Note: some token estimators (like tiktoken) may require downloading data files + on the first run, which can take some time and require internet access. Others, like the Anthropic ones, may require an API key + and rate limits may apply. + """ + default_max_tool_answer_chars: int = 150_000 + """Used as default for tools where the apply method has a default maximal answer length. + Even though the value of the max_answer_chars can be changed when calling the tool, it may make sense to adjust this default + through the global configuration. + """ + ls_specific_settings: dict = field(default_factory=dict) + """Advanced configuration option allowing to configure language server implementation specific options, see SolidLSPSettings for more info.""" + + CONFIG_FILE = "serena_config.yml" + CONFIG_FILE_DOCKER = "serena_config.docker.yml" # Docker-specific config file; auto-generated if missing, mounted via docker-compose for user customization + + def _tostring_includes(self) -> list[str]: + return ["config_file_path"] + + @classmethod + def generate_config_file(cls, config_file_path: str) -> None: + """ + Generates a Serena configuration file at the specified path from the template file. + + :param config_file_path: the path where the configuration file should be generated + """ + log.info(f"Auto-generating Serena configuration file in {config_file_path}") + loaded_commented_yaml = load_yaml(SERENA_CONFIG_TEMPLATE_FILE, preserve_comments=True) + save_yaml(config_file_path, loaded_commented_yaml, preserve_comments=True) + + @classmethod + def _determine_config_file_path(cls) -> str: + """ + :return: the location where the Serena configuration file is stored/should be stored + """ + if is_running_in_docker(): + return os.path.join(REPO_ROOT, cls.CONFIG_FILE_DOCKER) + else: + config_path = os.path.join(SERENA_MANAGED_DIR_IN_HOME, cls.CONFIG_FILE) + + # if the config file does not exist, check if we can migrate it from the old location + if not os.path.exists(config_path): + old_config_path = os.path.join(REPO_ROOT, cls.CONFIG_FILE) + if os.path.exists(old_config_path): + log.info(f"Moving Serena configuration file from {old_config_path} to {config_path}") + os.makedirs(os.path.dirname(config_path), exist_ok=True) + shutil.move(old_config_path, config_path) + + return config_path + + @classmethod + def from_config_file(cls, generate_if_missing: bool = True) -> "SerenaConfig": + """ + Static constructor to create SerenaConfig from the configuration file + """ + config_file_path = cls._determine_config_file_path() + + # create the configuration file from the template if necessary + if not os.path.exists(config_file_path): + if not generate_if_missing: + raise FileNotFoundError(f"Serena configuration file not found: {config_file_path}") + log.info(f"Serena configuration file not found at {config_file_path}, autogenerating...") + cls.generate_config_file(config_file_path) + + # load the configuration + log.info(f"Loading Serena configuration from {config_file_path}") + try: + loaded_commented_yaml = load_yaml(config_file_path, preserve_comments=True) + except Exception as e: + raise ValueError(f"Error loading Serena configuration from {config_file_path}: {e}") from e + + # create the configuration instance + instance = cls(loaded_commented_yaml=loaded_commented_yaml, config_file_path=config_file_path) + + # read projects + if "projects" not in loaded_commented_yaml: + raise SerenaConfigError("`projects` key not found in Serena configuration. Please update your `serena_config.yml` file.") + + # load list of known projects + instance.projects = [] + num_project_migrations = 0 + for path in loaded_commented_yaml["projects"]: + path = Path(path).resolve() + if not path.exists() or (path.is_dir() and not (path / ProjectConfig.rel_path_to_project_yml()).exists()): + log.warning(f"Project path {path} does not exist or does not contain a project configuration file, skipping.") + continue + if path.is_file(): + path = cls._migrate_out_of_project_config_file(path) + if path is None: + continue + num_project_migrations += 1 + project_config = ProjectConfig.load(path) + project = RegisteredProject( + project_root=str(path), + project_config=project_config, + ) + instance.projects.append(project) + + # set other configuration parameters + if is_running_in_docker(): + instance.gui_log_window_enabled = False # not supported in Docker + else: + instance.gui_log_window_enabled = loaded_commented_yaml.get("gui_log_window", False) + instance.log_level = loaded_commented_yaml.get("log_level", loaded_commented_yaml.get("gui_log_level", logging.INFO)) + instance.web_dashboard = loaded_commented_yaml.get("web_dashboard", True) + instance.web_dashboard_open_on_launch = loaded_commented_yaml.get("web_dashboard_open_on_launch", True) + instance.tool_timeout = loaded_commented_yaml.get("tool_timeout", DEFAULT_TOOL_TIMEOUT) + instance.trace_lsp_communication = loaded_commented_yaml.get("trace_lsp_communication", False) + instance.excluded_tools = loaded_commented_yaml.get("excluded_tools", []) + instance.included_optional_tools = loaded_commented_yaml.get("included_optional_tools", []) + instance.jetbrains = loaded_commented_yaml.get("jetbrains", False) + instance.record_tool_usage_stats = loaded_commented_yaml.get("record_tool_usage_stats", False) + instance.token_count_estimator = loaded_commented_yaml.get( + "token_count_estimator", RegisteredTokenCountEstimator.TIKTOKEN_GPT4O.name + ) + instance.default_max_tool_answer_chars = loaded_commented_yaml.get("default_max_tool_answer_chars", 150_000) + instance.ls_specific_settings = loaded_commented_yaml.get("ls_specific_settings", {}) + + # re-save the configuration file if any migrations were performed + if num_project_migrations > 0: + log.info( + f"Migrated {num_project_migrations} project configurations from legacy format to in-project configuration; re-saving configuration" + ) + instance.save() + + return instance + + @classmethod + def _migrate_out_of_project_config_file(cls, path: Path) -> Path | None: + """ + Migrates a legacy project configuration file (which is a YAML file containing the project root) to the + in-project configuration file (project.yml) inside the project root directory. + + :param path: the path to the legacy project configuration file + :return: the project root path if the migration was successful, None otherwise. + """ + log.info(f"Found legacy project configuration file {path}, migrating to in-project configuration.") + try: + with open(path, encoding="utf-8") as f: + project_config_data = yaml.safe_load(f) + if "project_name" not in project_config_data: + project_name = path.stem + with open(path, "a", encoding="utf-8") as f: + f.write(f"\nproject_name: {project_name}") + project_root = project_config_data["project_root"] + shutil.move(str(path), str(Path(project_root) / ProjectConfig.rel_path_to_project_yml())) + return Path(project_root).resolve() + except Exception as e: + log.error(f"Error migrating configuration file: {e}") + return None + + @cached_property + def project_paths(self) -> list[str]: + return sorted(str(project.project_root) for project in self.projects) + + @cached_property + def project_names(self) -> list[str]: + return sorted(project.project_config.project_name for project in self.projects) + + def get_project(self, project_root_or_name: str) -> Optional["Project"]: + # look for project by name + project_candidates = [] + for project in self.projects: + if project.project_config.project_name == project_root_or_name: + project_candidates.append(project) + if len(project_candidates) == 1: + return project_candidates[0].get_project_instance() + elif len(project_candidates) > 1: + raise ValueError( + f"Multiple projects found with name '{project_root_or_name}'. Please activate it by location instead. " + f"Locations: {[p.project_root for p in project_candidates]}" + ) + # no project found by name; check if it's a path + if os.path.isdir(project_root_or_name): + for project in self.projects: + if project.matches_root_path(project_root_or_name): + return project.get_project_instance() + return None + + def add_project_from_path(self, project_root: Path | str) -> "Project": + """ + Add a project to the Serena configuration from a given path. Will raise a FileExistsError if a + project already exists at the path. + + :param project_root: the path to the project to add + :return: the project that was added + """ + from ..project import Project + + project_root = Path(project_root).resolve() + if not project_root.exists(): + raise FileNotFoundError(f"Error: Path does not exist: {project_root}") + if not project_root.is_dir(): + raise FileNotFoundError(f"Error: Path is not a directory: {project_root}") + + for already_registered_project in self.projects: + if str(already_registered_project.project_root) == str(project_root): + raise FileExistsError( + f"Project with path {project_root} was already added with name '{already_registered_project.project_name}'." + ) + + project_config = ProjectConfig.load(project_root, autogenerate=True) + + new_project = Project(project_root=str(project_root), project_config=project_config, is_newly_created=True) + self.projects.append(RegisteredProject.from_project_instance(new_project)) + self.save() + + return new_project + + def remove_project(self, project_name: str) -> None: + # find the index of the project with the desired name and remove it + for i, project in enumerate(list(self.projects)): + if project.project_name == project_name: + del self.projects[i] + break + else: + raise ValueError(f"Project '{project_name}' not found in Serena configuration; valid project names: {self.project_names}") + self.save() + + def save(self) -> None: + """ + Saves the configuration to the file from which it was loaded (if any) + """ + if self.config_file_path is None: + return + assert self.loaded_commented_yaml is not None, "Cannot save configuration without loaded YAML" + loaded_original_yaml = deepcopy(self.loaded_commented_yaml) + # projects are unique absolute paths + # we also canonicalize them before saving + loaded_original_yaml["projects"] = sorted({str(project.project_root) for project in self.projects}) + save_yaml(self.config_file_path, loaded_original_yaml, preserve_comments=True) diff --git a/Libraries/serena/constants.py b/Libraries/serena/constants.py new file mode 100644 index 00000000..1e5d98a1 --- /dev/null +++ b/Libraries/serena/constants.py @@ -0,0 +1,35 @@ +from pathlib import Path + +_repo_root_path = Path(__file__).parent.parent.parent.resolve() +_serena_pkg_path = Path(__file__).parent.resolve() + +SERENA_MANAGED_DIR_NAME = ".serena" +_serena_in_home_managed_dir = Path.home() / ".serena" + +SERENA_MANAGED_DIR_IN_HOME = str(_serena_in_home_managed_dir) + +# TODO: Path-related constants should be moved to SerenaPaths; don't add further constants here. +REPO_ROOT = str(_repo_root_path) +PROMPT_TEMPLATES_DIR_INTERNAL = str(_serena_pkg_path / "resources" / "config" / "prompt_templates") +PROMPT_TEMPLATES_DIR_IN_USER_HOME = str(_serena_in_home_managed_dir / "prompt_templates") +SERENAS_OWN_CONTEXT_YAMLS_DIR = str(_serena_pkg_path / "resources" / "config" / "contexts") +"""The contexts that are shipped with the Serena package, i.e. the default contexts.""" +USER_CONTEXT_YAMLS_DIR = str(_serena_in_home_managed_dir / "contexts") +"""Contexts defined by the user. If a name of a context matches a name of a context in SERENAS_OWN_CONTEXT_YAMLS_DIR, the user context will override the default one.""" +SERENAS_OWN_MODE_YAMLS_DIR = str(_serena_pkg_path / "resources" / "config" / "modes") +"""The modes that are shipped with the Serena package, i.e. the default modes.""" +USER_MODE_YAMLS_DIR = str(_serena_in_home_managed_dir / "modes") +"""Modes defined by the user. If a name of a mode matches a name of a mode in SERENAS_OWN_MODE_YAMLS_DIR, the user mode will override the default one.""" +INTERNAL_MODE_YAMLS_DIR = str(_serena_pkg_path / "resources" / "config" / "internal_modes") +"""Internal modes, never overridden by user modes.""" +SERENA_DASHBOARD_DIR = str(_serena_pkg_path / "resources" / "dashboard") +SERENA_ICON_DIR = str(_serena_pkg_path / "resources" / "icons") + +DEFAULT_ENCODING = "utf-8" +DEFAULT_CONTEXT = "desktop-app" +DEFAULT_MODES = ("interactive", "editing") + +PROJECT_TEMPLATE_FILE = str(_serena_pkg_path / "resources" / "project.template.yml") +SERENA_CONFIG_TEMPLATE_FILE = str(_serena_pkg_path / "resources" / "serena_config.template.yml") + +SERENA_LOG_FORMAT = "%(levelname)-5s %(asctime)-15s [%(threadName)s] %(name)s:%(funcName)s:%(lineno)d - %(message)s" diff --git a/Libraries/serena/dashboard.py b/Libraries/serena/dashboard.py new file mode 100644 index 00000000..c4a5f91d --- /dev/null +++ b/Libraries/serena/dashboard.py @@ -0,0 +1,170 @@ +import os +import socket +import threading +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from flask import Flask, Response, request, send_from_directory +from pydantic import BaseModel +from sensai.util import logging + +from serena.analytics import ToolUsageStats +from serena.constants import SERENA_DASHBOARD_DIR +from serena.util.logging import MemoryLogHandler + +if TYPE_CHECKING: + from serena.agent import SerenaAgent + +log = logging.getLogger(__name__) + +# disable Werkzeug's logging to avoid cluttering the output +logging.getLogger("werkzeug").setLevel(logging.WARNING) + + +class RequestLog(BaseModel): + start_idx: int = 0 + + +class ResponseLog(BaseModel): + messages: list[str] + max_idx: int + active_project: str | None = None + + +class ResponseToolNames(BaseModel): + tool_names: list[str] + + +class ResponseToolStats(BaseModel): + stats: dict[str, dict[str, int]] + + +class SerenaDashboardAPI: + log = logging.getLogger(__qualname__) + + def __init__( + self, + memory_log_handler: MemoryLogHandler, + tool_names: list[str], + agent: "SerenaAgent", + shutdown_callback: Callable[[], None] | None = None, + tool_usage_stats: ToolUsageStats | None = None, + ) -> None: + self._memory_log_handler = memory_log_handler + self._tool_names = tool_names + self._agent = agent + self._shutdown_callback = shutdown_callback + self._app = Flask(__name__) + self._tool_usage_stats = tool_usage_stats + self._setup_routes() + + @property + def memory_log_handler(self) -> MemoryLogHandler: + return self._memory_log_handler + + def _setup_routes(self) -> None: + # Static files + @self._app.route("/dashboard/") + def serve_dashboard(filename: str) -> Response: + return send_from_directory(SERENA_DASHBOARD_DIR, filename) + + @self._app.route("/dashboard/") + def serve_dashboard_index() -> Response: + return send_from_directory(SERENA_DASHBOARD_DIR, "index.html") + + # API routes + @self._app.route("/get_log_messages", methods=["POST"]) + def get_log_messages() -> dict[str, Any]: + request_data = request.get_json() + if not request_data: + request_log = RequestLog() + else: + request_log = RequestLog.model_validate(request_data) + + result = self._get_log_messages(request_log) + return result.model_dump() + + @self._app.route("/get_tool_names", methods=["GET"]) + def get_tool_names() -> dict[str, Any]: + result = self._get_tool_names() + return result.model_dump() + + @self._app.route("/get_tool_stats", methods=["GET"]) + def get_tool_stats_route() -> dict[str, Any]: + result = self._get_tool_stats() + return result.model_dump() + + @self._app.route("/clear_tool_stats", methods=["POST"]) + def clear_tool_stats_route() -> dict[str, str]: + self._clear_tool_stats() + return {"status": "cleared"} + + @self._app.route("/get_token_count_estimator_name", methods=["GET"]) + def get_token_count_estimator_name() -> dict[str, str]: + estimator_name = self._tool_usage_stats.token_estimator_name if self._tool_usage_stats else "unknown" + return {"token_count_estimator_name": estimator_name} + + @self._app.route("/shutdown", methods=["PUT"]) + def shutdown() -> dict[str, str]: + self._shutdown() + return {"status": "shutting down"} + + def _get_log_messages(self, request_log: RequestLog) -> ResponseLog: + all_messages = self._memory_log_handler.get_log_messages() + requested_messages = all_messages[request_log.start_idx :] if request_log.start_idx <= len(all_messages) else [] + project = self._agent.get_active_project() + project_name = project.project_name if project else None + return ResponseLog(messages=requested_messages, max_idx=len(all_messages) - 1, active_project=project_name) + + def _get_tool_names(self) -> ResponseToolNames: + return ResponseToolNames(tool_names=self._tool_names) + + def _get_tool_stats(self) -> ResponseToolStats: + if self._tool_usage_stats is not None: + return ResponseToolStats(stats=self._tool_usage_stats.get_tool_stats_dict()) + else: + return ResponseToolStats(stats={}) + + def _clear_tool_stats(self) -> None: + if self._tool_usage_stats is not None: + self._tool_usage_stats.clear() + + def _shutdown(self) -> None: + log.info("Shutting down Serena") + if self._shutdown_callback: + self._shutdown_callback() + else: + # noinspection PyProtectedMember + # noinspection PyUnresolvedReferences + os._exit(0) + + @staticmethod + def _find_first_free_port(start_port: int) -> int: + port = start_port + while port <= 65535: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("0.0.0.0", port)) + return port + except OSError: + port += 1 + + raise RuntimeError(f"No free ports found starting from {start_port}") + + def run(self, host: str = "0.0.0.0", port: int = 0x5EDA) -> int: + """ + Runs the dashboard on the given host and port and returns the port number. + """ + # patch flask.cli.show_server to avoid printing the server info + from flask import cli + + cli.show_server_banner = lambda *args, **kwargs: None + + self._app.run(host=host, port=port, debug=False, use_reloader=False, threaded=True) + return port + + def run_in_thread(self) -> tuple[threading.Thread, int]: + port = self._find_first_free_port(0x5EDA) + thread = threading.Thread(target=lambda: self.run(port=port), daemon=True) + thread.start() + return thread, port diff --git a/Libraries/serena/generated/generated_prompt_factory.py b/Libraries/serena/generated/generated_prompt_factory.py new file mode 100644 index 00000000..3a5e1442 --- /dev/null +++ b/Libraries/serena/generated/generated_prompt_factory.py @@ -0,0 +1,38 @@ +# ruff: noqa +# black: skip +# mypy: ignore-errors + +# NOTE: This module is auto-generated from interprompt.autogenerate_prompt_factory_module, do not edit manually! + +from interprompt.multilang_prompt import PromptList +from interprompt.prompt_factory import PromptFactoryBase +from typing import Any + + +class PromptFactory(PromptFactoryBase): + """ + A class for retrieving and rendering prompt templates and prompt lists. + """ + + def create_onboarding_prompt(self, *, system: Any) -> str: + return self._render_prompt("onboarding_prompt", locals()) + + def create_think_about_collected_information(self) -> str: + return self._render_prompt("think_about_collected_information", locals()) + + def create_think_about_task_adherence(self) -> str: + return self._render_prompt("think_about_task_adherence", locals()) + + def create_think_about_whether_you_are_done(self) -> str: + return self._render_prompt("think_about_whether_you_are_done", locals()) + + def create_summarize_changes(self) -> str: + return self._render_prompt("summarize_changes", locals()) + + def create_prepare_for_new_conversation(self) -> str: + return self._render_prompt("prepare_for_new_conversation", locals()) + + def create_system_prompt( + self, *, available_markers: Any, available_tools: Any, context_system_prompt: Any, mode_system_prompts: Any + ) -> str: + return self._render_prompt("system_prompt", locals()) diff --git a/Libraries/serena/gui_log_viewer.py b/Libraries/serena/gui_log_viewer.py new file mode 100644 index 00000000..f46d8c30 --- /dev/null +++ b/Libraries/serena/gui_log_viewer.py @@ -0,0 +1,405 @@ +# mypy: ignore-errors +import logging +import os +import queue +import sys +import threading +import tkinter as tk +import traceback +from enum import Enum, auto +from pathlib import Path +from typing import Literal + +from serena import constants +from serena.util.logging import MemoryLogHandler + +log = logging.getLogger(__name__) + + +class LogLevel(Enum): + DEBUG = auto() + INFO = auto() + WARNING = auto() + ERROR = auto() + DEFAULT = auto() + + +class GuiLogViewer: + """ + A class that creates a Tkinter GUI for displaying log messages in a separate thread. + The log viewer supports coloring based on log levels (DEBUG, INFO, WARNING, ERROR). + It can also highlight tool names in boldface when they appear in log messages. + """ + + def __init__( + self, + mode: Literal["dashboard", "error"], + title="Log Viewer", + memory_log_handler: MemoryLogHandler | None = None, + width=800, + height=600, + ): + """ + :param mode: the mode; if "dashboard", run a dashboard with logs and some control options; if "error", run + a simple error log viewer (for fatal exceptions) + :param title: the window title + :param memory_log_handler: an optional log handler from which to obtain log messages; If not provided, + must pass the instance to a `GuiLogViewerHandler` to add log messages. + :param width: the initial window width + :param height: the initial window height + """ + self.mode = mode + self.title = title + self.width = width + self.height = height + self.message_queue = queue.Queue() + self.running = False + self.log_thread = None + self.tool_names = [] # List to store tool names for highlighting + + # Define colors for different log levels + self.log_colors = { + LogLevel.DEBUG: "#808080", # Gray + LogLevel.INFO: "#000000", # Black + LogLevel.WARNING: "#FF8C00", # Dark Orange + LogLevel.ERROR: "#FF0000", # Red + LogLevel.DEFAULT: "#000000", # Black + } + + if memory_log_handler is not None: + for msg in memory_log_handler.get_log_messages(): + self.message_queue.put(msg) + memory_log_handler.add_emit_callback(lambda msg: self.message_queue.put(msg)) + + def start(self): + """Start the log viewer in a separate thread.""" + if not self.running: + self.log_thread = threading.Thread(target=self.run_gui) + self.log_thread.daemon = True + self.log_thread.start() + return True + return False + + def stop(self): + """Stop the log viewer.""" + if self.running: + # Add a sentinel value to the queue to signal the GUI to exit + self.message_queue.put(None) + return True + return False + + def set_tool_names(self, tool_names): + """ + Set or update the list of tool names to be highlighted in log messages. + + Args: + tool_names (list): A list of tool name strings to highlight + + """ + self.tool_names = tool_names + + def add_log(self, message): + """ + Add a log message to the viewer. + + Args: + message (str): The log message to display + + """ + self.message_queue.put(message) + + def _determine_log_level(self, message): + """ + Determine the log level from the message. + + Args: + message (str): The log message + + Returns: + LogLevel: The determined log level + + """ + message_upper = message.upper() + if message_upper.startswith("DEBUG"): + return LogLevel.DEBUG + elif message_upper.startswith("INFO"): + return LogLevel.INFO + elif message_upper.startswith("WARNING"): + return LogLevel.WARNING + elif message_upper.startswith("ERROR"): + return LogLevel.ERROR + else: + return LogLevel.DEFAULT + + def _process_queue(self): + """Process messages from the queue and update the text widget.""" + try: + while not self.message_queue.empty(): + message = self.message_queue.get_nowait() + + # Check for sentinel value to exit + if message is None: + self.root.quit() + return + + # Check if scrollbar is at the bottom before adding new text + # Get current scroll position + current_position = self.text_widget.yview() + # If near the bottom (allowing for small floating point differences) + was_at_bottom = current_position[1] > 0.99 + + log_level = self._determine_log_level(message) + + # Insert the message at the end of the text with appropriate log level tag + self.text_widget.configure(state=tk.NORMAL) + + # Find tool names in the message and highlight them + if self.tool_names: + # Capture start position (before insertion) + start_index = self.text_widget.index("end-1c") + + # Insert the message + self.text_widget.insert(tk.END, message + "\n", log_level.name) + + # Convert start index to line/char format + line, char = map(int, start_index.split(".")) + + # Search for tool names in the message string directly + for tool_name in self.tool_names: + start_offset = 0 + while True: + found_at = message.find(tool_name, start_offset) + if found_at == -1: + break + + # Calculate line/column from offset + offset_line = line + offset_char = char + for c in message[:found_at]: + if c == "\n": + offset_line += 1 + offset_char = 0 + else: + offset_char += 1 + + # Construct index positions + start_pos = f"{offset_line}.{offset_char}" + end_pos = f"{offset_line}.{offset_char + len(tool_name)}" + + # Add tag to highlight the tool name + self.text_widget.tag_add("TOOL_NAME", start_pos, end_pos) + + start_offset = found_at + len(tool_name) + + else: + # No tool names to highlight, just insert the message + self.text_widget.insert(tk.END, message + "\n", log_level.name) + + self.text_widget.configure(state=tk.DISABLED) + + # Auto-scroll to the bottom only if it was already at the bottom + if was_at_bottom: + self.text_widget.see(tk.END) + + # Schedule to check the queue again + if self.running: + self.root.after(100, self._process_queue) + + except Exception as e: + print(f"Error processing message queue: {e}", file=sys.stderr) + if self.running: + self.root.after(100, self._process_queue) + + def run_gui(self): + """Run the GUI""" + self.running = True + try: + # Set app id (avoid app being lumped together with other Python-based apps in Windows taskbar) + if sys.platform == "win32": + import ctypes + + ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("oraios.serena") + + self.root = tk.Tk() + self.root.title(self.title) + self.root.geometry(f"{self.width}x{self.height}") + + # Make the window resizable + self.root.columnconfigure(0, weight=1) + # We now have two rows - one for logo and one for text + self.root.rowconfigure(0, weight=0) # Logo row + self.root.rowconfigure(1, weight=1) # Text content row + + dashboard_path = Path(constants.SERENA_DASHBOARD_DIR) + + # Load and display the logo image + try: + # construct path relative to path of this file + image_path = dashboard_path / "serena-logs.png" + self.logo_image = tk.PhotoImage(file=image_path) + + # Create a label to display the logo + self.logo_label = tk.Label(self.root, image=self.logo_image) + self.logo_label.grid(row=0, column=0, sticky="ew") + except Exception as e: + print(f"Error loading logo image: {e}", file=sys.stderr) + + # Create frame to hold text widget and scrollbars + frame = tk.Frame(self.root) + frame.grid(row=1, column=0, sticky="nsew") + frame.columnconfigure(0, weight=1) + frame.rowconfigure(0, weight=1) + + # Create horizontal scrollbar + h_scrollbar = tk.Scrollbar(frame, orient=tk.HORIZONTAL) + h_scrollbar.grid(row=1, column=0, sticky="ew") + + # Create vertical scrollbar + v_scrollbar = tk.Scrollbar(frame, orient=tk.VERTICAL) + v_scrollbar.grid(row=0, column=1, sticky="ns") + + # Create text widget with horizontal scrolling + self.text_widget = tk.Text( + frame, wrap=tk.NONE, width=self.width, height=self.height, xscrollcommand=h_scrollbar.set, yscrollcommand=v_scrollbar.set + ) + self.text_widget.grid(row=0, column=0, sticky="nsew") + self.text_widget.configure(state=tk.DISABLED) # Make it read-only + + # Configure scrollbars + h_scrollbar.config(command=self.text_widget.xview) + v_scrollbar.config(command=self.text_widget.yview) + + # Configure tags for different log levels with appropriate colors + for level, color in self.log_colors.items(): + self.text_widget.tag_configure(level.name, foreground=color) + + # Configure tag for tool names + self.text_widget.tag_configure("TOOL_NAME", background="#ffff00") + + # Set up the queue processing + self.root.after(100, self._process_queue) + + # Handle window close event depending on mode + if self.mode == "dashboard": + self.root.protocol("WM_DELETE_WINDOW", lambda: self.root.iconify()) + else: + self.root.protocol("WM_DELETE_WINDOW", self.stop) + + # Create menu bar + if self.mode == "dashboard": + menubar = tk.Menu(self.root) + server_menu = tk.Menu(menubar, tearoff=0) + server_menu.add_command(label="Shutdown", command=self._shutdown_server) # type: ignore + menubar.add_cascade(label="Server", menu=server_menu) + self.root.config(menu=menubar) + + # Configure icons + icon_16 = tk.PhotoImage(file=dashboard_path / "serena-icon-16.png") + icon_32 = tk.PhotoImage(file=dashboard_path / "serena-icon-32.png") + icon_48 = tk.PhotoImage(file=dashboard_path / "serena-icon-48.png") + self.root.iconphoto(False, icon_48, icon_32, icon_16) + + # Start the Tkinter event loop + self.root.mainloop() + + except Exception as e: + print(f"Error in GUI thread: {e}", file=sys.stderr) + finally: + self.running = False + + def _shutdown_server(self) -> None: + log.info("Shutting down Serena") + # noinspection PyUnresolvedReferences + # noinspection PyProtectedMember + os._exit(0) + + +class GuiLogViewerHandler(logging.Handler): + """ + A logging handler that sends log records to a ThreadedLogViewer instance. + This handler can be integrated with Python's standard logging module + to direct log entries to a GUI log viewer. + """ + + def __init__( + self, + log_viewer: GuiLogViewer, + level=logging.NOTSET, + format_string: str | None = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s:%(lineno)d - %(message)s", + ): + """ + Initialize the handler with a ThreadedLogViewer instance. + + Args: + log_viewer: A ThreadedLogViewer instance that will display the logs + level: The logging level (default: NOTSET which captures all logs) + format_string: the format string + + """ + super().__init__(level) + self.log_viewer = log_viewer + self.formatter = logging.Formatter(format_string) + + # Start the log viewer if it's not already running + if not self.log_viewer.running: + self.log_viewer.start() + + @classmethod + def is_instance_registered(cls) -> bool: + for h in logging.Logger.root.handlers: + if isinstance(h, cls): + return True + return False + + def emit(self, record): + """ + Emit a log record to the ThreadedLogViewer. + + Args: + record: The log record to emit + + """ + try: + # Format the record according to the formatter + msg = self.format(record) + + # Convert the level name to a standard format for the viewer + level_prefix = record.levelname + + # Add the appropriate prefix if it's not already there + if not msg.startswith(level_prefix): + msg = f"{level_prefix}: {msg}" + + self.log_viewer.add_log(msg) + + except Exception: + self.handleError(record) + + def close(self): + """ + Close the handler and optionally stop the log viewer. + """ + # We don't automatically stop the log viewer here as it might + # be used by other handlers or directly by the application + super().close() + + def stop_viewer(self): + """ + Explicitly stop the associated log viewer. + """ + if self.log_viewer.running: + self.log_viewer.stop() + + +def show_fatal_exception(e: Exception): + """ + Makes sure the given exception is shown in the GUI log viewer, + either an existing instance or a new one. + + :param e: the exception to display + """ + # show in new window in main thread (user must close it) + log_viewer = GuiLogViewer("error") + exc_info = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + log_viewer.add_log(f"ERROR Fatal exception: {e}\n{exc_info}") + log_viewer.run_gui() diff --git a/Libraries/serena/mcp.py b/Libraries/serena/mcp.py new file mode 100644 index 00000000..f8f75ae0 --- /dev/null +++ b/Libraries/serena/mcp.py @@ -0,0 +1,348 @@ +""" +The Serena Model Context Protocol (MCP) Server +""" + +import sys +from abc import abstractmethod +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import asynccontextmanager +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Literal, cast + +import docstring_parser +from mcp.server.fastmcp import server +from mcp.server.fastmcp.server import FastMCP, Settings +from mcp.server.fastmcp.tools.base import Tool as MCPTool +from pydantic_settings import SettingsConfigDict +from sensai.util import logging + +from serena.agent import ( + SerenaAgent, + SerenaConfig, +) +from serena.config.context_mode import SerenaAgentContext, SerenaAgentMode +from serena.constants import DEFAULT_CONTEXT, DEFAULT_MODES, SERENA_LOG_FORMAT +from serena.tools import Tool +from serena.util.exception import show_fatal_exception_safe +from serena.util.logging import MemoryLogHandler + +log = logging.getLogger(__name__) + + +def configure_logging(*args, **kwargs) -> None: # type: ignore + # We only do something here if logging has not yet been configured. + # Normally, logging is configured in the MCP server startup script. + if not logging.is_enabled(): + logging.basicConfig(level=logging.INFO, stream=sys.stderr, format=SERENA_LOG_FORMAT) + + +# patch the logging configuration function in fastmcp, because it's hard-coded and broken +server.configure_logging = configure_logging # type: ignore + + +@dataclass +class SerenaMCPRequestContext: + agent: SerenaAgent + + +class SerenaMCPFactory: + def __init__(self, context: str = DEFAULT_CONTEXT, project: str | None = None): + """ + :param context: The context name or path to context file + :param project: Either an absolute path to the project directory or a name of an already registered project. + If the project passed here hasn't been registered yet, it will be registered automatically and can be activated by its name + afterward. + """ + self.context = SerenaAgentContext.load(context) + self.project = project + + @staticmethod + def _sanitize_for_openai_tools(schema: dict) -> dict: + """ + This method was written by GPT-5, I have not reviewed it in detail. + Only called when `openai_tool_compatible` is True. + + Make a Pydantic/JSON Schema object compatible with OpenAI tool schema. + - 'integer' -> 'number' (+ multipleOf: 1) + - remove 'null' from union type arrays + - coerce integer-only enums to number + - best-effort simplify oneOf/anyOf when they only differ by integer/number + """ + s = deepcopy(schema) + + def walk(node): # type: ignore + if not isinstance(node, dict): + # lists get handled by parent calls + return node + + # ---- handle type ---- + t = node.get("type") + if isinstance(t, str): + if t == "integer": + node["type"] = "number" + # preserve existing multipleOf but ensure it's integer-like + if "multipleOf" not in node: + node["multipleOf"] = 1 + elif isinstance(t, list): + # remove 'null' (OpenAI tools don't support nullables) + t2 = [x if x != "integer" else "number" for x in t if x != "null"] + if not t2: + # fall back to object if it somehow becomes empty + t2 = ["object"] + node["type"] = t2[0] if len(t2) == 1 else t2 + if "integer" in t or "number" in t2: + # if integers were present, keep integer-like restriction + node.setdefault("multipleOf", 1) + + # ---- enums of integers -> number ---- + if "enum" in node and isinstance(node["enum"], list): + vals = node["enum"] + if vals and all(isinstance(v, int) for v in vals): + node.setdefault("type", "number") + # keep them as ints; JSON 'number' covers ints + node.setdefault("multipleOf", 1) + + # ---- simplify anyOf/oneOf if they only differ by integer/number ---- + for key in ("oneOf", "anyOf"): + if key in node and isinstance(node[key], list): + # Special case: anyOf or oneOf with "type X" and "null" + if len(node[key]) == 2: + types = [sub.get("type") for sub in node[key]] + if "null" in types: + non_null_type = next(t for t in types if t != "null") + if isinstance(non_null_type, str): + node["type"] = non_null_type + node.pop(key, None) + continue + simplified = [] + changed = False + for sub in node[key]: + sub = walk(sub) # recurse + simplified.append(sub) + # If all subs are the same after integer→number, collapse + try: + import json + + canon = [json.dumps(x, sort_keys=True) for x in simplified] + if len(set(canon)) == 1: + # copy the single schema up + only = simplified[0] + node.pop(key, None) + for k, v in only.items(): + if k not in node: + node[k] = v + changed = True + except Exception: + pass + if not changed: + node[key] = simplified + + # ---- recurse into known schema containers ---- + for child_key in ("properties", "patternProperties", "definitions", "$defs"): + if child_key in node and isinstance(node[child_key], dict): + for k, v in list(node[child_key].items()): + node[child_key][k] = walk(v) + + # arrays/items + if "items" in node: + node["items"] = walk(node["items"]) + + # allOf/if/then/else - pass through with integer→number conversions applied inside + for key in ("allOf",): + if key in node and isinstance(node[key], list): + node[key] = [walk(x) for x in node[key]] + + if "if" in node: + node["if"] = walk(node["if"]) + if "then" in node: + node["then"] = walk(node["then"]) + if "else" in node: + node["else"] = walk(node["else"]) + + return node + + return walk(s) + + @staticmethod + def make_mcp_tool(tool: Tool, openai_tool_compatible: bool = True) -> MCPTool: + """ + Create an MCP tool from a Serena Tool instance. + + :param tool: The Serena Tool instance to convert. + :param openai_tool_compatible: whether to process the tool schema to be compatible with OpenAI tools + (doesn't accept integer, needs number instead, etc.). This allows using Serena MCP within codex. + """ + func_name = tool.get_name() + func_doc = tool.get_apply_docstring() or "" + func_arg_metadata = tool.get_apply_fn_metadata() + is_async = False + parameters = func_arg_metadata.arg_model.model_json_schema() + if openai_tool_compatible: + parameters = SerenaMCPFactory._sanitize_for_openai_tools(parameters) + + docstring = docstring_parser.parse(func_doc) + + # Mount the tool description as a combination of the docstring description and + # the return value description, if it exists. + overridden_description = tool.agent.get_context().tool_description_overrides.get(func_name, None) + + if overridden_description is not None: + func_doc = overridden_description + elif docstring.description: + func_doc = docstring.description + else: + func_doc = "" + func_doc = func_doc.strip().strip(".") + if func_doc: + func_doc += "." + if docstring.returns and (docstring_returns_descr := docstring.returns.description): + # Only add a space before "Returns" if func_doc is not empty + prefix = " " if func_doc else "" + func_doc = f"{func_doc}{prefix}Returns {docstring_returns_descr.strip().strip('.')}." + + # Parse the parameter descriptions from the docstring and add pass its description + # to the parameter schema. + docstring_params = {param.arg_name: param for param in docstring.params} + parameters_properties: dict[str, dict[str, Any]] = parameters["properties"] + for parameter, properties in parameters_properties.items(): + if (param_doc := docstring_params.get(parameter)) and param_doc.description: + param_desc = f"{param_doc.description.strip().strip('.') + '.'}" + properties["description"] = param_desc[0].upper() + param_desc[1:] + + def execute_fn(**kwargs) -> str: # type: ignore + return tool.apply_ex(log_call=True, catch_exceptions=True, **kwargs) + + return MCPTool( + fn=execute_fn, + name=func_name, + description=func_doc, + parameters=parameters, + fn_metadata=func_arg_metadata, + is_async=is_async, + context_kwarg=None, + annotations=None, + title=None, + ) + + @abstractmethod + def _iter_tools(self) -> Iterator[Tool]: + pass + + # noinspection PyProtectedMember + def _set_mcp_tools(self, mcp: FastMCP, openai_tool_compatible: bool = False) -> None: + """Update the tools in the MCP server""" + if mcp is not None: + mcp._tool_manager._tools = {} + for tool in self._iter_tools(): + mcp_tool = self.make_mcp_tool(tool, openai_tool_compatible=openai_tool_compatible) + mcp._tool_manager._tools[tool.get_name()] = mcp_tool + log.info(f"Starting MCP server with {len(mcp._tool_manager._tools)} tools: {list(mcp._tool_manager._tools.keys())}") + + @abstractmethod + def _instantiate_agent(self, serena_config: SerenaConfig, modes: list[SerenaAgentMode]) -> None: + pass + + def create_mcp_server( + self, + host: str = "0.0.0.0", + port: int = 8000, + modes: Sequence[str] = DEFAULT_MODES, + enable_web_dashboard: bool | None = None, + enable_gui_log_window: bool | None = None, + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | None = None, + trace_lsp_communication: bool | None = None, + tool_timeout: float | None = None, + ) -> FastMCP: + """ + Create an MCP server with process-isolated SerenaAgent to prevent asyncio contamination. + + :param host: The host to bind to + :param port: The port to bind to + :param modes: List of mode names or paths to mode files + :param enable_web_dashboard: Whether to enable the web dashboard. If not specified, will take the value from the serena configuration. + :param enable_gui_log_window: Whether to enable the GUI log window. It currently does not work on macOS, and setting this to True will be ignored then. + If not specified, will take the value from the serena configuration. + :param log_level: Log level. If not specified, will take the value from the serena configuration. + :param trace_lsp_communication: Whether to trace the communication between Serena and the language servers. + This is useful for debugging language server issues. + :param tool_timeout: Timeout in seconds for tool execution. If not specified, will take the value from the serena configuration. + """ + try: + config = SerenaConfig.from_config_file() + + # update configuration with the provided parameters + if enable_web_dashboard is not None: + config.web_dashboard = enable_web_dashboard + if enable_gui_log_window is not None: + config.gui_log_window_enabled = enable_gui_log_window + if log_level is not None: + log_level = cast(Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], log_level.upper()) + config.log_level = logging.getLevelNamesMapping()[log_level] + if trace_lsp_communication is not None: + config.trace_lsp_communication = trace_lsp_communication + if tool_timeout is not None: + config.tool_timeout = tool_timeout + + modes_instances = [SerenaAgentMode.load(mode) for mode in modes] + self._instantiate_agent(config, modes_instances) + + except Exception as e: + show_fatal_exception_safe(e) + raise + + # Override model_config to disable the use of `.env` files for reading settings, because user projects are likely to contain + # `.env` files (e.g. containing LOG_LEVEL) that are not supposed to override the MCP settings; + # retain only FASTMCP_ prefix for already set environment variables. + Settings.model_config = SettingsConfigDict(env_prefix="FASTMCP_") + instructions = self._get_initial_instructions() + mcp = FastMCP(lifespan=self.server_lifespan, host=host, port=port, instructions=instructions) + return mcp + + @asynccontextmanager + @abstractmethod + async def server_lifespan(self, mcp_server: FastMCP) -> AsyncIterator[None]: + """Manage server startup and shutdown lifecycle.""" + yield None # ensures MyPy understands we yield None + + @abstractmethod + def _get_initial_instructions(self) -> str: + pass + + +class SerenaMCPFactorySingleProcess(SerenaMCPFactory): + """ + MCP server factory where the SerenaAgent and its language server run in the same process as the MCP server + """ + + def __init__(self, context: str = DEFAULT_CONTEXT, project: str | None = None, memory_log_handler: MemoryLogHandler | None = None): + """ + :param context: The context name or path to context file + :param project: Either an absolute path to the project directory or a name of an already registered project. + If the project passed here hasn't been registered yet, it will be registered automatically and can be activated by its name + afterward. + """ + super().__init__(context=context, project=project) + self.agent: SerenaAgent | None = None + self.memory_log_handler = memory_log_handler + + def _instantiate_agent(self, serena_config: SerenaConfig, modes: list[SerenaAgentMode]) -> None: + self.agent = SerenaAgent( + project=self.project, serena_config=serena_config, context=self.context, modes=modes, memory_log_handler=self.memory_log_handler + ) + + def _iter_tools(self) -> Iterator[Tool]: + assert self.agent is not None + yield from self.agent.get_exposed_tool_instances() + + def _get_initial_instructions(self) -> str: + assert self.agent is not None + # we don't use the tool (which at the time of writing calls this method), since the tool may be disabled by the config + return self.agent.create_system_prompt() + + @asynccontextmanager + async def server_lifespan(self, mcp_server: FastMCP) -> AsyncIterator[None]: + openai_tool_compatible = self.context.name in ["chatgpt", "codex", "oaicompat-agent"] + self._set_mcp_tools(mcp_server, openai_tool_compatible=openai_tool_compatible) + log.info("MCP server lifetime setup complete") + yield diff --git a/Libraries/serena/project.py b/Libraries/serena/project.py new file mode 100644 index 00000000..0673f05e --- /dev/null +++ b/Libraries/serena/project.py @@ -0,0 +1,314 @@ +import logging +import os +from pathlib import Path +from typing import Any + +import pathspec + +from serena.config.serena_config import DEFAULT_TOOL_TIMEOUT, ProjectConfig +from serena.constants import SERENA_MANAGED_DIR_IN_HOME, SERENA_MANAGED_DIR_NAME +from serena.text_utils import MatchedConsecutiveLines, search_files +from serena.util.file_system import GitignoreParser, match_path +from solidlsp import SolidLanguageServer +from solidlsp.ls_config import Language, LanguageServerConfig +from solidlsp.ls_logger import LanguageServerLogger +from solidlsp.settings import SolidLSPSettings + +log = logging.getLogger(__name__) + + +class Project: + def __init__(self, project_root: str, project_config: ProjectConfig, is_newly_created: bool = False): + self.project_root = project_root + self.project_config = project_config + self.is_newly_created = is_newly_created + + # create .gitignore file in the project's Serena data folder if not yet present + serena_data_gitignore_path = os.path.join(self.path_to_serena_data_folder(), ".gitignore") + if not os.path.exists(serena_data_gitignore_path): + os.makedirs(os.path.dirname(serena_data_gitignore_path), exist_ok=True) + log.info(f"Creating .gitignore file in {serena_data_gitignore_path}") + with open(serena_data_gitignore_path, "w", encoding="utf-8") as f: + f.write(f"/{SolidLanguageServer.CACHE_FOLDER_NAME}\n") + + # gather ignored paths from the project configuration and gitignore files + ignored_patterns = project_config.ignored_paths + if len(ignored_patterns) > 0: + log.info(f"Using {len(ignored_patterns)} ignored paths from the explicit project configuration.") + log.debug(f"Ignored paths: {ignored_patterns}") + if project_config.ignore_all_files_in_gitignore: + gitignore_parser = GitignoreParser(self.project_root) + for spec in gitignore_parser.get_ignore_specs(): + log.debug(f"Adding {len(spec.patterns)} patterns from {spec.file_path} to the ignored paths.") + ignored_patterns.extend(spec.patterns) + self._ignored_patterns = ignored_patterns + + # Set up the pathspec matcher for the ignored paths + # for all absolute paths in ignored_paths, convert them to relative paths + processed_patterns = [] + for pattern in set(ignored_patterns): + # Normalize separators (pathspec expects forward slashes) + pattern = pattern.replace(os.path.sep, "/") + processed_patterns.append(pattern) + log.debug(f"Processing {len(processed_patterns)} ignored paths") + self._ignore_spec = pathspec.PathSpec.from_lines(pathspec.patterns.GitWildMatchPattern, processed_patterns) + + @property + def project_name(self) -> str: + return self.project_config.project_name + + @property + def language(self) -> Language: + return self.project_config.language + + @classmethod + def load(cls, project_root: str | Path, autogenerate: bool = True) -> "Project": + project_root = Path(project_root).resolve() + if not project_root.exists(): + raise FileNotFoundError(f"Project root not found: {project_root}") + project_config = ProjectConfig.load(project_root, autogenerate=autogenerate) + return Project(project_root=str(project_root), project_config=project_config) + + def path_to_serena_data_folder(self) -> str: + return os.path.join(self.project_root, SERENA_MANAGED_DIR_NAME) + + def path_to_project_yml(self) -> str: + return os.path.join(self.project_root, self.project_config.rel_path_to_project_yml()) + + def read_file(self, relative_path: str) -> str: + """ + Reads a file relative to the project root. + + :param relative_path: the path to the file relative to the project root + :return: the content of the file + """ + abs_path = Path(self.project_root) / relative_path + if not abs_path.exists(): + raise FileNotFoundError(f"File not found: {abs_path}") + return abs_path.read_text(encoding=self.project_config.encoding) + + def get_ignore_spec(self) -> pathspec.PathSpec: + """ + :return: the pathspec matcher for the paths that were configured to be ignored, + either explicitly or implicitly through .gitignore files. + """ + return self._ignore_spec + + def _is_ignored_relative_path(self, relative_path: str | Path, ignore_non_source_files: bool = True) -> bool: + """ + Determine whether an existing path should be ignored based on file type and ignore patterns. + Raises `FileNotFoundError` if the path does not exist. + + :param relative_path: Relative path to check + :param ignore_non_source_files: whether files that are not source files (according to the file masks + determined by the project's programming language) shall be ignored + + :return: whether the path should be ignored + """ + abs_path = os.path.join(self.project_root, relative_path) + if not os.path.exists(abs_path): + raise FileNotFoundError(f"File {abs_path} not found, the ignore check cannot be performed") + + # Check file extension if it's a file + is_file = os.path.isfile(abs_path) + if is_file and ignore_non_source_files: + fn_matcher = self.language.get_source_fn_matcher() + if not fn_matcher.is_relevant_filename(abs_path): + return True + + # Create normalized path for consistent handling + rel_path = Path(relative_path) + + # always ignore paths inside .git + if len(rel_path.parts) > 0 and rel_path.parts[0] == ".git": + return True + + return match_path(str(relative_path), self.get_ignore_spec(), root_path=self.project_root) + + def is_ignored_path(self, path: str | Path, ignore_non_source_files: bool = False) -> bool: + """ + Checks whether the given path is ignored + + :param path: the path to check, can be absolute or relative + :param ignore_non_source_files: whether to ignore files that are not source files + (according to the file masks determined by the project's programming language) + """ + path = Path(path) + if path.is_absolute(): + try: + relative_path = path.relative_to(self.project_root) + except ValueError: + # If the path is not relative to the project root, we consider it as an absolute path outside the project + # (which we ignore) + log.warning(f"Path {path} is not relative to the project root {self.project_root} and was therefore ignored") + return True + else: + relative_path = path + + return self._is_ignored_relative_path(str(relative_path), ignore_non_source_files=ignore_non_source_files) + + def is_path_in_project(self, path: str | Path) -> bool: + """ + Checks if the given (absolute or relative) path is inside the project directory. + Note that even relative paths may be outside if they contain ".." or point to symlinks. + """ + path = Path(path) + _proj_root = Path(self.project_root) + if not path.is_absolute(): + path = _proj_root / path + + path = path.resolve() + return path.is_relative_to(_proj_root) + + def relative_path_exists(self, relative_path: str) -> bool: + """ + Checks if the given relative path exists in the project directory. + + :param relative_path: the path to check, relative to the project root + :return: True if the path exists, False otherwise + """ + abs_path = Path(self.project_root) / relative_path + return abs_path.exists() + + def validate_relative_path(self, relative_path: str) -> None: + """ + Validates that the given relative path to an existing file/dir is safe to read or edit, + meaning it's inside the project directory and is not ignored by git. + + Passing a path to a non-existing file will lead to a `FileNotFoundError`. + """ + if not self.is_path_in_project(relative_path): + raise ValueError(f"{relative_path=} points to path outside of the repository root; cannot access for safety reasons") + + if self.is_ignored_path(relative_path): + raise ValueError(f"Path {relative_path} is ignored; cannot access for safety reasons") + + def gather_source_files(self, relative_path: str = "") -> list[str]: + """Retrieves relative paths of all source files, optionally limited to the given path + + :param relative_path: if provided, restrict search to this path + """ + rel_file_paths = [] + start_path = os.path.join(self.project_root, relative_path) + if not os.path.exists(start_path): + raise FileNotFoundError(f"Relative path {start_path} not found.") + if os.path.isfile(start_path): + return [relative_path] + else: + for root, dirs, files in os.walk(start_path, followlinks=True): + # prevent recursion into ignored directories + dirs[:] = [d for d in dirs if not self.is_ignored_path(os.path.join(root, d))] + + # collect non-ignored files + for file in files: + abs_file_path = os.path.join(root, file) + try: + if not self.is_ignored_path(abs_file_path, ignore_non_source_files=True): + try: + rel_file_path = os.path.relpath(abs_file_path, start=self.project_root) + except Exception: + log.warning( + "Ignoring path '%s' because it appears to be outside of the project root (%s)", + abs_file_path, + self.project_root, + ) + continue + rel_file_paths.append(rel_file_path) + except FileNotFoundError: + log.warning( + f"File {abs_file_path} not found (possibly due it being a symlink), skipping it in request_parsed_files", + ) + return rel_file_paths + + def search_source_files_for_pattern( + self, + pattern: str, + relative_path: str = "", + context_lines_before: int = 0, + context_lines_after: int = 0, + paths_include_glob: str | None = None, + paths_exclude_glob: str | None = None, + ) -> list[MatchedConsecutiveLines]: + """ + Search for a pattern across all (non-ignored) source files + + :param pattern: Regular expression pattern to search for, either as a compiled Pattern or string + :param relative_path: + :param context_lines_before: Number of lines of context to include before each match + :param context_lines_after: Number of lines of context to include after each match + :param paths_include_glob: Glob pattern to filter which files to include in the search + :param paths_exclude_glob: Glob pattern to filter which files to exclude from the search. Takes precedence over paths_include_glob. + :return: List of matched consecutive lines with context + """ + relative_file_paths = self.gather_source_files(relative_path=relative_path) + return search_files( + relative_file_paths, + pattern, + root_path=self.project_root, + context_lines_before=context_lines_before, + context_lines_after=context_lines_after, + paths_include_glob=paths_include_glob, + paths_exclude_glob=paths_exclude_glob, + ) + + def retrieve_content_around_line( + self, relative_file_path: str, line: int, context_lines_before: int = 0, context_lines_after: int = 0 + ) -> MatchedConsecutiveLines: + """ + Retrieve the content of the given file around the given line. + + :param relative_file_path: The relative path of the file to retrieve the content from + :param line: The line number to retrieve the content around + :param context_lines_before: The number of lines to retrieve before the given line + :param context_lines_after: The number of lines to retrieve after the given line + + :return MatchedConsecutiveLines: A container with the desired lines. + """ + file_contents = self.read_file(relative_file_path) + return MatchedConsecutiveLines.from_file_contents( + file_contents, + line=line, + context_lines_before=context_lines_before, + context_lines_after=context_lines_after, + source_file_path=relative_file_path, + ) + + def create_language_server( + self, + log_level: int = logging.INFO, + ls_timeout: float | None = DEFAULT_TOOL_TIMEOUT - 5, + trace_lsp_communication: bool = False, + ls_specific_settings: dict[Language, Any] | None = None, + ) -> SolidLanguageServer: + """ + Create a language server for a project. Note that you will have to start it + before performing any LS operations. + + :param project: either a path to the project root or a ProjectConfig instance. + If no project.yml is found, the default project configuration will be used. + :param log_level: the log level for the language server + :param ls_timeout: the timeout for the language server + :param trace_lsp_communication: whether to trace LSP communication + :param ls_specific_settings: optional LS specific configuration of the language server, + see docstrings in the inits of subclasses of SolidLanguageServer to see what values may be passed. + :return: the language server + """ + ls_config = LanguageServerConfig( + code_language=self.language, + ignored_paths=self._ignored_patterns, + trace_lsp_communication=trace_lsp_communication, + ) + ls_logger = LanguageServerLogger(log_level=log_level) + + log.info(f"Creating language server instance for {self.project_root}.") + return SolidLanguageServer.create( + ls_config, + ls_logger, + self.project_root, + timeout=ls_timeout, + solidlsp_settings=SolidLSPSettings( + solidlsp_dir=SERENA_MANAGED_DIR_IN_HOME, + project_data_relative_path=SERENA_MANAGED_DIR_NAME, + ls_specific_settings=ls_specific_settings or {}, + ), + ) diff --git a/Libraries/serena/prompt_factory.py b/Libraries/serena/prompt_factory.py new file mode 100644 index 00000000..40f51e7a --- /dev/null +++ b/Libraries/serena/prompt_factory.py @@ -0,0 +1,14 @@ +import os + +from serena.constants import PROMPT_TEMPLATES_DIR_IN_USER_HOME, PROMPT_TEMPLATES_DIR_INTERNAL +from serena.generated.generated_prompt_factory import PromptFactory + + +class SerenaPromptFactory(PromptFactory): + """ + A class for retrieving and rendering prompt templates and prompt lists. + """ + + def __init__(self) -> None: + os.makedirs(PROMPT_TEMPLATES_DIR_IN_USER_HOME, exist_ok=True) + super().__init__(prompts_dir=[PROMPT_TEMPLATES_DIR_IN_USER_HOME, PROMPT_TEMPLATES_DIR_INTERNAL]) diff --git a/Libraries/serena/symbol.py b/Libraries/serena/symbol.py new file mode 100644 index 00000000..9d89c5fb --- /dev/null +++ b/Libraries/serena/symbol.py @@ -0,0 +1,645 @@ +import json +import logging +import os +from abc import ABC, abstractmethod +from collections.abc import Iterator, Sequence +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, Self, Union + +from sensai.util.string import ToStringMixin + +from solidlsp import SolidLanguageServer +from solidlsp.ls import ReferenceInSymbol as LSPReferenceInSymbol +from solidlsp.ls_types import Position, SymbolKind, UnifiedSymbolInformation + +from .project import Project + +if TYPE_CHECKING: + from .agent import SerenaAgent + +log = logging.getLogger(__name__) + + +@dataclass +class LanguageServerSymbolLocation: + """ + Represents the (start) location of a symbol identifier, which, within Serena, uniquely identifies the symbol. + """ + + relative_path: str | None + """ + the relative path of the file containing the symbol; if None, the symbol is defined outside of the project's scope + """ + line: int | None + """ + the line number in which the symbol identifier is defined (if the symbol is a function, class, etc.); + may be None for some types of symbols (e.g. SymbolKind.File) + """ + column: int | None + """ + the column number in which the symbol identifier is defined (if the symbol is a function, class, etc.); + may be None for some types of symbols (e.g. SymbolKind.File) + """ + + def __post_init__(self) -> None: + if self.relative_path is not None: + self.relative_path = self.relative_path.replace("/", os.path.sep) + + def to_dict(self, include_relative_path: bool = True) -> dict[str, Any]: + result = asdict(self) + if not include_relative_path: + result.pop("relative_path", None) + return result + + def has_position_in_file(self) -> bool: + return self.relative_path is not None and self.line is not None and self.column is not None + + +@dataclass +class PositionInFile: + """ + Represents a character position within a file + """ + + line: int + """ + the 0-based line number in the file + """ + col: int + """ + the 0-based column + """ + + def to_lsp_position(self) -> Position: + """ + Convert to LSP Position. + """ + return Position(line=self.line, character=self.col) + + +class Symbol(ABC): + @abstractmethod + def get_body_start_position(self) -> PositionInFile | None: + pass + + @abstractmethod + def get_body_end_position(self) -> PositionInFile | None: + pass + + def get_body_start_position_or_raise(self) -> PositionInFile: + """ + Get the start position of the symbol body, raising an error if it is not defined. + """ + pos = self.get_body_start_position() + if pos is None: + raise ValueError(f"Body start position is not defined for {self}") + return pos + + def get_body_end_position_or_raise(self) -> PositionInFile: + """ + Get the end position of the symbol body, raising an error if it is not defined. + """ + pos = self.get_body_end_position() + if pos is None: + raise ValueError(f"Body end position is not defined for {self}") + return pos + + @abstractmethod + def is_neighbouring_definition_separated_by_empty_line(self) -> bool: + """ + :return: whether a symbol definition of this symbol's kind is usually separated from the + previous/next definition by at least one empty line. + """ + + +class LanguageServerSymbol(Symbol, ToStringMixin): + _NAME_PATH_SEP = "/" + + @staticmethod + def match_name_path( + name_path: str, + symbol_name_path_parts: list[str], + substring_matching: bool, + ) -> bool: + """ + Checks if a given `name_path` matches a symbol's qualified name parts. + See docstring of `Symbol.find` for more details. + """ + assert name_path, "name_path must not be empty" + assert symbol_name_path_parts, "symbol_name_path_parts must not be empty" + name_path_sep = LanguageServerSymbol._NAME_PATH_SEP + + is_absolute_pattern = name_path.startswith(name_path_sep) + pattern_parts = name_path.lstrip(name_path_sep).rstrip(name_path_sep).split(name_path_sep) + + # filtering based on ancestors + if len(pattern_parts) > len(symbol_name_path_parts): + # can't possibly match if pattern has more parts than symbol + return False + if is_absolute_pattern and len(pattern_parts) != len(symbol_name_path_parts): + # for absolute patterns, the number of parts must match exactly + return False + if symbol_name_path_parts[-len(pattern_parts) : -1] != pattern_parts[:-1]: + # ancestors must match + return False + + # matching the last part of the symbol name + name_to_match = pattern_parts[-1] + symbol_name = symbol_name_path_parts[-1] + if substring_matching: + return name_to_match in symbol_name + else: + return name_to_match == symbol_name + + def __init__(self, symbol_root_from_ls: UnifiedSymbolInformation) -> None: + self.symbol_root = symbol_root_from_ls + + def _tostring_includes(self) -> list[str]: + return [] + + def _tostring_additional_entries(self) -> dict[str, Any]: + return dict(name=self.name, kind=self.kind, num_children=len(self.symbol_root["children"])) + + @property + def name(self) -> str: + return self.symbol_root["name"] + + @property + def kind(self) -> str: + return SymbolKind(self.symbol_kind).name + + @property + def symbol_kind(self) -> SymbolKind: + return self.symbol_root["kind"] + + def is_neighbouring_definition_separated_by_empty_line(self) -> bool: + return self.symbol_kind in (SymbolKind.Function, SymbolKind.Method, SymbolKind.Class, SymbolKind.Interface, SymbolKind.Struct) + + @property + def relative_path(self) -> str | None: + location = self.symbol_root.get("location") + if location: + return location.get("relativePath") + return None + + @property + def location(self) -> LanguageServerSymbolLocation: + """ + :return: the start location of the actual symbol identifier + """ + return LanguageServerSymbolLocation(relative_path=self.relative_path, line=self.line, column=self.column) + + @property + def body_start_position(self) -> Position | None: + location = self.symbol_root.get("location") + if location: + range_info = location.get("range") + if range_info: + start_pos = range_info.get("start") + if start_pos: + return start_pos + return None + + @property + def body_end_position(self) -> Position | None: + location = self.symbol_root.get("location") + if location: + range_info = location.get("range") + if range_info: + end_pos = range_info.get("end") + if end_pos: + return end_pos + return None + + def get_body_start_position(self) -> PositionInFile | None: + start_pos = self.body_start_position + if start_pos is None: + return None + return PositionInFile(line=start_pos["line"], col=start_pos["character"]) + + def get_body_end_position(self) -> PositionInFile | None: + end_pos = self.body_end_position + if end_pos is None: + return None + return PositionInFile(line=end_pos["line"], col=end_pos["character"]) + + def get_body_line_numbers(self) -> tuple[int | None, int | None]: + start_pos = self.body_start_position + end_pos = self.body_end_position + start_line = start_pos["line"] if start_pos else None + end_line = end_pos["line"] if end_pos else None + return start_line, end_line + + @property + def line(self) -> int | None: + """ + :return: the line in which the symbol identifier is defined. + """ + if "selectionRange" in self.symbol_root: + return self.symbol_root["selectionRange"]["start"]["line"] + else: + # line is expected to be undefined for some types of symbols (e.g. SymbolKind.File) + return None + + @property + def column(self) -> int | None: + if "selectionRange" in self.symbol_root: + return self.symbol_root["selectionRange"]["start"]["character"] + else: + # precise location is expected to be undefined for some types of symbols (e.g. SymbolKind.File) + return None + + @property + def body(self) -> str | None: + return self.symbol_root.get("body") + + def get_name_path(self) -> str: + """ + Get the name path of the symbol (e.g. "class/method/inner_function"). + """ + return self._NAME_PATH_SEP.join(self.get_name_path_parts()) + + def get_name_path_parts(self) -> list[str]: + """ + Get the parts of the name path of the symbol (e.g. ["class", "method", "inner_function"]). + """ + ancestors_within_file = list(self.iter_ancestors(up_to_symbol_kind=SymbolKind.File)) + ancestors_within_file.reverse() + return [a.name for a in ancestors_within_file] + [self.name] + + def iter_children(self) -> Iterator[Self]: + for c in self.symbol_root["children"]: + yield self.__class__(c) + + def iter_ancestors(self, up_to_symbol_kind: SymbolKind | None = None) -> Iterator[Self]: + """ + Iterate over all ancestors of the symbol, starting with the parent and going up to the root or + the given symbol kind. + + :param up_to_symbol_kind: if provided, iteration will stop *before* the first ancestor of the given kind. + A typical use case is to pass `SymbolKind.File` or `SymbolKind.Package`. + """ + parent = self.get_parent() + if parent is not None: + if up_to_symbol_kind is None or parent.symbol_kind != up_to_symbol_kind: + yield parent + yield from parent.iter_ancestors(up_to_symbol_kind=up_to_symbol_kind) + + def get_parent(self) -> Self | None: + parent_root = self.symbol_root.get("parent") + if parent_root is None: + return None + return self.__class__(parent_root) + + def find( + self, + name_path: str, + substring_matching: bool = False, + include_kinds: Sequence[SymbolKind] | None = None, + exclude_kinds: Sequence[SymbolKind] | None = None, + ) -> list[Self]: + """ + Find all symbols within the symbol's subtree that match the given `name_path`. + The matching behavior is determined by the structure of `name_path`, which can + either be a simple name (e.g. "method") or a name path like "class/method" (relative name path) + or "/class/method" (absolute name path). + + Key aspects of the name path matching behavior: + - Trailing slashes in `name_path` play no role and are ignored. + - The name of the retrieved symbols will match (either exactly or as a substring) + the last segment of `name_path`, while other segments will restrict the search to symbols that + have a desired sequence of ancestors. + - If there is no starting or intermediate slash in `name_path`, there is no + restriction on the ancestor symbols. For example, passing `method` will match + against symbols with name paths like `method`, `class/method`, `class/nested_class/method`, etc. + - If `name_path` contains a `/` but doesn't start with a `/`, the matching is restricted to symbols + with the same ancestors as the last segment of `name_path`. For example, passing `class/method` will match against + `class/method` as well as `nested_class/class/method` but not `method`. + - If `name_path` starts with a `/`, it will be treated as an absolute name path pattern, meaning + that the first segment of it must match the first segment of the symbol's name path. + For example, passing `/class` will match only against top-level symbols like `class` but not against `nested_class/class`. + Passing `/class/method` will match against `class/method` but not `nested_class/class/method` or `method`. + + :param name_path: the name path to match against + :param substring_matching: whether to use substring matching (as opposed to exact matching) + of the last segment of `name_path` against the symbol name. + :param include_kinds: an optional sequence of ints representing the LSP symbol kind. + If provided, only symbols of the given kinds will be included in the result. + :param exclude_kinds: If provided, symbols of the given kinds will be excluded from the result. + """ + result = [] + + def should_include(s: "LanguageServerSymbol") -> bool: + if include_kinds is not None and s.symbol_kind not in include_kinds: + return False + if exclude_kinds is not None and s.symbol_kind in exclude_kinds: + return False + return LanguageServerSymbol.match_name_path( + name_path=name_path, + symbol_name_path_parts=s.get_name_path_parts(), + substring_matching=substring_matching, + ) + + def traverse(s: "LanguageServerSymbol") -> None: + if should_include(s): + result.append(s) + for c in s.iter_children(): + traverse(c) + + traverse(self) + return result + + def to_dict( + self, + kind: bool = False, + location: bool = False, + depth: int = 0, + include_body: bool = False, + include_children_body: bool = False, + include_relative_path: bool = True, + ) -> dict[str, Any]: + """ + Converts the symbol to a dictionary. + + :param kind: whether to include the kind of the symbol + :param location: whether to include the location of the symbol + :param depth: the depth of the symbol + :param include_body: whether to include the body of the top-level symbol. + :param include_children_body: whether to also include the body of the children. + Note that the body of the children is part of the body of the parent symbol, + so there is usually no need to set this to True unless you want process the output + and pass the children without passing the parent body to the LM. + :param include_relative_path: whether to include the relative path of the symbol in the location + entry. Relative paths of the symbol's children are always excluded. + :return: a dictionary representation of the symbol + """ + result: dict[str, Any] = {"name": self.name, "name_path": self.get_name_path()} + + if kind: + result["kind"] = self.kind + + if location: + result["location"] = self.location.to_dict(include_relative_path=include_relative_path) + body_start_line, body_end_line = self.get_body_line_numbers() + result["body_location"] = {"start_line": body_start_line, "end_line": body_end_line} + + if include_body: + if self.body is None: + log.warning("Requested body for symbol, but it is not present. The symbol might have been loaded with include_body=False.") + result["body"] = self.body + + def add_children(s: Self) -> list[dict[str, Any]]: + children = [] + for c in s.iter_children(): + children.append( + c.to_dict( + kind=kind, + location=location, + depth=depth - 1, + include_body=include_children_body, + include_children_body=include_children_body, + # all children have the same relative path as the parent + include_relative_path=False, + ) + ) + return children + + if depth > 0: + result["children"] = add_children(self) + + return result + + +@dataclass +class ReferenceInLanguageServerSymbol(ToStringMixin): + """ + Represents the location of a reference to another symbol within a symbol/file. + + The contained symbol is the symbol within which the reference is located, + not the symbol that is referenced. + """ + + symbol: LanguageServerSymbol + """ + the symbol within which the reference is located + """ + line: int + """ + the line number in which the reference is located (0-based) + """ + character: int + """ + the column number in which the reference is located (0-based) + """ + + @classmethod + def from_lsp_reference(cls, reference: LSPReferenceInSymbol) -> Self: + return cls(symbol=LanguageServerSymbol(reference.symbol), line=reference.line, character=reference.character) + + def get_relative_path(self) -> str | None: + return self.symbol.location.relative_path + + +class LanguageServerSymbolRetriever: + def __init__(self, lang_server: SolidLanguageServer, agent: Union["SerenaAgent", None] = None) -> None: + """ + :param lang_server: the language server to use for symbol retrieval as well as editing operations. + :param agent: the agent to use (only needed for marking files as modified). You can pass None if you don't + need an agent to be aware of file modifications performed by the symbol manager. + """ + self._lang_server = lang_server + self.agent = agent + + def set_language_server(self, lang_server: SolidLanguageServer) -> None: + """ + Set the language server to use for symbol retrieval and editing operations. + This is useful if you want to change the language server after initializing the SymbolManager. + """ + self._lang_server = lang_server + + def get_language_server(self) -> SolidLanguageServer: + return self._lang_server + + def find_by_name( + self, + name_path: str, + include_body: bool = False, + include_kinds: Sequence[SymbolKind] | None = None, + exclude_kinds: Sequence[SymbolKind] | None = None, + substring_matching: bool = False, + within_relative_path: str | None = None, + ) -> list[LanguageServerSymbol]: + """ + Find all symbols that match the given name. See docstring of `Symbol.find` for more details. + The only parameter not mentioned there is `within_relative_path`, which can be used to restrict the search + to symbols within a specific file or directory. + """ + symbols: list[LanguageServerSymbol] = [] + symbol_roots = self._lang_server.request_full_symbol_tree(within_relative_path=within_relative_path, include_body=include_body) + for root in symbol_roots: + symbols.extend( + LanguageServerSymbol(root).find( + name_path, include_kinds=include_kinds, exclude_kinds=exclude_kinds, substring_matching=substring_matching + ) + ) + return symbols + + def get_document_symbols(self, relative_path: str) -> list[LanguageServerSymbol]: + symbol_dicts, _roots = self._lang_server.request_document_symbols(relative_path, include_body=False) + symbols = [LanguageServerSymbol(s) for s in symbol_dicts] + return symbols + + def find_by_location(self, location: LanguageServerSymbolLocation) -> LanguageServerSymbol | None: + if location.relative_path is None: + return None + symbol_dicts, _roots = self._lang_server.request_document_symbols(location.relative_path, include_body=False) + for symbol_dict in symbol_dicts: + symbol = LanguageServerSymbol(symbol_dict) + if symbol.location == location: + return symbol + return None + + def find_referencing_symbols( + self, + name_path: str, + relative_file_path: str, + include_body: bool = False, + include_kinds: Sequence[SymbolKind] | None = None, + exclude_kinds: Sequence[SymbolKind] | None = None, + ) -> list[ReferenceInLanguageServerSymbol]: + """ + Find all symbols that reference the symbol with the given name. + If multiple symbols fit the name (e.g. for variables that are overwritten), will use the first one. + + :param name_path: the name path of the symbol to find + :param relative_file_path: the relative path of the file in which the referenced symbol is defined. + :param include_body: whether to include the body of all symbols in the result. + Not recommended, as the referencing symbols will often be files, and thus the bodies will be very long. + :param include_kinds: which kinds of symbols to include in the result. + :param exclude_kinds: which kinds of symbols to exclude from the result. + """ + symbol_candidates = self.find_by_name(name_path, substring_matching=False, within_relative_path=relative_file_path) + if len(symbol_candidates) == 0: + log.warning(f"No symbol with name {name_path} found in file {relative_file_path}") + return [] + if len(symbol_candidates) > 1: + log.error( + f"Found {len(symbol_candidates)} symbols with name {name_path} in file {relative_file_path}." + f"May be an overwritten variable, in which case you can ignore this error. Proceeding with the first one. " + f"Found symbols for {name_path=} in {relative_file_path=}: \n" + f"{json.dumps([s.location.to_dict() for s in symbol_candidates], indent=2)}" + ) + symbol = symbol_candidates[0] + return self.find_referencing_symbols_by_location( + symbol.location, include_body=include_body, include_kinds=include_kinds, exclude_kinds=exclude_kinds + ) + + def find_referencing_symbols_by_location( + self, + symbol_location: LanguageServerSymbolLocation, + include_body: bool = False, + include_kinds: Sequence[SymbolKind] | None = None, + exclude_kinds: Sequence[SymbolKind] | None = None, + ) -> list[ReferenceInLanguageServerSymbol]: + """ + Find all symbols that reference the symbol at the given location. + + :param symbol_location: the location of the symbol for which to find references. + Does not need to include an end_line, as it is unused in the search. + :param include_body: whether to include the body of all symbols in the result. + Not recommended, as the referencing symbols will often be files, and thus the bodies will be very long. + Note: you can filter out the bodies of the children if you set include_children_body=False + in the to_dict method. + :param include_kinds: an optional sequence of ints representing the LSP symbol kind. + If provided, only symbols of the given kinds will be included in the result. + :param exclude_kinds: If provided, symbols of the given kinds will be excluded from the result. + Takes precedence over include_kinds. + :return: a list of symbols that reference the given symbol + """ + if not symbol_location.has_position_in_file(): + raise ValueError("Symbol location does not contain a valid position in a file") + assert symbol_location.relative_path is not None + assert symbol_location.line is not None + assert symbol_location.column is not None + references = self._lang_server.request_referencing_symbols( + relative_file_path=symbol_location.relative_path, + line=symbol_location.line, + column=symbol_location.column, + include_imports=False, + include_self=False, + include_body=include_body, + include_file_symbols=True, + ) + + if include_kinds is not None: + references = [s for s in references if s.symbol["kind"] in include_kinds] + + if exclude_kinds is not None: + references = [s for s in references if s.symbol["kind"] not in exclude_kinds] + + return [ReferenceInLanguageServerSymbol.from_lsp_reference(r) for r in references] + + @dataclass + class SymbolOverviewElement: + name_path: str + kind: int + + @classmethod + def from_symbol(cls, symbol: LanguageServerSymbol) -> Self: + return cls(name_path=symbol.get_name_path(), kind=int(symbol.symbol_kind)) + + def get_symbol_overview(self, relative_path: str) -> dict[str, list[SymbolOverviewElement]]: + path_to_unified_symbols = self._lang_server.request_overview(relative_path) + result = {} + for file_path, unified_symbols in path_to_unified_symbols.items(): + # TODO: maybe include not just top-level symbols? We could filter by kind to exclude variables + # The language server methods would need to be adjusted for this. + result[file_path] = [self.SymbolOverviewElement.from_symbol(LanguageServerSymbol(s)) for s in unified_symbols] + return result + + +class JetBrainsSymbol(Symbol): + def __init__(self, symbol_dict: dict, project: Project) -> None: + """ + :param symbol_dict: dictionary as returned by the JetBrains plugin client. + """ + self._project = project + self._dict = symbol_dict + self._cached_file_content: str | None = None + self._cached_body_start_position: PositionInFile | None = None + self._cached_body_end_position: PositionInFile | None = None + + def get_relative_path(self) -> str: + return self._dict["relative_path"] + + def get_file_content(self) -> str: + if self._cached_file_content is None: + path = os.path.join(self._project.project_root, self.get_relative_path()) + with open(path, encoding=self._project.project_config.encoding) as f: + self._cached_file_content = f.read() + return self._cached_file_content + + def is_position_in_file_available(self) -> bool: + return "text_range" in self._dict + + def get_body_start_position(self) -> PositionInFile | None: + if not self.is_position_in_file_available(): + return None + if self._cached_body_start_position is None: + pos = self._dict["text_range"]["start_pos"] + line, col = pos["line"], pos["col"] + self._cached_body_start_position = PositionInFile(line=line, col=col) + return self._cached_body_start_position + + def get_body_end_position(self) -> PositionInFile | None: + if not self.is_position_in_file_available(): + return None + if self._cached_body_end_position is None: + pos = self._dict["text_range"]["end_pos"] + line, col = pos["line"], pos["col"] + self._cached_body_end_position = PositionInFile(line=line, col=col) + return self._cached_body_end_position + + def is_neighbouring_definition_separated_by_empty_line(self) -> bool: + # NOTE: Symbol types cannot really be differentiated, because types are not handled in a language-agnostic way. + return False diff --git a/Libraries/serena/text_utils.py b/Libraries/serena/text_utils.py new file mode 100644 index 00000000..19987cee --- /dev/null +++ b/Libraries/serena/text_utils.py @@ -0,0 +1,405 @@ +import fnmatch +import logging +import os +import re +from collections.abc import Callable +from dataclasses import dataclass, field +from enum import StrEnum +from typing import Any, Self + +from joblib import Parallel, delayed + +log = logging.getLogger(__name__) + + +class LineType(StrEnum): + """Enum for different types of lines in search results.""" + + MATCH = "match" + """Part of the matched lines""" + BEFORE_MATCH = "prefix" + """Lines before the match""" + AFTER_MATCH = "postfix" + """Lines after the match""" + + +@dataclass(kw_only=True) +class TextLine: + """Represents a line of text with information on how it relates to the match.""" + + line_number: int + line_content: str + match_type: LineType + """Represents the type of line (match, prefix, postfix)""" + + def get_display_prefix(self) -> str: + """Get the display prefix for this line based on the match type.""" + if self.match_type == LineType.MATCH: + return " >" + return "..." + + def format_line(self, include_line_numbers: bool = True) -> str: + """Format the line for display (e.g.,for logging or passing to an LLM). + + :param include_line_numbers: Whether to include the line number in the result. + """ + prefix = self.get_display_prefix() + if include_line_numbers: + line_num = str(self.line_number).rjust(4) + prefix = f"{prefix}{line_num}" + return f"{prefix}:{self.line_content}" + + +@dataclass(kw_only=True) +class MatchedConsecutiveLines: + """Represents a collection of consecutive lines found through some criterion in a text file or a string. + May include lines before, after, and matched. + """ + + lines: list[TextLine] + """All lines in the context of the match. At least one of them is of `match_type` `MATCH`.""" + source_file_path: str | None = None + """Path to the file where the match was found (Metadata).""" + + # set in post-init + lines_before_matched: list[TextLine] = field(default_factory=list) + matched_lines: list[TextLine] = field(default_factory=list) + lines_after_matched: list[TextLine] = field(default_factory=list) + + def __post_init__(self) -> None: + for line in self.lines: + if line.match_type == LineType.BEFORE_MATCH: + self.lines_before_matched.append(line) + elif line.match_type == LineType.MATCH: + self.matched_lines.append(line) + elif line.match_type == LineType.AFTER_MATCH: + self.lines_after_matched.append(line) + + assert len(self.matched_lines) > 0, "At least one matched line is required" + + @property + def start_line(self) -> int: + return self.lines[0].line_number + + @property + def end_line(self) -> int: + return self.lines[-1].line_number + + @property + def num_matched_lines(self) -> int: + return len(self.matched_lines) + + def to_display_string(self, include_line_numbers: bool = True) -> str: + return "\n".join([line.format_line(include_line_numbers) for line in self.lines]) + + @classmethod + def from_file_contents( + cls, file_contents: str, line: int, context_lines_before: int = 0, context_lines_after: int = 0, source_file_path: str | None = None + ) -> Self: + line_contents = file_contents.split("\n") + start_lineno = max(0, line - context_lines_before) + end_lineno = min(len(line_contents) - 1, line + context_lines_after) + text_lines: list[TextLine] = [] + # before the line + for lineno in range(start_lineno, line): + text_lines.append(TextLine(line_number=lineno, line_content=line_contents[lineno], match_type=LineType.BEFORE_MATCH)) + # the line + text_lines.append(TextLine(line_number=line, line_content=line_contents[line], match_type=LineType.MATCH)) + # after the line + for lineno in range(line + 1, end_lineno + 1): + text_lines.append(TextLine(line_number=lineno, line_content=line_contents[lineno], match_type=LineType.AFTER_MATCH)) + + return cls(lines=text_lines, source_file_path=source_file_path) + + +def glob_to_regex(glob_pat: str) -> str: + regex_parts: list[str] = [] + i = 0 + while i < len(glob_pat): + ch = glob_pat[i] + if ch == "*": + regex_parts.append(".*") + elif ch == "?": + regex_parts.append(".") + elif ch == "\\": + i += 1 + if i < len(glob_pat): + regex_parts.append(re.escape(glob_pat[i])) + else: + regex_parts.append("\\") + else: + regex_parts.append(re.escape(ch)) + i += 1 + return "".join(regex_parts) + + +def search_text( + pattern: str, + content: str | None = None, + source_file_path: str | None = None, + allow_multiline_match: bool = False, + context_lines_before: int = 0, + context_lines_after: int = 0, + is_glob: bool = False, +) -> list[MatchedConsecutiveLines]: + """ + Search for a pattern in text content. Supports both regex and glob-like patterns. + + :param pattern: Pattern to search for (regex or glob-like pattern) + :param content: The text content to search. May be None if source_file_path is provided. + :param source_file_path: Optional path to the source file. If content is None, + this has to be passed and the file will be read. + :param allow_multiline_match: Whether to search across multiple lines. Currently, the default + option (False) is very inefficient, so it is recommended to set this to True. + :param context_lines_before: Number of context lines to include before matches + :param context_lines_after: Number of context lines to include after matches + :param is_glob: If True, pattern is treated as a glob-like pattern (e.g., "*.py", "test_??.py") + and will be converted to regex internally + + :return: List of `TextSearchMatch` objects + + :raises: ValueError if the pattern is not valid + + """ + if source_file_path and content is None: + with open(source_file_path) as f: + content = f.read() + + if content is None: + raise ValueError("Pass either content or source_file_path") + + matches = [] + lines = content.splitlines() + total_lines = len(lines) + + # Convert pattern to a compiled regex if it's a string + if is_glob: + pattern = glob_to_regex(pattern) + if allow_multiline_match: + # For multiline matches, we need to use the DOTALL flag to make '.' match newlines + compiled_pattern = re.compile(pattern, re.DOTALL) + # Search across the entire content as a single string + for match in compiled_pattern.finditer(content): + start_pos = match.start() + end_pos = match.end() + + # Find the line numbers for the start and end positions + start_line_num = content[:start_pos].count("\n") + 1 + end_line_num = content[:end_pos].count("\n") + 1 + + # Calculate the range of lines to include in the context + context_start = max(1, start_line_num - context_lines_before) + context_end = min(total_lines, end_line_num + context_lines_after) + + # Create TextLine objects for the context + context_lines = [] + for i in range(context_start - 1, context_end): + line_num = i + 1 + if context_start <= line_num < start_line_num: + match_type = LineType.BEFORE_MATCH + elif end_line_num < line_num <= context_end: + match_type = LineType.AFTER_MATCH + else: + match_type = LineType.MATCH + + context_lines.append(TextLine(line_number=line_num, line_content=lines[i], match_type=match_type)) + + matches.append(MatchedConsecutiveLines(lines=context_lines, source_file_path=source_file_path)) + else: + # TODO: extremely inefficient! Since we currently don't use this option in SerenaAgent or LanguageServer, + # it is not urgent to fix, but should be either improved or the option should be removed. + # Search line by line, normal compile without DOTALL + compiled_pattern = re.compile(pattern) + for i, line in enumerate(lines): + line_num = i + 1 + if compiled_pattern.search(line): + # Calculate the range of lines to include in the context + context_start = max(0, i - context_lines_before) + context_end = min(total_lines - 1, i + context_lines_after) + + # Create TextLine objects for the context + context_lines = [] + for j in range(context_start, context_end + 1): + context_line_num = j + 1 + if j < i: + match_type = LineType.BEFORE_MATCH + elif j > i: + match_type = LineType.AFTER_MATCH + else: + match_type = LineType.MATCH + + context_lines.append(TextLine(line_number=context_line_num, line_content=lines[j], match_type=match_type)) + + matches.append(MatchedConsecutiveLines(lines=context_lines, source_file_path=source_file_path)) + + return matches + + +def default_file_reader(file_path: str) -> str: + """Reads using utf-8 encoding.""" + with open(file_path, encoding="utf-8") as f: + return f.read() + + +def expand_braces(pattern: str) -> list[str]: + """ + Expands brace patterns in a glob string. + For example, "**/*.{js,jsx,ts,tsx}" becomes ["**/*.js", "**/*.jsx", "**/*.ts", "**/*.tsx"]. + Handles multiple brace sets as well. + """ + patterns = [pattern] + while any("{" in p for p in patterns): + new_patterns = [] + for p in patterns: + match = re.search(r"\{([^{}]+)\}", p) + if match: + prefix = p[: match.start()] + suffix = p[match.end() :] + options = match.group(1).split(",") + for option in options: + new_patterns.append(f"{prefix}{option}{suffix}") + else: + new_patterns.append(p) + patterns = new_patterns + return patterns + + +def glob_match(pattern: str, path: str) -> bool: + """ + Match a file path against a glob pattern. + + Supports standard glob patterns: + - * matches any number of characters except / + - ** matches any number of directories (zero or more) + - ? matches a single character except / + - [seq] matches any character in seq + + Supports brace expansion: + - {a,b,c} expands to multiple patterns (including nesting) + + Unsupported patterns: + - Bash extended glob features are unavailable in Python's fnmatch + - Extended globs like !(), ?(), +(), *(), @() are not supported + + :param pattern: Glob pattern (e.g., 'src/**/*.py', '**agent.py') + :param path: File path to match against + :return: True if path matches pattern + """ + pattern = pattern.replace("\\", "/") # Normalize backslashes to forward slashes + path = path.replace("\\", "/") # Normalize path backslashes to forward slashes + + # Handle ** patterns that should match zero or more directories + if "**" in pattern: + # Method 1: Standard fnmatch (matches one or more directories) + regex1 = fnmatch.translate(pattern) + if re.match(regex1, path): + return True + + # Method 2: Handle zero-directory case by removing /** entirely + # Convert "src/**/test.py" to "src/test.py" + if "/**/" in pattern: + zero_dir_pattern = pattern.replace("/**/", "/") + regex2 = fnmatch.translate(zero_dir_pattern) + if re.match(regex2, path): + return True + + # Method 3: Handle leading ** case by removing **/ + # Convert "**/test.py" to "test.py" + if pattern.startswith("**/"): + zero_dir_pattern = pattern[3:] # Remove "**/" + regex3 = fnmatch.translate(zero_dir_pattern) + if re.match(regex3, path): + return True + + return False + else: + # Simple pattern without **, use fnmatch directly + return fnmatch.fnmatch(path, pattern) + + +def search_files( + relative_file_paths: list[str], + pattern: str, + root_path: str = "", + file_reader: Callable[[str], str] = default_file_reader, + context_lines_before: int = 0, + context_lines_after: int = 0, + paths_include_glob: str | None = None, + paths_exclude_glob: str | None = None, +) -> list[MatchedConsecutiveLines]: + """ + Search for a pattern in a list of files. + + :param relative_file_paths: List of relative file paths in which to search + :param pattern: Pattern to search for + :param root_path: Root path to resolve relative paths against (by default, current working directory). + :param file_reader: Function to read a file, by default will just use os.open. + All files that can't be read by it will be skipped. + :param context_lines_before: Number of context lines to include before matches + :param context_lines_after: Number of context lines to include after matches + :param paths_include_glob: Optional glob pattern to include files from the list + :param paths_exclude_glob: Optional glob pattern to exclude files from the list + :return: List of MatchedConsecutiveLines objects + """ + # Pre-filter paths (done sequentially to avoid overhead) + # Use proper glob matching instead of gitignore patterns + include_patterns = expand_braces(paths_include_glob) if paths_include_glob else None + exclude_patterns = expand_braces(paths_exclude_glob) if paths_exclude_glob else None + + filtered_paths = [] + for path in relative_file_paths: + if include_patterns: + if not any(glob_match(p, path) for p in include_patterns): + log.debug(f"Skipping {path}: does not match include pattern {paths_include_glob}") + continue + + if exclude_patterns: + if any(glob_match(p, path) for p in exclude_patterns): + log.debug(f"Skipping {path}: matches exclude pattern {paths_exclude_glob}") + continue + + filtered_paths.append(path) + + log.info(f"Processing {len(filtered_paths)} files.") + + def process_single_file(path: str) -> dict[str, Any]: + """Process a single file - this function will be parallelized.""" + try: + abs_path = os.path.join(root_path, path) + file_content = file_reader(abs_path) + search_results = search_text( + pattern, + content=file_content, + source_file_path=path, + allow_multiline_match=True, + context_lines_before=context_lines_before, + context_lines_after=context_lines_after, + ) + if len(search_results) > 0: + log.debug(f"Found {len(search_results)} matches in {path}") + return {"path": path, "results": search_results, "error": None} + except Exception as e: + log.debug(f"Error processing {path}: {e}") + return {"path": path, "results": [], "error": str(e)} + + # Execute in parallel using joblib + results = Parallel( + n_jobs=-1, + backend="threading", + )(delayed(process_single_file)(path) for path in filtered_paths) + + # Collect results and errors + matches = [] + skipped_file_error_tuples = [] + + for result in results: + if result["error"]: + skipped_file_error_tuples.append((result["path"], result["error"])) + else: + matches.extend(result["results"]) + + if skipped_file_error_tuples: + log.debug(f"Failed to read {len(skipped_file_error_tuples)} files: {skipped_file_error_tuples}") + + log.info(f"Found {len(matches)} total matches across {len(filtered_paths)} files") + return matches diff --git a/Libraries/serena/tools/__init__.py b/Libraries/serena/tools/__init__.py new file mode 100644 index 00000000..e6fb4b08 --- /dev/null +++ b/Libraries/serena/tools/__init__.py @@ -0,0 +1,9 @@ +# ruff: noqa +from .tools_base import * +from .file_tools import * +from .symbol_tools import * +from .memory_tools import * +from .cmd_tools import * +from .config_tools import * +from .workflow_tools import * +from .jetbrains_tools import * diff --git a/Libraries/serena/tools/cmd_tools.py b/Libraries/serena/tools/cmd_tools.py new file mode 100644 index 00000000..e548903e --- /dev/null +++ b/Libraries/serena/tools/cmd_tools.py @@ -0,0 +1,52 @@ +""" +Tools supporting the execution of (external) commands +""" + +import os.path + +from serena.tools import Tool, ToolMarkerCanEdit +from serena.util.shell import execute_shell_command + + +class ExecuteShellCommandTool(Tool, ToolMarkerCanEdit): + """ + Executes a shell command. + """ + + def apply( + self, + command: str, + cwd: str | None = None, + capture_stderr: bool = True, + max_answer_chars: int = -1, + ) -> str: + """ + Execute a shell command and return its output. If there is a memory about suggested commands, read that first. + Never execute unsafe shell commands! + IMPORTANT: Do not use this tool to start + * long-running processes (e.g. servers) that are not intended to terminate quickly, + * processes that require user interaction. + + :param command: the shell command to execute + :param cwd: the working directory to execute the command in. If None, the project root will be used. + :param capture_stderr: whether to capture and return stderr output + :param max_answer_chars: if the output is longer than this number of characters, + no content will be returned. -1 means using the default value, don't adjust unless there is no other way to get the content + required for the task. + :return: a JSON object containing the command's stdout and optionally stderr output + """ + if cwd is None: + _cwd = self.get_project_root() + else: + if os.path.isabs(cwd): + _cwd = cwd + else: + _cwd = os.path.join(self.get_project_root(), cwd) + if not os.path.isdir(_cwd): + raise FileNotFoundError( + f"Specified a relative working directory ({cwd}), but the resulting path is not a directory: {_cwd}" + ) + + result = execute_shell_command(command, cwd=_cwd, capture_stderr=capture_stderr) + result = result.json() + return self._limit_length(result, max_answer_chars) diff --git a/Libraries/serena/tools/config_tools.py b/Libraries/serena/tools/config_tools.py new file mode 100644 index 00000000..2110e33f --- /dev/null +++ b/Libraries/serena/tools/config_tools.py @@ -0,0 +1,83 @@ +import json + +from serena.config.context_mode import SerenaAgentMode +from serena.tools import Tool, ToolMarkerDoesNotRequireActiveProject, ToolMarkerOptional + + +class ActivateProjectTool(Tool, ToolMarkerDoesNotRequireActiveProject): + """ + Activates a project by name. + """ + + def apply(self, project: str) -> str: + """ + Activates the project with the given name. + + :param project: the name of a registered project to activate or a path to a project directory + """ + active_project = self.agent.activate_project_from_path_or_name(project) + if active_project.is_newly_created: + result_str = ( + f"Created and activated a new project with name '{active_project.project_name}' at {active_project.project_root}, language: {active_project.project_config.language.value}. " + "You can activate this project later by name.\n" + f"The project's Serena configuration is in {active_project.path_to_project_yml()}. In particular, you may want to edit the project name and the initial prompt." + ) + else: + result_str = f"Activated existing project with name '{active_project.project_name}' at {active_project.project_root}, language: {active_project.project_config.language.value}" + + if active_project.project_config.initial_prompt: + result_str += f"\nAdditional project information:\n {active_project.project_config.initial_prompt}" + result_str += ( + f"\nAvailable memories:\n {json.dumps(list(self.memories_manager.list_memories()))}" + + "You should not read these memories directly, but rather use the `read_memory` tool to read them later if needed for the task." + ) + result_str += f"\nAvailable tools:\n {json.dumps(self.agent.get_active_tool_names())}" + return result_str + + +class RemoveProjectTool(Tool, ToolMarkerDoesNotRequireActiveProject, ToolMarkerOptional): + """ + Removes a project from the Serena configuration. + """ + + def apply(self, project_name: str) -> str: + """ + Removes a project from the Serena configuration. + + :param project_name: Name of the project to remove + """ + self.agent.serena_config.remove_project(project_name) + return f"Successfully removed project '{project_name}' from configuration." + + +class SwitchModesTool(Tool, ToolMarkerOptional): + """ + Activates modes by providing a list of their names + """ + + def apply(self, modes: list[str]) -> str: + """ + Activates the desired modes, like ["editing", "interactive"] or ["planning", "one-shot"] + + :param modes: the names of the modes to activate + """ + mode_instances = [SerenaAgentMode.load(mode) for mode in modes] + self.agent.set_modes(mode_instances) + + # Inform the Agent about the activated modes and the currently active tools + result_str = f"Successfully activated modes: {', '.join([mode.name for mode in mode_instances])}" + "\n" + result_str += "\n".join([mode_instance.prompt for mode_instance in mode_instances]) + "\n" + result_str += f"Currently active tools: {', '.join(self.agent.get_active_tool_names())}" + return result_str + + +class GetCurrentConfigTool(Tool): + """ + Prints the current configuration of the agent, including the active and available projects, tools, contexts, and modes. + """ + + def apply(self) -> str: + """ + Print the current configuration of the agent, including the active and available projects, tools, contexts, and modes. + """ + return self.agent.get_current_config_overview() diff --git a/Libraries/serena/tools/file_tools.py b/Libraries/serena/tools/file_tools.py new file mode 100644 index 00000000..f3db634e --- /dev/null +++ b/Libraries/serena/tools/file_tools.py @@ -0,0 +1,403 @@ +""" +File and file system-related tools, specifically for + * listing directory contents + * reading files + * creating files + * editing at the file level +""" + +import json +import os +import re +from collections import defaultdict +from fnmatch import fnmatch +from pathlib import Path + +from serena.text_utils import search_files +from serena.tools import SUCCESS_RESULT, EditedFileContext, Tool, ToolMarkerCanEdit, ToolMarkerOptional +from serena.util.file_system import scan_directory + + +class ReadFileTool(Tool): + """ + Reads a file within the project directory. + """ + + def apply(self, relative_path: str, start_line: int = 0, end_line: int | None = None, max_answer_chars: int = -1) -> str: + """ + Reads the given file or a chunk of it. Generally, symbolic operations + like find_symbol or find_referencing_symbols should be preferred if you know which symbols you are looking for. + + :param relative_path: the relative path to the file to read + :param start_line: the 0-based index of the first line to be retrieved. + :param end_line: the 0-based index of the last line to be retrieved (inclusive). If None, read until the end of the file. + :param max_answer_chars: if the file (chunk) is longer than this number of characters, + no content will be returned. Don't adjust unless there is really no other way to get the content + required for the task. + :return: the full text of the file at the given relative path + """ + self.project.validate_relative_path(relative_path) + + result = self.project.read_file(relative_path) + result_lines = result.splitlines() + if end_line is None: + result_lines = result_lines[start_line:] + else: + result_lines = result_lines[start_line : end_line + 1] + result = "\n".join(result_lines) + + return self._limit_length(result, max_answer_chars) + + +class CreateTextFileTool(Tool, ToolMarkerCanEdit): + """ + Creates/overwrites a file in the project directory. + """ + + def apply(self, relative_path: str, content: str) -> str: + """ + Write a new file or overwrite an existing file. + + :param relative_path: the relative path to the file to create + :param content: the (utf-8-encoded) content to write to the file + :return: a message indicating success or failure + """ + project_root = self.get_project_root() + abs_path = (Path(project_root) / relative_path).resolve() + will_overwrite_existing = abs_path.exists() + + if will_overwrite_existing: + self.project.validate_relative_path(relative_path) + else: + assert abs_path.is_relative_to( + self.get_project_root() + ), f"Cannot create file outside of the project directory, got {relative_path=}" + + abs_path.parent.mkdir(parents=True, exist_ok=True) + abs_path.write_text(content, encoding="utf-8") + answer = f"File created: {relative_path}." + if will_overwrite_existing: + answer += " Overwrote existing file." + return json.dumps(answer) + + +class ListDirTool(Tool): + """ + Lists files and directories in the given directory (optionally with recursion). + """ + + def apply(self, relative_path: str, recursive: bool, skip_ignored_files: bool = False, max_answer_chars: int = -1) -> str: + """ + Lists all non-gitignored files and directories in the given directory (optionally with recursion). + + :param relative_path: the relative path to the directory to list; pass "." to scan the project root + :param recursive: whether to scan subdirectories recursively + :param skip_ignored_files: whether to skip files and directories that are ignored + :param max_answer_chars: if the output is longer than this number of characters, + no content will be returned. -1 means the default value from the config will be used. + Don't adjust unless there is really no other way to get the content required for the task. + :return: a JSON object with the names of directories and files within the given directory + """ + # Check if the directory exists before validation + if not self.project.relative_path_exists(relative_path): + error_info = { + "error": f"Directory not found: {relative_path}", + "project_root": self.get_project_root(), + "hint": "Check if the path is correct relative to the project root", + } + return json.dumps(error_info) + + self.project.validate_relative_path(relative_path) + + dirs, files = scan_directory( + os.path.join(self.get_project_root(), relative_path), + relative_to=self.get_project_root(), + recursive=recursive, + is_ignored_dir=self.project.is_ignored_path if skip_ignored_files else None, + is_ignored_file=self.project.is_ignored_path if skip_ignored_files else None, + ) + + result = json.dumps({"dirs": dirs, "files": files}) + return self._limit_length(result, max_answer_chars) + + +class FindFileTool(Tool): + """ + Finds files in the given relative paths + """ + + def apply(self, file_mask: str, relative_path: str) -> str: + """ + Finds non-gitignored files matching the given file mask within the given relative path + + :param file_mask: the filename or file mask (using the wildcards * or ?) to search for + :param relative_path: the relative path to the directory to search in; pass "." to scan the project root + :return: a JSON object with the list of matching files + """ + self.project.validate_relative_path(relative_path) + + dir_to_scan = os.path.join(self.get_project_root(), relative_path) + + # find the files by ignoring everything that doesn't match + def is_ignored_file(abs_path: str) -> bool: + if self.project.is_ignored_path(abs_path): + return True + filename = os.path.basename(abs_path) + return not fnmatch(filename, file_mask) + + _dirs, files = scan_directory( + path=dir_to_scan, + recursive=True, + is_ignored_dir=self.project.is_ignored_path, + is_ignored_file=is_ignored_file, + relative_to=self.get_project_root(), + ) + + result = json.dumps({"files": files}) + return result + + +class ReplaceRegexTool(Tool, ToolMarkerCanEdit): + """ + Replaces content in a file by using regular expressions. + """ + + def apply( + self, + relative_path: str, + regex: str, + repl: str, + allow_multiple_occurrences: bool = False, + ) -> str: + r""" + Replaces one or more occurrences of the given regular expression. + This is the preferred way to replace content in a file whenever the symbol-level + tools are not appropriate. + Even large sections of code can be replaced by providing a concise regular expression of + the form "beginning.*?end-of-text-to-be-replaced". + Always try to use wildcards to avoid specifying the exact content of the code to be replaced, + especially if it spans several lines. + + IMPORTANT: REMEMBER TO USE WILDCARDS WHEN APPROPRIATE! I WILL BE VERY UNHAPPY IF YOU WRITE UNNECESSARILY LONG REGEXES WITHOUT USING WILDCARDS! + + :param relative_path: the relative path to the file + :param regex: a Python-style regular expression, matches of which will be replaced. + Dot matches all characters, multi-line matching is enabled. + :param repl: the string to replace the matched content with, which may contain + backreferences like \1, \2, etc. + IMPORTANT: Make sure to escape special characters appropriately! + Use "\n" to insert a newline, but use "\\n" to insert the string "\n" within a string literal. + :param allow_multiple_occurrences: if True, the regex may match multiple occurrences in the file + and all of them will be replaced. + If this is set to False and the regex matches multiple occurrences, an error will be returned + (and you may retry with a revised, more specific regex). + """ + self.project.validate_relative_path(relative_path) + with EditedFileContext(relative_path, self.agent) as context: + original_content = context.get_original_content() + updated_content, n = re.subn(regex, repl, original_content, flags=re.DOTALL | re.MULTILINE) + if n == 0: + return f"Error: No matches found for regex '{regex}' in file '{relative_path}'." + if not allow_multiple_occurrences and n > 1: + return ( + f"Error: Regex '{regex}' matches {n} occurrences in file '{relative_path}'. " + "Please revise the regex to be more specific or enable allow_multiple_occurrences if this is expected." + ) + context.set_updated_content(updated_content) + return SUCCESS_RESULT + + +class DeleteLinesTool(Tool, ToolMarkerCanEdit, ToolMarkerOptional): + """ + Deletes a range of lines within a file. + """ + + def apply( + self, + relative_path: str, + start_line: int, + end_line: int, + ) -> str: + """ + Deletes the given lines in the file. + Requires that the same range of lines was previously read using the `read_file` tool to verify correctness + of the operation. + + :param relative_path: the relative path to the file + :param start_line: the 0-based index of the first line to be deleted + :param end_line: the 0-based index of the last line to be deleted + """ + code_editor = self.create_code_editor() + code_editor.delete_lines(relative_path, start_line, end_line) + return SUCCESS_RESULT + + +class ReplaceLinesTool(Tool, ToolMarkerCanEdit, ToolMarkerOptional): + """ + Replaces a range of lines within a file with new content. + """ + + def apply( + self, + relative_path: str, + start_line: int, + end_line: int, + content: str, + ) -> str: + """ + Replaces the given range of lines in the given file. + Requires that the same range of lines was previously read using the `read_file` tool to verify correctness + of the operation. + + :param relative_path: the relative path to the file + :param start_line: the 0-based index of the first line to be deleted + :param end_line: the 0-based index of the last line to be deleted + :param content: the content to insert + """ + if not content.endswith("\n"): + content += "\n" + result = self.agent.get_tool(DeleteLinesTool).apply(relative_path, start_line, end_line) + if result != SUCCESS_RESULT: + return result + self.agent.get_tool(InsertAtLineTool).apply(relative_path, start_line, content) + return SUCCESS_RESULT + + +class InsertAtLineTool(Tool, ToolMarkerCanEdit, ToolMarkerOptional): + """ + Inserts content at a given line in a file. + """ + + def apply( + self, + relative_path: str, + line: int, + content: str, + ) -> str: + """ + Inserts the given content at the given line in the file, pushing existing content of the line down. + In general, symbolic insert operations like insert_after_symbol or insert_before_symbol should be preferred if you know which + symbol you are looking for. + However, this can also be useful for small targeted edits of the body of a longer symbol (without replacing the entire body). + + :param relative_path: the relative path to the file + :param line: the 0-based index of the line to insert content at + :param content: the content to be inserted + """ + if not content.endswith("\n"): + content += "\n" + code_editor = self.create_code_editor() + code_editor.insert_at_line(relative_path, line, content) + return SUCCESS_RESULT + + +class SearchForPatternTool(Tool): + """ + Performs a search for a pattern in the project. + """ + + def apply( + self, + substring_pattern: str, + context_lines_before: int = 0, + context_lines_after: int = 0, + paths_include_glob: str = "", + paths_exclude_glob: str = "", + relative_path: str = "", + restrict_search_to_code_files: bool = False, + max_answer_chars: int = -1, + ) -> str: + """ + Offers a flexible search for arbitrary patterns in the codebase, including the + possibility to search in non-code files. + Generally, symbolic operations like find_symbol or find_referencing_symbols + should be preferred if you know which symbols you are looking for. + + Pattern Matching Logic: + For each match, the returned result will contain the full lines where the + substring pattern is found, as well as optionally some lines before and after it. The pattern will be compiled with + DOTALL, meaning that the dot will match all characters including newlines. + This also means that it never makes sense to have .* at the beginning or end of the pattern, + but it may make sense to have it in the middle for complex patterns. + If a pattern matches multiple lines, all those lines will be part of the match. + Be careful to not use greedy quantifiers unnecessarily, it is usually better to use non-greedy quantifiers like .*? to avoid + matching too much content. + + File Selection Logic: + The files in which the search is performed can be restricted very flexibly. + Using `restrict_search_to_code_files` is useful if you are only interested in code symbols (i.e., those + symbols that can be manipulated with symbolic tools like find_symbol). + You can also restrict the search to a specific file or directory, + and provide glob patterns to include or exclude certain files on top of that. + The globs are matched against relative file paths from the project root (not to the `relative_path` parameter that + is used to further restrict the search). + Smartly combining the various restrictions allows you to perform very targeted searches. + + + :param substring_pattern: Regular expression for a substring pattern to search for + :param context_lines_before: Number of lines of context to include before each match + :param context_lines_after: Number of lines of context to include after each match + :param paths_include_glob: optional glob pattern specifying files to include in the search. + Matches against relative file paths from the project root (e.g., "*.py", "src/**/*.ts"). + Supports standard glob patterns (*, ?, [seq], **, etc.) and brace expansion {a,b,c}. + Only matches files, not directories. If left empty, all non-ignored files will be included. + :param paths_exclude_glob: optional glob pattern specifying files to exclude from the search. + Matches against relative file paths from the project root (e.g., "*test*", "**/*_generated.py"). + Supports standard glob patterns (*, ?, [seq], **, etc.) and brace expansion {a,b,c}. + Takes precedence over paths_include_glob. Only matches files, not directories. If left empty, no files are excluded. + :param relative_path: only subpaths of this path (relative to the repo root) will be analyzed. If a path to a single + file is passed, only that will be searched. The path must exist, otherwise a `FileNotFoundError` is raised. + :param max_answer_chars: if the output is longer than this number of characters, + no content will be returned. + -1 means the default value from the config will be used. + Don't adjust unless there is really no other way to get the content + required for the task. Instead, if the output is too long, you should + make a stricter query. + :param restrict_search_to_code_files: whether to restrict the search to only those files where + analyzed code symbols can be found. Otherwise, will search all non-ignored files. + Set this to True if your search is only meant to discover code that can be manipulated with symbolic tools. + For example, for finding classes or methods from a name pattern. + Setting to False is a better choice if you also want to search in non-code files, like in html or yaml files, + which is why it is the default. + :return: A mapping of file paths to lists of matched consecutive lines. + """ + abs_path = os.path.join(self.get_project_root(), relative_path) + if not os.path.exists(abs_path): + raise FileNotFoundError(f"Relative path {relative_path} does not exist.") + + if restrict_search_to_code_files: + matches = self.project.search_source_files_for_pattern( + pattern=substring_pattern, + relative_path=relative_path, + context_lines_before=context_lines_before, + context_lines_after=context_lines_after, + paths_include_glob=paths_include_glob.strip(), + paths_exclude_glob=paths_exclude_glob.strip(), + ) + else: + if os.path.isfile(abs_path): + rel_paths_to_search = [relative_path] + else: + _dirs, rel_paths_to_search = scan_directory( + path=abs_path, + recursive=True, + is_ignored_dir=self.project.is_ignored_path, + is_ignored_file=self.project.is_ignored_path, + relative_to=self.get_project_root(), + ) + # TODO (maybe): not super efficient to walk through the files again and filter if glob patterns are provided + # but it probably never matters and this version required no further refactoring + matches = search_files( + rel_paths_to_search, + substring_pattern, + root_path=self.get_project_root(), + paths_include_glob=paths_include_glob, + paths_exclude_glob=paths_exclude_glob, + ) + # group matches by file + file_to_matches: dict[str, list[str]] = defaultdict(list) + for match in matches: + assert match.source_file_path is not None + file_to_matches[match.source_file_path].append(match.to_display_string()) + result = json.dumps(file_to_matches) + return self._limit_length(result, max_answer_chars) diff --git a/Libraries/serena/tools/jetbrains_plugin_client.py b/Libraries/serena/tools/jetbrains_plugin_client.py new file mode 100644 index 00000000..5a845a6e --- /dev/null +++ b/Libraries/serena/tools/jetbrains_plugin_client.py @@ -0,0 +1,189 @@ +""" +Client for the Serena JetBrains Plugin +""" + +import json +import logging +from pathlib import Path +from typing import Any, Optional, Self, TypeVar + +import requests +from requests import Response +from sensai.util.string import ToStringMixin + +from serena.project import Project + +T = TypeVar("T") +log = logging.getLogger(__name__) + + +class SerenaClientError(Exception): + """Base exception for Serena client errors.""" + + +class ConnectionError(SerenaClientError): + """Raised when connection to the service fails.""" + + +class APIError(SerenaClientError): + """Raised when the API returns an error response.""" + + +class ServerNotFoundError(Exception): + """Raised when the plugin's service is not found.""" + + +class JetBrainsPluginClient(ToStringMixin): + """ + Python client for the Serena Backend Service. + + Provides simple methods to interact with all available endpoints. + """ + + BASE_PORT = 0x5EA2 + last_port: int | None = None + + def __init__(self, port: int, timeout: int = 30): + self.base_url = f"http://127.0.0.1:{port}" + self.timeout = timeout + self.session = requests.Session() + self.session.headers.update({"Content-Type": "application/json", "Accept": "application/json"}) + + def _tostring_includes(self) -> list[str]: + return ["base_url", "timeout"] + + @classmethod + def from_project(cls, project: Project) -> Self: + resolved_path = Path(project.project_root).resolve() + + if cls.last_port is not None: + client = JetBrainsPluginClient(cls.last_port) + if client.matches(resolved_path): + return client + + for port in range(cls.BASE_PORT, cls.BASE_PORT + 20): + client = JetBrainsPluginClient(port) + if client.matches(resolved_path): + log.info("Found JetBrains IDE service at port %d for project %s", port, resolved_path) + cls.last_port = port + return client + + raise ServerNotFoundError("Found no Serena service in a JetBrains IDE instance for the project at " + str(resolved_path)) + + def matches(self, resolved_path: Path) -> bool: + try: + return Path(self.project_root()).resolve() == resolved_path + except ConnectionError: + return False + + def _make_request(self, method: str, endpoint: str, data: Optional[dict] = None) -> dict[str, Any]: + url = f"{self.base_url}{endpoint}" + + response: Response | None = None + try: + if method.upper() == "GET": + response = self.session.get(url, timeout=self.timeout) + elif method.upper() == "POST": + json_data = json.dumps(data) if data else None + response = self.session.post(url, data=json_data, timeout=self.timeout) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + + # Try to parse JSON response + try: + return self._pythonify_response(response.json()) + except json.JSONDecodeError: + # If response is not JSON, return raw text + return {"response": response.text} + + except requests.exceptions.ConnectionError as e: + raise ConnectionError(f"Failed to connect to Serena service at {url}: {e}") + except requests.exceptions.Timeout as e: + raise ConnectionError(f"Request to {url} timed out: {e}") + except requests.exceptions.HTTPError as e: + if response is not None: + raise APIError(f"API request failed with status {response.status_code}: {response.text}") + raise APIError(f"API request failed with HTTP error: {e}") + except requests.exceptions.RequestException as e: + raise SerenaClientError(f"Request failed: {e}") + + @staticmethod + def _pythonify_response(response: T) -> T: + """ + Converts dictionary keys from camelCase to snake_case recursively. + + :response: the response in which to convert keys (dictionary or list) + """ + to_snake_case = lambda s: "".join(["_" + c.lower() if c.isupper() else c for c in s]) + + def convert(x): # type: ignore + if isinstance(x, dict): + return {to_snake_case(k): convert(v) for k, v in x.items()} + elif isinstance(x, list): + return [convert(item) for item in x] + else: + return x + + return convert(response) + + def project_root(self) -> str: + response = self._make_request("GET", "/status") + return response["project_root"] + + def find_symbol( + self, name_path: str, relative_path: str | None = None, include_body: bool = False, depth: int = 0, include_location: bool = False + ) -> dict[str, Any]: + """ + Find symbols by name. + + :param name_path: the name path to match + :param relative_path: the relative path to which to restrict the search + :param include_body: whether to include symbol body content + :param depth: depth of children to include (0 = no children) + + :return: Dictionary containing 'symbols' list with matching symbols + """ + request_data = { + "namePath": name_path, + "relativePath": relative_path, + "includeBody": include_body, + "depth": depth, + "includeLocation": include_location, + } + return self._make_request("POST", "/findSymbol", request_data) + + def find_references(self, name_path: str, relative_path: str) -> dict[str, Any]: + """ + Find references to a symbol. + + :param name_path: the name path of the symbol + :param relative_path: the relative path + :return: dictionary containing 'symbols' list with symbol references + """ + request_data = {"namePath": name_path, "relativePath": relative_path} + return self._make_request("POST", "/findReferences", request_data) + + def get_symbols_overview(self, relative_path: str) -> dict[str, Any]: + """ + :param relative_path: the relative path to a source file + """ + request_data = {"relativePath": relative_path} + return self._make_request("POST", "/getSymbolsOverview", request_data) + + def is_service_available(self) -> bool: + try: + self.project_root() + return True + except (ConnectionError, APIError): + return False + + def close(self) -> None: + self.session.close() + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore + self.close() diff --git a/Libraries/serena/tools/jetbrains_tools.py b/Libraries/serena/tools/jetbrains_tools.py new file mode 100644 index 00000000..db3a97c7 --- /dev/null +++ b/Libraries/serena/tools/jetbrains_tools.py @@ -0,0 +1,132 @@ +import json + +from serena.tools import Tool, ToolMarkerOptional, ToolMarkerSymbolicRead +from serena.tools.jetbrains_plugin_client import JetBrainsPluginClient + + +class JetBrainsFindSymbolTool(Tool, ToolMarkerSymbolicRead, ToolMarkerOptional): + """ + Performs a global (or local) search for symbols with/containing a given name/substring (optionally filtered by type). + """ + + def apply( + self, + name_path: str, + depth: int = 0, + relative_path: str | None = None, + include_body: bool = False, + max_answer_chars: int = -1, + ) -> str: + """ + Retrieves information on all symbols/code entities (classes, methods, etc.) based on the given `name_path`, + which represents a pattern for the symbol's path within the symbol tree of a single file. + The returned symbol location can be used for edits or further queries. + Specify `depth > 0` to retrieve children (e.g., methods of a class). + + The matching behavior is determined by the structure of `name_path`, which can + either be a simple name (e.g. "method") or a name path like "class/method" (relative name path) + or "/class/method" (absolute name path). + Note that the name path is not a path in the file system but rather a path in the symbol tree + **within a single file**. Thus, file or directory names should never be included in the `name_path`. + For restricting the search to a single file or directory, pass the `relative_path` parameter. + The retrieved symbols' `name_path` attribute will always be composed of symbol names, never file + or directory names. + + Key aspects of the name path matching behavior: + - The name of the retrieved symbols will match the last segment of `name_path`, while preceding segments + will restrict the search to symbols that have a desired sequence of ancestors. + - If there is no `/` in `name_path`, there is no restriction on the ancestor symbols. + For example, passing `method` will match against all symbols with name paths like `method`, + `class/method`, `class/nested_class/method`, etc. + - If `name_path` contains at least one `/`, the matching is restricted to symbols + with the respective ancestors. For example, passing `class/method` will match against + `class/method` as well as `nested_class/class/method` but not `other_class/method`. + - If `name_path` starts with a `/`, it will be treated as an absolute name path pattern, i.e. + all ancestors are provided and must match. + For example, passing `/class` will match only against top-level symbols named `class` but + will not match `nested_class/class`. Passing `/class/method` will match `class/method` but + not `outer_class/class/method`. + + :param name_path: The name path pattern to search for, see above for details. + :param depth: Depth to retrieve descendants (e.g., 1 for class methods/attributes). + :param relative_path: Optional. Restrict search to this file or directory. + If None, searches entire codebase. + If a directory is passed, the search will be restricted to the files in that directory. + If a file is passed, the search will be restricted to that file. + If you have some knowledge about the codebase, you should use this parameter, as it will significantly + speed up the search as well as reduce the number of results. + :param include_body: If True, include the symbol's source code. Use judiciously. + :param max_answer_chars: max characters for the JSON result. If exceeded, no content is returned. + -1 means the default value from the config will be used. + :return: JSON string: a list of symbols (with locations) matching the name. + """ + with JetBrainsPluginClient.from_project(self.project) as client: + response_dict = client.find_symbol( + name_path=name_path, + relative_path=relative_path, + depth=depth, + include_body=include_body, + ) + result = json.dumps(response_dict) + return self._limit_length(result, max_answer_chars) + + +class JetBrainsFindReferencingSymbolsTool(Tool, ToolMarkerSymbolicRead, ToolMarkerOptional): + """ + Finds symbols that reference the given symbol + """ + + def apply( + self, + name_path: str, + relative_path: str, + max_answer_chars: int = -1, + ) -> str: + """ + Finds symbols that reference the symbol at the given `name_path`. + The result will contain metadata about the referencing symbols. + + :param name_path: name path of the symbol for which to find references; matching logic as described in find symbol tool. + :param relative_path: the relative path to the file containing the symbol for which to find references. + Note that here you can't pass a directory but must pass a file. + :param max_answer_chars: max characters for the JSON result. If exceeded, no content is returned. -1 means the + default value from the config will be used. + :return: a list of JSON objects with the symbols referencing the requested symbol + """ + with JetBrainsPluginClient.from_project(self.project) as client: + response_dict = client.find_references( + name_path=name_path, + relative_path=relative_path, + ) + result = json.dumps(response_dict) + return self._limit_length(result, max_answer_chars) + + +class JetBrainsGetSymbolsOverviewTool(Tool, ToolMarkerSymbolicRead, ToolMarkerOptional): + """ + Retrieves an overview of the top-level symbols within a specified file + """ + + def apply( + self, + relative_path: str, + max_answer_chars: int = -1, + ) -> str: + """ + Gets an overview of the top-level symbols in the given file. + Calling this is often a good idea before more targeted reading, searching or editing operations on the code symbols. + Before requesting a symbol overview, it is usually a good idea to narrow down the scope of the overview + by first understanding the basic directory structure of the repository that you can get from memories + or by using the `list_dir` and `find_file` tools (or similar). + + :param relative_path: the relative path to the file to get the overview of + :param max_answer_chars: max characters for the JSON result. If exceeded, no content is returned. + -1 means the default value from the config will be used. + :return: a JSON object containing the symbols + """ + with JetBrainsPluginClient.from_project(self.project) as client: + response_dict = client.get_symbols_overview( + relative_path=relative_path, + ) + result = json.dumps(response_dict) + return self._limit_length(result, max_answer_chars) diff --git a/Libraries/serena/tools/memory_tools.py b/Libraries/serena/tools/memory_tools.py new file mode 100644 index 00000000..f52ae3f4 --- /dev/null +++ b/Libraries/serena/tools/memory_tools.py @@ -0,0 +1,64 @@ +import json + +from serena.tools import Tool + + +class WriteMemoryTool(Tool): + """ + Writes a named memory (for future reference) to Serena's project-specific memory store. + """ + + def apply(self, memory_name: str, content: str, max_answer_chars: int = -1) -> str: + """ + Write some information about this project that can be useful for future tasks to a memory in md format. + The memory name should be meaningful. + """ + if max_answer_chars == -1: + max_answer_chars = self.agent.serena_config.default_max_tool_answer_chars + if len(content) > max_answer_chars: + raise ValueError( + f"Content for {memory_name} is too long. Max length is {max_answer_chars} characters. " + "Please make the content shorter." + ) + + return self.memories_manager.save_memory(memory_name, content) + + +class ReadMemoryTool(Tool): + """ + Reads the memory with the given name from Serena's project-specific memory store. + """ + + def apply(self, memory_file_name: str, max_answer_chars: int = -1) -> str: + """ + Read the content of a memory file. This tool should only be used if the information + is relevant to the current task. You can infer whether the information + is relevant from the memory file name. + You should not read the same memory file multiple times in the same conversation. + """ + return self.memories_manager.load_memory(memory_file_name) + + +class ListMemoriesTool(Tool): + """ + Lists memories in Serena's project-specific memory store. + """ + + def apply(self) -> str: + """ + List available memories. Any memory can be read using the `read_memory` tool. + """ + return json.dumps(self.memories_manager.list_memories()) + + +class DeleteMemoryTool(Tool): + """ + Deletes a memory from Serena's project-specific memory store. + """ + + def apply(self, memory_file_name: str) -> str: + """ + Delete a memory file. Should only happen if a user asks for it explicitly, + for example by saying that the information retrieved from a memory file is no longer correct + or no longer relevant for the project. + """ + return self.memories_manager.delete_memory(memory_file_name) diff --git a/Libraries/serena/tools/symbol_tools.py b/Libraries/serena/tools/symbol_tools.py new file mode 100644 index 00000000..14099ef7 --- /dev/null +++ b/Libraries/serena/tools/symbol_tools.py @@ -0,0 +1,295 @@ +""" +Language server-related tools +""" + +import dataclasses +import json +import os +from collections.abc import Sequence +from copy import copy +from typing import Any + +from serena.tools import ( + SUCCESS_RESULT, + Tool, + ToolMarkerSymbolicEdit, + ToolMarkerSymbolicRead, +) +from serena.tools.tools_base import ToolMarkerOptional +from solidlsp.ls_types import SymbolKind + + +def _sanitize_symbol_dict(symbol_dict: dict[str, Any]) -> dict[str, Any]: + """ + Sanitize a symbol dictionary inplace by removing unnecessary information. + """ + # We replace the location entry, which repeats line information already included in body_location + # and has unnecessary information on column, by just the relative path. + symbol_dict = copy(symbol_dict) + s_relative_path = symbol_dict.get("location", {}).get("relative_path") + if s_relative_path is not None: + symbol_dict["relative_path"] = s_relative_path + symbol_dict.pop("location", None) + # also remove name, name_path should be enough + symbol_dict.pop("name") + return symbol_dict + + +class RestartLanguageServerTool(Tool, ToolMarkerOptional): + """Restarts the language server, may be necessary when edits not through Serena happen.""" + + def apply(self) -> str: + """Use this tool only on explicit user request or after confirmation. + It may be necessary to restart the language server if it hangs. + """ + self.agent.reset_language_server() + return SUCCESS_RESULT + + +class GetSymbolsOverviewTool(Tool, ToolMarkerSymbolicRead): + """ + Gets an overview of the top-level symbols defined in a given file. + """ + + def apply(self, relative_path: str, max_answer_chars: int = -1) -> str: + """ + Use this tool to get a high-level understanding of the code symbols in a file. + This should be the first tool to call when you want to understand a new file, unless you already know + what you are looking for. + + :param relative_path: the relative path to the file to get the overview of + :param max_answer_chars: if the overview is longer than this number of characters, + no content will be returned. -1 means the default value from the config will be used. + Don't adjust unless there is really no other way to get the content required for the task. + :return: a JSON object containing info about top-level symbols in the file + """ + symbol_retriever = self.create_language_server_symbol_retriever() + file_path = os.path.join(self.project.project_root, relative_path) + + # The symbol overview is capable of working with both files and directories, + # but we want to ensure that the user provides a file path. + if not os.path.exists(file_path): + raise FileNotFoundError(f"File or directory {relative_path} does not exist in the project.") + if os.path.isdir(file_path): + raise ValueError(f"Expected a file path, but got a directory path: {relative_path}. ") + result = symbol_retriever.get_symbol_overview(relative_path)[relative_path] + result_json_str = json.dumps([dataclasses.asdict(i) for i in result]) + return self._limit_length(result_json_str, max_answer_chars) + + +class FindSymbolTool(Tool, ToolMarkerSymbolicRead): + """ + Performs a global (or local) search for symbols with/containing a given name/substring (optionally filtered by type). + """ + + def apply( + self, + name_path: str, + depth: int = 0, + relative_path: str = "", + include_body: bool = False, + include_kinds: list[int] = [], # noqa: B006 + exclude_kinds: list[int] = [], # noqa: B006 + substring_matching: bool = False, + max_answer_chars: int = -1, + ) -> str: + """ + Retrieves information on all symbols/code entities (classes, methods, etc.) based on the given `name_path`, + which represents a pattern for the symbol's path within the symbol tree of a single file. + The returned symbol location can be used for edits or further queries. + Specify `depth > 0` to retrieve children (e.g., methods of a class). + + The matching behavior is determined by the structure of `name_path`, which can + either be a simple name (e.g. "method") or a name path like "class/method" (relative name path) + or "/class/method" (absolute name path). Note that the name path is not a path in the file system + but rather a path in the symbol tree **within a single file**. Thus, file or directory names should never + be included in the `name_path`. For restricting the search to a single file or directory, + the `within_relative_path` parameter should be used instead. The retrieved symbols' `name_path` attribute + will always be composed of symbol names, never file or directory names. + + Key aspects of the name path matching behavior: + - Trailing slashes in `name_path` play no role and are ignored. + - The name of the retrieved symbols will match (either exactly or as a substring) + the last segment of `name_path`, while other segments will restrict the search to symbols that + have a desired sequence of ancestors. + - If there is no starting or intermediate slash in `name_path`, there is no + restriction on the ancestor symbols. For example, passing `method` will match + against symbols with name paths like `method`, `class/method`, `class/nested_class/method`, etc. + - If `name_path` contains a `/` but doesn't start with a `/`, the matching is restricted to symbols + with the same ancestors as the last segment of `name_path`. For example, passing `class/method` will match against + `class/method` as well as `nested_class/class/method` but not `method`. + - If `name_path` starts with a `/`, it will be treated as an absolute name path pattern, meaning + that the first segment of it must match the first segment of the symbol's name path. + For example, passing `/class` will match only against top-level symbols like `class` but not against `nested_class/class`. + Passing `/class/method` will match against `class/method` but not `nested_class/class/method` or `method`. + + + :param name_path: The name path pattern to search for, see above for details. + :param depth: Depth to retrieve descendants (e.g., 1 for class methods/attributes). + :param relative_path: Optional. Restrict search to this file or directory. If None, searches entire codebase. + If a directory is passed, the search will be restricted to the files in that directory. + If a file is passed, the search will be restricted to that file. + If you have some knowledge about the codebase, you should use this parameter, as it will significantly + speed up the search as well as reduce the number of results. + :param include_body: If True, include the symbol's source code. Use judiciously. + :param include_kinds: Optional. List of LSP symbol kind integers to include. (e.g., 5 for Class, 12 for Function). + Valid kinds: 1=file, 2=module, 3=namespace, 4=package, 5=class, 6=method, 7=property, 8=field, 9=constructor, 10=enum, + 11=interface, 12=function, 13=variable, 14=constant, 15=string, 16=number, 17=boolean, 18=array, 19=object, + 20=key, 21=null, 22=enum member, 23=struct, 24=event, 25=operator, 26=type parameter. + If not provided, all kinds are included. + :param exclude_kinds: Optional. List of LSP symbol kind integers to exclude. Takes precedence over `include_kinds`. + If not provided, no kinds are excluded. + :param substring_matching: If True, use substring matching for the last segment of `name`. + :param max_answer_chars: Max characters for the JSON result. If exceeded, no content is returned. + -1 means the default value from the config will be used. + :return: a list of symbols (with locations) matching the name. + """ + parsed_include_kinds: Sequence[SymbolKind] | None = [SymbolKind(k) for k in include_kinds] if include_kinds else None + parsed_exclude_kinds: Sequence[SymbolKind] | None = [SymbolKind(k) for k in exclude_kinds] if exclude_kinds else None + symbol_retriever = self.create_language_server_symbol_retriever() + symbols = symbol_retriever.find_by_name( + name_path, + include_body=include_body, + include_kinds=parsed_include_kinds, + exclude_kinds=parsed_exclude_kinds, + substring_matching=substring_matching, + within_relative_path=relative_path, + ) + symbol_dicts = [_sanitize_symbol_dict(s.to_dict(kind=True, location=True, depth=depth, include_body=include_body)) for s in symbols] + result = json.dumps(symbol_dicts) + return self._limit_length(result, max_answer_chars) + + +class FindReferencingSymbolsTool(Tool, ToolMarkerSymbolicRead): + """ + Finds symbols that reference the symbol at the given location (optionally filtered by type). + """ + + def apply( + self, + name_path: str, + relative_path: str, + include_kinds: list[int] = [], # noqa: B006 + exclude_kinds: list[int] = [], # noqa: B006 + max_answer_chars: int = -1, + ) -> str: + """ + Finds references to the symbol at the given `name_path`. The result will contain metadata about the referencing symbols + as well as a short code snippet around the reference. + + :param name_path: for finding the symbol to find references for, same logic as in the `find_symbol` tool. + :param relative_path: the relative path to the file containing the symbol for which to find references. + Note that here you can't pass a directory but must pass a file. + :param include_kinds: same as in the `find_symbol` tool. + :param exclude_kinds: same as in the `find_symbol` tool. + :param max_answer_chars: same as in the `find_symbol` tool. + :return: a list of JSON objects with the symbols referencing the requested symbol + """ + include_body = False # It is probably never a good idea to include the body of the referencing symbols + parsed_include_kinds: Sequence[SymbolKind] | None = [SymbolKind(k) for k in include_kinds] if include_kinds else None + parsed_exclude_kinds: Sequence[SymbolKind] | None = [SymbolKind(k) for k in exclude_kinds] if exclude_kinds else None + symbol_retriever = self.create_language_server_symbol_retriever() + references_in_symbols = symbol_retriever.find_referencing_symbols( + name_path, + relative_file_path=relative_path, + include_body=include_body, + include_kinds=parsed_include_kinds, + exclude_kinds=parsed_exclude_kinds, + ) + reference_dicts = [] + for ref in references_in_symbols: + ref_dict = ref.symbol.to_dict(kind=True, location=True, depth=0, include_body=include_body) + ref_dict = _sanitize_symbol_dict(ref_dict) + if not include_body: + ref_relative_path = ref.symbol.location.relative_path + assert ref_relative_path is not None, f"Referencing symbol {ref.symbol.name} has no relative path, this is likely a bug." + content_around_ref = self.project.retrieve_content_around_line( + relative_file_path=ref_relative_path, line=ref.line, context_lines_before=1, context_lines_after=1 + ) + ref_dict["content_around_reference"] = content_around_ref.to_display_string() + reference_dicts.append(ref_dict) + result = json.dumps(reference_dicts) + return self._limit_length(result, max_answer_chars) + + +class ReplaceSymbolBodyTool(Tool, ToolMarkerSymbolicEdit): + """ + Replaces the full definition of a symbol. + """ + + def apply( + self, + name_path: str, + relative_path: str, + body: str, + ) -> str: + r""" + Replaces the body of the symbol with the given `name_path`. + + The tool shall be used to replace symbol bodies that have been previously retrieved + (e.g. via `find_symbol`). + IMPORTANT: Do not use this tool if you do not know what exactly constitutes the body of the symbol. + + :param name_path: for finding the symbol to replace, same logic as in the `find_symbol` tool. + :param relative_path: the relative path to the file containing the symbol + :param body: the new symbol body. The symbol body is the definition of a symbol + in the programming language, including e.g. the signature line for functions. + IMPORTANT: The body does NOT include any preceding docstrings/comments or imports, in particular. + """ + code_editor = self.create_code_editor() + code_editor.replace_body( + name_path, + relative_file_path=relative_path, + body=body, + ) + return SUCCESS_RESULT + + +class InsertAfterSymbolTool(Tool, ToolMarkerSymbolicEdit): + """ + Inserts content after the end of the definition of a given symbol. + """ + + def apply( + self, + name_path: str, + relative_path: str, + body: str, + ) -> str: + """ + Inserts the given body/content after the end of the definition of the given symbol (via the symbol's location). + A typical use case is to insert a new class, function, method, field or variable assignment. + + :param name_path: name path of the symbol after which to insert content (definitions in the `find_symbol` tool apply) + :param relative_path: the relative path to the file containing the symbol + :param body: the body/content to be inserted. The inserted code shall begin with the next line after + the symbol. + """ + code_editor = self.create_code_editor() + code_editor.insert_after_symbol(name_path, relative_file_path=relative_path, body=body) + return SUCCESS_RESULT + + +class InsertBeforeSymbolTool(Tool, ToolMarkerSymbolicEdit): + """ + Inserts content before the beginning of the definition of a given symbol. + """ + + def apply( + self, + name_path: str, + relative_path: str, + body: str, + ) -> str: + """ + Inserts the given content before the beginning of the definition of the given symbol (via the symbol's location). + A typical use case is to insert a new class, function, method, field or variable assignment; or + a new import statement before the first symbol in the file. + + :param name_path: name path of the symbol before which to insert content (definitions in the `find_symbol` tool apply) + :param relative_path: the relative path to the file containing the symbol + :param body: the body/content to be inserted before the line in which the referenced symbol is defined + """ + code_editor = self.create_code_editor() + code_editor.insert_before_symbol(name_path, relative_file_path=relative_path, body=body) + return SUCCESS_RESULT diff --git a/Libraries/serena/tools/tools_base.py b/Libraries/serena/tools/tools_base.py new file mode 100644 index 00000000..3f2d2c25 --- /dev/null +++ b/Libraries/serena/tools/tools_base.py @@ -0,0 +1,413 @@ +import inspect +import os +from abc import ABC +from collections.abc import Iterable +from dataclasses import dataclass +from types import TracebackType +from typing import TYPE_CHECKING, Any, Protocol, Self, TypeVar + +from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata +from sensai.util import logging +from sensai.util.string import dict_string + +from serena.project import Project +from serena.prompt_factory import PromptFactory +from serena.symbol import LanguageServerSymbolRetriever +from serena.util.class_decorators import singleton +from serena.util.inspection import iter_subclasses +from solidlsp.ls_exceptions import SolidLSPException + +if TYPE_CHECKING: + from serena.agent import MemoriesManager, SerenaAgent + from serena.code_editor import CodeEditor + +log = logging.getLogger(__name__) +T = TypeVar("T") +SUCCESS_RESULT = "OK" + + +class Component(ABC): + def __init__(self, agent: "SerenaAgent"): + self.agent = agent + + def get_project_root(self) -> str: + """ + :return: the root directory of the active project, raises a ValueError if no active project configuration is set + """ + return self.agent.get_project_root() + + @property + def prompt_factory(self) -> PromptFactory: + return self.agent.prompt_factory + + @property + def memories_manager(self) -> "MemoriesManager": + assert self.agent.memories_manager is not None + return self.agent.memories_manager + + def create_language_server_symbol_retriever(self) -> LanguageServerSymbolRetriever: + if not self.agent.is_using_language_server(): + raise Exception("Cannot create LanguageServerSymbolRetriever; agent is not in language server mode.") + language_server = self.agent.language_server + assert language_server is not None + return LanguageServerSymbolRetriever(language_server, agent=self.agent) + + @property + def project(self) -> Project: + return self.agent.get_active_project_or_raise() + + def create_code_editor(self) -> "CodeEditor": + from ..code_editor import JetBrainsCodeEditor, LanguageServerCodeEditor + + if self.agent.is_using_language_server(): + return LanguageServerCodeEditor(self.create_language_server_symbol_retriever(), agent=self.agent) + else: + return JetBrainsCodeEditor(project=self.project, agent=self.agent) + + +class ToolMarker: + """ + Base class for tool markers. + """ + + +class ToolMarkerCanEdit(ToolMarker): + """ + Marker class for all tools that can perform editing operations on files. + """ + + +class ToolMarkerDoesNotRequireActiveProject(ToolMarker): + pass + + +class ToolMarkerOptional(ToolMarker): + """ + Marker class for optional tools that are disabled by default. + """ + + +class ToolMarkerSymbolicRead(ToolMarker): + """ + Marker class for tools that perform symbol read operations. + """ + + +class ToolMarkerSymbolicEdit(ToolMarkerCanEdit): + """ + Marker class for tools that perform symbolic edit operations. + """ + + +class ApplyMethodProtocol(Protocol): + """Callable protocol for the apply method of a tool.""" + + def __call__(self, *args: Any, **kwargs: Any) -> str: + pass + + +class Tool(Component): + # NOTE: each tool should implement the apply method, which is then used in + # the central method of the Tool class `apply_ex`. + # Failure to do so will result in a RuntimeError at tool execution time. + # The apply method is not declared as part of the base Tool interface since we cannot + # know the signature of the (input parameters of the) method in advance. + # + # The docstring and types of the apply method are used to generate the tool description + # (which is use by the LLM, so a good description is important) + # and to validate the tool call arguments. + + @classmethod + def get_name_from_cls(cls) -> str: + name = cls.__name__ + if name.endswith("Tool"): + name = name[:-4] + # convert to snake_case + name = "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") + return name + + def get_name(self) -> str: + return self.get_name_from_cls() + + def get_apply_fn(self) -> ApplyMethodProtocol: + apply_fn = getattr(self, "apply") + if apply_fn is None: + raise RuntimeError(f"apply not defined in {self}. Did you forget to implement it?") + return apply_fn + + @classmethod + def can_edit(cls) -> bool: + """ + Returns whether this tool can perform editing operations on code. + + :return: True if the tool can edit code, False otherwise + """ + return issubclass(cls, ToolMarkerCanEdit) + + @classmethod + def get_tool_description(cls) -> str: + docstring = cls.__doc__ + if docstring is None: + return "" + return docstring.strip() + + @classmethod + def get_apply_docstring_from_cls(cls) -> str: + """Get the docstring for the apply method from the class (static metadata). + Needed for creating MCP tools in a separate process without running into serialization issues. + """ + # First try to get from __dict__ to handle dynamic docstring changes + if "apply" in cls.__dict__: + apply_fn = cls.__dict__["apply"] + else: + # Fall back to getattr for inherited methods + apply_fn = getattr(cls, "apply", None) + if apply_fn is None: + raise AttributeError(f"apply method not defined in {cls}. Did you forget to implement it?") + + docstring = apply_fn.__doc__ + if not docstring: + raise AttributeError(f"apply method has no (or empty) docstring in {cls}. Did you forget to implement it?") + return docstring.strip() + + def get_apply_docstring(self) -> str: + """Gets the docstring for the tool application, used by the MCP server.""" + return self.get_apply_docstring_from_cls() + + def get_apply_fn_metadata(self) -> FuncMetadata: + """Gets the metadata for the tool application function, used by the MCP server.""" + return self.get_apply_fn_metadata_from_cls() + + @classmethod + def get_apply_fn_metadata_from_cls(cls) -> FuncMetadata: + """Get the metadata for the apply method from the class (static metadata). + Needed for creating MCP tools in a separate process without running into serialization issues. + """ + # First try to get from __dict__ to handle dynamic docstring changes + if "apply" in cls.__dict__: + apply_fn = cls.__dict__["apply"] + else: + # Fall back to getattr for inherited methods + apply_fn = getattr(cls, "apply", None) + if apply_fn is None: + raise AttributeError(f"apply method not defined in {cls}. Did you forget to implement it?") + + return func_metadata(apply_fn, skip_names=["self", "cls"]) + + def _log_tool_application(self, frame: Any) -> None: + params = {} + ignored_params = {"self", "log_call", "catch_exceptions", "args", "apply_fn"} + for param, value in frame.f_locals.items(): + if param in ignored_params: + continue + if param == "kwargs": + params.update(value) + else: + params[param] = value + log.info(f"{self.get_name_from_cls()}: {dict_string(params)}") + + def _limit_length(self, result: str, max_answer_chars: int) -> str: + if max_answer_chars == -1: + max_answer_chars = self.agent.serena_config.default_max_tool_answer_chars + if max_answer_chars <= 0: + raise ValueError(f"Must be positive or the default (-1), got: {max_answer_chars=}") + if (n_chars := len(result)) > max_answer_chars: + result = ( + f"The answer is too long ({n_chars} characters). " + + "Please try a more specific tool query or raise the max_answer_chars parameter." + ) + return result + + def is_active(self) -> bool: + return self.agent.tool_is_active(self.__class__) + + def apply_ex(self, log_call: bool = True, catch_exceptions: bool = True, **kwargs) -> str: # type: ignore + """ + Applies the tool with logging and exception handling, using the given keyword arguments + """ + + def task() -> str: + apply_fn = self.get_apply_fn() + + try: + if not self.is_active(): + return f"Error: Tool '{self.get_name_from_cls()}' is not active. Active tools: {self.agent.get_active_tool_names()}" + except Exception as e: + return f"RuntimeError while checking if tool {self.get_name_from_cls()} is active: {e}" + + if log_call: + self._log_tool_application(inspect.currentframe()) + try: + # check whether the tool requires an active project and language server + if not isinstance(self, ToolMarkerDoesNotRequireActiveProject): + if self.agent._active_project is None: + return ( + "Error: No active project. Ask the user to provide the project path or to select a project from this list of known projects: " + + f"{self.agent.serena_config.project_names}" + ) + if self.agent.is_using_language_server() and not self.agent.is_language_server_running(): + log.info("Language server is not running. Starting it ...") + self.agent.reset_language_server() + + # apply the actual tool + try: + result = apply_fn(**kwargs) + except SolidLSPException as e: + if e.is_language_server_terminated(): + log.error(f"Language server terminated while executing tool ({e}). Restarting the language server and retrying ...") + self.agent.reset_language_server() + result = apply_fn(**kwargs) + else: + raise + + # record tool usage + self.agent.record_tool_usage_if_enabled(kwargs, result, self) + + except Exception as e: + if not catch_exceptions: + raise + msg = f"Error executing tool: {e}" + log.error(f"Error executing tool: {e}", exc_info=e) + result = msg + + if log_call: + log.info(f"Result: {result}") + + try: + if self.agent.language_server is not None: + self.agent.language_server.save_cache() + except Exception as e: + log.error(f"Error saving language server cache: {e}") + + return result + + future = self.agent.issue_task(task, name=self.__class__.__name__) + return future.result(timeout=self.agent.serena_config.tool_timeout) + + +class EditedFileContext: + """ + Context manager for file editing. + + Create the context, then use `set_updated_content` to set the new content, the original content + being provided in `original_content`. + When exiting the context without an exception, the updated content will be written back to the file. + """ + + def __init__(self, relative_path: str, agent: "SerenaAgent"): + self._project = agent.get_active_project() + assert self._project is not None + self._abs_path = os.path.join(self._project.project_root, relative_path) + if not os.path.isfile(self._abs_path): + raise FileNotFoundError(f"File {self._abs_path} does not exist.") + with open(self._abs_path, encoding=self._project.project_config.encoding) as f: + self._original_content = f.read() + self._updated_content: str | None = None + + def __enter__(self) -> Self: + return self + + def get_original_content(self) -> str: + """ + :return: the original content of the file before any modifications. + """ + return self._original_content + + def set_updated_content(self, content: str) -> None: + """ + Sets the updated content of the file, which will be written back to the file + when the context is exited without an exception. + + :param content: the updated content of the file + """ + self._updated_content = content + + def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None) -> None: + if self._updated_content is not None and exc_type is None: + assert self._project is not None + with open(self._abs_path, "w", encoding=self._project.project_config.encoding) as f: + f.write(self._updated_content) + log.info(f"Updated content written to {self._abs_path}") + # Language servers should automatically detect the change and update its state accordingly. + # If they do not, we may have to add a call to notify it. + + +@dataclass(kw_only=True) +class RegisteredTool: + tool_class: type[Tool] + is_optional: bool + tool_name: str + + +@singleton +class ToolRegistry: + def __init__(self) -> None: + self._tool_dict: dict[str, RegisteredTool] = {} + for cls in iter_subclasses(Tool): + if not cls.__module__.startswith("serena.tools"): + continue + is_optional = issubclass(cls, ToolMarkerOptional) + name = cls.get_name_from_cls() + if name in self._tool_dict: + raise ValueError(f"Duplicate tool name found: {name}. Tool classes must have unique names.") + self._tool_dict[name] = RegisteredTool(tool_class=cls, is_optional=is_optional, tool_name=name) + + def get_tool_class_by_name(self, tool_name: str) -> type[Tool]: + return self._tool_dict[tool_name].tool_class + + def get_all_tool_classes(self) -> list[type[Tool]]: + return list(t.tool_class for t in self._tool_dict.values()) + + def get_tool_classes_default_enabled(self) -> list[type[Tool]]: + """ + :return: the list of tool classes that are enabled by default (i.e. non-optional tools). + """ + return [t.tool_class for t in self._tool_dict.values() if not t.is_optional] + + def get_tool_classes_optional(self) -> list[type[Tool]]: + """ + :return: the list of tool classes that are optional (i.e. disabled by default). + """ + return [t.tool_class for t in self._tool_dict.values() if t.is_optional] + + def get_tool_names_default_enabled(self) -> list[str]: + """ + :return: the list of tool names that are enabled by default (i.e. non-optional tools). + """ + return [t.tool_name for t in self._tool_dict.values() if not t.is_optional] + + def get_tool_names_optional(self) -> list[str]: + """ + :return: the list of tool names that are optional (i.e. disabled by default). + """ + return [t.tool_name for t in self._tool_dict.values() if t.is_optional] + + def get_tool_names(self) -> list[str]: + """ + :return: the list of all tool names. + """ + return list(self._tool_dict.keys()) + + def print_tool_overview( + self, tools: Iterable[type[Tool] | Tool] | None = None, include_optional: bool = False, only_optional: bool = False + ) -> None: + """ + Print a summary of the tools. If no tools are passed, a summary of the selection of tools (all, default or only optional) is printed. + """ + if tools is None: + if only_optional: + tools = self.get_tool_classes_optional() + elif include_optional: + tools = self.get_all_tool_classes() + else: + tools = self.get_tool_classes_default_enabled() + + tool_dict: dict[str, type[Tool] | Tool] = {} + for tool_class in tools: + tool_dict[tool_class.get_name_from_cls()] = tool_class + for tool_name in sorted(tool_dict.keys()): + tool_class = tool_dict[tool_name] + print(f" * `{tool_name}`: {tool_class.get_tool_description().strip()}") + + def is_valid_tool_name(self, tool_name: str) -> bool: + return tool_name in self._tool_dict diff --git a/Libraries/serena/tools/workflow_tools.py b/Libraries/serena/tools/workflow_tools.py new file mode 100644 index 00000000..be78d03d --- /dev/null +++ b/Libraries/serena/tools/workflow_tools.py @@ -0,0 +1,139 @@ +""" +Tools supporting the general workflow of the agent +""" + +import json +import platform + +from serena.tools import Tool, ToolMarkerDoesNotRequireActiveProject, ToolMarkerOptional + + +class CheckOnboardingPerformedTool(Tool): + """ + Checks whether project onboarding was already performed. + """ + + def apply(self) -> str: + """ + Checks whether project onboarding was already performed. + You should always call this tool before beginning to actually work on the project/after activating a project, + but after calling the initial instructions tool. + """ + from .memory_tools import ListMemoriesTool + + list_memories_tool = self.agent.get_tool(ListMemoriesTool) + memories = json.loads(list_memories_tool.apply()) + if len(memories) == 0: + return ( + "Onboarding not performed yet (no memories available). " + + "You should perform onboarding by calling the `onboarding` tool before proceeding with the task." + ) + else: + return f"""The onboarding was already performed, below is the list of available memories. + Do not read them immediately, just remember that they exist and that you can read them later, if it is necessary + for the current task. + Some memories may be based on previous conversations, others may be general for the current project. + You should be able to tell which one you need based on the name of the memory. + + {memories}""" + + +class OnboardingTool(Tool): + """ + Performs onboarding (identifying the project structure and essential tasks, e.g. for testing or building). + """ + + def apply(self) -> str: + """ + Call this tool if onboarding was not performed yet. + You will call this tool at most once per conversation. + + :return: instructions on how to create the onboarding information + """ + system = platform.system() + return self.prompt_factory.create_onboarding_prompt(system=system) + + +class ThinkAboutCollectedInformationTool(Tool): + """ + Thinking tool for pondering the completeness of collected information. + """ + + def apply(self) -> str: + """ + Think about the collected information and whether it is sufficient and relevant. + This tool should ALWAYS be called after you have completed a non-trivial sequence of searching steps like + find_symbol, find_referencing_symbols, search_files_for_pattern, read_file, etc. + """ + return self.prompt_factory.create_think_about_collected_information() + + +class ThinkAboutTaskAdherenceTool(Tool): + """ + Thinking tool for determining whether the agent is still on track with the current task. + """ + + def apply(self) -> str: + """ + Think about the task at hand and whether you are still on track. + Especially important if the conversation has been going on for a while and there + has been a lot of back and forth. + + This tool should ALWAYS be called before you insert, replace, or delete code. + """ + return self.prompt_factory.create_think_about_task_adherence() + + +class ThinkAboutWhetherYouAreDoneTool(Tool): + """ + Thinking tool for determining whether the task is truly completed. + """ + + def apply(self) -> str: + """ + Whenever you feel that you are done with what the user has asked for, it is important to call this tool. + """ + return self.prompt_factory.create_think_about_whether_you_are_done() + + +class SummarizeChangesTool(Tool, ToolMarkerOptional): + """ + Provides instructions for summarizing the changes made to the codebase. + """ + + def apply(self) -> str: + """ + Summarize the changes you have made to the codebase. + This tool should always be called after you have fully completed any non-trivial coding task, + but only after the think_about_whether_you_are_done call. + """ + return self.prompt_factory.create_summarize_changes() + + +class PrepareForNewConversationTool(Tool): + """ + Provides instructions for preparing for a new conversation (in order to continue with the necessary context). + """ + + def apply(self) -> str: + """ + Instructions for preparing for a new conversation. This tool should only be called on explicit user request. + """ + return self.prompt_factory.create_prepare_for_new_conversation() + + +class InitialInstructionsTool(Tool, ToolMarkerDoesNotRequireActiveProject, ToolMarkerOptional): + """ + Gets the initial instructions for the current project. + Should only be used in settings where the system prompt cannot be set, + e.g. in clients you have no control over, like Claude Desktop. + """ + + def apply(self) -> str: + """ + Get the initial instructions for the current coding project. + If you haven't received instructions on how to use Serena's tools in the system prompt, + you should always call this tool before starting to work (including using any other tool) on any programming task, + the only exception being when you are asked to call `activate_project`, which you should then call before. + """ + return self.agent.create_system_prompt() diff --git a/Libraries/serena/util/class_decorators.py b/Libraries/serena/util/class_decorators.py new file mode 100644 index 00000000..6beadbde --- /dev/null +++ b/Libraries/serena/util/class_decorators.py @@ -0,0 +1,15 @@ +from typing import Any + + +# duplicate of interprompt.class_decorators +# We don't want to depend on interprompt for this in serena, so we duplicate it here +def singleton(cls: type[Any]) -> Any: + instance = None + + def get_instance(*args: Any, **kwargs: Any) -> Any: + nonlocal instance + if instance is None: + instance = cls(*args, **kwargs) + return instance + + return get_instance diff --git a/Libraries/serena/util/exception.py b/Libraries/serena/util/exception.py new file mode 100644 index 00000000..f9cbb562 --- /dev/null +++ b/Libraries/serena/util/exception.py @@ -0,0 +1,65 @@ +import os +import sys + +from serena.agent import log + + +def is_headless_environment() -> bool: + """ + Detect if we're running in a headless environment where GUI operations would fail. + + Returns True if: + - No DISPLAY variable on Linux/Unix + - Running in SSH session + - Running in WSL without X server + - Running in Docker container + """ + # Check if we're on Windows - GUI usually works there + if sys.platform == "win32": + return False + + # Check for DISPLAY variable (required for X11) + if not os.environ.get("DISPLAY"): # type: ignore + return True + + # Check for SSH session + if os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT"): + return True + + # Check for common CI/container environments + if os.environ.get("CI") or os.environ.get("CONTAINER") or os.path.exists("/.dockerenv"): + return True + + # Check for WSL (only on Unix-like systems where os.uname exists) + if hasattr(os, "uname"): + if "microsoft" in os.uname().release.lower(): + # In WSL, even with DISPLAY set, X server might not be running + # This is a simplified check - could be improved + return True + + return False + + +def show_fatal_exception_safe(e: Exception) -> None: + """ + Shows the given exception in the GUI log viewer on the main thread and ensures that the exception is logged or at + least printed to stderr. + """ + # Log the error and print it to stderr + log.error(f"Fatal exception: {e}", exc_info=e) + print(f"Fatal exception: {e}", file=sys.stderr) + + # Don't attempt GUI in headless environments + if is_headless_environment(): + log.debug("Skipping GUI error display in headless environment") + return + + # attempt to show the error in the GUI + try: + # NOTE: The import can fail on macOS if Tk is not available (depends on Python interpreter installation, which uv + # used as a base); while tkinter as such is always available, its dependencies can be unavailable on macOS. + from serena.gui_log_viewer import show_fatal_exception + + show_fatal_exception(e) + except Exception as gui_error: + log.debug(f"Failed to show GUI error dialog: {gui_error}") diff --git a/Libraries/serena/util/file_system.py b/Libraries/serena/util/file_system.py new file mode 100644 index 00000000..784fc2cf --- /dev/null +++ b/Libraries/serena/util/file_system.py @@ -0,0 +1,351 @@ +import logging +import os +from collections.abc import Callable, Iterator +from dataclasses import dataclass, field +from pathlib import Path +from typing import NamedTuple + +import pathspec +from pathspec import PathSpec +from sensai.util.logging import LogTime + +log = logging.getLogger(__name__) + + +class ScanResult(NamedTuple): + """Result of scanning a directory.""" + + directories: list[str] + files: list[str] + + +def scan_directory( + path: str, + recursive: bool = False, + relative_to: str | None = None, + is_ignored_dir: Callable[[str], bool] | None = None, + is_ignored_file: Callable[[str], bool] | None = None, +) -> ScanResult: + """ + :param path: the path to scan + :param recursive: whether to recursively scan subdirectories + :param relative_to: the path to which the results should be relative to; if None, provide absolute paths + :param is_ignored_dir: a function with which to determine whether the given directory (abs. path) shall be ignored + :param is_ignored_file: a function with which to determine whether the given file (abs. path) shall be ignored + :return: the list of directories and files + """ + if is_ignored_file is None: + is_ignored_file = lambda x: False + if is_ignored_dir is None: + is_ignored_dir = lambda x: False + + files = [] + directories = [] + + abs_path = os.path.abspath(path) + rel_base = os.path.abspath(relative_to) if relative_to else None + + try: + with os.scandir(abs_path) as entries: + for entry in entries: + try: + entry_path = entry.path + + if rel_base: + result_path = os.path.relpath(entry_path, rel_base) + else: + result_path = entry_path + + if entry.is_file(): + if not is_ignored_file(entry_path): + files.append(result_path) + elif entry.is_dir(): + if not is_ignored_dir(entry_path): + directories.append(result_path) + if recursive: + sub_result = scan_directory( + entry_path, + recursive=True, + relative_to=relative_to, + is_ignored_dir=is_ignored_dir, + is_ignored_file=is_ignored_file, + ) + files.extend(sub_result.files) + directories.extend(sub_result.directories) + except PermissionError as ex: + # Skip files/directories that cannot be accessed due to permission issues + log.debug(f"Skipping entry due to permission error: {entry.path}", exc_info=ex) + continue + except PermissionError as ex: + # Skip the entire directory if it cannot be accessed + log.debug(f"Skipping directory due to permission error: {abs_path}", exc_info=ex) + return ScanResult([], []) + + return ScanResult(directories, files) + + +def find_all_non_ignored_files(repo_root: str) -> list[str]: + """ + Find all non-ignored files in the repository, respecting all gitignore files in the repository. + + :param repo_root: The root directory of the repository + :return: A list of all non-ignored files in the repository + """ + gitignore_parser = GitignoreParser(repo_root) + _, files = scan_directory( + repo_root, recursive=True, is_ignored_dir=gitignore_parser.should_ignore, is_ignored_file=gitignore_parser.should_ignore + ) + return files + + +@dataclass +class GitignoreSpec: + file_path: str + """Path to the gitignore file.""" + patterns: list[str] = field(default_factory=list) + """List of patterns from the gitignore file. + The patterns are adjusted based on the gitignore file location. + """ + pathspec: PathSpec = field(init=False) + """Compiled PathSpec object for pattern matching.""" + + def __post_init__(self) -> None: + """Initialize the PathSpec from patterns.""" + self.pathspec = PathSpec.from_lines(pathspec.patterns.GitWildMatchPattern, self.patterns) + + def matches(self, relative_path: str) -> bool: + """ + Check if the given path matches any pattern in this gitignore spec. + + :param relative_path: Path to check (should be relative to repo root) + :return: True if path matches any pattern + """ + return match_path(relative_path, self.pathspec, root_path=os.path.dirname(self.file_path)) + + +class GitignoreParser: + """ + Parser for gitignore files in a repository. + + This class handles parsing multiple gitignore files throughout a repository + and provides methods to check if paths should be ignored. + """ + + def __init__(self, repo_root: str) -> None: + """ + Initialize the parser for a repository. + + :param repo_root: Root directory of the repository + """ + self.repo_root = os.path.abspath(repo_root) + self.ignore_specs: list[GitignoreSpec] = [] + self._load_gitignore_files() + + def _load_gitignore_files(self) -> None: + """Load all gitignore files from the repository.""" + with LogTime("Loading of .gitignore files", logger=log): + for gitignore_path in self._iter_gitignore_files(): + log.info("Processing .gitignore file: %s", gitignore_path) + spec = self._create_ignore_spec(gitignore_path) + if spec.patterns: # Only add non-empty specs + self.ignore_specs.append(spec) + + def _iter_gitignore_files(self, follow_symlinks: bool = False) -> Iterator[str]: + """ + Iteratively discover .gitignore files in a top-down fashion, starting from the repository root. + Directory paths are skipped if they match any already loaded ignore patterns. + + :return: an iterator yielding paths to .gitignore files (top-down) + """ + queue: list[str] = [self.repo_root] + + def scan(abs_path: str | None) -> Iterator[str]: + for entry in os.scandir(abs_path): + if entry.is_dir(follow_symlinks=follow_symlinks): + queue.append(entry.path) + elif entry.is_file(follow_symlinks=follow_symlinks) and entry.name == ".gitignore": + yield entry.path + + while queue: + next_abs_path = queue.pop(0) + if next_abs_path != self.repo_root: + rel_path = os.path.relpath(next_abs_path, self.repo_root) + if self.should_ignore(rel_path): + continue + yield from scan(next_abs_path) + + def _create_ignore_spec(self, gitignore_file_path: str) -> GitignoreSpec: + """ + Create a GitignoreSpec from a single gitignore file. + + :param gitignore_file_path: Path to the .gitignore file + :return: GitignoreSpec object for the gitignore patterns + """ + try: + with open(gitignore_file_path, encoding="utf-8") as f: + content = f.read() + except (OSError, UnicodeDecodeError): + # If we can't read the file, return an empty spec + return GitignoreSpec(gitignore_file_path, []) + + gitignore_dir = os.path.dirname(gitignore_file_path) + patterns = self._parse_gitignore_content(content, gitignore_dir) + + return GitignoreSpec(gitignore_file_path, patterns) + + def _parse_gitignore_content(self, content: str, gitignore_dir: str) -> list[str]: + """ + Parse gitignore content and adjust patterns based on the gitignore file location. + + :param content: Content of the .gitignore file + :param gitignore_dir: Directory containing the .gitignore file (absolute path) + :return: List of adjusted patterns + """ + patterns = [] + + # Get the relative path from repo root to the gitignore directory + rel_dir = os.path.relpath(gitignore_dir, self.repo_root) + if rel_dir == ".": + rel_dir = "" + + for line in content.splitlines(): + # Strip trailing whitespace (but preserve leading whitespace for now) + line = line.rstrip() + + # Skip empty lines and comments + if not line or line.lstrip().startswith("#"): + continue + + # Store whether this is a negation pattern + is_negation = line.startswith("!") + if is_negation: + line = line[1:] + + # Strip leading/trailing whitespace after removing negation + line = line.strip() + + if not line: + continue + + # Handle escaped characters at the beginning + if line.startswith(("\\#", "\\!")): + line = line[1:] + + # Determine if pattern is anchored to the gitignore directory and remove leading slash for processing + is_anchored = line.startswith("/") + if is_anchored: + line = line[1:] + + # Adjust pattern based on gitignore file location + if rel_dir: + if is_anchored: + # Anchored patterns are relative to the gitignore directory + adjusted_pattern = os.path.join(rel_dir, line) + else: + # Non-anchored patterns can match anywhere below the gitignore directory + # We need to preserve this behavior + if line.startswith("**/"): + # Even if pattern starts with **, it should still be scoped to the subdirectory + adjusted_pattern = os.path.join(rel_dir, line) + else: + # Add the directory prefix but also allow matching in subdirectories + adjusted_pattern = os.path.join(rel_dir, "**", line) + else: + if is_anchored: + # Anchored patterns in root should only match at root level + # Add leading slash back to indicate root-only matching + adjusted_pattern = "/" + line + else: + # Non-anchored patterns can match anywhere + adjusted_pattern = line + + # Re-add negation if needed + if is_negation: + adjusted_pattern = "!" + adjusted_pattern + + # Normalize path separators to forward slashes (gitignore uses forward slashes) + adjusted_pattern = adjusted_pattern.replace(os.sep, "/") + + patterns.append(adjusted_pattern) + + return patterns + + def should_ignore(self, path: str) -> bool: + """ + Check if a path should be ignored based on the gitignore rules. + + :param path: Path to check (absolute or relative to repo_root) + :return: True if the path should be ignored, False otherwise + """ + # Convert to relative path from repo root + if os.path.isabs(path): + try: + rel_path = os.path.relpath(path, self.repo_root) + except Exception as e: + # If the path could not be converted to a relative path, + # it is outside the repository root, so we ignore it + log.info("Ignoring path '%s' which is outside of the repository root (%s)", path, e) + return True + else: + rel_path = path + + # Ignore paths inside .git + rel_path_first_path = Path(rel_path).parts[0] + if rel_path_first_path == ".git": + return True + + abs_path = os.path.join(self.repo_root, rel_path) + + # Normalize path separators + rel_path = rel_path.replace(os.sep, "/") + + if os.path.exists(abs_path) and os.path.isdir(abs_path) and not rel_path.endswith("/"): + rel_path = rel_path + "/" + + # Check against each ignore spec + for spec in self.ignore_specs: + if spec.matches(rel_path): + return True + + return False + + def get_ignore_specs(self) -> list[GitignoreSpec]: + """ + Get all loaded gitignore specs. + + :return: List of GitignoreSpec objects + """ + return self.ignore_specs + + def reload(self) -> None: + """Reload all gitignore files from the repository.""" + self.ignore_specs.clear() + self._load_gitignore_files() + + +def match_path(relative_path: str, path_spec: PathSpec, root_path: str = "") -> bool: + """ + Match a relative path against a given pathspec. Just pathspec.match_file() is not enough, + we need to do some massaging to fix issues with pathspec matching. + + :param relative_path: relative path to match against the pathspec + :param path_spec: the pathspec to match against + :param root_path: the root path from which the relative path is derived + :return: + """ + normalized_path = str(relative_path).replace(os.path.sep, "/") + + # We can have patterns like /src/..., which would only match corresponding paths from the repo root + # Unfortunately, pathspec can't know whether a relative path is relative to the repo root or not, + # so it will never match src/... + # The fix is to just always assume that the input path is relative to the repo root and to + # prefix it with /. + if not normalized_path.startswith("/"): + normalized_path = "/" + normalized_path + + # pathspec can't handle the matching of directories if they don't end with a slash! + # see https://github.com/cpburnz/python-pathspec/issues/89 + abs_path = os.path.abspath(os.path.join(root_path, relative_path)) + if os.path.isdir(abs_path) and not normalized_path.endswith("/"): + normalized_path = normalized_path + "/" + return path_spec.match_file(normalized_path) diff --git a/Libraries/serena/util/general.py b/Libraries/serena/util/general.py new file mode 100644 index 00000000..b350ae18 --- /dev/null +++ b/Libraries/serena/util/general.py @@ -0,0 +1,32 @@ +import os +from typing import Literal, overload + +from ruamel.yaml import YAML +from ruamel.yaml.comments import CommentedMap + + +def _create_YAML(preserve_comments: bool = False) -> YAML: + """ + Creates a YAML that can load/save with comments if preserve_comments is True. + """ + typ = None if preserve_comments else "safe" + result = YAML(typ=typ) + result.preserve_quotes = preserve_comments + return result + + +@overload +def load_yaml(path: str, preserve_comments: Literal[False]) -> dict: ... +@overload +def load_yaml(path: str, preserve_comments: Literal[True]) -> CommentedMap: ... +def load_yaml(path: str, preserve_comments: bool = False) -> dict | CommentedMap: + with open(path, encoding="utf-8") as f: + yaml = _create_YAML(preserve_comments) + return yaml.load(f) + + +def save_yaml(path: str, data: dict | CommentedMap, preserve_comments: bool = False) -> None: + yaml = _create_YAML(preserve_comments) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + yaml.dump(data, f) diff --git a/Libraries/serena/util/git.py b/Libraries/serena/util/git.py new file mode 100644 index 00000000..865ae59e --- /dev/null +++ b/Libraries/serena/util/git.py @@ -0,0 +1,20 @@ +import logging + +from sensai.util.git import GitStatus + +from .shell import subprocess_check_output + +log = logging.getLogger(__name__) + + +def get_git_status() -> GitStatus | None: + try: + commit_hash = subprocess_check_output(["git", "rev-parse", "HEAD"]) + unstaged = bool(subprocess_check_output(["git", "diff", "--name-only"])) + staged = bool(subprocess_check_output(["git", "diff", "--staged", "--name-only"])) + untracked = bool(subprocess_check_output(["git", "ls-files", "--others", "--exclude-standard"])) + return GitStatus( + commit=commit_hash, has_unstaged_changes=unstaged, has_staged_uncommitted_changes=staged, has_untracked_files=untracked + ) + except: + return None diff --git a/Libraries/serena/util/inspection.py b/Libraries/serena/util/inspection.py new file mode 100644 index 00000000..668cf545 --- /dev/null +++ b/Libraries/serena/util/inspection.py @@ -0,0 +1,58 @@ +import logging +import os +from collections.abc import Generator +from typing import TypeVar + +from serena.util.file_system import find_all_non_ignored_files +from solidlsp.ls_config import Language + +T = TypeVar("T") + +log = logging.getLogger(__name__) + + +def iter_subclasses(cls: type[T], recursive: bool = True) -> Generator[type[T], None, None]: + """Iterate over all subclasses of a class. If recursive is True, also iterate over all subclasses of all subclasses.""" + for subclass in cls.__subclasses__(): + yield subclass + if recursive: + yield from iter_subclasses(subclass, recursive) + + +def determine_programming_language_composition(repo_path: str) -> dict[str, float]: + """ + Determine the programming language composition of a repository. + + :param repo_path: Path to the repository to analyze + + :return: Dictionary mapping language names to percentages of files matching each language + """ + all_files = find_all_non_ignored_files(repo_path) + + if not all_files: + return {} + + # Count files for each language + language_counts: dict[str, int] = {} + total_files = len(all_files) + + for language in Language.iter_all(include_experimental=False): + matcher = language.get_source_fn_matcher() + count = 0 + + for file_path in all_files: + # Use just the filename for matching, not the full path + filename = os.path.basename(file_path) + if matcher.is_relevant_filename(filename): + count += 1 + + if count > 0: + language_counts[str(language)] = count + + # Convert counts to percentages + language_percentages: dict[str, float] = {} + for language_name, count in language_counts.items(): + percentage = (count / total_files) * 100 + language_percentages[language_name] = round(percentage, 2) + + return language_percentages diff --git a/Libraries/serena/util/logging.py b/Libraries/serena/util/logging.py new file mode 100644 index 00000000..75af996b --- /dev/null +++ b/Libraries/serena/util/logging.py @@ -0,0 +1,67 @@ +import queue +import threading +from collections.abc import Callable + +from sensai.util import logging + +from serena.constants import SERENA_LOG_FORMAT + + +class MemoryLogHandler(logging.Handler): + def __init__(self, level: int = logging.NOTSET) -> None: + super().__init__(level=level) + self.setFormatter(logging.Formatter(SERENA_LOG_FORMAT)) + self._log_buffer = LogBuffer() + self._log_queue: queue.Queue[str] = queue.Queue() + self._stop_event = threading.Event() + self._emit_callbacks: list[Callable[[str], None]] = [] + + # start background thread to process logs + self.worker_thread = threading.Thread(target=self._process_queue, daemon=True) + self.worker_thread.start() + + def add_emit_callback(self, callback: Callable[[str], None]) -> None: + """ + Adds a callback that will be called with each log message. + The callback should accept a single string argument (the log message). + """ + self._emit_callbacks.append(callback) + + def emit(self, record: logging.LogRecord) -> None: + msg = self.format(record) + self._log_queue.put_nowait(msg) + + def _process_queue(self) -> None: + while not self._stop_event.is_set(): + try: + msg = self._log_queue.get(timeout=1) + self._log_buffer.append(msg) + for callback in self._emit_callbacks: + try: + callback(msg) + except: + pass + self._log_queue.task_done() + except queue.Empty: + continue + + def get_log_messages(self) -> list[str]: + return self._log_buffer.get_log_messages() + + +class LogBuffer: + """ + A thread-safe buffer for storing log messages. + """ + + def __init__(self) -> None: + self._log_messages: list[str] = [] + self._lock = threading.Lock() + + def append(self, msg: str) -> None: + with self._lock: + self._log_messages.append(msg) + + def get_log_messages(self) -> list[str]: + with self._lock: + return self._log_messages.copy() diff --git a/Libraries/serena/util/shell.py b/Libraries/serena/util/shell.py new file mode 100644 index 00000000..3b310852 --- /dev/null +++ b/Libraries/serena/util/shell.py @@ -0,0 +1,49 @@ +import os +import subprocess + +from pydantic import BaseModel + +from solidlsp.util.subprocess_util import subprocess_kwargs + + +class ShellCommandResult(BaseModel): + stdout: str + return_code: int + cwd: str + stderr: str | None = None + + +def execute_shell_command(command: str, cwd: str | None = None, capture_stderr: bool = False) -> ShellCommandResult: + """ + Execute a shell command and return the output. + + :param command: The command to execute. + :param cwd: The working directory to execute the command in. If None, the current working directory will be used. + :param capture_stderr: Whether to capture the stderr output. + :return: The output of the command. + """ + if cwd is None: + cwd = os.getcwd() + + process = subprocess.Popen( + command, + shell=True, + stdin=subprocess.DEVNULL, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE if capture_stderr else None, + text=True, + encoding="utf-8", + errors="replace", + cwd=cwd, + **subprocess_kwargs(), + ) + + stdout, stderr = process.communicate() + return ShellCommandResult(stdout=stdout, stderr=stderr, return_code=process.returncode, cwd=cwd) + + +def subprocess_check_output(args: list[str], encoding: str = "utf-8", strip: bool = True, timeout: float | None = None) -> str: + output = subprocess.check_output(args, stdin=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=timeout, env=os.environ.copy(), **subprocess_kwargs()).decode(encoding) # type: ignore + if strip: + output = output.strip() + return output diff --git a/Libraries/serena/util/thread.py b/Libraries/serena/util/thread.py new file mode 100644 index 00000000..7b779187 --- /dev/null +++ b/Libraries/serena/util/thread.py @@ -0,0 +1,69 @@ +import threading +from collections.abc import Callable +from enum import Enum +from typing import Generic, TypeVar + +from sensai.util.string import ToStringMixin + + +class TimeoutException(Exception): + def __init__(self, message: str, timeout: float) -> None: + super().__init__(message) + self.timeout = timeout + + +T = TypeVar("T") + + +class ExecutionResult(Generic[T], ToStringMixin): + + class Status(Enum): + SUCCESS = "success" + TIMEOUT = "timeout" + EXCEPTION = "error" + + def __init__(self) -> None: + self.result_value: T | None = None + self.status: ExecutionResult.Status | None = None + self.exception: Exception | None = None + + def set_result_value(self, value: T) -> None: + self.result_value = value + self.status = ExecutionResult.Status.SUCCESS + + def set_timed_out(self, exception: TimeoutException) -> None: + self.exception = exception + self.status = ExecutionResult.Status.TIMEOUT + + def set_exception(self, exception: Exception) -> None: + self.exception = exception + self.status = ExecutionResult.Status.EXCEPTION + + +def execute_with_timeout(func: Callable[[], T], timeout: float, function_name: str) -> ExecutionResult[T]: + """ + Executes the given function with a timeout + + :param func: the function to execute + :param timeout: the timeout in seconds + :param function_name: the name of the function (for error messages) + :returns: the execution result + """ + execution_result: ExecutionResult[T] = ExecutionResult() + + def target() -> None: + try: + value = func() + execution_result.set_result_value(value) + except Exception as e: + execution_result.set_exception(e) + + thread = threading.Thread(target=target, daemon=True) + thread.start() + thread.join(timeout=timeout) + + if thread.is_alive(): + timeout_exception = TimeoutException(f"Execution of '{function_name}' timed out after {timeout} seconds.", timeout) + execution_result.set_timed_out(timeout_exception) + + return execution_result diff --git a/sync_libraries.py b/sync_libraries.py new file mode 100644 index 00000000..48f1a813 --- /dev/null +++ b/sync_libraries.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 +""" +Dynamic Library Sync System for Analyzer + +This script synchronizes external libraries (autogenlib, serena, graph-sitter) +into the Libraries folder, ensuring they stay up-to-date with source changes. + +Usage: + python sync_libraries.py # Sync all libraries + python sync_libraries.py --library autogenlib # Sync specific library + python sync_libraries.py --check # Check for changes without syncing +""" + +import argparse +import hashlib +import json +import logging +import os +import shutil +import subprocess +import sys +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Configuration +REPO_ROOT = Path(__file__).parent +LIBRARIES_DIR = REPO_ROOT / "Libraries" +TEMP_DIR = REPO_ROOT / ".lib_sync_temp" +SYNC_STATE_FILE = LIBRARIES_DIR / ".sync_state.json" + +LIBRARY_CONFIGS = { + "autogenlib": { + "repo_url": "https://github.com/Zeeeepa/autogenlib.git", + "source_path": "autogenlib", # Path within the repo + "target_path": LIBRARIES_DIR / "autogenlib", + "include_patterns": ["*.py", "*.pyi", "*.typed"], + "exclude_patterns": ["*test*", "*__pycache__*", "*.pyc"], + }, + "serena": { + "repo_url": "https://github.com/Zeeeepa/serena.git", + "source_path": "src/serena", # Main serena package + "target_path": LIBRARIES_DIR / "serena", + "include_patterns": ["*.py", "*.pyi", "*.typed"], + "exclude_patterns": ["*test*", "*__pycache__*", "*.pyc"], + }, + "graph_sitter": { + "repo_url": "https://github.com/Zeeeepa/graph-sitter.git", + "source_path": "src", # graph_sitter package is in src/ + "target_path": LIBRARIES_DIR / "graph_sitter_lib", + "include_patterns": ["*.py", "*.pyi", "*.typed"], + "exclude_patterns": ["*test*", "*__pycache__*", "*.pyc"], + }, +} + + +class LibrarySync: + """Manages synchronization of external libraries.""" + + def __init__(self, library_name: str, config: Dict): + self.library_name = library_name + self.config = config + self.repo_url = config["repo_url"] + self.source_path = config["source_path"] + self.target_path = Path(config["target_path"]) + self.include_patterns = config.get("include_patterns", ["*"]) + self.exclude_patterns = config.get("exclude_patterns", []) + + def clone_or_pull(self) -> Path: + """Clone or update the library repository.""" + temp_repo_path = TEMP_DIR / self.library_name + + if temp_repo_path.exists(): + logger.info(f"Updating {self.library_name} repository...") + try: + subprocess.run( + ["git", "pull", "--ff-only"], + cwd=temp_repo_path, + check=True, + capture_output=True, + text=True + ) + except subprocess.CalledProcessError as e: + logger.warning(f"Git pull failed, re-cloning: {e}") + shutil.rmtree(temp_repo_path) + return self.clone_or_pull() + else: + logger.info(f"Cloning {self.library_name} repository...") + TEMP_DIR.mkdir(parents=True, exist_ok=True) + subprocess.run( + ["git", "clone", "--depth", "1", self.repo_url, str(temp_repo_path)], + check=True, + capture_output=True, + text=True + ) + + return temp_repo_path + + def calculate_directory_hash(self, directory: Path) -> str: + """Calculate hash of directory contents for change detection.""" + hash_md5 = hashlib.md5() + + if not directory.exists(): + return "" + + for file_path in sorted(directory.rglob("*.py")): + if file_path.is_file() and not any(excl in str(file_path) for excl in self.exclude_patterns): + try: + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + except Exception as e: + logger.warning(f"Error hashing {file_path}: {e}") + + return hash_md5.hexdigest() + + def should_sync(self, source_dir: Path) -> bool: + """Check if sync is needed based on content changes.""" + source_hash = self.calculate_directory_hash(source_dir) + target_hash = self.calculate_directory_hash(self.target_path) + + logger.info(f"{self.library_name}: source_hash={source_hash[:8]}, target_hash={target_hash[:8]}") + + return source_hash != target_hash + + def copy_filtered(self, source: Path, target: Path): + """Copy files from source to target with filtering.""" + if target.exists(): + shutil.rmtree(target) + + target.mkdir(parents=True, exist_ok=True) + + copied_count = 0 + for file_path in source.rglob("*"): + if not file_path.is_file(): + continue + + # Check exclude patterns + if any(excl in str(file_path) for excl in self.exclude_patterns): + continue + + # Check include patterns + if not any(file_path.match(pattern) for pattern in self.include_patterns): + continue + + # Calculate relative path and target location + rel_path = file_path.relative_to(source) + target_file = target / rel_path + + # Create parent directories + target_file.parent.mkdir(parents=True, exist_ok=True) + + # Copy file + shutil.copy2(file_path, target_file) + copied_count += 1 + + logger.info(f"Copied {copied_count} files to {target}") + + def sync(self, force: bool = False) -> bool: + """Synchronize the library.""" + logger.info(f"Syncing {self.library_name}...") + + # Clone or update repository + temp_repo_path = self.clone_or_pull() + source_dir = temp_repo_path / self.source_path + + if not source_dir.exists(): + logger.error(f"Source path {source_dir} does not exist!") + return False + + # Check if sync is needed + if not force and not self.should_sync(source_dir): + logger.info(f"{self.library_name} is up-to-date, skipping sync") + return True + + # Copy files + logger.info(f"Syncing {self.library_name} from {source_dir} to {self.target_path}") + self.copy_filtered(source_dir, self.target_path) + + # Create __init__.py if it doesn't exist + init_file = self.target_path / "__init__.py" + if not init_file.exists(): + init_file.write_text(f'"""Synced from {self.repo_url}"""\n') + + logger.info(f"✅ Successfully synced {self.library_name}") + return True + + def check_status(self) -> Dict: + """Check sync status without syncing.""" + temp_repo_path = TEMP_DIR / self.library_name + + if not temp_repo_path.exists(): + temp_repo_path = self.clone_or_pull() + + source_dir = temp_repo_path / self.source_path + + return { + "library": self.library_name, + "target_exists": self.target_path.exists(), + "source_hash": self.calculate_directory_hash(source_dir), + "target_hash": self.calculate_directory_hash(self.target_path), + "needs_sync": self.should_sync(source_dir), + } + + +class SyncManager: + """Manages synchronization of all libraries.""" + + def __init__(self): + self.libraries = { + name: LibrarySync(name, config) + for name, config in LIBRARY_CONFIGS.items() + } + + def sync_all(self, force: bool = False) -> bool: + """Sync all libraries.""" + logger.info("=" * 60) + logger.info("SYNCING ALL LIBRARIES") + logger.info("=" * 60) + + results = {} + for name, lib_sync in self.libraries.items(): + try: + results[name] = lib_sync.sync(force=force) + except Exception as e: + logger.error(f"Error syncing {name}: {e}", exc_info=True) + results[name] = False + + # Save sync state + self.save_sync_state(results) + + # Summary + logger.info("\n" + "=" * 60) + logger.info("SYNC SUMMARY") + logger.info("=" * 60) + for name, success in results.items(): + status = "✅ SUCCESS" if success else "❌ FAILED" + logger.info(f"{name:20} {status}") + + return all(results.values()) + + def sync_one(self, library_name: str, force: bool = False) -> bool: + """Sync a specific library.""" + if library_name not in self.libraries: + logger.error(f"Unknown library: {library_name}") + logger.info(f"Available libraries: {', '.join(self.libraries.keys())}") + return False + + lib_sync = self.libraries[library_name] + try: + return lib_sync.sync(force=force) + except Exception as e: + logger.error(f"Error syncing {library_name}: {e}", exc_info=True) + return False + + def check_all(self) -> Dict: + """Check status of all libraries.""" + statuses = {} + for name, lib_sync in self.libraries.items(): + try: + statuses[name] = lib_sync.check_status() + except Exception as e: + logger.error(f"Error checking {name}: {e}", exc_info=True) + statuses[name] = {"error": str(e)} + + return statuses + + def save_sync_state(self, results: Dict): + """Save sync state to file.""" + state = { + "last_sync": datetime.now().isoformat(), + "results": results, + } + + LIBRARIES_DIR.mkdir(parents=True, exist_ok=True) + with open(SYNC_STATE_FILE, 'w') as f: + json.dump(state, f, indent=2) + + logger.info(f"Sync state saved to {SYNC_STATE_FILE}") + + def load_sync_state(self) -> Optional[Dict]: + """Load sync state from file.""" + if not SYNC_STATE_FILE.exists(): + return None + + with open(SYNC_STATE_FILE) as f: + return json.load(f) + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Sync external libraries into analyzer" + ) + parser.add_argument( + "--library", + "-l", + help="Sync specific library (autogenlib, serena, graph_sitter)" + ) + parser.add_argument( + "--check", + "-c", + action="store_true", + help="Check status without syncing" + ) + parser.add_argument( + "--force", + "-f", + action="store_true", + help="Force sync even if no changes detected" + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Verbose output" + ) + + args = parser.parse_args() + + if args.verbose: + logger.setLevel(logging.DEBUG) + + manager = SyncManager() + + try: + if args.check: + # Check status + statuses = manager.check_all() + print("\n" + "=" * 60) + print("LIBRARY STATUS") + print("=" * 60) + for name, status in statuses.items(): + if "error" in status: + print(f"\n{name}: ❌ ERROR - {status['error']}") + else: + needs_sync = "⚠️ NEEDS SYNC" if status['needs_sync'] else "✅ UP-TO-DATE" + print(f"\n{name}: {needs_sync}") + print(f" Target exists: {status['target_exists']}") + print(f" Source hash: {status['source_hash'][:8]}") + print(f" Target hash: {status['target_hash'][:8]}") + return 0 + + elif args.library: + # Sync specific library + success = manager.sync_one(args.library, force=args.force) + return 0 if success else 1 + + else: + # Sync all libraries + success = manager.sync_all(force=args.force) + return 0 if success else 1 + + except KeyboardInterrupt: + logger.info("\nSync interrupted by user") + return 130 + except Exception as e: + logger.error(f"Unexpected error: {e}", exc_info=True) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/validate_modules.py b/validate_modules.py new file mode 100644 index 00000000..65c63af4 --- /dev/null +++ b/validate_modules.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 +""" +Module and Adapter Validation Script + +This script validates that all modules and adapters in the analyzer +work correctly after library synchronization. + +Usage: + python validate_modules.py # Validate all + python validate_modules.py --quick # Quick validation + python validate_modules.py --verbose # Verbose output +""" + +import argparse +import importlib +import logging +import sys +from pathlib import Path +from typing import Dict, List, Tuple + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Add Libraries to path +REPO_ROOT = Path(__file__).parent +LIBRARIES_DIR = REPO_ROOT / "Libraries" +sys.path.insert(0, str(LIBRARIES_DIR)) + + +class ModuleValidator: + """Validates modules and adapters.""" + + def __init__(self, verbose: bool = False): + self.verbose = verbose + self.results = {} + + def validate_import(self, module_name: str, description: str = "") -> Tuple[bool, str]: + """Validate that a module can be imported.""" + try: + module = importlib.import_module(module_name) + if self.verbose: + logger.info(f"✅ {module_name}: Successfully imported") + if hasattr(module, '__file__'): + logger.info(f" Location: {module.__file__}") + return True, "OK" + except ImportError as e: + logger.error(f"❌ {module_name}: Import failed - {e}") + return False, str(e) + except Exception as e: + logger.error(f"⚠️ {module_name}: Unexpected error - {e}") + return False, str(e) + + def validate_library_modules(self) -> Dict[str, bool]: + """Validate synced library modules.""" + logger.info("\n" + "=" * 60) + logger.info("VALIDATING LIBRARY MODULES") + logger.info("=" * 60) + + modules_to_test = [ + ("autogenlib", "Autogenlib core module"), + ("serena", "Serena semantic editing module"), + ("graph_sitter", "Graph-sitter core from graph_sitter_lib"), + ] + + results = {} + for module_name, description in modules_to_test: + logger.info(f"\nTesting: {description}") + success, error = self.validate_import(module_name, description) + results[module_name] = success + + return results + + def validate_adapter_modules(self) -> Dict[str, bool]: + """Validate adapter modules.""" + logger.info("\n" + "=" * 60) + logger.info("VALIDATING ADAPTER MODULES") + logger.info("=" * 60) + + adapters = [ + ("static_libs", "Static libraries and utilities"), + ("lsp_adapter", "LSP adapter for real-time diagnostics"), + ("autogenlib_adapter", "Autogenlib integration adapter"), + ("graph_sitter_adapter", "Graph-sitter AST parsing adapter"), + ("analyzer", "Core analyzer orchestrator"), + ] + + results = {} + for adapter_name, description in adapters: + logger.info(f"\nTesting: {description}") + success, error = self.validate_import(adapter_name, description) + results[adapter_name] = success + + return results + + def validate_adapter_classes(self) -> Dict[str, bool]: + """Validate specific classes in adapters.""" + logger.info("\n" + "=" * 60) + logger.info("VALIDATING ADAPTER CLASSES") + logger.info("=" * 60) + + tests = [ + ("static_libs", "Severity", "Severity enum"), + ("static_libs", "ErrorCategory", "ErrorCategory enum"), + ("static_libs", "CacheManager", "Cache manager class"), + ("static_libs", "ErrorDatabase", "Error database class"), + ("lsp_adapter", "LSPDiagnosticsManager", "LSP diagnostics manager"), + ("lsp_adapter", "RuntimeErrorCollector", "Runtime error collector"), + ] + + results = {} + for module_name, class_name, description in tests: + test_id = f"{module_name}.{class_name}" + logger.info(f"\nTesting: {description}") + + try: + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + + if self.verbose: + logger.info(f"✅ {test_id}: Found") + logger.info(f" Type: {type(cls)}") + else: + logger.info(f"✅ {test_id}: Found") + + results[test_id] = True + except ImportError as e: + logger.error(f"❌ {test_id}: Module import failed - {e}") + results[test_id] = False + except AttributeError as e: + logger.error(f"❌ {test_id}: Class not found - {e}") + results[test_id] = False + except Exception as e: + logger.error(f"⚠️ {test_id}: Unexpected error - {e}") + results[test_id] = False + + return results + + def validate_adapter_functions(self) -> Dict[str, bool]: + """Validate specific functions in adapters.""" + logger.info("\n" + "=" * 60) + logger.info("VALIDATING ADAPTER FUNCTIONS") + logger.info("=" * 60) + + tests = [ + ("autogenlib_adapter", "get_llm_codebase_overview", "LLM codebase overview function"), + ("autogenlib_adapter", "get_file_context", "File context extraction function"), + ("autogenlib_adapter", "get_comprehensive_symbol_context", "Symbol context function"), + ] + + results = {} + for module_name, func_name, description in tests: + test_id = f"{module_name}.{func_name}" + logger.info(f"\nTesting: {description}") + + try: + module = importlib.import_module(module_name) + func = getattr(module, func_name) + + if callable(func): + if self.verbose: + logger.info(f"✅ {test_id}: Found and callable") + logger.info(f" Type: {type(func)}") + else: + logger.info(f"✅ {test_id}: Found and callable") + results[test_id] = True + else: + logger.error(f"❌ {test_id}: Found but not callable") + results[test_id] = False + except ImportError as e: + logger.error(f"❌ {test_id}: Module import failed - {e}") + results[test_id] = False + except AttributeError as e: + logger.error(f"❌ {test_id}: Function not found - {e}") + results[test_id] = False + except Exception as e: + logger.error(f"⚠️ {test_id}: Unexpected error - {e}") + results[test_id] = False + + return results + + def print_summary(self): + """Print validation summary.""" + logger.info("\n" + "=" * 60) + logger.info("VALIDATION SUMMARY") + logger.info("=" * 60) + + all_results = {} + for category_results in self.results.values(): + all_results.update(category_results) + + total = len(all_results) + passed = sum(1 for v in all_results.values() if v) + failed = total - passed + + logger.info(f"\nTotal Tests: {total}") + logger.info(f"Passed: {passed} ✅") + logger.info(f"Failed: {failed} ❌") + logger.info(f"Success Rate: {(passed/total*100):.1f}%") + + if failed > 0: + logger.info("\nFailed Tests:") + for test_name, success in all_results.items(): + if not success: + logger.info(f" ❌ {test_name}") + + return failed == 0 + + def run_all(self, quick: bool = False) -> bool: + """Run all validations.""" + self.results['library_modules'] = self.validate_library_modules() + self.results['adapter_modules'] = self.validate_adapter_modules() + + if not quick: + self.results['adapter_classes'] = self.validate_adapter_classes() + self.results['adapter_functions'] = self.validate_adapter_functions() + + return self.print_summary() + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Validate analyzer modules and adapters" + ) + parser.add_argument( + "--quick", + "-q", + action="store_true", + help="Quick validation (skip detailed checks)" + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Verbose output" + ) + + args = parser.parse_args() + + if args.verbose: + logger.setLevel(logging.DEBUG) + + validator = ModuleValidator(verbose=args.verbose) + + try: + success = validator.run_all(quick=args.quick) + return 0 if success else 1 + except KeyboardInterrupt: + logger.info("\nValidation interrupted by user") + return 130 + except Exception as e: + logger.error(f"Unexpected error: {e}", exc_info=True) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) +