Skip to content

Commit 7039479

Browse files
rjzamorakarlhigley
andauthored
Avoid using numba to set device context in import (#145)
* avoid numba context in import * use pynvml instead of dask-cuda * reduce diff * pylint * use distributed.nvml Co-authored-by: Karl Higley <kmhigley@gmail.com>
1 parent e3286f8 commit 7039479

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

merlin/core/compat.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
HAS_GPU = False
1716
try:
1817
from numba import cuda
1918

20-
try:
21-
HAS_GPU = len(cuda.gpus.lst) > 0
22-
except cuda.cudadrv.error.CudaSupportError:
23-
pass
2419
except ImportError:
2520
cuda = None
21+
22+
HAS_GPU = False
23+
try:
24+
from dask.distributed.diagnostics import nvml
25+
26+
HAS_GPU = nvml.device_get_count() > 0
27+
except ImportError:
28+
# We can use `cuda` to set `HAS_GPU` now that we
29+
# know `distributed` is not installed (otherwise
30+
# the `nvml` import would have succeeded)
31+
HAS_GPU = cuda is not None

0 commit comments

Comments
 (0)