diff --git a/dask_jobqueue/__init__.py b/dask_jobqueue/__init__.py index dcee8ff8..55aa52bf 100644 --- a/dask_jobqueue/__init__.py +++ b/dask_jobqueue/__init__.py @@ -2,3 +2,4 @@ from .core import JobQueueCluster from .pbs import PBSCluster from .slurm import SLURMCluster +from .sge import SGECluster diff --git a/dask_jobqueue/core.py b/dask_jobqueue/core.py index 7806c937..f8ad52d1 100644 --- a/dask_jobqueue/core.py +++ b/dask_jobqueue/core.py @@ -4,6 +4,8 @@ import socket import os import sys +import shlex + import docrep from distributed.utils import tmpfile, ignoring, get_ip_interface, parse_bytes @@ -153,8 +155,8 @@ def start_workers(self, n=1): workers = [] for _ in range(n): with self.job_file() as fn: - out = self._call([self.submit_command, fn]) - job = out.decode().split('.')[0] + out = self._call(shlex.split(self.submit_command) + [fn]) + job = out.decode().split('.')[0].strip() self.jobs[self.n] = job workers.append(self.n) return workers diff --git a/dask_jobqueue/sge.py b/dask_jobqueue/sge.py new file mode 100644 index 00000000..1dadeb83 --- /dev/null +++ b/dask_jobqueue/sge.py @@ -0,0 +1,77 @@ +import logging + +from .core import JobQueueCluster, docstrings + +logger = logging.getLogger(__name__) + + +@docstrings.with_indent(4) +class SGECluster(JobQueueCluster): + """ Launch Dask on a SGE cluster + + Parameters + ---------- + queue : str + Destination queue for each worker job. Passed to `#$ -q` option. + project : str + Accounting string associated with each worker job. Passed to + `#$ -A` option. + resource_spec : str + Request resources and specify job placement. Passed to `#$ -l` + option. + walltime : str + Walltime for each worker job. + %(JobQueueCluster.parameters)s + + Examples + -------- + >>> from dask_jobqueue import SGECluster + >>> cluster = SGECluster(queue='regular') + >>> cluster.start_workers(10) # this may take a few seconds to launch + + >>> from dask.distributed import Client + >>> client = Client(cluster) + + This also works with adaptive clusters. This automatically launches and + kill workers based on load. + + >>> cluster.adapt() + """ + + #Override class variables + submit_command = 'qsub -terse' + cancel_command = 'qdel' + + def __init__(self, + queue=None, + project=None, + resource_spec=None, + walltime='0:30:00', + **kwargs): + + super(SGECluster, self).__init__(**kwargs) + + header_lines = ['#!/bin/bash'] + + if self.name is not None: + header_lines.append('#$ -N %(name)s') + if queue is not None: + header_lines.append('#$ -q %(queue)s') + if project is not None: + header_lines.append('#$ -P %(project)s') + if resource_spec is not None: + header_lines.append('#$ -l %(resource_spec)s') + if walltime is not None: + header_lines.append('#$ -l h_rt=%(walltime)s') + header_lines.extend(['#$ -cwd', '#$ -j y']) + header_template = '\n'.join(header_lines) + + config = {'name': self.name, + 'queue': queue, + 'project': project, + 'processes': self.worker_processes, + 'walltime': walltime, + 'resource_spec': resource_spec,} + self.job_header = header_template % config + + logger.debug("Job script: \n %s" % self.job_script()) diff --git a/dask_jobqueue/tests/test_sge.py b/dask_jobqueue/tests/test_sge.py index 15e984c8..ec40a8a1 100644 --- a/dask_jobqueue/tests/test_sge.py +++ b/dask_jobqueue/tests/test_sge.py @@ -1,9 +1,34 @@ +from time import time, sleep import pytest +from dask.distributed import Client +from distributed.utils_test import loop # noqa: F401 + +from dask_jobqueue import SGECluster + pytestmark = pytest.mark.env("sge") -def test_sge_placeholder(): - # to test that CI is working - pass +def test_basic(loop): # noqa: F811 + with SGECluster(walltime='00:02:00', threads=2, memory='7GB', + loop=loop) as cluster: + with Client(cluster, loop=loop) as client: + workers = cluster.start_workers(2) + future = client.submit(lambda x: x + 1, 10) + assert future.result(60) == 11 + assert cluster.jobs + + info = client.scheduler_info() + w = list(info['workers'].values())[0] + assert w['memory_limit'] == 7e9 + assert w['ncores'] == 2 + + cluster.stop_workers(workers) + + start = time() + while len(client.scheduler_info()['workers']) > 0: + sleep(0.100) + assert time() < start + 10 + + assert not cluster.jobs