Skip to content

Commit 1737e0d

Browse files
lingvo-botcopybara-github
authored andcommitted
Add an input to read plain text data.
PiperOrigin-RevId: 415406418
1 parent 5d4d881 commit 1737e0d

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

lingvo/jax/tasks/lm/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pytype_strict_library(
2020
"//lingvo/core/ops",
2121
"//lingvo/jax:base_input",
2222
"//lingvo/jax:py_utils",
23+
"//lingvo/jax:pytypes",
2324
# Implicit tensorflow dependency.
2425
],
2526
)

lingvo/jax/tasks/lm/input_generator.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
# ==============================================================================
1616
"""Language model input generator."""
1717

18+
from typing import List
19+
1820
from absl import logging
1921
from lingvo.core import base_input_generator
2022
from lingvo.core import layers
2123
from lingvo.core import ops
2224
from lingvo.jax import base_input
2325
from lingvo.jax import py_utils
26+
from lingvo.jax import pytypes
2427
import tensorflow.compat.v2 as tf
2528

2629
InstantiableParams = py_utils.InstantiableParams
@@ -305,3 +308,86 @@ def _InputBatch(self):
305308
input_batch.segment_pos = tf.tile(
306309
tf.range(0, p.seq_len)[tf.newaxis, :], [p.batch_size, 1])
307310
return input_batch
311+
312+
313+
class TextInput(base_input.BaseInput):
314+
"""Input generator reading plain text used for eval.
315+
316+
Each row in the batch corresponds to a line in the input file. This input
317+
raises out of range after all input data are returned at least once. Depends
318+
on the number of infeed hosts and batch size, duplicate input is returned
319+
to pad to full, synchronized batches on all infeed hosts.
320+
"""
321+
322+
@classmethod
323+
def Params(cls) -> InstantiableParams:
324+
p = super().Params()
325+
p.Define('input_file', None, 'String, path of a (small) input file.')
326+
p.Define('tokenizer', None, 'Lingvo tokenizer param.')
327+
p.Define('max_sequence_length', 512,
328+
'Maximum number of tokens to be present in a single example.')
329+
p.Define(
330+
'num_samples', 0, 'Number of items contained in the input. 0 for '
331+
'dynamically determined (slower).')
332+
return p
333+
334+
def __init__(self, p: InstantiableParams) -> None:
335+
super().__init__(p)
336+
self.tokenizer = p.tokenizer.Instantiate()
337+
self._dataset = self._gen_dataset()
338+
self._iterator = iter(self._dataset)
339+
340+
def get_next(self) -> NestedMap:
341+
"""Returns a batch with .ids, .paddings, and .labels."""
342+
ret = self._iterator.get_next()
343+
return tf.nest.map_structure(lambda x: x.numpy(), ret)
344+
345+
def reset(self) -> None:
346+
self._iterator = iter(self._dataset)
347+
348+
@property
349+
def num_samples(self):
350+
"""Number of samples contained in the dataset."""
351+
p = self.params
352+
if p.num_samples > 0:
353+
return p.num_samples
354+
lines = tf.data.TextLineDataset(p.input_file)
355+
p.num_samples = len(list(lines.as_numpy_iterator()))
356+
return p.num_samples
357+
358+
def _num_to_truncate(self):
359+
"""Smallest multiple of global batch size that covers the entire data."""
360+
p = self.params
361+
n = p.num_infeed_hosts * p.batch_size
362+
num_global_batches = (self.num_samples + n - 1) // n
363+
return num_global_batches * n
364+
365+
def ids_to_strings(self, ids: pytypes.NpTensor,
366+
lengths: pytypes.NpTensor) -> List[str]:
367+
bytes_list = self.tokenizer.IdsToStrings(ids, lengths).numpy()
368+
return [b.decode('utf-8') for b in bytes_list]
369+
370+
def _to_nested_map(self, text) -> py_utils.NestedMap:
371+
ids, labels, paddings = self.tokenizer.StringsToIds(
372+
text, max_length=self.params.max_sequence_length)
373+
# Unfortunately some tokenizers don't return the correct paddings.
374+
# We recompute it by looking at when the labels sequence terminates.
375+
indices = tf.where(tf.math.equal(labels, self.tokenizer.eos_id))
376+
lengths = tf.math.segment_min(indices[:, 1], indices[:, 0]) + 1
377+
new_paddings = tf.cast(
378+
1.0 - tf.sequence_mask(
379+
lengths,
380+
maxlen=self.params.max_sequence_length,
381+
dtype=paddings.dtype),
382+
dtype=paddings.dtype)
383+
return py_utils.NestedMap(ids=ids, labels=labels, paddings=new_paddings)
384+
385+
def _gen_dataset(self) -> tf.data.Dataset:
386+
p = self.params
387+
lines = tf.data.TextLineDataset(p.input_file)
388+
num_repeat = self._num_to_truncate() // self.num_samples + 1
389+
lines = lines.repeat(num_repeat).take(self._num_to_truncate())
390+
lines = lines.shard(
391+
num_shards=p.num_infeed_hosts, index=p.infeed_host_index)
392+
lines = lines.batch(p.batch_size)
393+
return lines.map(self._to_nested_map)

0 commit comments

Comments
 (0)