-
Notifications
You must be signed in to change notification settings - Fork 644
Expand file tree
/
Copy pathinternvl_chat.py
More file actions
670 lines (587 loc) · 27.6 KB
/
internvl_chat.py
File metadata and controls
670 lines (587 loc) · 27.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModel, CLIPImageProcessor
import warnings
from PIL import Image
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE, DATASET_MODALITY
import pandas as pd
import string
import torch.distributed as dist
import torchvision.transforms as T
import transformers
from torchvision.transforms.functional import InterpolationMode
import re
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=6, upscale=False):
image = Image.open(image_file).convert('RGB')
if upscale:
image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR)
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
# This function is used to split InternVL2-Llama3-76B
def split_model(model_name):
import math
device_map = {}
num_gpus = torch.cuda.device_count()
rank, world_size = get_rank_and_world_size()
num_gpus = num_gpus // world_size
num_layers_map = {
'InternVL2-8B': 32,
'InternVL2-26B': 48,
'InternVL2-40B': 60,
'InternVL2-Llama3-76B': 80
}
if model_name not in num_layers_map:
return 'cuda'
num_layers = num_layers_map[model_name]
# Since the first GPU will be used for ViT, treat it as 0.5 GPU.
num_layers_per_gpu = math.ceil(num_layers / (num_gpus - 0.5))
num_layers_per_gpu = [num_layers_per_gpu] * num_gpus
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'language_model.model.layers.{layer_cnt}'] = rank + world_size * i
layer_cnt += 1
device_map['vision_model'] = rank
device_map['mlp1'] = rank
device_map['language_model.model.tok_embeddings'] = rank
device_map['language_model.model.embed_tokens'] = rank
device_map['language_model.output'] = rank
device_map['language_model.model.norm'] = rank
device_map['language_model.lm_head'] = rank
device_map['language_model.model.rotary_emb'] = rank
device_map[f'language_model.model.layers.{num_layers - 1}'] = rank
return device_map
def load_image_mmniah(image_file, dynamic_image_size=True, input_size=448, max_num=6):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
if dynamic_image_size:
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
else:
images = [image]
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def split_model_mmniah(model_path):
num_gpus_per_rank = 8
num_gpus = torch.cuda.device_count()
# rank = int(os.getenv('SLURM_PROCID', '0'))
# local_rank = rank % (num_gpus // num_gpus_per_rank)
# world_size = int(os.getenv('SLURM_NTASKS', '1'))
local_rank = 0
local_world_size = num_gpus // num_gpus_per_rank
visible_devices = [i for i in range(local_rank, num_gpus, local_world_size)]
device_map = {}
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
num_gpus_for_vit = 1
num_gpus_for_llm = len(visible_devices) - num_gpus_for_vit
num_layers = config.llm_config.num_hidden_layers
num_layers_per_gpu = num_layers // num_gpus_for_llm + 1
for i in range(num_layers):
device_idx = min(i // num_layers_per_gpu + num_gpus_for_vit, len(visible_devices) - 1)
device_map[f'language_model.model.layers.{i}'] = visible_devices[device_idx]
num_layers = config.vision_config.num_hidden_layers
num_layers_per_gpu = num_layers // num_gpus_for_vit + 1
for i in range(num_layers):
device_idx = min(i // num_layers_per_gpu, num_gpus_for_vit - 1)
device_map[f'vision_model.encoder.layers.{i}'] = visible_devices[device_idx]
device_map['vision_model.embeddings'] = visible_devices[0]
device_map['mlp1'] = visible_devices[num_gpus_for_vit - 1]
# InternLM2
device_map['language_model.model.tok_embeddings'] = visible_devices[num_gpus_for_vit]
device_map['language_model.model.norm'] = visible_devices[-1]
device_map['language_model.output'] = visible_devices[-1]
# Qwen2
device_map['language_model.model.embed_tokens'] = visible_devices[num_gpus_for_vit]
device_map['language_model.model.norm'] = visible_devices[-1]
device_map['language_model.lm_head'] = visible_devices[-1]
return device_map
def extract_answer(text):
match = re.search(r'(Final answer:|Answer:)\s*(.*)', text, re.IGNORECASE)
if match:
return match.group(2).strip()
return text
class InternVLChat(BaseModel):
INSTALL_REQ = False
INTERLEAVE = True
def __init__(self,
model_path='OpenGVLab/InternVL-Chat-V1-5',
load_in_8bit=False,
cot_prompt=False,
version='V1.0',
**kwargs):
assert model_path is not None
assert version_cmp(transformers.__version__, '4.36.2', 'ge')
self.cot_prompt = cot_prompt
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
# Regular expression to match the pattern 'Image' followed by a number, e.g. Image1
self.pattern = r'Image(\d+)'
# Replacement pattern to insert a hyphen between 'Image' and the number, e.g. Image-1
self.replacement = r'Image-\1'
# Convert InternVL2 response to dataset format
# e.g. Image1 -> Image-1
# Regular expression to match the pattern 'Image-' followed by a number
self.reverse_pattern = r'Image-(\d+)'
# Replacement pattern to remove the hyphen (Image-1 -> Image1)
self.reverse_replacement = r'Image\1'
self.device = 'cuda'
if auto_split_flag() and listinstr(['InternVL2-8B', 'InternVL2-26B', 'InternVL2-40B'], model_path):
device_map = split_model(model_path.split('/')[-1])
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
load_in_8bit=load_in_8bit,
trust_remote_code=True,
low_cpu_mem_usage=True,
device_map=device_map).eval()
elif listinstr(['InternVL2-Llama3-76B'], model_path):
device_map = split_model(model_path.split('/')[-1])
breakpoint()
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
load_in_8bit=load_in_8bit,
trust_remote_code=True,
low_cpu_mem_usage=True,
device_map=device_map).eval()
elif listinstr(['InternVL-Chat-V1-5'], model_path) and version == "mmniah":
device_map = split_model_mmniah(model_path)
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True,
device_map=device_map).eval()
else:
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
load_in_8bit=load_in_8bit,
trust_remote_code=True).eval()
if not load_in_8bit:
self.model = self.model.to('cuda')
self.image_size = self.model.config.vision_config.image_size
self.version = version
kwargs_default = dict(do_sample=False, max_new_tokens=2048, top_p=None, num_beams=1)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
def use_custom_prompt(self, dataset):
assert dataset is not None
if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN', "NIAH"], dataset):
# For Multi-Turn we don't have custom prompt
return False
if DATASET_MODALITY(dataset) == 'VIDEO':
# For Video benchmarks we don't have custom prompt at here
return False
else:
return True
def build_multi_choice_prompt(self, line, dataset=None):
question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
if hint is not None:
question = hint + '\n' + question
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
question += f'\n{key}. {item}'
prompt = question
if len(options):
prompt += '\n请直接回答选项字母。' if cn_string(
prompt) else "\nAnswer with the option's letter from the given choices directly."
else:
prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
return prompt
def build_video_prompt(self, prompt, dataset=None, max_frames=64):
for start in range(0, max_frames, 8):
images_to_remove = ''.join([f'<Image-{i}>' for i in range(start + 1, start + 9)])
prompt = prompt.replace(images_to_remove, '')
for i in range(max_frames):
prompt = prompt.replace(f'Image-{i + 1}', f'Frame-{i + 1}')
if listinstr(['MMBench-Video'], dataset):
prompt = prompt.replace('\nAnswer:', '')
elif listinstr(['Video-MME'], dataset):
prompt = prompt.replace('\nAnswer:', '')
prompt += "\nAnswer with the option's letter from the given choices directly."
elif listinstr(['MVBench'], dataset):
prompt = prompt.replace('Best option:(', '')
return prompt
def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
tgt_path = self.dump_image(line, dataset)
if self.version == 'V1.1':
kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=5)
else:
kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1)
self.kwargs = kwargs_default
if dataset is not None and DATASET_TYPE(dataset) == 'Y/N':
question = line['question']
if listinstr(['MME'], dataset):
prompt = question + ' Answer the question using a single word or phrase.'
elif listinstr(['HallusionBench'], dataset):
prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
else:
prompt = question
elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
prompt = self.build_multi_choice_prompt(line, dataset)
elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
question = line['question']
if listinstr(['MathVista', 'MathVision', 'VCR', 'MTVQA', 'MMVet', 'MathVerse'], dataset):
prompt = question
elif listinstr(['LLaVABench'], dataset):
prompt = question + '\nAnswer this question in detail.'
else:
prompt = question + '\nAnswer the question using a single word or phrase.'
else:
prompt = line['question']
if self.cot_prompt and not listinstr(['LLaVABench'], dataset):
cot_prompt_with_final_answer = (
"Your task is to answer the question below. "
"Give step by step reasoning before you answer, and when you're ready to answer, "
"please use the format \"Final answer: ..\""
"\n\n"
"Question:"
"\n\n"
"{question}"
)
cot_prompt_wo_final_answer = (
"Your task is to answer the question below. "
"Give step by step reasoning. "
"\n\n"
"Question:"
"\n\n"
"{question}"
)
if listinstr(['MMVet'], dataset):
cot_prompt = cot_prompt_wo_final_answer
else:
cot_prompt = cot_prompt_with_final_answer
question_orig = line['question']
if listinstr(['MathVerse', 'MathVision'], dataset):
question_orig = question_orig.split('Question:', 1)[-1].strip()
question_orig = question_orig.replace('Choices:\n', '').strip()
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = ''
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
if options_prompt.strip():
question_orig = f'{question_orig}\n{options_prompt}'
prompt = cot_prompt.format(question=question_orig)
message = [dict(type='text', value=prompt)]
message.extend([dict(type='image', value=s) for s in tgt_path])
return message
def set_max_num(self, dataset):
if dataset is None:
self.max_num = 6
return None
# res_1_datasets = ['MMBench-Video', 'Video-MME', 'MVBench', 'Video']
res_12_datasets = ['ChartQA_TEST', 'MMMU_DEV_VAL', 'MMMU_TEST', 'MME-RealWorld',
'MME-RealWorld', 'VCR_EN', 'VCR_ZH']
res_18_datasets = ['DocVQA_VAL', 'DocVQA_TEST']
res_24_datasets = ['InfoVQA_VAL', 'InfoVQA_TEST', 'OCRBench', 'HRBench4K', 'HRBench8K']
if DATASET_MODALITY(dataset) == 'VIDEO':
self.max_num = 1
elif listinstr(res_12_datasets, dataset):
self.max_num = 12
elif listinstr(res_18_datasets, dataset):
self.max_num = 18
elif listinstr(res_24_datasets, dataset):
self.max_num = 24
else:
self.max_num = 6
def generate_v1_2(self, message, dataset=None):
self.INTERLEAVE = False
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
image = Image.open(image_path).convert('RGB')
image = image.resize((self.image_size, self.image_size))
image_processor = CLIPImageProcessor.from_pretrained(self.model_path)
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).to(self.device)
with torch.no_grad():
response = self.model.chat(self.tokenizer, pixel_values=pixel_values,
question=prompt, generation_config=self.kwargs)
return response
def generate_v1_5(self, message, dataset=None):
image_num = len([x for x in message if x['type'] == 'image'])
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
if DATASET_MODALITY(dataset) == 'VIDEO':
prompt = self.build_video_prompt(prompt, dataset)
if image_num > 1:
image_path = [x['value'] for x in message if x['type'] == 'image']
pixel_values_list = []
for file_name in image_path:
pixel_values_list.append(load_image(file_name, max_num=self.max_num).to(self.device).to(torch.bfloat16))
pixel_values = torch.cat(pixel_values_list, dim=0)
elif image_num == 1:
image_path = [x['value'] for x in message if x['type'] == 'image'][0]
pixel_values = load_image(image_path, max_num=self.max_num).to(self.device).to(torch.bfloat16)
else:
pixel_values = None
with torch.no_grad():
response = self.model.chat(
self.tokenizer,
pixel_values=pixel_values,
question=prompt,
generation_config=self.kwargs,
verbose=False)
return response
def generate_mmniah(self, message, dataset=None):
self.tokenizer.model_max_length = 256000
image_num = len([x for x in message if x['type'] == 'image'])
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
num_patches_list = []
if image_num > 1:
image_path = [x['value'] for x in message if x['type'] == 'image']
pixel_values_list = []
for file_name in image_path:
curr_pixel_values = load_image_mmniah(file_name, max_num=6, dynamic_image_size=False)
curr_pixel_values = curr_pixel_values.to(self.device).to(torch.bfloat16)
pixel_values_list.append(curr_pixel_values)
num_patches_list.append(len(curr_pixel_values))
pixel_values = torch.cat(pixel_values_list, dim=0)
elif image_num == 1:
image_path = [x['value'] for x in message if x['type'] == 'image'][0]
pixel_values = load_image_mmniah(image_path, max_num=6, dynamic_image_size=False)
pixel_values = pixel_values.to(self.device).to(torch.bfloat16)
num_patches_list.append(len(pixel_values))
else:
pixel_values = None
with torch.no_grad():
response = self.model.chat(
self.tokenizer,
pixel_values=pixel_values,
question=prompt,
generation_config=dict(
do_sample=False,
num_beams=1,
max_new_tokens=32,
),
num_patches_list=num_patches_list,
history=None,
return_history=False,
verbose=False)
return response
def generate_v2(self, message, dataset=None):
image_num = len([x for x in message if x['type'] == 'image'])
if image_num == 1:
prompt = '<image>\n' + '\n'.join([x['value'] for x in message if x['type'] == 'text'])
else:
prompt, image_idx = '', 1
for x in message:
if x['type'] == 'text':
prompt += x['value']
elif x['type'] == 'image':
prompt += f'<Image-{image_idx}>'
image_idx += 1
prompt = '\n'.join([f'Image-{i + 1}: <image>' for i in range(image_num)]) + '\n' + prompt
if dataset is not None and DATASET_MODALITY(dataset) == 'VIDEO':
prompt = self.build_video_prompt(prompt, dataset)
if image_num > 1:
image_path = [x['value'] for x in message if x['type'] == 'image']
num_patches_list = []
pixel_values_list = []
for image_idx, file_name in enumerate(image_path):
upscale_flag = image_idx == 0 and dataset is not None and listinstr(['MMMU_DEV_VAL'], dataset)
curr_pixel_values = load_image(
file_name, max_num=self.max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
num_patches_list.append(curr_pixel_values.size(0))
pixel_values_list.append(curr_pixel_values)
pixel_values = torch.cat(pixel_values_list, dim=0)
elif image_num == 1:
image_path = [x['value'] for x in message if x['type'] == 'image'][0]
upscale_flag = dataset is not None and listinstr(['MMMU_DEV_VAL'], dataset)
pixel_values = load_image(
image_path, max_num=self.max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
num_patches_list = [pixel_values.size(0)]
else:
pixel_values = None
num_patches_list = []
with torch.no_grad():
response = self.model.chat(
self.tokenizer,
pixel_values=pixel_values,
num_patches_list=num_patches_list,
question=prompt,
generation_config=self.kwargs,
verbose=False
)
if (
self.cot_prompt
and dataset is not None
and (
DATASET_TYPE(dataset) in ['Y/N', 'MCQ']
or listinstr(['CRPE'], dataset)
)
):
response = extract_answer(response).strip()
return response
def generate_inner(self, message, dataset=None):
self.set_max_num(dataset)
print(f'InternVL model version: {self.version}')
if self.version in ['V1.1', 'V1.2']:
return self.generate_v1_2(message, dataset)
elif self.version == 'V1.5':
return self.generate_v1_5(message, dataset)
elif self.version == 'V2.0':
return self.generate_v2(message, dataset)
elif self.version == 'mmniah':
return self.generate_mmniah(message, dataset)
else:
raise ValueError(f'Unsupported version: {self.version}')
def build_history(self, message):
# Global Variables
image_path = []
image_cnt = 0
def concat_tilist(tilist):
nonlocal image_cnt # Declare image_cnt as nonlocal to modify it
prompt = ''
for item in tilist:
# Substitute the pattern in the text
if item['type'] == 'text':
prompt += re.sub(self.pattern, self.replacement, item['value'])
elif item['type'] == 'image':
image_cnt += 1
prompt += '<image>\n'
image_path.append(item['value'])
return prompt
# Only previous messages
assert len(message) % 2 == 0
history = []
for i in range(len(message) // 2):
m1, m2 = message[2 * i], message[2 * i + 1]
assert m1['role'] == 'user' and m2['role'] == 'assistant'
history.append((concat_tilist(m1['content']), concat_tilist(m2['content'])))
return history, image_path, image_cnt
def chat_inner_v2(self, message, dataset=None):
image_cnt = 0
if len(message) > 1:
history, image_path, image_cnt = self.build_history(message[:-1])
else:
history, image_path, image_cnt = None, [], 1
current_msg = message[-1]
question = ''
# If message is just text in the conversation
if len(current_msg['content']) == 1 and current_msg['content'][0]['type'] == 'text':
question = current_msg['content'][0]['value']
question = re.sub(self.pattern, self.replacement, question) # Fix pattern as per InternVL
else:
for msg in current_msg['content']:
if msg['type'] == 'text':
question += re.sub(self.pattern, self.replacement, msg['value'])
elif msg['type'] == 'image':
image_cnt += 1
question += '<image>\n'
image_path.append(msg['value'])
if image_cnt > 1:
num_patches_list = []
pixel_values_list = []
for image_idx, file_name in enumerate(image_path):
upscale_flag = image_idx == 0 and dataset is not None and listinstr(['MMMU_DEV_VAL'], dataset)
curr_pixel_values = load_image(
file_name, max_num=self.max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
num_patches_list.append(curr_pixel_values.size(0))
pixel_values_list.append(curr_pixel_values)
pixel_values = torch.cat(pixel_values_list, dim=0)
elif image_cnt == 1:
upscale_flag = listinstr(['MMMU_DEV_VAL'], dataset)
pixel_values = load_image(
image_path, max_num=self.max_num, upscale=upscale_flag).to(self.device).to(torch.bfloat16)
num_patches_list = [pixel_values.size(0)]
else:
pixel_values = None
num_patches_list = []
response, history = self.model.chat(
self.tokenizer,
pixel_values=pixel_values,
num_patches_list=num_patches_list,
question=question,
generation_config=self.kwargs,
history=history,
return_history=True
)
response = re.sub(self.reverse_pattern, self.reverse_replacement, response)
return response
def chat_inner(self, message, dataset=None):
self.set_max_num(dataset)
if self.version in ['V1.1', 'V1.2']:
raise ValueError(f'Unsupported version for Multi-Turn: {self.version}')
elif self.version == 'V1.5':
raise ValueError(f'Unsupported version for Multi-Turn: {self.version}')
elif self.version == 'V2.0':
kwargs_default = dict(do_sample=False, max_new_tokens=512, top_p=None, num_beams=1)
self.kwargs = kwargs_default
return self.chat_inner_v2(message, dataset)
else:
raise ValueError(f'Unsupported version for Multi-Turn: {self.version}')