Skip to content

Commit 136f688

Browse files
Update GPU detection in merlin.core.utils for Distributed class (#98)
* Check cuda.gpus.lst for available GPUs * Move gpu check to compat module
1 parent 3f8e8f4 commit 136f688

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

merlin/core/compat.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#
2+
# Copyright (c) 2022, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
HAS_GPU = False
17+
try:
18+
from numba import cuda
19+
20+
try:
21+
HAS_GPU = len(cuda.gpus.lst) > 0
22+
except cuda.cudadrv.error.CudaSupportError:
23+
pass
24+
except ImportError:
25+
cuda = None

merlin/core/dispatch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import pyarrow as pa
2525
import pyarrow.parquet as pq
2626

27+
from merlin.core.compat import HAS_GPU
28+
2729
try:
2830
import cudf
2931
import cupy as cp
@@ -40,9 +42,7 @@
4042
from cudf.utils.dtypes import is_list_dtype as cudf_is_list_dtype
4143
from cudf.utils.dtypes import is_string_dtype as cudf_is_string_dtype
4244

43-
HAS_GPU = True
4445
except ImportError:
45-
HAS_GPU = False
4646
cp = None
4747
cudf = None
4848
rmm = None

merlin/core/utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@
2828
from dask.distributed import Client, get_client
2929
from tqdm import tqdm
3030

31+
from merlin.core.compat import HAS_GPU, cuda
32+
3133
_merlin_dask_client = ContextVar("_merlin_dask_client", default="auto")
3234

33-
try:
34-
from numba import cuda
35-
except ImportError:
36-
cuda = None
3735

3836
try:
3937
import psutil
@@ -254,7 +252,7 @@ class only supports the automatic generation of
254252
def __init__(self, client=None, cluster_type=None, force_new=False, **cluster_options):
255253
self._initial_client = global_dask_client() # Initial state
256254
self._client = client or "auto" # Cannot be `None`
257-
self.cluster_type = cluster_type or ("cpu" if cuda is None else "cuda")
255+
self.cluster_type = cluster_type or ("cuda" if HAS_GPU else "cpu")
258256
self.cluster_options = cluster_options
259257
# We can only shut down the cluster in `shutdown`/`__exit__`
260258
# if we are generating it internally

0 commit comments

Comments
 (0)