Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions paddlenlp/datasets/embedding_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@
pos_ids = list(range(len(token_ids)))
return token_ids, pos_ids

def _postprocess_sequence(self, example: Example):
def _postprocess_sequence(self, example: Example, rng):
"""Post process sequence: tokenization & truncation."""
query = example.query
pos_passage = random.choice(example.pos_passage)
pos_passage = rng.choice(example.pos_passage)

Check warning on line 123 in paddlenlp/datasets/embedding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/embedding_dataset.py#L123

Added line #L123 was not covered by tests
neg_passage = example.neg_passage
if len(neg_passage) > 0:
if len(neg_passage) < self.group_size - 1:
Expand All @@ -132,12 +132,12 @@
selected_neg_passage = neg_passage * full_sets_needed

# Ensure the remainder part is filled; randomly select from neg_passage
selected_neg_passage += random.sample(neg_passage, remainder)
selected_neg_passage += rng.sample(neg_passage, remainder)

Check warning on line 135 in paddlenlp/datasets/embedding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/embedding_dataset.py#L135

Added line #L135 was not covered by tests

# Shuffle the result to ensure randomness
random.shuffle(selected_neg_passage)
rng.shuffle(selected_neg_passage)

Check warning on line 138 in paddlenlp/datasets/embedding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/embedding_dataset.py#L138

Added line #L138 was not covered by tests
else:
selected_neg_passage = random.sample(neg_passage, self.group_size - 1)
selected_neg_passage = rng.sample(neg_passage, self.group_size - 1)

Check warning on line 140 in paddlenlp/datasets/embedding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/embedding_dataset.py#L140

Added line #L140 was not covered by tests
else:
selected_neg_passage = []
# Process query tokens
Expand Down Expand Up @@ -241,9 +241,11 @@
"""Iterates through one epoch of the dataset."""

num_sequences = 0
for index, example in enumerate(self.example_dataset):
rng = random.Random()
for _, example in enumerate(self.example_dataset):

Check warning on line 245 in paddlenlp/datasets/embedding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/embedding_dataset.py#L244-L245

Added lines #L244 - L245 were not covered by tests
example = self.convert_example(example)
sequence = self._postprocess_sequence(example)
rng.seed(num_sequences)
sequence = self._postprocess_sequence(example, rng)

Check warning on line 248 in paddlenlp/datasets/embedding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/embedding_dataset.py#L247-L248

Added lines #L247 - L248 were not covered by tests
if sequence is None:
continue
num_sequences += 1
Expand Down