Skip to content

Commit 8068cfe

Browse files
Use HAS_GPU in dispatch module to avoid some imports
1 parent 136f688 commit 8068cfe

File tree

1 file changed

+28
-23
lines changed

1 file changed

+28
-23
lines changed

merlin/core/dispatch.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,33 +26,30 @@
2626

2727
from merlin.core.compat import HAS_GPU
2828

29-
try:
30-
import cudf
31-
import cupy as cp
32-
import dask_cudf
33-
import rmm
34-
from cudf.core.column import as_column, build_column
29+
cp = None
30+
cudf = None
31+
rmm = None
3532

33+
if HAS_GPU:
3634
try:
37-
# cudf >= 21.08
38-
from cudf.api.types import is_list_dtype as cudf_is_list_dtype
39-
from cudf.api.types import is_string_dtype as cudf_is_string_dtype
35+
import cudf
36+
import cupy as cp
37+
import dask_cudf
38+
import rmm
39+
from cudf.core.column import as_column, build_column
40+
41+
try:
42+
# cudf >= 21.08
43+
from cudf.api.types import is_list_dtype as cudf_is_list_dtype
44+
from cudf.api.types import is_string_dtype as cudf_is_string_dtype
45+
except ImportError:
46+
# cudf < 21.08
47+
from cudf.utils.dtypes import is_list_dtype as cudf_is_list_dtype
48+
from cudf.utils.dtypes import is_string_dtype as cudf_is_string_dtype
49+
4050
except ImportError:
41-
# cudf < 21.08
42-
from cudf.utils.dtypes import is_list_dtype as cudf_is_list_dtype
43-
from cudf.utils.dtypes import is_string_dtype as cudf_is_string_dtype
51+
pass
4452

45-
except ImportError:
46-
cp = None
47-
cudf = None
48-
rmm = None
49-
50-
try:
51-
# Dask >= 2021.5.1
52-
from dask.dataframe.core import hash_object_dispatch
53-
except ImportError:
54-
# Dask < 2021.5.1
55-
from dask.dataframe.utils import hash_object_dispatch
5653

5754
try:
5855
import nvtx
@@ -71,6 +68,14 @@ def inner2(*args, **kwargs):
7168
return inner1
7269

7370

71+
try:
72+
# Dask >= 2021.5.1
73+
from dask.dataframe.core import hash_object_dispatch
74+
except ImportError:
75+
# Dask < 2021.5.1
76+
from dask.dataframe.utils import hash_object_dispatch
77+
78+
7479
if HAS_GPU:
7580
DataFrameType = Union[pd.DataFrame, cudf.DataFrame] # type: ignore
7681
SeriesType = Union[pd.Series, cudf.Series] # type: ignore

0 commit comments

Comments
 (0)