Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment generating kinds concurrently #149

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions src/taskgraph/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import asyncio
import copy
import logging
import os
Expand Down Expand Up @@ -44,7 +45,8 @@ def _get_loader(self):
loader = "taskgraph.loader.default:loader"
return find_object(loader)

def load_tasks(self, parameters, loaded_tasks, write_artifacts):
async def load_tasks(self, parameters, loaded_tasks, write_artifacts):
logger.debug(f"Loading tasks for kind {self.name}")
loader = self._get_loader()
config = copy.deepcopy(self.config)

Expand Down Expand Up @@ -85,8 +87,9 @@ def load_tasks(self, parameters, loaded_tasks, write_artifacts):
soft_dependencies=task_dict.get("soft-dependencies"),
if_dependencies=task_dict.get("if-dependencies"),
)
for task_dict in transforms(trans_config, inputs)
async for task_dict in await transforms(trans_config, inputs)
]
logger.info(f"Generated {len(tasks)} tasks for kind {self.name}")
return tasks

@classmethod
Expand Down Expand Up @@ -249,6 +252,57 @@ def _load_kinds(self, graph_config, target_kinds=None):
except KindNotFound:
continue

async def _load_tasks(self, kinds, kind_graph, parameters):
all_tasks = {}
futures_to_kind = {}

def add_new_tasks(tasks):
for task in tasks:
if task.label in all_tasks:
raise Exception("duplicate tasks with label " + task.label)
all_tasks[task.label] = task

def create_futures(kinds, edges):
"""Create the next batch of tasks for kinds without dependencies."""
kinds_with_deps = {edge[0] for edge in edges}
ready_kinds = set(kinds) - kinds_with_deps
futures = set()
for name in ready_kinds:
task = asyncio.create_task(
kinds[name].load_tasks(
parameters,
list(all_tasks.values()),
self._write_artifacts,
)
)
futures.add(task)
futures_to_kind[task] = name
return futures

edges = set(kind_graph.edges)
futures = create_futures(kinds, edges)
while len(kinds) > 0:
done, futures = await asyncio.wait(
futures, return_when=asyncio.FIRST_COMPLETED
)

for future in done:
add_new_tasks(future.result())
name = futures_to_kind[future]

# Update state for next batch of futures.
del kinds[name]
edges = {e for e in edges if e[1] != name}

futures |= create_futures(kinds, edges)

if futures:
done, _ = await asyncio.wait(futures, return_when=asyncio.ALL_COMPLETED)
for future in done:
add_new_tasks(future.result())

return all_tasks

def _run(self):
logger.info("Loading graph configuration.")
graph_config = load_graph_config(self.root_dir)
Expand Down Expand Up @@ -303,24 +357,8 @@ def _run(self):
)

logger.info("Generating full task set")
all_tasks = {}
for kind_name in kind_graph.visit_postorder():
logger.debug(f"Loading tasks for kind {kind_name}")
kind = kinds[kind_name]
try:
new_tasks = kind.load_tasks(
parameters,
list(all_tasks.values()),
self._write_artifacts,
)
except Exception:
logger.exception(f"Error loading tasks for kind {kind_name}:")
raise
for task in new_tasks:
if task.label in all_tasks:
raise Exception("duplicate tasks with label " + task.label)
all_tasks[task.label] = task
logger.info(f"Generated {len(new_tasks)} tasks for kind {kind_name}")
all_tasks = asyncio.run(self._load_tasks(kinds, kind_graph, parameters))

full_task_set = TaskGraph(all_tasks, Graph(set(all_tasks), set()))
yield self.verify("full_task_set", full_task_set, graph_config, parameters)

Expand Down
30 changes: 27 additions & 3 deletions src/taskgraph/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.


import inspect
import re
from dataclasses import dataclass, field
from typing import Dict, List, Union
Expand Down Expand Up @@ -107,6 +107,12 @@ def repo_configs(self):
return repo_configs


async def convert_async(it):
"""Convert a synchronous iterator to an async one."""
for i in it:
yield i


