Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@
import pytest

import dask_jobqueue.lsf
import dask

from dask_jobqueue import (
PBSCluster,
MoabCluster,
SLURMCluster,
SGECluster,
LSFCluster,
OARCluster,
HTCondorCluster,
)

from dask_jobqueue.local import LocalCluster


def pytest_addoption(parser):
Expand All @@ -18,19 +31,35 @@ def pytest_addoption(parser):


def pytest_configure(config):
# register an additional marker
# register additional markers
config.addinivalue_line(
"markers", "env(NAME): only run test if environment NAME matches"
)
config.addinivalue_line(
"markers", "xfail_env(NAME): known failure for environment NAME"
)


def pytest_runtest_setup(item):
envnames = [mark.args[0] for mark in item.iter_markers(name="env")]
if (item.config.getoption("-E") is None and envnames) or (
item.config.getoption("-E") is not None
and item.config.getoption("-E") not in envnames
env = item.config.getoption("-E")
envnames = sum(
[
mark.args[0] if isinstance(mark.args[0], list) else [mark.args[0]]
for mark in item.iter_markers(name="env")
],
[],
)
if (
None not in envnames
and (env is None and envnames)
or (env is not None and env not in envnames)
):
pytest.skip("test requires env in %r" % envnames)
else:
xfail = {}
[xfail.update(mark.args[0]) for mark in item.iter_markers(name="xfail_env")]
if env in xfail:
item.add_marker(pytest.mark.xfail(reason=xfail[env]))


@pytest.fixture(autouse=True)
Expand All @@ -46,3 +75,35 @@ def mock_lsf_version(monkeypatch, request):
except OSError:
# Provide a fake implementation of lsf_version()
monkeypatch.setattr(dask_jobqueue.lsf, "lsf_version", lambda: "10")


all_envs = {
None: LocalCluster,
"pbs": PBSCluster,
"moab": MoabCluster,
"slurm": SLURMCluster,
"sge": SGECluster,
"lsf": LSFCluster,
"oar": OARCluster,
"htcondor": HTCondorCluster,
}


@pytest.fixture(
params=[pytest.param(v, marks=[pytest.mark.env(k)]) for (k, v) in all_envs.items()]
)
def EnvSpecificCluster(request):
"""Run test only with the specific cluster class set by the environment"""
if request.param == HTCondorCluster:
# HTCondor requires explicitly specifying requested disk space
dask.config.set({"jobqueue.htcondor.disk": "1GB"})
return request.param


@pytest.fixture(params=list(all_envs.values()))
def Cluster(request):
Comment thread
guillaumeeb marked this conversation as resolved.
"""Run test for each cluster class when no environment is set (test should not require the actual scheduler)"""
if request.param == HTCondorCluster:
# HTCondor requires explicitly specifying requested disk space
dask.config.set({"jobqueue.htcondor.disk": "1GB"})
Comment thread
guillaumeeb marked this conversation as resolved.
return request.param
62 changes: 12 additions & 50 deletions dask_jobqueue/tests/test_job.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,24 @@
import asyncio
from time import time

from dask_jobqueue import (
PBSCluster,
SGECluster,
SLURMCluster,
LSFCluster,
HTCondorCluster,
MoabCluster,
OARCluster,
)
from dask_jobqueue.local import LocalJob, LocalCluster
from dask_jobqueue.local import LocalCluster
from dask_jobqueue.pbs import PBSJob
from dask_jobqueue.sge import SGEJob
from dask_jobqueue.slurm import SLURMJob
from dask_jobqueue.lsf import LSFJob
from dask_jobqueue.moab import MoabJob
from dask_jobqueue.htcondor import HTCondorJob
from dask_jobqueue.oar import OARJob

from dask_jobqueue.core import JobQueueCluster
from dask.distributed import Scheduler, Client
from distributed.core import Status

import pytest


def test_basic():
job = PBSJob(scheduler="127.0.0.1:12345", cores=1, memory="1 GB")
def test_basic(Cluster):
job_cls = Cluster.job_cls
job = job_cls(scheduler="127.0.0.1:12345", cores=1, memory="1 GB")
assert "127.0.0.1:12345" in job.job_script()


job_protected = [
pytest.param(SGEJob, marks=[pytest.mark.env("sge")]),
pytest.param(PBSJob, marks=[pytest.mark.env("pbs")]),
pytest.param(SLURMJob, marks=[pytest.mark.env("slurm")]),
pytest.param(LSFJob, marks=[pytest.mark.env("lsf")]),
LocalJob,
]


all_jobs = [SGEJob, PBSJob, SLURMJob, LSFJob, HTCondorJob, MoabJob, OARJob]
all_clusters = [
SGECluster,
PBSCluster,
SLURMCluster,
LSFCluster,
HTCondorCluster,
MoabCluster,
OARCluster,
HTCondorCluster,
]


@pytest.mark.parametrize("job_cls", job_protected)
@pytest.mark.asyncio
async def test_job(job_cls):
async def test_job(EnvSpecificCluster):
job_cls = EnvSpecificCluster.job_cls
Comment thread
guillaumeeb marked this conversation as resolved.
async with Scheduler(port=0) as s:
job = job_cls(scheduler=s.address, name="foo", cores=1, memory="1GB")
job = await job
Expand All @@ -71,9 +34,9 @@ async def test_job(job_cls):
assert time() < start + 10


@pytest.mark.parametrize("job_cls", job_protected)
@pytest.mark.asyncio
async def test_cluster(job_cls):
async def test_cluster(EnvSpecificCluster):
job_cls = EnvSpecificCluster.job_cls
async with JobQueueCluster(
1, cores=1, memory="1GB", job_cls=job_cls, asynchronous=True, name="foo"
) as cluster:
Expand All @@ -94,9 +57,9 @@ async def test_cluster(job_cls):
assert time() < start + 10


@pytest.mark.parametrize("job_cls", job_protected)
@pytest.mark.asyncio
async def test_adapt(job_cls):
async def test_adapt(EnvSpecificCluster):
job_cls = EnvSpecificCluster.job_cls
async with JobQueueCluster(
1, cores=1, memory="1GB", job_cls=job_cls, asynchronous=True, name="foo"
) as cluster:
Expand Down Expand Up @@ -124,9 +87,9 @@ async def test_adapt(job_cls):
assert not cluster.workers


@pytest.mark.parametrize("job_cls", job_protected)
@pytest.mark.asyncio
async def test_adapt_parameters(job_cls):
async def test_adapt_parameters(EnvSpecificCluster):
job_cls = EnvSpecificCluster.job_cls
async with JobQueueCluster(
cores=2, memory="1GB", processes=2, job_cls=job_cls, asynchronous=True
) as cluster:
Expand Down Expand Up @@ -176,7 +139,6 @@ async def test_nprocs_scale():
assert len(cluster.worker_spec) == 1


@pytest.mark.parametrize("Cluster", all_clusters)
def test_docstring_cluster(Cluster):
assert "cores :" in Cluster.__doc__
assert Cluster.__name__[: -len("Cluster")] in Cluster.__doc__
Loading