diff --git a/rag_factory/parser/Parser_Dotsocr/fig_recognize.py b/rag_factory/parser/Parser_Dotsocr/fig_recognize.py new file mode 100644 index 0000000..175529e --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/fig_recognize.py @@ -0,0 +1,183 @@ +import os +import glob +import json +import re +import fitz +from PIL import Image +from tqdm import tqdm +from dashscope import MultiModalConversation +import argparse +from pathlib import Path + + +def fig_understand(fig_path): + # prompt = '请给出图像中具体内容信息,并用json格式输出,仅输出json格式数据,其中,图片类型请从["chart","knowladge_map","other"]中选择' + prompt = ''' +你是一个图像内容理解专家,任务是读取图像内容并生成结构化 JSON 数据。请遵循以下规则: + +1. **仅输出 JSON 数据**,不要添加任何解释、前缀或后缀文字。 +2. JSON 格式中必须包含两个字段: + - "type": 图像类型,只能从 ["chart", "knowladge_map", "other"] 中选择。 + - "content": 图像的具体结构化内容描述。 +3. 如果图像类型是: + - "chart": 请提取图表的标题、坐标轴标签、图例、系列等结构信息。 + - "knowladge_map": 输出树状结构,所有节点使用 {"name": xxx, "children": [...]} 格式。 + - "other": 尽可能准确描述图像的主要元素。 + +以下是几个示例,请模仿格式输出。 + +--- + +### 示例1(chart): +输入图像:柱状图,标题为“年度销售统计”,X轴为月份,Y轴为销售额,图例为“产品A”和“产品B”。 + +输出: +```json +{ + "type": "chart", + "content": { + + "title": "年度销售统计", + "x_axis": "月份", + "y_axis": "销售额", + "legend": ["产品A", "产品B"], + "series": [ + {"name": "产品A", "data": [100, 120, 130]}, + {"name": "产品B", "data": [80, 90, 100]} + ] + } +} +示例2(knowladge_map): +输入图像:知识图谱,核心为“机器学习”,子节点有“监督学习”和“无监督学习”,监督学习下有“回归”和“分类”。 + +输出: +{ + "type": "knowladge_map", + "content": { + "name": "机器学习", + "children": [ + { + "name": "监督学习", + "children": [ + {"name": "回归"}, + {"name": "分类"} + ] + }, + { + "name": "无监督学习" + } + ] + } +} +示例3(other): +输入图像:一张会议室内多个人开会的场景。 + +输出: +{ + "type": "other", + "content": "一个会议室中有5个人正在围绕会议桌讨论,桌上有笔记本电脑和文件。" +} + +请根据上面的示例输出格式,严格输出图像的内容识别结果,只返回符合格式的 JSON 数据。 + +''' + messages= [ + { + "role": "user", + "content": [ + {"image": f"file://{fig_path}"}, + {"text": prompt} + ] + } +] + + response = MultiModalConversation.call( + api_key=os.environ.get('DASHSCOPE_API_KEY'), + model="qwen-vl-plus", + messages=messages, + ) + + # print(response) + return response["output"]["choices"][0]["message"].content[0]["text"].replace("```json",'').replace("```",'').strip() + +def save_fig(file_path, page_no, index, bbox, scale): + + file_name, file_ext = os.path.splitext(os.path.basename(file_path)) + file_name = file_name.replace('_layout', '') + base_dir = os.path.dirname(file_path) + pdf_file = os.path.join(base_dir, f"{file_name}_original.pdf") + doc = fitz.open(pdf_file) + page = doc.load_page(page_no) + + pdf_width = page.rect.width + pdf_height = page.rect.height + + scale_x = scale[1] / pdf_width + scale_y = scale[0] / pdf_height + x1 = bbox[0] / scale_x + y1 = bbox[1] / scale_y + x2 = bbox[2] / scale_x + y2 = bbox[3] / scale_y + pdf_bbox = fitz.Rect(x1, y1, x2, y2) + zoom = 300 / 72 # 输出300 DPI + matrix = fitz.Matrix(zoom, zoom) + img = page.get_pixmap(matrix=matrix, clip=pdf_bbox, alpha=False) + + save_dir = os.path.join(base_dir,f"{file_name}/image") + if not os.path.exists(save_dir): + os.mkdir(save_dir) + + + text = '' + save_path = os.path.join(save_dir, f'page_{page_no}_{index}.png') + if img is not None: + img.save(save_path) + + text = fig_understand(save_path) + + return save_path, text + +def process_one_file(json_file): + file_name = os.path.basename(json_file) + base_dir = os.path.dirname(json_file) + output_path = str(json_file).replace("layout", "img_content") + data = [] + with open(json_file, 'r', encoding='utf-8') as f: + json_data = json.load(f) + print(f"Processing file: {file_name}") + for row in tqdm(json_data): + if row.get('category','') == 'Picture': + bbox = row['bbox'] + page_no = row['page_no'] + if (bbox[2]-bbox[0])*(bbox[3]-bbox[1]) < 52000: + row['text'] = "" + else: + fig_path, text = save_fig(json_file, page_no=page_no, index=row['index'], bbox=bbox, scale=row['scale']) + # print(text) + row['text'] = json.loads(text) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(json_data, f, ensure_ascii=False, indent=4) + return json_data + +def main(): + parser = argparse.ArgumentParser(description="Use vlm to get parsed figure content.") + parser.add_argument( + "--output", type=str, default="output", + help="Output parsed directory (default: output)" + ) + args = parser.parse_args() + + + if os.path.isdir(args.output): + for file in sorted(Path(args.output).glob('*_layout.json')): + data = process_one_file(file) + elif os.path.isfile(args.output): + data = process_one_file(args.output) + else: + print(f"'{args.output}' no exist") + +if __name__ == "__main__": + os.environ["DASHSCOPE_API_KEY"] = "your api key" + main() + + diff --git a/rag_factory/parser/Parser_Dotsocr/readme.md b/rag_factory/parser/Parser_Dotsocr/readme.md index 35f5294..d83b416 100644 --- a/rag_factory/parser/Parser_Dotsocr/readme.md +++ b/rag_factory/parser/Parser_Dotsocr/readme.md @@ -20,21 +20,28 @@ snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir="Parser_Dotsocr/we or from modelscope import snapshot_download -snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir="Parser_Dotsocr/weights/DotsOCR") +snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir=model_dir) ``` ## 2. vLLM inference -Using vLLM for faster paser speed ( based on vllm==0.9.1 ) +Using vLLM for faster speed ( based on vllm==0.9.1 ) ``` -python vllm_launch.py --model_path dots_model_path +python vllm_launch.py --model_path weights/DotsOCR ``` -## 3. Document Parse +## 3. Document parse ``` -python parser.py pdf_path.pdf +python parser.py pdf_path.pdf (or pdfs_dir) ``` If you want to parse document with transformers,add `--use_hf=True` + +## 4. Figure understand + +Use vl model to understand content in parsed picture. Please obtain pdf layout parsed result first. +``` +python fig_recognize.py --output output +``` diff --git a/rag_factory/parser/Parser_Dotsocr/requirements.txt b/rag_factory/parser/Parser_Dotsocr/requirements.txt deleted file mode 100644 index 8159c46..0000000 --- a/rag_factory/parser/Parser_Dotsocr/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -# streamlit -gradio -gradio_image_annotation -PyMuPDF -openai -qwen_vl_utils -transformers==4.51.3 -huggingface_hub -modelscope -flash-attn==2.8.0.post2 -# for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2 -accelerate diff --git a/rag_factory/parser/Parser_Dotsocr/vllm_launch.py b/rag_factory/parser/Parser_Dotsocr/vllm_launch.py index d79bc38..95d6fbc 100644 --- a/rag_factory/parser/Parser_Dotsocr/vllm_launch.py +++ b/rag_factory/parser/Parser_Dotsocr/vllm_launch.py @@ -4,7 +4,7 @@ from pathlib import Path import argparse -def launch_vllm_server(hf_model_path="/home/yangcehao/doc_analysis/dots.ocr/weights/DotsOCR", num_gpus="0", gpu_memory_utilization=0.95, port=8001): +def launch_vllm_server(hf_model_path="weights/DotsOCR", num_gpus="0", gpu_memory_utilization=0.95, port=8001): # 1. 检查模型路径 model_path = Path(hf_model_path).resolve() if not model_path.exists(): diff --git a/requirements.txt b/requirements.txt index ac808e5..823e164 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,22 @@ neo4j aioboto3 llama-index llama-index-core -peewee \ No newline at end of file +peewee + +mineru[core] +rank_bm25 +faiss_gpu + + + +# streamlit +PyMuPDF +openai +qwen_vl_utils +transformers==4.51.3 +huggingface_hub +modelscope +flash-attn==2.8.0.post2 +# for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2 +accelerate +dashscope