Skip to content

Commit

Permalink
add bulk download input loader (#277)
Browse files Browse the repository at this point in the history
* add bulk download input loader initial

* typing

* linting

* fix list

* change bd paths

* remove print

* rename intermediate output reference

* fix type

* message status breaking

* fix status messsage

* update policies

* add prints

* add zip handler

* remove outputs

* ignore type

* linting

* if file_format is .txt treat like a fasta

* lint
  • Loading branch information
rzlim08 authored Apr 4, 2024
1 parent d643dd0 commit ca86f15
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 8 deletions.
12 changes: 11 additions & 1 deletion platformics/support/format_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from Bio import SeqIO
from typing import Protocol


from mypy_boto3_s3 import S3Client


Expand Down Expand Up @@ -101,18 +102,27 @@ class JsonHandler(FileFormatHandler):
def validate(self) -> None:
json.loads(self.contents()) # throws an exception for invalid JSON

class ZipHandler(FileFormatHandler):
"""
Validate ZIP files
"""

def validate(self) -> None:
assert self.key.endswith(".zip") # throws an exception if the file is not a zip file

def get_validator(format: str) -> type[FileFormatHandler]:
"""
Returns the validator for a given file format
"""
if format == "fasta":
if format in ["fa", "fasta"]:
return FastaHandler
elif format == "fastq":
return FastqHandler
elif format == "bed":
return BedHandler
elif format == "json":
return JsonHandler
elif format == "zip":
return ZipHandler
else:
raise Exception(f"Unknown file format '{format}'")
3 changes: 2 additions & 1 deletion workflows/.happy/terraform/envs/dev/iam_policies.tf
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ data "aws_iam_policy_document" "workflows" {
"states:StartExecution"
]
resources = [
"arn:aws:states:us-west-2:${var.aws_account_id}:stateMachine:idseq-swipe-dev-default-wdl"
"arn:aws:states:us-west-2:${var.aws_account_id}:stateMachine:idseq-swipe-dev-default-wdl",
"arn:aws:states:us-west-2:${var.aws_account_id}:execution:idseq-swipe-dev-default-wdl:*"
]
}
statement {
Expand Down
3 changes: 2 additions & 1 deletion workflows/.happy/terraform/envs/sandbox/iam_policies.tf
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ data "aws_iam_policy_document" "workflows" {
"states:StartExecution"
]
resources = [
"arn:aws:states:us-west-2:${var.aws_account_id}:stateMachine:idseq-swipe-sandbox-default-wdl"
"arn:aws:states:us-west-2:${var.aws_account_id}:stateMachine:idseq-swipe-sandbox-default-wdl",
"arn:aws:states:us-west-2:${var.aws_account_id}:execution:idseq-swipe-sandbox-default-wdl:*"
]
}
statement {
Expand Down
3 changes: 2 additions & 1 deletion workflows/.happy/terraform/envs/staging/iam_policies.tf
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ data "aws_iam_policy_document" "workflows" {
"states:StartExecution"
]
resources = [
"arn:aws:states:us-west-2:${var.aws_account_id}:stateMachine:idseq-swipe-staging-default-wdl"
"arn:aws:states:us-west-2:${var.aws_account_id}:stateMachine:idseq-swipe-staging-default-wdl",
"arn:aws:states:us-west-2:${var.aws_account_id}:execution:idseq-swipe-staging-default-wdl:*"
]
}
statement {
Expand Down
2 changes: 1 addition & 1 deletion workflows/plugins/event_bus/swipe/event_bus_swipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _parse_message(self, message: dict) -> WorkflowStatusMessage | None:
# TODO: handle aws.batch for step statuses
if not message.get("source") == "aws.states":
return None
status = self._create_workflow_status(message["status"])
status = self._create_workflow_status(message["detail"]["status"])
execution_arn = message["detail"]["executionArn"]
if status == "WORKFLOW_SUCCESS":
return WorkflowSucceededMessage(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from sgqlc.operation import Operation

from database.models.workflow_version import WorkflowVersion
from manifest.manifest import EntityInput, Primitive
from platformics.client.entities_schema import (
ConsensusGenomeWhereClause,
Query,
UUIDComparators,
)
from platformics.util.types_utils import JSONValue
from plugins.plugin_types import InputLoader

PUBLIC_REFERENCES_PREFIX = "s3://czid-public-references/consensus-genome"
CG_BULK_DOWNLOAD_OUTPUT = "consensus_genome_intermediate_output_files"
CG_BULK_DOWNLOAD_CONSENSUS = "consensus_genome"
CG_BULK_DOWNLOADS = [CG_BULK_DOWNLOAD_CONSENSUS, CG_BULK_DOWNLOAD_OUTPUT]


class BulkDownloadInputLoader(InputLoader):
async def load(
self,
workflow_version: WorkflowVersion,
entity_inputs: dict[str, EntityInput | list[EntityInput]],
raw_inputs: dict[str, Primitive | list[Primitive]],
requested_outputs: list[str] = [],
) -> dict[str, JSONValue]:
inputs: dict[str, JSONValue] = {}
if raw_inputs.get("bulk_download_type") in CG_BULK_DOWNLOADS:
consensus_genome_input = entity_inputs["consensus_genomes"]
op = Operation(Query)
if isinstance(consensus_genome_input, EntityInput):
# if single input
consensus_genome = op.consensus_genomes(
where=ConsensusGenomeWhereClause(id=UUIDComparators(_eq=consensus_genome_input.entity_id))
)
else:
# must be list of inputs
consensus_genome = op.consensus_genomes(
where=ConsensusGenomeWhereClause(
id=UUIDComparators(_in=[cg.entity_id for cg in consensus_genome_input])
)
)
consensus_genome.sequencing_read()
consensus_genome.sequencing_read.sample()
consensus_genome.sequencing_read.sample.id()
consensus_genome.sequencing_read.sample.name()
consensus_genome.accession()
consensus_genome.accession.accession_id()
if raw_inputs.get("bulk_download_type") == CG_BULK_DOWNLOAD_OUTPUT:
self._fetch_file(consensus_genome.intermediate_outputs())
elif raw_inputs.get("bulk_download_type") == CG_BULK_DOWNLOAD_CONSENSUS:
self._fetch_file(consensus_genome.sequence())
res = self._entities_gql(op)
files: list[dict[str, Primitive | None]] = []
for cg_res in res["consensusGenomes"]:
sample_name = f"{cg_res['sequencingRead']['sample']['name']}"
sample_id = f"{cg_res['sequencingRead']['sample']['id']}"
if cg_res["accession"]:
accession = f"{cg_res['accession']['accessionId']}"
output_name = f"{sample_name}_{sample_id}_{accession}"
else:
output_name = f"{sample_name}_{sample_id}"

if raw_inputs.get("bulk_download_type") == CG_BULK_DOWNLOAD_OUTPUT:
download_link = self._uri_file(cg_res["intermediateOutputs"])
suffix = ".zip"
elif raw_inputs.get("bulk_download_type") == CG_BULK_DOWNLOAD_CONSENSUS:
download_link = self._uri_file(cg_res["sequence"])
suffix = ".fa"
files.append(
{
"output_name": output_name + suffix,
"file_path": download_link,
}
)
inputs["files"] = files # type: ignore
return inputs
3 changes: 2 additions & 1 deletion workflows/plugins/input_loaders/czid_workflows/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
long_description="",
long_description_content_type="text/markdown",
author="Todd Morse",
py_modules=["consensus_genome_input"],
py_modules=["consensus_genome_input", "bulk_download_input"],
python_requires=">=3.6",
setup_requires=[],
reentry_register=True,
entry_points={
"czid.plugin.input_loader": [
"consensus_genome = consensus_genome_input:ConsensusGenomeInputLoader",
"bulk_download = bulk_download_input:BulkDownloadInputLoader",
],
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ async def load(
file_path = workflow_outputs["file"]
assert isinstance(file_path, str)

file_format = file_path.split(".")[-1]
# if file_path ends with .txt, change file_format to fasta
file_format = "fasta" if file_format == "txt" else file_format

bulk_download = op.create_bulk_download(
input=BulkDownloadCreateInput(
producing_run_id=ID(workflow_run.id),
Expand All @@ -43,7 +47,7 @@ async def load(
file = op.create_file(
entity_id=bulk_download_id,
entity_field_name="file",
file=FileCreate(name="file", file_format="fasta", **self._parse_uri(file_path)),
file=FileCreate(name="file", file_format=file_format, **self._parse_uri(file_path)),
)
file.id()
self._entities_gql(op)
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def load(
reference_genome_id: ID | None = None
if reference_genome_input:
assert isinstance(reference_genome_input, EntityInput)
reference_genome_id = ID(reference_genome_input.entity_id)
reference_genome_id = ID(reference_genome_input.entity_id) # type: ignore

consensus_genome = op.create_consensus_genome(
input=ConsensusGenomeCreateInput(
Expand Down

0 comments on commit ca86f15

Please sign in to comment.