Skip to content

Commit 6ea7773

Browse files
lingvo-botcopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 415079304
1 parent c08852d commit 6ea7773

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

lingvo/core/batch_major_attention.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2431,10 +2431,9 @@ def _StreamStepStaticLength(self, theta, query_vec, query_paddings, key_vec,
24312431
p = self.params
24322432
dims = self._StreamStepDimensions(query_vec)
24332433
h, s, b, q = dims.h, dims.s, dims.b, dims.q
2434-
assert query_vec.shape[1] is not None, 'query_vec.shape[1] must be static.'
2435-
assert q <= p.inference_step_max_length, (
2436-
f'q: {q} should be less than p.inference_step_max_length: '
2437-
f'{p.inference_step_max_length}')
2434+
assert q is not None
2435+
query_vec = py_utils.with_dependencies(
2436+
[py_utils.assert_less_equal(q, p.inference_step_max_length)], query_vec)
24382437

24392438
b, k = py_utils.GetShape(key_vec, 2)
24402439
q = (k + p.query_stride - 1) // p.query_stride

0 commit comments

Comments
 (0)