99from dataset_domain import CMRDataset
1010
1111from torch .utils import data
12- import torchvision
1312from losses import DiceLoss
1413from utils .utils import *
1514from utils import metrics
@@ -222,6 +221,8 @@ def get_comma_separated_int_args(option, opt, value, parser):
222221
223222 parser .add_option ('-o' , '--log-path' , type = 'str' , dest = 'log_path' , default = './log/' , help = 'log path' )
224223 parser .add_option ('-m' , type = 'str' , dest = 'model' , default = 'UTNet' , help = 'use which model' )
224+ parser .add_option ('--num_class' , type = 'int' , dest = 'num_class' , default = 4 , help = 'number of segmentation classes' )
225+ parser .add_option ('--base_chan' , type = 'int' , dest = 'base_chan' , default = 32 , help = 'number of channels of first expansion in UNet' )
225226 parser .add_option ('-u' , '--unique_name' , type = 'str' , dest = 'unique_name' , default = 'test' , help = 'unique experiment name' )
226227 parser .add_option ('--rlt' , type = 'float' , dest = 'rlt' , default = 1 , help = 'relation between CE/FL and dice' )
227228 parser .add_option ('--weight' , type = 'float' , dest = 'weight' ,
@@ -246,9 +247,10 @@ def get_comma_separated_int_args(option, opt, value, parser):
246247 print ('Using model:' , options .model )
247248
248249 if options .model == 'UTNet' :
249- net = UTNet (1 , 32 , 4 , reduce_size = options .reduce_size , block_list = options .block_list , num_blocks = options .num_blocks , num_heads = [4 ,4 ,4 ,4 ], projection = 'interp' , attn_drop = 0.1 , proj_drop = 0.1 , rel_pos = True , aux_loss = options .aux_loss , maxpool = True )
250+ net = UTNet (1 , options . base_chan , options . num_class , reduce_size = options .reduce_size , block_list = options .block_list , num_blocks = options .num_blocks , num_heads = [4 ,4 ,4 ,4 ], projection = 'interp' , attn_drop = 0.1 , proj_drop = 0.1 , rel_pos = True , aux_loss = options .aux_loss , maxpool = True )
250251 elif options .model == 'UTNet_encoder' :
251- net = UTNet_Encoderonly (1 , 32 , 4 , reduce_size = options .reduce_size , block_list = options .block_list , num_blocks = options .num_blocks , num_heads = [4 ,4 ,4 ,4 ], projection = 'interp' , attn_drop = 0.1 , proj_drop = 0.1 , rel_pos = True , aux_loss = options .aux_loss , maxpool = True )
252+ # Apply transformer blocks only in the encoder
253+ net = UTNet_Encoderonly (1 , options .base_chan , options .num_class , reduce_size = options .reduce_size , block_list = options .block_list , num_blocks = options .num_blocks , num_heads = [4 ,4 ,4 ,4 ], projection = 'interp' , attn_drop = 0.1 , proj_drop = 0.1 , rel_pos = True , aux_loss = options .aux_loss , maxpool = True )
252254
253255
254256 else :
0 commit comments