diff --git a/merlin/core/compat.py b/merlin/core/compat.py new file mode 100644 index 000000000..4c70179e1 --- /dev/null +++ b/merlin/core/compat.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +HAS_GPU = False +try: + from numba import cuda + + try: + HAS_GPU = len(cuda.gpus.lst) > 0 + except cuda.cudadrv.error.CudaSupportError: + pass +except ImportError: + cuda = None diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index fb1db2446..ad55df1f3 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -24,6 +24,8 @@ import pyarrow as pa import pyarrow.parquet as pq +from merlin.core.compat import HAS_GPU + try: import cudf import cupy as cp @@ -40,9 +42,7 @@ from cudf.utils.dtypes import is_list_dtype as cudf_is_list_dtype from cudf.utils.dtypes import is_string_dtype as cudf_is_string_dtype - HAS_GPU = True except ImportError: - HAS_GPU = False cp = None cudf = None rmm = None diff --git a/merlin/core/utils.py b/merlin/core/utils.py index d83ddfbca..5cdf53e92 100644 --- a/merlin/core/utils.py +++ b/merlin/core/utils.py @@ -28,12 +28,10 @@ from dask.distributed import Client, get_client from tqdm import tqdm +from merlin.core.compat import HAS_GPU, cuda + _merlin_dask_client = ContextVar("_merlin_dask_client", default="auto") -try: - from numba import cuda -except ImportError: - cuda = None try: import psutil @@ -254,7 +252,7 @@ class only supports the automatic generation of def __init__(self, client=None, cluster_type=None, force_new=False, **cluster_options): self._initial_client = global_dask_client() # Initial state self._client = client or "auto" # Cannot be `None` - self.cluster_type = cluster_type or ("cpu" if cuda is None else "cuda") + self.cluster_type = cluster_type or ("cuda" if HAS_GPU else "cpu") self.cluster_options = cluster_options # We can only shut down the cluster in `shutdown`/`__exit__` # if we are generating it internally