-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathnodes.py
More file actions
72 lines (58 loc) · 1.93 KB
/
nodes.py
File metadata and controls
72 lines (58 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Dict, Tuple
import onnxruntime as ort
import torch
from .detector.human_parts import get_mask, labels
from .utils import model_path
class HumanParts:
"""
This node is used to get a mask of the human parts in the image.
The model used is DeepLabV3+ with a ResNet50 backbone trained
by Keras-io, converted to ONNX format.
"""
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("mask",)
FUNCTION = "get_mask"
CATEGORY = "Metal3d"
OUTPU_NODE = True
@classmethod
def INPUT_TYPES(cls):
def _bool_widget(
is_enabled=False, tooltip: str | None = None
) -> Tuple[str, dict]:
"""Helper function to create a boolean widget"""
return (
"BOOLEAN",
{
"default": is_enabled,
"label_on": "Enabled",
"label_off": "Disabled",
"tooltip": tooltip,
},
)
# automate the creation of the inputs using the known labels
entries: Dict[str, tuple] = {
segment[0]: _bool_widget(False, f"{segment[1]}")
for segment in labels.values()
if segment[0] != ""
}
inputs = {
"required": {
"image": (
"IMAGE",
{
"label": "Image",
"tooltip": "The image in which to detect human parts",
},
)
},
"optional": {},
}
inputs["required"].update(entries)
return inputs
def get_mask(self, image: torch.Tensor, **kwargs) -> Tuple[torch.Tensor]:
"""
Return a Tensor with the mask of the human parts in the image.
"""
model = ort.InferenceSession(model_path)
ret_tensor, _ = get_mask(image, model=model, rotation=0, **kwargs)
return (ret_tensor,)