Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 57 additions & 48 deletions distributed/comm/tests/test_ws.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import tempfile
import warnings

import numpy as np
import pytest

import dask
Expand All @@ -18,12 +16,11 @@
get_client_ssl_context,
get_server_ssl_context,
inc,
xfail_ssl_issue5601,
)

from .test_comms import check_tls_extra

security = Security.temporary()


def test_registered():
assert "ws" in backends
Expand Down Expand Up @@ -77,22 +74,24 @@ async def test_expect_ssl_context():


@gen_test()
async def test_expect_scheduler_ssl_when_sharing_server():
with tempfile.TemporaryDirectory() as tempdir:
key_path = os.path.join(tempdir, "dask.pem")
cert_path = os.path.join(tempdir, "dask.crt")
with open(key_path, "w") as f:
f.write(security.tls_scheduler_key)
with open(cert_path, "w") as f:
f.write(security.tls_scheduler_cert)
c = {
"distributed.scheduler.dashboard.tls.key": key_path,
"distributed.scheduler.dashboard.tls.cert": cert_path,
}
with dask.config.set(c):
with pytest.raises(RuntimeError):
async with Scheduler(protocol="ws://", dashboard=True, port=8787):
pass
async def test_expect_scheduler_ssl_when_sharing_server(tmpdir):
xfail_ssl_issue5601()
pytest.importorskip("cryptography")
security = Security.temporary()
key_path = os.path.join(str(tmpdir), "dask.pem")
cert_path = os.path.join(str(tmpdir), "dask.crt")
with open(key_path, "w") as f:
f.write(security.tls_scheduler_key)
with open(cert_path, "w") as f:
f.write(security.tls_scheduler_cert)
c = {
"distributed.scheduler.dashboard.tls.key": key_path,
"distributed.scheduler.dashboard.tls.cert": cert_path,
}
with dask.config.set(c):
with pytest.raises(RuntimeError):
async with Scheduler(protocol="ws://", dashboard=True, port=8787):
pass


@gen_cluster(client=True, scheduler_kwargs={"protocol": "ws://"})
Expand All @@ -117,8 +116,8 @@ async def test_large_transfer(c, s, a, b):
await c.scatter(np.random.random(1_000_000))


@pytest.mark.asyncio
async def test_large_transfer_with_no_compression(cleanup):
@gen_test()
async def test_large_transfer_with_no_compression():
np = pytest.importorskip("numpy")
with dask.config.set({"distributed.comm.compression": None}):
async with Scheduler(protocol="ws://") as s:
Expand All @@ -132,16 +131,20 @@ async def test_large_transfer_with_no_compression(cleanup):
"dashboard,protocol,security,port",
[
(True, "ws://", None, 8787),
(True, "wss://", security, 8787),
(True, "wss://", True, 8787),
(False, "ws://", None, 8787),
(False, "wss://", security, 8787),
(False, "wss://", True, 8787),
(True, "ws://", None, 8786),
(True, "wss://", security, 8786),
(True, "wss://", True, 8786),
(False, "ws://", None, 8786),
(False, "wss://", security, 8786),
(False, "wss://", True, 8786),
],
)
async def test_http_and_comm_server(cleanup, dashboard, protocol, security, port):
if security:
xfail_ssl_issue5601()
pytest.importorskip("cryptography")
security = Security.temporary()
async with Scheduler(
protocol=protocol, dashboard=dashboard, port=port, security=security
) as s:
Expand All @@ -156,22 +159,18 @@ async def test_http_and_comm_server(cleanup, dashboard, protocol, security, port


@pytest.mark.asyncio
@pytest.mark.parametrize(
"protocol,security",
[
(
"ws://",
Security(extra_conn_args={"headers": {"Authorization": "Token abcd"}}),
),
(
"wss://",
Security.temporary(
extra_conn_args={"headers": {"Authorization": "Token abcd"}}
),
),
],
)
async def test_connection_made_with_extra_conn_args(cleanup, protocol, security):
@pytest.mark.parametrize("protocol", ["ws://", "wss://"])
async def test_connection_made_with_extra_conn_args(cleanup, protocol):
if protocol == "ws://":
security = Security(
extra_conn_args={"headers": {"Authorization": "Token abcd"}}
)
else:
xfail_ssl_issue5601()
pytest.importorskip("cryptography")
security = Security.temporary(
extra_conn_args={"headers": {"Authorization": "Token abcd"}}
)
async with Scheduler(
protocol=protocol, security=security, dashboard_address=":0"
) as s:
Expand All @@ -197,15 +196,25 @@ async def test_quiet_close():

@gen_cluster(client=True, scheduler_kwargs={"protocol": "ws://"})
async def test_ws_roundtrip(c, s, a, b):
np = pytest.importorskip("numpy")
x = np.arange(100)
future = await c.scatter(x)
y = await future
assert (x == y).all()


@gen_cluster(client=True, security=security, scheduler_kwargs={"protocol": "wss://"})
async def test_wss_roundtrip(c, s, a, b):
x = np.arange(100)
future = await c.scatter(x)
y = await future
assert (x == y).all()
@gen_test()
async def test_wss_roundtrip():
np = pytest.importorskip("numpy")
xfail_ssl_issue5601()
pytest.importorskip("cryptography")
security = Security.temporary()
async with Scheduler(
protocol="wss://", security=security, dashboard_address=":0"
) as s:
async with Worker(s.address, security=security) as w:
async with Client(s.address, security=security, asynchronous=True) as c:
x = np.arange(100)
future = await c.scatter(x)
y = await future
assert (x == y).all()
4 changes: 4 additions & 0 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
inc,
slowinc,
tls_only_security,
xfail_ssl_issue5601,
)


