|
| 1 | +import os |
| 2 | +import yaml |
| 3 | +import random |
| 4 | +from itertools import product |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +import torch.utils.data as data |
| 9 | +from sklearn.model_selection import KFold |
| 10 | + |
| 11 | +from ..util.misc import get_file_name, str_localtime |
| 12 | +from ..util.RNA_kit import read_SS, connects2arr, mut_seq |
| 13 | +from ..util.postprocess import get_base_index |
| 14 | +from ..util.base_pair_motif import BPM_energy |
| 15 | +from ..util.base_pair_probability import read_BPPM, gen_BPPM |
| 16 | + |
| 17 | + |
| 18 | +CANONICAL_PAIRS = {'AU', 'UA', 'GC', 'CG', 'GU', 'UG'} |
| 19 | + |
| 20 | + |
| 21 | +class RNAseq_data(data.Dataset): |
| 22 | + def __init__(self, |
| 23 | + data_dir, |
| 24 | + index_name='data_index.yaml', |
| 25 | + phase='train', |
| 26 | + Lmax=600, |
| 27 | + Lmin=0, |
| 28 | + fold=0, |
| 29 | + nfolds=4, |
| 30 | + seed=42, |
| 31 | + cache_dir='.cache_data', |
| 32 | + mask_only=False, |
| 33 | + method='CDPfold', |
| 34 | + trainall=False, |
| 35 | + predict_files=None, |
| 36 | + training_set=None, |
| 37 | + test_set=None, |
| 38 | + use_BPE=True, |
| 39 | + use_BPP=True, |
| 40 | + normalize_energy=False, |
| 41 | + verbose=False, |
| 42 | + para_dir='paras/', |
| 43 | + BPM_type='all', |
| 44 | + *args, |
| 45 | + **kargs, |
| 46 | + ): |
| 47 | + # Set all input args as attributes |
| 48 | + self.__dict__.update(locals()) |
| 49 | + self.use_BPE = use_BPE |
| 50 | + self.use_BPP = use_BPP |
| 51 | + self.BPM_type = BPM_type |
| 52 | + self.phase = phase.lower() |
| 53 | + self.verbose = verbose |
| 54 | + if self.phase == 'predict': |
| 55 | + self.cache_dir = os.path.join(cache_dir, method) |
| 56 | + else: |
| 57 | + self.cache_dir = os.path.join(data_dir, cache_dir, method) |
| 58 | + self.method = method |
| 59 | + self.data_dir = data_dir |
| 60 | + self.mask_only = mask_only |
| 61 | + if not os.path.exists(self.cache_dir): |
| 62 | + os.makedirs(self.cache_dir) |
| 63 | + |
| 64 | + # data filter |
| 65 | + index_file = os.path.join(data_dir, index_name) |
| 66 | + if self.phase == 'predict': |
| 67 | + self.file_list = predict_files |
| 68 | + self.Lmax = max([f['length'] for f in self.file_list]) |
| 69 | + else: |
| 70 | + if self.phase in {'train', 'validate'}: |
| 71 | + with open(index_file) as f: |
| 72 | + index_dic = yaml.load(f.read(), Loader=yaml.FullLoader) |
| 73 | + all_files = index_dic['train'][:] |
| 74 | + if 'validate' in index_dic: |
| 75 | + all_files += index_dic['validate'] |
| 76 | + if trainall: |
| 77 | + all_files += index_dic['test'] |
| 78 | + # train with specific datasets |
| 79 | + if self.phase == 'train' and training_set: |
| 80 | + training_set = set(training_set) |
| 81 | + if 'PDB_test' in training_set: |
| 82 | + training_set.update({'PDB_test-TS1', 'PDB_test-TS2', 'PDB_test-TS3'}) |
| 83 | + print(f'training sets: {training_set}') |
| 84 | + all_files = [f for f in all_files if f['dataset'] in training_set] |
| 85 | + |
| 86 | + all_files.sort(key=lambda dic: dic['path']) |
| 87 | + |
| 88 | + # Kfold |
| 89 | + if nfolds <= 1: # no kfold |
| 90 | + if self.phase == 'validate': |
| 91 | + self.file_list = index_dic['test'][:] |
| 92 | + else: # train |
| 93 | + self.file_list = all_files[:] |
| 94 | + random.shuffle(self.file_list) |
| 95 | + else: # kfold training |
| 96 | + split = list(KFold(n_splits=nfolds, random_state=seed, shuffle=True).split(range(len(all_files))))[fold][0 if self.phase == 'train' else 1] |
| 97 | + self.file_list = [all_files[i] for i in split] |
| 98 | + self.Lmax = Lmax |
| 99 | + # limit length |
| 100 | + self.file_list = [f for f in self.file_list if Lmin<=f['length']<=self.Lmax] |
| 101 | + |
| 102 | + elif self.phase in ['test']: |
| 103 | + with open(index_file) as f: |
| 104 | + index_dic = yaml.load(f.read(), Loader=yaml.FullLoader) |
| 105 | + self.file_list = index_dic['test'][:] |
| 106 | + if self.phase == 'test' and test_set: |
| 107 | + test_set = set(test_set) |
| 108 | + if 'PDB_test' in test_set: |
| 109 | + test_set.update({'PDB_test-TS1', 'PDB_test-TS2', 'PDB_test-TS3'}) |
| 110 | + print(f'test sets: {test_set}') |
| 111 | + self.file_list = [f for f in self.file_list if f['dataset'] in test_set] |
| 112 | + self.Lmax = max([f['length'] for f in self.file_list]) |
| 113 | + else: |
| 114 | + raise NotImplementedError |
| 115 | + print(f'phase={self.phase}, num={len(self.file_list)}, nfolds={nfolds}: {index_file}') |
| 116 | + print(f'use_BPP={use_BPP}, use_BPE={use_BPE}') |
| 117 | + for dic in self.file_list: |
| 118 | + dic['path'] = os.path.join(self.data_dir, dic['path']) |
| 119 | + |
| 120 | + if self.use_BPE: |
| 121 | + self.normalize_energy = normalize_energy |
| 122 | + self.BPM_ene = BPM_energy(path=os.path.join(para_dir, 'key.energy')) |
| 123 | + self.base_index = get_base_index() # for matrix embed |
| 124 | + self.num_base = len(self.base_index) |
| 125 | + self.index_base = {v: k for k, v in self.base_index.items()} |
| 126 | + self.token_index = {k: v for k, v in self.base_index.items()} |
| 127 | + self.token_index.update({tok: len(self.base_index)+i for i, tok in enumerate(['START', 'END', 'EMPTY'])}) # for sequence embed |
| 128 | + self.noncanonical = [self.index_base[i]+self.index_base[j] not in CANONICAL_PAIRS for i, j in product(range(self.num_base), range(self.num_base))] |
| 129 | + self.noncanonical_flag = np.array(self.noncanonical, dtype=bool) |
| 130 | + self.to_device_keywords = {'input', 'input_seqmat', 'mask', 'forward_mask', 'BPPM', 'BPEM', 'seq_onehot', 'nc_map',} |
| 131 | + if self.phase !='predict': |
| 132 | + self.to_device_keywords.add('gt') |
| 133 | + print(self.token_index) |
| 134 | + |
| 135 | + |
| 136 | + def __len__(self): |
| 137 | + return len(self.file_list) |
| 138 | + |
| 139 | + def prepare_data(self, name, seq, connects=None): |
| 140 | + ret = {} |
| 141 | + L = len(seq) |
| 142 | + ret['ori_seq'] = seq.upper() # AUGC and others |
| 143 | + ret['seq'] = mut_seq(ret['ori_seq'].replace('T', 'U'), connects) # unknown -> AUGC |
| 144 | + ret['length'] = L |
| 145 | + |
| 146 | + # mask: (Lmax+2)x(Lmax+2) |
| 147 | + mask = torch.zeros(self.Lmax + 2, self.Lmax+2, dtype=torch.bool) |
| 148 | + for row in range(1, L+1): # not including START and END |
| 149 | + mask[row, 1:L+1] = True |
| 150 | + # forward_mask: Lmax+2 |
| 151 | + forward_mask = torch.zeros(self.Lmax + 2, dtype=torch.bool) # START, seq, END |
| 152 | + forward_mask[0:L+2] = True # including START and END |
| 153 | + ret['mask'] = mask |
| 154 | + ret['forward_mask'] = forward_mask |
| 155 | + |
| 156 | + if self.mask_only: |
| 157 | + return ret, None |
| 158 | + |
| 159 | + ## pad to uniform size=(Lmax+2) when batch-loading |
| 160 | + rside_pad = self.Lmax + 1 - L |
| 161 | + |
| 162 | + # seq_embed seq: Lmax+2 |
| 163 | + ret['input'] = self.seq_embed_sequence(ret['seq']) |
| 164 | + |
| 165 | + # seq_embed outer product |
| 166 | + seqmat, seq_onehot = self.seq_embed_matrix(ret['seq'], return_onehot=True) # N**2xLxL |
| 167 | + seqmat_pad = np.pad(seqmat, ((0, 0), (1,rside_pad), (1, rside_pad)), constant_values=0) |
| 168 | + ret['input_seqmat'] = torch.FloatTensor(seqmat_pad) # NUM_BASE**2 x (Lmax+2) x (Lmax+2) |
| 169 | + seq_onehot_pad = np.pad(seq_onehot, ((1, rside_pad), (0, 0)), constant_values=0) |
| 170 | + ret['seq_onehot'] = torch.FloatTensor(seq_onehot_pad) # LxNUM_BASE |
| 171 | + |
| 172 | + # nc_map: noncanonical : (Lmax+2) x (Lmax+2) |
| 173 | + nc_map = seqmat[self.noncanonical_flag].sum(axis=0).astype(bool) # LxL |
| 174 | + nc_map_pad = np.pad(nc_map, ((1, rside_pad), (1, rside_pad)), constant_values=0) |
| 175 | + ret['nc_map'] = torch.FloatTensor(nc_map_pad) |
| 176 | + |
| 177 | + # BPPM: 1x(Lmax+2)x(Lmax+2) |
| 178 | + if self.use_BPP: |
| 179 | + bppm = self.load_BPPM(seq=seq, name=name, use_cache=(self.phase!='predict')) |
| 180 | + bppm_pad = np.pad(bppm, ((1, rside_pad), (1, rside_pad)), constant_values=0) |
| 181 | + ret['BPPM'] = torch.FloatTensor(bppm_pad).unsqueeze(0) |
| 182 | + # ret['BPPM'] = torch.log(ret['BPPM']+1e-5) # Note: Energy ~ log(p) |
| 183 | + # ret['BPPM'] = - ret['BPPM'] |
| 184 | + # BPEM: 1x(Lmax+2)x(Lmax+2) |
| 185 | + if self.use_BPE: |
| 186 | + bpem = self.BPM_ene.get_energy(ret['seq'], normalize_energy=self.normalize_energy, BPM_type=self.BPM_type) |
| 187 | + if self.normalize_energy: |
| 188 | + bpem_pad = np.pad(bpem, ((0, 0), (1, rside_pad), (1, rside_pad)), constant_values=0) |
| 189 | + ret['BPEM'] = torch.FloatTensor(bpem_pad) |
| 190 | + else: |
| 191 | + bpem_pad = np.pad(bpem, ((1, rside_pad), (1, rside_pad)), constant_values=0) |
| 192 | + ret['BPEM'] = torch.FloatTensor(bpem_pad).unsqueeze(0) |
| 193 | + |
| 194 | + y = {k: ret[k] for k in ['mask', 'forward_mask', 'nc_map', 'seq_onehot']} |
| 195 | + # gt, contact map: (Lmax+2)x(Lmax+2) |
| 196 | + if self.phase != 'predict': |
| 197 | + gt = connects2arr(connects) |
| 198 | + gt_pad = np.pad(gt, ((1, rside_pad), (1, rside_pad)), constant_values=0) |
| 199 | + ret['gt'] = torch.FloatTensor(gt_pad) |
| 200 | + y['gt'] = ret['gt'] |
| 201 | + return ret, y |
| 202 | + |
| 203 | + def __getitem__(self, idx): |
| 204 | + info_dic = self.file_list[idx] |
| 205 | + |
| 206 | + dataset = info_dic['dataset'] if 'dataset' in info_dic else 'RNAseq' |
| 207 | + seq = name = connects = None |
| 208 | + |
| 209 | + if 'path' in info_dic: |
| 210 | + path = info_dic['path'] |
| 211 | + name, suf = get_file_name(path, return_suf=True) |
| 212 | + if suf.lower() in {'.bpseq', '.ct', '.dbn'}: |
| 213 | + seq, connects = read_SS(path) |
| 214 | + else: |
| 215 | + with open(path) as fp: |
| 216 | + fp.readline() |
| 217 | + seq = fp.readline().strip(' \n') |
| 218 | + else: |
| 219 | + if 'seq' in info_dic: |
| 220 | + seq = info_dic['seq'] |
| 221 | + else: |
| 222 | + raise Exception(f'[Error] seq or path needed: {info_dic}') |
| 223 | + name = info_dic['name'] if 'name' in info_dic else str_localtime() |
| 224 | + if connects is None and self.phase != 'predict': |
| 225 | + raise Exception(f'[Error] Invalid input: {info_dic} at {self.phase} stage, gt SS needed.') |
| 226 | + # # load_data |
| 227 | + # ret_data = {} |
| 228 | + # cache_path = os.path.join(self.cache_dir, name+'.pth') |
| 229 | + # if os.path.exists(cache_path): |
| 230 | + # ret_data = torch.load(cache_path) |
| 231 | + # else: |
| 232 | + # ret_data, y = self.prepare_data(name=name, seq=seq, connects=connects) |
| 233 | + # if not self.mask_only: |
| 234 | + # torch.save(ret_data, cache_path) |
| 235 | + ret_data, y = self.prepare_data(name=name, seq=seq, connects=connects) |
| 236 | + |
| 237 | + if self.mask_only: |
| 238 | + return {k: ret_data[k] for k in ['mask', 'forward_mask', 'length']} |
| 239 | + # update ret dic |
| 240 | + ret_data.update({'name': name, 'idx': idx, 'dataset': dataset}) |
| 241 | + return ret_data, y |
| 242 | + |
| 243 | + def seq_embed_matrix(self, seq, return_onehot=False): |
| 244 | + ''' |
| 245 | + seq: str, len=L, 'AUGC...' |
| 246 | + ret: tensor, NUM_BASE**2 x L x L, 0, 1 val |
| 247 | + ''' |
| 248 | + L = len(seq) |
| 249 | + # seq onehot L x NUM_BASE |
| 250 | + seq_onehot = np.zeros((L, self.num_base), dtype=float) |
| 251 | + for i in range(L): # should be consistent to function `postprocess` in ..util.postprocess |
| 252 | + seq_onehot[i][self.base_index[seq[i]]] = 1 |
| 253 | + |
| 254 | + # seq embeding: NUM_BASE*NUM_BASE x L x L |
| 255 | + seq_embed = np.zeros((self.num_base**2, L, L)) # AUGC |
| 256 | + for n, (i, j) in enumerate(product(range(self.num_base), range(self.num_base))): |
| 257 | + seq_embed[n] = np.matmul(seq_onehot[:, i].reshape(-1, 1), seq_onehot[:, j].reshape(1, -1)) |
| 258 | + if return_onehot: |
| 259 | + return seq_embed, seq_onehot |
| 260 | + else: |
| 261 | + return seq_embed |
| 262 | + |
| 263 | + def seq_embed_sequence(self, seq): |
| 264 | + ''' |
| 265 | + seq: str, len=L, 'AUGC...' |
| 266 | + ret: ndarray, Lmax+2, 0-6 val, repr 'start AUGC... end empty...' |
| 267 | + ''' |
| 268 | + ret = [self.token_index['START']] |
| 269 | + ret.extend(self.token_index[s] for s in seq) |
| 270 | + ret.append(self.token_index['END']) |
| 271 | + |
| 272 | + # Lmax + 2, final length |
| 273 | + for i in range(self.Lmax - len(seq)): |
| 274 | + ret.append(self.token_index['EMPTY']) |
| 275 | + ret = torch.Tensor(ret).int() # float? TODO |
| 276 | + return ret |
| 277 | + |
| 278 | + |
| 279 | + def load_BPPM(self, seq, name, use_cache=True): |
| 280 | + |
| 281 | + txt_path = os.path.join(self.cache_dir, name+'.txt') |
| 282 | + npy_path = os.path.join(self.cache_dir, name+'.npy') |
| 283 | + |
| 284 | + if use_cache and os.path.exists(npy_path): |
| 285 | + return np.load(npy_path, allow_pickle=True) |
| 286 | + else: |
| 287 | + if not os.path.exists(txt_path): |
| 288 | + try: |
| 289 | + if self.phase == 'predict' and self.verbose: |
| 290 | + print(f'[Info] Using "{self.method}" to generate BPPM, saving at "{txt_path}"') |
| 291 | + gen_BPPM(txt_path, mut_seq(seq), name, self.method) |
| 292 | + except Exception as e: |
| 293 | + if self.phase == 'predict' and self.verbose: |
| 294 | + print(f'[Warning] {e}, using CDPfold instead') |
| 295 | + gen_BPPM(txt_path, mut_seq(seq), name, 'CDPfold') |
| 296 | + BPPM = read_BPPM(txt_path, len(seq)) |
| 297 | + np.save(npy_path, BPPM) |
| 298 | + return BPPM |
0 commit comments