Skip to content

Commit 5d208c5

Browse files
committed
Update paths to avoid conflicts in Hub
1 parent 26e62f7 commit 5d208c5

File tree

7 files changed

+15
-11
lines changed

7 files changed

+15
-11
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ The pre-trained model corresponds to `DS 4` with multi-objective optimization en
1313

1414
### Setup
1515

16-
1) Download the model weights [model.pt](https://github.com/intel-isl/MiDaS/releases/download/v2/model.pt) and place the
16+
1) Download the model weights [model-f45da743.pt](https://github.com/intel-isl/MiDaS/releases/download/v2/model-f46da743.pt) and place the
1717
file in the root folder.
1818

1919
2) Set up dependencies:

hubconf.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
dependencies = ['torch']
1+
dependencies = ["torch"]
22

33
import torch
44

5-
from models.midas_net import MidasNet
5+
from midas.midas_net import MidasNet
6+
67

78
def MiDaS(pretrained=True, **kwargs):
89
""" # This docstring shows up in hub.help()
@@ -13,8 +14,12 @@ def MiDaS(pretrained=True, **kwargs):
1314
model = MidasNet()
1415

1516
if pretrained:
16-
checkpoint = "https://github.com/intel-isl/MiDaS/releases/download/v2/model.pt"
17-
state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
17+
checkpoint = (
18+
"https://github.com/intel-isl/MiDaS/releases/download/v2/model-f46da743.pt"
19+
)
20+
state_dict = torch.hub.load_state_dict_from_url(
21+
checkpoint, progress=True, check_hash=True
22+
)
1823
model.load_state_dict(state_dict)
1924

2025
return model
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.nn as nn
32

43

54
class BaseModel(torch.nn.Module):
File renamed without changes.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import torch
66
import torch.nn as nn
77

8-
from models.base_model import BaseModel
9-
from models.blocks import FeatureFusionBlock, Interpolate, _make_encoder
8+
from .base_model import BaseModel
9+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
1010

1111

1212
class MidasNet(BaseModel):

run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import cv2
88

99
from torchvision.transforms import Compose
10-
from models.midas_net import MidasNet
11-
from models.transforms import Resize, NormalizeImage, PrepareForNet
10+
from midas.midas_net import MidasNet
11+
from midas.transforms import Resize, NormalizeImage, PrepareForNet
1212

1313

1414
def run(input_path, output_path, model_path):
@@ -95,7 +95,7 @@ def run(input_path, output_path, model_path):
9595
INPUT_PATH = "input"
9696
OUTPUT_PATH = "output"
9797
# MODEL_PATH = "model.pt"
98-
MODEL_PATH = "model.pt"
98+
MODEL_PATH = "model-f46da743.pt"
9999

100100
# set torch options
101101
torch.backends.cudnn.enabled = True

0 commit comments

Comments
 (0)