Skip to content

Commit ef2dacc

Browse files
committed
Update at 20250115_23h41m06s
0 parents  commit ef2dacc

32 files changed

+101318
-0
lines changed

.gitignore

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
.*
2+
!.gitignore
3+
log*
4+
*log
5+
__pycache__/
6+
*.pyc
7+
lightning_logs/
8+
MNIST/
9+
weights/
10+
/backup
11+
*.*seq
12+
*.yaml
13+
*results
14+
examples
15+
*txt
16+
*png
17+
*jpg
18+
*.fasta

BPfold/__init__.py

Whitespace-only changes.

BPfold/dataset/RNAseq.py

+298
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import os
2+
import yaml
3+
import random
4+
from itertools import product
5+
6+
import numpy as np
7+
import torch
8+
import torch.utils.data as data
9+
from sklearn.model_selection import KFold
10+
11+
from ..util.misc import get_file_name, str_localtime
12+
from ..util.RNA_kit import read_SS, connects2arr, mut_seq
13+
from ..util.postprocess import get_base_index
14+
from ..util.base_pair_motif import BPM_energy
15+
from ..util.base_pair_probability import read_BPPM, gen_BPPM
16+
17+
18+
CANONICAL_PAIRS = {'AU', 'UA', 'GC', 'CG', 'GU', 'UG'}
19+
20+
21+
class RNAseq_data(data.Dataset):
22+
def __init__(self,
23+
data_dir,
24+
index_name='data_index.yaml',
25+
phase='train',
26+
Lmax=600,
27+
Lmin=0,
28+
fold=0,
29+
nfolds=4,
30+
seed=42,
31+
cache_dir='.cache_data',
32+
mask_only=False,
33+
method='CDPfold',
34+
trainall=False,
35+
predict_files=None,
36+
training_set=None,
37+
test_set=None,
38+
use_BPE=True,
39+
use_BPP=True,
40+
normalize_energy=False,
41+
verbose=False,
42+
para_dir='paras/',
43+
BPM_type='all',
44+
*args,
45+
**kargs,
46+
):
47+
# Set all input args as attributes
48+
self.__dict__.update(locals())
49+
self.use_BPE = use_BPE
50+
self.use_BPP = use_BPP
51+
self.BPM_type = BPM_type
52+
self.phase = phase.lower()
53+
self.verbose = verbose
54+
if self.phase == 'predict':
55+
self.cache_dir = os.path.join(cache_dir, method)
56+
else:
57+
self.cache_dir = os.path.join(data_dir, cache_dir, method)
58+
self.method = method
59+
self.data_dir = data_dir
60+
self.mask_only = mask_only
61+
if not os.path.exists(self.cache_dir):
62+
os.makedirs(self.cache_dir)
63+
64+
# data filter
65+
index_file = os.path.join(data_dir, index_name)
66+
if self.phase == 'predict':
67+
self.file_list = predict_files
68+
self.Lmax = max([f['length'] for f in self.file_list])
69+
else:
70+
if self.phase in {'train', 'validate'}:
71+
with open(index_file) as f:
72+
index_dic = yaml.load(f.read(), Loader=yaml.FullLoader)
73+
all_files = index_dic['train'][:]
74+
if 'validate' in index_dic:
75+
all_files += index_dic['validate']
76+
if trainall:
77+
all_files += index_dic['test']
78+
# train with specific datasets
79+
if self.phase == 'train' and training_set:
80+
training_set = set(training_set)
81+
if 'PDB_test' in training_set:
82+
training_set.update({'PDB_test-TS1', 'PDB_test-TS2', 'PDB_test-TS3'})
83+
print(f'training sets: {training_set}')
84+
all_files = [f for f in all_files if f['dataset'] in training_set]
85+
86+
all_files.sort(key=lambda dic: dic['path'])
87+
88+
# Kfold
89+
if nfolds <= 1: # no kfold
90+
if self.phase == 'validate':
91+
self.file_list = index_dic['test'][:]
92+
else: # train
93+
self.file_list = all_files[:]
94+
random.shuffle(self.file_list)
95+
else: # kfold training
96+
split = list(KFold(n_splits=nfolds, random_state=seed, shuffle=True).split(range(len(all_files))))[fold][0 if self.phase == 'train' else 1]
97+
self.file_list = [all_files[i] for i in split]
98+
self.Lmax = Lmax
99+
# limit length
100+
self.file_list = [f for f in self.file_list if Lmin<=f['length']<=self.Lmax]
101+
102+
elif self.phase in ['test']:
103+
with open(index_file) as f:
104+
index_dic = yaml.load(f.read(), Loader=yaml.FullLoader)
105+
self.file_list = index_dic['test'][:]
106+
if self.phase == 'test' and test_set:
107+
test_set = set(test_set)
108+
if 'PDB_test' in test_set:
109+
test_set.update({'PDB_test-TS1', 'PDB_test-TS2', 'PDB_test-TS3'})
110+
print(f'test sets: {test_set}')
111+
self.file_list = [f for f in self.file_list if f['dataset'] in test_set]
112+
self.Lmax = max([f['length'] for f in self.file_list])
113+
else:
114+
raise NotImplementedError
115+
print(f'phase={self.phase}, num={len(self.file_list)}, nfolds={nfolds}: {index_file}')
116+
print(f'use_BPP={use_BPP}, use_BPE={use_BPE}')
117+
for dic in self.file_list:
118+
dic['path'] = os.path.join(self.data_dir, dic['path'])
119+
120+
if self.use_BPE:
121+
self.normalize_energy = normalize_energy
122+
self.BPM_ene = BPM_energy(path=os.path.join(para_dir, 'key.energy'))
123+
self.base_index = get_base_index() # for matrix embed
124+
self.num_base = len(self.base_index)
125+
self.index_base = {v: k for k, v in self.base_index.items()}
126+
self.token_index = {k: v for k, v in self.base_index.items()}
127+
self.token_index.update({tok: len(self.base_index)+i for i, tok in enumerate(['START', 'END', 'EMPTY'])}) # for sequence embed
128+
self.noncanonical = [self.index_base[i]+self.index_base[j] not in CANONICAL_PAIRS for i, j in product(range(self.num_base), range(self.num_base))]
129+
self.noncanonical_flag = np.array(self.noncanonical, dtype=bool)
130+
self.to_device_keywords = {'input', 'input_seqmat', 'mask', 'forward_mask', 'BPPM', 'BPEM', 'seq_onehot', 'nc_map',}
131+
if self.phase !='predict':
132+
self.to_device_keywords.add('gt')
133+
print(self.token_index)
134+
135+
136+
def __len__(self):
137+
return len(self.file_list)
138+
139+
def prepare_data(self, name, seq, connects=None):
140+
ret = {}
141+
L = len(seq)
142+
ret['ori_seq'] = seq.upper() # AUGC and others
143+
ret['seq'] = mut_seq(ret['ori_seq'].replace('T', 'U'), connects) # unknown -> AUGC
144+
ret['length'] = L
145+
146+
# mask: (Lmax+2)x(Lmax+2)
147+
mask = torch.zeros(self.Lmax + 2, self.Lmax+2, dtype=torch.bool)
148+
for row in range(1, L+1): # not including START and END
149+
mask[row, 1:L+1] = True
150+
# forward_mask: Lmax+2
151+
forward_mask = torch.zeros(self.Lmax + 2, dtype=torch.bool) # START, seq, END
152+
forward_mask[0:L+2] = True # including START and END
153+
ret['mask'] = mask
154+
ret['forward_mask'] = forward_mask
155+
156+
if self.mask_only:
157+
return ret, None
158+
159+
## pad to uniform size=(Lmax+2) when batch-loading
160+
rside_pad = self.Lmax + 1 - L
161+
162+
# seq_embed seq: Lmax+2
163+
ret['input'] = self.seq_embed_sequence(ret['seq'])
164+
165+
# seq_embed outer product
166+
seqmat, seq_onehot = self.seq_embed_matrix(ret['seq'], return_onehot=True) # N**2xLxL
167+
seqmat_pad = np.pad(seqmat, ((0, 0), (1,rside_pad), (1, rside_pad)), constant_values=0)
168+
ret['input_seqmat'] = torch.FloatTensor(seqmat_pad) # NUM_BASE**2 x (Lmax+2) x (Lmax+2)
169+
seq_onehot_pad = np.pad(seq_onehot, ((1, rside_pad), (0, 0)), constant_values=0)
170+
ret['seq_onehot'] = torch.FloatTensor(seq_onehot_pad) # LxNUM_BASE
171+
172+
# nc_map: noncanonical : (Lmax+2) x (Lmax+2)
173+
nc_map = seqmat[self.noncanonical_flag].sum(axis=0).astype(bool) # LxL
174+
nc_map_pad = np.pad(nc_map, ((1, rside_pad), (1, rside_pad)), constant_values=0)
175+
ret['nc_map'] = torch.FloatTensor(nc_map_pad)
176+
177+
# BPPM: 1x(Lmax+2)x(Lmax+2)
178+
if self.use_BPP:
179+
bppm = self.load_BPPM(seq=seq, name=name, use_cache=(self.phase!='predict'))
180+
bppm_pad = np.pad(bppm, ((1, rside_pad), (1, rside_pad)), constant_values=0)
181+
ret['BPPM'] = torch.FloatTensor(bppm_pad).unsqueeze(0)
182+
# ret['BPPM'] = torch.log(ret['BPPM']+1e-5) # Note: Energy ~ log(p)
183+
# ret['BPPM'] = - ret['BPPM']
184+
# BPEM: 1x(Lmax+2)x(Lmax+2)
185+
if self.use_BPE:
186+
bpem = self.BPM_ene.get_energy(ret['seq'], normalize_energy=self.normalize_energy, BPM_type=self.BPM_type)
187+
if self.normalize_energy:
188+
bpem_pad = np.pad(bpem, ((0, 0), (1, rside_pad), (1, rside_pad)), constant_values=0)
189+
ret['BPEM'] = torch.FloatTensor(bpem_pad)
190+
else:
191+
bpem_pad = np.pad(bpem, ((1, rside_pad), (1, rside_pad)), constant_values=0)
192+
ret['BPEM'] = torch.FloatTensor(bpem_pad).unsqueeze(0)
193+
194+
y = {k: ret[k] for k in ['mask', 'forward_mask', 'nc_map', 'seq_onehot']}
195+
# gt, contact map: (Lmax+2)x(Lmax+2)
196+
if self.phase != 'predict':
197+
gt = connects2arr(connects)
198+
gt_pad = np.pad(gt, ((1, rside_pad), (1, rside_pad)), constant_values=0)
199+
ret['gt'] = torch.FloatTensor(gt_pad)
200+
y['gt'] = ret['gt']
201+
return ret, y
202+
203+
def __getitem__(self, idx):
204+
info_dic = self.file_list[idx]
205+
206+
dataset = info_dic['dataset'] if 'dataset' in info_dic else 'RNAseq'
207+
seq = name = connects = None
208+
209+
if 'path' in info_dic:
210+
path = info_dic['path']
211+
name, suf = get_file_name(path, return_suf=True)
212+
if suf.lower() in {'.bpseq', '.ct', '.dbn'}:
213+
seq, connects = read_SS(path)
214+
else:
215+
with open(path) as fp:
216+
fp.readline()
217+
seq = fp.readline().strip(' \n')
218+
else:
219+
if 'seq' in info_dic:
220+
seq = info_dic['seq']
221+
else:
222+
raise Exception(f'[Error] seq or path needed: {info_dic}')
223+
name = info_dic['name'] if 'name' in info_dic else str_localtime()
224+
if connects is None and self.phase != 'predict':
225+
raise Exception(f'[Error] Invalid input: {info_dic} at {self.phase} stage, gt SS needed.')
226+
# # load_data
227+
# ret_data = {}
228+
# cache_path = os.path.join(self.cache_dir, name+'.pth')
229+
# if os.path.exists(cache_path):
230+
# ret_data = torch.load(cache_path)
231+
# else:
232+
# ret_data, y = self.prepare_data(name=name, seq=seq, connects=connects)
233+
# if not self.mask_only:
234+
# torch.save(ret_data, cache_path)
235+
ret_data, y = self.prepare_data(name=name, seq=seq, connects=connects)
236+
237+
if self.mask_only:
238+
return {k: ret_data[k] for k in ['mask', 'forward_mask', 'length']}
239+
# update ret dic
240+
ret_data.update({'name': name, 'idx': idx, 'dataset': dataset})
241+
return ret_data, y
242+
243+
def seq_embed_matrix(self, seq, return_onehot=False):
244+
'''
245+
seq: str, len=L, 'AUGC...'
246+
ret: tensor, NUM_BASE**2 x L x L, 0, 1 val
247+
'''
248+
L = len(seq)
249+
# seq onehot L x NUM_BASE
250+
seq_onehot = np.zeros((L, self.num_base), dtype=float)
251+
for i in range(L): # should be consistent to function `postprocess` in ..util.postprocess
252+
seq_onehot[i][self.base_index[seq[i]]] = 1
253+
254+
# seq embeding: NUM_BASE*NUM_BASE x L x L
255+
seq_embed = np.zeros((self.num_base**2, L, L)) # AUGC
256+
for n, (i, j) in enumerate(product(range(self.num_base), range(self.num_base))):
257+
seq_embed[n] = np.matmul(seq_onehot[:, i].reshape(-1, 1), seq_onehot[:, j].reshape(1, -1))
258+
if return_onehot:
259+
return seq_embed, seq_onehot
260+
else:
261+
return seq_embed
262+
263+
def seq_embed_sequence(self, seq):
264+
'''
265+
seq: str, len=L, 'AUGC...'
266+
ret: ndarray, Lmax+2, 0-6 val, repr 'start AUGC... end empty...'
267+
'''
268+
ret = [self.token_index['START']]
269+
ret.extend(self.token_index[s] for s in seq)
270+
ret.append(self.token_index['END'])
271+
272+
# Lmax + 2, final length
273+
for i in range(self.Lmax - len(seq)):
274+
ret.append(self.token_index['EMPTY'])
275+
ret = torch.Tensor(ret).int() # float? TODO
276+
return ret
277+
278+
279+
def load_BPPM(self, seq, name, use_cache=True):
280+
281+
txt_path = os.path.join(self.cache_dir, name+'.txt')
282+
npy_path = os.path.join(self.cache_dir, name+'.npy')
283+
284+
if use_cache and os.path.exists(npy_path):
285+
return np.load(npy_path, allow_pickle=True)
286+
else:
287+
if not os.path.exists(txt_path):
288+
try:
289+
if self.phase == 'predict' and self.verbose:
290+
print(f'[Info] Using "{self.method}" to generate BPPM, saving at "{txt_path}"')
291+
gen_BPPM(txt_path, mut_seq(seq), name, self.method)
292+
except Exception as e:
293+
if self.phase == 'predict' and self.verbose:
294+
print(f'[Warning] {e}, using CDPfold instead')
295+
gen_BPPM(txt_path, mut_seq(seq), name, 'CDPfold')
296+
BPPM = read_BPPM(txt_path, len(seq))
297+
np.save(npy_path, BPPM)
298+
return BPPM

BPfold/dataset/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .RNAseq import RNAseq_data
2+
3+
4+
def get_dataset(s):
5+
return {
6+
'rnaseq': RNAseq_data,
7+
}[s.lower()]

0 commit comments

Comments
 (0)