Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forced swapping methods #71

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
90e4ec0
Add forced swap mode
ajfriedman22 Dec 10, 2024
e86ce0e
Update proposal options
ajfriedman22 Dec 11, 2024
c186287
fix error
ajfriedman22 Dec 11, 2024
1dfc70d
fix errors
ajfriedman22 Dec 11, 2024
0dc66ff
fix
ajfriedman22 Dec 11, 2024
b9de8d8
Add forced_random
ajfriedman22 Jan 6, 2025
163d4f4
Fix error in forced_random
ajfriedman22 Jan 6, 2025
81f8482
Fix forced_random l state assignment
ajfriedman22 Jan 9, 2025
db5f1e0
Fix linting
ajfriedman22 Jan 9, 2025
f305e03
Fix tests
ajfriedman22 Jan 14, 2025
cb8a343
Fix linting
ajfriedman22 Jan 14, 2025
ab9f394
Add MT-REXEE analysis functions
ajfriedman22 Jan 17, 2025
6f84779
Fix linting
ajfriedman22 Jan 17, 2025
1e9655c
Save transition matrix to array and fix file paths
ajfriedman22 Jan 17, 2025
ff11294
Add forced-swap test
ajfriedman22 Jan 19, 2025
92e2b89
Merge branch 'forced-swap' of github.com:wehs7661/ensemble_md into fo…
ajfriedman22 Jan 21, 2025
618346f
update analysis
ajfriedman22 Jan 31, 2025
67cc828
Merge branch 'forced-swap' of github.com:wehs7661/ensemble_md into fo…
ajfriedman22 Jan 31, 2025
a7cb9a3
update cli commands
ajfriedman22 Mar 6, 2025
fca1f0a
update utils
ajfriedman22 Mar 6, 2025
1a4b736
Merge branch 'master' into forced-swap
ajfriedman22 Mar 6, 2025
01f1a6d
Remove forced_swap option and rename forced_random to random_range
ajfriedman22 Mar 6, 2025
3739064
Fix linting
ajfriedman22 Mar 6, 2025
a9fb24c
Fix test errors
ajfriedman22 Mar 6, 2025
bc33e22
Remove redundancy
ajfriedman22 Mar 6, 2025
ff71715
Add checkpoint file for tracking the frame selected for the swap
ajfriedman22 Mar 6, 2025
40ce7ff
Add file closure
ajfriedman22 Mar 6, 2025
3a5db22
Merge branch 'forced-swap' of github.com:wehs7661/ensemble_md into fo…
ajfriedman22 Mar 6, 2025
471f884
Fix checkpoint error
ajfriedman22 Mar 7, 2025
0a68daf
Fix MTREXEE FE issue
ajfriedman22 Mar 7, 2025
8903a7d
Fix issue with swaps on checkpoints
ajfriedman22 Mar 7, 2025
f4f833e
Fix FE function for MTREXEE
ajfriedman22 Mar 7, 2025
703f1dd
Fix FE estimates for MTREXEE
ajfriedman22 Mar 7, 2025
18eff2f
Fix linting
ajfriedman22 Mar 7, 2025
ee284de
Fix error
ajfriedman22 Mar 7, 2025
228d211
fix merge
ajfriedman22 Mar 7, 2025
86f3413
update FE for mtrexee
ajfriedman22 Mar 14, 2025
760a0cb
update tests
ajfriedman22 Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions ensemble_md/analysis/analyze_free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
A list of lists free energy differences between adjacent states for all replicas.
state_ranges : list
A list of lists of showing the state indices sampled by each replica.
n_tot : int
Number of lambda states
df_err_adjacent : list, Optional
A list of lists of uncertainties corresponding to the values of :code:`df_adjacent`. Notably, if
:code:`df_err_adjacent` is :code:`None`, simple means will be used. Otherwise, inverse-variance weighted
Expand Down Expand Up @@ -247,7 +249,37 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
return df, df_err, overlap_bool


def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None):
def _calculate_df(estimators):
"""
An internal function used in :func:`calculate_free_energy` to calculate a list of free energies between adjacent
states for all replicas.

Parameters
----------
estimators : list
A list of estimators fitting the input data for all replicas. With this, the user
can access all the free energies and their associated uncertainties for all states and replicas.
In our code, these estimators come from the function :func:`_apply_estimators`.

Returns
-------
df : float
Free energy differences between for specified replica.
df_err : float
Uncertainties corresponding to the values in :code:`df`.

See also
--------
:func:`calculate_free_energy`
"""
l = np.linspace(0, 1, num=len(estimators[0].index))
estimators[0].index = l
estimators[0].columns = l
est = estimators[0].loc[0, 1]
return est


