Skip to content

Commit

Permalink
Fix: Inference of python model names from the file system (#3844)
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman authored Feb 14, 2025
1 parent 64af77b commit ea860e4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
15 changes: 10 additions & 5 deletions sqlmesh/core/model/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
render_meta_fields,
)
from sqlmesh.core.model.kind import ModelKindName, _ModelKind
from sqlmesh.utils import registry_decorator
from sqlmesh.utils import registry_decorator, DECORATOR_RETURN_TYPE
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.metaprogramming import build_env, serialize_env

Expand All @@ -39,6 +39,7 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs:
if not is_sql and "columns" not in kwargs:
raise ConfigError("Python model must define column schema.")

self.name_provided = bool(name)
self.name = name or ""
self.is_sql = is_sql
self.kwargs = kwargs
Expand Down Expand Up @@ -76,6 +77,13 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs:
for column_name, column_type in self.kwargs.pop("columns", {}).items()
}

def __call__(
self, func: t.Callable[..., DECORATOR_RETURN_TYPE]
) -> t.Callable[..., DECORATOR_RETURN_TYPE]:
if not self.name_provided:
self.name = get_model_name(Path(inspect.getfile(func)))
return super().__call__(func)

def model(
self,
*,
Expand All @@ -97,10 +105,7 @@ def model(
env: t.Dict[str, t.Any] = {}
entrypoint = self.func.__name__

if not self.name and infer_names:
self.name = get_model_name(Path(inspect.getfile(self.func)))

if not self.name:
if not self.name_provided and not infer_names:
raise ConfigError("Python model must have a name.")

kind = self.kwargs.get("kind", None)
Expand Down
26 changes: 26 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6024,6 +6024,32 @@ def my_model(context, **kwargs):
assert isinstance(context.get_model(expected_name), PythonModel)


def test_python_model_name_inference_multiple_models(tmp_path: Path) -> None:
init_example_project(tmp_path, dialect="duckdb")
config = Config(
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
model_naming=NameInferenceConfig(infer_names=True),
)

path_a = tmp_path / "models/test_schema/test_model_a.py"
path_b = tmp_path / "models/test_schema/test_model_b.py"

model_payload = """from sqlmesh import model
@model(
columns={'"COL"': "int"},
)
def my_model(context, **kwargs):
pass"""

path_a.parent.mkdir(parents=True, exist_ok=True)
path_a.write_text(model_payload)
path_b.write_text(model_payload)

context = Context(paths=tmp_path, config=config)
assert context.get_model("test_schema.test_model_a").name == "test_schema.test_model_a"
assert context.get_model("test_schema.test_model_b").name == "test_schema.test_model_b"


def test_custom_kind():
from sqlmesh import CustomMaterialization

Expand Down

0 comments on commit ea860e4

Please sign in to comment.