Skip to content

Commit 5e4cfe6

Browse files
committed
[fix] removed redundant neighbor-related input
1 parent 9278dc4 commit 5e4cfe6

File tree

3 files changed

+4
-42
lines changed

3 files changed

+4
-42
lines changed

flow_planner/model/flow_planner_model/flow_planner.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ def extract_encoder_inputs(self, inputs):
102102
return encoder_inputs
103103

104104
def extract_decoder_inputs(self, encoder_outputs, inputs):
105-
model_extra = dict(neighbor_current_mask=inputs['neighbor_current_mask'],
106-
cfg_flags=inputs['cfg_flags'] if 'cfg_flags' in inputs.keys() else None,)
105+
model_extra = dict(cfg_flags=inputs['cfg_flags'] if 'cfg_flags' in inputs.keys() else None,)
107106
model_extra.update(encoder_outputs)
108107
return model_extra
109108

@@ -151,12 +150,10 @@ def forward_train(self, data: NuPlanDataSample):
151150
prediction = self.decoder(noised_traj_tokens, t, **decoder_model_extra)
152151

153152
loss_dict = {}
154-
155153
batch_loss = self.basic_loss(prediction, target_tokens)
156154
loss_dict['batch_loss'] = batch_loss
157155

158156
loss = torch.sum(batch_loss, dim=-1) # (B, action_num, action_length, dim)
159-
loss_dict['neighbor_pred_loss'] = torch.tensor(0.0, device=loss.device)
160157
loss_dict['ego_planning_loss'] = loss.mean()
161158

162159
if self.planner_params['action_overlap'] > 0: # TODO:

flow_planner/model/model_utils/input_preprocess.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,6 @@ def sample_to_model_input(
4444
if ego_future.numel() != 0:
4545
ego_future = ego_future[..., 1:1+self.future_len, :3] # (x, y, heading)
4646

47-
# in the original input, the neighbor_future only include 10 neighbors, while the neighbor_agent_current include 32 neighbors
48-
49-
neighbor_future = data.neighbor_future
50-
if neighbor_future.numel() != 0:
51-
neighbor_future_mask = torch.sum(torch.ne(neighbor_future[..., :self.future_len, :3],0), dim=-1) == 0
52-
neighbor_future = neighbor_future[..., :self.future_len, :3] # (x, y, heading)
53-
54-
neighbor_future[neighbor_future_mask] = 0.
55-
neighbor_future_valid = ~neighbor_future_mask
56-
5747
model_inputs = {}
5848
model_inputs['ego_past'] = data.ego_past.to(device)
5949
model_inputs['neighbor_past'] = data.neighbor_past.to(device)
@@ -63,10 +53,6 @@ def sample_to_model_input(
6353
model_inputs['routes'] = data.routes.to(device)
6454
model_inputs['map_objects'] = data.map_objects.to(device)
6555

66-
neighbor_current = data.neighbor_past[..., :self.neighbor_pred_num, -1, :4]
67-
neighbor_current_mask = torch.sum(torch.ne(neighbor_current[..., :4], 0), dim=-1) == 0
68-
model_inputs['neighbor_current_mask'] = neighbor_current_mask.to(device)
69-
7056

7157
ego_current_state = data.ego_current
7258
model_inputs['ego_current'] = ego_current_state
@@ -75,38 +61,20 @@ def sample_to_model_input(
7561
ego_current_xy_cos_sin[..., :2],
7662
torch.atan2(ego_current_xy_cos_sin[..., 3:4], ego_current_xy_cos_sin[..., 2:3])
7763
], dim=-1)
78-
neighbor_current = torch.cat([
79-
neighbor_current[..., :2],
80-
torch.atan2(neighbor_current[..., 3:4], neighbor_current[..., 2:3])
81-
], dim=-1)
82-
current_states = torch.cat([
83-
ego_current[:, None],
84-
neighbor_current
85-
], dim=1)
64+
65+
current_states = ego_current[:, None]
8666

8767
if is_training:
88-
gt_future = torch.cat([
89-
ego_future[:, None, :, :],
90-
neighbor_future,
91-
], dim=1)
68+
gt_future = ego_future[:, None, :, :]
9269

9370
gt_with_current = torch.cat([
9471
current_states[:, :, None, :],
9572
gt_future
9673
], dim=2)
97-
98-
neighbor_mask = torch.cat([
99-
neighbor_current_mask.unsqueeze(-1),
100-
neighbor_future_mask
101-
], dim=-1)
102-
gt_with_current[:, 1:][neighbor_mask] = 0.
10374

10475
gt_with_current.to(device)
105-
neighbor_future_valid.to(device)
106-
10776
else:
10877
gt_with_current = current_states[:, :, None, :].repeat(1, 1, self.future_len + 1, 1)
109-
neighbor_future_valid = None
11078

11179
if kinematic == 'waypoints':
11280
gt_with_current = torch.cat([
@@ -127,6 +95,4 @@ def sample_to_model_input(
12795

12896
gt_with_current = self.state_preprocess(gt_with_current)
12997

130-
model_inputs.update({'neighbor_future_valid': neighbor_future_valid})
131-
13298
return model_inputs, gt_with_current

flow_planner/script/core/flow_matching.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,4 @@ input_aug:
1111
device: ${device}
1212

1313
ego_planning_loss: 1.0
14-
neighbor_pred_loss: 0.0
1514
consistency_loss: 0.5

0 commit comments

Comments
 (0)