|
15 | 15 | # ============================================================================== |
16 | 16 | """Language model input generator.""" |
17 | 17 |
|
| 18 | +import ast |
18 | 19 | from typing import List |
19 | 20 |
|
20 | 21 | from absl import logging |
@@ -329,6 +330,9 @@ def Params(cls) -> InstantiableParams: |
329 | 330 | p.Define( |
330 | 331 | 'num_samples', 0, 'Number of items contained in the input. 0 for ' |
331 | 332 | 'dynamically determined (slower).') |
| 333 | + p.Define( |
| 334 | + 'bytes_repr', True, 'Whether the texts are written as bytes ' |
| 335 | + r"representation, e.g. b'Q: Who directed?\n\nA:'") |
332 | 336 | return p |
333 | 337 |
|
334 | 338 | def __init__(self, p: InstantiableParams) -> None: |
@@ -382,9 +386,24 @@ def _to_nested_map(self, text) -> py_utils.NestedMap: |
382 | 386 | dtype=paddings.dtype) |
383 | 387 | return py_utils.NestedMap(ids=ids, labels=labels, paddings=new_paddings) |
384 | 388 |
|
| 389 | + def _remove_bytes_repr(self, ds): |
| 390 | + |
| 391 | + def eval_bytes(s): |
| 392 | + return ast.literal_eval(s.numpy().decode()) |
| 393 | + |
| 394 | + def tf_eval_bytes(x): |
| 395 | + x_shape = x.shape |
| 396 | + y = tf.py_function(eval_bytes, [x], tf.string) |
| 397 | + y.set_shape(x_shape) |
| 398 | + return y |
| 399 | + |
| 400 | + return ds.map(tf_eval_bytes) |
| 401 | + |
385 | 402 | def _gen_dataset(self) -> tf.data.Dataset: |
386 | 403 | p = self.params |
387 | 404 | lines = tf.data.TextLineDataset(p.input_file) |
| 405 | + if p.bytes_repr: |
| 406 | + lines = self._remove_bytes_repr(lines) |
388 | 407 | num_repeat = self._num_to_truncate() // self.num_samples + 1 |
389 | 408 | lines = lines.repeat(num_repeat).take(self._num_to_truncate()) |
390 | 409 | lines = lines.shard( |
|
0 commit comments