diff --git a/ci/environment.yml b/ci/environment.yml index 8734d9c5..df4cacb2 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -10,3 +10,4 @@ dependencies: - black - pytest - pytest-asyncio + - cryptography diff --git a/ci/htcondor.sh b/ci/htcondor.sh index 142b5b34..ccc6b5d3 100755 --- a/ci/htcondor.sh +++ b/ci/htcondor.sh @@ -16,7 +16,7 @@ function jobqueue_before_install { function jobqueue_install { cd ./ci/htcondor - docker-compose exec -T submit /bin/bash -c "cd /dask-jobqueue; pip3 install -e .;chown -R submituser ." + docker-compose exec -T submit /bin/bash -c "cd /dask-jobqueue; pip3 install -e .[test]; chown -R submituser ." cd - } diff --git a/conftest.py b/conftest.py index 751b4b15..a191a990 100644 --- a/conftest.py +++ b/conftest.py @@ -20,15 +20,17 @@ def pytest_addoption(parser): def pytest_configure(config): # register an additional marker config.addinivalue_line( - "markers", "env(name): mark test to run only on named environment" + "markers", "env(NAME): only run test if environment NAME matches" ) def pytest_runtest_setup(item): envnames = [mark.args[0] for mark in item.iter_markers(name="env")] - if envnames: - if item.config.getoption("-E") not in envnames: - pytest.skip("test requires env in %r" % envnames) + if (item.config.getoption("-E") is None and envnames) or ( + item.config.getoption("-E") is not None + and item.config.getoption("-E") not in envnames + ): + pytest.skip("test requires env in %r" % envnames) @pytest.fixture(autouse=True) diff --git a/dask_jobqueue/core.py b/dask_jobqueue/core.py index 22d1bbf2..d896cae0 100644 --- a/dask_jobqueue/core.py +++ b/dask_jobqueue/core.py @@ -8,6 +8,8 @@ import sys import weakref import abc +import tempfile +import copy import dask @@ -17,6 +19,7 @@ from distributed.deploy.spec import ProcessInterface, SpecCluster from distributed.deploy.local import nprocesses_nthreads from distributed.scheduler import Scheduler +from distributed.security import Security from distributed.utils import tmpfile logger = logging.getLogger(__name__) @@ -220,6 +223,7 @@ def __init__( extra = extra + ["--protocol", protocol] if security: worker_security_dict = security.get_tls_config_for_role("worker") + security_command_line_list = [ ["--tls-" + key.replace("_", "-"), value] for key, value in worker_security_dict.items() @@ -450,7 +454,7 @@ def __init__( scheduler_cls=Scheduler, # Use local scheduler for now # Options for both scheduler and workers interface=None, - protocol="tcp://", + protocol=None, # Job keywords config_name=None, **job_kwargs @@ -500,6 +504,17 @@ def __init__( "jobqueue.%s.scheduler-options" % config_name, {} ) + if protocol is None and security is not None: + protocol = "tls://" + if security is None and protocol is not None and protocol.startswith("tls"): + try: + security = Security.temporary() + except ImportError: + raise ImportError( + "In order to use TLS without pregenerated certificates `cryptography` is required," + "please install it using either pip or conda" + ) + default_scheduler_options = { "protocol": protocol, "dashboard_address": ":8787", @@ -521,7 +536,26 @@ def __init__( job_kwargs["config_name"] = config_name job_kwargs["interface"] = interface job_kwargs["protocol"] = protocol - job_kwargs["security"] = security + job_kwargs["security"] = copy.copy(security) + + if security is not None: + worker_security_dict = job_kwargs["security"].get_tls_config_for_role( + "worker" + ) + for key, value in worker_security_dict.items(): + # dump worker in-memory keys for use in job_script + if value is not None and "\n" in value: + f = tempfile.NamedTemporaryFile(mode="wt") + # make sure that tmpfile survives by keeping a reference + setattr(self, "_job_" + key, f) + f.write(value) + f.flush() + setattr( + job_kwargs["security"], + "tls_" + ("worker_" if key != "ca_file" else "") + key, + f.name, + ) + self._job_kwargs = job_kwargs worker = {"cls": self.job_cls, "options": self._job_kwargs} diff --git a/dask_jobqueue/tests/test_jobqueue_core.py b/dask_jobqueue/tests/test_jobqueue_core.py index 3290e94b..2a82201b 100644 --- a/dask_jobqueue/tests/test_jobqueue_core.py +++ b/dask_jobqueue/tests/test_jobqueue_core.py @@ -427,12 +427,11 @@ def test_security(): require_encryption=True, ) - with LocalCluster( - cores=1, memory="1GB", security=security, protocol="tls" - ) as cluster: + with LocalCluster(cores=1, memory="1GB", security=security) as cluster: assert cluster.security == security assert cluster.scheduler_spec["options"]["security"] == security job_script = cluster.job_script() + assert "tls://" in job_script assert "--tls-key {}".format(key) in job_script assert "--tls-cert {}".format(cert) in job_script assert "--tls-ca-file {}".format(cert) in job_script @@ -442,3 +441,20 @@ def test_security(): future = client.submit(lambda x: x + 1, 10) result = future.result() assert result == 11 + + +def test_security_temporary(): + with LocalCluster(cores=1, memory="1GB", protocol="tls") as cluster: + assert cluster.security + assert cluster.scheduler_spec["options"]["security"] == cluster.security + job_script = cluster.job_script() + assert "tls://" in job_script + assert "--tls-key" in job_script + assert "--tls-cert" in job_script + assert "--tls-ca-file" in job_script + + cluster.scale(jobs=1) + with Client(cluster) as client: + future = client.submit(lambda x: x + 1, 10) + result = future.result() + assert result == 11 diff --git a/setup.py b/setup.py index e5966511..60818565 100755 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ packages=["dask_jobqueue"], include_package_data=True, install_requires=install_requires, - tests_require=["pytest >= 2.7.1"], + extras_require={"test": ["pytest >= 2.7.1", "pytest-asyncio", "cryptography"]}, long_description=long_description, zip_safe=False, )