def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None, MTREXEE=False): # noqa: E501
"""
Caculates the averaged free energy profile with the chosen method given :math:`u_{nk}` or :math:`dH/dλ` data
obtained from all replicas of the REXEE simulation. Available methods include TI, BAR, and MBAR. TI
Expand Down Expand Up @@ -275,6 +307,8 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
seed : int, Optional
The random seed for bootstrapping. Only relevant when :code:`err_method` is :code:`"bootstrap"`.
The default is :code:`None`.
MTREXEE : bool
Whether this is a MT-REXEE simulation or not

Returns
-------
Expand All @@ -299,10 +333,17 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
>>> f, _, _ = analyze_free_energy.calculate_free_energy(data_list, state_ranges, "MBAR", "propagate")
"""
n_sim = len(data)
n_tot = state_ranges[-1][-1] + 1
if MTREXEE is False:
n_tot = state_ranges[-1][-1] + 1
else:
n_tot = state_ranges[-1] + 1
estimators = _apply_estimators(data, df_method)
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate')
print(estimators)
if MTREXEE is False:
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # noqa: E501
else:
df, df_err = _calculate_df(estimators)

if err_method == 'bootstrap':
if seed is not None:
Expand All @@ -314,15 +355,18 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
for b in range(n_bootstrap):
sampled_data = [sampled_data_all[i].iloc[b * len(data[i]):(b + 1) * len(data[i])] for i in range(n_sim)]
bootstrap_estimators = _apply_estimators(sampled_data, df_method)
df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators)
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
if MTREXEE is False:
df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators)
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
else:
df_sampled, _ = _calculate_df(bootstrap_estimators)
df_bootstrap.append(df_sampled)
error_bootstrap = np.std(df_bootstrap, axis=0, ddof=1)

# Replace the value in df_err with value in error_bootstrap if df_err corresponds to
# the df between overlapping states
for i in range(n_tot - 1):
if overlap_bool[i] is True:
if MTREXEE is True or overlap_bool[i] is True:
print(f'Replaced the propagated error with the bootstrapped error for states {i} and {i + 1}: {df_err[i]:.5f} -> {error_bootstrap[i]:.5f}.') # noqa: E501
df_err[i] = error_bootstrap[i]
elif err_method == 'propagate':
Expand Down
163 changes: 162 additions & 1 deletion ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,15 @@ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, pre
hist, bins = np.histogram(traj, bins=np.arange(lower_bound, upper_bound + 1, 1))
hist_data.append(hist)
if save_hist is True:
np.save('hist_data.npy', hist_data)
if len(fig_name.split('/')) > 1:
dir_list = []
for i in fig_name.split('/')[:-1]:
dir_list.append(i)
dir_list.append('/')
dir_path = ''.join(dir_list)
np.save(f'{dir_path}/hist_data.npy', hist_data)
else:
np.save('hist_data.npy', hist_data)

# Use the same bins for all histograms
bins = bins[:-1] # Remove the last bin edge because there are n+1 bin edges for n bins
Expand Down Expand Up @@ -685,6 +693,8 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
units : str
The units of the time.
"""
import pandas as pd

if dt is None:
x = np.arange(len(trajs[0]))
units = 'step'
Expand Down Expand Up @@ -824,6 +834,14 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
plt.savefig(f'{folder}/hist_{fig_names[t]}', dpi=600)
else:
plt.savefig(f'{folder}/{fig_prefix}_hist_{fig_names[t]}', dpi=600)
# Save to csv
sim_list, rt_list = [], []
for n in range(len(t_roundtrip_list)):
for rt in t_roundtrip_list[n]:
sim_list.append(n)
rt_list.append(rt)
df_rt = pd.DataFrame({'Sim': sim_list, 'Round Trip Time': rt_list})
df_rt.to_csv(f'{folder}/roundtrip_times.csv')

return t_0k_list, t_k0_list, t_roundtrip_list, units

Expand Down Expand Up @@ -1330,3 +1348,146 @@ def get_delta_w_updates(log_file, plot=False):
plt.savefig('delta_w_updates.png', dpi=600)

return t_updates, delta_w_updates, equil


def end_states_only_traj(working_dir, n_sim, n_iter, l0_states, l1_states, swap_rep_pattern, ps_per_frame):
"""
Create a trajectory which is a concatenation off all frames for each unique end state.

