Skip to content

Commit e365820

Browse files
authored
Update merlin.core.compat to use HAS_GPU and add add'l libraries (#262)
1 parent ec9a360 commit e365820

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

merlin/core/compat.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (c) 2022, NVIDIA CORPORATION.
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -17,11 +17,6 @@
1717
# pylint: disable=unused-import
1818
import os
1919

20-
try:
21-
from numba import cuda
22-
except ImportError:
23-
cuda = None
24-
2520
from dask.distributed.diagnostics import nvml
2621

2722

@@ -50,21 +45,42 @@ def _get_gpu_count():
5045

5146
HAS_GPU = _get_gpu_count() > 0
5247

48+
if HAS_GPU:
49+
try:
50+
from numba import cuda
51+
except ImportError:
52+
cuda = None
53+
54+
try:
55+
import cudf
56+
except ImportError:
57+
cudf = None
58+
59+
try:
60+
import cupy
61+
except ImportError:
62+
cupy = None
63+
64+
try:
65+
import dask_cudf
66+
except ImportError:
67+
dask_cudf = None
68+
69+
else:
70+
cuda = None
71+
cudf = None
72+
cupy = None
73+
dask_cudf = None
5374

5475
try:
5576
import numpy
5677
except ImportError:
5778
numpy = None
5879

5980
try:
60-
import cupy
81+
import pandas
6182
except ImportError:
62-
cupy = None
63-
64-
try:
65-
import cudf
66-
except ImportError:
67-
cudf = None
83+
pandas = None
6884

6985
try:
7086
import tensorflow

0 commit comments

Comments
 (0)