Skip to content

Commit 96a858d

Browse files
committed
fixes
1 parent 299697b commit 96a858d

File tree

8 files changed

+16
-30
lines changed

8 files changed

+16
-30
lines changed

lab2/text_recognizer/data/emnist.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import toml
1515
import torch
1616

17-
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
17+
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info, split_dataset
1818
from text_recognizer.data.util import BaseDataset
1919

2020
NUM_SPECIAL_TOKENS = 4
@@ -68,11 +68,7 @@ def setup(self, stage: str = None):
6868
self.y_trainval = f["y_train"][:].squeeze().astype(int)
6969

7070
data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform)
71-
train_size = int(TRAIN_FRAC * len(data_trainval))
72-
val_size = len(data_trainval) - train_size
73-
self.data_train, self.data_val = torch.utils.data.random_split(
74-
data_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42)
75-
)
71+
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
7672

7773
if stage == "test" or stage is None:
7874
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:

lab3/text_recognizer/data/emnist.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import toml
1515
import torch
1616

17-
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
17+
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info, split_dataset
1818
from text_recognizer.data.util import BaseDataset
1919

2020
NUM_SPECIAL_TOKENS = 4
@@ -68,11 +68,7 @@ def setup(self, stage: str = None):
6868
self.y_trainval = f["y_train"][:].squeeze().astype(int)
6969

7070
data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform)
71-
train_size = int(TRAIN_FRAC * len(data_trainval))
72-
val_size = len(data_trainval) - train_size
73-
self.data_train, self.data_val = torch.utils.data.random_split(
74-
data_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42)
75-
)
71+
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
7672

7773
if stage == "test" or stage is None:
7874
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:

lab3/text_recognizer/models/line_cnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
128128
x = self.dropout(x)
129129
x = self.fc2(x) # (B, S, C)
130130
x = x.permute(0, 2, 1) # -> (B, C, S)
131+
if self.limit_output_length:
132+
x = x[:, :, :self.output_length]
131133
return x
132134

133135
@staticmethod

lab4/text_recognizer/data/emnist.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import toml
1515
import torch
1616

17-
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
17+
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info, split_dataset
1818
from text_recognizer.data.util import BaseDataset
1919

2020
NUM_SPECIAL_TOKENS = 4
@@ -68,11 +68,7 @@ def setup(self, stage: str = None):
6868
self.y_trainval = f["y_train"][:].squeeze().astype(int)
6969

7070
data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform)
71-
train_size = int(TRAIN_FRAC * len(data_trainval))
72-
val_size = len(data_trainval) - train_size
73-
self.data_train, self.data_val = torch.utils.data.random_split(
74-
data_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42)
75-
)
71+
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
7672

7773
if stage == "test" or stage is None:
7874
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:

lab4/text_recognizer/models/line_cnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
128128
x = self.dropout(x)
129129
x = self.fc2(x) # (B, S, C)
130130
x = x.permute(0, 2, 1) # -> (B, C, S)
131+
if self.limit_output_length:
132+
x = x[:, :, :self.output_length]
131133
return x
132134

133135
@staticmethod

lab5/text_recognizer/data/emnist.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import toml
1515
import torch
1616

17-
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
17+
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info, split_dataset
1818
from text_recognizer.data.util import BaseDataset
1919

2020
NUM_SPECIAL_TOKENS = 4
@@ -68,11 +68,7 @@ def setup(self, stage: str = None):
6868
self.y_trainval = f["y_train"][:].squeeze().astype(int)
6969

7070
data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform)
71-
train_size = int(TRAIN_FRAC * len(data_trainval))
72-
val_size = len(data_trainval) - train_size
73-
self.data_train, self.data_val = torch.utils.data.random_split(
74-
data_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42)
75-
)
71+
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
7672

7773
if stage == "test" or stage is None:
7874
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:

lab5/text_recognizer/data/iam_lines.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torchvision import transforms
1717

1818
from text_recognizer.data.util import BaseDataset, convert_strings_to_labels
19-
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
19+
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info, split_dataset
2020
from text_recognizer.data.emnist import EMNIST
2121
from text_recognizer.data.iam import IAM
2222
from text_recognizer import util
@@ -82,11 +82,7 @@ def setup(self, stage: str = None):
8282
y_trainval = convert_strings_to_labels(labels_trainval, self.inverse_mapping, length=self.output_dims[0])
8383
data_trainval = BaseDataset(x_trainval, y_trainval, transform=get_transform(IMAGE_WIDTH, self.augment))
8484

85-
train_size = int(TRAIN_FRAC * len(data_trainval))
86-
val_size = len(data_trainval) - train_size
87-
self.data_train, self.data_val = torch.utils.data.random_split(
88-
data_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42)
89-
)
85+
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
9086

9187
# Note that test data does not go through augmentation transforms
9288
if stage == "test" or stage is None:

lab5/text_recognizer/models/line_cnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
128128
x = self.dropout(x)
129129
x = self.fc2(x) # (B, S, C)
130130
x = x.permute(0, 2, 1) # -> (B, C, S)
131+
if self.limit_output_length:
132+
x = x[:, :, :self.output_length]
131133
return x
132134

133135
@staticmethod

0 commit comments

Comments
 (0)