Skip to content

Commit 5bf5601

Browse files
lingvo-botcopybara-github
authored andcommitted
Read the TriviaQA text as bytes representations by calling ast.literal_eval on it first.
PiperOrigin-RevId: 415648113
1 parent bcf61e3 commit 5bf5601

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

lingvo/jax/tasks/lm/input_generator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# ==============================================================================
1616
"""Language model input generator."""
1717

18+
import ast
1819
from typing import List
1920

2021
from absl import logging
@@ -329,6 +330,9 @@ def Params(cls) -> InstantiableParams:
329330
p.Define(
330331
'num_samples', 0, 'Number of items contained in the input. 0 for '
331332
'dynamically determined (slower).')
333+
p.Define(
334+
'bytes_repr', True, 'Whether the texts are written as bytes '
335+
r"representation, e.g. b'Q: Who directed?\n\nA:'")
332336
return p
333337

334338
def __init__(self, p: InstantiableParams) -> None:
@@ -382,9 +386,24 @@ def _to_nested_map(self, text) -> py_utils.NestedMap:
382386
dtype=paddings.dtype)
383387
return py_utils.NestedMap(ids=ids, labels=labels, paddings=new_paddings)
384388

389+
def _remove_bytes_repr(self, ds):
390+
391+
def eval_bytes(s):
392+
return ast.literal_eval(s.numpy().decode())
393+
394+
def tf_eval_bytes(x):
395+
x_shape = x.shape
396+
y = tf.py_function(eval_bytes, [x], tf.string)
397+
y.set_shape(x_shape)
398+
return y
399+
400+
return ds.map(tf_eval_bytes)
401+
385402
def _gen_dataset(self) -> tf.data.Dataset:
386403
p = self.params
387404
lines = tf.data.TextLineDataset(p.input_file)
405+
if p.bytes_repr:
406+
lines = self._remove_bytes_repr(lines)
388407
num_repeat = self._num_to_truncate() // self.num_samples + 1
389408
lines = lines.repeat(num_repeat).take(self._num_to_truncate())
390409
lines = lines.shard(

0 commit comments

Comments
 (0)