Skip to content
Merged

Dev #22

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ requires-python = '>=3.8'
dependencies = [
'numpy >= 1.21.5',
'matplotlib >= 3.5.1',
'scipy >= 1.8.0',
'tqdm >= 4.63.0',
'pymatgen >= 2022.3.29',
'plotly >= 5.6.0',
Expand Down
8 changes: 6 additions & 2 deletions src/pylattica/core/periodic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, lattice: Lattice):
self.lattice = lattice
self.dim = lattice.dim
self._sites = {}
self.site_ids = []
self._site_ids = []
self._location_lookup = {}
self._offset_vector = np.array([VEC_OFFSET for _ in range(self.dim)])

Expand All @@ -127,6 +127,10 @@ def as_dict(self):
"_sites": copied,
}

@property
def site_ids(self):
return copy.copy(self._site_ids)

@classmethod
def from_dict(cls, d):
struct = cls(Lattice.from_dict(d["lattice"]))
Expand Down Expand Up @@ -182,7 +186,7 @@ def add_site(self, site_class: str, location: Tuple[float]) -> int:
}

self._location_lookup[offset_periodized_coords] = new_site_id
self.site_ids.append(new_site_id)
self._site_ids.append(new_site_id)
return new_site_id

def site_at(self, location: Tuple[float]) -> Dict:
Expand Down
1 change: 1 addition & 0 deletions src/pylattica/core/runner/asynchronous_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _add_sites_to_queue():
state_updates = merge_updates(state_updates, site_id=site_id)
live_state.batch_update(state_updates)
site_queue.extend(next_sites)

result.add_step(state_updates)

if len(site_queue) == 0:
Expand Down
57 changes: 53 additions & 4 deletions src/pylattica/core/simulation_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from monty.serialization import dumpfn, loadfn
import datetime
from .simulation_state import SimulationState
from .constants import GENERAL, SITES

import copy


class SimulationResult:
Expand All @@ -23,13 +26,21 @@ def from_file(cls, fpath):
@classmethod
def from_dict(cls, res_dict):
diffs = res_dict["diffs"]
res = cls(SimulationState.from_dict(res_dict["initial_state"]))
compress_freq = res_dict.get("compress_freq", 1)
res = cls(
SimulationState.from_dict(res_dict["initial_state"]),
compress_freq=compress_freq,
)
for diff in diffs:
formatted = {int(k): v for k, v in diff.items() if k != "GENERAL"}
res.add_step(formatted)
if SITES in diff:
diff[SITES] = {int(k): v for k, v in diff[SITES].items()}
if GENERAL not in diff and SITES not in diff:
diff = {int(k): v for k, v in diff.items()}
res.add_step(diff)

return res

def __init__(self, starting_state: SimulationState):
def __init__(self, starting_state: SimulationState, compress_freq: int = 1):
"""Initializes a SimulationResult with the specified starting_state.

Parameters
Expand All @@ -38,6 +49,7 @@ def __init__(self, starting_state: SimulationState):
The state with which the simulation started.
"""
self.initial_state = starting_state
self.compress_freq = compress_freq
self._diffs: list[dict] = []
self._stored_states = {}

Expand All @@ -60,6 +72,10 @@ def add_step(self, updates: Dict[int, Dict]) -> None:
"""
self._diffs.append(updates)

@property
def original_length(self) -> int:
return int(len(self) * self.compress_freq)

def __len__(self) -> int:
return len(self._diffs) + 1

Expand Down Expand Up @@ -120,6 +136,9 @@ def get_step(self, step_no) -> SimulationState:
The simulation state at the requested step.
"""

# if step_no % self.compress_freq != 0:
# raise ValueError(f"Cannot retrieve step no {step_no} because this result has been compressed with sampling frequency {self.compress_freq}")

stored = self._stored_states.get(step_no)
if stored is not None:
return stored
Expand All @@ -133,6 +152,7 @@ def as_dict(self):
return {
"initial_state": self.initial_state.as_dict(),
"diffs": self._diffs,
"compress_freq": self.compress_freq,
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
}
Expand All @@ -152,3 +172,32 @@ def to_file(self, fpath: str = None) -> None:

dumpfn(self, fpath)
return fpath


