Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added random split for non-SMILES identifiers #452

Merged
merged 18 commits into from
Nov 17, 2023
Merged
109 changes: 86 additions & 23 deletions data/tabular/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Dict, List

import fire
import numpy as np
import pandas as pd
from rdkit import Chem, RDLogger
from rdkit.Chem.Scaffolds import MurckoScaffold
Expand All @@ -29,6 +30,20 @@
RDLogger.DisableLog("rdApp.*")


REPRESENTATION_LIST = [
"SMILES",
"PSMILES",
"SELFIES",
"RXNSMILES",
"RXNSMILESWAdd",
"IUPAC",
"InChI",
"InChIKey",
"Compositions",
"Sentence",
]


def print_sys(s):
"""system print

Expand Down Expand Up @@ -120,10 +135,10 @@ def create_scaffold_split(

def rewrite_data_with_splits(
csv_paths: List[str],
repr_col: str,
train_test_df: pd.DataFrame,
override: bool = False,
check: bool = True,
repr_col: str = "SMILES",
) -> None:
"""Rewrite dataframes with the correct split column

Expand All @@ -138,8 +153,12 @@ def rewrite_data_with_splits(
repr_col (str): the column name for where SMILES representation is stored
defaults to "SMILES"
"""

for path in csv_paths:
read_dataset = pd.read_csv(path)

if check:
train_smiles = set(train_test_df.query("split == 'train'")["SMILES"].to_list())
train_smiles = set(train_test_df.query("split == 'train'")[repr_col].to_list())

for path in csv_paths:
read_dataset = pd.read_csv(path)
Expand All @@ -153,11 +172,8 @@ def rewrite_data_with_splits(
except KeyError:
print(f"No split column in {path}")

col_to_merge = "SMILES"
merged_data = pd.merge(
read_dataset, train_test_df, on=col_to_merge, how="left"
)
merged_data = merged_data.dropna()
merged_data = pd.merge(read_dataset, train_test_df, on=repr_col, how="left")
# merged_data = merged_data.dropna()
if override:
merged_data.to_csv(path, index=False)
else:
Expand All @@ -172,23 +188,24 @@ def rewrite_data_with_splits(
raise ValueError("Split failed, no test data")
if check:
test_split_smiles = set(
merged_data.query("split == 'test'")["SMILES"].to_list()
merged_data.query("split == 'test'")[repr_col].to_list()
)
if len(train_smiles.intersection(test_split_smiles)) > 0:
raise ValueError("Split failed, train and test overlap")
else:
print(f"Skipping {path} as it does not contain {repr_col} column")


def cli(
TRACKED_DATASETS = []


def per_repr(
repr_col: str,
seed: int = 42,
train_size: float = 0.8,
val_size: float = 0.0,
test_size: float = 0.2,
path: str = "*/data_clean.csv",
override: bool = False,
check: bool = True,
repr_col: str = "SMILES",
):
paths_to_data = glob(path)

Expand All @@ -197,40 +214,86 @@ def cli(
# for path in paths_to_data:
# if "flashpoint" in path:
# filtered_paths.append(path)
# elif "freesolv" in path:
# elif "BBBP" in path:
# filtered_paths.append(path)
# elif "peptide" in path:
# filtered_paths.append(path)
# elif "bicerano_dataset" in path:
# filtered_paths.append(path)
# paths_to_data = filtered_paths

REPRESENTATION_LIST = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In LN33 and here you use the same variable name but I guess for a different use case, or am I wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that is a global variable in LN33, but I think it is a good idea to change this local one to proper convention with lower-case names 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MicPie now should be better


for path in tqdm(paths_to_data):
for path in paths_to_data:
df = pd.read_csv(path)
if repr_col in df.columns:

if repr_col in df.columns and path not in TRACKED_DATASETS:
print("Processing", path.split("/")[0])
TRACKED_DATASETS.append(path)
REPRESENTATION_LIST.extend(df[repr_col].to_list())

REPR_DF = pd.DataFrame()
REPR_DF["SMILES"] = list(set(REPRESENTATION_LIST))
REPR_DF[repr_col] = list(set(REPRESENTATION_LIST))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, there should not need to be a need to allocate a dataframe for this (which can cost more memory). But it should work for our needs.


scaffold_split = create_scaffold_split(
REPR_DF, seed=seed, frac=[train_size, val_size, test_size]
)
if "SMILES" in REPR_DF.columns:
split = create_scaffold_split(
REPR_DF, seed=seed, frac=[train_size, val_size, test_size]
)
else:
split_ = pd.DataFrame(
np.random.choice(
["train", "test", "valid"],
size=len(REPR_DF),
p=[1 - val_size - test_size, test_size, val_size],
)
)

REPR_DF["split"] = split_
train = REPR_DF.query("split == 'train'").reset_index(drop=True)
test = REPR_DF.query("split == 'test'").reset_index(drop=True)
valid = REPR_DF.query("split == 'valid'").reset_index(drop=True)

split = {"train": train, "test": test, "valid": valid}

# create train and test dataframes
train_df = scaffold_split["train"]
test_df = scaffold_split["test"]
train_df = split["train"]
test_df = split["test"]
valid_df = split["valid"]
# add split columns to train and test dataframes
train_df["split"] = len(train_df) * ["train"]
test_df["split"] = len(test_df) * ["test"]
valid_df["split"] = len(valid_df) * ["valid"]

# merge train and test across all datasets
merge = pd.concat([train_df, test_df], axis=0)
merge = pd.concat([train_df, test_df, valid_df], axis=0)
# rewrite data_clean.csv for each dataset
rewrite_data_with_splits(
paths_to_data, merge, override=override, check=check, repr_col=repr_col
csv_paths=paths_to_data,
repr_col=repr_col,
train_test_df=merge,
override=override,
check=check,
)


def cli(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, this should do what we want it to do. I'll walk through it once more tomorrow with a fresh mind

seed: int = 42,
train_size: float = 0.75,
val_size: float = 0.125,
test_size: float = 0.125,
path: str = "*/data_clean.csv",
override: bool = False,
check: bool = True,
):
for representation in tqdm(REPRESENTATION_LIST):
if representation == "SMILES":
print("Processing priority representation: SMILES")
else:
print("Processing datasets with the representation column:", representation)
per_repr(
representation, seed, train_size, val_size, test_size, path, override, check
)


if __name__ == "__main__":
fire.Fire(cli)