forked from NVIDIA-Merlin/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconftest.py
More file actions
104 lines (77 loc) · 2.88 KB
/
conftest.py
File metadata and controls
104 lines (77 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# from functools import lru_cache
from __future__ import absolute_import
from pathlib import Path
import distributed
import pytest
from merlin.core.utils import Distributed
from merlin.datasets.synthetic import generate_data
from merlin.io import Dataset
REPO_ROOT = Path(__file__).parent.parent
@pytest.fixture
def ecommerce_data() -> Dataset:
return generate_data("e-commerce", num_rows=100)
@pytest.fixture
def music_streaming_data() -> Dataset:
return generate_data("music-streaming", num_rows=100)
@pytest.fixture
def sequence_testing_data() -> Dataset:
return generate_data("sequence-testing", num_rows=100)
@pytest.fixture
def social_data() -> Dataset:
return generate_data("social", num_rows=100)
@pytest.fixture
def criteo_data() -> Dataset:
return generate_data("criteo", num_rows=100)
@pytest.fixture
def testing_data() -> Dataset:
data = generate_data("testing", num_rows=100)
data.schema = data.schema.without(["session_id", "session_start", "day_idx"])
return data
@pytest.fixture(scope="module")
def dask_client() -> distributed.Client:
with Distributed(cluster_type="cpu") as dist:
yield dist.client
try:
import tensorflow as tf # noqa
from tests.unit.tf._conftest import * # noqa
except ImportError:
pass
try:
import torchmetrics # noqa
from tests.unit.torch._conftest import * # noqa
except ModuleNotFoundError:
pass
def pytest_collection_modifyitems(items):
for item in items:
path = item.location[0]
if path.startswith("tests/unit/tf"):
item.add_marker(pytest.mark.tensorflow)
if path.startswith("tests/unit/tf/examples"):
item.add_marker(pytest.mark.example)
if path.startswith("tests/unit/tf/integration"):
item.add_marker(pytest.mark.integration)
elif path.startswith("tests/unit/torch"):
item.add_marker(pytest.mark.torch)
elif path.startswith("tests/unit/implicit"):
item.add_marker(pytest.mark.implicit)
elif path.startswith("tests/unit/lightfm"):
item.add_marker(pytest.mark.lightfm)
elif path.startswith("tests/unit/xgb"):
item.add_marker(pytest.mark.xgboost)
elif path.startswith("tests/unit/datasets"):
item.add_marker(pytest.mark.datasets)