@dataclass()
class TransformSequence:
"""
Expand All @@ -121,11 +127,29 @@ class TransformSequence:

_transforms: List = field(default_factory=list)

def __call__(self, config, items):
async def __call__(self, config, items):
for xform in self._transforms:
items = xform(config, items)
if isinstance(xform, TransformSequence):
items = await xform(config, items)
elif inspect.isasyncgenfunction(xform):
# Async generator transforms require async generator inputs.
# This can happen if a synchronous transform ran immediately
# prior.
if not inspect.isasyncgen(items):
items = convert_async(items)
items = xform(config, items)
else:
# Creating a synchronous generator from an asynchronous context
# doesn't appear possible, so unfortunately we need to convert
# to a list.
if inspect.isasyncgen(items):
items = [i async for i in items]
items = xform(config, items)
if items is None:
raise Exception(f"Transform {xform} is not a generator")

if not inspect.isasyncgen(items):
items = convert_async(items)
return items

def add(self, func):
Expand Down
4 changes: 2 additions & 2 deletions src/taskgraph/transforms/cached_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def format_task_digest(cached_task):


@transforms.add
def cache_task(config, tasks):
async def cache_task(config, tasks):
if taskgraph.fast:
for task in tasks:
yield task
Expand All @@ -61,7 +61,7 @@ def cache_task(config, tasks):
if "cached_task" in task.attributes:
digests[task.label] = format_task_digest(task.attributes["cached_task"])

for task in order_tasks(config, tasks):
for task in order_tasks(config, [t async for t in tasks]):
cache = task.pop("cache", None)
if cache is None:
yield task
Expand Down
4 changes: 2 additions & 2 deletions src/taskgraph/transforms/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@


@transforms.add
def chunk_tasks(config, tasks):
for task in tasks:
async def chunk_tasks(config, tasks):
async for task in tasks:
chunk_config = task.pop("chunk", None)
if not chunk_config:
yield task
Expand Down
4 changes: 2 additions & 2 deletions src/taskgraph/transforms/code_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@


@transforms.add
def add_dependencies(config, jobs):
for job in jobs:
async def add_dependencies(config, jobs):
async for job in jobs:
job.setdefault("soft-dependencies", [])
job["soft-dependencies"] += [
dep_task.label
Expand Down
6 changes: 2 additions & 4 deletions src/taskgraph/transforms/docker_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@


@transforms.add
def fill_template(config, tasks):
async def fill_template(config, tasks):
available_packages = set()
for task in config.kind_dependencies_tasks.values():
if task.kind != "packages":
Expand All @@ -75,13 +75,11 @@ def fill_template(config, tasks):

context_hashes = {}

tasks = list(tasks)

if not taskgraph.fast and config.write_artifacts:
if not os.path.isdir(CONTEXTS_DIR):
os.makedirs(CONTEXTS_DIR)

for task in tasks:
async for task in tasks:
image_name = task.pop("name")
job_symbol = task.pop("symbol", None)
args = task.pop("args", {})
Expand Down
8 changes: 4 additions & 4 deletions src/taskgraph/transforms/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def wrap(func):


@transforms.add
def process_fetch_job(config, jobs):
async def process_fetch_job(config, jobs):
# Converts fetch-url entries to the job schema.
for job in jobs:
async for job in jobs:
typ = job["fetch"]["type"]
name = job["name"]
fetch = job.pop("fetch")
Expand All @@ -103,15 +103,15 @@ def configure_fetch(config, typ, name, fetch):


@transforms.add
def make_task(config, jobs):
async def make_task(config, jobs):
# Fetch tasks are idempotent and immutable. Have them live for
# essentially forever.
if config.params["level"] == "3":
expires = "1000 years"
else:
expires = "28 days"

for job in jobs:
async for job in jobs:
name = job["name"]
artifact_prefix = job.get("artifact-prefix", "public")
env = job.get("env", {})
Expand Down
4 changes: 2 additions & 2 deletions src/taskgraph/transforms/from_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@


@transforms.add
def from_deps(config, tasks):
for task in tasks:
async def from_deps(config, tasks):
async for task in tasks:
# Setup and error handling.
from_deps = task.pop("from-deps")
kind_deps = config.config.get("kind-dependencies", [])
Expand Down
24 changes: 12 additions & 12 deletions src/taskgraph/transforms/job/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@


@transforms.add
def rewrite_when_to_optimization(config, jobs):
for job in jobs:
async def rewrite_when_to_optimization(config, jobs):
async for job in jobs:
when = job.pop("when", {})
if not when:
yield job
Expand All @@ -132,8 +132,8 @@ def rewrite_when_to_optimization(config, jobs):


@transforms.add
def set_implementation(config, jobs):
for job in jobs:
async def set_implementation(config, jobs):
async for job in jobs:
impl, os = worker_type_implementation(config.graph_config, job["worker-type"])
if os:
job.setdefault("tags", {})["os"] = os
Expand All @@ -148,8 +148,8 @@ def set_implementation(config, jobs):


@transforms.add
def set_label(config, jobs):
for job in jobs:
async def set_label(config, jobs):
async for job in jobs:
if "label" not in job:
if "name" not in job:
raise Exception("job has neither a name nor a label")
Expand All @@ -160,8 +160,8 @@ def set_label(config, jobs):


@transforms.add
def add_resource_monitor(config, jobs):
for job in jobs:
async def add_resource_monitor(config, jobs):
async for job in jobs:
if job.get("attributes", {}).get("resource-monitor"):
worker_implementation, worker_os = worker_type_implementation(
config.graph_config, job["worker-type"]
Expand Down Expand Up @@ -204,13 +204,13 @@ def get_attribute(dict, key, attributes, attribute_name):


@transforms.add
def use_fetches(config, jobs):
async def use_fetches(config, jobs):
artifact_names = {}
aliases = {}
extra_env = {}

jobs = [j async for j in jobs]
if config.kind in ("toolchain", "fetch"):
jobs = list(jobs)
for job in jobs:
run = job.get("run", {})
label = job["label"]
Expand Down Expand Up @@ -353,12 +353,12 @@ def cmp_artifacts(a):


@transforms.add
def make_task_description(config, jobs):
async def make_task_description(config, jobs):
"""Given a build description, create a task description"""
# import plugin modules first, before iterating over jobs
import_sibling_modules(exceptions=("common.py",))

for job in jobs:
async for job in jobs:
# always-optimized tasks never execute, so have no workdir
if job["worker"]["implementation"] in ("docker-worker", "generic-worker"):
job["run"].setdefault("workdir", "/builds/worker")
Expand Down
4 changes: 2 additions & 2 deletions src/taskgraph/transforms/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def _convert_content(content):


@transforms.add
def add_notifications(config, tasks):
for task in tasks:
async def add_notifications(config, tasks):
async for task in tasks:
label = "{}-{}".format(config.kind, task["name"])
if "notifications" in task:
notify = _convert_legacy(config, task.pop("notifications"), label)
Expand Down
Loading