Skip to content

Commit

Permalink
feat(de): interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
atarashansky committed May 22, 2024
1 parent 6c72fef commit 15dd8fc
Show file tree
Hide file tree
Showing 21 changed files with 5,578 additions and 2,688 deletions.
13 changes: 13 additions & 0 deletions backend/de/api/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os

from backend.common.utils.secret_config import SecretConfig


class DeConfig(SecretConfig):
def __init__(self, *args, **kwargs):
super().__init__("backend", secret_name="de_config", **kwargs)

Check warning on line 8 in backend/de/api/config.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/config.py#L8

Added line #L8 was not covered by tests

def get_defaults_template(self):
deployment_stage = os.getenv("DEPLOYMENT_STAGE", "test")
defaults_template = {"bucket": f"wmg-{deployment_stage}", "data_path_prefix": "", "tiledb_config_overrides": {}}
return defaults_template
96 changes: 95 additions & 1 deletion backend/de/api/de-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,101 @@ paths:
maxLength: 4
minimum: 0.0
maximum: 1.0

/v1/interpretDeResults:
post:
summary: Interpret differential expression results using OpenAI GPT-4o.
tags:
- de
operationId: backend.de.api.v1.interpretDeResults
parameters: []
requestBody:
content:
application/json:
schema:
type: object
properties:
queryGroup1Filters:
type: object
properties:
organism_ontology_term_id:
type: string
tissue_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
publication_citations:
type: array
items:
type: string
disease_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
sex_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
development_ontology_stage_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
self_reported_ethnicity_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
cell_type_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
required:
- organism_ontology_term_id
queryGroup2Filters:
type: object
properties:
organism_ontology_term_id:
type: string
tissue_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
publication_citations:
type: array
items:
type: string
disease_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
sex_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
development_ontology_stage_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
self_reported_ethnicity_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
cell_type_ontology_term_ids:
$ref: "#/components/schemas/de_ontology_term_id_list"
required:
- organism_ontology_term_id
deGenes1:
description: ->
Differentially expressed genes for group 1
type: array
items:
description: ->
Gene symbol
type: string
deGenes2:
description: ->
Differentially expressed genes for group 2
type: array
items:
description: ->
Gene symbol
type: string
required:
- queryGroup1Filters
- queryGroup2Filters
- deGenes1
- deGenes2
responses:
"200":
description: OK
content:
application/json:
schema:
type: object
required:
- message
- prompt
properties:
message:
type: string
prompt:
type: string
components:
schemas:
problem:
Expand Down
83 changes: 83 additions & 0 deletions backend/de/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import openai

from backend.common.utils.ontology_parser import ontology_parser
from backend.de.api.config import DeConfig
from backend.wmg.data.query import DeQueryCriteria


def interpret_de_results(
criteria1: DeQueryCriteria,
criteria2: DeQueryCriteria,
genes1: list[str],
genes2: list[str],
) -> str:
prompt = _craft_de_interpretation_prompt(criteria1, criteria2, genes1, genes2)

messages = [{"role": "user", "content": prompt}]

Check warning on line 16 in backend/de/api/utils.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/utils.py#L16

Added line #L16 was not covered by tests

response = openai.ChatCompletion.create(

Check warning on line 18 in backend/de/api/utils.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/utils.py#L18

Added line #L18 was not covered by tests
model="gpt-4o-2024-05-13",
messages=messages,
api_key=DeConfig().openai_api_key,
)
return response["choices"][0]["message"]["content"], prompt

Check warning on line 23 in backend/de/api/utils.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/utils.py#L23

Added line #L23 was not covered by tests


def _craft_de_interpretation_prompt(
criteria1: DeQueryCriteria, criteria2: DeQueryCriteria, genes1: list[str], genes2: list[str]
) -> str:
def get_term_labels(ontology_term_ids: list[str]) -> list[str]:
return [ontology_parser.get_term_label(term_id) for term_id in ontology_term_ids]

Check warning on line 30 in backend/de/api/utils.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/utils.py#L29-L30

Added lines #L29 - L30 were not covered by tests

