|
15 | 15 | # ============================================================================== |
16 | 16 | """Language model input generator.""" |
17 | 17 |
|
| 18 | +from typing import List |
| 19 | + |
18 | 20 | from absl import logging |
19 | 21 | from lingvo.core import base_input_generator |
20 | 22 | from lingvo.core import layers |
21 | 23 | from lingvo.core import ops |
22 | 24 | from lingvo.jax import base_input |
23 | 25 | from lingvo.jax import py_utils |
| 26 | +from lingvo.jax import pytypes |
24 | 27 | import tensorflow.compat.v2 as tf |
25 | 28 |
|
26 | 29 | InstantiableParams = py_utils.InstantiableParams |
@@ -305,3 +308,86 @@ def _InputBatch(self): |
305 | 308 | input_batch.segment_pos = tf.tile( |
306 | 309 | tf.range(0, p.seq_len)[tf.newaxis, :], [p.batch_size, 1]) |
307 | 310 | return input_batch |
| 311 | + |
| 312 | + |
| 313 | +class TextInput(base_input.BaseInput): |
| 314 | + """Input generator reading plain text used for eval. |
| 315 | +
|
| 316 | + Each row in the batch corresponds to a line in the input file. This input |
| 317 | + raises out of range after all input data are returned at least once. Depends |
| 318 | + on the number of infeed hosts and batch size, duplicate input is returned |
| 319 | + to pad to full, synchronized batches on all infeed hosts. |
| 320 | + """ |
| 321 | + |
| 322 | + @classmethod |
| 323 | + def Params(cls) -> InstantiableParams: |
| 324 | + p = super().Params() |
| 325 | + p.Define('input_file', None, 'String, path of a (small) input file.') |
| 326 | + p.Define('tokenizer', None, 'Lingvo tokenizer param.') |
| 327 | + p.Define('max_sequence_length', 512, |
| 328 | + 'Maximum number of tokens to be present in a single example.') |
| 329 | + p.Define( |
| 330 | + 'num_samples', 0, 'Number of items contained in the input. 0 for ' |
| 331 | + 'dynamically determined (slower).') |
| 332 | + return p |
| 333 | + |
| 334 | + def __init__(self, p: InstantiableParams) -> None: |
| 335 | + super().__init__(p) |
| 336 | + self.tokenizer = p.tokenizer.Instantiate() |
| 337 | + self._dataset = self._gen_dataset() |
| 338 | + self._iterator = iter(self._dataset) |
| 339 | + |
| 340 | + def get_next(self) -> NestedMap: |
| 341 | + """Returns a batch with .ids, .paddings, and .labels.""" |
| 342 | + ret = self._iterator.get_next() |
| 343 | + return tf.nest.map_structure(lambda x: x.numpy(), ret) |
| 344 | + |
| 345 | + def reset(self) -> None: |
| 346 | + self._iterator = iter(self._dataset) |
| 347 | + |
| 348 | + @property |
| 349 | + def num_samples(self): |
| 350 | + """Number of samples contained in the dataset.""" |
| 351 | + p = self.params |
| 352 | + if p.num_samples > 0: |
| 353 | + return p.num_samples |
| 354 | + lines = tf.data.TextLineDataset(p.input_file) |
| 355 | + p.num_samples = len(list(lines.as_numpy_iterator())) |
| 356 | + return p.num_samples |
| 357 | + |
| 358 | + def _num_to_truncate(self): |
| 359 | + """Smallest multiple of global batch size that covers the entire data.""" |
| 360 | + p = self.params |
| 361 | + n = p.num_infeed_hosts * p.batch_size |
| 362 | + num_global_batches = (self.num_samples + n - 1) // n |
| 363 | + return num_global_batches * n |
| 364 | + |
| 365 | + def ids_to_strings(self, ids: pytypes.NpTensor, |
| 366 | + lengths: pytypes.NpTensor) -> List[str]: |
| 367 | + bytes_list = self.tokenizer.IdsToStrings(ids, lengths).numpy() |
| 368 | + return [b.decode('utf-8') for b in bytes_list] |
| 369 | + |
| 370 | + def _to_nested_map(self, text) -> py_utils.NestedMap: |
| 371 | + ids, labels, paddings = self.tokenizer.StringsToIds( |
| 372 | + text, max_length=self.params.max_sequence_length) |
| 373 | + # Unfortunately some tokenizers don't return the correct paddings. |
| 374 | + # We recompute it by looking at when the labels sequence terminates. |
| 375 | + indices = tf.where(tf.math.equal(labels, self.tokenizer.eos_id)) |
| 376 | + lengths = tf.math.segment_min(indices[:, 1], indices[:, 0]) + 1 |
| 377 | + new_paddings = tf.cast( |
| 378 | + 1.0 - tf.sequence_mask( |
| 379 | + lengths, |
| 380 | + maxlen=self.params.max_sequence_length, |
| 381 | + dtype=paddings.dtype), |
| 382 | + dtype=paddings.dtype) |
| 383 | + return py_utils.NestedMap(ids=ids, labels=labels, paddings=new_paddings) |
| 384 | + |
| 385 | + def _gen_dataset(self) -> tf.data.Dataset: |
| 386 | + p = self.params |
| 387 | + lines = tf.data.TextLineDataset(p.input_file) |
| 388 | + num_repeat = self._num_to_truncate() // self.num_samples + 1 |
| 389 | + lines = lines.repeat(num_repeat).take(self._num_to_truncate()) |
| 390 | + lines = lines.shard( |
| 391 | + num_shards=p.num_infeed_hosts, index=p.infeed_host_index) |
| 392 | + lines = lines.batch(p.batch_size) |
| 393 | + return lines.map(self._to_nested_map) |
0 commit comments