Expand Down Expand Up @@ -264,6 +265,7 @@ def test_Client_twice(loop):

@gen_test()
async def test_client_constructor_with_temporary_security():
xfail_ssl_issue5601()
pytest.importorskip("cryptography")
async with Client(
security=True, silence_logs=False, dashboard_address=":0", asynchronous=True
Expand Down Expand Up @@ -707,6 +709,7 @@ def test_adapt_then_manual(loop):
@pytest.mark.parametrize("temporary", [True, False])
def test_local_tls(loop, temporary):
if temporary:
xfail_ssl_issue5601()
pytest.importorskip("cryptography")
security = True
else:
Expand Down Expand Up @@ -989,6 +992,7 @@ async def test_threads_per_worker_set_to_0():
@pytest.mark.parametrize("temporary", [True, False])
async def test_capture_security(cleanup, temporary):
if temporary:
xfail_ssl_issue5601()
pytest.importorskip("cryptography")
security = True
else:
Expand Down
13 changes: 9 additions & 4 deletions distributed/tests/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from distributed.comm import connect, listen
from distributed.security import Security
from distributed.utils_test import get_cert
from distributed.utils_test import gen_test, get_cert, xfail_ssl_issue5601

ca_file = get_cert("tls-ca-cert.pem")

Expand Down Expand Up @@ -111,6 +111,8 @@ def test_kwargs():


def test_repr_temp_keys():
xfail_ssl_issue5601()
pytest.importorskip("cryptography")
sec = Security.temporary()
representation = repr(sec)
assert "Temporary (In-memory)" in representation
Expand Down Expand Up @@ -282,7 +284,7 @@ def basic_checks(ctx):
assert len(tls_13_ciphers) in (0, 3)


@pytest.mark.asyncio
@gen_test()
async def test_tls_listen_connect():
"""
Functional test for TLS connection args.
Expand Down Expand Up @@ -330,7 +332,7 @@ async def handle_comm(comm):
comm.abort()


@pytest.mark.asyncio
@gen_test()
async def test_require_encryption():
"""
Functional test for "require_encryption" setting.
Expand Down Expand Up @@ -394,6 +396,7 @@ def check_encryption_error():


def test_temporary_credentials():
xfail_ssl_issue5601()
pytest.importorskip("cryptography")

sec = Security.temporary()
Expand All @@ -411,14 +414,16 @@ def test_temporary_credentials():


def test_extra_conn_args_in_temporary_credentials():
xfail_ssl_issue5601()
pytest.importorskip("cryptography")

sec = Security.temporary(extra_conn_args={"headers": {"X-Request-ID": "abcd"}})
assert sec.extra_conn_args == {"headers": {"X-Request-ID": "abcd"}}


@pytest.mark.asyncio
@gen_test()
async def test_tls_temporary_credentials_functional():
xfail_ssl_issue5601()
pytest.importorskip("cryptography")

async def handle_comm(comm):
Expand Down
14 changes: 14 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
if TYPE_CHECKING:
from typing_extensions import Literal

from distributed.compatibility import MACOS
from distributed.scheduler import Scheduler

try:
Expand Down Expand Up @@ -1763,6 +1764,19 @@ async def connect(self, *args, **kwargs):
)


def xfail_ssl_issue5601():
"""Work around https://github.com/dask/distributed/issues/5601 where any test that
inits Security.temporary() crashes on MacOS GitHub Actions CI
"""
pytest.importorskip("cryptography")
try:
Security.temporary()
except ImportError:
if MACOS:
pytest.xfail(reason="distributed#5601")
raise


def assert_worker_story(
story: list[tuple], expect: list[tuple], *, strict: bool = False
) -> None:
Expand Down