-
Notifications
You must be signed in to change notification settings - Fork 487
Expand file tree
/
Copy pathcolpali_visual_retrieval_marktechpost.py
More file actions
93 lines (71 loc) · 2.77 KB
/
colpali_visual_retrieval_marktechpost.py
File metadata and controls
93 lines (71 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
"""colpali_visual_retrieval_Marktechpost.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1xiNO91MBZA3ylQY8CL4P9347fCYvA9Zu
"""
import subprocess, sys, os, json, hashlib
def pip(cmd):
subprocess.check_call([sys.executable, "-m", "pip"] + cmd)
pip(["uninstall", "-y", "pillow", "PIL", "torchaudio", "colpali-engine"])
pip(["install", "-q", "--upgrade", "pip"])
pip(["install", "-q", "pillow<12", "torchaudio==2.8.0"])
pip(["install", "-q", "colpali-engine", "pypdfium2", "matplotlib", "tqdm", "requests"])
import torch
import requests
import pypdfium2 as pdfium
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from transformers.utils.import_utils import is_flash_attn_2_available
from colpali_engine.models import ColPali, ColPaliProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
MODEL_NAME = "vidore/colpali-v1.3"
model = ColPali.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map=device,
attn_implementation="flash_attention_2" if device == "cuda" and is_flash_attn_2_available() else None,
).eval()
processor = ColPaliProcessor.from_pretrained(MODEL_NAME)
PDF_URL = "https://arxiv.org/pdf/2407.01449.pdf"
pdf_bytes = requests.get(PDF_URL).content
pdf = pdfium.PdfDocument(pdf_bytes)
pages = []
MAX_PAGES = 15
for i in range(min(len(pdf), MAX_PAGES)):
page = pdf[i]
img = page.render(scale=2).to_pil().convert("RGB")
pages.append(img)
page_embeddings = []
batch_size = 2 if device == "cuda" else 1
for i in tqdm(range(0, len(pages), batch_size)):
batch_imgs = pages[i:i+batch_size]
batch = processor.process_images(batch_imgs)
batch = {k: v.to(model.device) for k, v in batch.items()}
with torch.no_grad():
emb = model(**batch)
page_embeddings.extend(list(emb.cpu()))
page_embeddings = torch.stack(page_embeddings)
def retrieve(query, top_k=3):
q = processor.process_queries([query])
q = {k: v.to(model.device) for k, v in q.items()}
with torch.no_grad():
q_emb = model(**q).cpu()
scores = processor.score_multi_vector(q_emb, page_embeddings)[0]
vals, idxs = torch.topk(scores, top_k)
return [(int(i), float(v)) for i, v in zip(idxs, vals)]
def show(img, title):
plt.figure(figsize=(6,6))
plt.imshow(img)
plt.axis("off")
plt.title(title)
plt.show()
query = "What is ColPali and what problem does it solve?"
results = retrieve(query, top_k=3)
for rank, (idx, score) in enumerate(results, 1):
show(pages[idx], f"Rank {rank} — Page {idx+1}")
def search(query, k=5):
return [{"page": i+1, "score": s} for i, s in retrieve(query, k)]
print(json.dumps(search("late interaction retrieval"), indent=2))