Parameters
----------
working_dir : str
path for the current working directory
n_sim : int
the number of simulations run
n_iter : int
the number of iterations run
l0_states : list of int
the lambda states which correspond to lambda=0
l1_states : list of int
the lambda states which correspond to lambda=1
swap_rep_pattern : list of int
the replica swapping pattern which will indicate which end states are common
ps_per_frame : float
the timestep to convert the time in the GROMACS dh/dl file to frames in the trajecotry

Returns
-------
None
"""
import pandas as pd
import os
import mdtraj as md

# Determine how many end states are present, which simulations and lambdas those end states correspond to
state_name = ['A']
considered_swaps = [[0, 0]]
cat = ord('A') + 1
for swap in swap_rep_pattern:
part_1, part_2 = swap
if part_1 in considered_swaps and part_2 in considered_swaps:
continue
elif part_1 in considered_swaps:
index = considered_swaps.index(part_1)
state_name.append(state_name[index])
considered_swaps.append(part_2)
elif part_2 in considered_swaps:
index = considered_swaps.index(part_2)
state_name.append(state_name[index])
considered_swaps.append(part_1)
else:
state_name.append(chr(cat))
state_name.append(chr(cat))
considered_swaps.append(part_1)
considered_swaps.append(part_2)
cat += 1
for i in range(n_sim):
for j in [0, 1]:
if [i, j] not in considered_swaps:
state_name.append(chr(cat))
considered_swaps.append([i, j])
cat += 1

# Determine which frames correspond to which end states
state_frame_df = pd.DataFrame()
for n in range(n_sim):
for i in range(n_iter):
l0_frame, l1_frame = [], []
dhdl_file = open(f'{working_dir}/sim_{n}/iteration_{i}/dhdl.xvg', 'r').readlines()
start = True
for line in dhdl_file:
split_line = line.split(' ')
while '' in split_line:
split_line.remove('')
if '#' not in split_line[0] and '@' not in split_line[0]:
time = float(split_line[0])
if start:
start_time = time
start = False
state = float(split_line[1])
if time % ps_per_frame == 0:
if state in l0_states:
l0_frame.append(int((time-start_time)/ps_per_frame))
elif state in l1_states:
l1_frame.append(int((time-start_time)/ps_per_frame))
if len(l0_frame) != 0:
df_0 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l0_frame, 'Lambda': 0})
state_frame_df = pd.concat([state_frame_df, df_0])
if len(l1_frame) != 0:
df_1 = pd.DataFrame({'Sim': n, 'Iteration': i, 'Frame': l1_frame, 'Lambda': 1})
state_frame_df = pd.concat([state_frame_df, df_1])

# Concatenate all frames from each set of trajectories for each end state
unique_states = list(set(state_name))
for state in unique_states:
indices = [i for i, value in enumerate(state_name) if value == state]
for i, index in enumerate(indices):
rep, lambda_rep = considered_swaps[index]
started = False
if os.path.exists(f'{working_dir}/sim_{rep}/iteration_0/confout_backup.gro'):
name = 'confout_backup'
else:
name = 'confout'
for iteration in range(n_iter):
frames_select = state_frame_df[(state_frame_df['Sim'] == rep) & (state_frame_df['Iteration'] == iteration) & (state_frame_df['Lambda'] == lambda_rep)]['Frame'].to_numpy() # noqa: E501
if len(frames_select) != 0:
if not started:
traj = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
started = True
else:
traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
traj = md.join(traj, traj_add)
traj.save_xtc(f'{working_dir}/analysis/{state}_{rep}.xtc')


def concat_sim_traj(working_dir, n_sim, n_iter):
"""
Create a trajectory which is a concatenation off each iterations trajectory

Parameters
----------
working_dir : str
path for the current working directory
n_sim : int
the number of simulations run
n_iter : int
the number of iterations run

Returns
-------
None
"""
import mdtraj as md
import os

for rep in range(n_sim):
if os.path.exists(f'{working_dir}/sim_{rep}/iteration_0/confout_backup.gro'):
name = 'confout_backup'
else:
name = 'confout'

traj = md.load(f'{working_dir}/sim_{rep}/iteration_0/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
for iteration in range(1, n_iter):
traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
traj = md.join([traj, traj_add])
traj.save_xtc(f'{working_dir}/analysis/sim{rep}_concat.xtc')
Loading
Loading