Skip to content

Remove refs to InfrahubServices for git ops #6406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions backend/infrahub/artifacts/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from infrahub.artifacts.models import CheckArtifactCreate
from infrahub.core.constants import InfrahubKind, ValidatorConclusion
from infrahub.core.timestamp import Timestamp
from infrahub.git import InfrahubReadOnlyRepository, InfrahubRepository
from infrahub.git.repository import get_initialized_repo
from infrahub.services import InfrahubServices
from infrahub.tasks.artifact import define_artifact
from infrahub.workflows.utils import add_tags
Expand All @@ -14,21 +14,13 @@ async def create(model: CheckArtifactCreate, service: InfrahubServices) -> Valid
await add_tags(branches=[model.branch_name], nodes=[model.target_id])
validator = await service.client.get(kind=InfrahubKind.ARTIFACTVALIDATOR, id=model.validator_id, include=["checks"])

repo: InfrahubReadOnlyRepository | InfrahubRepository
if InfrahubKind.READONLYREPOSITORY:
repo = await InfrahubReadOnlyRepository.init(
id=model.repository_id,
name=model.repository_name,
client=service.client,
service=service,
)
else:
repo = await InfrahubRepository.init(
id=model.repository_id,
name=model.repository_name,
client=service.client,
service=service,
)
repo = await get_initialized_repo(
client=service.client,
repository_id=model.repository_id,
name=model.repository_name,
repository_kind=model.repository_kind,
commit=model.commit,
)

artifact, artifact_created = await define_artifact(model=model, service=service)

Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/computed_attribute/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ async def process_transform(
)

repo = await get_initialized_repo(
client=service.client,
repository_id=transform.repository.peer.id,
name=transform.repository.peer.name.value,
service=service,
repository_kind=str(transform.repository.peer.typename),
commit=repo_node.commit.value,
) # type: ignore[misc]
Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/generators/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ async def run_generator(model: RequestGeneratorRun, service: InfrahubServices) -
await add_tags(branches=[model.branch_name], nodes=[model.target_id])

repository = await get_initialized_repo(
client=service.client,
repository_id=model.repository_id,
name=model.repository_name,
service=service,
repository_kind=model.repository_kind,
commit=model.commit,
)
Expand Down
11 changes: 4 additions & 7 deletions backend/infrahub/git/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from infrahub.git.directory import get_repositories_directory, initialize_repositories_directory
from infrahub.git.worktree import Worktree
from infrahub.log import get_logger
from infrahub.services import InfrahubServices # noqa: TC001
from infrahub.workers.dependencies import get_client

if TYPE_CHECKING:
from infrahub_sdk.branch import BranchData
Expand Down Expand Up @@ -153,9 +153,6 @@ class InfrahubRepositoryBase(BaseModel, ABC):
)

cache_repo: Repo | None = Field(None, description="Internal cache of the GitPython Repo object")
service: InfrahubServices = Field(
..., description="Service object with access to the message queue, the database etc.."
)
is_read_only: bool = Field(False, description="If true, changes will not be synced to remote")

internal_status: str = Field("active", description="Internal status: Active, Inactive, Staging")
Expand All @@ -164,10 +161,10 @@ class InfrahubRepositoryBase(BaseModel, ABC):

@property
def sdk(self) -> InfrahubClient:
if self.client:
return self.client
if not self.client:
self.client = get_client()

return self.service.client
return self.client
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could perhaps be a follow up PR but I think we should get rid of the sdk property.


@property
def default_branch(self) -> str:
Expand Down
17 changes: 9 additions & 8 deletions backend/infrahub/git/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from infrahub.exceptions import CheckError, RepositoryInvalidFileSystemError, TransformError
from infrahub.git.base import InfrahubRepositoryBase, extract_repo_file_information
from infrahub.log import get_logger
from infrahub.workers.dependencies import get_event_service
from infrahub.workflows.utils import add_tags

