@@ -508,6 +508,34 @@ def test_any_pad():
508508 verify_any_pad (any_dims (3 ), ((0 , 0 ), (1 , 1 ), (2 , 2 )), (1 , 2 , 3 ))
509509 verify_any_pad (any_dims (4 ), ((1 , 0 ), (1 , 3 ), (0 , 2 ), (9 , 0 )), (13 , 11 , 3 , 1 ))
510510
511+ def verify_any_dilate (data_shape , strides , static_data_shape ):
512+ assert len (data_shape ) == len (strides )
513+ mod = tvm .IRModule ()
514+ dtype = "float32"
515+ data = relay .var ('data' , shape = data_shape , dtype = dtype )
516+ y = relay .nn .dilate (data , strides )
517+ mod ["main" ] = relay .Function ([data ], y )
518+ data_np = np .random .uniform (size = static_data_shape ).astype (dtype )
519+ ref_shape = tuple ((static_data_shape [i ] - 1 ) * strides [i ] + 1
520+ for i in range (len (static_data_shape )))
521+ ref_out = np .zeros (shape = ref_shape , dtype = dtype )
522+ ref_out [tuple (slice (None , None , strides [i ]) for i in range (len (data_shape )))] = data_np
523+
524+ for kind in ["debug" , "vm" ]:
525+ ex = relay .create_executor (kind , mod = mod , ctx = tvm .cpu (), target = "llvm" )
526+ result = ex .evaluate ()(data_np )
527+ tvm .testing .assert_allclose (result .asnumpy (), ref_out )
528+
529+ def test_any_dilate ():
530+ verify_any_dilate (any_dims (1 ), (1 ,), (1 ,))
531+ verify_any_dilate (any_dims (1 ), (1 ,), (5 ,))
532+ verify_any_dilate (any_dims (1 ), (5 ,), (5 ,))
533+ verify_any_dilate (any_dims (3 ), (1 , 1 , 1 ), (1 , 2 , 3 ))
534+ verify_any_dilate (any_dims (3 ), (1 , 1 , 2 ), (1 , 2 , 3 ))
535+ verify_any_dilate (any_dims (3 ), (1 , 1 , 5 ), (1 , 2 , 3 ))
536+ verify_any_dilate (any_dims (3 ), (3 , 7 , 5 ), (1 , 2 , 3 ))
537+ verify_any_dilate (any_dims (4 ), (3 , 7 , 1 , 5 ), (1 , 2 , 3 , 4 ))
538+
511539def verify_any_softmax (data_shape , axis , static_data_shape , ref_out_shape ):
512540 mod = tvm .IRModule ()
513541 dtype = "float32"
0 commit comments