Skip to content

Commit

Permalink
Update finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
michellemli committed May 6, 2024
1 parent e183339 commit 785fe0e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion finetune_pinnacle/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def main():
data_split_path = args.data_split_path + ".json"

# Load data
embed, celltype_dict, celltype_protein_dict, positive_proteins, negative_proteins, all_relevant_proteins = load_data(embed_path, labels_path, args.positive_proteins_prefix, args.negative_proteins_prefix, args.raw_data_prefix)
embed, celltype_dict, celltype_protein_dict, positive_proteins, negative_proteins, all_relevant_proteins = load_data(embed_path, labels_path, args.positive_proteins_prefix, args.negative_proteins_prefix, args.raw_data_prefix, None)
for c, v in positive_proteins.items():
assert len(v) == len(set(v).intersection(set(all_relevant_proteins)))

Expand Down
5 changes: 4 additions & 1 deletion finetune_pinnacle/read_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def read_labels_from_evidence(positive_protein_prefix, negative_protein_prefix,
return {}, {}, {}


def load_data(embed_path: str, labels_path: str, positive_proteins_prefix: str, negative_proteins_prefix: str, raw_data_prefix: str):
def load_data(embed_path: str, labels_path: str, positive_proteins_prefix: str, negative_proteins_prefix: str, raw_data_prefix: str, task_name: str):

embed = torch.load(embed_path)
with open(labels_path, "r") as f:
Expand All @@ -60,6 +60,9 @@ def load_data(embed_path: str, labels_path: str, positive_proteins_prefix: str,

positive_proteins, negative_proteins, all_relevant_proteins = read_labels_from_evidence(positive_proteins_prefix, negative_proteins_prefix, raw_data_prefix)
assert len(positive_proteins) > 0
if task_name != None and len(positive_proteins) == 1:
positive_proteins = positive_proteins[task_name]
negative_proteins = negative_proteins[task_name]

return embed, celltype_dict, celltype_protein_dict, positive_proteins, negative_proteins, all_relevant_proteins

2 changes: 1 addition & 1 deletion finetune_pinnacle/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def main(args, hparams, wandb):
models_output_dir, metrics_output_dir, random_state, embed_path, labels_path = setup_paths(args)

# Load data
embed, celltype_dict, celltype_protein_dict, positive_proteins, negative_proteins, _ = load_data(embed_path, labels_path, args.positive_proteins_prefix, args.negative_proteins_prefix, None)
embed, celltype_dict, celltype_protein_dict, positive_proteins, negative_proteins, _ = load_data(embed_path, labels_path, args.positive_proteins_prefix, args.negative_proteins_prefix, None, args.task_name)
print("Finished reading data, evaluating...\n")

# Run model
Expand Down

0 comments on commit 785fe0e

Please sign in to comment.