Skip to content

Commit

Permalink
Add typing to src/aiida/workflows/arithmetic/multiply_add.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhollas committed Feb 5, 2025
1 parent c975b80 commit d9db4e2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ repos:
src/aiida/transports/cli.py|
src/aiida/transports/plugins/local.py|
src/aiida/transports/plugins/ssh.py|
src/aiida/workflows/arithmetic/multiply_add.py|
)$
Expand Down Expand Up @@ -323,7 +322,6 @@ repos:
src/aiida/transports/cli.py|
src/aiida/transports/plugins/local.py|
src/aiida/transports/plugins/ssh.py|
src/aiida/workflows/arithmetic/multiply_add.py|
)$
- id: generate-conda-environment
Expand Down
21 changes: 14 additions & 7 deletions src/aiida/workflows/arithmetic/multiply_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@
# start-marker for docs
"""Implementation of the MultiplyAddWorkChain for testing and demonstration purposes."""

from __future__ import annotations

from typing import TypeVar

from aiida.engine import ToContext, WorkChain, calcfunction
from aiida.orm import AbstractCode, Int
from aiida.orm import AbstractCode, Int, ProcessNode
from aiida.plugins.factories import CalculationFactory

ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add')

T = TypeVar('T')


@calcfunction
def multiply(x, y):
Expand All @@ -25,7 +31,7 @@ class MultiplyAddWorkChain(WorkChain):
"""WorkChain to multiply two numbers and add a third, for testing and demonstration purposes."""

@classmethod
def define(cls, spec):
def define(cls, spec) -> None:
"""Specify inputs and outputs."""
super().define(spec)
spec.input('x', valid_type=Int)
Expand All @@ -41,24 +47,25 @@ def define(cls, spec):
spec.output('result', valid_type=Int)
spec.exit_code(400, 'ERROR_NEGATIVE_NUMBER', message='The result is a negative number.')

def multiply(self):
def multiply(self) -> None:
"""Multiply two integers."""
self.ctx.product = multiply(self.inputs.x, self.inputs.y)

def add(self):
def add(self) -> dict[str, ProcessNode]:
"""Add two numbers using the `ArithmeticAddCalculation` calculation job plugin."""
inputs = {'x': self.ctx.product, 'y': self.inputs.z, 'code': self.inputs.code}
future = self.submit(ArithmeticAddCalculation, **inputs)
future = self.submit(ArithmeticAddCalculation, **inputs) # type: ignore[arg-type]
self.report(f'Submitted the `ArithmeticAddCalculation`: {future}')
return ToContext(addition=future)

def validate_result(self):
def validate_result(self) -> int | None:
"""Make sure the result is not negative."""
result = self.ctx.addition.outputs.sum

if result.value < 0:
return self.exit_codes.ERROR_NEGATIVE_NUMBER
return None

def result(self):
def result(self) -> None:
"""Add the result to the outputs."""
self.out('result', self.ctx.addition.outputs.sum)

0 comments on commit d9db4e2

Please sign in to comment.