Skip to content

Commit 70f81cb

Browse files
authored
feat: add semantic chunking to eval script; add wrapper for minilm (jina-ai#11)
* feat: add semantic chunking to eval script; add wrapper for minilm * fix: gaps in semantic chunking * feat: add option to pass custom model for chunking * refactor: add second model to semantic chunking test
1 parent d5a0fa6 commit 70f81cb

File tree

5 files changed

+54
-12
lines changed

5 files changed

+54
-12
lines changed

chunked_pooling/chunking.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _setup_semantic_chunking(self, embedding_model_name):
3131
self.embed_model = HuggingFaceEmbedding(
3232
model_name=self.embedding_model_name,
3333
trust_remote_code=True,
34+
embed_batch_size=1,
3435
)
3536
self.splitter = SemanticSplitterNodeParser(
3637
embed_model=self.embed_model,
@@ -71,13 +72,12 @@ def chunk_semantically(
7172
start_chunk_index = bisect.bisect_left(
7273
[offset[0] for offset in token_offsets], char_start
7374
)
74-
end_chunk_index = (
75-
bisect.bisect_right([offset[1] for offset in token_offsets], char_end)
76-
- 1
75+
end_chunk_index = bisect.bisect_right(
76+
[offset[1] for offset in token_offsets], char_end
7777
)
7878

7979
# Add the chunk span if it's within the tokenized text
80-
if start_chunk_index < len(token_offsets) and end_chunk_index < len(
80+
if start_chunk_index < len(token_offsets) and end_chunk_index <= len(
8181
token_offsets
8282
):
8383
chunk_spans.append((start_chunk_index, end_chunk_index))

chunked_pooling/mteb_chunked_eval.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
chunk_size: Optional[int] = None,
2626
n_sentences: Optional[int] = None,
2727
model_has_instructions: bool = False,
28+
embedding_model_name: Optional[str] = None, # for semantic chunking
2829
**kwargs,
2930
):
3031
super().__init__(**kwargs)
@@ -45,6 +46,7 @@ def __init__(
4546
self.chunking_args = {
4647
'chunk_size': chunk_size,
4748
'n_sentences': n_sentences,
49+
'embedding_model_name': embedding_model_name,
4850
}
4951

5052
def load_data(self, **kwargs):

chunked_pooling/wrappers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5+
from sentence_transformers import SentenceTransformer
56
from transformers import AutoModel
67

78

@@ -61,7 +62,10 @@ def has_instructions():
6162
return True
6263

6364

64-
MODEL_WRAPPERS = {'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper}
65+
MODEL_WRAPPERS = {
66+
'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper,
67+
'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer,
68+
}
6569
MODELS_WITHOUT_PROMPT_NAME_ARG = [
6670
'jinaai/jina-embeddings-v2-small-en',
6771
'jinaai/jina-embeddings-v2-base-en',
@@ -82,7 +86,10 @@ def wrapper(self, *args, **kwargs):
8286
def load_model(model_name, **model_kwargs):
8387
if model_name in MODEL_WRAPPERS:
8488
model = MODEL_WRAPPERS[model_name](model_name, **model_kwargs)
85-
has_instructions = MODEL_WRAPPERS[model_name].has_instructions()
89+
if hasattr(MODEL_WRAPPERS[model_name], 'has_instructions'):
90+
has_instructions = MODEL_WRAPPERS[model_name].has_instructions()
91+
else:
92+
has_instructions = False
8693
else:
8794
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
8895
has_instructions = False

run_chunked_eval.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929
@click.option(
3030
'--eval-split', default='test', help='The name of the evaluation split in the task.'
3131
)
32-
def main(model_name, strategy, task_name, eval_split):
32+
@click.option(
33+
'--chunking-model',
34+
default=None,
35+
required=False,
36+
help='The name of the model used for semantic chunking.',
37+
)
38+
def main(model_name, strategy, task_name, eval_split, chunking_model):
3339
try:
3440
task_cls = globals()[task_name]
3541
except:
@@ -44,6 +50,7 @@ def main(model_name, strategy, task_name, eval_split):
4450
'n_sentences': DEFAULT_N_SENTENCES,
4551
'chunking_strategy': strategy,
4652
'model_has_instructions': has_instructions,
53+
'embedding_model_name': chunking_model if chunking_model else model_name,
4754
}
4855

4956
if torch.cuda.is_available():

tests/test_chunking_methods.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,42 @@ def test_chunk_by_tokens():
9898
assert end - start <= 10
9999

100100

101-
def test_chunk_semantically():
101+
@pytest.mark.parametrize(
102+
'model_name',
103+
['jinaai/jina-embeddings-v2-small-en', 'sentence-transformers/all-MiniLM-L6-v2'],
104+
)
105+
def test_chunk_semantically(model_name):
102106
chunker = Chunker(chunking_strategy="semantic")
103-
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
104-
chunks = chunker.chunk(
107+
tokenizer = AutoTokenizer.from_pretrained(model_name)
108+
tokens = tokenizer.encode_plus(
109+
EXAMPLE_TEXT_1, add_special_tokens=False, return_offsets_mapping=True
110+
)
111+
boundary_cues = chunker.chunk(
105112
EXAMPLE_TEXT_1,
106113
tokenizer=tokenizer,
107114
chunking_strategy='semantic',
108-
embedding_model_name='jinaai/jina-embeddings-v2-small-en',
115+
embedding_model_name=model_name,
116+
)
117+
118+
# check if it returns boundary cues
119+
assert len(boundary_cues) > 0
120+
121+
# test if bounaries are at the end of sentences
122+
for start_token_idx, end_token_idx in boundary_cues:
123+
assert (
124+
EXAMPLE_TEXT_1[tokens.offset_mapping[end_token_idx - 1][0]] in PUNCTATIONS
125+
)
126+
decoded_text_chunk = tokenizer.decode(
127+
tokens.input_ids[start_token_idx:end_token_idx]
128+
)
129+
130+
# check that the boundary cues are continuous (no token is missing)
131+
assert all(
132+
[
133+
boundary_cues[i][1] == boundary_cues[i + 1][0]
134+
for i in range(len(boundary_cues) - 1)
135+
]
109136
)
110-
assert len(chunks) > 0
111137

112138

113139
def test_empty_input():

0 commit comments

Comments
 (0)