diff --git a/pyproject.toml b/pyproject.toml index 5d8b1cd..501f8b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ requires-python = '>=3.8' dependencies = [ 'numpy >= 1.21.5', 'matplotlib >= 3.5.1', - 'scipy >= 1.8.0', 'tqdm >= 4.63.0', 'pymatgen >= 2022.3.29', 'plotly >= 5.6.0', diff --git a/src/pylattica/core/periodic_structure.py b/src/pylattica/core/periodic_structure.py index e364bc0..0bd6ca2 100644 --- a/src/pylattica/core/periodic_structure.py +++ b/src/pylattica/core/periodic_structure.py @@ -113,7 +113,7 @@ def __init__(self, lattice: Lattice): self.lattice = lattice self.dim = lattice.dim self._sites = {} - self.site_ids = [] + self._site_ids = [] self._location_lookup = {} self._offset_vector = np.array([VEC_OFFSET for _ in range(self.dim)]) @@ -127,6 +127,10 @@ def as_dict(self): "_sites": copied, } + @property + def site_ids(self): + return copy.copy(self._site_ids) + @classmethod def from_dict(cls, d): struct = cls(Lattice.from_dict(d["lattice"])) @@ -182,7 +186,7 @@ def add_site(self, site_class: str, location: Tuple[float]) -> int: } self._location_lookup[offset_periodized_coords] = new_site_id - self.site_ids.append(new_site_id) + self._site_ids.append(new_site_id) return new_site_id def site_at(self, location: Tuple[float]) -> Dict: diff --git a/src/pylattica/core/runner/asynchronous_runner.py b/src/pylattica/core/runner/asynchronous_runner.py index fbfe19a..e7f985f 100644 --- a/src/pylattica/core/runner/asynchronous_runner.py +++ b/src/pylattica/core/runner/asynchronous_runner.py @@ -86,6 +86,7 @@ def _add_sites_to_queue(): state_updates = merge_updates(state_updates, site_id=site_id) live_state.batch_update(state_updates) site_queue.extend(next_sites) + result.add_step(state_updates) if len(site_queue) == 0: diff --git a/src/pylattica/core/simulation_result.py b/src/pylattica/core/simulation_result.py index 44d3440..ed28d53 100644 --- a/src/pylattica/core/simulation_result.py +++ b/src/pylattica/core/simulation_result.py @@ -5,6 +5,9 @@ from monty.serialization import dumpfn, loadfn import datetime from .simulation_state import SimulationState +from .constants import GENERAL, SITES + +import copy class SimulationResult: @@ -23,13 +26,21 @@ def from_file(cls, fpath): @classmethod def from_dict(cls, res_dict): diffs = res_dict["diffs"] - res = cls(SimulationState.from_dict(res_dict["initial_state"])) + compress_freq = res_dict.get("compress_freq", 1) + res = cls( + SimulationState.from_dict(res_dict["initial_state"]), + compress_freq=compress_freq, + ) for diff in diffs: - formatted = {int(k): v for k, v in diff.items() if k != "GENERAL"} - res.add_step(formatted) + if SITES in diff: + diff[SITES] = {int(k): v for k, v in diff[SITES].items()} + if GENERAL not in diff and SITES not in diff: + diff = {int(k): v for k, v in diff.items()} + res.add_step(diff) + return res - def __init__(self, starting_state: SimulationState): + def __init__(self, starting_state: SimulationState, compress_freq: int = 1): """Initializes a SimulationResult with the specified starting_state. Parameters @@ -38,6 +49,7 @@ def __init__(self, starting_state: SimulationState): The state with which the simulation started. """ self.initial_state = starting_state + self.compress_freq = compress_freq self._diffs: list[dict] = [] self._stored_states = {} @@ -60,6 +72,10 @@ def add_step(self, updates: Dict[int, Dict]) -> None: """ self._diffs.append(updates) + @property + def original_length(self) -> int: + return int(len(self) * self.compress_freq) + def __len__(self) -> int: return len(self._diffs) + 1 @@ -120,6 +136,9 @@ def get_step(self, step_no) -> SimulationState: The simulation state at the requested step. """ + # if step_no % self.compress_freq != 0: + # raise ValueError(f"Cannot retrieve step no {step_no} because this result has been compressed with sampling frequency {self.compress_freq}") + stored = self._stored_states.get(step_no) if stored is not None: return stored @@ -133,6 +152,7 @@ def as_dict(self): return { "initial_state": self.initial_state.as_dict(), "diffs": self._diffs, + "compress_freq": self.compress_freq, "@module": self.__class__.__module__, "@class": self.__class__.__name__, } @@ -152,3 +172,32 @@ def to_file(self, fpath: str = None) -> None: dumpfn(self, fpath) return fpath + + +def compress_result(result: SimulationResult, num_steps: int): + i_state = result.first_step + # total steps is the actual number of diffs stored, not the number of original simulation steps taken + total_steps = len(result) + if num_steps >= total_steps: + raise ValueError( + f"Cannot upsample SimulationResult of length {total_steps} to size {num_steps}." + ) + + exact_sample_freq = total_steps / (num_steps) + # print(total_steps, current_sample_freq) + total_compress_freq = exact_sample_freq * result.compress_freq + compressed_result = SimulationResult(i_state, compress_freq=total_compress_freq) + + live_state = SimulationState(copy.deepcopy(i_state._state)) + added = 0 + next_sample_step = exact_sample_freq + for i, diff in enumerate(result._diffs): + curr_step = i + 1 + live_state.batch_update(diff) + # if curr_step % current_sample_freq == 0: + if curr_step > next_sample_step: + # print(curr_step) + added += 1 + compressed_result.add_step(live_state.as_state_update()) + next_sample_step += exact_sample_freq + return compressed_result diff --git a/src/pylattica/core/simulation_state.py b/src/pylattica/core/simulation_state.py index b6a92ab..d94626b 100644 --- a/src/pylattica/core/simulation_state.py +++ b/src/pylattica/core/simulation_state.py @@ -100,7 +100,7 @@ def get_site_state(self, site_id: int) -> Dict: """ return self._state[SITES].get(site_id) - def get_general_state(self) -> Dict: + def get_general_state(self, key: str = None, default=None) -> Dict: """Returns the general state. Returns @@ -108,7 +108,10 @@ def get_general_state(self) -> Dict: Dict The general state. """ - return self._state.get(GENERAL) + if key is None: + return copy.deepcopy(self._state.get(GENERAL)) + else: + return copy.deepcopy(self._state.get(GENERAL)).get(key, default) def set_general_state(self, updates: Dict) -> None: """Updates the general state with the keys and values provided by the updates parameter. @@ -175,5 +178,8 @@ def copy(self) -> SimulationState: """ return SimulationState(self._state) + def as_state_update(self) -> Dict: + return copy.deepcopy(self._state) + def __eq__(self, other: SimulationState) -> bool: return self._state == other._state diff --git a/src/pylattica/models/game_of_life/__init__.py b/src/pylattica/models/game_of_life/__init__.py index 06f2f55..a21bc66 100644 --- a/src/pylattica/models/game_of_life/__init__.py +++ b/src/pylattica/models/game_of_life/__init__.py @@ -1 +1,2 @@ from .controller import GameOfLifeController, Life, Seeds, Anneal, Diamoeba, Maze +from .life_phase_set import LIFE_PHASE_SET \ No newline at end of file diff --git a/src/pylattica/models/game_of_life/life_phase_set.py b/src/pylattica/models/game_of_life/life_phase_set.py new file mode 100644 index 0000000..0c9e9ee --- /dev/null +++ b/src/pylattica/models/game_of_life/life_phase_set.py @@ -0,0 +1,3 @@ +from ...discrete.phase_set import PhaseSet + +LIFE_PHASE_SET = PhaseSet(["alive", "dead"]) \ No newline at end of file diff --git a/src/pylattica/structures/square_grid/grid_setup.py b/src/pylattica/structures/square_grid/grid_setup.py index 6249596..9404306 100644 --- a/src/pylattica/structures/square_grid/grid_setup.py +++ b/src/pylattica/structures/square_grid/grid_setup.py @@ -348,9 +348,8 @@ def setup_random_sites( while num_sites_planted < num_sites_desired: if total_attempts > 1000 * num_sites_desired: - raise RuntimeError( - f"Too many nucleation sites at the specified buffer: {total_attempts} made at placing nuclei" - ) + print(f"Only able to place {num_sites_planted} in {total_attempts} attempts") + break rand_site = random.choice(all_sites) rand_site_id = rand_site[SITE_ID] diff --git a/src/pylattica/visualization/result_artist.py b/src/pylattica/visualization/result_artist.py index 52c3c33..0ed2451 100644 --- a/src/pylattica/visualization/result_artist.py +++ b/src/pylattica/visualization/result_artist.py @@ -8,13 +8,24 @@ from PIL import Image +from typing import Callable + _dsr_globals = {} +def default_annotation_builder(step, step_no): + return f"Step {step_no}" + + class ResultArtist: """A class for rendering simulation results as animated GIFs.""" - def __init__(self, step_artist: StructureArtist, result: SimulationResult): + def __init__( + self, + step_artist: StructureArtist, + result: SimulationResult, + annotation_builder: Callable = default_annotation_builder, + ): """Instantiates the ResultArtist class. Parameters @@ -26,6 +37,7 @@ def __init__(self, step_artist: StructureArtist, result: SimulationResult): """ self._step_artist = step_artist self.result = result + self.annotation_builder = annotation_builder def _get_images(self, **kwargs): draw_freq = kwargs.get("draw_freq", 1) @@ -47,9 +59,10 @@ def _get_images(self, **kwargs): with mp.get_context("fork").Pool(PROCESSES) as pool: params = [] for idx in indices: - label = f"Step {idx}" - step_kwargs = {**kwargs, "label": label} step = self.result.get_step(idx) + label = self.annotation_builder(step, idx) + step_kwargs = {**kwargs, "label": label} + params.append([step, step_kwargs]) for img in pool.starmap(_get_img_parallel, params): @@ -57,11 +70,7 @@ def _get_images(self, **kwargs): return imgs - def jupyter_show_step( - self, - step_no: int, - cell_size=20, - ) -> None: + def jupyter_show_step(self, step_no: int, cell_size=20, **kwargs) -> None: """In a jupyter notebook environment, visualizes the step as a color coded phase grid. Parameters @@ -71,17 +80,13 @@ def jupyter_show_step( cell_size : int, optional The size of each simulation cell, in pixels, by default 20 """ - label = f"Step {step_no}" # pragma: no cover step = self.result.get_step(step_no) # pragma: no cover + label = self.annotation_builder(step, step_no) self._step_artist.jupyter_show( - step, label=label, cell_size=cell_size + step, label=label, cell_size=cell_size, **kwargs ) # pragma: no cover - def jupyter_play( - self, - cell_size: int = 20, - wait: int = 1, - ): + def jupyter_play(self, cell_size: int = 20, wait: int = 1, **kwargs): """In a jupyter notebook environment, plays the simulation visualization back by showing a series of images with {wait} seconds between each one. @@ -94,7 +99,7 @@ def jupyter_play( """ from IPython.display import clear_output, display # pragma: no cover - imgs = self._get_images(cell_size=cell_size) # pragma: no cover + imgs = self._get_images(cell_size=cell_size, **kwargs) # pragma: no cover for img in imgs: # pragma: no cover clear_output() # pragma: no cover display(img) # pragma: no cover diff --git a/src/pylattica/visualization/square_grid_artist_2D.py b/src/pylattica/visualization/square_grid_artist_2D.py index 016779e..e47843c 100644 --- a/src/pylattica/visualization/square_grid_artist_2D.py +++ b/src/pylattica/visualization/square_grid_artist_2D.py @@ -13,16 +13,27 @@ def _draw_image(self, state: SimulationState, **kwargs): label = kwargs.get("label", None) cell_size = kwargs.get("cell_size", 20) + show_legend = kwargs.get("show_legend", True) + legend = self.cell_artist.get_legend(state) legend_order = sorted(legend.keys()) state_size = int(self.structure.lattice.vec_lengths[0]) - width = state_size + 6 - legend_border_width = 5 - height = max(state_size, len(legend) + 1) + if show_legend: + width = state_size + 6 + legend_border_width = 5 + height = max(state_size, len(legend) + 1) + img_width = width * cell_size + legend_border_width + img_height = height * cell_size + else: + width = state_size + height = state_size + img_width = width * cell_size + img_height = height * cell_size + img = Image.new( "RGB", - (width * cell_size + legend_border_width, height * cell_size), + (img_width, img_height), "black", ) # Create a new black image @@ -39,29 +50,30 @@ def _draw_image(self, state: SimulationState, **kwargs): for p_y in range(p_y_start, p_y_start + cell_size): pixels[p_x, p_y] = cell_color - count = 0 - legend_hoffset = int(cell_size / 4) - legend_voffset = int(cell_size / 4) - - for p_y in range(height * cell_size): - for p_x in range(0, legend_border_width): - x = state_size * cell_size + p_x - pixels[x, p_y] = (255, 255, 255) - - for phase in legend_order: - color = legend.get(phase) - p_col_start = state_size * cell_size + legend_border_width + legend_hoffset - p_row_start = count * cell_size + legend_voffset - for p_x in range(p_col_start, p_col_start + cell_size): - for p_y in range(p_row_start, p_row_start + cell_size): - pixels[p_x, p_y] = color - - legend_label_loc = ( - int(p_col_start + cell_size + cell_size / 4), - int(p_row_start + cell_size / 4), - ) - draw.text(legend_label_loc, phase, (255, 255, 255)) - count += 1 + if show_legend: + count = 0 + legend_hoffset = int(cell_size / 4) + legend_voffset = int(cell_size / 4) + + for p_y in range(height * cell_size): + for p_x in range(0, legend_border_width): + x = state_size * cell_size + p_x + pixels[x, p_y] = (255, 255, 255) + + for phase in legend_order: + color = legend.get(phase) + p_col_start = state_size * cell_size + legend_border_width + legend_hoffset + p_row_start = count * cell_size + legend_voffset + for p_x in range(p_col_start, p_col_start + cell_size): + for p_y in range(p_row_start, p_row_start + cell_size): + pixels[p_x, p_y] = color + + legend_label_loc = ( + int(p_col_start + cell_size + cell_size / 4), + int(p_row_start + cell_size / 4), + ) + draw.text(legend_label_loc, phase, (255, 255, 255)) + count += 1 if label is not None: draw.text((5, 5), label, (255, 255, 255)) diff --git a/src/pylattica/visualization/square_grid_artist_3D.py b/src/pylattica/visualization/square_grid_artist_3D.py index 408c375..fcd8981 100644 --- a/src/pylattica/visualization/square_grid_artist_3D.py +++ b/src/pylattica/visualization/square_grid_artist_3D.py @@ -6,6 +6,7 @@ import io import matplotlib.pyplot as plt from PIL import Image +from matplotlib.lines import Line2D class SquareGridArtist3D(StructureArtist): @@ -46,11 +47,50 @@ def _draw_image(self, state: SimulationState, **kwargs): colors = [0.8, 0.8, 0.8, 0.2] ax.voxels(data, facecolors=colors, edgecolor="k", linewidth=0) else: - colors = np.array(color_cache[color]) / 255 + colors = list(np.array(color_cache[color]) / 255) ax.voxels(data, facecolors=colors, edgecolor="k", linewidth=0.25) - ax.legend() + if kwargs.get("show_legend") == True: + legend = self.cell_artist.get_legend(state) + legend_handles = [] + for phase, color in legend.items(): + legend_handles.append( + Line2D( + [0], + [0], + marker="s", + color="w", + markerfacecolor=list(np.array(color) / 255), + markersize=10, + label=phase, + ) + ) + + # Add custom legend to the plot + legend_font_props = {"family": "Lato", "size": 14} + + plt.legend( + handles=legend_handles, + loc="lower center", + prop=legend_font_props, + ncols=5, + frameon=False, + ) plt.axis("off") + if kwargs.get("label") is not None: + x_text, y_text, z_text = 18, -5, 30 + + # Add the text + annotation_font = { + "size": 16, + "family": "Lato", + "color": np.array([194, 29, 63]) / 255, + "weight": "bold", + } + ax.text( + x_text, y_text, z_text, kwargs.get("label"), fontdict=annotation_font + ) + fig = ax.get_figure() buf = io.BytesIO() fig.savefig(buf) diff --git a/tests/atomic/conftest.py b/tests/atomic/conftest.py index c8454c3..ee05ac8 100644 --- a/tests/atomic/conftest.py +++ b/tests/atomic/conftest.py @@ -4,21 +4,54 @@ from pylattica.core import Lattice, StructureBuilder + @pytest.fixture() def zr_pmg_struct(): - Zr = {"@module": "pymatgen.core.structure", "@class": "Structure", "charge": 0.0, "lattice": {"matrix": [[3.23923141, 0.0, 1.9834571889770337e-16], [-1.6196157049999995, 2.8052566897964866, 1.9834571889770337e-16], [0.0, 0.0, 5.17222]], "pbc": [True, True, True], "a": 3.23923141, "b": 3.2392314099999995, "c": 5.17222, "alpha": 90.0, "beta": 90.0, "gamma": 119.99999999999999, "volume": 46.99931962635987}, "properties": {}, "sites": [{"species": [{"element": "Zr", "occu": 1.0}], "abc": [0.3333333333333333, 0.6666666666666666, 0.25], "xyz": [3.5599729479122526e-16, 1.870171126530991, 1.2930550000000003], "properties": {}, "label": "Zr0"}, {"species": [{"element": "Zr", "occu": 1.0}], "abc": [0.6666666666666666, 0.3333333333333333, 0.75], "xyz": [1.6196157050000002, 0.9350855632654955, 3.8791650000000004], "properties": {}, "label": "Zr1"}]} + Zr = { + "@module": "pymatgen.core.structure", + "@class": "Structure", + "charge": 0.0, + "lattice": { + "matrix": [ + [3.23923141, 0.0, 1.9834571889770337e-16], + [-1.6196157049999995, 2.8052566897964866, 1.9834571889770337e-16], + [0.0, 0.0, 5.17222], + ], + "pbc": [True, True, True], + "a": 3.23923141, + "b": 3.2392314099999995, + "c": 5.17222, + "alpha": 90.0, + "beta": 90.0, + "gamma": 119.99999999999999, + "volume": 46.99931962635987, + }, + "properties": {}, + "sites": [ + { + "species": [{"element": "Zr", "occu": 1.0}], + "abc": [0.3333333333333333, 0.6666666666666666, 0.25], + "xyz": [3.5599729479122526e-16, 1.870171126530991, 1.2930550000000003], + "properties": {}, + "label": "Zr0", + }, + { + "species": [{"element": "Zr", "occu": 1.0}], + "abc": [0.6666666666666666, 0.3333333333333333, 0.75], + "xyz": [1.6196157050000002, 0.9350855632654955, 3.8791650000000004], + "properties": {}, + "label": "Zr1", + }, + ], + } return Structure.from_dict(Zr) + @pytest.fixture() def pyl_struct(): + lat = Lattice([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - lat = Lattice([[1,0,0], [0,1,0], [0,0,1]]) - - motif = { - "Zr": [[ - 0.5, 0.5, 0.5 - ]] - } + motif = {"Zr": [[0.5, 0.5, 0.5]]} builder = StructureBuilder(lat, motif) - return builder.build((1, 2, 3)) \ No newline at end of file + return builder.build((1, 2, 3)) diff --git a/tests/atomic/test_convert_pymatgen_structure.py b/tests/atomic/test_convert_pymatgen_structure.py index b41de24..8f30e7c 100644 --- a/tests/atomic/test_convert_pymatgen_structure.py +++ b/tests/atomic/test_convert_pymatgen_structure.py @@ -9,6 +9,7 @@ import numpy as np + def test_can_convert_lattice(zr_pmg_struct: PmgStructure): converter = PymatgenStructureConverter() @@ -22,6 +23,7 @@ def test_can_convert_lattice(zr_pmg_struct: PmgStructure): assert np.allclose(lattice.vecs[1], zr_pmg_struct.lattice.matrix[1]) assert np.allclose(lattice.vecs[2], zr_pmg_struct.lattice.matrix[2]) + def test_can_convert_pmg_struct_to_pyl_struct_builder(zr_pmg_struct: PmgStructure): converter = PymatgenStructureConverter() @@ -37,13 +39,22 @@ def test_can_convert_pmg_struct_to_pyl_struct_builder(zr_pmg_struct: PmgStructur tol = OFFSET_PRECISION # brittle - relies on the order in which sites are enumerated - assert np.allclose(struct.site_location(0),zr_pmg_struct.sites[0].coords, atol=OFFSET_PRECISION) - assert np.allclose(struct.site_location(1),zr_pmg_struct.sites[1].coords, atol=OFFSET_PRECISION) + assert np.allclose( + struct.site_location(0), zr_pmg_struct.sites[0].coords, atol=OFFSET_PRECISION + ) + assert np.allclose( + struct.site_location(1), zr_pmg_struct.sites[1].coords, atol=OFFSET_PRECISION + ) + + assert np.allclose( + struct.lattice.get_fractional_coords(struct.site_location(1)), + zr_pmg_struct.sites[1].frac_coords, + atol=OFFSET_PRECISION, + ) - assert np.allclose(struct.lattice.get_fractional_coords(struct.site_location(1)), zr_pmg_struct.sites[1].frac_coords, atol=OFFSET_PRECISION) def test_can_convert_pmg_struct_to_pyl_struct_and_state(zr_pmg_struct: PmgStructure): - zr_pmg_struct.make_supercell((2,2,2)) + zr_pmg_struct.make_supercell((2, 2, 2)) assert zr_pmg_struct.num_sites == 16 converter = PymatgenStructureConverter() @@ -56,6 +67,7 @@ def test_can_convert_pmg_struct_to_pyl_struct_and_state(zr_pmg_struct: PmgStruct assert pyl_state.get_site_state(sid) is not None assert pyl_state.get_site_state(sid)[DISCRETE_OCCUPANCY] == site.species_string + def test_can_convert_pyl_lat(pyl_struct: PeriodicStructure): converter = PymatgenStructureConverter() @@ -68,6 +80,7 @@ def test_can_convert_pyl_lat(pyl_struct: PeriodicStructure): assert np.allclose(pyl_struct.lattice.vecs[1], pmg_lat.matrix[1]) assert np.allclose(pyl_struct.lattice.vecs[2], pmg_lat.matrix[2]) + def test_can_convert_pyl_struct(pyl_struct: PeriodicStructure): converter = PymatgenStructureConverter() @@ -80,4 +93,7 @@ def test_can_convert_pyl_struct(pyl_struct: PeriodicStructure): assert matching_site is not None assert matching_site[SITE_CLASS] == site.species_string - assert np.all(pyl_struct.lattice.get_fractional_coords(matching_site[LOCATION]) == site.frac_coords) \ No newline at end of file + assert np.all( + pyl_struct.lattice.get_fractional_coords(matching_site[LOCATION]) + == site.frac_coords + ) diff --git a/tests/conftest.py b/tests/conftest.py index b80aa4a..eac13fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,31 +3,41 @@ from pylattica.discrete import PhaseSet from pylattica.structures.square_grid import DiscreteGridSetup -from pylattica.structures.square_grid import SimpleSquare2DStructureBuilder, DiscreteGridSetup +from pylattica.structures.square_grid import ( + SimpleSquare2DStructureBuilder, + DiscreteGridSetup, +) from pylattica.discrete import PhaseSet from pylattica.core import Lattice + @pytest.fixture(scope="module") def square_grid_2D_2x2(): return SimpleSquare2DStructureBuilder().build(2) + @pytest.fixture(scope="module") def square_grid_2D_4x4(): return SimpleSquare2DStructureBuilder().build(4) + @pytest.fixture(scope="module") def square_lattice(): - return Lattice([ - [0, 0, 1], - [0, 1, 0], - [1, 0, 0], - ]) + return Lattice( + [ + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + ] + ) + @pytest.fixture(scope="module") def simple_phase_set(): return PhaseSet(["A", "B", "C", "D"]) + @pytest.fixture(scope="module") def grid_setup(simple_phase_set): - return DiscreteGridSetup(simple_phase_set) \ No newline at end of file + return DiscreteGridSetup(simple_phase_set) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 4be910e..07e1fac 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -3,25 +3,22 @@ from pylattica.core import Lattice from pylattica.core import PeriodicStructure + @pytest.fixture(scope="module") def square_2D_basis_vecs(): - return [ - (1, 0), - (0, 1) - ] + return [(1, 0), (0, 1)] + @pytest.fixture(scope="module") -def simple_motif(): - return { - "A": [ - (0.5, 0.5) - ] - } +def simple_motif(): + return {"A": [(0.5, 0.5)]} + @pytest.fixture(scope="module") def square_2D_lattice(square_2D_basis_vecs): - return Lattice(square_2D_basis_vecs) + return Lattice(square_2D_basis_vecs) + @pytest.fixture(scope="module") def square_2x2_2D_grid_in_test(square_2D_lattice: Lattice, simple_motif: Dict): - return PeriodicStructure.build_from(square_2D_lattice, [2, 2], simple_motif) \ No newline at end of file + return PeriodicStructure.build_from(square_2D_lattice, [2, 2], simple_motif) diff --git a/tests/core/test_analyzer.py b/tests/core/test_analyzer.py index 4089ec9..4153344 100644 --- a/tests/core/test_analyzer.py +++ b/tests/core/test_analyzer.py @@ -6,82 +6,82 @@ import pytest + def test_analyze_get_sites_arb_criteria(square_grid_2D_4x4: PeriodicStructure): state = SimulationState() for idx, site in enumerate(square_grid_2D_4x4.sites()): - state.set_site_state(site[SITE_ID], { "trait": idx }) - + state.set_site_state(site[SITE_ID], {"trait": idx}) + analyzer = StateAnalyzer(square_grid_2D_4x4) def _criteria_1(state: Dict) -> bool: - return state.get('trait') >= 12 + return state.get("trait") >= 12 def _criteria_2(state: Dict) -> bool: - return state.get('trait') <= 15 + return state.get("trait") <= 15 sites = analyzer.get_sites(state, state_criteria=[_criteria_1, _criteria_2]) assert len(sites) == 4 + def test_analyze_get_sites_no_criteria(square_grid_2D_4x4: PeriodicStructure): state = SimulationState() for idx, site in enumerate(square_grid_2D_4x4.sites()): - state.set_site_state(site[SITE_ID], { "trait": idx }) - - analyzer = StateAnalyzer(square_grid_2D_4x4) + state.set_site_state(site[SITE_ID], {"trait": idx}) + + analyzer = StateAnalyzer(square_grid_2D_4x4) sites = analyzer.get_sites(state) - assert len(sites) == len(state.all_site_states()) - + assert len(sites) == len(state.all_site_states()) + def test_analyze_count_equal(grid_setup: DiscreteGridSetup): - periodic_state = grid_setup.setup_interface(4, 'A', 'B') + periodic_state = grid_setup.setup_interface(4, "A", "B") analyzer = StateAnalyzer(periodic_state.structure) - assert analyzer.get_site_count_where_equal(periodic_state.state, - { - DISCRETE_OCCUPANCY: "A" - } - ) == 8 + assert ( + analyzer.get_site_count_where_equal( + periodic_state.state, {DISCRETE_OCCUPANCY: "A"} + ) + == 8 + ) - assert analyzer.get_site_count_where_equal(periodic_state.state, - { - DISCRETE_OCCUPANCY: "B" - } - ) == 8 + assert ( + analyzer.get_site_count_where_equal( + periodic_state.state, {DISCRETE_OCCUPANCY: "B"} + ) + == 8 + ) + + assert ( + analyzer.get_site_count_where_equal( + periodic_state.state, {DISCRETE_OCCUPANCY: "C"} + ) + == 0 + ) - assert analyzer.get_site_count_where_equal(periodic_state.state, - { - DISCRETE_OCCUPANCY: "C" - } - ) == 0 def test_analyze_get_sites_where_equal(grid_setup: DiscreteGridSetup): - periodic_state = grid_setup.setup_interface(4, 'A', 'B') + periodic_state = grid_setup.setup_interface(4, "A", "B") analyzer = StateAnalyzer(periodic_state.structure) - a_site_ids = analyzer.get_sites_where_equal(periodic_state.state, - { - DISCRETE_OCCUPANCY: "A" - } + a_site_ids = analyzer.get_sites_where_equal( + periodic_state.state, {DISCRETE_OCCUPANCY: "A"} ) for site_id in a_site_ids: reretrieved_site = periodic_state.state.get_site_state(site_id) assert reretrieved_site[DISCRETE_OCCUPANCY] == "A" - a_site_ids = analyzer.get_sites_where_equal(periodic_state.state, - { - DISCRETE_OCCUPANCY: "A" - } - ) + a_site_ids = analyzer.get_sites_where_equal( + periodic_state.state, {DISCRETE_OCCUPANCY: "A"} + ) - b_site_ids = analyzer.get_sites_where_equal(periodic_state.state, - { - DISCRETE_OCCUPANCY: "B" - } + b_site_ids = analyzer.get_sites_where_equal( + periodic_state.state, {DISCRETE_OCCUPANCY: "B"} ) for site_id in b_site_ids: reretrieved_site = periodic_state.state.get_site_state(site_id) - assert reretrieved_site[DISCRETE_OCCUPANCY] == "B" \ No newline at end of file + assert reretrieved_site[DISCRETE_OCCUPANCY] == "B" diff --git a/tests/core/test_async_runner.py b/tests/core/test_async_runner.py index fbe6db9..35d15fa 100644 --- a/tests/core/test_async_runner.py +++ b/tests/core/test_async_runner.py @@ -2,31 +2,27 @@ from pylattica.core import AsynchronousRunner, BasicController from pylattica.core.simulation_state import SimulationState +from pylattica.core.simulation_result import SimulationResult from pylattica.core.periodic_structure import PeriodicStructure -from pylattica.core.constants import SITE_ID +from pylattica.core.constants import SITE_ID, GENERAL import random + def test_simple_async_controller(square_grid_2D_4x4: PeriodicStructure): - class SimpleAsyncController(BasicController): - def get_state_update(self, site_id: int, prev_state: SimulationState): new_state = 3 - return { - site_id: { - "value": new_state - } - } - + return {site_id: {"value": new_state}} + initial_state = SimulationState() for site in square_grid_2D_4x4.sites(): - initial_state.set_site_state(site[SITE_ID], { "value": 0 }) - + initial_state.set_site_state(site[SITE_ID], {"value": 0}) + runner = AsynchronousRunner() controller = SimpleAsyncController() - result = runner.run(initial_state, controller=controller, num_steps = 10) + result = runner.run(initial_state, controller=controller, num_steps=10) last_step = result.last_step @@ -34,30 +30,25 @@ def get_state_update(self, site_id: int, prev_state: SimulationState): for site_state in last_step.all_site_states(): if site_state["value"] == 3: num_converted += 1 - + assert num_converted > 0 assert num_converted <= 10 + def test_simple_async_controller_async_flag(square_grid_2D_4x4: PeriodicStructure): - class SimpleAsyncController(BasicController): - def get_state_update(self, site_id: int, prev_state: SimulationState): new_state = 3 - return { - site_id: { - "value": new_state - } - } - + return {site_id: {"value": new_state}} + initial_state = SimulationState() for site in square_grid_2D_4x4.sites(): - initial_state.set_site_state(site[SITE_ID], { "value": 0 }) - + initial_state.set_site_state(site[SITE_ID], {"value": 0}) + runner = AsynchronousRunner() controller = SimpleAsyncController() - result = runner.run(initial_state, controller=controller, num_steps = 10) + result = runner.run(initial_state, controller=controller, num_steps=10) last_step = result.last_step @@ -65,34 +56,62 @@ def get_state_update(self, site_id: int, prev_state: SimulationState): for site_state in last_step.all_site_states(): if site_state["value"] == 3: num_converted += 1 - + assert num_converted > 0 assert num_converted <= 10 -def test_async_controller_next_sites(square_grid_2D_4x4): +def test_general_state_async_controller(square_grid_2D_4x4: PeriodicStructure): + class GeneralStateAsyncController(BasicController): + def get_state_update(self, site_id: int, prev_state: SimulationState): + prev = prev_state.get_general_state("value") + if prev is not None: + nxt = prev + 1 + else: + nxt = 1 + return {GENERAL: {"value": nxt}} + + initial_state = SimulationState() + initial_state.set_site_state(1, {"value": 0}) + + runner = AsynchronousRunner() + + controller = GeneralStateAsyncController() + result = runner.run(initial_state, controller=controller, num_steps=10) + + for i in range(1, 10): + state = result.get_step(i) + gen = state.get_general_state("value") + assert gen == i + + fname = "tmp.json" + result.to_file(fname) + rehyd = SimulationResult.from_file(fname) + for i in range(1, 10): + state = rehyd.get_step(i) + gen = state.get_general_state("value") + assert gen == i + + +def test_async_controller_next_sites(square_grid_2D_4x4): CHANGED_STATE = 3 - class NextSiteAsyncController(BasicController): + class NextSiteAsyncController(BasicController): def get_random_site(self, _): return 0 def get_state_update(self, site_id: int, _): - return { - site_id: { - "value": CHANGED_STATE - } - }, [site_id + 1] - + return {site_id: {"value": CHANGED_STATE}}, [site_id + 1] + initial_state = SimulationState() for site in square_grid_2D_4x4.sites(): - initial_state.set_site_state(site[SITE_ID], { "value": 0 }) - + initial_state.set_site_state(site[SITE_ID], {"value": 0}) + runner = AsynchronousRunner() controller = NextSiteAsyncController() num_steps = 4 - result = runner.run(initial_state, controller=controller, num_steps = num_steps) + result = runner.run(initial_state, controller=controller, num_steps=num_steps) last_step = result.last_step @@ -102,31 +121,27 @@ def get_state_update(self, site_id: int, _): else: assert site_state["value"] == 0 -def test_async_controller_random_site_list(square_grid_2D_4x4): +def test_async_controller_random_site_list(square_grid_2D_4x4): CHANGED_STATE = 3 chosen_sites = [0, 3, 5, 9] - class NextSiteAsyncController(BasicController): + class NextSiteAsyncController(BasicController): def get_random_site(self, _): return chosen_sites def get_state_update(self, site_id: int, _): - return { - site_id: { - "value": CHANGED_STATE - } - } - + return {site_id: {"value": CHANGED_STATE}} + initial_state = SimulationState() for site in square_grid_2D_4x4.sites(): - initial_state.set_site_state(site[SITE_ID], { "value": 0 }) - + initial_state.set_site_state(site[SITE_ID], {"value": 0}) + runner = AsynchronousRunner() controller = NextSiteAsyncController() num_steps = 10 - result = runner.run(initial_state, controller=controller, num_steps = num_steps) + result = runner.run(initial_state, controller=controller, num_steps=num_steps) last_step = result.last_step @@ -136,40 +151,36 @@ def get_state_update(self, site_id: int, _): else: assert site_state["value"] == 0 -def test_async_controller_empty_random_site_list(square_grid_2D_4x4): +def test_async_controller_empty_random_site_list(square_grid_2D_4x4): CHANGED_STATE = 3 chosen_sites = [] - class NextSiteAsyncController(BasicController): + class NextSiteAsyncController(BasicController): def get_random_site(self, _): return chosen_sites def get_state_update(self, site_id: int, _): - return { - site_id: { - "value": CHANGED_STATE - } - } - + return {site_id: {"value": CHANGED_STATE}} + initial_state = SimulationState() for site in square_grid_2D_4x4.sites(): - initial_state.set_site_state(site[SITE_ID], { "value": 0 }) - + initial_state.set_site_state(site[SITE_ID], {"value": 0}) + runner = AsynchronousRunner() controller = NextSiteAsyncController() num_steps = 10 with pytest.raises(RuntimeError, match="Controller provided"): - result = runner.run(initial_state, controller=controller, num_steps = num_steps) + result = runner.run(initial_state, controller=controller, num_steps=num_steps) -def test_async_controller_limited_sites(square_grid_2D_4x4): +def test_async_controller_limited_sites(square_grid_2D_4x4): CHANGED_STATE = 3 chosen_sites = [0, 3, 5, 9] original_chosen_sites = chosen_sites.copy() - class NextSiteAsyncController(BasicController): + class NextSiteAsyncController(BasicController): steps_taken = 0 def get_random_site(self, _): @@ -180,21 +191,17 @@ def get_random_site(self, _): def get_state_update(self, site_id: int, _): self.steps_taken += 1 - return { - site_id: { - "value": CHANGED_STATE - } - } - + return {site_id: {"value": CHANGED_STATE}} + initial_state = SimulationState() for site in square_grid_2D_4x4.sites(): - initial_state.set_site_state(site[SITE_ID], { "value": 0 }) - + initial_state.set_site_state(site[SITE_ID], {"value": 0}) + runner = AsynchronousRunner() controller = NextSiteAsyncController() num_steps = 10 - result = runner.run(initial_state, controller=controller, num_steps = num_steps) + result = runner.run(initial_state, controller=controller, num_steps=num_steps) assert controller.steps_taken == 4 last_step = result.last_step @@ -202,4 +209,6 @@ def get_state_update(self, site_id: int, _): if site_state[SITE_ID] in original_chosen_sites: assert site_state["value"] == CHANGED_STATE else: - assert site_state["value"] == 0, f'{site_state[SITE_ID]} had the wrong state!' \ No newline at end of file + assert ( + site_state["value"] == 0 + ), f"{site_state[SITE_ID]} had the wrong state!" diff --git a/tests/core/test_basic_controller.py b/tests/core/test_basic_controller.py index 7b2ba1c..86f3bbc 100644 --- a/tests/core/test_basic_controller.py +++ b/tests/core/test_basic_controller.py @@ -4,14 +4,13 @@ from pylattica.core.periodic_structure import PeriodicStructure from pylattica.core.simulation_state import SimulationState -def test_simple_controller(square_grid_2D_4x4: PeriodicStructure): +def test_simple_controller(square_grid_2D_4x4: PeriodicStructure): class SimpleController(BasicController): - def get_state_update(self, site_id: int, prev_state: SimulationState): return {} sc = SimpleController() state = SimulationState.from_struct(square_grid_2D_4x4) - assert type(sc.get_random_site(state)) == int \ No newline at end of file + assert type(sc.get_random_site(state)) == int diff --git a/tests/core/test_coordinate_utils.py b/tests/core/test_coordinate_utils.py index 108d6d2..74b2059 100644 --- a/tests/core/test_coordinate_utils.py +++ b/tests/core/test_coordinate_utils.py @@ -2,11 +2,12 @@ from pylattica.core.coordinate_utils import get_points_in_box + def test_get_points_in_box(): - lbs = (0,0) + lbs = (0, 0) ubs = (2, 2) pts = get_points_in_box(lbs, ubs) assert len(pts) == 4 - assert (0,0) in pts \ No newline at end of file + assert (0, 0) in pts diff --git a/tests/core/test_distance_map.py b/tests/core/test_distance_map.py index 887576a..73b4c2b 100644 --- a/tests/core/test_distance_map.py +++ b/tests/core/test_distance_map.py @@ -1,24 +1,17 @@ from pylattica.core.distance_map import EuclideanDistanceMap, ManhattanDistanceMap + def test_distance_map_basics(): - dmap = EuclideanDistanceMap([ - (1, 1), - (0, 1) - ]) + dmap = EuclideanDistanceMap([(1, 1), (0, 1)]) assert dmap.get_dist((1, 1)) == 1.41 assert dmap.get_dist((0, 1)) == 1 def test_manhattan_distance_map(): - dmap = ManhattanDistanceMap([ - (1,1), - (0,1), - (0,-1), - (2.5,1.7) - ]) + dmap = ManhattanDistanceMap([(1, 1), (0, 1), (0, -1), (2.5, 1.7)]) - assert dmap.get_dist((1,1)) == 2 - assert dmap.get_dist((0,1)) == 1 - assert dmap.get_dist((0,-1)) == 1 - assert dmap.get_dist((2.5,1.7)) == 4.2 \ No newline at end of file + assert dmap.get_dist((1, 1)) == 2 + assert dmap.get_dist((0, 1)) == 1 + assert dmap.get_dist((0, -1)) == 1 + assert dmap.get_dist((2.5, 1.7)) == 4.2 diff --git a/tests/core/test_lattice.py b/tests/core/test_lattice.py index 5cfea92..9694750 100644 --- a/tests/core/test_lattice.py +++ b/tests/core/test_lattice.py @@ -3,23 +3,28 @@ import numpy as np from pylattica.core import Lattice, PeriodicStructure -from pylattica.core.lattice import periodize +from pylattica.core.lattice import periodize from pylattica.core.constants import SITE_ID + def assert_points_equal(pt1, pt2): assert (np.array(pt1) == np.array(pt2)).all() + def test_instantiate_lattice(square_2D_basis_vecs): assert Lattice(square_2D_basis_vecs) is not None + def test_basic_points(square_2D_basis_vecs): lat = Lattice(square_2D_basis_vecs) assert_points_equal(lat.get_cartesian_coords((1.5, 0.5)), (1.5, 0.5)) + def test_can_build_simple_2x2_grid(square_2D_lattice, simple_motif): structure = PeriodicStructure.build_from(square_2D_lattice, [2, 2], simple_motif) assert structure is not None + def test_simple_structure_has_correct_sites(square_2D_lattice, simple_motif): structure = PeriodicStructure.build_from(square_2D_lattice, [2, 2], simple_motif) @@ -35,71 +40,73 @@ def test_simple_structure_has_correct_sites(square_2D_lattice, simple_motif): assert len(set([s[SITE_ID] for s in [site_1, site_2, site_3, site_4]])) == 4 + def test_rectangular_lattice_point_conversions(): - lattice = Lattice( - [[1, 0], - [0, 1/2]] - ) + lattice = Lattice([[1, 0], [0, 1 / 2]]) # points - pt1 = (1/2, 1/4) - pt2 = (1/2, 3/4) - pt3 = (0, 1/2) + pt1 = (1 / 2, 1 / 4) + pt2 = (1 / 2, 3 / 4) + pt3 = (0, 1 / 2) # periodizing cartesian points - assert (np.array([1/2, 1/4]) == lattice.get_periodized_cartesian_coords(pt1)).all() - assert (np.array([1/2, 1/4]) == lattice.get_periodized_cartesian_coords(pt2)).all() + assert ( + np.array([1 / 2, 1 / 4]) == lattice.get_periodized_cartesian_coords(pt1) + ).all() + assert ( + np.array([1 / 2, 1 / 4]) == lattice.get_periodized_cartesian_coords(pt2) + ).all() assert (np.array([0, 0]) == lattice.get_periodized_cartesian_coords(pt3)).all() - + # conversion to fractional coords - assert (np.array([1/2, 1/2]) == lattice.get_fractional_coords(pt1)).all() - assert (np.array([1/2, 3/2]) == lattice.get_fractional_coords(pt2)).all() + assert (np.array([1 / 2, 1 / 2]) == lattice.get_fractional_coords(pt1)).all() + assert (np.array([1 / 2, 3 / 2]) == lattice.get_fractional_coords(pt2)).all() assert (np.array([0, 1]) == lattice.get_fractional_coords(pt3)).all() # conversion from fractional_coords coords - assert (np.array([1/2, 1/8]) == lattice.get_cartesian_coords(pt1)).all() - assert (np.array([1/2, 3/8]) == lattice.get_cartesian_coords(pt2)).all() - assert (np.array([0, 1/4]) == lattice.get_cartesian_coords(pt3)).all() + assert (np.array([1 / 2, 1 / 8]) == lattice.get_cartesian_coords(pt1)).all() + assert (np.array([1 / 2, 3 / 8]) == lattice.get_cartesian_coords(pt2)).all() + assert (np.array([0, 1 / 4]) == lattice.get_cartesian_coords(pt3)).all() + def test_scaled_rectangular_lattice_point_conversions(): - lattice = Lattice( - [[1, 0], - [0, 1/2]] - ) + lattice = Lattice([[1, 0], [0, 1 / 2]]) - lattice = lattice.get_scaled_lattice((2,1)) + lattice = lattice.get_scaled_lattice((2, 1)) # points - pt1 = (3/2, 1/4) - pt2 = (5/2, 3/4) - pt3 = (0, 1/2) - pt4 = (1,1/4) + pt1 = (3 / 2, 1 / 4) + pt2 = (5 / 2, 3 / 4) + pt3 = (0, 1 / 2) + pt4 = (1, 1 / 4) # periodizing cartesian points - assert (np.array([3/2, 1/4]) == lattice.get_periodized_cartesian_coords(pt1)).all() - assert (np.array([1/2, 1/4]) == lattice.get_periodized_cartesian_coords(pt2)).all() + assert ( + np.array([3 / 2, 1 / 4]) == lattice.get_periodized_cartesian_coords(pt1) + ).all() + assert ( + np.array([1 / 2, 1 / 4]) == lattice.get_periodized_cartesian_coords(pt2) + ).all() assert (np.array([0, 0]) == lattice.get_periodized_cartesian_coords(pt3)).all() - assert (np.array([1, 1/4]) == lattice.get_periodized_cartesian_coords(pt4)).all() - + assert (np.array([1, 1 / 4]) == lattice.get_periodized_cartesian_coords(pt4)).all() + # conversion to fractional coords - assert (np.array([3/4, 1/2]) == lattice.get_fractional_coords(pt1)).all() - assert (np.array([5/4, 3/2]) == lattice.get_fractional_coords(pt2)).all() + assert (np.array([3 / 4, 1 / 2]) == lattice.get_fractional_coords(pt1)).all() + assert (np.array([5 / 4, 3 / 2]) == lattice.get_fractional_coords(pt2)).all() assert (np.array([0, 1]) == lattice.get_fractional_coords(pt3)).all() - assert (np.array([1/2, 1/2]) == lattice.get_fractional_coords(pt4)).all() + assert (np.array([1 / 2, 1 / 2]) == lattice.get_fractional_coords(pt4)).all() # conversion from fractional_coords coords - assert (np.array([3, 1/8]) == lattice.get_cartesian_coords(pt1)).all() - assert (np.array([5, 3/8]) == lattice.get_cartesian_coords(pt2)).all() - assert (np.array([0, 1/4]) == lattice.get_cartesian_coords(pt3)).all() - assert (np.array([2, 1/8]) == lattice.get_cartesian_coords(pt4)).all() + assert (np.array([3, 1 / 8]) == lattice.get_cartesian_coords(pt1)).all() + assert (np.array([5, 3 / 8]) == lattice.get_cartesian_coords(pt2)).all() + assert (np.array([0, 1 / 4]) == lattice.get_cartesian_coords(pt3)).all() + assert (np.array([2, 1 / 8]) == lattice.get_cartesian_coords(pt4)).all() + def test_rectangular_lattice_pbc_distance(): - lattice = Lattice( - [[1, 0], - [0, 1/2]] - ) + lattice = Lattice([[1, 0], [0, 1 / 2]]) - scaled = lattice.get_scaled_lattice((2,2)) + scaled = lattice.get_scaled_lattice((2, 2)) pt1 = (0.5, 0.5) pt2 = (0.5, 0.75) @@ -109,12 +116,9 @@ def test_rectangular_lattice_pbc_distance(): assert scaled.cartesian_periodic_distance(pt1, pt3) == 0.0 assert scaled.cartesian_periodic_distance(pt2, pt3) == 0.25 + def test_square_lattice_non_pbc_distance(): - lattice = Lattice( - [[1, 0], - [0, 1]], - periodic=False - ) + lattice = Lattice([[1, 0], [0, 1]], periodic=False) pt1 = (0.1, 0.5) pt2 = (0.9, 0.5) @@ -126,11 +130,7 @@ def test_square_lattice_non_pbc_distance(): def test_square_lattice_non_pbc_coords(): - lattice = Lattice( - [[1, 0], - [0, 1]], - periodic=False - ) + lattice = Lattice([[1, 0], [0, 1]], periodic=False) pt1 = (0.1, 0.5) pt2 = (0.9, 0.5) @@ -140,13 +140,11 @@ def test_square_lattice_non_pbc_coords(): assert_points_equal(lattice.get_periodized_cartesian_coords(pt2), pt2) assert_points_equal(lattice.get_periodized_cartesian_coords(pt3), (1.5, 0.5)) + def test_canted_lattice_pbc_distance(): - lattice = Lattice( - [[1, 0], - [1, 1]] - ) + lattice = Lattice([[1, 0], [1, 1]]) - scaled = lattice.get_scaled_lattice((2,2)) + scaled = lattice.get_scaled_lattice((2, 2)) pt1 = (0.75, 0.5) pt2 = (1.75, 0.5) @@ -160,32 +158,35 @@ def test_canted_lattice_pbc_distance(): assert scaled.cartesian_periodic_distance(pt1, pt3) == 0.75 assert scaled.cartesian_periodic_distance(pt4, pt5) == 0.0 assert scaled.cartesian_periodic_distance(pt6, pt7) == 1 - + + def test_canted_rectangular_lattice_point_conversions(): - lattice = Lattice( - [[1, 0], - [1, 1]] - ) + lattice = Lattice([[1, 0], [1, 1]]) # points - pt1 = (1/2, 1/4) - pt2 = (1/2, 3/4) - pt3 = (0, 1/2) + pt1 = (1 / 2, 1 / 4) + pt2 = (1 / 2, 3 / 4) + pt3 = (0, 1 / 2) # periodizing cartesian points - assert (np.array([1/2, 1/4]) == lattice.get_periodized_cartesian_coords(pt1)).all() - assert (np.array([3/2, 3/4]) == lattice.get_periodized_cartesian_coords(pt2)).all() - assert (np.array([1, 1/2]) == lattice.get_periodized_cartesian_coords(pt3)).all() - + assert ( + np.array([1 / 2, 1 / 4]) == lattice.get_periodized_cartesian_coords(pt1) + ).all() + assert ( + np.array([3 / 2, 3 / 4]) == lattice.get_periodized_cartesian_coords(pt2) + ).all() + assert (np.array([1, 1 / 2]) == lattice.get_periodized_cartesian_coords(pt3)).all() + # conversion to fractional coords - assert (np.array([1/4, 1/4]) == lattice.get_fractional_coords(pt1)).all() - assert (np.array([-1/4, 3/4]) == lattice.get_fractional_coords(pt2)).all() - assert (np.array([-1/2, 1/2]) == lattice.get_fractional_coords(pt3)).all() + assert (np.array([1 / 4, 1 / 4]) == lattice.get_fractional_coords(pt1)).all() + assert (np.array([-1 / 4, 3 / 4]) == lattice.get_fractional_coords(pt2)).all() + assert (np.array([-1 / 2, 1 / 2]) == lattice.get_fractional_coords(pt3)).all() # conversion from fractional_coords coords - assert (np.array([3/4, 1/4]) == lattice.get_cartesian_coords(pt1)).all() - assert (np.array([5/4, 3/4]) == lattice.get_cartesian_coords(pt2)).all() - assert (np.array([1/2, 1/2]) == lattice.get_cartesian_coords(pt3)).all() + assert (np.array([3 / 4, 1 / 4]) == lattice.get_cartesian_coords(pt1)).all() + assert (np.array([5 / 4, 3 / 4]) == lattice.get_cartesian_coords(pt2)).all() + assert (np.array([1 / 2, 1 / 2]) == lattice.get_cartesian_coords(pt3)).all() + def test_periodizing_point(): pt1 = (0.5, 1.5, 0.8) @@ -203,6 +204,7 @@ def test_periodizing_point(): pt4 = (1.5, 1.5) assert_points_equal((0.5, 1.5), periodize(pt4, (True, False))) + def test_nonperiodic_lattice(square_2D_basis_vecs): lat = Lattice(square_2D_basis_vecs, False) @@ -212,25 +214,28 @@ def test_nonperiodic_lattice(square_2D_basis_vecs): assert_points_equal(lat.get_periodized_cartesian_coords([1.5, 1.5]), [0.5, 0.5]) lat_slab = Lattice(square_2D_basis_vecs, (True, False)) - assert_points_equal(lat_slab.get_periodized_cartesian_coords([1.5, 1.5]), [0.5, 1.5]) + assert_points_equal( + lat_slab.get_periodized_cartesian_coords([1.5, 1.5]), [0.5, 1.5] + ) def test_partially_periodic_lattice(square_2D_basis_vecs): - lat = Lattice(square_2D_basis_vecs, (True, False)).get_scaled_lattice((3,3)) + lat = Lattice(square_2D_basis_vecs, (True, False)).get_scaled_lattice((3, 3)) assert_points_equal(lat.get_periodized_cartesian_coords((-0.5, 1.5)), (2.5, 1.5)) assert_points_equal(lat.get_periodized_cartesian_coords((1.5, -0.5)), (1.5, -0.5)) assert_points_equal(lat.get_periodized_cartesian_coords((-0.5, -0.5)), (2.5, -0.5)) + def test_scaling_lattice_retains_periodicity(square_2D_basis_vecs): lat = Lattice(square_2D_basis_vecs, False) - scaled = lat.get_scaled_lattice((2,2)) + scaled = lat.get_scaled_lattice((2, 2)) assert scaled.periodic == (False, False) lat = Lattice(square_2D_basis_vecs, (False, True)) - scaled = lat.get_scaled_lattice((2,2)) + scaled = lat.get_scaled_lattice((2, 2)) assert scaled.periodic == (False, True) lat = Lattice(square_2D_basis_vecs, True) - scaled = lat.get_scaled_lattice((2,2)) - assert scaled.periodic == (True, True) \ No newline at end of file + scaled = lat.get_scaled_lattice((2, 2)) + assert scaled.periodic == (True, True) diff --git a/tests/core/test_lattice_utils.py b/tests/core/test_lattice_utils.py index 1dc5379..e77f675 100644 --- a/tests/core/test_lattice_utils.py +++ b/tests/core/test_lattice_utils.py @@ -2,6 +2,7 @@ import numpy as np + def test_pbc_diff_frac_vec(): pt1 = (0.1, 0) pt2 = (0.9, 0) @@ -28,11 +29,9 @@ def test_pbc_diff_frac_vec(): assert np.allclose(pbc_diff_frac_vec(pt7, pt6, np.array([0, 1])), [1.2, 0.2]) assert np.allclose(pbc_diff_frac_vec(pt7, pt6, np.array([0, 0])), [1.2, 1.2]) + def test_pbc_diff_cart(): - lvecs = [ - [1, 0], - [0, 1] - ] + lvecs = [[1, 0], [0, 1]] pt1 = (0.1, 0.1) pt2 = (0.1, 0.9) @@ -52,9 +51,3 @@ def test_pbc_diff_cart(): assert np.isclose(l2.cartesian_periodic_distance(pt1, pt5), 1.0) assert np.isclose(l3.cartesian_periodic_distance(pt1, pt5), 0) assert np.isclose(l4.cartesian_periodic_distance(pt1, pt5), 1.0) - - - - - - diff --git a/tests/core/test_neighborhood_builders.py b/tests/core/test_neighborhood_builders.py index 3a58be6..03f020e 100644 --- a/tests/core/test_neighborhood_builders.py +++ b/tests/core/test_neighborhood_builders.py @@ -3,11 +3,17 @@ import numpy as np import math -from pylattica.core.neighborhood_builders import DistanceNeighborhoodBuilder, MotifNeighborhoodBuilder, AnnularNeighborhoodBuilder -from pylattica.structures.square_grid.structure_builders import SimpleSquare2DStructureBuilder +from pylattica.core.neighborhood_builders import ( + DistanceNeighborhoodBuilder, + MotifNeighborhoodBuilder, + AnnularNeighborhoodBuilder, +) +from pylattica.structures.square_grid.structure_builders import ( + SimpleSquare2DStructureBuilder, +) -def test_distance_nb_builder(square_grid_2D_4x4): +def test_distance_nb_builder(square_grid_2D_4x4): builder = DistanceNeighborhoodBuilder(1.01) nbhood = builder.get(square_grid_2D_4x4) neighbors = nbhood.neighbors_of(1) @@ -20,7 +26,6 @@ def test_distance_nb_builder(square_grid_2D_4x4): for _, nb_dist in nbs_w_dists: assert nb_dist == 1.0 - builder = DistanceNeighborhoodBuilder(1.5) nbhood = builder.get(square_grid_2D_4x4) neighbors = nbhood.neighbors_of(1) @@ -33,8 +38,9 @@ def test_distance_nb_builder(square_grid_2D_4x4): for _, nb_dist in nbs_w_dists: assert nb_dist == 1.0 or np.isclose(nb_dist, root_2, 0.01) + def test_annular_nb_hood_builder(): - struct = SimpleSquare2DStructureBuilder().build((5,5)) + struct = SimpleSquare2DStructureBuilder().build((5, 5)) builder = AnnularNeighborhoodBuilder(1.2, 2.1) nb_hood = builder.get(struct) @@ -56,6 +62,6 @@ def test_struct_nb_hood_builder(square_grid_2D_4x4): assert len(nbs_w_dists) == 2 assert len(set([nb[0] for nb in nbs_w_dists])) == 2 - + for nb_id, nb_dist in nbs_w_dists: assert nb_dist == 1.0 diff --git a/tests/core/test_neighborhoods.py b/tests/core/test_neighborhoods.py index a969035..041bc1f 100644 --- a/tests/core/test_neighborhoods.py +++ b/tests/core/test_neighborhoods.py @@ -1,13 +1,16 @@ -from pylattica.core.neighborhood_builders import MotifNeighborhoodBuilder, SiteClassNeighborhoodBuilder, StochasticNeighborhoodBuilder, DistanceNeighborhoodBuilder +from pylattica.core.neighborhood_builders import ( + MotifNeighborhoodBuilder, + SiteClassNeighborhoodBuilder, + StochasticNeighborhoodBuilder, + DistanceNeighborhoodBuilder, +) from pylattica.core import Lattice, PeriodicStructure import numpy as np + def test_site_class_neighborhood(): - lattice = Lattice([ - [1, 0], - [0, 1] - ]) + lattice = Lattice([[1, 0], [0, 1]]) motif = { "A": [[0.25, 0.25]], @@ -15,19 +18,14 @@ def test_site_class_neighborhood(): "C": [[0.75, 0.75]], } - struct = PeriodicStructure.build_from( - lattice, - (3,3), - motif - ) + struct = PeriodicStructure.build_from(lattice, (3, 3), motif) A_builder = MotifNeighborhoodBuilder([(0.25, 0.25), (0.5, 0.5)]) C_builder = MotifNeighborhoodBuilder([(0.5, 0.5)]) - overall_nbhood_builder = SiteClassNeighborhoodBuilder({ - "A": A_builder, - "C": C_builder - }) + overall_nbhood_builder = SiteClassNeighborhoodBuilder( + {"A": A_builder, "C": C_builder} + ) overall_nbhood = overall_nbhood_builder.get(struct) A_site = struct.id_at((1.25, 1.25)) @@ -41,41 +39,37 @@ def test_site_class_neighborhood(): B_nbs = overall_nbhood.neighbors_of(B_site) assert len(list(B_nbs)) == 0 - + C_nbs = overall_nbhood.neighbors_of(C_site) assert len(list(C_nbs)) == 1 assert len([nbid for nbid in C_nbs if struct.site_class(nbid) == "A"]) == 1 assert len([nbid for nbid in C_nbs if struct.site_class(nbid) == "B"]) == 0 + def test_motif_nbhood(): - lattice = Lattice([ - [1, 0], - [0, 1] - ]) - + lattice = Lattice([[1, 0], [0, 1]]) + motif = [[0.5, 0.5]] - structure = PeriodicStructure.build_from(lattice, (3,3), motif) + structure = PeriodicStructure.build_from(lattice, (3, 3), motif) - motif_builder = MotifNeighborhoodBuilder([(0,1)]) + motif_builder = MotifNeighborhoodBuilder([(0, 1)]) nb = motif_builder.get(structure) assert len(nb.neighbors_of(4)) == 1 assert np.allclose(structure.site_location(nb.neighbors_of(4)[0]), [1.5, 2.5]) + def test_stochastic_nbhood(): - lattice = Lattice([ - [1, 0], - [0, 1] - ]) + lattice = Lattice([[1, 0], [0, 1]]) motif = [[0.5, 0.5]] - structure = PeriodicStructure.build_from(lattice, (5,5), motif) + structure = PeriodicStructure.build_from(lattice, (5, 5), motif) - motif1 = MotifNeighborhoodBuilder([(0,1)]) - motif2 = MotifNeighborhoodBuilder([(0,-1)]) - motif3 = MotifNeighborhoodBuilder([(1,0)]) + motif1 = MotifNeighborhoodBuilder([(0, 1)]) + motif2 = MotifNeighborhoodBuilder([(0, -1)]) + motif3 = MotifNeighborhoodBuilder([(1, 0)]) motif4 = MotifNeighborhoodBuilder([(-1, 0)]) stoch_nbbuilder = StochasticNeighborhoodBuilder([motif1, motif2, motif3, motif4]) @@ -83,18 +77,18 @@ def test_stochastic_nbhood(): assert len(stoch_nb.neighbors_of(0)) == 1 + def test_fully_periodic_neighborhoods(): - lattice_vecs = [ - [1, 0], - [0, 1] - ] + lattice_vecs = [[1, 0], [0, 1]] motif = [[0.5, 0.5]] von_neumann_nb_builder = DistanceNeighborhoodBuilder(1.01) full_periodic_lattice = Lattice(lattice_vecs, True) - full_periodic_struct = PeriodicStructure.build_from(full_periodic_lattice, (3,3), motif) + full_periodic_struct = PeriodicStructure.build_from( + full_periodic_lattice, (3, 3), motif + ) full_periodic_nbhood = von_neumann_nb_builder.get(full_periodic_struct) edge_coords = (0.5, 1.5) @@ -110,18 +104,18 @@ def test_fully_periodic_neighborhoods(): assert len(corner_nbs) == 4 + def test_partially_periodic_neighborhoods(): - lattice_vecs = [ - [1, 0], - [0, 1] - ] + lattice_vecs = [[1, 0], [0, 1]] motif = [[0.5, 0.5]] von_neumann_nb_builder = DistanceNeighborhoodBuilder(1.01) partial_periodic_lattice = Lattice(lattice_vecs, (False, True)) - partial_periodic_struct = PeriodicStructure.build_from(partial_periodic_lattice, (3,3), motif) + partial_periodic_struct = PeriodicStructure.build_from( + partial_periodic_lattice, (3, 3), motif + ) partial_periodic_nbhood = von_neumann_nb_builder.get(partial_periodic_struct) edge_coords = (0.5, 1.5) @@ -139,17 +133,16 @@ def test_partially_periodic_neighborhoods(): def test_non_periodic_neighborhoods(): - lattice_vecs = [ - [1, 0], - [0, 1] - ] + lattice_vecs = [[1, 0], [0, 1]] motif = [[0.5, 0.5]] von_neumann_nb_builder = DistanceNeighborhoodBuilder(1.01) non_periodic_lattice = Lattice(lattice_vecs, False) - non_periodic_struct = PeriodicStructure.build_from(non_periodic_lattice, (3,3), motif) + non_periodic_struct = PeriodicStructure.build_from( + non_periodic_lattice, (3, 3), motif + ) non_periodic_nbhood = von_neumann_nb_builder.get(non_periodic_struct) edge_coords = (0.5, 1.5) @@ -163,4 +156,4 @@ def test_non_periodic_neighborhoods(): corner_id = non_periodic_struct.id_at(corner_coords) corner_nbs = non_periodic_nbhood.neighbors_of(corner_id) - assert len(corner_nbs) == 2 \ No newline at end of file + assert len(corner_nbs) == 2 diff --git a/tests/core/test_parallel_runner.py b/tests/core/test_parallel_runner.py index 3bb277a..5ba0aa9 100644 --- a/tests/core/test_parallel_runner.py +++ b/tests/core/test_parallel_runner.py @@ -9,54 +9,42 @@ from helpers.helpers import skip_windows_due_to_parallel + @skip_windows_due_to_parallel def test_parallel_runner(square_grid_2D_4x4: PeriodicStructure): - class SimpleParallelController(BasicController): - def get_state_update(self, site_id: int, prev_state: SimulationState): prev = prev_state.get_site_state(site_id)["value"] new_state = prev + 1 - return { - site_id: { - "value": new_state - } - } - + return {site_id: {"value": new_state}} + initial_state = SimulationState() for site in square_grid_2D_4x4.sites(): - initial_state.set_site_state(site[SITE_ID], { "value": 0 }) - + initial_state.set_site_state(site[SITE_ID], {"value": 0}) + runner = SynchronousRunner(parallel=True) controller = SimpleParallelController() - result = runner.run(initial_state, controller=controller, num_steps = 1000) + result = runner.run(initial_state, controller=controller, num_steps=1000) last_step = result.last_step for site_state in last_step.all_site_states(): assert site_state["value"] == 1000 + @skip_windows_due_to_parallel def test_parallel_runner_speed(square_grid_2D_4x4: PeriodicStructure): - class SimpleParallelController(BasicController): - def get_state_update(self, site_id: int, prev_state: SimulationState): prev = prev_state.get_site_state(site_id)["value"] new_state = prev + 1 - return { - site_id: { - "value": new_state - } - } - + return {site_id: {"value": new_state}} + initial_state = SimulationState() for site in square_grid_2D_4x4.sites(): - initial_state.set_site_state(site[SITE_ID], { "value": 0 }) - + initial_state.set_site_state(site[SITE_ID], {"value": 0}) - parallel_runner = SynchronousRunner(parallel=True) series_runner = SynchronousRunner() @@ -65,11 +53,15 @@ def get_state_update(self, site_id: int, prev_state: SimulationState): num_steps = 1000 t0 = time.time() - parallel_result = parallel_runner.run(initial_state, controller=controller, num_steps = num_steps) + parallel_result = parallel_runner.run( + initial_state, controller=controller, num_steps=num_steps + ) t1 = time.time() t2 = time.time() - series_result = series_runner.run(initial_state, controller=controller, num_steps = num_steps) + series_result = series_runner.run( + initial_state, controller=controller, num_steps=num_steps + ) t3 = time.time() assert (t3 - t2) < (t1 - t0) @@ -78,4 +70,4 @@ def get_state_update(self, site_id: int, prev_state: SimulationState): assert site_state["value"] == num_steps for site_state in series_result.last_step.all_site_states(): - assert site_state["value"] == num_steps \ No newline at end of file + assert site_state["value"] == num_steps diff --git a/tests/core/test_periodic_structure.py b/tests/core/test_periodic_structure.py index a7a557e..c34e189 100644 --- a/tests/core/test_periodic_structure.py +++ b/tests/core/test_periodic_structure.py @@ -3,9 +3,11 @@ from pylattica.core import PeriodicStructure, Lattice + def test_can_instantiate_structure(square_lattice): assert PeriodicStructure(square_lattice) is not None + def test_serialization(square_2x2_2D_grid_in_test: PeriodicStructure): d = square_2x2_2D_grid_in_test.as_dict() @@ -15,34 +17,45 @@ def test_serialization(square_2x2_2D_grid_in_test: PeriodicStructure): for sid, site in square_2x2_2D_grid_in_test._sites.items(): r_site = reproduced._sites[sid] - assert r_site['_site_id'] == site['_site_id'] - assert r_site['_site_class'] == site['_site_class'] - assert (r_site['_location'] == site['_location']).all() + assert r_site["_site_id"] == site["_site_id"] + assert r_site["_site_class"] == site["_site_class"] + assert (r_site["_location"] == site["_location"]).all() assert reproduced.dim == square_2x2_2D_grid_in_test.dim assert reproduced.site_ids == square_2x2_2D_grid_in_test.site_ids -def test_simple_structure_has_correct_sites(square_2x2_2D_grid_in_test: PeriodicStructure): + +def test_simple_structure_has_correct_sites( + square_2x2_2D_grid_in_test: PeriodicStructure, +): assert square_2x2_2D_grid_in_test.site_at((0.5, 0.5)) is not None assert square_2x2_2D_grid_in_test.site_at((0.5, 1.5)) is not None assert square_2x2_2D_grid_in_test.site_at((1.5, 0.5)) is not None assert square_2x2_2D_grid_in_test.site_at((1.5, 1.5)) is not None -def test_simple_structure_doesnt_have_incorrect_sites(square_2x2_2D_grid_in_test: PeriodicStructure): + +def test_simple_structure_doesnt_have_incorrect_sites( + square_2x2_2D_grid_in_test: PeriodicStructure, +): assert square_2x2_2D_grid_in_test.site_at((0.5, 0.501)) is None -def test_simple_structure_has_periodic_sites(square_2x2_2D_grid_in_test: PeriodicStructure): + +def test_simple_structure_has_periodic_sites( + square_2x2_2D_grid_in_test: PeriodicStructure, +): assert square_2x2_2D_grid_in_test.site_at((-1.5, -1.5)) is not None assert square_2x2_2D_grid_in_test.site_at((-1.5, 1.5)) is not None assert square_2x2_2D_grid_in_test.site_at((-1.5, 2.5)) is not None assert square_2x2_2D_grid_in_test.site_at((2.5, -0.5)) is not None + def test_sites_have_unaltered_location(square_2x2_2D_grid_in_test: PeriodicStructure): location = (0.5, 0.5) site = square_2x2_2D_grid_in_test.site_at(location) assert site[LOCATION][0] == 0.5 assert site[LOCATION][1] == 0.5 + def test_structure_returns_sites(square_2x2_2D_grid_in_test): no_sites = square_2x2_2D_grid_in_test.sites("NOT_A_SITE") assert len(no_sites) == 0 @@ -50,15 +63,12 @@ def test_structure_returns_sites(square_2x2_2D_grid_in_test): all_sites = square_2x2_2D_grid_in_test.sites("A") assert len(all_sites) == 4 + def test_build_structure_from_list_motif(square_2D_lattice): - motif = [( - 0.5, 0.5 - )] + motif = [(0.5, 0.5)] struct = PeriodicStructure.build_from( - square_2D_lattice, - num_cells=(3,3), - site_motif=motif + square_2D_lattice, num_cells=(3, 3), site_motif=motif ) assert len(struct.site_ids) == 9 @@ -66,6 +76,7 @@ def test_build_structure_from_list_motif(square_2D_lattice): assert struct.site_at((-0.5, 2.5)) is not None assert struct.site_at((0.25, 0.25)) is None + def test_periodic_structure_class_iding(square_2D_lattice): motif = { "A": [(0.25, 0.25)], @@ -73,7 +84,7 @@ def test_periodic_structure_class_iding(square_2D_lattice): "C": [(0.35, 0.45)], } - struct = PeriodicStructure.build_from(square_2D_lattice, (2,2), motif) + struct = PeriodicStructure.build_from(square_2D_lattice, (2, 2), motif) assert struct.class_at((0.25, 0.25)) == "A" assert struct.class_at((1.35, 0.25)) == "B" @@ -84,7 +95,8 @@ def test_periodic_structure_class_iding(square_2D_lattice): assert "B" in struct.all_site_classes() assert "C" in struct.all_site_classes() - assert struct.class_at((0,0)) is None + assert struct.class_at((0, 0)) is None + def test_id_at(square_2D_lattice): motif = { @@ -93,19 +105,20 @@ def test_id_at(square_2D_lattice): "C": [(0.35, 0.45)], } - struct = PeriodicStructure.build_from(square_2D_lattice, (2,2), motif) + struct = PeriodicStructure.build_from(square_2D_lattice, (2, 2), motif) site_1 = struct.site_at((0.25, 0.25)) assert struct.id_at((0.25, 0.25)) == site_1["_site_id"] - assert struct.id_at((0,0)) is None + assert struct.id_at((0, 0)) is None + def test_partially_periodic_structure(square_2D_basis_vecs): lat = Lattice(square_2D_basis_vecs, (False, True)) motif = [[0.5, 0.5]] - struct = PeriodicStructure.build_from(lat, (3,3), motif) - + struct = PeriodicStructure.build_from(lat, (3, 3), motif) + assert struct.site_at((-0.5, 0.5)) is None assert struct.site_at((0.5, 1.5)) is not None - assert struct.site_at((0.5, -1.5)) is not None \ No newline at end of file + assert struct.site_at((0.5, -1.5)) is not None diff --git a/tests/core/test_runner_utils.py b/tests/core/test_runner_utils.py index 072e719..f131a15 100644 --- a/tests/core/test_runner_utils.py +++ b/tests/core/test_runner_utils.py @@ -3,136 +3,95 @@ from pylattica.core.runner.common import merge_updates from pylattica.core.constants import GENERAL, SITES + @pytest.fixture def curr_updates(): - return { SITES: { 0: { "a": 1 }}, GENERAL: {} } + return {SITES: {0: {"a": 1}}, GENERAL: {}} + def test_merge_updates_no_new(curr_updates): - new_updates = None + new_updates = None updated_updates = merge_updates(new_updates, curr_updates) assert updated_updates == curr_updates -def test_merge_updates_full_with_sites(curr_updates): - new_updates = { SITES: { 1: { "b": 1 }} } +def test_merge_updates_full_with_sites(curr_updates): + new_updates = {SITES: {1: {"b": 1}}} updated_updates = merge_updates(new_updates, curr_updates) - expected = { - SITES: { - 1: {"b": 1}, - 0: {"a": 1} - }, - GENERAL: {} - } + expected = {SITES: {1: {"b": 1}, 0: {"a": 1}}, GENERAL: {}} assert updated_updates == expected -def test_merge_updates_full_with_sites_and_general(curr_updates): - new_updates = { SITES: { 1: { "b": 1 }}, GENERAL: { "c": 2 } } +def test_merge_updates_full_with_sites_and_general(curr_updates): + new_updates = {SITES: {1: {"b": 1}}, GENERAL: {"c": 2}} updated_updates = merge_updates(new_updates, curr_updates) - expected = { - SITES: { - 1: {"b": 1}, - 0: {"a": 1} - }, - GENERAL: { - "c": 2 - } - } + expected = {SITES: {1: {"b": 1}, 0: {"a": 1}}, GENERAL: {"c": 2}} assert updated_updates == expected -def test_merge_updates_full_with_sites_overwrite(curr_updates): - new_updates = { SITES: { 1: { "b": 1 }, 0: { "a": 2 }} } +def test_merge_updates_full_with_sites_overwrite(curr_updates): + new_updates = {SITES: {1: {"b": 1}, 0: {"a": 2}}} updated_updates = merge_updates(new_updates, curr_updates) - expected = { - SITES: { - 1: {"b": 1}, - 0: {"a": 2} - }, - GENERAL: {} - } + expected = {SITES: {1: {"b": 1}, 0: {"a": 2}}, GENERAL: {}} assert updated_updates == expected -def test_merge_updates_implicit_sites(curr_updates): - new_updates = { 1: { "b": 1 }} +def test_merge_updates_implicit_sites(curr_updates): + new_updates = {1: {"b": 1}} updated_updates = merge_updates(new_updates, curr_updates) - expected = { - SITES: { - 1: {"b": 1}, - 0: {"a": 1} - }, - GENERAL: {} - } + expected = {SITES: {1: {"b": 1}, 0: {"a": 1}}, GENERAL: {}} assert updated_updates == expected -def test_merge_updates_implicit_sites_overwrite(curr_updates): - new_updates = { 1: { "b": 1 }, 0: { "a": 2 }} +def test_merge_updates_implicit_sites_overwrite(curr_updates): + new_updates = {1: {"b": 1}, 0: {"a": 2}} updated_updates = merge_updates(new_updates, curr_updates) - expected = { - SITES: { - 1: {"b": 1}, - 0: {"a": 2} - }, - GENERAL: {} - } + expected = {SITES: {1: {"b": 1}, 0: {"a": 2}}, GENERAL: {}} assert updated_updates == expected -def test_merge_updates_specific_site(curr_updates): - new_updates = { "b": 1 } +def test_merge_updates_specific_site(curr_updates): + new_updates = {"b": 1} updated_updates = merge_updates(new_updates, curr_updates, site_id=1) - expected = { - SITES: { - 1: {"b": 1}, - 0: {"a": 1} - }, - GENERAL: {} - } + expected = {SITES: {1: {"b": 1}, 0: {"a": 1}}, GENERAL: {}} assert updated_updates == expected -def test_merge_updates_specific_site_no_curr(): - new_updates = { "b": 1 } +def test_merge_updates_specific_site_no_curr(): + new_updates = {"b": 1} updated_updates = merge_updates(new_updates, site_id=1) expected = { SITES: { 1: {"b": 1}, }, - GENERAL: {} + GENERAL: {}, } assert updated_updates == expected -def test_merge_updates_specific_site_overwrite(curr_updates): - new_updates = { "a": 2 } +def test_merge_updates_specific_site_overwrite(curr_updates): + new_updates = {"a": 2} updated_updates = merge_updates(new_updates, curr_updates, site_id=0) - expected = { - SITES: { - 0: {"a": 2} - }, - GENERAL: {} - } + expected = {SITES: {0: {"a": 2}}, GENERAL: {}} assert updated_updates == expected -def test_merge_updates_bad_args(): +def test_merge_updates_bad_args(): with pytest.raises(ValueError, match="Bad combination"): - merge_updates({}, None, None) \ No newline at end of file + merge_updates({}, None, None) diff --git a/tests/core/test_simulation.py b/tests/core/test_simulation.py index f0b2e60..1a7c334 100644 --- a/tests/core/test_simulation.py +++ b/tests/core/test_simulation.py @@ -11,14 +11,16 @@ def test_can_instantiate_periodic_state(square_2x2_2D_grid_in_test): state = SimulationState() assert Simulation(state, square_2x2_2D_grid_in_test) is not None + def test_retrieves_state_correctly(square_2x2_2D_grid_in_test: PeriodicStructure): state = SimulationState() site = square_2x2_2D_grid_in_test.site_at((0.5, 0.5)) - state.set_site_state(site[SITE_ID], { 'my_state_key': 2 }) + state.set_site_state(site[SITE_ID], {"my_state_key": 2}) periodic_state = Simulation(state, square_2x2_2D_grid_in_test) site_state = periodic_state.state_at((0.5, 0.5)) - assert site_state['my_state_key'] == 2 + assert site_state["my_state_key"] == 2 + def test_returns_none_if_no_state(square_2x2_2D_grid_in_test: PeriodicStructure): state = SimulationState() @@ -26,10 +28,11 @@ def test_returns_none_if_no_state(square_2x2_2D_grid_in_test: PeriodicStructure) assert sim.state_at((0.6, 0.6)) is None + def test_serialization(square_2x2_2D_grid_in_test: PeriodicStructure): state = SimulationState() site = square_2x2_2D_grid_in_test.site_at((0.5, 0.5)) - state.set_site_state(site[SITE_ID], { 'my_state_key': 2 }) + state.set_site_state(site[SITE_ID], {"my_state_key": 2}) sim = Simulation(state, square_2x2_2D_grid_in_test) fname = "test_serialization.tmp.json" @@ -37,4 +40,4 @@ def test_serialization(square_2x2_2D_grid_in_test: PeriodicStructure): sim2 = Simulation.from_file(fname) os.remove(fname) assert sim2.state._state == state._state - assert sim2.structure._sites == sim2.structure._sites \ No newline at end of file + assert sim2.structure._sites == sim2.structure._sites diff --git a/tests/core/test_simulation_result.py b/tests/core/test_simulation_result.py index 1dd0ebe..1b51149 100644 --- a/tests/core/test_simulation_result.py +++ b/tests/core/test_simulation_result.py @@ -9,6 +9,7 @@ def initial_state(): return SimulationState() + @pytest.fixture def random_result_big(initial_state): result = SimulationResult(initial_state) @@ -17,15 +18,12 @@ def random_result_big(initial_state): site_id = random.randint(0, 10) val = random.random() - updates = { - site_id: { - "a": val - } - } + updates = {site_id: {"a": val}} result.add_step(updates) return result + @pytest.fixture def random_result_small(initial_state): result = SimulationResult(initial_state) @@ -34,15 +32,12 @@ def random_result_small(initial_state): site_id = random.randint(0, 10) val = random.random() - updates = { - site_id: { - "a": val - } - } + updates = {site_id: {"a": val}} result.add_step(updates) return result + @pytest.fixture def random_result_small_ordered(initial_state): result = SimulationResult(initial_state) @@ -50,45 +45,41 @@ def random_result_small_ordered(initial_state): for i in range(3): site_id = random.randint(0, 10) - updates = { - site_id: { - "a": i - } - } + updates = {site_id: {"a": i}} result.add_step(updates) return result + def test_can_add_step(initial_state): result = SimulationResult(initial_state) - updates = { - 24: { "a": 1 } - } - + updates = {24: {"a": 1}} + result.add_step(updates) assert len(result) == 2 first_step = result.first_step assert first_step.as_dict() == initial_state.as_dict() -def test_can_load_at_intervals(random_result_big): +def test_can_load_at_intervals(random_result_big): assert len(random_result_big) == 1000 - random_result_big.load_steps(interval = 10) + random_result_big.load_steps(interval=10) assert len(random_result_big._stored_states) == 100 + def test_serialization(random_result_big: SimulationResult): d = random_result_big.as_dict() - rehydrated = SimulationResult.from_dict(d) for idx, step in enumerate(rehydrated.steps()): orig = random_result_big.get_step(idx) assert step.as_dict() == orig.as_dict() + def test_write_file(random_result_small: SimulationResult): fname = "tmp_test_res.json" random_result_small.to_file(fname) @@ -98,6 +89,7 @@ def test_write_file(random_result_small: SimulationResult): os.remove(fname) assert random_result_small.as_dict() == rehydrated.as_dict() + def test_write_file_autoname(random_result_small: SimulationResult): fname = random_result_small.to_file() @@ -106,6 +98,7 @@ def test_write_file_autoname(random_result_small: SimulationResult): os.remove(fname) assert random_result_small.as_dict() == rehydrated.as_dict() + def test_diff_storage(random_result_small_ordered: SimulationResult): diff_one = random_result_small_ordered._diffs[0] assert len(diff_one.keys()) == 1 diff --git a/tests/core/test_simulation_state.py b/tests/core/test_simulation_state.py index 7ba98de..ad2d9a9 100644 --- a/tests/core/test_simulation_state.py +++ b/tests/core/test_simulation_state.py @@ -1,58 +1,54 @@ from pylattica.discrete import PhaseSet from pylattica.structures.square_grid import DiscreteGridSetup -from pylattica.structures.square_grid.neighborhoods import PseudoHexagonalNeighborhoodBuilder2D +from pylattica.structures.square_grid.neighborhoods import ( + PseudoHexagonalNeighborhoodBuilder2D, +) from pylattica.core.simulation_state import SimulationState from pylattica.core.constants import SITES, GENERAL + def test_can_run_growth_sim_series(): phases = PhaseSet(["A", "B", "C", "D"]) nb_spec = PseudoHexagonalNeighborhoodBuilder2D() setup = DiscreteGridSetup(phases) - nuc_amts = { - 'B': 1, - 'C': 1, - 'D': 1 - } + nuc_amts = {"B": 1, "C": 1, "D": 1} periodic_initial_state = setup.setup_random_sites(20, 20, "A", nuc_amts=nuc_amts) - assert 'SITES' not in periodic_initial_state.state.site_ids() + assert "SITES" not in periodic_initial_state.state.site_ids() + def test_can_serialize(): state = SimulationState() - state.set_site_state(0, { "a": 1 }) + state.set_site_state(0, {"a": 1}) d = state.as_dict() assert "state" in d assert "@module" in d assert "@class" in d - + assert SITES in d["state"] and GENERAL in d["state"] rehydrated = SimulationState.from_dict(d) assert rehydrated.get_site_state(0)["a"] == 1 + def test_batch_update(): state = SimulationState() - updates = { - 1: { "a": 3 }, - 2: { "b": 4 } - } - + updates = {1: {"a": 3}, 2: {"b": 4}} + state.batch_update(updates) assert state.get_site_state(1)["a"] == 3 assert state.get_site_state(2)["b"] == 4 + def test_states_equal(): state1 = SimulationState() state2 = SimulationState() - updates = { - 1: { "a": 3 }, - 2: { "b": 4 } - } - + updates = {1: {"a": 3}, 2: {"b": 4}} + state1.batch_update(updates) state2.batch_update(updates) - assert state1 == state2 \ No newline at end of file + assert state1 == state2 diff --git a/tests/core/test_structure_builder.py b/tests/core/test_structure_builder.py index 4d584be..117a1b5 100644 --- a/tests/core/test_structure_builder.py +++ b/tests/core/test_structure_builder.py @@ -5,124 +5,79 @@ from pylattica.core.constants import LOCATION, SITE_CLASS from pylattica.core.structure_builder import StructureBuilder + def test_basic_structure_building(): - lattice = Lattice([ - [1, 0], - [0, 1] - ]) - - motif = { - "A": [ - (0.5, 0.5) - ], - "B": [ - (0.25, 0.5) - ] - } + lattice = Lattice([[1, 0], [0, 1]]) + + motif = {"A": [(0.5, 0.5)], "B": [(0.25, 0.5)]} builder = StructureBuilder(lattice, motif) structure = builder.build((1, 1)) assert len(structure.site_ids) == 2 - assert structure.site_at((0,0)) is None - assert structure.site_at((0.5,0.5)) is not None + assert structure.site_at((0, 0)) is None + assert structure.site_at((0.5, 0.5)) is not None assert structure.site_at((0.5, 0.5))[SITE_CLASS] is "A" - assert structure.site_at((0.25,0.5)) is not None + assert structure.site_at((0.25, 0.5)) is not None assert structure.site_at((0.25, 0.5))[SITE_CLASS] is "B" + def test_basic_structure_building_size_as_int(): - lattice = Lattice([ - [1, 0], - [0, 1] - ]) - - motif = { - "A": [ - (0.5, 0.5) - ], - "B": [ - (0.25, 0.5) - ] - } + lattice = Lattice([[1, 0], [0, 1]]) + + motif = {"A": [(0.5, 0.5)], "B": [(0.25, 0.5)]} builder = StructureBuilder(lattice, motif) structure = builder.build(2) assert len(structure.site_ids) == 8 - assert structure.site_at((0,0)) is None - assert structure.site_at((0.5,0.5)) is not None + assert structure.site_at((0, 0)) is None + assert structure.site_at((0.5, 0.5)) is not None assert structure.site_at((0.5, 0.5))[SITE_CLASS] is "A" - assert structure.site_at((0.25,0.5)) is not None + assert structure.site_at((0.25, 0.5)) is not None assert structure.site_at((0.25, 0.5))[SITE_CLASS] is "B" + def test_basic_structure_building_bad_size_value(): - lattice = Lattice([ - [1, 0], - [0, 1] - ]) - - motif = { - "A": [ - (0.5, 0.5) - ], - "B": [ - (0.25, 0.5) - ] - } + lattice = Lattice([[1, 0], [0, 1]]) + + motif = {"A": [(0.5, 0.5)], "B": [(0.25, 0.5)]} builder = StructureBuilder(lattice, motif) with pytest.raises(ValueError, match="Desired structure dimensions"): - builder.build((2,2,2,2)) + builder.build((2, 2, 2, 2)) + def test_structure_building_unequal_dirs(): - lattice = Lattice([ - [1, 0], - [0, 1] - ]) - - motif = { - "A": [ - (0.5, 0.5) - ], - "B": [ - (0.25, 0.5) - ] - } + lattice = Lattice([[1, 0], [0, 1]]) + + motif = {"A": [(0.5, 0.5)], "B": [(0.25, 0.5)]} builder = StructureBuilder(lattice, motif) structure = builder.build((2, 1)) assert len(structure.site_ids) == 4 - assert structure.site_at((0,0)) is None - assert structure.site_at((1.5,0.5)) is not None + assert structure.site_at((0, 0)) is None + assert structure.site_at((1.5, 0.5)) is not None assert structure.site_at((1.5, 0.5))[SITE_CLASS] is "A" - assert structure.site_at((0.25,0.5)) is not None + assert structure.site_at((0.25, 0.5)) is not None assert structure.site_at((1.25, 0.5))[SITE_CLASS] is "B" + def test_structure_building_frac_coords(): - lattice = Lattice([ - [2, 0], - [0, 2] - ]) - - motif = { - "A": [ - (0.5, 0.5) - ], - "B": [ - (0.25, 0.5) - ] - } + lattice = Lattice([[2, 0], [0, 2]]) + + motif = {"A": [(0.5, 0.5)], "B": [(0.25, 0.5)]} builder = StructureBuilder(lattice, motif) builder.frac_coords = True structure = builder.build((2, 1)) assert len(structure.site_ids) == 4 - assert structure.site_at((0,0)) is None + assert structure.site_at((0, 0)) is None assert structure.site_at((3, 1)) is not None assert structure.site_at((3, 1))[SITE_CLASS] is "A" diff --git a/tests/discrete/test_discrete_result_analyzer.py b/tests/discrete/test_discrete_result_analyzer.py index 3736dec..cebbe31 100644 --- a/tests/discrete/test_discrete_result_analyzer.py +++ b/tests/discrete/test_discrete_result_analyzer.py @@ -5,20 +5,22 @@ from pylattica.discrete import PhaseSet, DiscreteResultAnalyzer from pylattica.structures.square_grid.grid_setup import DiscreteGridSetup + @pytest.fixture() def discrete_result(): phases = PhaseSet(["dead", "alive"]) setup = DiscreteGridSetup(phases) simulation = setup.setup_noise(10, ["dead", "alive"]) - controller = GameOfLifeController(structure = simulation.structure, - variant=Life) + controller = GameOfLifeController(structure=simulation.structure, variant=Life) runner = SynchronousRunner(parallel=False) return runner.run(simulation.state, controller, 10, verbose=False) + def test_plot_phase_fractions(discrete_result): analyzer = DiscreteResultAnalyzer(discrete_result) analyzer.plot_phase_fractions() + def test_final_phase_fractions(discrete_result): analyzer = DiscreteResultAnalyzer(discrete_result) fracs = analyzer.final_phase_fractions() @@ -26,16 +28,15 @@ def test_final_phase_fractions(discrete_result): assert type(phase) is str assert type(amt) is float + def test_phase_fraction_at_step(): phases = PhaseSet(["dead", "alive"]) setup = DiscreteGridSetup(phases) simulation = setup.setup_interface(10, "dead", "alive") - controller = GameOfLifeController(structure = simulation.structure, - variant=Life) + controller = GameOfLifeController(structure=simulation.structure, variant=Life) runner = SynchronousRunner(parallel=False) result = runner.run(simulation.state, controller, 10, verbose=False) analyzer = DiscreteResultAnalyzer(result) assert analyzer.phase_fraction_at(0, "dead") == 0.5 - diff --git a/tests/discrete/test_discrete_step_analyzer.py b/tests/discrete/test_discrete_step_analyzer.py index bde875e..7c43f86 100644 --- a/tests/discrete/test_discrete_step_analyzer.py +++ b/tests/discrete/test_discrete_step_analyzer.py @@ -4,11 +4,12 @@ from pylattica.discrete import DiscreteStepAnalyzer from pylattica.core import SimulationState + def test_cell_count(): state = SimulationState() - state.set_site_state(1, { DISCRETE_OCCUPANCY: "A"}) - state.set_site_state(3, { DISCRETE_OCCUPANCY: "A"}) - state.set_site_state(2, { DISCRETE_OCCUPANCY: "B" }) + state.set_site_state(1, {DISCRETE_OCCUPANCY: "A"}) + state.set_site_state(3, {DISCRETE_OCCUPANCY: "A"}) + state.set_site_state(2, {DISCRETE_OCCUPANCY: "B"}) analyzer = DiscreteStepAnalyzer() @@ -18,4 +19,4 @@ def test_cell_count(): assert analyzer.cell_ratio(state, "A", "B") == 2 assert analyzer.cell_ratio(state, "B", "A") == 0.5 - assert analyzer.phase_count(state) == 2 \ No newline at end of file + assert analyzer.phase_count(state) == 2 diff --git a/tests/helpers/helpers.py b/tests/helpers/helpers.py index 59094bf..400e19d 100644 --- a/tests/helpers/helpers.py +++ b/tests/helpers/helpers.py @@ -1,4 +1,7 @@ import sys import pytest -skip_windows_due_to_parallel = pytest.mark.skipif(sys.platform.startswith("win"), reason="Parallel simulation not supported on windows due to fork") +skip_windows_due_to_parallel = pytest.mark.skipif( + sys.platform.startswith("win"), + reason="Parallel simulation not supported on windows due to fork", +) diff --git a/tests/models/test_gol_models.py b/tests/models/test_gol_models.py index 2fc2758..150e6fb 100644 --- a/tests/models/test_gol_models.py +++ b/tests/models/test_gol_models.py @@ -1,9 +1,17 @@ from pylattica.core import SynchronousRunner from pylattica.discrete.state_constants import DISCRETE_OCCUPANCY -from pylattica.models.game_of_life import Maze, Anneal, Diamoeba, Seeds, Life, GameOfLifeController +from pylattica.models.game_of_life import ( + Maze, + Anneal, + Diamoeba, + Seeds, + Life, + GameOfLifeController, +) from pylattica.discrete import PhaseSet from pylattica.structures.square_grid.grid_setup import DiscreteGridSetup + def test_gol_variants(): variants = [Life, Maze, Anneal, Diamoeba, Seeds] for variant in variants: @@ -11,20 +19,18 @@ def test_gol_variants(): setup = DiscreteGridSetup(phases) simulation = setup.setup_noise(10, ["dead", "alive"]) controller = GameOfLifeController( - structure=simulation.structure, - variant=variant + structure=simulation.structure, variant=variant ) runner = SynchronousRunner(parallel=False) runner.run(simulation.state, controller, 10, verbose=False) + def test_gol_update_rule(): phases = PhaseSet(["dead", "alive"]) setup = DiscreteGridSetup(phases) - simulation = setup.setup_interface(10, "dead", "alive") + simulation = setup.setup_interface(10, "dead", "alive") controller = GameOfLifeController(structure=simulation.structure) controller.pre_run(None) - site_id = simulation.structure.site_at((4,4))['_site_id'] + site_id = simulation.structure.site_at((4, 4))["_site_id"] update = controller.get_state_update(site_id, simulation.state) assert update[DISCRETE_OCCUPANCY] == "alive" - - diff --git a/tests/models/test_growth_model.py b/tests/models/test_growth_model.py index 8c4edff..b207601 100644 --- a/tests/models/test_growth_model.py +++ b/tests/models/test_growth_model.py @@ -8,59 +8,57 @@ from helpers.helpers import skip_windows_due_to_parallel + @skip_windows_due_to_parallel def test_can_run_growth_sim_parallel(): phases = PhaseSet(["A", "B", "C", "D"]) nb_spec = MooreNbHoodBuilder() setup = DiscreteGridSetup(phases) - periodic_initial_state = setup.setup_coords(20, "A", - { - "B": [(10, 10)] - } - ) + periodic_initial_state = setup.setup_coords(20, "A", {"B": [(10, 10)]}) controller = GrowthController( phases, periodic_initial_state.structure, nb_builder=nb_spec, - background_phase="A" + background_phase="A", ) runner = SynchronousRunner(parallel=True) - res = runner.run(periodic_initial_state.state, controller, num_steps = 3) + res = runner.run(periodic_initial_state.state, controller, num_steps=3) analyzer = StateAnalyzer(periodic_initial_state.structure) - assert analyzer.get_site_count_where_equal(res.get_step(1), { - DISCRETE_OCCUPANCY: "B" - }) == 9 + assert ( + analyzer.get_site_count_where_equal(res.get_step(1), {DISCRETE_OCCUPANCY: "B"}) + == 9 + ) + + assert ( + analyzer.get_site_count_where_equal(res.get_step(2), {DISCRETE_OCCUPANCY: "B"}) + == 25 + ) - assert analyzer.get_site_count_where_equal(res.get_step(2), { - DISCRETE_OCCUPANCY: "B" - }) == 25 def test_can_run_growth_sim_series(): phases = PhaseSet(["A", "B", "C", "D"]) nb_spec = MooreNbHoodBuilder() setup = DiscreteGridSetup(phases) - periodic_initial_state = setup.setup_coords(20, "A", - { - "B": [(10, 10)] - } - ) + periodic_initial_state = setup.setup_coords(20, "A", {"B": [(10, 10)]}) controller = GrowthController( phases, periodic_initial_state.structure, nb_builder=nb_spec, - background_phase="A" + background_phase="A", ) runner = SynchronousRunner() - res = runner.run(periodic_initial_state.state, controller, num_steps = 3) + res = runner.run(periodic_initial_state.state, controller, num_steps=3) analyzer = StateAnalyzer(periodic_initial_state.structure) - assert analyzer.get_site_count_where_equal(res.get_step(1), { - DISCRETE_OCCUPANCY: "B" - }) == 9 + assert ( + analyzer.get_site_count_where_equal(res.get_step(1), {DISCRETE_OCCUPANCY: "B"}) + == 9 + ) - assert analyzer.get_site_count_where_equal(res.get_step(2), { - DISCRETE_OCCUPANCY: "B" - }) == 25 \ No newline at end of file + assert ( + analyzer.get_site_count_where_equal(res.get_step(2), {DISCRETE_OCCUPANCY: "B"}) + == 25 + ) diff --git a/tests/structures/honeycombs/test_hc_lattices.py b/tests/structures/honeycombs/test_hc_lattices.py index b6d24c5..be0c68e 100644 --- a/tests/structures/honeycombs/test_hc_lattices.py +++ b/tests/structures/honeycombs/test_hc_lattices.py @@ -2,20 +2,24 @@ import numpy as np -from pylattica.structures.honeycomb.lattice import RhombohedralLattice, HONEYCOMB_SIDE_LENGTH, ROOT_3 +from pylattica.structures.honeycomb.lattice import ( + RhombohedralLattice, + HONEYCOMB_SIDE_LENGTH, + ROOT_3, +) def test_rhombohedral_lattice(): lattice = RhombohedralLattice() - pt1 = (1/2, 1/2) + pt1 = (1 / 2, 1 / 2) periodized_pt1 = lattice.get_periodized_cartesian_coords(pt1) assert (np.array(pt1) == periodized_pt1).all() - pt2 = (0, 1/2) + pt2 = (0, 1 / 2) periodized_pt2 = lattice.get_periodized_cartesian_coords(pt2) - assert (np.array([1, 1/2]) == periodized_pt2).all() + assert (np.array([1, 1 / 2]) == periodized_pt2).all() - pt3 = (1, 3/2) + pt3 = (1, 3 / 2) periodized_pt3 = lattice.get_periodized_cartesian_coords(pt3) - assert np.allclose(np.array([1/2, 3/2 - ROOT_3 /2]), periodized_pt3) \ No newline at end of file + assert np.allclose(np.array([1 / 2, 3 / 2 - ROOT_3 / 2]), periodized_pt3) diff --git a/tests/structures/honeycombs/test_hc_neighborhoods.py b/tests/structures/honeycombs/test_hc_neighborhoods.py index 176a0a8..eb71e46 100644 --- a/tests/structures/honeycombs/test_hc_neighborhoods.py +++ b/tests/structures/honeycombs/test_hc_neighborhoods.py @@ -1,14 +1,18 @@ import pytest -from pylattica.structures.honeycomb import HoneycombTilingBuilder, HoneycombNeighborhoodBuilder +from pylattica.structures.honeycomb import ( + HoneycombTilingBuilder, + HoneycombNeighborhoodBuilder, +) + def test_basic_tiling_neighborhood(): # this test is for a hexagonal tiling - each point # has 6 neighbors - struct = HoneycombTilingBuilder().build((3,3)) + struct = HoneycombTilingBuilder().build((3, 3)) nbhood = HoneycombNeighborhoodBuilder().get(struct) nbs = nbhood.neighbors_of(0, True) - assert len(nbs) == 6 \ No newline at end of file + assert len(nbs) == 6 diff --git a/tests/structures/honeycombs/test_hc_structure_builders.py b/tests/structures/honeycombs/test_hc_structure_builders.py index c24b9bc..f0a34bb 100644 --- a/tests/structures/honeycombs/test_hc_structure_builders.py +++ b/tests/structures/honeycombs/test_hc_structure_builders.py @@ -7,6 +7,7 @@ from pylattica.core.constants import LOCATION from pylattica.core.periodic_structure import OFFSET_PRECISION + def test_small_honeycomb_tiling_builder(): builder = HoneycombTilingBuilder() @@ -16,40 +17,52 @@ def test_small_honeycomb_tiling_builder(): site_0_cart = tiling.get_site(0)[LOCATION] site_0_frac = tiling.lattice.get_fractional_coords(site_0_cart) - assert np.allclose(np.array(site_0_frac),np.array([0.5, 0.5]), atol=OFFSET_PRECISION) - assert np.allclose(np.array(site_0_cart),np.array([3/4, ROOT_3 / 4]), atol=OFFSET_PRECISION) + assert np.allclose( + np.array(site_0_frac), np.array([0.5, 0.5]), atol=OFFSET_PRECISION + ) + assert np.allclose( + np.array(site_0_cart), np.array([3 / 4, ROOT_3 / 4]), atol=OFFSET_PRECISION + ) + def test_medium_honeycomb_tiling_builder_uneven_size(): builder = HoneycombTilingBuilder() - tiling2 = builder.build((2,1)) + tiling2 = builder.build((2, 1)) site_0_cart = tiling2.get_site(0)[LOCATION] site_0_frac = tiling2.lattice.get_fractional_coords(site_0_cart) - assert np.allclose(np.array(site_0_frac), np.array([0.25, 0.5]), atol=OFFSET_PRECISION) - assert np.allclose(np.array(site_0_cart), np.array([3/4, ROOT_3 / 4]), atol=OFFSET_PRECISION) + assert np.allclose( + np.array(site_0_frac), np.array([0.25, 0.5]), atol=OFFSET_PRECISION + ) + assert np.allclose( + np.array(site_0_cart), np.array([3 / 4, ROOT_3 / 4]), atol=OFFSET_PRECISION + ) site_1_cart = tiling2.get_site(1)[LOCATION] site_1_frac = tiling2.lattice.get_fractional_coords(site_1_cart) - assert np.allclose(np.array(site_1_frac), np.array([3 / 4, 0.5]), atol=OFFSET_PRECISION) - assert np.allclose(np.array(site_1_cart), np.array([7/4, ROOT_3 / 4]), atol=OFFSET_PRECISION) + assert np.allclose( + np.array(site_1_frac), np.array([3 / 4, 0.5]), atol=OFFSET_PRECISION + ) + assert np.allclose( + np.array(site_1_cart), np.array([7 / 4, ROOT_3 / 4]), atol=OFFSET_PRECISION + ) def test_medium_honeycomb_tiling_builder_even_size(): builder = HoneycombTilingBuilder() - tiling = builder.build((3,3)) + tiling = builder.build((3, 3)) middle_loc = tiling.lattice.get_cartesian_coords((0.5, 0.5)) - assert (middle_loc == np.array((9/4, 3 * ROOT_3 / 4))).all() + assert (middle_loc == np.array((9 / 4, 3 * ROOT_3 / 4))).all() middle_site = tiling.site_at(middle_loc) assert middle_site is not None - upper_left_loc = tiling.lattice.get_cartesian_coords((1/6, 5/6)) - assert np.allclose(upper_left_loc, np.array((7/4, 5 * ROOT_3 / 4))) + upper_left_loc = tiling.lattice.get_cartesian_coords((1 / 6, 5 / 6)) + assert np.allclose(upper_left_loc, np.array((7 / 4, 5 * ROOT_3 / 4))) assert np.isclose(np.linalg.norm(middle_loc - upper_left_loc), 1.0) ul_site = tiling.site_at(upper_left_loc) assert ul_site is not None assert tiling.lattice.cartesian_periodic_distance(middle_loc, upper_left_loc) == 1 - \ No newline at end of file diff --git a/tests/structures/square_grid/test_grid_setup.py b/tests/structures/square_grid/test_grid_setup.py index f2c4d48..d874538 100644 --- a/tests/structures/square_grid/test_grid_setup.py +++ b/tests/structures/square_grid/test_grid_setup.py @@ -9,81 +9,77 @@ def test_can_instantiate_grid_setup(simple_phase_set): setup = DiscreteGridSetup(simple_phase_set) assert setup is not None + def test_setup_interface(grid_setup: DiscreteGridSetup): - simulation = grid_setup.setup_interface(4, 'A', 'B') + simulation = grid_setup.setup_interface(4, "A", "B") - A_site = simulation.state_at((1,1)) - assert A_site[DISCRETE_OCCUPANCY] == 'A' + A_site = simulation.state_at((1, 1)) + assert A_site[DISCRETE_OCCUPANCY] == "A" - A_site = simulation.state_at((1,2)) - assert A_site[DISCRETE_OCCUPANCY] == 'A' + A_site = simulation.state_at((1, 2)) + assert A_site[DISCRETE_OCCUPANCY] == "A" - B_site = simulation.state_at((2,2)) - assert B_site[DISCRETE_OCCUPANCY] == 'B' + B_site = simulation.state_at((2, 2)) + assert B_site[DISCRETE_OCCUPANCY] == "B" - B_site = simulation.state_at((2,3)) - assert B_site[DISCRETE_OCCUPANCY] == 'B' + B_site = simulation.state_at((2, 3)) + assert B_site[DISCRETE_OCCUPANCY] == "B" def test_setup_random_particles(grid_setup: DiscreteGridSetup): - state = grid_setup.setup_random_particles(4, radius = 2, num_particles = 3, bulk_phase = 'A', particle_phases = ['B']) + state = grid_setup.setup_random_particles( + 4, radius=2, num_particles=3, bulk_phase="A", particle_phases=["B"] + ) assert state is not None + def test_setup_noise(grid_setup: DiscreteGridSetup): - state = grid_setup.setup_noise(4, phases = ['A', 'B']) + state = grid_setup.setup_noise(4, phases=["A", "B"]) assert state is not None + def test_setup_random_sites(grid_setup: DiscreteGridSetup): num_sites = 2 - nuc_amts = { - 'B': 1, - 'C': 1 - } + nuc_amts = {"B": 1, "C": 1} simulation = grid_setup.setup_random_sites( - 4, - num_sites_desired = num_sites, - background_spec='A', - nuc_amts=nuc_amts, - buffer=1 + 4, num_sites_desired=num_sites, background_spec="A", nuc_amts=nuc_amts, buffer=1 ) analyzer = StateAnalyzer(simulation.structure) def count_criteria(state: Dict) -> bool: - return state[DISCRETE_OCCUPANCY] == 'B' or state[DISCRETE_OCCUPANCY] == 'C' + return state[DISCRETE_OCCUPANCY] == "B" or state[DISCRETE_OCCUPANCY] == "C" + + assert ( + analyzer.get_site_count(simulation.state, state_criteria=[count_criteria]) + == num_sites + ) - assert analyzer.get_site_count(simulation.state, state_criteria=[count_criteria]) == num_sites def test_setup_random_sites_with_ratios(grid_setup: DiscreteGridSetup): num_sites = 2 - nuc_amts = { - 'B': 1, - 'C': 1 - } + nuc_amts = {"B": 1, "C": 1} simulation = grid_setup.setup_random_sites( - 4, - num_sites_desired = num_sites, - background_spec='A', - nuc_amts = nuc_amts, - buffer=1 + 4, num_sites_desired=num_sites, background_spec="A", nuc_amts=nuc_amts, buffer=1 ) analyzer = StateAnalyzer(simulation.structure) def count_criteria(state: Dict) -> bool: - return state[DISCRETE_OCCUPANCY] == 'B' or state[DISCRETE_OCCUPANCY] == 'C' + return state[DISCRETE_OCCUPANCY] == "B" or state[DISCRETE_OCCUPANCY] == "C" - assert analyzer.get_site_count(simulation.state, state_criteria=[count_criteria]) == num_sites + assert ( + analyzer.get_site_count(simulation.state, state_criteria=[count_criteria]) + == num_sites + ) -def test_setup_specific_coords(grid_setup: DiscreteGridSetup): - simulation = grid_setup.setup_coords(4, background_state='A', coordinates = { - 'B': [[0, 0]], - 'C': [[1,1]] - }) - - assert simulation.state_at((0,0))[DISCRETE_OCCUPANCY] == 'B' - assert simulation.state_at((1,1))[DISCRETE_OCCUPANCY] == 'C' - assert simulation.state_at((1,0))[DISCRETE_OCCUPANCY] == 'A' +def test_setup_specific_coords(grid_setup: DiscreteGridSetup): + simulation = grid_setup.setup_coords( + 4, background_state="A", coordinates={"B": [[0, 0]], "C": [[1, 1]]} + ) + assert simulation.state_at((0, 0))[DISCRETE_OCCUPANCY] == "B" + assert simulation.state_at((1, 1))[DISCRETE_OCCUPANCY] == "C" + assert simulation.state_at((1, 0))[DISCRETE_OCCUPANCY] == "A" diff --git a/tests/structures/square_grid/test_grid_structure.py b/tests/structures/square_grid/test_grid_structure.py index 989f3b4..a6763ed 100644 --- a/tests/structures/square_grid/test_grid_structure.py +++ b/tests/structures/square_grid/test_grid_structure.py @@ -1,18 +1,24 @@ -from pylattica.structures.square_grid.structure_builders import SimpleSquare2DStructureBuilder, SimpleSquare3DStructureBuilder +from pylattica.structures.square_grid.structure_builders import ( + SimpleSquare2DStructureBuilder, + SimpleSquare3DStructureBuilder, +) from pylattica.core.constants import LOCATION + def test_grid_has_points_as_expected(square_grid_2D_2x2): assert square_grid_2D_2x2.site_at((0, 0)) is not None assert square_grid_2D_2x2.site_at((0, 1)) is not None assert square_grid_2D_2x2.site_at((1, 0)) is not None assert square_grid_2D_2x2.site_at((1, 1)) is not None + def test_grid_has_points_as_expected_outside_cell(square_grid_2D_2x2): assert square_grid_2D_2x2.site_at((0, -1)) is not None assert square_grid_2D_2x2.site_at((0, 2)) is not None assert square_grid_2D_2x2.site_at((4, 0)) is not None - assert square_grid_2D_2x2.site_at((12,31)) is not None + assert square_grid_2D_2x2.site_at((12, 31)) is not None + def test_grid_does_not_have_unexpected_points(square_grid_2D_2x2): assert square_grid_2D_2x2.site_at((0, 0.5)) is None @@ -20,5 +26,6 @@ def test_grid_does_not_have_unexpected_points(square_grid_2D_2x2): assert square_grid_2D_2x2.site_at((1, 1.5)) is None assert square_grid_2D_2x2.site_at((-12.001, 1)) is None + def test_retrieved_site_has_correct_location(square_grid_2D_2x2): - assert (square_grid_2D_2x2.site_at((0, 0))[LOCATION] == (0,0)).all() \ No newline at end of file + assert (square_grid_2D_2x2.site_at((0, 0))[LOCATION] == (0, 0)).all() diff --git a/tests/structures/square_grid/test_growth_setup.py b/tests/structures/square_grid/test_growth_setup.py index 6d03442..2c7bd9f 100644 --- a/tests/structures/square_grid/test_growth_setup.py +++ b/tests/structures/square_grid/test_growth_setup.py @@ -6,6 +6,7 @@ from helpers.helpers import skip_windows_due_to_parallel + @skip_windows_due_to_parallel def test_growth_setup(): phases = PhaseSet(["A", "B", "C"]) @@ -14,11 +15,8 @@ def test_growth_setup(): total_num_sites = 4 background_phase = "A" - nuc_amts = { - 'B': 1, - 'C': 1 - } - buffer = 2 # Each site should be at least 2 cells away from any other + nuc_amts = {"B": 1, "C": 1} + buffer = 2 # Each site should be at least 2 cells away from any other growth_setup = GrowthSetup(phases) simulation = growth_setup.grow( @@ -27,10 +25,11 @@ def test_growth_setup(): num_sites_desired=total_num_sites, nuc_amts=nuc_amts, buffer=buffer, - nb_builder=MooreNbHoodBuilder(1) + nb_builder=MooreNbHoodBuilder(1), ) analyzer = StateAnalyzer(simulation.structure) - assert analyzer.get_site_count_where_equal(simulation.state, { DISCRETE_OCCUPANCY: "A" }) == 0 - - + assert ( + analyzer.get_site_count_where_equal(simulation.state, {DISCRETE_OCCUPANCY: "A"}) + == 0 + ) diff --git a/tests/structures/square_grid/test_sg_neighborhoods_discrete.py b/tests/structures/square_grid/test_sg_neighborhoods_discrete.py index 30e2b0e..5cd5a65 100644 --- a/tests/structures/square_grid/test_sg_neighborhoods_discrete.py +++ b/tests/structures/square_grid/test_sg_neighborhoods_discrete.py @@ -7,10 +7,14 @@ PseudoHexagonalNeighborhoodBuilder2D, PseudoPentagonalNeighborhoodBuilder, PseudoHexagonalNeighborhoodBuilder3D, - VonNeumannNbHood3DBuilder + VonNeumannNbHood3DBuilder, ) from pylattica.core.constants import SITE_ID -from pylattica.structures.square_grid.structure_builders import SimpleSquare2DStructureBuilder, SimpleSquare3DStructureBuilder +from pylattica.structures.square_grid.structure_builders import ( + SimpleSquare2DStructureBuilder, + SimpleSquare3DStructureBuilder, +) + def test_von_neumann_neighborhood(): struct = SimpleSquare2DStructureBuilder().build(10) @@ -18,7 +22,7 @@ def test_von_neumann_neighborhood(): spec = VonNeumannNbHood2DBuilder() nb_hood = spec.get(struct) - site = struct.site_at((5,5)) + site = struct.site_at((5, 5)) nbs = nb_hood.neighbors_of(site[SITE_ID]) assert len(nbs) == 4 @@ -29,13 +33,14 @@ def test_moore_neighborhood(): spec = MooreNbHoodBuilder() nb_hood = spec.get(struct) - site = struct.site_at((5,5)) + site = struct.site_at((5, 5)) nbs = nb_hood.neighbors_of(site[SITE_ID]) assert len(nbs) == 8 + def test_circular_neighborhood(): struct = SimpleSquare2DStructureBuilder().build(20) - + nb_builder = CircularNeighborhoodBuilder(3) nbh = nb_builder.get(struct) nbs = nbh.neighbors_of(0) @@ -46,33 +51,37 @@ def test_circular_neighborhood(): nbs = nbh.neighbors_of(0) assert len(nbs) == 8 + def test_pseudo_hexagonal_nb_hood(): struct = SimpleSquare2DStructureBuilder().build(10) - + nb_builder = PseudoHexagonalNeighborhoodBuilder2D() nbh = nb_builder.get(struct) nbs = nbh.neighbors_of(0) assert len(nbs) == 6 + def test_pseudo_pentagonal_nb_hood(): struct = SimpleSquare2DStructureBuilder().build(10) - + nb_builder = PseudoPentagonalNeighborhoodBuilder() nbh = nb_builder.get(struct) nbs = nbh.neighbors_of(0) assert len(nbs) == 5 + def test_pseudo_hexagonal_nb_3d_hood(): struct = SimpleSquare3DStructureBuilder().build(10) - + nb_builder = PseudoHexagonalNeighborhoodBuilder3D() nbh = nb_builder.get(struct) nbs = nbh.neighbors_of(0) assert len(nbs) == 8 + def test_von_neumann_nb_3d_hood(): struct = SimpleSquare3DStructureBuilder().build(10) - + nb_builder = VonNeumannNbHood3DBuilder(1) nbh = nb_builder.get(struct) nbs = nbh.neighbors_of(0) diff --git a/tests/structures/square_grid/test_sg_structure_builder.py b/tests/structures/square_grid/test_sg_structure_builder.py index 3cb8d81..82761b3 100644 --- a/tests/structures/square_grid/test_sg_structure_builder.py +++ b/tests/structures/square_grid/test_sg_structure_builder.py @@ -1,7 +1,9 @@ -from pylattica.structures.square_grid.structure_builders import SimpleSquare2DStructureBuilder +from pylattica.structures.square_grid.structure_builders import ( + SimpleSquare2DStructureBuilder, +) import pytest def test_creates_square_grid(): struct = SimpleSquare2DStructureBuilder().build(4) - assert struct.site_at((0,0)) is not None \ No newline at end of file + assert struct.site_at((0, 0)) is not None diff --git a/tests/visualization/test_cell_artists.py b/tests/visualization/test_cell_artists.py index b66a8e8..769c528 100644 --- a/tests/visualization/test_cell_artists.py +++ b/tests/visualization/test_cell_artists.py @@ -3,68 +3,63 @@ from pylattica.visualization import DiscreteCellArtist from pylattica.core import SimulationState + def test_discrete_cell_artist_no_legend_no_cmap(): phases = ["a", "b"] artist = DiscreteCellArtist.from_phase_list(phases, state_key="x") - assert artist.get_color_from_cell_state({ "x": "a" }) != (0, 0, 0) - assert artist.get_color_from_cell_state({ "x": "c" }) == (0, 0, 0) - assert artist.get_color_from_cell_state({ "x": "b" }) != (0, 0, 0) - assert artist.get_color_from_cell_state({ "x": "b" }) != artist.get_color_from_cell_state({ "x": "a" }) - - assert artist.get_cell_legend_label({ "x": "a" }) == "a" - assert artist.get_cell_legend_label({ "x": "b" }) == "b" - assert artist.get_cell_legend_label({ "x": "c" }) == "c" - - state = SimulationState({ - "SITES": { - 1: { - - "x": "a" - }, - 2: { - "x": "b" - }, + assert artist.get_color_from_cell_state({"x": "a"}) != (0, 0, 0) + assert artist.get_color_from_cell_state({"x": "c"}) == (0, 0, 0) + assert artist.get_color_from_cell_state({"x": "b"}) != (0, 0, 0) + assert artist.get_color_from_cell_state( + {"x": "b"} + ) != artist.get_color_from_cell_state({"x": "a"}) + + assert artist.get_cell_legend_label({"x": "a"}) == "a" + assert artist.get_cell_legend_label({"x": "b"}) == "b" + assert artist.get_cell_legend_label({"x": "c"}) == "c" + + state = SimulationState( + { + "SITES": { + 1: {"x": "a"}, + 2: {"x": "b"}, + } } - }) + ) legend = artist.get_legend(state) assert "a" in legend assert "b" in legend - assert legend.get("a") == artist.get_color_from_cell_state({ "x": "a" }) - assert legend.get("b") == artist.get_color_from_cell_state({ "x": "b" }) + assert legend.get("a") == artist.get_color_from_cell_state({"x": "a"}) + assert legend.get("b") == artist.get_color_from_cell_state({"x": "b"}) + def test_discrete_cell_artist_no_legend_cmap(): a_color = (50, 60, 70) b_color = (110, 120, 130) - cmap = { - "a": a_color, - "b": b_color - } + cmap = {"a": a_color, "b": b_color} artist = DiscreteCellArtist(cmap, state_key="x") - assert artist.get_color_from_cell_state({ "x": "a" }) == a_color - assert artist.get_color_from_cell_state({ "x": "c" }) == (0, 0, 0) - assert artist.get_color_from_cell_state({ "x": "b" }) == b_color - - assert artist.get_cell_legend_label({ "x": "a" }) == "a" - assert artist.get_cell_legend_label({ "x": "b" }) == "b" - assert artist.get_cell_legend_label({ "x": "c" }) == "c" - - state = SimulationState({ - "SITES": { - 1: { - - "x": "a" - }, - 2: { - "x": "b" - }, + assert artist.get_color_from_cell_state({"x": "a"}) == a_color + assert artist.get_color_from_cell_state({"x": "c"}) == (0, 0, 0) + assert artist.get_color_from_cell_state({"x": "b"}) == b_color + + assert artist.get_cell_legend_label({"x": "a"}) == "a" + assert artist.get_cell_legend_label({"x": "b"}) == "b" + assert artist.get_cell_legend_label({"x": "c"}) == "c" + + state = SimulationState( + { + "SITES": { + 1: {"x": "a"}, + 2: {"x": "b"}, + } } - }) + ) legend = artist.get_legend(state) @@ -82,36 +77,27 @@ def test_discrete_cell_artist_legend_and_cmap(): a_color_leg = (200, 210, 220) b_color_leg = (230, 240, 250) - cmap = { - "a": a_color, - "b": b_color - } + cmap = {"a": a_color, "b": b_color} - legend = { - "a": a_color_leg, - "b": b_color_leg - } + legend = {"a": a_color_leg, "b": b_color_leg} artist = DiscreteCellArtist(cmap, state_key="x", legend=legend) - assert artist.get_color_from_cell_state({ "x": "a" }) == a_color - assert artist.get_color_from_cell_state({ "x": "c" }) == (0, 0, 0) - assert artist.get_color_from_cell_state({ "x": "b" }) == b_color - - assert artist.get_cell_legend_label({ "x": "a" }) == "a" - assert artist.get_cell_legend_label({ "x": "b" }) == "b" - assert artist.get_cell_legend_label({ "x": "c" }) == "c" - - state = SimulationState({ - "SITES": { - 1: { - - "x": "a" - }, - 2: { - "x": "b" - }, + assert artist.get_color_from_cell_state({"x": "a"}) == a_color + assert artist.get_color_from_cell_state({"x": "c"}) == (0, 0, 0) + assert artist.get_color_from_cell_state({"x": "b"}) == b_color + + assert artist.get_cell_legend_label({"x": "a"}) == "a" + assert artist.get_cell_legend_label({"x": "b"}) == "b" + assert artist.get_cell_legend_label({"x": "c"}) == "c" + + state = SimulationState( + { + "SITES": { + 1: {"x": "a"}, + 2: {"x": "b"}, + } } - }) + ) legend = artist.get_legend(state) @@ -119,4 +105,4 @@ def test_discrete_cell_artist_legend_and_cmap(): assert "b" in legend assert legend.get("a") == a_color_leg - assert legend.get("b") == b_color_leg \ No newline at end of file + assert legend.get("b") == b_color_leg diff --git a/tests/visualization/test_helpers.py b/tests/visualization/test_helpers.py index 29c4ee7..06da856 100644 --- a/tests/visualization/test_helpers.py +++ b/tests/visualization/test_helpers.py @@ -3,11 +3,12 @@ from pylattica.visualization.helpers import color_map, COLORS + def test_color_map_no_extra_phases(): len_colors = len(COLORS) - phases = [str(random.randint(0,100000)) for _ in range(10)] + phases = [str(random.randint(0, 100000)) for _ in range(10)] cmap = color_map(phases) all_colors = set(cmap.values()) - assert len(all_colors) == len(phases) \ No newline at end of file + assert len(all_colors) == len(phases) diff --git a/tests/visualization/test_square_grid_artists.py b/tests/visualization/test_square_grid_artists.py index 14b17eb..db359bb 100644 --- a/tests/visualization/test_square_grid_artists.py +++ b/tests/visualization/test_square_grid_artists.py @@ -2,19 +2,24 @@ from pylattica.core.simulation_state import SimulationState from pylattica.discrete import PhaseSet from pylattica.structures.square_grid.grid_setup import DiscreteGridSetup -from pylattica.visualization import SquareGridArtist2D, SquareGridArtist3D, ResultArtist, DiscreteCellArtist +from pylattica.visualization import ( + SquareGridArtist2D, + SquareGridArtist3D, + ResultArtist, + DiscreteCellArtist, +) from pylattica.models.game_of_life import Life, GameOfLifeController from pylattica.discrete.state_constants import DISCRETE_OCCUPANCY import os import random + def test_step_artist(): phases = PhaseSet(["dead", "alive"]) setup = DiscreteGridSetup(phases) simulation = setup.setup_noise(10, ["dead", "alive"]) - controller = GameOfLifeController(structure = simulation.structure, - variant=Life) + controller = GameOfLifeController(structure=simulation.structure, variant=Life) runner = SynchronousRunner(parallel=False) result = runner.run(simulation.state, controller, 10, verbose=False) cell_artist = DiscreteCellArtist.from_discrete_state(result.last_step) @@ -23,12 +28,12 @@ def test_step_artist(): artist.save_img(result.last_step, "tmp.png") os.remove("tmp.png") + def test_result_artist(): phases = PhaseSet(["dead", "alive"]) setup = DiscreteGridSetup(phases) simulation = setup.setup_noise(10, ["dead", "alive"]) - controller = GameOfLifeController(structure = simulation.structure, - variant=Life) + controller = GameOfLifeController(structure=simulation.structure, variant=Life) runner = SynchronousRunner(parallel=False) result = runner.run(simulation.state, controller, 10, verbose=False) cell_artist = DiscreteCellArtist.from_discrete_result(result) @@ -38,16 +43,14 @@ def test_result_artist(): result_artist.to_gif("out.gif", cell_size=5) os.remove("out.gif") -def test_step_artist_3D(): +def test_step_artist_3D(): class SimpleController(BasicController): - def __init__(self): pass def get_state_update(self, site_id: int, prev_state: SimulationState): - return { DISCRETE_OCCUPANCY: random.choice(["dead", "alive"])} - + return {DISCRETE_OCCUPANCY: random.choice(["dead", "alive"])} phases = PhaseSet(["dead", "alive"]) setup = DiscreteGridSetup(phases, dim=3) @@ -57,4 +60,4 @@ def get_state_update(self, site_id: int, prev_state: SimulationState): cell_artist = DiscreteCellArtist.from_discrete_state(result.last_step) artist = SquareGridArtist3D(simulation.structure, cell_artist) - artist.get_img(result.last_step, cell_size=5) \ No newline at end of file + artist.get_img(result.last_step, cell_size=5)