Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Star generator refactor #131

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ select = [
"UP",
]
ignore = [
"F401", # unused import - (TODO: fix this when refactoring tests)
"F811", # redefinition (TODO: fix this when refactoring tests)
"E402", # module-level import not at top (conflicts w/ isort)
"E501", # line-length violations (black enforces these)
"E731", # lambda expressions (TODO: we should fix these)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,9 @@

mappings = []
for component_pair in compound_pairs:
best_score = 0.0
best_mapping = None
molA = component_pair[0]
molB = component_pair[1]

for mapper in mappers:
mapping_generator = mapper.suggest_mappings(molA, molB)

if scorer:
try:
tmp_mappings = [
mapping.with_annotations({"score": scorer(mapping)})
for mapping in mapping_generator
]
except:
continue
if len(tmp_mappings) > 0:
tmp_best_mapping = min(tmp_mappings, key=lambda m: m.annotations["score"])

if tmp_best_mapping.annotations["score"] < best_score or best_mapping is None:
best_score = tmp_best_mapping.annotations["score"]
best_mapping = tmp_best_mapping

else:
try:
best_mapping = next(mapping_generator)
except:
continue
best_mapping = _determine_best_mapping(
component_pair=component_pair, mappers=mappers, scorer=scorer
)
if best_mapping is not None:
mappings.append(best_mapping)

Expand Down Expand Up @@ -100,10 +75,6 @@

possible_edges = list(possible_edges)
n_batches = 10 * n_processes
# total = len(possible_edges)

# # size of each batch +fetch division rest
# batch_num = (total // n_batches) + 1

# Prepare parallel execution.
# suboptimal implementation, but itertools.batch is python 3.12,
Expand All @@ -118,3 +89,64 @@
mappings.extend(sub_result)

return mappings


def _serial_map_scoring(
possible_edges: list[tuple[SmallMoleculeComponent, SmallMoleculeComponent]],
scorer: Callable[[AtomMapping], float],
mappers: list[AtomMapper],
edges_to_score: int,
show_progress: bool = True,
):
if show_progress is True:
progress = functools.partial(tqdm, total=edges_to_score, delay=1.5, desc="Mapping")
else:
progress = lambda x: x

mappings = []
for component_pair in progress(possible_edges):
best_mapping = _determine_best_mapping(
component_pair=component_pair, mappers=mappers, scorer=scorer
)

if best_mapping is not None:
mappings.append(best_mapping)

return mappings


def _determine_best_mapping(
component_pair: tuple[SmallMoleculeComponent],
mappers: AtomMapper | list[AtomMapper],
scorer: Callable,
):
best_score = 0.0
best_mapping = None
molA = component_pair[0]
molB = component_pair[1]

for mapper in mappers:
try:
mapping_generator = mapper.suggest_mappings(molA, molB)
except:
continue

if scorer:
tmp_mappings = [
mapping.with_annotations({"score": scorer(mapping)})
for mapping in mapping_generator
]

if len(tmp_mappings) > 0:
tmp_best_mapping = min(tmp_mappings, key=lambda m: m.annotations["score"])

if tmp_best_mapping.annotations["score"] < best_score or best_mapping is None:
best_score = tmp_best_mapping.annotations["score"]
best_mapping = tmp_best_mapping
else:
try:
best_mapping = next(mapping_generator)
except:
continue

Check warning on line 150 in src/konnektor/network_planners/_map_scoring.py

View check run for this annotation

Codecov / codecov/patch

src/konnektor/network_planners/_map_scoring.py#L149-L150

Added lines #L149 - L150 were not covered by tests

return best_mapping
57 changes: 8 additions & 49 deletions src/konnektor/network_planners/concatenators/max_concatenator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

import functools
import itertools
import logging
from collections.abc import Iterable

from gufe import AtomMapper, LigandNetwork
from tqdm import tqdm

from ..generators._parallel_mapping_pattern import _parallel_map_scoring
from ...network_planners._map_scoring import _parallel_map_scoring, _serial_map_scoring
from ._abstract_network_concatenator import NetworkConcatenator

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -90,52 +88,13 @@ def concatenate_networks(self, ligand_networks: Iterable[LigandNetwork]) -> Liga
)

else: # serial variant
if self.progress is True:
progress = functools.partial(
tqdm, total=len(pedges), delay=1.5, desc="Mapping Subnets"
)
else:
progress = lambda x: x

bipartite_graph_mappings = []
for component_pair in progress(pedges):
best_score = 0.0
best_mapping = None
molA = component_pair[0]
molB = component_pair[1]

for mapper in self.mappers:
try:
mapping_generator = mapper.suggest_mappings(molA, molB)
except:
continue

if self.scorer:
tmp_mappings = [
mapping.with_annotations({"score": self.scorer(mapping)})
for mapping in mapping_generator
]

if len(tmp_mappings) > 0:
tmp_best_mapping = min(
tmp_mappings, key=lambda m: m.annotations["score"]
)

