Skip to content

Commit 983d571

Browse files
authored
Merge pull request #16 from nekitmm/fix/weights-extraction
Fix weights extraction and TF version
2 parents cb79d29 + ea7f3f8 commit 983d571

File tree

3 files changed

+82
-200
lines changed

3 files changed

+82
-200
lines changed

dlpacker/dlpacker.py

+45-97
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,42 @@
3636

3737
import os
3838

39+
import numpy as np
40+
from Bio.PDB import (
41+
PDBParser,
42+
Selection,
43+
Superimposer,
44+
PDBIO,
45+
Atom,
46+
Residue,
47+
Structure,
48+
)
49+
50+
from dlpacker.utils import (
51+
DLPModel,
52+
InputBoxReader,
53+
THE20,
54+
SCH_ATOMS,
55+
BB_ATOMS,
56+
SIDE_CHAINS,
57+
BOX_SIZE,
58+
)
59+
3960
dir_path = os.path.dirname(os.path.realpath(__file__))
4061
DEFAULT_REFERENCE_PDB = os.path.join(dir_path, 'data', 'reference.pdb')
4162
DEFAULT_LIBRARY_NPZ = os.path.join(dir_path, 'data', 'library.npz')
4263
DEFAULT_CHARGES_RTP = os.path.join(dir_path, 'data', 'charges.rtp')
4364

44-
CUSTOMIZED_WEIGHTS_DIR=os.getenv('DLPACKER_PRETRAINED_WEIGHT')
65+
66+
CUSTOMIZED_WEIGHTS_DIR = os.getenv('DLPACKER_PRETRAINED_WEIGHT')
4567
if CUSTOMIZED_WEIGHTS_DIR:
4668
if not os.path.exists(CUSTOMIZED_WEIGHTS_DIR):
4769
os.makedirs(CUSTOMIZED_WEIGHTS_DIR)
4870
DEFAULT_WEIGHTS = os.path.join(CUSTOMIZED_WEIGHTS_DIR, 'DLPacker_weights')
4971
else:
5072
DEFAULT_WEIGHTS = os.path.join(dir_path, 'data', 'DLPacker_weights')
5173

74+
5275
if not os.path.exists(f'{DEFAULT_WEIGHTS}.h5'):
5376
from dlpacker.utils import unzip_weights
5477

@@ -58,27 +81,6 @@
5881
output_dir=os.path.dirname(DEFAULT_WEIGHTS),
5982
)
6083

61-
import numpy as np
62-
from Bio.PDB import (
63-
PDBParser,
64-
Selection,
65-
Superimposer,
66-
PDBIO,
67-
Atom,
68-
Residue,
69-
Structure,
70-
)
71-
from dlpacker.utils import (
72-
DLPModel,
73-
InputBoxReader,
74-
DataGenerator,
75-
THE20,
76-
SCH_ATOMS,
77-
BB_ATOMS,
78-
SIDE_CHAINS,
79-
BOX_SIZE,
80-
)
81-
8284

