Skip to content

Commit eabd412

Browse files
author
Yunhe Gao
committed
update args
1 parent 738eda2 commit eabd412

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

train_deep.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from dataset_domain import CMRDataset
1010

1111
from torch.utils import data
12-
import torchvision
1312
from losses import DiceLoss
1413
from utils.utils import *
1514
from utils import metrics
@@ -222,6 +221,8 @@ def get_comma_separated_int_args(option, opt, value, parser):
222221

223222
parser.add_option('-o', '--log-path', type='str', dest='log_path', default='./log/', help='log path')
224223
parser.add_option('-m', type='str', dest='model', default='UTNet', help='use which model')
224+
parser.add_option('--num_class', type='int', dest='num_class', default=4, help='number of segmentation classes')
225+
parser.add_option('--base_chan', type='int', dest='base_chan', default=32, help='number of channels of first expansion in UNet')
225226
parser.add_option('-u', '--unique_name', type='str', dest='unique_name', default='test', help='unique experiment name')
226227
parser.add_option('--rlt', type='float', dest='rlt', default=1, help='relation between CE/FL and dice')
227228
parser.add_option('--weight', type='float', dest='weight',
@@ -246,9 +247,10 @@ def get_comma_separated_int_args(option, opt, value, parser):
246247
print('Using model:', options.model)
247248

248249
if options.model == 'UTNet':
249-
net = UTNet(1, 32, 4, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)
250+
net = UTNet(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)
250251
elif options.model == 'UTNet_encoder':
251-
net = UTNet_Encoderonly(1, 32, 4, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)
252+
# Apply transformer blocks only in the encoder
253+
net = UTNet_Encoderonly(1, options.base_chan, options.num_class, reduce_size=options.reduce_size, block_list=options.block_list, num_blocks=options.num_blocks, num_heads=[4,4,4,4], projection='interp', attn_drop=0.1, proj_drop=0.1, rel_pos=True, aux_loss=options.aux_loss, maxpool=True)
252254

253255

254256
else:

0 commit comments

Comments
 (0)