Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dummy commit #346

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/conda-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
echo "Creating Conda Environment from environment.yml"
conda env create -f environment.yml
conda activate tdc-conda-env
python run_tests.py tdc.test.test_model_server.TestModelServer.testscVI
python run_tests.py
yapf --style=google -r -d tdc
conda deactivate
10 changes: 10 additions & 0 deletions tdc/model_server/model_loaders/scvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

class scVILoader:

def __init__(self):
pass

def load(self):
"""load scVI model
calling download_scvi() and any other helper functions
"""
20 changes: 20 additions & 0 deletions tdc/model_server/models/scvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

class scVI:
"""class to load and perform inference w/ scvi

adding any additional utils from scvi that facilitate inference / data processing
"""

def __init__(self):
import scvi
pass

def forward(self, **kwargs):
"""
loads self.model if needed
calls inference on these arguments
"""

def load(self):
"""import the model loader
-> then, LOAD the MODEL CLASS and return it and also save it into self.model"""
5 changes: 4 additions & 1 deletion tdc/model_server/tdc_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
'CYP3A4_Veith-AttentiveFP',
]

model_hub = ["Geneformer", "scGPT"]
model_hub = ["Geneformer", "scGPT", "scVI"]


class tdc_hf_interface:
Expand Down Expand Up @@ -66,6 +66,9 @@ def load(self):
AutoModel.register(ScGPTConfig, ScGPTModel)
model = AutoModel.from_pretrained("tdc/scGPT")
return model
elif self.model_name == "scVI":
# import scVI model and return the model class
pass
raise Exception("Not implemented yet!")

def load_deeppurpose(self, save_path):
Expand Down
13 changes: 13 additions & 0 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ def setUp(self):
print(os.getcwd())
self.resource = cellxgene_census.CensusResource()

def testscVI(self):
from tdc.multi_pred.anndata_dataset import DataLoader
from tdc import tdc_hf_interface
# run scvi test
adata = DataLoader("cellxgene_sample_small",
"./data",
dataset_names=["cellxgene_sample_small"],
no_convert=True).adata
# run scVI on a single sample
scvi = tdc_hf_interface("scVI")
model = scvi.load() # this line can cause segmentation fault
pass

def testscGPT(self):
from tdc.multi_pred.anndata_dataset import DataLoader
from tdc import tdc_hf_interface
Expand Down
Loading