diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index ad55df1f3..72463eb53 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -26,26 +26,30 @@ from merlin.core.compat import HAS_GPU -try: - import cudf - import cupy as cp - import dask_cudf - import rmm - from cudf.core.column import as_column, build_column +cp = None +cudf = None +rmm = None +if HAS_GPU: try: - # cudf >= 21.08 - from cudf.api.types import is_list_dtype as cudf_is_list_dtype - from cudf.api.types import is_string_dtype as cudf_is_string_dtype + import cudf + import cupy as cp + import dask_cudf + import rmm + from cudf.core.column import as_column, build_column + + try: + # cudf >= 21.08 + from cudf.api.types import is_list_dtype as cudf_is_list_dtype + from cudf.api.types import is_string_dtype as cudf_is_string_dtype + except ImportError: + # cudf < 21.08 + from cudf.utils.dtypes import is_list_dtype as cudf_is_list_dtype + from cudf.utils.dtypes import is_string_dtype as cudf_is_string_dtype + except ImportError: - # cudf < 21.08 - from cudf.utils.dtypes import is_list_dtype as cudf_is_list_dtype - from cudf.utils.dtypes import is_string_dtype as cudf_is_string_dtype + pass -except ImportError: - cp = None - cudf = None - rmm = None try: # Dask >= 2021.5.1