if TYPE_CHECKING:
Expand All @@ -59,7 +60,6 @@

from infrahub.artifacts.models import CheckArtifactCreate
from infrahub.git.models import RequestArtifactGenerate
from infrahub.services import InfrahubServices


class ArtifactGenerateResult(BaseModel):
Expand Down Expand Up @@ -138,8 +138,8 @@ class that uses an "InfrahubRepository" or "InfrahubReadOnlyRepository" as input
"""

@classmethod
async def init(cls, service: InfrahubServices, commit: str | None = None, **kwargs: Any) -> Self:
self = cls(service=service, **kwargs)
async def init(cls, commit: str | None = None, **kwargs: Any) -> Self:
self = cls(**kwargs)
log = get_logger()
try:
self.validate_local_directories()
Expand Down Expand Up @@ -209,7 +209,8 @@ async def import_objects_from_files(
raise error

infrahub_branch = registry.get_branch_from_registry(branch=infrahub_branch_name)
await self.service.event.send(
event_service = await get_event_service()
await event_service.send(
CommitUpdatedEvent(
commit=commit,
repository_name=self.name,
Expand Down Expand Up @@ -1183,7 +1184,6 @@ async def execute_python_transform(
infrahub_node=InfrahubNode,
)
return await transform.run(data=data)

except ModuleNotFoundError as exc:
error_msg = f"Unable to load the transform file {location}"
log.error(error_msg)
Expand Down Expand Up @@ -1233,11 +1233,11 @@ async def artifact_generate(
artifact_content = await self.execute_python_transform.with_options(
timeout_seconds=transformation.timeout.value
)(
client=self.sdk,
branch_name=branch_name,
commit=commit,
location=transformation_location,
data=response,
client=self.sdk,
convert_query_response=transformation.convert_query_response.value,
) # type: ignore[misc]

Expand Down Expand Up @@ -1293,11 +1293,11 @@ async def render_artifact(
) # type: ignore[misc]
elif message.transform_type == InfrahubKind.TRANSFORMPYTHON:
artifact_content = await self.execute_python_transform.with_options(timeout_seconds=message.timeout)(
client=self.sdk,
branch_name=message.branch_name,
commit=message.commit,
location=message.transform_location,
data=response,
client=self.sdk,
convert_query_response=message.convert_query_response,
) # type: ignore[misc]

Expand Down Expand Up @@ -1340,5 +1340,6 @@ async def render_artifact(
storage_id_previous=previous_storage_id,
)

await self.service.event.send(event=event)
event_service = await get_event_service()
await event_service.send(event=event)
return ArtifactGenerateResult(changed=True, checksum=checksum, storage_id=storage_id, artifact_id=artifact.id)
43 changes: 13 additions & 30 deletions backend/infrahub/git/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from git.exc import BadName, GitCommandError
from infrahub_sdk.exceptions import GraphQLError
from prefect import task
from pydantic import Field

from infrahub.core.constants import InfrahubKind, RepositoryInternalStatus
Expand All @@ -12,7 +13,7 @@
from infrahub.log import get_logger

if TYPE_CHECKING:
from infrahub.services import InfrahubServices
from infrahub_sdk.client import InfrahubClient

log = get_logger()

Expand All @@ -25,10 +26,8 @@ class InfrahubRepository(InfrahubRepositoryIntegrator):
"""

@classmethod
async def new(
cls, service: InfrahubServices, update_commit_value: bool = True, **kwargs: Any
) -> InfrahubRepository:
self = cls(service=service, **kwargs)
async def new(cls, update_commit_value: bool = True, **kwargs: Any) -> InfrahubRepository:
self = cls(**kwargs)
await self.create_locally(
infrahub_branch_name=self.infrahub_branch_name, update_commit_value=update_commit_value
)
Expand Down Expand Up @@ -209,11 +208,11 @@ class InfrahubReadOnlyRepository(InfrahubRepositoryIntegrator):
ref: str | None = Field(None, description="Ref to track on the external repository")

