Skip to content

Commit 35c6086

Browse files
move from private repo
1 parent 985b5bb commit 35c6086

29 files changed

+294776
-1
lines changed

README.md

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,121 @@
1-
# TxGNN
1+
# TxGNN: Repurposing therapeutics for neglected diseases using geometric deep learning
2+
3+
This repository hosts the official implementation of TxGNN, a method that can predict drug efficacy to disease with limited molecular underpinnings and few treatments by applying geomtric learning on multi-scale disease knowledge graph.
4+
5+
### Installation
6+
7+
Create your virtual environment using `virtualenv` or `conda` and then do `pip install TxGNN`
8+
9+
### Core API Interface
10+
Using the API, you can (1) reproduce the results in our paper and (2) train TxGNN on your own drug repurposing dataset using a few lines of code, and also generate graph explanations.
11+
12+
```python
13+
from TxGNN import TxData, TxGNN, TxEval
14+
15+
# Download/load knowledge graph dataset
16+
TxData = TxData(data_folder_path = './data')
17+
TxData.prepare_split(split = 'complex_disease', seed = 42)
18+
TxGNN = TxGNN(data = TxData,
19+
weight_bias_track = False,
20+
proj_name = 'TxGNN',
21+
exp_name = 'TxGNN'
22+
)
23+
24+
# Initialize a new model
25+
TxGNN.model_initialize(n_hid = 100,
26+
n_inp = 100,
27+
n_out = 100,
28+
proto = True,
29+
proto_num = 3,
30+
attention = False,
31+
sim_measure = 'all_nodes_profile',
32+
bert_measure = 'disease_name',
33+
agg_measure = 'rarity',
34+
num_walks = 200,
35+
walk_mode = 'bit',
36+
path_length = 2)
37+
38+
```
39+
40+
Instead of initializing a new model, you can also load a saved model:
41+
42+
```python
43+
TxGNN.load_pretrained('./model_ckpt')
44+
```
45+
46+
To do pre-training using link prediction for all edge types, you can type:
47+
48+
```python
49+
TxGNN.pretrain(n_epoch = 2,
50+
learning_rate = 1e-3,
51+
batch_size = 1024,
52+
train_print_per_n = 20)
53+
```
54+
55+
Lastly, to do finetuning on drug-disease relation with metric learning, you can type:
56+
57+
```python
58+
TxGNN.finetune(n_epoch = 500,
59+
learning_rate = 5e-4,
60+
train_print_per_n = 5,
61+
valid_per_n = 20,
62+
save_name = finetune_result_path)
63+
```
64+
65+
To save the trained model, you can type:
66+
67+
```python
68+
TxGNN.save_model('./model_ckpt')
69+
```
70+
71+
To evaluate the model on the entire test set using disease-centric evaluation, you can type:
72+
73+
```python
74+
result = TxEval.eval_disease_centric(disease_idxs = 'test_set',
75+
show_plot = False,
76+
verbose = True,
77+
save_result = True,
78+
return_raw = False,
79+
save_name = 'SAVE_PATH')
80+
81+
```
82+
83+
If you want to look at specific disease, you can also do:
84+
85+
```python
86+
result = TxEval.eval_disease_centric(disease_idxs = [9907.0, 12787.0],
87+
relation = 'indication',
88+
save_result = False)
89+
```
90+
91+
92+
After training a satisfying link prediction model, we can also train graph XAI model by:
93+
94+
```python
95+
TxGNN.train_graphmask(relation = 'indication',
96+
learning_rate = 3e-4,
97+
allowance = 0.005,
98+
epochs_per_layer = 3,
99+
penalty_scaling = 1,
100+
valid_per_n = 20)
101+
```
102+
103+
You can retrieve and save the graph XAI gates (whether or not an edge is important) into a pkl file located as `SAVED_PATH/'graphmask_output_RELATION.pkl'`:
104+
105+
```python
106+
gates = TxGNN.retrieve_save_gates('SAVED_PATH')
107+
```
108+
109+
Of course, you can save and load graphmask model as well via:
110+
111+
```python
112+
TxGNN.save_graphmask_model('./graphmask_model_ckpt')
113+
TxGNN.load_pretrained_graphmask('./graphmask_model_ckpt')
114+
115+
```
116+
117+
118+
### Cite Us
119+
120+
```
121+
```

0 commit comments

Comments
 (0)