8385
class DLPacker:
8486
# This is the meat of our code.
@@ -126,9 +128,7 @@ def __init__(
126128

127129
self.input_reader = input_reader
128130
if not self.input_reader:
129-
self.input_reader = InputBoxReader(
130-
charges_filename=charges_filename
131-
)
131+
self.input_reader = InputBoxReader(charges_filename=charges_filename)
132132

133133
def _load_library(self):
134134
# Loads library of rotamers.
@@ -137,9 +137,7 @@ def _load_library(self):
137137
self.library = np.load(self.lib_name, allow_pickle=True)
138138
self.library = self.library['arr_0'].item()
139139
for k in self.library['grids']:
140-
self.library['grids'][k] = self.library['grids'][k].astype(
141-
np.float32
142-
)
140+
self.library['grids'][k] = self.library['grids'][k].astype(np.float32)
143141

144142
def _read_structures(self):
145143
# Reads in main PDB structure and reference structure.
@@ -218,13 +216,9 @@ def _remove_altloc(self, structure: Structure):
218216
disordered_list.append(atom)
219217
# sometimes one of the altlocs just does not exist!
220218
try:
221-
selected_list.append(
222-
atom.disordered_get(self.altloc[0])
223-
)
219+
selected_list.append(atom.disordered_get(self.altloc[0]))
224220
except:
225-
selected_list.append(
226-
atom.disordered_get(self.altloc[1])
227-
)
221+
selected_list.append(atom.disordered_get(self.altloc[1]))
228222
selected_list[-1].set_altloc(' ')
229223
selected_list[-1].disordered_flag = 0
230224

@@ -249,11 +243,7 @@ def _align_residue(self, residue: Residue):
249243
# In order to generate input box properly
250244
# we first need to align selected residue
251245
# to reference atoms from reference.pdb
252-
if (
253-
not residue.has_id('N')
254-
or not residue.has_id('C')
255-
or not residue.has_id('CA')
256-
):
246+
if not residue.has_id('N') or not residue.has_id('C') or not residue.has_id('CA'):
257247
print(
258248
'Missing backbone atoms: residue',
259249
self._get_residue_tuple(residue),
@@ -265,9 +255,7 @@ def _align_residue(self, residue: Residue):
265255
self.sup.apply(self._get_parent_structure(residue))
266256
return True
267257

268-
def _align_structures(
269-
self, structure_a: Structure, structure_b: Structure
270-
):
258+
def _align_structures(self, structure_a: Structure, structure_b: Structure):
271259
# Aligns two structures using backbone atoms
272260
bb_a, bb_b = [], []
273261
residues_a = Selection.unfold_entities(structure_a, 'R')
@@ -297,20 +285,11 @@ def _get_box_atoms(self, residue: Residue):
297285
b = self.box_size + 1 # one angstrom offset to include more atoms
298286
for a in self._get_parent_structure(residue).get_atoms():
299287
xyz = a.coord
300-
if (
301-
xyz[0] < b
302-
and xyz[0] > -b
303-
and xyz[1] < b
304-
and xyz[1] > -b
305-
and xyz[2] < b
306-
and xyz[2] > -b
307-
):
288+
if xyz[0] < b and xyz[0] > -b and xyz[1] < b and xyz[1] > -b and xyz[2] < b and xyz[2] > -b:
308289
atoms.append(a)
309290
return atoms
310291

311-
def _genetare_input_box(
312-
self, residue: Residue, allow_missing_atoms: bool = False
313-
):
292+
def _genetare_input_box(self, residue: Residue, allow_missing_atoms: bool = False):
314293
# Takes a residue and generates a special
315294
# dictionary that is then given to InputReader,
316295
# which uses this dictionary to generate the actual input
@@ -400,42 +379,29 @@ def _get_sorted_residues(
400379
for residue in Selection.unfold_entities(structure, 'R'):
401380
if not targets or self._get_residue_tuple(residue) in targets:
402381
if residue.get_resname() in THE20:
403-
if (
404-
residue.has_id('CA')
405-
and residue.has_id('C')
406-
and residue.has_id('N')
407-
):
382+
if residue.has_id('CA') and residue.has_id('C') and residue.has_id('N'):
408383
atoms = self._get_box_atoms(residue)
409384
tuples.append((residue, len(atoms)))
410385
tuples.sort(key=lambda x: -x[1])
411386

412387
elif method == 'score':
413388
tuples = []
414-
for i, residue in enumerate(
415-
Selection.unfold_entities(structure, 'R')
416-
):
389+
for i, residue in enumerate(Selection.unfold_entities(structure, 'R')):
417390
if not targets or self._get_residue_tuple(residue) in targets:
418-
if (
419-
residue.get_resname() in THE20
420-
and residue.get_resname() != 'GLY'
421-
):
391+
if residue.get_resname() in THE20 and residue.get_resname() != 'GLY':
422392
name = self._get_residue_tuple(residue)
423393
print("Scoring residue:", i, name, end='\r')
424394

425395
r, s, n = self._get_residue_tuple(residue)
426396
box = self._genetare_input_box(residue, True)
427397

428398
if not box:
429-
print(
430-
"\nSkipping residue:", i, residue.get_resname()
431-
)
399+
print("\nSkipping residue:", i, residue.get_resname())
432400
continue
433401

434402
pred = self._get_prediction(box, n)
435403
scores = np.abs(self.library['grids'][n] - pred)
436-
scores = np.mean(
437-
scores, axis=tuple(range(1, pred.ndim + 1))
438-
)
404+
scores = np.mean(scores, axis=tuple(range(1, pred.ndim + 1)))
439405
best_ind = np.argmin(scores)
440406
best_score = np.min(scores)
441407
tuples.append((residue, best_score / SCH_ATOMS[n]))
@@ -512,9 +478,7 @@ def mutate_sequence(self, target: tuple, new_label: str):
512478
# and mutates it in the sequence to new one given by new_label argument
513479
# IMPORTANT: this function just renames a residue without
514480
# doing anything else at all
515-
assert (
516-
new_label in THE20
517-
), 'Only mutations to canonical 20 amino acids are supported!'
481+
assert new_label in THE20, 'Only mutations to canonical 20 amino acids are supported!'
518482
for residue in Selection.unfold_entities(self.structure, 'R'):
519483
if target == self._get_residue_tuple(residue):
520484
residue.resname = new_label
@@ -558,9 +522,7 @@ def reconstruct_residue(self, residue: Residue, refine_only: bool = False):
558522
residue[name].coord = best_match[i]
559523
else:
560524
# most values are dummy here
561-
new_atom = Atom.Atom(
562-
name, best_match[i], 0, 1, ' ', name, 2, element=name[:1]
563-
)
525+
new_atom = Atom.Atom(name, best_match[i], 0, 1, ' ', name, 2, element=name[:1])
564526
residue.add(new_atom)
565527

566528
def reconstruct_protein(
@@ -587,21 +549,14 @@ def reconstruct_protein(
587549
if not self.reconstructed:
588550
self.reconstructed = self.structure.copy()
589551
else:
590-
print(
591-
'Reconstructed structure already exists, something might be wrong!'
592-
)
552+
print('Reconstructed structure already exists, something might be wrong!')
593553
if not refine_only:
594554
self._remove_sidechains(self.reconstructed)
595555

596556
# run reconstruction for all residues in selected order
597-
sorted_residues = self._get_sorted_residues(
598-
self.reconstructed, method=order
599-
)
557+
sorted_residues = self._get_sorted_residues(self.reconstructed, method=order)
600558
for i, residue in enumerate(sorted_residues):
601-
if (
602-
residue.get_resname() in THE20
603-
and residue.get_resname() != 'GLY'
604-
):
559+
if residue.get_resname() in THE20 and residue.get_resname() != 'GLY':
605560
name = self._get_residue_tuple(residue)
606561
print("Working on residue:", i, name, end='\r')
607562
self.reconstruct_residue(residue, refine_only)
@@ -636,9 +591,7 @@ def reconstruct_region(
636591
if not self.reconstructed:
637592
self.reconstructed = self.structure.copy()
638593
else:
639-
print(
640-
'Reconstructed structure already exists, something might be wrong!'
641-
)
594+
print('Reconstructed structure already exists, something might be wrong!')
642595

643596
# remove side chains for target amino acids is refine_only is False
644597
if not refine_only:
@@ -647,15 +600,10 @@ def reconstruct_region(
647600
self._remove_sidechain(residue)
648601

649602
# run reconstruction for specified list of residues
650-
sorted_residues = self._get_sorted_residues(
651-
self.reconstructed, targets, method=order
652-
)
603+
sorted_residues = self._get_sorted_residues(self.reconstructed, targets, method=order)
653604
for i, residue in enumerate(sorted_residues):
654605
if self._get_residue_tuple(residue) in targets:
655-
if (
656-
residue.get_resname() in THE20
657-
and residue.get_resname() != 'GLY'
658-
):
606+
if residue.get_resname() in THE20 and residue.get_resname() != 'GLY':
659607
name = self._get_residue_tuple(residue)
660608
print("Working on residue:", i, name, end='\r')
661609
self.reconstruct_residue(residue, refine_only)

0 commit comments

Comments
 (0)