@classmethod
async def new(cls, service: InfrahubServices, **kwargs: Any) -> InfrahubReadOnlyRepository:
async def new(cls, **kwargs: Any) -> InfrahubReadOnlyRepository:
if "ref" not in kwargs or "infrahub_branch_name" not in kwargs:
raise ValueError("ref and infrahub_branch_name are mandatory to initialize a new Read-Only repository")

self = cls(service=service, **kwargs)
self = cls(**kwargs)
await self.create_locally(checkout_ref=self.ref, infrahub_branch_name=self.infrahub_branch_name)
log.info("Created new repository locally.", repository=self.name)
return self
Expand Down Expand Up @@ -248,33 +247,17 @@ async def sync_from_remote(self, commit: str | None = None) -> None:
await self.update_commit_value(branch_name=self.infrahub_branch_name, commit=commit)


@task(
name="Fetch repository commit",
description="Retrieve a git repository at a given commit, if it does not already exist locally",
)
async def get_initialized_repo(
repository_id: str, name: str, service: InfrahubServices, repository_kind: str, commit: str | None = None
client: InfrahubClient, repository_id: str, name: str, repository_kind: str, commit: str | None = None
) -> InfrahubReadOnlyRepository | InfrahubRepository:
if repository_kind == InfrahubKind.REPOSITORY:
return await InfrahubRepository.init(
id=repository_id, name=name, commit=commit, client=service._client, service=service
)

if repository_kind == InfrahubKind.READONLYREPOSITORY:
return await InfrahubReadOnlyRepository.init(
id=repository_id, name=name, commit=commit, client=service._client, service=service
)

raise NotImplementedError(f"The repository kind {repository_kind} has not been implemented")


async def initialize_repo(
location: str, repository_id: str, name: str, service: InfrahubServices, repository_kind: str
) -> InfrahubReadOnlyRepository | InfrahubRepository:
if repository_kind == InfrahubKind.REPOSITORY:
return await InfrahubRepository.new(
location=location, id=repository_id, name=name, client=service._client, service=service
)
return await InfrahubRepository.init(id=repository_id, name=name, commit=commit, client=client)

if repository_kind == InfrahubKind.READONLYREPOSITORY:
return await InfrahubReadOnlyRepository.new(
location=location, id=repository_id, name=name, client=service._client, service=service
)
return await InfrahubReadOnlyRepository.init(id=repository_id, name=name, commit=commit, client=client)

