From 21fbb71aa1808f19a1357651992b0c2f9bb60239 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 1 Jun 2022 17:12:33 +0100 Subject: [PATCH 1/2] Check cuda.gpus.lst for available GPUs --- merlin/core/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/merlin/core/utils.py b/merlin/core/utils.py index d83ddfbca..41545e91a 100644 --- a/merlin/core/utils.py +++ b/merlin/core/utils.py @@ -30,8 +30,14 @@ _merlin_dask_client = ContextVar("_merlin_dask_client", default="auto") +HAS_GPU = False try: from numba import cuda + + try: + HAS_GPU = len(cuda.gpus.lst) > 0 + except cuda.cudadrv.error.CudaSupportError: + pass except ImportError: cuda = None @@ -254,7 +260,7 @@ class only supports the automatic generation of def __init__(self, client=None, cluster_type=None, force_new=False, **cluster_options): self._initial_client = global_dask_client() # Initial state self._client = client or "auto" # Cannot be `None` - self.cluster_type = cluster_type or ("cpu" if cuda is None else "cuda") + self.cluster_type = cluster_type or ("cuda" if HAS_GPU else "cpu") self.cluster_options = cluster_options # We can only shut down the cluster in `shutdown`/`__exit__` # if we are generating it internally From 0e6f300caea67715418756e1f77b3990d8010caf Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 1 Jun 2022 18:21:37 +0100 Subject: [PATCH 2/2] Move gpu check to compat module --- merlin/core/compat.py | 25 +++++++++++++++++++++++++ merlin/core/dispatch.py | 4 ++-- merlin/core/utils.py | 12 ++---------- 3 files changed, 29 insertions(+), 12 deletions(-) create mode 100644 merlin/core/compat.py diff --git a/merlin/core/compat.py b/merlin/core/compat.py new file mode 100644 index 000000000..4c70179e1 --- /dev/null +++ b/merlin/core/compat.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +HAS_GPU = False +try: + from numba import cuda + + try: + HAS_GPU = len(cuda.gpus.lst) > 0 + except cuda.cudadrv.error.CudaSupportError: + pass +except ImportError: + cuda = None diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index fb1db2446..ad55df1f3 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -24,6 +24,8 @@ import pyarrow as pa import pyarrow.parquet as pq +from merlin.core.compat import HAS_GPU + try: import cudf import cupy as cp @@ -40,9 +42,7 @@ from cudf.utils.dtypes import is_list_dtype as cudf_is_list_dtype from cudf.utils.dtypes import is_string_dtype as cudf_is_string_dtype - HAS_GPU = True except ImportError: - HAS_GPU = False cp = None cudf = None rmm = None diff --git a/merlin/core/utils.py b/merlin/core/utils.py index 41545e91a..5cdf53e92 100644 --- a/merlin/core/utils.py +++ b/merlin/core/utils.py @@ -28,18 +28,10 @@ from dask.distributed import Client, get_client from tqdm import tqdm -_merlin_dask_client = ContextVar("_merlin_dask_client", default="auto") +from merlin.core.compat import HAS_GPU, cuda -HAS_GPU = False -try: - from numba import cuda +_merlin_dask_client = ContextVar("_merlin_dask_client", default="auto") - try: - HAS_GPU = len(cuda.gpus.lst) > 0 - except cuda.cudadrv.error.CudaSupportError: - pass -except ImportError: - cuda = None try: import psutil