Skip to content

Commit f6651f4

Browse files
committed
compatible with MAC
1 parent a853e63 commit f6651f4

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

tests/test_model_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
Author: Wenyu Ouyang
3+
Date: 2025-04-15 13:07:20
4+
LastEditTime: 2025-04-15 13:09:25
5+
LastEditors: Wenyu Ouyang
6+
Description:
7+
FilePath: /torchhydro/tests/test_model_utils.py
8+
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
9+
"""
10+
11+
import pytest
12+
import torch
13+
from torchhydro.models.model_utils import get_the_device
14+
15+
16+
@pytest.mark.parametrize(
17+
"device_num, expected_device",
18+
[
19+
(-1, "cpu"),
20+
([-1], "cpu"),
21+
(["-1"], "cpu"),
22+
],
23+
)
24+
def test_get_the_device_cpu(device_num, expected_device):
25+
device = get_the_device(device_num)
26+
assert device.type == expected_device
27+
28+
29+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
30+
@pytest.mark.parametrize(
31+
"device_num, expected_device",
32+
[
33+
(0, "cuda"),
34+
([0], "cuda"),
35+
(1, "cuda"),
36+
],
37+
)
38+
def test_get_the_device_cuda(device_num, expected_device):
39+
device = get_the_device(device_num)
40+
assert device.type == expected_device
41+
assert device.index == (
42+
device_num[0] if isinstance(device_num, list) else device_num
43+
)
44+
45+
46+
@pytest.mark.skipif(
47+
not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(),
48+
reason="MPS is not available",
49+
)
50+
@pytest.mark.parametrize(
51+
"device_num, expected_device",
52+
[
53+
(0, "mps"),
54+
([0], "mps"),
55+
(1, "mps"), # Should warn and default to mps:0
56+
],
57+
)
58+
def test_get_the_device_mps(device_num, expected_device):
59+
if device_num != 0:
60+
with pytest.warns(UserWarning, match="MPS only supports device 0"):
61+
device = get_the_device(device_num)
62+
else:
63+
device = get_the_device(device_num)
64+
assert device.type == expected_device
65+
assert device.index == 0

torchhydro/models/model_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""
22
Author: Wenyu Ouyang
33
Date: 2021-08-09 10:19:13
4-
LastEditTime: 2023-09-21 16:47:05
4+
LastEditTime: 2025-04-15 12:59:35
55
LastEditors: Wenyu Ouyang
66
Description: Some util functions for modeling
77
FilePath: /torchhydro/torchhydro/models/model_utils.py
88
Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
99
"""
1010

11+
import contextlib
1112
from typing import Union
1213
import warnings
1314
import torch
@@ -20,7 +21,7 @@ def get_the_device(device_num: Union[list, int]):
2021
Parameters
2122
----------
2223
device_num : Union[list, int]
23-
number of the device -- -1 means "cpu" or 0, 1, ... means "cuda:x"
24+
number of the device -- -1 means "cpu" or 0, 1, ... means "cuda:x" or "mps:x"
2425
"""
2526
if device_num in [[-1], -1, ["-1"]]:
2627
return torch.device("cpu")
@@ -30,6 +31,16 @@ def get_the_device(device_num: Union[list, int]):
3031
if type(device_num) is not list
3132
else torch.device(f"cuda:{str(device_num[0])}")
3233
)
34+
# Check for MPS (MacOS)
35+
mps_available = False
36+
with contextlib.suppress(AttributeError):
37+
mps_available = torch.backends.mps.is_available()
38+
if mps_available:
39+
if device_num != 0:
40+
warnings.warn(
41+
f"MPS only supports device 0. Using 'mps:0' instead of {device_num}."
42+
)
43+
return torch.device("mps:0")
3344
if device_num not in [[-1], -1, ["-1"]]:
3445
warnings.warn("You don't have GPU, so have to choose cpu for models")
3546
return torch.device("cpu")

0 commit comments

Comments
 (0)