diff --git a/protein_holography/coordinates/get_holograms.py b/protein_holography/coordinates/get_holograms.py index 20fc599..c6947f1 100755 --- a/protein_holography/coordinates/get_holograms.py +++ b/protein_holography/coordinates/get_holograms.py @@ -11,10 +11,10 @@ from progress.bar import Bar import traceback -def c(np_nh,L_max,ks,num_combi_channels,r_max): +def c(np_nh,L_max,ks,num_combi_channels,r_max,element_channels): #try: - hgm = get_hologram(np_nh,L_max,ks,num_combi_channels,r_max) + hgm = get_hologram(np_nh,L_max,ks,num_combi_channels,r_max,element_channels) #except Exception as e: # print(e) # print('Error with',np_nh[0]) @@ -50,7 +50,8 @@ def c(np_nh,L_max,ks,num_combi_channels,r_max): ds = PDBPreprocessor(args.hdf5_in,args.neighborhood_list) bad_neighborhoods = [] n = 0 - channels = ['C','N','O','S','H','SASA','charge'] + element_channels = [b'C',b'N',b'O',b'S',b'H',b"P",b"F",b"Cl",] + channels = np.concatenate((element_channels, [b"Unk", b'SASA',b'charge'])) num_combi_channels = len(channels) * len(args.ks) dt = np.dtype([(str(l),'complex64',(num_combi_channels,2*l+1)) for l in range(args.Lmax + 1)]) @@ -77,7 +78,8 @@ def c(np_nh,L_max,ks,num_combi_channels,r_max): params = {'L_max': args.Lmax, 'ks':args.ks, 'num_combi_channels': num_combi_channels, - 'r_max': args.r_max}, + 'r_max': args.r_max, + "element_channels":element_channels}, parallelism = args.parallelism)): if hgm is None: bar.next() diff --git a/protein_holography/coordinates/get_holograms_zach.py b/protein_holography/coordinates/get_holograms_zach.py index ed3d7f1..1e3510a 100755 --- a/protein_holography/coordinates/get_holograms_zach.py +++ b/protein_holography/coordinates/get_holograms_zach.py @@ -10,10 +10,10 @@ from progress.bar import Bar import traceback -def c(np_nh,L_max,ks,num_combi_channels,r_max): +def c(np_nh,L_max,ks,num_combi_channels,r_max,element_channels): #try: - hgm = get_hologram(np_nh,L_max,ks,num_combi_channels,r_max) + hgm = get_hologram(np_nh,L_max,ks,num_combi_channels,r_max,element_channels) #except Exception as e: # print(e) @@ -50,7 +50,8 @@ def c(np_nh,L_max,ks,num_combi_channels,r_max): ds = PDBPreprocessor(args.hdf5_in,args.neighborhood_list) bad_neighborhoods = [] n = 0 - channels = ['C','N','O','S','H','SASA','charge'] + element_channels = [b'C',b'N',b'O',b'S',b'H',b"P",b"F",b"Cl",] + channels = np.concatenate((element_channels, [b"Unk", b'SASA',b'charge'])) num_combi_channels = len(channels) * len(args.ks) dt = np.dtype([(str(l),'complex64',(num_combi_channels,2*l+1)) for l in range(args.Lmax + 1)]) @@ -77,7 +78,8 @@ def c(np_nh,L_max,ks,num_combi_channels,r_max): params = {'L_max': args.Lmax, 'ks':args.ks, 'num_combi_channels': num_combi_channels, - 'r_max': args.r_max}, + 'r_max': args.r_max, + "element_channels":element_channels}, parallelism = args.parallelism)): if hgm is None: bar.next() diff --git a/protein_holography/coordinates/get_neighborhoods.py b/protein_holography/coordinates/get_neighborhoods.py index 77a68a4..1535c7a 100755 --- a/protein_holography/coordinates/get_neighborhoods.py +++ b/protein_holography/coordinates/get_neighborhoods.py @@ -65,7 +65,7 @@ def get_neighborhoods( dt = np.dtype([ ('res_id','S5',(6)), ('atom_names', 'S4', (max_atoms)), - ('elements', 'S1', (max_atoms)), + ('elements', 'S3', (max_atoms)), ('res_ids', 'S5', (max_atoms,6)), ('coords', 'f8', (max_atoms,3)), ('SASAs', 'f8', (max_atoms)), diff --git a/protein_holography/coordinates/get_structural_info.py b/protein_holography/coordinates/get_structural_info.py index a68821b..6099dfb 100755 --- a/protein_holography/coordinates/get_structural_info.py +++ b/protein_holography/coordinates/get_structural_info.py @@ -66,7 +66,7 @@ def c(pose,padded_length=200000): dt = np.dtype([ ('pdb','S4',()), ('atom_names', 'S4', (max_atoms)), - ('elements', 'S1', (max_atoms)), + ('elements', 'S3', (max_atoms)), ('res_ids', 'S5', (max_atoms,6)), ('coords', 'f8', (max_atoms,3)), ('SASAs', 'f8', (max_atoms)), diff --git a/protein_holography/coordinates/get_zernikegrams.py b/protein_holography/coordinates/get_zernikegrams.py index 0a673c0..cc4f059 100755 --- a/protein_holography/coordinates/get_zernikegrams.py +++ b/protein_holography/coordinates/get_zernikegrams.py @@ -10,10 +10,10 @@ from progress.bar import Bar import traceback -def c(np_nh,L_max,ks,num_combi_channels,r_max): +def c(np_nh,L_max,ks,num_combi_channels,r_max,element_channels): #try: - hgm = get_hologram(np_nh,L_max,ks,num_combi_channels,r_max) + hgm = get_hologram(np_nh,L_max,ks,num_combi_channels,r_max,element_channels) #except Exception as e: # print(e) @@ -45,7 +45,8 @@ def get_zernikegrams( ds = PDBPreprocessor(hdf5_in,neighborhood_list) bad_neighborhoods = [] n = 0 - channels = ['C','N','O','S','H','SASA','charge'] + element_channels = [b'C',b'N',b'O',b'S',b'H',b"P",b"F",b"Cl",] + channels = np.concatenate((element_channels, [b"Unk", b'SASA',b'charge'])) num_combi_channels = len(channels) * len(ks) dt = np.dtype([(str(l),'complex64',(num_combi_channels,2*l+1)) for l in range(Lmax + 1)]) @@ -77,7 +78,8 @@ def get_zernikegrams( params = {'L_max': Lmax, 'ks':ks, 'num_combi_channels': num_combi_channels, - 'r_max': r_max}, + 'r_max': r_max, + "element_channels":element_channels}, parallelism = parallelism)): if hgm is None or hgm[0] is None: bar.next() diff --git a/protein_holography/coordinates/pyrosetta_hdf5_holograms.py b/protein_holography/coordinates/pyrosetta_hdf5_holograms.py index d8a2d2b..01820cc 100755 --- a/protein_holography/coordinates/pyrosetta_hdf5_holograms.py +++ b/protein_holography/coordinates/pyrosetta_hdf5_holograms.py @@ -66,12 +66,13 @@ def zernike_coeff_lm_new(r, t, p, n, r_max, l, m, weights): return coeffs -def get_hologram(nh,L_max,ks,num_combi_channels,r_max): +def get_hologram(nh,L_max,ks,num_combi_channels,r_max, + element_channels=[b'C',b'N',b'O',b'S',b'H',b"P",b"F",b"Cl",]): dt = np.dtype([(str(l),'complex64',(num_combi_channels,2*l+1)) for l in range(L_max + 1)]) arr = np.zeros(shape=(1,),dtype=dt) # get info from nh - channels = ['C','N','O','S','H','SASA','charge'] + channels = np.concatenate((element_channels, [b"Unk", b'SASA',b'charge'])) num_channels = len(channels) atom_names = nh['atom_names'] real_locs = np.logical_and(atom_names != b'',nh['coords'][:,0] <= r_max) @@ -104,31 +105,16 @@ def get_hologram(nh,L_max,ks,num_combi_channels,r_max): nonzero_len = np.count_nonzero(nonzero_idxs) nmax = len(ks) - - for i_ch,ch in enumerate(channels): - - if ch == 'C': - r,t,p = *atom_coords[elements == b'C'].T, - weights=np.ones(shape=(r.shape[-1],)) - if ch == 'N': - r,t,p = *atom_coords[elements == b'N'].T, - weights=np.ones(shape=(r.shape[-1],)) - if ch == 'O': - r,t,p = *atom_coords[elements == b'O'].T, - weights=np.ones(shape=(r.shape[-1],)) - if ch == 'S': - r,t,p = *atom_coords[elements == b'S'].T, - weights=np.ones(shape=(r.shape[-1],)) - if ch == 'H': - r,t,p = *atom_coords[elements == b'H'].T, - weights=np.ones(shape=(r.shape[-1],)) - if ch == 'SASA': - weights = curr_SASA - r,t,p = np.einsum('ij->ji',atom_coords) - #print(weights) - if ch == 'charge': - r,t,p = np.einsum('ij->ji',atom_coords) - weights = curr_charge + arr_weights = np.empty(shape=(num_channels,r.shape[-1],)) + which_channel = np.array( elements[:,None] == element_channels, dtype=float) + r,t,p = np.einsum('ij->ji',atom_coords) + + arr_weights[:len(element_channels)] = which_channel + arr_weights[-3] = np.logical_not( np.any( which_channel,axis=1)) + arr_weights[-2] = curr_SASA + arr_weights[-1] = curr_charge + + out_z = np.zeros(shape=(num_channels,ns.shape[0]), dtype=np.complex64) out_z = np.zeros(shape=ns.shape[0], dtype=np.complex64) diff --git a/protein_holography/coordinates/pyrosetta_hdf5_holograms_zach.py b/protein_holography/coordinates/pyrosetta_hdf5_holograms_zach.py index d8a2d2b..77bee7c 100755 --- a/protein_holography/coordinates/pyrosetta_hdf5_holograms_zach.py +++ b/protein_holography/coordinates/pyrosetta_hdf5_holograms_zach.py @@ -66,12 +66,13 @@ def zernike_coeff_lm_new(r, t, p, n, r_max, l, m, weights): return coeffs -def get_hologram(nh,L_max,ks,num_combi_channels,r_max): +def get_hologram(nh,L_max,ks,num_combi_channels,r_max, + element_channels = [b'C',b'N',b'O',b'S',b'H',b"P",b"F",b"Cl",]): dt = np.dtype([(str(l),'complex64',(num_combi_channels,2*l+1)) for l in range(L_max + 1)]) arr = np.zeros(shape=(1,),dtype=dt) # get info from nh - channels = ['C','N','O','S','H','SASA','charge'] + channels = np.concatenate((element_channels, [b"Unk", b'SASA',b'charge'])) num_channels = len(channels) atom_names = nh['atom_names'] real_locs = np.logical_and(atom_names != b'',nh['coords'][:,0] <= r_max) @@ -105,28 +106,21 @@ def get_hologram(nh,L_max,ks,num_combi_channels,r_max): nmax = len(ks) + which_channel = elements[:,None] == element_channels) for i_ch,ch in enumerate(channels): - - if ch == 'C': - r,t,p = *atom_coords[elements == b'C'].T, - weights=np.ones(shape=(r.shape[-1],)) - if ch == 'N': - r,t,p = *atom_coords[elements == b'N'].T, - weights=np.ones(shape=(r.shape[-1],)) - if ch == 'O': - r,t,p = *atom_coords[elements == b'O'].T, - weights=np.ones(shape=(r.shape[-1],)) - if ch == 'S': - r,t,p = *atom_coords[elements == b'S'].T, + + if ch in element_channels: + r,t,p = *atom_coords[np.where(which_channel[:,i_ch])].T, weights=np.ones(shape=(r.shape[-1],)) - if ch == 'H': - r,t,p = *atom_coords[elements == b'H'].T, + + elif ch == b'Unk': + r,t,p = *atom_coords[np.where(np.logical_not( np.any( which_channel,axis=1)))].T, weights=np.ones(shape=(r.shape[-1],)) - if ch == 'SASA': + + elif ch == b'SASA': weights = curr_SASA r,t,p = np.einsum('ij->ji',atom_coords) - #print(weights) - if ch == 'charge': + elif ch == b'charge': r,t,p = np.einsum('ij->ji',atom_coords) weights = curr_charge diff --git a/protein_holography/coordinates/pyrosetta_hdf5_zernikegrams.py b/protein_holography/coordinates/pyrosetta_hdf5_zernikegrams.py index 0c51df9..d4e8f51 100755 --- a/protein_holography/coordinates/pyrosetta_hdf5_zernikegrams.py +++ b/protein_holography/coordinates/pyrosetta_hdf5_zernikegrams.py @@ -86,12 +86,13 @@ def zernike_coeff_lm_new(r: np.ndarray, t: np.ndarray, p: np.ndarray, n: np.ndar return coeffs -def get_hologram(nh, L_max: int, ks, num_combi_channels, r_max: np.float64): +def get_hologram(nh, L_max: int, ks, num_combi_channels, r_max: np.float64, + element_channels=[b'C',b'N',b'O',b'S',b'H',b"P",b"F",b"Cl",]): dt = np.dtype([(str(l),'complex64',(num_combi_channels,2*l+1)) for l in range(L_max + 1)]) arr = np.zeros(shape=(1,),dtype=dt) # get info from nh - channels = ['C','N','O','S','H','SASA','charge'] + channels = np.concatenate((element_channels, [b"Unk", b'SASA',b'charge'])) num_channels = len(channels) atom_names = nh['atom_names'] real_locs = np.logical_and(atom_names != b'',nh['coords'][:,0] <= r_max) @@ -123,38 +124,17 @@ def get_hologram(nh, L_max: int, ks, num_combi_channels, r_max: np.float64): nonzero_idxs = ~(l_greater_n | odds) nonzero_len = np.count_nonzero(nonzero_idxs) nmax = len(ks) + + arr_weights = np.empty(shape=(num_channels,r.shape[-1],)) + which_channel = np.array( elements[:,None] == element_channels, dtype=float) + r,t,p = np.einsum('ij->ji',atom_coords) + + arr_weights[:len(element_channels)] = which_channel + arr_weights[-3] = np.logical_not( np.any( which_channel,axis=1)) + arr_weights[-2] = curr_SASA + arr_weights[-1] = curr_charge - arr_weights = np.empty(shape=(7,r.shape[-1],)) - # for i_ch,ch in enumerate(channels): - - - # r,t,p = np.einsum('ij->ji',atom_coords) - # if ch == 'C': - # weights=np.array(elements == b'C',dtype=float) - # if ch == 'N': - # weights=np.array(elements == b'N',dtype=float) - # if ch == 'O': - # weights=np.array(elements == b'O',dtype=float) - # if ch == 'S': - # weights=np.array(elements == b'S',dtype=float) - # if ch == 'H': - # weights=np.array(elements == b'H',dtype=float) - # if ch == 'SASA': - # weights = curr_SASA - # if ch == 'charge': - # weights = curr_charge - - # arr_weights[i_ch] = weights - arr_weights[0] = np.array(elements == b'C', dtype=float) - arr_weights[1] = np.array(elements == b'N', dtype=float) - arr_weights[2] = np.array(elements == b'O', dtype=float) - arr_weights[3] = np.array(elements == b'S', dtype=float) - arr_weights[4] = np.array(elements == b'H', dtype=float) - arr_weights[5] = curr_SASA - arr_weights[6] = curr_charge - - ch_num = len(channels) - out_z = np.zeros(shape=(ch_num,ns.shape[0]), dtype=np.complex64) + out_z = np.zeros(shape=(num_channels,ns.shape[0]), dtype=np.complex64) rs = np.tile(r, (nonzero_len, 1)) ts = np.tile(t, (nonzero_len, 1)) @@ -168,7 +148,7 @@ def get_hologram(nh, L_max: int, ks, num_combi_channels, r_max: np.float64): for l in range(L_max + 1): num_m = (2 * l + 1) high_idx = (nmax) * num_m + low_idx - arr[0][l][:,:] = out_z[:,low_idx:high_idx].reshape(nmax*ch_num, num_m, ) + arr[0][l][:,:] = out_z[:,low_idx:high_idx].reshape(nmax*num_channels, num_m, ) low_idx = high_idx return arr[0] @@ -240,8 +220,8 @@ def get_sparse_hologram(nh,L_max,ks,num_combi_channels,r_max): arr_weights[4] = np.array(atom_names == b'CB',dtype=float) - ch_num = len(channels) - out_z = np.zeros(shape=(ch_num,ns.shape[0]), dtype=np.complex64) + num_channels = len(channels) + out_z = np.zeros(shape=(num_channels,ns.shape[0]), dtype=np.complex64) rs = np.tile(r, (nonzero_len, 1)) ts = np.tile(t, (nonzero_len, 1)) @@ -258,7 +238,7 @@ def get_sparse_hologram(nh,L_max,ks,num_combi_channels,r_max): for l in range(L_max + 1): num_m = (2 * l + 1) high_idx = (nmax) * num_m + low_idx - arr[0][l][:,:] = out_z[:,low_idx:high_idx].reshape(nmax*ch_num, num_m, ) + arr[0][l][:,:] = out_z[:,low_idx:high_idx].reshape(nmax*num_channels, num_m, ) low_idx = high_idx @@ -343,8 +323,8 @@ def get_backbone_hologram(nh,L_max,ks,num_combi_channels,r_max): ]),dtype=float) - ch_num = len(channels) - out_z = np.zeros(shape=(ch_num,ns.shape[0]), dtype=np.complex64) + num_channels = len(channels) + out_z = np.zeros(shape=(num_channels,ns.shape[0]), dtype=np.complex64) rs = np.tile(r, (nonzero_len, 1)) ts = np.tile(t, (nonzero_len, 1)) @@ -361,7 +341,7 @@ def get_backbone_hologram(nh,L_max,ks,num_combi_channels,r_max): for l in range(L_max + 1): num_m = (2 * l + 1) high_idx = (nmax) * num_m + low_idx - arr[0][l][:,:] = out_z[:,low_idx:high_idx].reshape(nmax*ch_num, num_m, ) + arr[0][l][:,:] = out_z[:,low_idx:high_idx].reshape(nmax*num_channels, num_m, ) low_idx = high_idx