Skip to content

Commit 6b1548f

Browse files
committed
feat: add max-length, new tasks and fix corner cases
1 parent 2c2413a commit 6b1548f

File tree

4 files changed

+150
-31
lines changed

4 files changed

+150
-31
lines changed

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,13 @@ We run this evaluation for various BeIR datasets with traditional chunking and o
6060
To split texts into chunks, we choose a straightforward method, which chunks the tests into strings of 256 tokens.
6161
Both the traditional and context-sensitive tests used the [jina-embeddings-v2-small-en](https://huggingface.co/jinaai/jina-embeddings-v2-small-en) model.
6262

63-
| Dataset | Traditional Chunking (nDCG@10) | Context-Sensitive Chunking (nDCG@10) |
64-
|-----------|--------------------------------|--------------------------------------|
65-
| SciFact | 64.20% | 66.10% |
66-
| TRECCOVID | TODO | TODO |
67-
68-
In (all|most|some) cases, context-sensitive chunking improved the score.
63+
| Dataset | AVG Document Length (characters) | Traditional Chunking (nDCG@10) | Context-Sensitive Chunking (nDCG@10) | No Chunking |
64+
|-----------|----------------------------------|--------------------------------|--------------------------------------|-------------|
65+
| SciFact | 1498.4 | 64.20% | **66.10%** | 63.89% |
66+
| TRECCOVID | 1116.7 | 63.36% | 64.70% | **65.18%** |
67+
| FiQA2018 | 767.2 | 33.25% | **33.84%** | 33.43% |
68+
| NFCorpus | 1589.8 | 23.46% | 29.98% | **30.40%** |
69+
| Quora | 62.2 | 87.19% | 87.19% | 87.19% |
70+
71+
In all cases, context-sensitive chunking improved the score. In some cases, it also outperforms encoding the whole document into a single embedding, while for other datasets, no chunking performs best. However, this only makes sense if one does not need to rank chunks. One can also see that the average length of the documents correlates with greater improvement in the nDCG scores through context-sensitive chunking.
6972

chunked_pooling/__init__.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,20 @@ def chunk_by_sentences(input_text: str, tokenizer: callable):
2929
return chunks, span_annotations
3030

3131

32-
def chunked_pooling(model_output: 'BatchEncoding', span_annotation: list):
32+
def chunked_pooling(
33+
model_output: 'BatchEncoding', span_annotation: list, max_length=None
34+
):
3335
token_embeddings = model_output[0]
3436
outputs = []
3537
for embeddings, annotations in zip(token_embeddings, span_annotation):
36-
if annotations[-1][1] > len(embeddings):
37-
raise RuntimeError(
38-
f'Not enough token embeddings {len(token_embeddings)} for your annotations {annotations}'
39-
)
38+
if (
39+
max_length is not None
40+
): # remove annotations which go bejond the max-length of the model
41+
annotations = [
42+
(start, min(end, max_length - 1))
43+
for (start, end) in annotations
44+
if start < (max_length - 1)
45+
]
4046
pooled_embeddings = [
4147
embeddings[start:end].sum(dim=0) / (end - start)
4248
for start, end in annotations

chunked_pooling/chunked_eval_tasks.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,108 @@ def __init__(self, **kwargs):
7878
super().__init__(**kwargs)
7979

8080

81+
class NFCorpusChunked(AbsTaskChunkedRetrieval):
82+
metadata = TaskMetadata(
83+
name="NFCorpusChunked",
84+
dataset={
85+
"path": "mteb/nfcorpus",
86+
"revision": "ec0fa4fe99da2ff19ca1214b7966684033a58814",
87+
'name': 'NFCorpus',
88+
},
89+
description="NFCorpus: A Full-Text Learning to Rank Dataset for Medical Information Retrieval",
90+
reference="https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/",
91+
type="Retrieval",
92+
category="s2p",
93+
eval_splits=["test"],
94+
eval_langs=["eng-Latn"],
95+
main_score="ndcg_at_10",
96+
date=None,
97+
form=None,
98+
domains=None,
99+
task_subtypes=None,
100+
license=None,
101+
socioeconomic_status=None,
102+
annotations_creators=None,
103+
dialect=None,
104+
text_creation=None,
105+
bibtex_citation=None,
106+
n_samples=None,
107+
avg_character_length=None,
108+
)
109+
110+
def __init__(self, **kwargs):
111+
super().__init__(**kwargs)
112+
113+
114+
class QuoraChunked(AbsTaskChunkedRetrieval):
115+
metadata = TaskMetadata(
116+
name="QuoraChunked",
117+
dataset={
118+
"path": "mteb/quora",
119+
"revision": "e4e08e0b7dbe3c8700f0daef558ff32256715259",
120+
"name": "QuoraRetrieval",
121+
},
122+
description=(
123+
"QuoraRetrieval is based on questions that are marked as duplicates on the Quora platform. Given a"
124+
" question, find other (duplicate) questions."
125+
),
126+
reference="https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs",
127+
type="Retrieval",
128+
category="s2s",
129+
eval_splits=["dev", "test"],
130+
eval_langs=["eng-Latn"],
131+
main_score="ndcg_at_10",
132+
date=None,
133+
form=None,
134+
domains=None,
135+
task_subtypes=None,
136+
license=None,
137+
socioeconomic_status=None,
138+
annotations_creators=None,
139+
dialect=None,
140+
text_creation=None,
141+
bibtex_citation=None,
142+
n_samples=None,
143+
avg_character_length=None,
144+
)
145+
146+
def __init__(self, **kwargs):
147+
super().__init__(**kwargs)
148+
149+
150+
class FiQA2018Chunked(AbsTaskChunkedRetrieval):
151+
metadata = TaskMetadata(
152+
name="FiQA2018Chunked",
153+
description="Financial Opinion Mining and Question Answering",
154+
reference="https://sites.google.com/view/fiqa/",
155+
dataset={
156+
"path": "mteb/fiqa",
157+
"revision": "27a168819829fe9bcd655c2df245fb19452e8e06",
158+
'name': 'FiQA2018',
159+
},
160+
type="Retrieval",
161+
category="s2p",
162+
eval_splits=["train", "dev", "test"],
163+
eval_langs=["eng-Latn"],
164+
main_score="ndcg_at_10",
165+
date=None,
166+
form=None,
167+
domains=None,
168+
task_subtypes=None,
169+
license=None,
170+
socioeconomic_status=None,
171+
annotations_creators=None,
172+
dialect=None,
173+
text_creation=None,
174+
bibtex_citation=None,
175+
n_samples=None,
176+
avg_character_length=None,
177+
)
178+
179+
def __init__(self, **kwargs):
180+
super().__init__(**kwargs)
181+
182+
81183
class TRECCOVIDChunked(AbsTaskChunkedRetrieval):
82184
metadata = TaskMetadata(
83185
name='TRECCOVIDChunked',

chunked_pooling/mteb_chunked_eval.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from typing import Any, Optional
33

44
import numpy as np
5+
import torch
56
from mteb.abstasks import AbsTask
67
from mteb.evaluation.evaluators import RetrievalEvaluator
78
from mteb.tasks import Retrieval
8-
99
from tqdm import tqdm
1010

1111
from chunked_pooling import chunked_pooling
@@ -78,7 +78,7 @@ def evaluate(self, model, split='test', **kwargs):
7878
return scores
7979

8080
def _evaluate_monolingual(
81-
self, model, corpus, queries, relevant_docs, lang=None, batch_size=8, **kwargs
81+
self, model, corpus, queries, relevant_docs, lang=None, batch_size=1, **kwargs
8282
):
8383
# split corpus into chunks
8484
if not self.chunked_pooling_enabled:
@@ -118,24 +118,32 @@ def _evaluate_monolingual(
118118
]
119119

120120
corpus_embs = []
121-
for inputs in tqdm(
122-
self._batch_inputs(
123-
list(zip(corpus_texts, chunk_annotations)), batch_size=batch_size
124-
),
125-
total=(len(corpus_texts) // batch_size),
126-
):
127-
text_inputs = [x[0] for x in inputs]
128-
annotations = [x[1] for x in inputs]
129-
model_inputs = self.tokenizer(
130-
text_inputs, return_tensors='pt', padding=True
131-
)
132-
if model.device.type == 'cuda':
133-
model_inputs = {
134-
k: v.to(model.device) for k, v in model_inputs.items()
135-
}
136-
model_outputs = model(**model_inputs)
137-
138-
corpus_embs.extend(chunked_pooling(model_outputs, annotations))
121+
with torch.no_grad():
122+
for inputs in tqdm(
123+
self._batch_inputs(
124+
list(zip(corpus_texts, chunk_annotations)),
125+
batch_size=batch_size,
126+
),
127+
total=(len(corpus_texts) // batch_size),
128+
):
129+
text_inputs = [x[0] for x in inputs]
130+
annotations = [x[1] for x in inputs]
131+
model_inputs = self.tokenizer(
132+
text_inputs,
133+
return_tensors='pt',
134+
padding=True,
135+
truncation=True,
136+
max_length=8192,
137+
)
138+
if model.device.type == 'cuda':
139+
model_inputs = {
140+
k: v.to(model.device) for k, v in model_inputs.items()
141+
}
142+
model_outputs = model(**model_inputs)
143+
output_embs = chunked_pooling(
144+
model_outputs, annotations, max_length=8192
145+
)
146+
corpus_embs.extend(output_embs)
139147

140148
max_chunks = max([len(x) for x in corpus_embs])
141149
k_values = self._calculate_k_values(max_chunks)

0 commit comments

Comments
 (0)