|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import ast |
| 4 | +from collections import defaultdict |
| 5 | +from pathlib import Path |
| 6 | +from typing import Optional |
| 7 | + |
| 8 | +import typer |
| 9 | +from ariadne_codegen.client_generators.package import PackageGenerator, get_package_generator |
| 10 | +from ariadne_codegen.exceptions import ParsingError |
| 11 | +from ariadne_codegen.plugins.explorer import get_plugins_types |
| 12 | +from ariadne_codegen.plugins.manager import PluginManager |
| 13 | +from ariadne_codegen.schema import ( |
| 14 | + filter_fragments_definitions, |
| 15 | + filter_operations_definitions, |
| 16 | + get_graphql_schema_from_path, |
| 17 | +) |
| 18 | +from ariadne_codegen.settings import ClientSettings, CommentsStrategy |
| 19 | +from ariadne_codegen.utils import ast_to_str |
| 20 | +from graphql import DefinitionNode, GraphQLSchema, NoUnusedFragmentsRule, parse, specified_rules, validate |
| 21 | +from rich.console import Console |
| 22 | + |
| 23 | +from ..async_typer import AsyncTyper |
| 24 | +from ..ctl.client import initialize_client |
| 25 | +from ..ctl.utils import catch_exception |
| 26 | +from ..graphql.utils import insert_fragments_inline, remove_fragment_import |
| 27 | +from .parameters import CONFIG_PARAM |
| 28 | + |
| 29 | +app = AsyncTyper() |
| 30 | +console = Console() |
| 31 | + |
| 32 | +ARIADNE_PLUGINS = [ |
| 33 | + "infrahub_sdk.graphql.plugin.PydanticBaseModelPlugin", |
| 34 | + "infrahub_sdk.graphql.plugin.FutureAnnotationPlugin", |
| 35 | + "infrahub_sdk.graphql.plugin.StandardTypeHintPlugin", |
| 36 | +] |
| 37 | + |
| 38 | + |
| 39 | +def find_gql_files(query_path: Path) -> list[Path]: |
| 40 | + """ |
| 41 | + Find all files with .gql extension in the specified directory. |
| 42 | +
|
| 43 | + Args: |
| 44 | + query_path: Path to the directory to search for .gql files |
| 45 | +
|
| 46 | + Returns: |
| 47 | + List of Path objects for all .gql files found |
| 48 | + """ |
| 49 | + if not query_path.exists(): |
| 50 | + raise FileNotFoundError(f"File or directory not found: {query_path}") |
| 51 | + |
| 52 | + if not query_path.is_dir() and query_path.is_file(): |
| 53 | + return [query_path] |
| 54 | + |
| 55 | + return list(query_path.glob("**/*.gql")) |
| 56 | + |
| 57 | + |
| 58 | +def get_graphql_query(queries_path: Path, schema: GraphQLSchema) -> tuple[DefinitionNode, ...]: |
| 59 | + """Get GraphQL queries definitions from a single GraphQL file.""" |
| 60 | + |
| 61 | + if not queries_path.exists(): |
| 62 | + raise FileNotFoundError(f"File not found: {queries_path}") |
| 63 | + if not queries_path.is_file(): |
| 64 | + raise ValueError(f"{queries_path} is not a file") |
| 65 | + |
| 66 | + queries_str = queries_path.read_text(encoding="utf-8") |
| 67 | + queries_ast = parse(queries_str) |
| 68 | + validation_errors = validate( |
| 69 | + schema=schema, |
| 70 | + document_ast=queries_ast, |
| 71 | + rules=[r for r in specified_rules if r is not NoUnusedFragmentsRule], |
| 72 | + ) |
| 73 | + if validation_errors: |
| 74 | + raise ValueError("\n\n".join(error.message for error in validation_errors)) |
| 75 | + return queries_ast.definitions |
| 76 | + |
| 77 | + |
| 78 | +def generate_result_types(directory: Path, package: PackageGenerator, fragment: ast.Module) -> None: |
| 79 | + for file_name, module in package._result_types_files.items(): |
| 80 | + file_path = directory / file_name |
| 81 | + |
| 82 | + insert_fragments_inline(module, fragment) |
| 83 | + remove_fragment_import(module) |
| 84 | + |
| 85 | + code = package._add_comments_to_code(ast_to_str(module), package.queries_source) |
| 86 | + if package.plugin_manager: |
| 87 | + code = package.plugin_manager.generate_result_types_code(code) |
| 88 | + file_path.write_text(code) |
| 89 | + package._generated_files.append(file_path.name) |
| 90 | + |
| 91 | + |
| 92 | +@app.callback() |
| 93 | +def callback() -> None: |
| 94 | + """ |
| 95 | + Various GraphQL related commands. |
| 96 | + """ |
| 97 | + |
| 98 | + |
| 99 | +@app.command() |
| 100 | +@catch_exception(console=console) |
| 101 | +async def export_schema( |
| 102 | + destination: Path = typer.Option("schema.graphql", help="Path to the GraphQL schema file."), |
| 103 | + _: str = CONFIG_PARAM, |
| 104 | +) -> None: |
| 105 | + """Export the GraphQL schema to a file.""" |
| 106 | + |
| 107 | + client = initialize_client() |
| 108 | + schema_text = await client.schema.get_graphql_schema() |
| 109 | + |
| 110 | + destination.parent.mkdir(parents=True, exist_ok=True) |
| 111 | + destination.write_text(schema_text) |
| 112 | + console.print(f"[green]Schema exported to {destination}") |
| 113 | + |
| 114 | + |
| 115 | +@app.command() |
| 116 | +@catch_exception(console=console) |
| 117 | +async def generate_return_types( |
| 118 | + query: Optional[Path] = typer.Argument( |
| 119 | + None, help="Location of the GraphQL query file(s). Defaults to current directory if not specified." |
| 120 | + ), |
| 121 | + schema: Path = typer.Option("schema.graphql", help="Path to the GraphQL schema file."), |
| 122 | + _: str = CONFIG_PARAM, |
| 123 | +) -> None: |
| 124 | + """Create Pydantic Models for GraphQL query return types""" |
| 125 | + |
| 126 | + query = Path.cwd() if query is None else query |
| 127 | + |
| 128 | + # Load the GraphQL schema |
| 129 | + if not schema.exists(): |
| 130 | + raise FileNotFoundError(f"GraphQL Schema file not found: {schema}") |
| 131 | + graphql_schema = get_graphql_schema_from_path(schema_path=str(schema)) |
| 132 | + |
| 133 | + # Initialize the plugin manager |
| 134 | + plugin_manager = PluginManager( |
| 135 | + schema=graphql_schema, |
| 136 | + plugins_types=get_plugins_types(plugins_strs=ARIADNE_PLUGINS), |
| 137 | + ) |
| 138 | + |
| 139 | + # Find the GraphQL files and organize them by directory |
| 140 | + gql_files = find_gql_files(query) |
| 141 | + gql_per_directory: dict[Path, list[Path]] = defaultdict(list) |
| 142 | + for gql_file in gql_files: |
| 143 | + gql_per_directory[gql_file.parent].append(gql_file) |
| 144 | + |
| 145 | + # Generate the Pydantic Models for the GraphQL queries |
| 146 | + for directory, gql_files in gql_per_directory.items(): |
| 147 | + for gql_file in gql_files: |
| 148 | + try: |
| 149 | + definitions = get_graphql_query(queries_path=gql_file, schema=graphql_schema) |
| 150 | + except ValueError as exc: |
| 151 | + console.print(f"[red]Error generating result types for {gql_file}: {exc}") |
| 152 | + continue |
| 153 | + queries = filter_operations_definitions(definitions) |
| 154 | + fragments = filter_fragments_definitions(definitions) |
| 155 | + |
| 156 | + package_generator = get_package_generator( |
| 157 | + schema=graphql_schema, |
| 158 | + fragments=fragments, |
| 159 | + settings=ClientSettings( |
| 160 | + schema_path=str(schema), |
| 161 | + target_package_name=directory.name, |
| 162 | + queries_path=str(directory), |
| 163 | + include_comments=CommentsStrategy.NONE, |
| 164 | + ), |
| 165 | + plugin_manager=plugin_manager, |
| 166 | + ) |
| 167 | + |
| 168 | + parsing_failed = False |
| 169 | + try: |
| 170 | + for query_operation in queries: |
| 171 | + package_generator.add_operation(query_operation) |
| 172 | + except ParsingError as exc: |
| 173 | + console.print(f"[red]Unable to process {gql_file.name}: {exc}") |
| 174 | + parsing_failed = True |
| 175 | + |
| 176 | + if parsing_failed: |
| 177 | + continue |
| 178 | + |
| 179 | + module_fragment = package_generator.fragments_generator.generate() |
| 180 | + |
| 181 | + generate_result_types(directory=directory, package=package_generator, fragment=module_fragment) |
| 182 | + |
| 183 | + for file_name in package_generator._result_types_files.keys(): |
| 184 | + console.print(f"[green]Generated {file_name} in {directory}") |
0 commit comments