|
16 | 16 | from torchvision import transforms |
17 | 17 |
|
18 | 18 | from text_recognizer.data.util import BaseDataset, convert_strings_to_labels |
19 | | -from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info |
| 19 | +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info, split_dataset |
20 | 20 | from text_recognizer.data.emnist import EMNIST |
21 | 21 | from text_recognizer.data.iam import IAM |
22 | 22 | from text_recognizer import util |
@@ -82,11 +82,7 @@ def setup(self, stage: str = None): |
82 | 82 | y_trainval = convert_strings_to_labels(labels_trainval, self.inverse_mapping, length=self.output_dims[0]) |
83 | 83 | data_trainval = BaseDataset(x_trainval, y_trainval, transform=get_transform(IMAGE_WIDTH, self.augment)) |
84 | 84 |
|
85 | | - train_size = int(TRAIN_FRAC * len(data_trainval)) |
86 | | - val_size = len(data_trainval) - train_size |
87 | | - self.data_train, self.data_val = torch.utils.data.random_split( |
88 | | - data_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42) |
89 | | - ) |
| 85 | + self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42) |
90 | 86 |
|
91 | 87 | # Note that test data does not go through augmentation transforms |
92 | 88 | if stage == "test" or stage is None: |
|
0 commit comments