def format_criteria(criteria: DeQueryCriteria) -> str:
parts = [

Check warning on line 33 in backend/de/api/utils.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/utils.py#L32-L33

Added lines #L32 - L33 were not covered by tests
f"Organism: {ontology_parser.get_term_label(criteria.organism_ontology_term_id)}",
(
f"Tissues: {', '.join(get_term_labels(criteria.tissue_ontology_term_ids))}"
if criteria.tissue_ontology_term_ids
else ""
),
(
f"Cell Types: {', '.join(get_term_labels(criteria.cell_type_ontology_term_ids))}"
if criteria.cell_type_ontology_term_ids
else ""
),
(
f"Diseases: {', '.join(get_term_labels(criteria.disease_ontology_term_ids))}"
if criteria.disease_ontology_term_ids
else ""
),
(
f"Ethnicities: {', '.join(get_term_labels(criteria.self_reported_ethnicity_ontology_term_ids))}"
if criteria.self_reported_ethnicity_ontology_term_ids
else ""
),
(
f"Sexes: {', '.join(get_term_labels(criteria.sex_ontology_term_ids))}"
if criteria.sex_ontology_term_ids
else ""
),
]
return ", ".join([part for part in parts if part])

Check warning on line 61 in backend/de/api/utils.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/utils.py#L61

Added line #L61 was not covered by tests

formatted_criteria1 = format_criteria(criteria1)
formatted_criteria2 = format_criteria(criteria2)
formatted_genes1 = "\n ".join([f"{i+1}. {gene}" for i, gene in enumerate(genes1)])
formatted_genes2 = "\n ".join([f"{i+1}. {gene}" for i, gene in enumerate(genes2)])

Check warning on line 66 in backend/de/api/utils.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/utils.py#L63-L66

Added lines #L63 - L66 were not covered by tests

prompt = f"""

Check warning on line 68 in backend/de/api/utils.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/utils.py#L68

Added line #L68 was not covered by tests
Please analyze the following differential expression (DE) results and provide a detailed interpretation focusing on
biologically interesting and relevant signals, such as pathway enrichment.
- **Group 1 Criteria**: {formatted_criteria1}
- **Group 2 Criteria**: {formatted_criteria2}
- **Top Upregulated Genes for Group 1 (Downregulated for Group 2)**:
{formatted_genes1}
- **Top Downregulated Genes for Group 1 (Upregulated for Group 2)**:
{formatted_genes2}
Please include in your analysis any significant pathways, biological processes, or functional annotations that are enriched
in these top differentially expressed genes. Also, describe any notable patterns or trends that emerge from these results,
especially given the biological context of the query.
"""
return prompt.strip()

Check warning on line 83 in backend/de/api/utils.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/utils.py#L83

Added line #L83 was not covered by tests
19 changes: 19 additions & 0 deletions backend/de/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from backend.common.marker_gene_files.blacklist import marker_gene_blacklist
from backend.common.utils.rollup import descendants
from backend.de.api.utils import interpret_de_results
from backend.wmg.api.wmg_api_config import (
WMG_API_FORCE_LOAD_SNAPSHOT_ID,
WMG_API_READ_FS_CACHED_SNAPSHOT,
Expand Down Expand Up @@ -298,6 +299,24 @@ def run_differential_expression(q: WmgQuery, criteria1, criteria2) -> Tuple[List
return statistics, n_overlap


@tracer.wrap(name="interpretDeResults", service="de-api", resource="interpretDeResults", span_type="de-api")
def interpretDeResults():
request = connexion.request.json

Check warning on line 304 in backend/de/api/v1.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/v1.py#L304

Added line #L304 was not covered by tests

queryGroup1Filters = request["queryGroup1Filters"]
queryGroup2Filters = request["queryGroup2Filters"]
de_genes1 = request["deGenes1"]
de_genes2 = request["deGenes2"]

Check warning on line 309 in backend/de/api/v1.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/v1.py#L306-L309

Added lines #L306 - L309 were not covered by tests

criteria1 = DeQueryCriteria(**queryGroup1Filters)
criteria2 = DeQueryCriteria(**queryGroup2Filters)

Check warning on line 312 in backend/de/api/v1.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/v1.py#L311-L312

Added lines #L311 - L312 were not covered by tests

with ServerTiming.time("interpret differential expression results"):
message, prompt = interpret_de_results(criteria1, criteria2, de_genes1, de_genes2)

Check warning on line 315 in backend/de/api/v1.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/v1.py#L314-L315

Added lines #L314 - L315 were not covered by tests

return jsonify(dict(message=message, prompt=prompt))

Check warning on line 317 in backend/de/api/v1.py

View check run for this annotation

Codecov / codecov/patch

backend/de/api/v1.py#L317

Added line #L317 was not covered by tests


def _get_cell_counts_for_query(q: WmgQuery, criteria: WmgFiltersQueryCriteria) -> pd.DataFrame:
if criteria.cell_type_ontology_term_ids:
criteria.cell_type_ontology_term_ids = list(
Expand Down
Loading

0 comments on commit 15dd8fc

Please sign in to comment.