@@ -44,16 +44,6 @@ def sample_to_model_input(
4444 if ego_future .numel () != 0 :
4545 ego_future = ego_future [..., 1 :1 + self .future_len , :3 ] # (x, y, heading)
4646
47- # in the original input, the neighbor_future only include 10 neighbors, while the neighbor_agent_current include 32 neighbors
48-
49- neighbor_future = data .neighbor_future
50- if neighbor_future .numel () != 0 :
51- neighbor_future_mask = torch .sum (torch .ne (neighbor_future [..., :self .future_len , :3 ],0 ), dim = - 1 ) == 0
52- neighbor_future = neighbor_future [..., :self .future_len , :3 ] # (x, y, heading)
53-
54- neighbor_future [neighbor_future_mask ] = 0.
55- neighbor_future_valid = ~ neighbor_future_mask
56-
5747 model_inputs = {}
5848 model_inputs ['ego_past' ] = data .ego_past .to (device )
5949 model_inputs ['neighbor_past' ] = data .neighbor_past .to (device )
@@ -63,10 +53,6 @@ def sample_to_model_input(
6353 model_inputs ['routes' ] = data .routes .to (device )
6454 model_inputs ['map_objects' ] = data .map_objects .to (device )
6555
66- neighbor_current = data .neighbor_past [..., :self .neighbor_pred_num , - 1 , :4 ]
67- neighbor_current_mask = torch .sum (torch .ne (neighbor_current [..., :4 ], 0 ), dim = - 1 ) == 0
68- model_inputs ['neighbor_current_mask' ] = neighbor_current_mask .to (device )
69-
7056
7157 ego_current_state = data .ego_current
7258 model_inputs ['ego_current' ] = ego_current_state
@@ -75,38 +61,20 @@ def sample_to_model_input(
7561 ego_current_xy_cos_sin [..., :2 ],
7662 torch .atan2 (ego_current_xy_cos_sin [..., 3 :4 ], ego_current_xy_cos_sin [..., 2 :3 ])
7763 ], dim = - 1 )
78- neighbor_current = torch .cat ([
79- neighbor_current [..., :2 ],
80- torch .atan2 (neighbor_current [..., 3 :4 ], neighbor_current [..., 2 :3 ])
81- ], dim = - 1 )
82- current_states = torch .cat ([
83- ego_current [:, None ],
84- neighbor_current
85- ], dim = 1 )
64+
65+ current_states = ego_current [:, None ]
8666
8767 if is_training :
88- gt_future = torch .cat ([
89- ego_future [:, None , :, :],
90- neighbor_future ,
91- ], dim = 1 )
68+ gt_future = ego_future [:, None , :, :]
9269
9370 gt_with_current = torch .cat ([
9471 current_states [:, :, None , :],
9572 gt_future
9673 ], dim = 2 )
97-
98- neighbor_mask = torch .cat ([
99- neighbor_current_mask .unsqueeze (- 1 ),
100- neighbor_future_mask
101- ], dim = - 1 )
102- gt_with_current [:, 1 :][neighbor_mask ] = 0.
10374
10475 gt_with_current .to (device )
105- neighbor_future_valid .to (device )
106-
10776 else :
10877 gt_with_current = current_states [:, :, None , :].repeat (1 , 1 , self .future_len + 1 , 1 )
109- neighbor_future_valid = None
11078
11179 if kinematic == 'waypoints' :
11280 gt_with_current = torch .cat ([
@@ -127,6 +95,4 @@ def sample_to_model_input(
12795
12896 gt_with_current = self .state_preprocess (gt_with_current )
12997
130- model_inputs .update ({'neighbor_future_valid' : neighbor_future_valid })
131-
13298 return model_inputs , gt_with_current
0 commit comments