diff --git a/pyproject.toml b/pyproject.toml index ff288e4..3fa1f52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "pydantic", "pytest", "ipykernel", + "jupyter", "torch", "quimb", "numpy", diff --git a/wavefunction_branching/evolve_and_branch_finite.py b/wavefunction_branching/evolve_and_branch_finite.py index 0272f2a..87c1ef4 100644 --- a/wavefunction_branching/evolve_and_branch_finite.py +++ b/wavefunction_branching/evolve_and_branch_finite.py @@ -119,6 +119,9 @@ def __init__(self): self.branch_values: list[dict] = [] # This is what gets updated when we add measurements self.df_branch_values: pd.DataFrame | None = None # A dataframe computed from branch_values self.df_combined_values: pd.DataFrame | None = None # prob-weigted average measurments + # Track branching events including unsampled branches + self.branching_events: list[dict] = [] # Records of each branching event + self.df_branching_events: pd.DataFrame | None = None # DataFrame of branching events def add_measurements_tebd(self, engine, extra_measurements=None, **kwargs): """Update self.branch_values with measurements from a TEBD engine (measures strictly more than add_measurements)""" @@ -175,9 +178,156 @@ def combine_measurements(self): self.df_combined_values = pd.DataFrame(combined) return self.df_combined_values + def add_branching_event( + self, + time, + n_candidate_branches, + n_sampled_branches, + candidate_probs, + sampled_indices, + parent_prob, + site, + trace_distances=None, + ): + """ + Record a branching event with information about both sampled and unsampled branches. + + Parameters: + ----------- + time : float + Time at which branching occurred + n_candidate_branches : int + Total number of candidate branches before sampling + n_sampled_branches : int + Number of branches that survived sampling + candidate_probs : np.ndarray + Array of probabilities for all candidate branches + sampled_indices : np.ndarray + Indices of branches that were sampled + parent_prob : float + Probability of the parent branch before splitting + site : int + Site index where branching occurred + trace_distances : dict, optional + Dictionary of trace distances and other quality metrics + """ + event = { + "time": time, + "site": site, + "n_candidate_branches": n_candidate_branches, + "n_sampled_branches": n_sampled_branches, + "n_discarded_branches": n_candidate_branches - n_sampled_branches, + "parent_prob": parent_prob, + "total_candidate_prob": np.sum(candidate_probs), + "sampled_prob": np.sum(candidate_probs[sampled_indices]), + "discarded_prob": np.sum(candidate_probs) - np.sum(candidate_probs[sampled_indices]), + "candidate_probs_mean": np.mean(candidate_probs), + "candidate_probs_std": np.std(candidate_probs), + "candidate_probs_min": np.min(candidate_probs), + "candidate_probs_max": np.max(candidate_probs), + } + + if trace_distances is not None: + event.update(trace_distances) + + self.branching_events.append(event) + + def branching_events_to_dataframe(self): + """Convert branching events list to DataFrame""" + if len(self.branching_events) > 0: + self.df_branching_events = pd.DataFrame.from_records(self.branching_events) + return self.df_branching_events + + def get_cumulative_branch_counts(self): + """ + Calculate cumulative number of branches over time. + + Returns: + -------- + pd.DataFrame with columns: + - time: measurement times + - n_sampled_branches_cumulative: total sampled branches existing at each time + - n_candidate_branches_cumulative: total candidate branches that existed at each time + """ + if self.df_branching_events is None: + self.branching_events_to_dataframe() + + if self.df_combined_values is None: + self.combine_measurements() + + # Ensure branch_values_to_dataframe has been called + if self.df_branch_values is None: + self.branch_values_to_dataframe() + + if self.df_branching_events is None or len(self.df_branching_events) == 0: + # If no branching events, return dataframe based on measurements + if self.df_combined_values is None or len(self.df_combined_values) == 0: + return None + if self.df_branch_values is None or len(self.df_branch_values) == 0: + return None + times = sorted(self.df_combined_values["time"].unique()) + cumulative_data = [] + for t in times: + n_actual = len(self.df_branch_values[self.df_branch_values["time"] == t]) + cumulative_data.append( + { + "time": t, + "n_sampled_branches_actual": n_actual, + "n_sampled_branches_cumulative": n_actual, + "n_candidate_branches_cumulative": n_actual, + "n_discarded_branches_cumulative": 0, + } + ) + return pd.DataFrame(cumulative_data) + + # Get all unique times from measurements + if self.df_branch_values is None or len(self.df_branch_values) == 0: + return None + times = sorted(self.df_combined_values["time"].unique()) + + cumulative_data = [] + + for t in times: + # Count actual branches (from measurements) at this time + n_actual = len(self.df_branch_values[self.df_branch_values["time"] == t]) + + # Count total candidate branches up to this time + events_before = self.df_branching_events[self.df_branching_events["time"] <= t] + + if len(events_before) > 0: + # Calculate cumulative branches + # Start with 1 (initial branch) + # Each branching event: creates n_candidate_branches new branches, replaces 1 parent + # So net change per event: -1 + n_sampled_branches (only counted if we sample) + # But for candidates: -1 + n_candidate_branches + n_candidates_total = 1 # Start with 1 branch + n_sampled_total = 1 + for _, event in events_before.iterrows(): + n_candidates_total = n_candidates_total - 1 + event["n_candidate_branches"] + n_sampled_total = n_sampled_total - 1 + event["n_sampled_branches"] + n_discarded_total = events_before["n_discarded_branches"].sum() + else: + n_candidates_total = 1 # Start with 1 branch + n_sampled_total = 1 + n_discarded_total = 0 + + cumulative_data.append( + { + "time": t, + "n_sampled_branches_actual": n_actual, + "n_sampled_branches_cumulative": n_sampled_total, + "n_candidate_branches_cumulative": n_candidates_total, + "n_discarded_branches_cumulative": n_discarded_total, + } + ) + + return pd.DataFrame(cumulative_data) + def merge_with_other(self, other): self.branch_values += other.branch_values + self.branching_events += other.branching_events # NEW self.branch_values_to_dataframe() + self.branching_events_to_dataframe() # NEW self.combine_measurements() return self @@ -190,6 +340,235 @@ def _repr_html_(self): return self.df_branch_values._repr_html_() # type: ignore +def plot_branch_counts_over_time(branch_values, name="", outfolder=None, save=True): + """ + Plot the number of branches over time, including both sampled and unsampled. + + Parameters: + ----------- + branch_values : BranchValues + The BranchValues object containing branching event data + name : str + Name for the plot title and filename + outfolder : Path or str + Directory to save the plot + save : bool + Whether to save the plot to file + """ + # Get cumulative branch counts + df_cumulative = branch_values.get_cumulative_branch_counts() + + if df_cumulative is None or len(df_cumulative) == 0: + print("No branching event data available to plot") + return + + # Create figure + plt.figure(figsize=(12, 6), dpi=150) + + # Plot cumulative candidate branches (total that ever existed) + plt.plot( + df_cumulative["time"], + df_cumulative["n_candidate_branches_cumulative"], + label="Total candidate branches (including unsampled)", + color="#d62728", + linewidth=2, + linestyle="--", + alpha=0.8, + ) + + # Plot cumulative sampled branches + plt.plot( + df_cumulative["time"], + df_cumulative["n_sampled_branches_cumulative"], + label="Sampled branches (kept)", + color="#2ca02c", + linewidth=2.5, + ) + + # Plot actual active branches (from measurements) + plt.plot( + df_cumulative["time"], + df_cumulative["n_sampled_branches_actual"], + label="Active branches at time t", + color="#1f77b4", + linewidth=2, + marker="o", + markersize=3, + alpha=0.7, + ) + + # Plot discarded branches (as shaded region) + plt.fill_between( + df_cumulative["time"], + df_cumulative["n_sampled_branches_cumulative"], + df_cumulative["n_candidate_branches_cumulative"], + alpha=0.3, + color="#ff7f0e", + label="Discarded branches (unsampled)", + ) + + plt.xlabel("Time", fontsize=12) + plt.ylabel("Number of Branches", fontsize=12) + plt.title(f"Branch Counts Over Time: {name}", fontsize=14) + plt.legend(loc="best", fontsize=10) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + if save and outfolder is not None: + from pathlib import Path + + outfolder = Path(outfolder) + outfolder.mkdir(exist_ok=True, parents=True) + + plt.savefig(outfolder / f"{NOW}_{name}_branch_counts.pdf") + plt.savefig(outfolder / f"{NOW}_{name}_branch_counts.png") + print(f"Saved branch counts plot to {outfolder}") + + plt.show() + + +def plot_branching_events_detail(branch_values, name="", outfolder=None, save=True): + """ + Plot detailed information about each branching event. + + Shows when branching occurred, how many branches were created vs sampled. + """ + branch_values.branching_events_to_dataframe() + df_events = branch_values.df_branching_events + + if df_events is None or len(df_events) == 0: + print("No branching event data available") + return + + fig, axes = plt.subplots(2, 1, figsize=(12, 8), dpi=150, sharex=True) + + # Top plot: Number of branches at each event + ax1 = axes[0] + width = 0.35 + x = np.arange(len(df_events)) + + ax1.bar( + x - width / 2, + df_events["n_candidate_branches"], + width, + label="Candidate branches", + color="#d62728", + alpha=0.7, + ) + ax1.bar( + x + width / 2, + df_events["n_sampled_branches"], + width, + label="Sampled branches", + color="#2ca02c", + alpha=0.7, + ) + + ax1.set_ylabel("Number of Branches", fontsize=11) + ax1.set_title(f"Branching Events Detail: {name}", fontsize=13) + ax1.legend() + ax1.grid(True, alpha=0.3, axis="y") + + # Bottom plot: Probability distribution + ax2 = axes[1] + ax2.bar( + x, + df_events["total_candidate_prob"], + width * 1.5, + label="Total candidate prob", + color="#ff7f0e", + alpha=0.5, + ) + ax2.bar( + x, df_events["sampled_prob"], width * 1.5, label="Sampled prob", color="#1f77b4", alpha=0.7 + ) + + ax2.set_xlabel("Branching Event", fontsize=11) + ax2.set_ylabel("Probability", fontsize=11) + ax2.legend() + ax2.grid(True, alpha=0.3, axis="y") + + # Set x-axis labels with times + ax2.set_xticks(x) + ax2.set_xticklabels([f"{t:.2f}" for t in df_events["time"]], rotation=45) + + plt.tight_layout() + + if save and outfolder is not None: + from pathlib import Path + + outfolder = Path(outfolder) + outfolder.mkdir(exist_ok=True, parents=True) + + plt.savefig(outfolder / f"{NOW}_{name}_branching_events.pdf") + plt.savefig(outfolder / f"{NOW}_{name}_branching_events.png") + print(f"Saved branching events plot to {outfolder}") + + plt.show() + + +def plot_sampling_efficiency(branch_values, name="", outfolder=None, save=True): + """ + Plot the sampling efficiency over time. + Shows what fraction of candidate branches are being kept. + """ + branch_values.branching_events_to_dataframe() + df_events = branch_values.df_branching_events + + if df_events is None or len(df_events) == 0: + print("No branching event data available") + return + + # Calculate sampling efficiency + df_events["sampling_efficiency"] = ( + df_events["n_sampled_branches"] / df_events["n_candidate_branches"] + ) + df_events["prob_efficiency"] = df_events["sampled_prob"] / df_events["total_candidate_prob"] + + plt.figure(figsize=(12, 6), dpi=150) + + plt.plot( + df_events["time"], + df_events["sampling_efficiency"], + marker="o", + label="Branch count efficiency", + linewidth=2, + markersize=8, + color="#1f77b4", + ) + plt.plot( + df_events["time"], + df_events["prob_efficiency"], + marker="s", + label="Probability efficiency", + linewidth=2, + markersize=8, + color="#2ca02c", + ) + + plt.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5, label="No sampling (100%)") + + plt.xlabel("Time", fontsize=12) + plt.ylabel("Sampling Efficiency (sampled / candidates)", fontsize=12) + plt.title(f"Sampling Efficiency Over Time: {name}", fontsize=14) + plt.legend(loc="best", fontsize=10) + plt.grid(True, alpha=0.3) + plt.ylim([0, 1.1]) + plt.tight_layout() + + if save and outfolder is not None: + from pathlib import Path + + outfolder = Path(outfolder) + outfolder.mkdir(exist_ok=True, parents=True) + + plt.savefig(outfolder / f"{NOW}_{name}_sampling_efficiency.pdf") + plt.savefig(outfolder / f"{NOW}_{name}_sampling_efficiency.png") + print(f"Saved sampling efficiency plot to {outfolder}") + + plt.show() + + def bring_into_theta_form(theta, formL, formR, sL, sR): return einsum((sL) ** (1.0 - formL), theta, (sR) ** (1.0 - formR), "l, b p l r, r -> b p l r") @@ -264,8 +643,10 @@ class BranchingMPS: def __init__( self, tebd_engine: tenpy.TEBDEngine - | ExpMPOEvolution, # The TEBD engine to use for time evolution - cfg: BranchingMPSConfig, # The configuration for splitting the wavefunction into branches + | ExpMPOEvolution + | None = None, # The TEBD engine to use for time evolution (None for unsampled branches) + cfg: BranchingMPSConfig + | None = None, # The configuration for splitting the wavefunction into branches branch_values: BranchValues | None = None, # The structure for storing the measurements of all the branches over time branch_function: Callable @@ -281,10 +662,37 @@ def __init__( name="", wandb_project: str | None = None, info={}, + # New parameters for unsampled branches + sampled: bool = True, # Whether this branch was sampled (False for discarded branches) + created_time: float + | None = None, # Time when branch was created (evolved_time at creation) + branching_site: int | None = None, # Site where branching occurred (for child nodes) + prob_override: float + | None = None, # Probability override (for unsampled branches without tebd_engine) + norm_override: float + | None = None, # Norm override (for unsampled branches without tebd_engine) ): - self.tebd_engine = tebd_engine # The TEBD engine to use for time evolution - self.norm = self.tebd_engine.psi.norm - self.prob = abs(self.norm**2) + self.sampled = sampled # Track whether this branch was sampled or discarded + self.created_time = created_time # Time when branch was created + self.branching_site = branching_site # Site where branching occurred + + # For unsampled branches, tebd_engine may be None + if tebd_engine is None: + # This is an unsampled branch - use override values + self.tebd_engine = None + if prob_override is not None: + self.prob = prob_override + self.norm = np.sqrt(prob_override) if norm_override is None else norm_override + elif norm_override is not None: + self.norm = norm_override + self.prob = abs(norm_override**2) + else: + self.prob = 0.0 + self.norm = 0.0 + else: + self.tebd_engine = tebd_engine # The TEBD engine to use for time evolution + self.norm = self.tebd_engine.psi.norm + self.prob = abs(self.norm**2) self.cfg = cfg # The configuration for splitting the wavefunction into branches self.pickle_file = pickle_file self.outfolder = outfolder @@ -297,6 +705,14 @@ def __init__( else: self.branch_values = branch_values + # Track root creation walltime for relative time calculations + # This needs to be set before we use it, so handle it after parent is assigned + # (We'll set it at the end of __init__) + pass + + # Track final walltime (updated when branch finishes or at save time) + self.final_walltime: datetime | None = None + # a BrancingMPS from which we split (or None) self.parent = parent @@ -306,8 +722,14 @@ def __init__( else: self.children = children + # List of unsampled (discarded) branches from accepted decompositions + self.unsampled_children: list[BranchingMPS] = [] + if max_children is None: - self.max_children = int(self.cfg.max_branches) + if cfg is not None: + self.max_children = int(cfg.max_branches) + else: + self.max_children = 8 # Default else: self.max_children = max_children @@ -316,19 +738,30 @@ def __init__( print(f"Starting name = {self.name}, ID = {self.ID}") self.trunc_err = tenpy.linalg.truncation.TruncationError(eps=0.0, ov=1.0) - self.dt = tebd_engine.options["dt"] if "dt" in tebd_engine.options else 1 + if self.tebd_engine is not None: + self.dt = tebd_engine.options["dt"] if "dt" in tebd_engine.options else 1 + else: + self.dt = 1 # Default for unsampled branches if self.parent is None: self.costFun_LM_MR_trace_distance = 0.0 self.global_reconstruction_error_trace_distance = 0.0 - self.t_last_attempted_branching_sites = np.zeros(len(self.tebd_engine.psi.chi)) + if self.tebd_engine is not None: + self.t_last_attempted_branching_sites = np.zeros(len(self.tebd_engine.psi.chi)) + else: + # For unsampled root branch, create empty array (will be initialized later if needed) + self.t_last_attempted_branching_sites = np.array([]) self.site_last_attempted_branching: int | None = None + if self.tebd_engine is not None: + num_sites = len(self.tebd_engine.psi.chi) + else: + num_sites = 0 # For unsampled branches, will be set if needed self.last_attempted_branching_trunc_bond_dims_sites: list[ None | tuple[int, int, int] - ] = [None] * len(self.tebd_engine.psi.chi) + ] = [None] * num_sites self.last_attempted_branching_trunc_trace_distance_sites: list[None | float] = [ None - ] * len(self.tebd_engine.psi.chi) + ] * num_sites self.trace_distances = {} self.trace_distances["estimated_interference_error"] = 0.0 self.trace_distances["global_reconstruction_error_trace_distance"] = ( @@ -394,10 +827,27 @@ def __init__( self.depth = self.parent.depth + 1 self.branching_attempts = self.parent.branching_attempts - self.evolved_time = float(abs(self.tebd_engine.evolved_time)) + if self.tebd_engine is not None: + self.evolved_time = float(abs(self.tebd_engine.evolved_time)) + else: + # For unsampled branches, use created_time if available + self.evolved_time = self.created_time if self.created_time is not None else 0.0 + self.t_last_attempted_branching = self.evolved_time self.finished = False + # Set created_time if not provided (for sampled branches) + if self.created_time is None: + self.created_time = self.evolved_time + + # Set root_created_walltime (after parent is assigned) + if self.parent is None: + # This is the root branch - store its creation time as reference + self.root_created_walltime = self.created_walltime + else: + # Inherit root creation time from parent + self.root_created_walltime = self.parent.root_created_walltime + def branch_and_sample( self, formL=1.0, @@ -721,6 +1171,8 @@ def branch_and_sample( print(f"{self.ID}No further branch_indices were selected: len(branch_indices) = 0") print(f"{self.ID}TERMINATING.") self.finished = True + # Update final walltime when branch terminates + self.final_walltime = datetime.now() self.branch_values.add_measurements_tebd( self.tebd_engine, extra_measurements=self.trace_distances ) @@ -747,6 +1199,25 @@ def branch_and_sample( elif len(branch_indices) > 1: print(f"{self.ID}Creating {num_kept_branches} children nodes.") + # Record the branching event including unsampled branches + if hasattr(self, "branch_values") and self.branch_values is not None: + self.branch_values.add_branching_event( + time=self.evolved_time, + n_candidate_branches=num_candidates, + n_sampled_branches=len(survivor_indices), + candidate_probs=branch_probs, + sampled_indices=survivor_indices, + parent_prob=abs(self.norm**2), + site=coarsegrain_from, + trace_distances={ + "costFun_LM_MR_trace_distance": costFun_LM_MR_trace_distance, + "global_reconstruction_error_trace_distance": global_reconstruction_error_trace_distance, + }, + ) + print( + f"{self.ID}Branching summary: Total candidates={num_candidates}, Sampled={len(survivor_indices)}, Discarded={num_candidates - len(survivor_indices)}" + ) + if np.isclose(total_prob_survived, 0.0): print( f"{self.ID}ERROR: Total probability of surviving branches is zero, cannot renormalize weights. TERMINATING." @@ -792,12 +1263,46 @@ def branch_and_sample( ID=self.ID + f"{branch_indices[i]}|", n_times_saved=self.n_times_saved, name=self.name, + sampled=True, # Explicitly mark as sampled + created_time=self.evolved_time, + branching_site=coarsegrain_from, ) ) print( f"{self.ID} Child {i} (orig index {survivor_indices[i]}): prob={prob:.4f}, prob={child_prob:.6f}, max_children={child_max_children}" ) + # Create unsampled branch nodes for discarded branches + discarded_indices = np.setdiff1d(candidate_indices, survivor_indices) + for discarded_idx in discarded_indices: + discarded_prob = branch_probs[discarded_idx] + # Calculate probability similar to sampled branches + if sampling_occurred: + discarded_child_prob = discarded_prob / total_prob_survived + else: + discarded_child_prob = discarded_prob + + unsampled_child = BranchingMPS( + tebd_engine=None, # No engine for unsampled branches + cfg=self.cfg, + branch_values=self.branch_values, + branch_function=self.branch_function, + parent=self, + max_children=0, # Unsampled branches don't get budget + ID=self.ID + f"{discarded_idx}|", + n_times_saved=self.n_times_saved, + name=self.name, + sampled=False, # Mark as unsampled + created_time=self.evolved_time, + branching_site=coarsegrain_from, + prob_override=discarded_child_prob + * abs(self.norm**2), # Parent prob * child prob + ) + self.unsampled_children.append(unsampled_child) + print( + f"{self.ID} Unsampled child (orig index {discarded_idx}): prob={discarded_prob:.4f}, child_prob={discarded_child_prob:.6f}" + ) + # Verify prob conservation child_prob_sum = sum(abs(c.norm**2) for c in self.children) print( @@ -923,11 +1428,102 @@ def count_leaves(self): else: return sum([child.count_leaves() for child in self.children]) + def to_tree_dict(self): + """ + Convert BranchingMPS tree structure to a dictionary, stripping out tensors and engines. + This preserves the tree structure with all metadata for visualization and analysis. + + Returns: + -------- + dict: Dictionary representation of the tree node + """ + # Calculate relative times in seconds from root creation + root_created = ( + self.root_created_walltime + if hasattr(self, "root_created_walltime") + else self.created_walltime + ) + + created_walltime_rel = ( + (self.created_walltime - root_created).total_seconds() + if self.created_walltime + else None + ) + + if self.final_walltime is not None: + final_walltime_rel = (self.final_walltime - root_created).total_seconds() + else: + final_walltime_rel = None + + tree_dict = { + "ID": self.ID, + "sampled": self.sampled, + "created_time": self.created_time, + "evolved_time": self.evolved_time, + "created_walltime": self.created_walltime.isoformat() + if self.created_walltime + else None, + "created_walltime_rel_seconds": created_walltime_rel, + "final_walltime": self.final_walltime.isoformat() if self.final_walltime else None, + "final_walltime_rel_seconds": final_walltime_rel, + "prob": float(self.prob), + "norm": float(self.norm), + "depth": self.depth, + "max_children": self.max_children, + "finished": self.finished, + "synchronized": self.synchronized if hasattr(self, "synchronized") else False, + "branching_site": self.branching_site, + "costFun_LM_MR_trace_distance": float(self.costFun_LM_MR_trace_distance) + if hasattr(self, "costFun_LM_MR_trace_distance") + else 0.0, + "global_reconstruction_error_trace_distance": float( + self.global_reconstruction_error_trace_distance + ) + if hasattr(self, "global_reconstruction_error_trace_distance") + else 0.0, + "has_tebd_engine": self.tebd_engine is not None, + "n_children": len(self.children), + "n_unsampled_children": len(self.unsampled_children) + if hasattr(self, "unsampled_children") + else 0, + "children": [child.to_tree_dict() for child in self.children], + "unsampled_children": [ + child.to_tree_dict() + for child in ( + self.unsampled_children if hasattr(self, "unsampled_children") else [] + ) + ], + } + return tree_dict + + def _update_final_walltime_recursive(self, current_time: datetime): + """ + Recursively update final_walltime for all branches in the tree. + If a branch is finished, keep its final_walltime. Otherwise, update to current_time. + """ + if self.finished and self.final_walltime is not None: + # Keep existing final_walltime if branch already finished + pass + else: + # Update to current time (at save time) + self.final_walltime = current_time + + # Update children and unsampled children + for child in self.children: + child._update_final_walltime_recursive(current_time) + if hasattr(self, "unsampled_children"): + for child in self.unsampled_children: + child._update_final_walltime_recursive(current_time) + def save(self, final=False): if self.parent is not None: self.parent.save() else: t0 = time.time() + # Update final_walltime for all branches at save time + current_walltime = datetime.now() + self._update_final_walltime_recursive(current_walltime) + if self.pickle_file is not None: branchvals_file = str(self.pickle_file).split(".pkl")[0] + "_branch_values.pkl" branchvals_file = ( @@ -938,6 +1534,23 @@ def save(self, final=False): with open(branchvals_file, "wb") as f: pickle.dump(self.branch_values, f) f.close() + + # Save tree structure (without tensors) to JSON + tree_file = str(self.pickle_file).split(".pkl")[0] + "_tree.json" + tree_file = ( + tree_file if (self.n_times_saved % 2 == 0 or final) else tree_file + "tmp" + ) + try: + tree_dict = self.to_tree_dict() + with open(tree_file, "w") as f: + json.dump(tree_dict, f, indent=2, default=str) + print(f"{self.ID}Saved tree structure to {tree_file}") + except Exception as e: + print(f"{self.ID}Warning: Failed to save tree structure: {e}") + import traceback + + traceback.print_exc() + if self.cfg.save_full_state: pickle_file = ( self.pickle_file @@ -951,7 +1564,11 @@ def save(self, final=False): if final: print(f"{self.ID}Removing temp files as this is the final save.") - for tmp_file in [branchvals_file + "tmp", str(self.pickle_file) + "tmp"]: + for tmp_file in [ + branchvals_file + "tmp", + str(self.pickle_file) + "tmp", + tree_file + "tmp", + ]: Path(tmp_file).unlink(missing_ok=True) t1 = time.time() print(f"{self.ID}Saved in {t1 - t0} seconds to {self.pickle_file}") @@ -1178,6 +1795,25 @@ def trace_distance_colors(key): plt.clf() plt.cla() + # NEW PLOTS: Branch counts including unsampled branches + try: + plot_branch_counts_over_time( + self.branch_values, name=self.name, outfolder=plots_dir, save=True + ) + + plot_branching_events_detail( + self.branch_values, name=self.name, outfolder=plots_dir, save=True + ) + + plot_sampling_efficiency( + self.branch_values, name=self.name, outfolder=plots_dir, save=True + ) + except Exception as e: + print(f"Error plotting branch counts: {e}") + import traceback + + traceback.print_exc() + # Log to wandb if self.wandb_project is not None: # Select just the most recent combined values @@ -1519,6 +2155,8 @@ def evolve_and_branch_leaf(self, stop_before_branching=False, t_evo=None, **kwar if self.evolved_time >= self.cfg.t_evo: print(f"{self.ID}Finished.") self.finished = True + # Update final walltime when branch finishes + self.final_walltime = datetime.now() # Measure self.branch_values.add_measurements_tebd( self.tebd_engine, extra_measurements=self.trace_distances