-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added random split for non-SMILES identifiers #452
Changes from 9 commits
7173116
089073c
1b6dfc7
517a084
32e70c4
0b88457
06b3d07
912f6b8
0d407c9
30881de
1cd82fc
2edaccf
973d08e
27dac3e
abe158e
3b8edda
4e39a73
1130991
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
from typing import Dict, List | ||
|
||
import fire | ||
import numpy as np | ||
import pandas as pd | ||
from rdkit import Chem, RDLogger | ||
from rdkit.Chem.Scaffolds import MurckoScaffold | ||
|
@@ -29,6 +30,20 @@ | |
RDLogger.DisableLog("rdApp.*") | ||
|
||
|
||
REPRESENTATION_LIST = [ | ||
"SMILES", | ||
"PSMILES", | ||
"SELFIES", | ||
"RXNSMILES", | ||
"RXNSMILESWAdd", | ||
"IUPAC", | ||
"InChI", | ||
"InChIKey", | ||
"Compositions", | ||
"Sentence", | ||
] | ||
|
||
|
||
def print_sys(s): | ||
"""system print | ||
|
||
|
@@ -120,10 +135,10 @@ def create_scaffold_split( | |
|
||
def rewrite_data_with_splits( | ||
csv_paths: List[str], | ||
repr_col: str, | ||
train_test_df: pd.DataFrame, | ||
override: bool = False, | ||
check: bool = True, | ||
repr_col: str = "SMILES", | ||
) -> None: | ||
"""Rewrite dataframes with the correct split column | ||
|
||
|
@@ -138,8 +153,12 @@ def rewrite_data_with_splits( | |
repr_col (str): the column name for where SMILES representation is stored | ||
defaults to "SMILES" | ||
""" | ||
|
||
for path in csv_paths: | ||
read_dataset = pd.read_csv(path) | ||
|
||
if check: | ||
train_smiles = set(train_test_df.query("split == 'train'")["SMILES"].to_list()) | ||
train_smiles = set(train_test_df.query("split == 'train'")[repr_col].to_list()) | ||
|
||
for path in csv_paths: | ||
read_dataset = pd.read_csv(path) | ||
|
@@ -153,11 +172,8 @@ def rewrite_data_with_splits( | |
except KeyError: | ||
print(f"No split column in {path}") | ||
|
||
col_to_merge = "SMILES" | ||
merged_data = pd.merge( | ||
read_dataset, train_test_df, on=col_to_merge, how="left" | ||
) | ||
merged_data = merged_data.dropna() | ||
merged_data = pd.merge(read_dataset, train_test_df, on=repr_col, how="left") | ||
# merged_data = merged_data.dropna() | ||
if override: | ||
merged_data.to_csv(path, index=False) | ||
else: | ||
|
@@ -172,23 +188,24 @@ def rewrite_data_with_splits( | |
raise ValueError("Split failed, no test data") | ||
if check: | ||
test_split_smiles = set( | ||
merged_data.query("split == 'test'")["SMILES"].to_list() | ||
merged_data.query("split == 'test'")[repr_col].to_list() | ||
) | ||
if len(train_smiles.intersection(test_split_smiles)) > 0: | ||
raise ValueError("Split failed, train and test overlap") | ||
else: | ||
print(f"Skipping {path} as it does not contain {repr_col} column") | ||
|
||
|
||
def cli( | ||
TRACKED_DATASETS = [] | ||
|
||
|
||
def per_repr( | ||
repr_col: str, | ||
seed: int = 42, | ||
train_size: float = 0.8, | ||
val_size: float = 0.0, | ||
test_size: float = 0.2, | ||
path: str = "*/data_clean.csv", | ||
override: bool = False, | ||
check: bool = True, | ||
repr_col: str = "SMILES", | ||
): | ||
paths_to_data = glob(path) | ||
|
||
|
@@ -197,40 +214,86 @@ def cli( | |
# for path in paths_to_data: | ||
# if "flashpoint" in path: | ||
# filtered_paths.append(path) | ||
# elif "freesolv" in path: | ||
# elif "BBBP" in path: | ||
# filtered_paths.append(path) | ||
# elif "peptide" in path: | ||
# filtered_paths.append(path) | ||
# elif "bicerano_dataset" in path: | ||
# filtered_paths.append(path) | ||
# paths_to_data = filtered_paths | ||
|
||
REPRESENTATION_LIST = [] | ||
|
||
for path in tqdm(paths_to_data): | ||
for path in paths_to_data: | ||
df = pd.read_csv(path) | ||
if repr_col in df.columns: | ||
|
||
if repr_col in df.columns and path not in TRACKED_DATASETS: | ||
print("Processing", path.split("/")[0]) | ||
TRACKED_DATASETS.append(path) | ||
REPRESENTATION_LIST.extend(df[repr_col].to_list()) | ||
|
||
REPR_DF = pd.DataFrame() | ||
REPR_DF["SMILES"] = list(set(REPRESENTATION_LIST)) | ||
REPR_DF[repr_col] = list(set(REPRESENTATION_LIST)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in general, there should not need to be a need to allocate a dataframe for this (which can cost more memory). But it should work for our needs. |
||
|
||
scaffold_split = create_scaffold_split( | ||
REPR_DF, seed=seed, frac=[train_size, val_size, test_size] | ||
) | ||
if "SMILES" in REPR_DF.columns: | ||
split = create_scaffold_split( | ||
REPR_DF, seed=seed, frac=[train_size, val_size, test_size] | ||
) | ||
else: | ||
split_ = pd.DataFrame( | ||
np.random.choice( | ||
["train", "test", "valid"], | ||
size=len(REPR_DF), | ||
p=[1 - val_size - test_size, test_size, val_size], | ||
) | ||
) | ||
|
||
REPR_DF["split"] = split_ | ||
train = REPR_DF.query("split == 'train'").reset_index(drop=True) | ||
test = REPR_DF.query("split == 'test'").reset_index(drop=True) | ||
valid = REPR_DF.query("split == 'valid'").reset_index(drop=True) | ||
|
||
split = {"train": train, "test": test, "valid": valid} | ||
|
||
# create train and test dataframes | ||
train_df = scaffold_split["train"] | ||
test_df = scaffold_split["test"] | ||
train_df = split["train"] | ||
test_df = split["test"] | ||
valid_df = split["valid"] | ||
# add split columns to train and test dataframes | ||
train_df["split"] = len(train_df) * ["train"] | ||
test_df["split"] = len(test_df) * ["test"] | ||
valid_df["split"] = len(valid_df) * ["valid"] | ||
|
||
# merge train and test across all datasets | ||
merge = pd.concat([train_df, test_df], axis=0) | ||
merge = pd.concat([train_df, test_df, valid_df], axis=0) | ||
# rewrite data_clean.csv for each dataset | ||
rewrite_data_with_splits( | ||
paths_to_data, merge, override=override, check=check, repr_col=repr_col | ||
csv_paths=paths_to_data, | ||
repr_col=repr_col, | ||
train_test_df=merge, | ||
override=override, | ||
check=check, | ||
) | ||
|
||
|
||
def cli( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, this should do what we want it to do. I'll walk through it once more tomorrow with a fresh mind |
||
seed: int = 42, | ||
train_size: float = 0.75, | ||
val_size: float = 0.125, | ||
test_size: float = 0.125, | ||
path: str = "*/data_clean.csv", | ||
override: bool = False, | ||
check: bool = True, | ||
): | ||
for representation in tqdm(REPRESENTATION_LIST): | ||
if representation == "SMILES": | ||
print("Processing priority representation: SMILES") | ||
else: | ||
print("Processing datasets with the representation column:", representation) | ||
per_repr( | ||
representation, seed, train_size, val_size, test_size, path, override, check | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(cli) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In LN33 and here you use the same variable name but I guess for a different use case, or am I wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that is a global variable in LN33, but I think it is a good idea to change this local one to proper convention with lower-case names 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MicPie now should be better