Note
Go to the end to download the full example code.
Run tasks in parallel
Introduction
In this tutorial, you will learn how to run task in parallel.
Load the AiiDA profile.
from aiida import load_profile
load_profile()
Profile<uuid='90da34ae855c481f9f23e4e2526238f1' name='presto'>
First workflow
Suppose we want to calculate `(x + y) * z `
in two steps. First, add x and y, then multiply the result with z. And X is a list of values. We want to calculate these in parallel.
Create task
We need a create a WorkGraph to run tasksin parallel. And then treat this WorkGraph as a task.
from aiida_workgraph import task, WorkGraph
# define multiply task
@task()
def multiply(x, y):
return x * y
# Create a WorkGraph as a task
@task.graph_builder()
def multiply_parallel(X, y):
wg = WorkGraph()
# here the task `multiply` is created and will run in parallel
for key, value in X.items():
wg.add_task(multiply, name=f"multiply_{key}", x=value, y=y)
return wg
Create the workflow
from aiida_workgraph import WorkGraph
from aiida.orm import Int, List
X = {"a": Int(1), "b": Int(2), "c": Int(3)}
y = Int(2)
z = Int(3)
wg = WorkGraph("parallel_tasks")
multiply_parallel1 = wg.add_task(multiply_parallel, name="multiply_parallel1", X=X, y=y)
wg.submit(wait=True)
WorkGraph process created, PK: 323
Process 323 finished with state: FINISHED
<WorkGraphNode: uuid: e8bc8839-25c0-47c5-8b17-eb58e9349f2d (pk: 323) (aiida.workflows:workgraph.engine)>
Check the status and results
print("State of WorkGraph: {}".format(wg.state))
State of WorkGraph: FINISHED
Generate node graph from the AiiDA process:
from aiida_workgraph.utils import generate_node_graph
generate_node_graph(wg.pk)
Second workflow: gather results
Now I want to gather the results from the previous multiply_parallel tasks and calculate the sum of all their results. Let’s update the multiply_parallel function to multiply_parallel_gather.
@task.graph_builder(outputs=[{"name": "result"}])
def multiply_parallel_gather(X, y):
wg = WorkGraph()
for key, value in X.items():
multiply1 = wg.add_task(multiply, x=value, y=y)
# add result of multiply1 to `self.context.mul`
# self.context.mul is a dict {"a": value1, "b": value2, "c": value3}
wg.update_ctx({f"mul.{key}": multiply1.outputs.result})
wg.outputs.result = wg.ctx.mul
return wg
@task()
# the input is dynamic, we must use a variable kewword argument. **datas
def sum(**datas):
from aiida.orm import Float
total = 0
for key, data in datas.items():
total += data
return Float(total)
Now, let’s create a WorkGraph to use the new task:
from aiida_workgraph import WorkGraph
from aiida.orm import Int, List
X = {"a": Int(1), "b": Int(2), "c": Int(3)}
y = Int(2)
z = Int(3)
wg = WorkGraph("parallel_tasks")
multiply_parallel_gather1 = wg.add_task(multiply_parallel_gather, X=X, y=y)
sum1 = wg.add_task(sum, name="sum1")
# wg.add_link(add1.outputs[0], multiply_parallel_gather1.inputs["uuids"])
wg.add_link(multiply_parallel_gather1.outputs[0], sum1.inputs[0])
wg.submit(wait=True)
WorkGraph process created, PK: 341
Process 341 finished with state: FINISHED
<WorkGraphNode: uuid: ce659e8a-15ca-4079-b05b-f97e4acf2df5 (pk: 341) (aiida.workflows:workgraph.engine)>
Get the result of the tasks:
print("State of WorkGraph: {}".format(wg.state))
print("Result of task add1: {}".format(wg.tasks.sum1.outputs.result.value))
State of WorkGraph: FINISHED
Result of task add1: uuid: 6b972a52-f1cb-4493-9687-0e9d8b61c47b (pk: 358) value: 12.0
Generate node graph from the AiiDA process:
from aiida_workgraph.utils import generate_node_graph
generate_node_graph(wg.pk)
You can see that the outputs of multiply_parallel_gather workgraph is linked to the input of the sum task.
Total running time of the script: (0 minutes 11.072 seconds)