-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added random split for non-SMILES identifiers #452
Merged
kjappelbaum
merged 18 commits into
OpenBioML:main
from
AdrianM0:train_test_split_addition
Nov 17, 2023
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
7173116
feat: added random split for non-SMILES identifiers
AdrianM0 089073c
fix: commented out debug lines
AdrianM0 1b6dfc7
feat: add non-SMILES for debug. add checks for non-SMILES + data-saver
AdrianM0 517a084
feat: generalized train-test split
AdrianM0 32e70c4
lint
AdrianM0 0b88457
fix: error
AdrianM0 06b3d07
lint
AdrianM0 912f6b8
fix: double reading and prioritize SMILES
AdrianM0 0d407c9
chore: pre-commit
AdrianM0 30881de
chore: var names more pythonic
AdrianM0 1cd82fc
update composition type
2edaccf
add AS_SEQUENCE as type
973d08e
update representation list
27dac3e
factor out scaffold split
abe158e
update scaffold split code
3b8edda
Merge branch 'main' into train_test_split_addition
kjappelbaum 4e39a73
lint
1130991
lint and `low_memory=False`
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import os | ||
import subprocess | ||
from functools import partial | ||
from glob import glob | ||
from typing import List | ||
|
||
import fire | ||
import pandas as pd | ||
import yaml | ||
from pandarallel import pandarallel | ||
|
||
from chemnlp.data.split import _create_scaffold_split | ||
|
||
pandarallel.initialize(progress_bar=True) | ||
|
||
to_scaffold_split = [ | ||
"blood_brain_barrier_martins_et_al", | ||
"serine_threonine_kinase_33_butkiewicz", | ||
"ld50_zhu", | ||
"herg_blockers", | ||
"clearance_astrazeneca", | ||
"half_life_obach", | ||
"drug_induced_liver_injury", | ||
"kcnq2_potassium_channel_butkiewicz", | ||
"sarscov2_3clpro_diamond", | ||
"volume_of_distribution_at_steady_state_lombardo_et_al", | ||
"sr_hse_tox21", | ||
"nr_er_tox21", | ||
"cyp2c9_substrate_carbonmangels", | ||
"nr_aromatase_tox21", | ||
"cyp_p450_2d6_inhibition_veith_et_al", | ||
"cyp_p450_1a2_inhibition_veith_et_al", | ||
"cyp_p450_2c9_inhibition_veith_et_al", | ||
"m1_muscarinic_receptor_antagonists_butkiewicz", | ||
"nr_ar_tox21", | ||
"sr_atad5_tox21", | ||
"tyrosyl-dna_phosphodiesterase_butkiewicz", | ||
"cav3_t-type_calcium_channels_butkiewicz", | ||
"clintox" "sr_p53_tox21", | ||
"nr_er_lbd_tox21" "pampa_ncats", | ||
"sr_mmp_tox21", | ||
"caco2_wang", | ||
"sarscov2_vitro_touret", | ||
"choline_transporter_butkiewicz", | ||
"orexin1_receptor_butkiewicz", | ||
"human_intestinal_absorption", | ||
"nr_ahr_tox21", | ||
"cyp3a4_substrate_carbonmangels", | ||
"herg_karim_et_al", | ||
"hiv", | ||
"carcinogens", | ||
"sr_are_tox21", | ||
"nr_ppar_gamma_tox21", | ||
"solubility_aqsoldb", | ||
"m1_muscarinic_receptor_agonists_butkiewicz", | ||
"ames_mutagenicity", | ||
"potassium_ion_channel_kir2_1_butkiewicz", | ||
"cyp_p450_3a4_inhibition_veith_et_al", | ||
"skin_reaction", | ||
"cyp2d6_substrate_carbonmangels", | ||
"cyp_p450_2c19_inhibition_veith_et_al", | ||
"nr_ar_lbd_tox21", | ||
"p_glycoprotein_inhibition_broccatelli_et_al", | ||
"bioavailability_ma_et_al", | ||
"BACE", | ||
"hiv", | ||
"lipophilicity", | ||
"thermosol", | ||
"MUV_466", | ||
"MUV_548", | ||
"MUV_600", | ||
"MUV_644", | ||
"MUV_652", | ||
"MUV_689", | ||
"MUV_692", | ||
"MUV_712", | ||
"MUV_713", | ||
"MUV_733", | ||
"MUV_737", | ||
"MUV_810", | ||
"MUV_832", | ||
"MUV_846", | ||
"MUV_852", | ||
"MUV_858", | ||
"MUV_859", | ||
] | ||
|
||
|
||
def split_for_smiles(smiles: str, train_smiles: List[str], val_smiles: List[str]): | ||
if smiles in train_smiles: | ||
return "train" | ||
elif smiles in val_smiles: | ||
return "valid" | ||
else: | ||
return "test" | ||
|
||
|
||
def main( | ||
data_dir, | ||
override: bool = False, | ||
train_frac: float = 0.7, | ||
val_frac: float = 0.15, | ||
test_frac: float = 0.15, | ||
seed: int = 42, | ||
debug: bool = False, | ||
): | ||
all_yaml_files = glob(os.path.join(data_dir, "**", "*.yaml"), recursive=True) | ||
if debug: | ||
all_yaml_files = all_yaml_files[:5] | ||
transformed_files = [] | ||
for file in all_yaml_files: | ||
with open(file, "r") as f: | ||
meta = yaml.safe_load(f) | ||
|
||
if meta["name"] in to_scaffold_split: | ||
# run transform.py script in this directory if `data_clean.csv` does not exist | ||
# use this dir as cwd | ||
if not os.path.exists( | ||
os.path.join(os.path.dirname(file), "data_clean.csv") | ||
): | ||
subprocess.run( | ||
["python", "transform.py"], | ||
cwd=os.path.dirname(file), | ||
) | ||
|
||
transformed_files.append( | ||
os.path.join(os.path.dirname(file), "data_clean.csv") | ||
) | ||
|
||
all_smiles = set() | ||
for file in transformed_files: | ||
df = pd.read_csv(file, low_memory=False) | ||
all_smiles.update(df["SMILES"].tolist()) | ||
del df # ensure memory is freed | ||
|
||
all_smiles = list(all_smiles) | ||
splits = _create_scaffold_split( | ||
all_smiles, frac=[train_frac, val_frac, test_frac], seed=42 | ||
) | ||
|
||
# select the right indices for each split | ||
train_smiles = [all_smiles[i] for i in splits["train"]] | ||
val_smiles = [all_smiles[i] for i in splits["valid"]] | ||
print( | ||
"Train smiles:", | ||
len(train_smiles), | ||
"Val smiles:", | ||
len(val_smiles), | ||
"Test smiles:", | ||
len(splits["test"]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cause we do not allocate a list of smiles for test (we only have the indices in the dict) |
||
) | ||
|
||
# now for each dataframe, add a split column based on in which list in the `splits` dict the smiles is | ||
# smiles are in the SMILES column | ||
# if the override option is true, we write this new file to `data_clean.csv` in the same directory | ||
# otherwise, we copy the old `data_clean.csv` to `data_clean_old.csv` and write the new file to `data_clean.csv` | ||
# we print the train/val/test split sizes and fractions for each dataset | ||
|
||
split_for_smiles_curried = partial( | ||
split_for_smiles, train_smiles=train_smiles, val_smiles=val_smiles | ||
) | ||
for file in transformed_files: | ||
df = pd.read_csv(file, low_memory=False) | ||
df["split"] = df["SMILES"].parallel_apply(split_for_smiles_curried) | ||
|
||
# to ensure overall scaffold splitting does not distort train/val/test split sizes for each dataset | ||
print( | ||
f"Dataset {file} has {len(df)} datapoints. Split sizes: {len(df[df['split'] == 'train'])} train, {len(df[df['split'] == 'valid'])} valid, {len(df[df['split'] == 'test'])} test." # noqa: E501 | ||
) | ||
print( | ||
f"Dataset {file} has {len(df)} datapoints. Split fractions: {len(df[df['split'] == 'train']) / len(df)} train, {len(df[df['split'] == 'valid']) / len(df)} valid, {len(df[df['split'] == 'test']) / len(df)} test." # noqa: E501 | ||
) | ||
if override: | ||
# write the new data_clean.csv file | ||
df.to_csv(file, index=False) | ||
else: | ||
# copy the old data_clean.csv to data_clean_old.csv | ||
os.rename(file, file.replace(".csv", "_old.csv")) | ||
# write the new data_clean.csv file | ||
df.to_csv(file, index=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you are running "transform.py" you are not ensuring that the additional mol. reprs. are present. But we can keep it as it is, but just to mention that here.