Skip to content

Commit

Permalink
Merge pull request #114 from bowang-lab/bugfix-remove-celltypekey-fro…
Browse files Browse the repository at this point in the history
…m-embed-data

Bugfix remove celltypekey from embed data
  • Loading branch information
subercui authored Nov 8, 2023
2 parents 9f62f16 + 4dea6f6 commit f7e5d52
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
15 changes: 7 additions & 8 deletions scgpt/tasks/cell_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def __getitem__(self, idx):
def embed_data(
adata_or_file: Union[AnnData, PathLike],
model_dir: PathLike,
cell_type_key: str = "cell_type",
gene_col: str = "feature_name",
max_length=1200,
batch_size=64,
Expand All @@ -164,13 +163,11 @@ def embed_data(
adata_or_file (Union[AnnData, PathLike]): The AnnData object or the path to the
AnnData object.
model_dir (PathLike): The path to the model directory.
cell_type_key (str): The key in adata.obs that contains the cell type labels.
Defaults to "cell_type".
gene_col (str): The column in adata.var that contains the gene names.
max_length (int): The maximum length of the input sequence. Defaults to 1200.
batch_size (int): The batch size for inference. Defaults to 64.
obs_to_save (Optional[list]): The list of obs columns to save in the output adata.
If None, will only keep the column of :attr:`cell_type_key`. Defaults to None.
Useful for retaining meta data to output. Defaults to None.
device (Union[str, torch.device]): The device to use. Defaults to "cuda".
use_fast_transformer (bool): Whether to use flash-attn. Defaults to True.
return_new_adata (bool): Whether to return a new AnnData object. If False, will
Expand All @@ -184,8 +181,11 @@ def embed_data(
else:
adata = sc.read_h5ad(adata_or_file)

# verify cell type key and gene col
assert cell_type_key in adata.obs
if isinstance(obs_to_save, str):
assert obs_to_save in adata.obs, f"obs_to_save {obs_to_save} not in adata.obs"
obs_to_save = [obs_to_save]

# verify gene col
if gene_col == "index":
adata.var["index"] = adata.var.index
else:
Expand Down Expand Up @@ -273,8 +273,7 @@ def embed_data(
)

if return_new_adata:
obs_to_save = [cell_type_key] if obs_to_save is None else obs_to_save
obs_df = adata.obs[obs_to_save]
obs_df = adata.obs[obs_to_save] if obs_to_save is not None else None
return sc.AnnData(X=cell_embeddings, obs=obs_df, dtype="float32")

adata.obsm["X_scGPT"] = cell_embeddings
Expand Down
34 changes: 23 additions & 11 deletions tutorials/Tutorial_Reference_Mapping.ipynb

Large diffs are not rendered by default.

0 comments on commit f7e5d52

Please sign in to comment.