Skip to content

Commit 21fbb71

Browse files
Check cuda.gpus.lst for available GPUs
1 parent 3f8e8f4 commit 21fbb71

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

merlin/core/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,14 @@
3030

3131
_merlin_dask_client = ContextVar("_merlin_dask_client", default="auto")
3232

33+
HAS_GPU = False
3334
try:
3435
from numba import cuda
36+
37+
try:
38+
HAS_GPU = len(cuda.gpus.lst) > 0
39+
except cuda.cudadrv.error.CudaSupportError:
40+
pass
3541
except ImportError:
3642
cuda = None
3743

@@ -254,7 +260,7 @@ class only supports the automatic generation of
254260
def __init__(self, client=None, cluster_type=None, force_new=False, **cluster_options):
255261
self._initial_client = global_dask_client() # Initial state
256262
self._client = client or "auto" # Cannot be `None`
257-
self.cluster_type = cluster_type or ("cpu" if cuda is None else "cuda")
263+
self.cluster_type = cluster_type or ("cuda" if HAS_GPU else "cpu")
258264
self.cluster_options = cluster_options
259265
# We can only shut down the cluster in `shutdown`/`__exit__`
260266
# if we are generating it internally

0 commit comments

Comments
 (0)