def compress_result(result: SimulationResult, num_steps: int):
i_state = result.first_step
# total steps is the actual number of diffs stored, not the number of original simulation steps taken
total_steps = len(result)
if num_steps >= total_steps:
raise ValueError(
f"Cannot upsample SimulationResult of length {total_steps} to size {num_steps}."
)

exact_sample_freq = total_steps / (num_steps)
# print(total_steps, current_sample_freq)
total_compress_freq = exact_sample_freq * result.compress_freq
compressed_result = SimulationResult(i_state, compress_freq=total_compress_freq)

live_state = SimulationState(copy.deepcopy(i_state._state))
added = 0
next_sample_step = exact_sample_freq
for i, diff in enumerate(result._diffs):
curr_step = i + 1
live_state.batch_update(diff)
# if curr_step % current_sample_freq == 0:
if curr_step > next_sample_step:
# print(curr_step)
added += 1
compressed_result.add_step(live_state.as_state_update())
next_sample_step += exact_sample_freq
return compressed_result
10 changes: 8 additions & 2 deletions src/pylattica/core/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,18 @@ def get_site_state(self, site_id: int) -> Dict:
"""
return self._state[SITES].get(site_id)

def get_general_state(self) -> Dict:
def get_general_state(self, key: str = None, default=None) -> Dict:
"""Returns the general state.

Returns
-------
Dict
The general state.
"""
return self._state.get(GENERAL)
if key is None:
return copy.deepcopy(self._state.get(GENERAL))
else:
return copy.deepcopy(self._state.get(GENERAL)).get(key, default)

def set_general_state(self, updates: Dict) -> None:
"""Updates the general state with the keys and values provided by the updates parameter.
Expand Down Expand Up @@ -175,5 +178,8 @@ def copy(self) -> SimulationState:
"""
return SimulationState(self._state)

def as_state_update(self) -> Dict:
return copy.deepcopy(self._state)

def __eq__(self, other: SimulationState) -> bool:
return self._state == other._state
1 change: 1 addition & 0 deletions src/pylattica/models/game_of_life/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .controller import GameOfLifeController, Life, Seeds, Anneal, Diamoeba, Maze
from .life_phase_set import LIFE_PHASE_SET
3 changes: 3 additions & 0 deletions src/pylattica/models/game_of_life/life_phase_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ...discrete.phase_set import PhaseSet

LIFE_PHASE_SET = PhaseSet(["alive", "dead"])
5 changes: 2 additions & 3 deletions src/pylattica/structures/square_grid/grid_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,8 @@ def setup_random_sites(

while num_sites_planted < num_sites_desired:
if total_attempts > 1000 * num_sites_desired:
raise RuntimeError(
f"Too many nucleation sites at the specified buffer: {total_attempts} made at placing nuclei"
)
print(f"Only able to place {num_sites_planted} in {total_attempts} attempts")
break

rand_site = random.choice(all_sites)
rand_site_id = rand_site[SITE_ID]
Expand Down
37 changes: 21 additions & 16 deletions src/pylattica/visualization/result_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,24 @@

from PIL import Image

from typing import Callable

_dsr_globals = {}


def default_annotation_builder(step, step_no):
return f"Step {step_no}"


class ResultArtist:
"""A class for rendering simulation results as animated GIFs."""

def __init__(self, step_artist: StructureArtist, result: SimulationResult):
def __init__(
self,
step_artist: StructureArtist,
result: SimulationResult,
annotation_builder: Callable = default_annotation_builder,
):
"""Instantiates the ResultArtist class.

Parameters
Expand All @@ -26,6 +37,7 @@ def __init__(self, step_artist: StructureArtist, result: SimulationResult):
"""
self._step_artist = step_artist
self.result = result
self.annotation_builder = annotation_builder

def _get_images(self, **kwargs):
draw_freq = kwargs.get("draw_freq", 1)
Expand All @@ -47,21 +59,18 @@ def _get_images(self, **kwargs):
with mp.get_context("fork").Pool(PROCESSES) as pool:
params = []
for idx in indices:
label = f"Step {idx}"
step_kwargs = {**kwargs, "label": label}
step = self.result.get_step(idx)
label = self.annotation_builder(step, idx)
step_kwargs = {**kwargs, "label": label}

params.append([step, step_kwargs])

for img in pool.starmap(_get_img_parallel, params):
imgs.append(img)

return imgs

def jupyter_show_step(
self,
step_no: int,
cell_size=20,
) -> None:
def jupyter_show_step(self, step_no: int, cell_size=20, **kwargs) -> None:
"""In a jupyter notebook environment, visualizes the step as a color coded phase grid.