raise NotImplementedError(f"The repository kind {repository_kind} has not been implemented")
21 changes: 15 additions & 6 deletions backend/infrahub/git/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ async def generate_artifact(model: RequestArtifactGenerate, service: InfrahubSer
await add_tags(branches=[model.branch_name], nodes=[model.target_id])
log = get_run_logger()
repo = await get_initialized_repo(
client=service.client,
repository_id=model.repository_id,
name=model.repository_name,
service=service,
repository_kind=model.repository_kind,
commit=model.commit,
)
Expand Down Expand Up @@ -477,9 +477,9 @@ async def merge_git_repository(model: GitRepositoryMerge, service: InfrahubServi
async def import_objects_from_git_repository(model: GitRepositoryImportObjects, service: InfrahubServices) -> None:
await add_branch_tag(model.infrahub_branch_name)
repo = await get_initialized_repo(
client=service.client,
repository_id=model.repository_id,
name=model.repository_name,
service=service,
repository_kind=model.repository_kind,
commit=model.commit,
)
Expand All @@ -495,9 +495,9 @@ async def git_repository_diff_names_only(
model: GitDiffNamesOnly, service: InfrahubServices
) -> GitDiffNamesOnlyResponse:
repo = await get_initialized_repo(
client=service.client,
repository_id=model.repository_id,
name=model.repository_name,
service=service,
repository_kind=model.repository_kind,
)
files_changed: list[str] = []
Expand Down Expand Up @@ -756,7 +756,12 @@ async def run_check_merge_conflicts(
validator = await service.client.get(kind=InfrahubKind.REPOSITORYVALIDATOR, id=model.validator_id)
await validator.checks.fetch()

repo = await InfrahubRepository.init(id=model.repository_id, name=model.repository_name, service=service)
repo = await get_initialized_repo(
client=service.client,
repository_id=model.repository_id,
name=model.repository_name,
repository_kind=InfrahubKind.REPOSITORY,
)
async with lock.registry.get(name=model.repository_name, namespace="repository"):
conflicts = await repo.get_conflicts(source_branch=model.source_branch, dest_branch=model.target_branch)

Expand Down Expand Up @@ -830,8 +835,12 @@ async def run_user_check(model: UserCheckData, service: InfrahubServices) -> Val
validator = await service.client.get(kind=InfrahubKind.USERVALIDATOR, id=model.validator_id)
await validator.checks.fetch()

repo = await InfrahubRepository.init(
id=model.repository_id, name=model.repository_name, commit=model.commit, service=service
repo = await get_initialized_repo(
client=service.client,
repository_id=model.repository_id,
name=model.repository_name,
repository_kind=InfrahubKind.REPOSITORY,
commit=model.commit,
)
conclusion = ValidatorConclusion.FAILURE
severity = "critical"
Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/message_bus/operations/git/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ async def get(message: messages.GitFileGet, service: InfrahubServices) -> None:
log.info("Collecting file from repository", repository=message.repository_name, file=message.file)

repo = await get_initialized_repo(
client=service.client,
repository_id=message.repository_id,
name=message.repository_name,
service=service,
repository_kind=message.repository_kind,
commit=message.commit,
)
Expand Down
23 changes: 7 additions & 16 deletions backend/infrahub/message_bus/operations/git/repository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from prefect import flow

from infrahub.exceptions import RepositoryError
from infrahub.git.repository import InfrahubRepository, get_initialized_repo, initialize_repo
from infrahub.git.repository import InfrahubRepository, get_initialized_repo
from infrahub.log import get_logger
from infrahub.message_bus import messages
from infrahub.message_bus.messages.git_repository_connectivity import (
Expand Down Expand Up @@ -37,21 +37,12 @@ async def fetch(message: messages.RefreshGitFetch, service: InfrahubServices) ->
log.info("Ignoring git fetch request originating from self", worker=WORKER_IDENTITY)
return

try:
repo = await get_initialized_repo(
repository_id=message.repository_id,
name=message.repository_name,
service=service,
repository_kind=message.repository_kind,
)
except RepositoryError:
repo = await initialize_repo(
location=message.location,
repository_id=message.repository_id,
name=message.repository_name,
service=service,
repository_kind=message.repository_kind,
)
repo = await get_initialized_repo(
client=service.client,
repository_id=message.repository_id,
name=message.repository_name,
repository_kind=message.repository_kind,
)

await repo.fetch()
await repo.pull(
Expand Down
4 changes: 2 additions & 2 deletions backend/infrahub/proposed_change/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,9 @@ def _execute(
for repository in model.branch_diff.repositories:
if model.source_branch_sync_with_git:
repo = await get_initialized_repo(
client=service.client,
repository_id=repository.repository_id,
name=repository.repository_name,
service=service,
repository_kind=repository.kind,
)
commit = repo.get_commit_value(proposed_change.source_branch.value)
Expand Down Expand Up @@ -706,9 +706,9 @@ async def run_generator_as_check(
log = get_run_logger()

repository = await get_initialized_repo(
client=service.client,
repository_id=model.repository_id,
name=model.repository_name,
service=service,
repository_kind=model.repository_kind,
commit=model.commit,
)
Expand Down
Loading
Loading