|  | 
| 1 | 1 | import logging | 
| 2 | 2 | import os | 
|  | 3 | +import random | 
| 3 | 4 | from pathlib import Path | 
| 4 | 5 | 
 | 
| 5 | 6 | import atom3.pair as pa | 
| 6 | 7 | import click | 
| 7 | 8 | import pandas as pd | 
| 8 | 9 | from atom3 import database as db | 
| 9 |  | -from sklearn.model_selection import train_test_split | 
| 10 | 10 | from tqdm import tqdm | 
| 11 | 11 | 
 | 
| 12 | 12 | from project.utils.constants import DB5_TEST_PDB_CODES, ATOM_COUNT_LIMIT | 
| @@ -64,15 +64,25 @@ def main(output_dir: str, source_type: str, filter_by_atom_count: bool, max_atom | 
| 64 | 64 |             if not os.path.exists(pairs_postprocessed_val_txt):  # Create val data list if not already existent | 
| 65 | 65 |                 open(pairs_postprocessed_val_txt, 'w').close() | 
| 66 | 66 |             # Write out training-validation partitions for DIPS | 
|  | 67 | +            output_dirs = [filename | 
|  | 68 | +                           for filename in os.listdir(output_dir) | 
|  | 69 | +                           if os.path.isdir(os.path.join(output_dir, filename))] | 
|  | 70 | +            # Get training and validation directories separately | 
|  | 71 | +            num_train_dirs = int(0.8 * len(output_dirs)) | 
|  | 72 | +            train_dirs = random.sample(output_dirs, num_train_dirs) | 
|  | 73 | +            val_dirs = list(set(output_dirs) - set(train_dirs)) | 
|  | 74 | +            # Ascertain training and validation filename separately | 
| 67 | 75 |             filenames_frame = pd.read_csv(pairs_postprocessed_txt, header=None) | 
| 68 |  | -            train_filenames_frame, val_filenames_frame, _, _ = train_test_split( | 
| 69 |  | -                filenames_frame, | 
| 70 |  | -                # Ignore labels for now - will create feature vectors in dataset class | 
| 71 |  | -                [None for _ in range(len(filenames_frame))], | 
| 72 |  | -                train_size=(8 / 10), | 
| 73 |  | -                test_size=(2 / 10) | 
| 74 |  | -            ) | 
|  | 76 | +            train_filenames = [os.path.join(train_dir, filename) | 
|  | 77 | +                               for train_dir in train_dirs | 
|  | 78 | +                               for filename in os.listdir(os.path.join(output_dir, train_dir)) | 
|  | 79 | +                               if os.path.join(train_dir, filename) in filenames_frame.values] | 
|  | 80 | +            val_filenames = [os.path.join(val_dir, filename) | 
|  | 81 | +                             for val_dir in val_dirs | 
|  | 82 | +                             for filename in os.listdir(os.path.join(output_dir, val_dir)) | 
|  | 83 | +                             if os.path.join(val_dir, filename) in filenames_frame.values] | 
| 75 | 84 |             # Create separate .txt files to describe the training list and validation list, respectively | 
|  | 85 | +            train_filenames_frame, val_filenames_frame = pd.DataFrame(train_filenames), pd.DataFrame(val_filenames) | 
| 76 | 86 |             train_filenames_frame.to_csv(pairs_postprocessed_train_txt, header=None, index=None, sep=' ', mode='a') | 
| 77 | 87 |             val_filenames_frame.to_csv(pairs_postprocessed_val_txt, header=None, index=None, sep=' ', mode='a') | 
| 78 | 88 | 
 | 
|  | 
0 commit comments