diff --git a/merlin/core/compat.py b/merlin/core/compat.py index 4c70179e1..d372f8a73 100644 --- a/merlin/core/compat.py +++ b/merlin/core/compat.py @@ -13,13 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # -HAS_GPU = False try: from numba import cuda - try: - HAS_GPU = len(cuda.gpus.lst) > 0 - except cuda.cudadrv.error.CudaSupportError: - pass except ImportError: cuda = None + +HAS_GPU = False +try: + from dask.distributed.diagnostics import nvml + + HAS_GPU = nvml.device_get_count() > 0 +except ImportError: + # We can use `cuda` to set `HAS_GPU` now that we + # know `distributed` is not installed (otherwise + # the `nvml` import would have succeeded) + HAS_GPU = cuda is not None