if (
tmp_best_mapping.annotations["score"] < best_score
or best_mapping is None
):
best_score = tmp_best_mapping.annotations["score"]
best_mapping = tmp_best_mapping
else:
try:
best_mapping = next(mapping_generator)
except:
print("warning")
continue

if best_mapping is not None:
bipartite_graph_mappings.append(best_mapping)
bipartite_graph_mappings = _serial_map_scoring(
possible_edges=pedges,
scorer=self.scorer,
mappers=self.mappers,
edges_to_score=len(pedges),
show_progress=self.progress,
)

# Add network connecting edges
selected_edges.extend(bipartite_graph_mappings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from gufe import AtomMapper, LigandNetwork

from ...network_planners._map_scoring import _parallel_map_scoring
from .._networkx_implementations import MstNetworkAlgorithm
from ..generators._parallel_mapping_pattern import _parallel_map_scoring
from ._abstract_network_concatenator import NetworkConcatenator

log = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from gufe import AtomMapper, Component, LigandNetwork

from .._map_scoring import _parallel_map_scoring
from ._abstract_network_generator import NetworkGenerator
from ._parallel_mapping_pattern import _parallel_map_scoring


class ExplicitNetworkGenerator(NetworkGenerator):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

import functools
import itertools
from collections.abc import Iterable

import numpy as np
from gufe import AtomMapper, Component, LigandNetwork
from tqdm.auto import tqdm

from .._map_scoring import _parallel_map_scoring, _serial_map_scoring
from ._abstract_network_generator import NetworkGenerator
from ._parallel_mapping_pattern import _parallel_map_scoring

# Todo: is graph connectivity ensured?

Expand Down Expand Up @@ -103,45 +101,13 @@ def generate_ligand_network(self, components: Iterable[Component]) -> LigandNetw
show_progress=self.progress,
)
else: # serial variant
if self.progress is True:
progress = functools.partial(tqdm, total=total, delay=1.5, desc="Mapping")
else:
progress = lambda x: x

mappings = []
for component_pair in progress(sample_combinations):
best_score = 0.0
best_mapping = None
molA = component_pair[0]
molB = component_pair[1]

for mapper in self.mappers:
mapping_generator = mapper.suggest_mappings(molA, molB)

if self.scorer:
tmp_mappings = [
mapping.with_annotations({"score": self.scorer(mapping)})
for mapping in mapping_generator
]

if len(tmp_mappings) > 0:
tmp_best_mapping = min(
tmp_mappings, key=lambda m: m.annotations["score"]
)

if (
tmp_best_mapping.annotations["score"] < best_score
or best_mapping is None
):
best_score = tmp_best_mapping.annotations["score"]
best_mapping = tmp_best_mapping
else:
try:
best_mapping = next(mapping_generator)
except:
continue
if best_mapping is not None:
mappings.append(best_mapping)
mappings = _serial_map_scoring(
possible_edges=sample_combinations,
scorer=self.scorer,
mappers=self.mappers,
edges_to_score=total,
show_progress=self.progress,
)

if len(mappings) == 0:
raise RuntimeError("Could not generate any mapping!")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

import functools
import itertools
from collections.abc import Iterable

from gufe import AtomMapper, Component, LigandNetwork
from tqdm.auto import tqdm

from .._map_scoring import _parallel_map_scoring, _serial_map_scoring
from ._abstract_network_generator import NetworkGenerator
from ._parallel_mapping_pattern import _parallel_map_scoring


class MaximalNetworkGenerator(NetworkGenerator):
Expand Down Expand Up @@ -91,50 +89,13 @@ def generate_ligand_network(self, components: Iterable[Component]) -> LigandNetw
show_progress=self.progress,
)
else: # serial variant
if self.progress is True:
progress = functools.partial(tqdm, total=total, delay=1.5, desc="Mapping")
else:
progress = lambda x: x

mappings = []
for component_pair in progress(itertools.combinations(components, 2)):
best_score = 0.0
best_mapping = None
molA = component_pair[0]
molB = component_pair[1]

for mapper in self.mappers:
try:
mapping_generator = mapper.suggest_mappings(molA, molB)
except:
continue

if self.scorer:
tmp_mappings = [
mapping.with_annotations({"score": self.scorer(mapping)})
for mapping in mapping_generator
]

if len(tmp_mappings) > 0:
tmp_best_mapping = min(
tmp_mappings, key=lambda m: m.annotations["score"]
)

if (
tmp_best_mapping.annotations["score"] < best_score
or best_mapping is None
):
best_score = tmp_best_mapping.annotations["score"]
best_mapping = tmp_best_mapping
else:
try:
best_mapping = next(mapping_generator)
except:
print("warning")
continue

if best_mapping is not None:
mappings.append(best_mapping)
mappings = _serial_map_scoring(
possible_edges=itertools.combinations(components, 2),
scorer=self.scorer,
mappers=self.mappers,
edges_to_score=total,
show_progress=self.progress,
)

if len(mappings) == 0:
raise RuntimeError("Could not generate any mapping!")
Expand Down
Loading
Loading