Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,5 @@
[u'UT-Battelle, LLC'], 1)
]

intersphinx_mapping = {'python': ('https://docs.python.org/3', None)}
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
'distributed': ('http://distributed.dask.org/en/stable', None)}
45 changes: 45 additions & 0 deletions doc/user_guides/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,49 @@ You batch script should then look like:
ips.py --config=ips.conf --platform=platform.conf


Running with worker plugin
--------------------------

There is the ability to set a
:class:`~distributed.diagnostics.plugin.WorkerPlugin` on the dask
worker using the `dask_worker_plugin` option in
:meth:`~ipsframework.services.ServicesProxy.submit_tasks`.

Using a WorkerPlugin in combination with shifter allows you to do
things like coping files out of the `Temporary XFS
<https://docs.nersc.gov/development/shifter/how-to-use/#temporary-xfs-files-for-optimizing-io>`_
file system. An example of that is

.. code-block:: python

from distributed.diagnostics.plugin import WorkerPlugin

class DaskWorkerPlugin(WorkerPlugin):
def __init__(self, tmp_dir, target_dir):
self.tmp_dir = tmp_dir
self.target_dir = target_dir

def teardown(self, worker):
os.system(f"cp {self.tmp_dir}/* {self.target_dir}")

class Worker(Component):
def step(self, timestamp=0.0):
cwd = self.services.get_working_dir()

self.services.create_task_pool('pool')
self.services.add_task('pool', 'task_1', 1, '/tmp/', 'executable')

worker_plugin = DaskWorkerPlugin('/tmp', cwd)

ret_val = self.services.submit_tasks('pool',
use_dask=True, use_shifter=True,
dask_worker_plugin=worker_plugin)

exit_status = self.services.get_finished_tasks('pool')


where the batch script has the temporary XFS filesystem mounted as

.. code-block:: bash

#SBATCH --volume="/global/cscratch1/sd/$USER/tmpfiles:/tmp:perNodeCache=size=1G"
18 changes: 12 additions & 6 deletions ipsframework/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,7 @@ def add_task(self, task_pool_name, task_name, nproc, working_dir,
*args, keywords=keywords)

def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1,
dask_ppn=None, launch_interval=0.0, use_shifter=False):
dask_ppn=None, launch_interval=0.0, use_shifter=False, dask_worker_plugin=None):
"""
Launch all unfinished tasks in task pool *task_pool_name*. If *block* is ``True``,
return when all tasks have been launched. If *block* is ``False``, return when all
Expand All @@ -1860,7 +1860,7 @@ def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1,
start_time = time.time()
self._send_monitor_event('IPS_TASK_POOL_BEGIN', 'task_pool = %s ' % task_pool_name)
task_pool: TaskPool = self.task_pools[task_pool_name]
retval = task_pool.submit_tasks(block, use_dask, dask_nodes, dask_ppn, launch_interval, use_shifter)
retval = task_pool.submit_tasks(block, use_dask, dask_nodes, dask_ppn, launch_interval, use_shifter, dask_worker_plugin)
elapsed_time = time.time() - start_time
self._send_monitor_event('IPS_TASK_POOL_END', 'task_pool = %s elapsed time = %.2f S' %
(task_pool_name, elapsed_time),
Expand Down Expand Up @@ -2066,7 +2066,7 @@ def add_task(self, task_name, nproc, working_dir, binary, *args, **keywords):
self.queued_tasks[task_name] = Task(task_name, nproc, working_dir, binary_fullpath, *args,
**keywords["keywords"])

def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter=False):
def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter=False, dask_worker_plugin=None):
"""Launch tasks in *queued_tasks* using dask.

:param block: Unused, this will always return after tasks are submitted
Expand All @@ -2077,6 +2077,8 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
:type dask_ppn: int
:param use_shifter: Option to launch dask scheduler and workers in shifter container
:type use_shifter: bool
:param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
:type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
"""
services: ServicesProxy = self.services
self.dask_file_name = os.path.join(os.getcwd(),
Expand Down Expand Up @@ -2115,6 +2117,9 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter

self.dask_client = self.dask.distributed.Client(scheduler_file=self.dask_file_name)

if dask_worker_plugin is not None:
self.dask_client.register_worker_plugin(dask_worker_plugin)

try:
self.worker_event_logfile = services.sim_name + '_' + services.get_config_param("PORTAL_RUNID") + '_' + self.name + '_{}.json'
except KeyError:
Expand All @@ -2135,7 +2140,7 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
self.queued_tasks = {}
return len(self.futures)

def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, launch_interval=0.0, use_shifter=False):
def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, launch_interval=0.0, use_shifter=False, dask_worker_plugin=None):
"""Launch tasks in *queued_tasks*. Finished tasks are handled before
launching new ones. If *block* is ``True``, the number of
tasks submitted is returned after all tasks have been launched
Expand All @@ -2157,7 +2162,8 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
:type launch_internal: float
:param use_shifter: Option to launch dask scheduler and workers in shifter container
:type use_shifter: bool

:param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
:type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
"""

if use_dask:
Expand All @@ -2167,7 +2173,7 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
self.services.error("Requested to run dask within shifter but shifter not available")
raise Exception("shifter not found")
else:
return self.submit_dask_tasks(block, dask_nodes, dask_ppn, use_shifter)
return self.submit_dask_tasks(block, dask_nodes, dask_ppn, use_shifter, dask_worker_plugin)
elif not TaskPool.dask:
self.services.warning("Requested use_dask but cannot because import dask failed")
elif not self.serial_pool:
Expand Down
14 changes: 13 additions & 1 deletion tests/helloworld/hello_worker_task_pool_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# -------------------------------------------------------------------------------
from time import sleep
import copy
from distributed.diagnostics.plugin import WorkerPlugin
from ipsframework import Component


Expand All @@ -12,6 +13,14 @@ def myFun(*args):
return 0


class DaskWorkerPlugin(WorkerPlugin):
def setup(self, worker):
print("Running setup of worker")

def teardown(self, worker):
print("Running teardown of worker")


class HelloWorker(Component):
def __init__(self, services, config):
super().__init__(services, config)
Expand All @@ -32,7 +41,10 @@ def step(self, timestamp=0.0, **keywords):
self.services.add_task('pool', 'func_' + str(i), 1,
cwd, myFun, duration)

ret_val = self.services.submit_tasks('pool', use_dask=True, dask_nodes=1, dask_ppn=10)
worker_plugin = DaskWorkerPlugin()

ret_val = self.services.submit_tasks('pool', use_dask=True, dask_nodes=1, dask_ppn=10,
dask_worker_plugin=worker_plugin)
print('ret_val = ', ret_val)
exit_status = self.services.get_finished_tasks('pool')
print(exit_status)
Expand Down
4 changes: 4 additions & 0 deletions tests/helloworld/test_helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ def test_helloworld_task_pool_dask(tmpdir, capfd):
assert captured_out[4] == 'HelloDriver: finished worker init call'
assert captured_out[5] == 'HelloDriver: beginning step call'
assert captured_out[6] == 'Hello from HelloWorker'

assert "Running setup of worker" in captured_out
assert "Running teardown of worker" in captured_out

assert 'ret_val = 9' in captured_out

for duration in ("0.2", "0.4", "0.6"):
Expand Down