Parameters
Expand All @@ -71,17 +80,13 @@ def jupyter_show_step(
cell_size : int, optional
The size of each simulation cell, in pixels, by default 20
"""
label = f"Step {step_no}" # pragma: no cover
step = self.result.get_step(step_no) # pragma: no cover
label = self.annotation_builder(step, step_no)
self._step_artist.jupyter_show(
step, label=label, cell_size=cell_size
step, label=label, cell_size=cell_size, **kwargs
) # pragma: no cover

def jupyter_play(
self,
cell_size: int = 20,
wait: int = 1,
):
def jupyter_play(self, cell_size: int = 20, wait: int = 1, **kwargs):
"""In a jupyter notebook environment, plays the simulation visualization back by showing a
series of images with {wait} seconds between each one.

Expand All @@ -94,7 +99,7 @@ def jupyter_play(
"""
from IPython.display import clear_output, display # pragma: no cover

imgs = self._get_images(cell_size=cell_size) # pragma: no cover
imgs = self._get_images(cell_size=cell_size, **kwargs) # pragma: no cover
for img in imgs: # pragma: no cover
clear_output() # pragma: no cover
display(img) # pragma: no cover
Expand Down
66 changes: 39 additions & 27 deletions src/pylattica/visualization/square_grid_artist_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,27 @@ def _draw_image(self, state: SimulationState, **kwargs):
label = kwargs.get("label", None)
cell_size = kwargs.get("cell_size", 20)

show_legend = kwargs.get("show_legend", True)

legend = self.cell_artist.get_legend(state)
legend_order = sorted(legend.keys())
state_size = int(self.structure.lattice.vec_lengths[0])
width = state_size + 6

legend_border_width = 5
height = max(state_size, len(legend) + 1)
if show_legend:
width = state_size + 6
legend_border_width = 5
height = max(state_size, len(legend) + 1)
img_width = width * cell_size + legend_border_width
img_height = height * cell_size
else:
width = state_size
height = state_size
img_width = width * cell_size
img_height = height * cell_size

img = Image.new(
"RGB",
(width * cell_size + legend_border_width, height * cell_size),
(img_width, img_height),
"black",
) # Create a new black image

Expand All @@ -39,29 +50,30 @@ def _draw_image(self, state: SimulationState, **kwargs):
for p_y in range(p_y_start, p_y_start + cell_size):
pixels[p_x, p_y] = cell_color

count = 0
legend_hoffset = int(cell_size / 4)
legend_voffset = int(cell_size / 4)

for p_y in range(height * cell_size):
for p_x in range(0, legend_border_width):
x = state_size * cell_size + p_x
pixels[x, p_y] = (255, 255, 255)

for phase in legend_order:
color = legend.get(phase)
p_col_start = state_size * cell_size + legend_border_width + legend_hoffset
p_row_start = count * cell_size + legend_voffset
for p_x in range(p_col_start, p_col_start + cell_size):
for p_y in range(p_row_start, p_row_start + cell_size):
pixels[p_x, p_y] = color

legend_label_loc = (
int(p_col_start + cell_size + cell_size / 4),
int(p_row_start + cell_size / 4),
)
draw.text(legend_label_loc, phase, (255, 255, 255))
count += 1
if show_legend:
count = 0
legend_hoffset = int(cell_size / 4)
legend_voffset = int(cell_size / 4)

for p_y in range(height * cell_size):
for p_x in range(0, legend_border_width):
x = state_size * cell_size + p_x
pixels[x, p_y] = (255, 255, 255)

for phase in legend_order:
color = legend.get(phase)
p_col_start = state_size * cell_size + legend_border_width + legend_hoffset
p_row_start = count * cell_size + legend_voffset
for p_x in range(p_col_start, p_col_start + cell_size):
for p_y in range(p_row_start, p_row_start + cell_size):
pixels[p_x, p_y] = color

legend_label_loc = (
int(p_col_start + cell_size + cell_size / 4),
int(p_row_start + cell_size / 4),
)
draw.text(legend_label_loc, phase, (255, 255, 255))
count += 1

if label is not None:
draw.text((5, 5), label, (255, 255, 255))
Expand Down
Loading
Loading