|
1 |
| -# TxGNN |
| 1 | +# TxGNN: Repurposing therapeutics for neglected diseases using geometric deep learning |
| 2 | + |
| 3 | +This repository hosts the official implementation of TxGNN, a method that can predict drug efficacy to disease with limited molecular underpinnings and few treatments by applying geomtric learning on multi-scale disease knowledge graph. |
| 4 | + |
| 5 | +### Installation |
| 6 | + |
| 7 | +Create your virtual environment using `virtualenv` or `conda` and then do `pip install TxGNN` |
| 8 | + |
| 9 | +### Core API Interface |
| 10 | +Using the API, you can (1) reproduce the results in our paper and (2) train TxGNN on your own drug repurposing dataset using a few lines of code, and also generate graph explanations. |
| 11 | + |
| 12 | +```python |
| 13 | +from TxGNN import TxData, TxGNN, TxEval |
| 14 | + |
| 15 | +# Download/load knowledge graph dataset |
| 16 | +TxData = TxData(data_folder_path = './data') |
| 17 | +TxData.prepare_split(split = 'complex_disease', seed = 42) |
| 18 | +TxGNN = TxGNN(data = TxData, |
| 19 | + weight_bias_track = False, |
| 20 | + proj_name = 'TxGNN', |
| 21 | + exp_name = 'TxGNN' |
| 22 | + ) |
| 23 | + |
| 24 | +# Initialize a new model |
| 25 | +TxGNN.model_initialize(n_hid = 100, |
| 26 | + n_inp = 100, |
| 27 | + n_out = 100, |
| 28 | + proto = True, |
| 29 | + proto_num = 3, |
| 30 | + attention = False, |
| 31 | + sim_measure = 'all_nodes_profile', |
| 32 | + bert_measure = 'disease_name', |
| 33 | + agg_measure = 'rarity', |
| 34 | + num_walks = 200, |
| 35 | + walk_mode = 'bit', |
| 36 | + path_length = 2) |
| 37 | + |
| 38 | +``` |
| 39 | + |
| 40 | +Instead of initializing a new model, you can also load a saved model: |
| 41 | + |
| 42 | +```python |
| 43 | +TxGNN.load_pretrained('./model_ckpt') |
| 44 | +``` |
| 45 | + |
| 46 | +To do pre-training using link prediction for all edge types, you can type: |
| 47 | + |
| 48 | +```python |
| 49 | +TxGNN.pretrain(n_epoch = 2, |
| 50 | + learning_rate = 1e-3, |
| 51 | + batch_size = 1024, |
| 52 | + train_print_per_n = 20) |
| 53 | +``` |
| 54 | + |
| 55 | +Lastly, to do finetuning on drug-disease relation with metric learning, you can type: |
| 56 | + |
| 57 | +```python |
| 58 | +TxGNN.finetune(n_epoch = 500, |
| 59 | + learning_rate = 5e-4, |
| 60 | + train_print_per_n = 5, |
| 61 | + valid_per_n = 20, |
| 62 | + save_name = finetune_result_path) |
| 63 | +``` |
| 64 | + |
| 65 | +To save the trained model, you can type: |
| 66 | + |
| 67 | +```python |
| 68 | +TxGNN.save_model('./model_ckpt') |
| 69 | +``` |
| 70 | + |
| 71 | +To evaluate the model on the entire test set using disease-centric evaluation, you can type: |
| 72 | + |
| 73 | +```python |
| 74 | +result = TxEval.eval_disease_centric(disease_idxs = 'test_set', |
| 75 | + show_plot = False, |
| 76 | + verbose = True, |
| 77 | + save_result = True, |
| 78 | + return_raw = False, |
| 79 | + save_name = 'SAVE_PATH') |
| 80 | + |
| 81 | +``` |
| 82 | + |
| 83 | +If you want to look at specific disease, you can also do: |
| 84 | + |
| 85 | +```python |
| 86 | +result = TxEval.eval_disease_centric(disease_idxs = [9907.0, 12787.0], |
| 87 | + relation = 'indication', |
| 88 | + save_result = False) |
| 89 | +``` |
| 90 | + |
| 91 | + |
| 92 | +After training a satisfying link prediction model, we can also train graph XAI model by: |
| 93 | + |
| 94 | +```python |
| 95 | +TxGNN.train_graphmask(relation = 'indication', |
| 96 | + learning_rate = 3e-4, |
| 97 | + allowance = 0.005, |
| 98 | + epochs_per_layer = 3, |
| 99 | + penalty_scaling = 1, |
| 100 | + valid_per_n = 20) |
| 101 | +``` |
| 102 | + |
| 103 | +You can retrieve and save the graph XAI gates (whether or not an edge is important) into a pkl file located as `SAVED_PATH/'graphmask_output_RELATION.pkl'`: |
| 104 | + |
| 105 | +```python |
| 106 | +gates = TxGNN.retrieve_save_gates('SAVED_PATH') |
| 107 | +``` |
| 108 | + |
| 109 | +Of course, you can save and load graphmask model as well via: |
| 110 | + |
| 111 | +```python |
| 112 | +TxGNN.save_graphmask_model('./graphmask_model_ckpt') |
| 113 | +TxGNN.load_pretrained_graphmask('./graphmask_model_ckpt') |
| 114 | + |
| 115 | +``` |
| 116 | + |
| 117 | + |
| 118 | +### Cite Us |
| 119 | + |
| 120 | +``` |
| 121 | +``` |
0 commit comments