Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions rag_factory/parser/Parser_Dotsocr/fig_recognize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import os
import glob
import json
import re
import fitz
from PIL import Image
from tqdm import tqdm
from dashscope import MultiModalConversation
import argparse
from pathlib import Path


def fig_understand(fig_path):
# prompt = '请给出图像中具体内容信息,并用json格式输出,仅输出json格式数据,其中,图片类型请从["chart","knowladge_map","other"]中选择'
prompt = '''
你是一个图像内容理解专家,任务是读取图像内容并生成结构化 JSON 数据。请遵循以下规则:

1. **仅输出 JSON 数据**,不要添加任何解释、前缀或后缀文字。
2. JSON 格式中必须包含两个字段:
- "type": 图像类型,只能从 ["chart", "knowladge_map", "other"] 中选择。
- "content": 图像的具体结构化内容描述。
3. 如果图像类型是:
- "chart": 请提取图表的标题、坐标轴标签、图例、系列等结构信息。
- "knowladge_map": 输出树状结构,所有节点使用 {"name": xxx, "children": [...]} 格式。
- "other": 尽可能准确描述图像的主要元素。

以下是几个示例,请模仿格式输出。

---

### 示例1(chart):
输入图像:柱状图,标题为“年度销售统计”,X轴为月份,Y轴为销售额,图例为“产品A”和“产品B”。

输出:
```json
{
"type": "chart",
"content": {

"title": "年度销售统计",
"x_axis": "月份",
"y_axis": "销售额",
"legend": ["产品A", "产品B"],
"series": [
{"name": "产品A", "data": [100, 120, 130]},
{"name": "产品B", "data": [80, 90, 100]}
]
}
}
示例2(knowladge_map):
输入图像:知识图谱,核心为“机器学习”,子节点有“监督学习”和“无监督学习”,监督学习下有“回归”和“分类”。

输出:
{
"type": "knowladge_map",
"content": {
"name": "机器学习",
"children": [
{
"name": "监督学习",
"children": [
{"name": "回归"},
{"name": "分类"}
]
},
{
"name": "无监督学习"
}
]
}
}
示例3(other):
输入图像:一张会议室内多个人开会的场景。

输出:
{
"type": "other",
"content": "一个会议室中有5个人正在围绕会议桌讨论,桌上有笔记本电脑和文件。"
}

请根据上面的示例输出格式,严格输出图像的内容识别结果,只返回符合格式的 JSON 数据。

'''
messages= [
{
"role": "user",
"content": [
{"image": f"file://{fig_path}"},
{"text": prompt}
]
}
]

response = MultiModalConversation.call(
api_key=os.environ.get('DASHSCOPE_API_KEY'),
model="qwen-vl-plus",
messages=messages,
)

# print(response)
return response["output"]["choices"][0]["message"].content[0]["text"].replace("```json",'').replace("```",'').strip()

def save_fig(file_path, page_no, index, bbox, scale):

file_name, file_ext = os.path.splitext(os.path.basename(file_path))
file_name = file_name.replace('_layout', '')
base_dir = os.path.dirname(file_path)
pdf_file = os.path.join(base_dir, f"{file_name}_original.pdf")
doc = fitz.open(pdf_file)
page = doc.load_page(page_no)

pdf_width = page.rect.width
pdf_height = page.rect.height

scale_x = scale[1] / pdf_width
scale_y = scale[0] / pdf_height
x1 = bbox[0] / scale_x
y1 = bbox[1] / scale_y
x2 = bbox[2] / scale_x
y2 = bbox[3] / scale_y
pdf_bbox = fitz.Rect(x1, y1, x2, y2)
zoom = 300 / 72 # 输出300 DPI
matrix = fitz.Matrix(zoom, zoom)
img = page.get_pixmap(matrix=matrix, clip=pdf_bbox, alpha=False)

save_dir = os.path.join(base_dir,f"{file_name}/image")
if not os.path.exists(save_dir):
os.mkdir(save_dir)


text = ''
save_path = os.path.join(save_dir, f'page_{page_no}_{index}.png')
if img is not None:
img.save(save_path)

text = fig_understand(save_path)

return save_path, text

def process_one_file(json_file):
file_name = os.path.basename(json_file)
base_dir = os.path.dirname(json_file)
output_path = str(json_file).replace("layout", "img_content")
data = []
with open(json_file, 'r', encoding='utf-8') as f:
json_data = json.load(f)
print(f"Processing file: {file_name}")
for row in tqdm(json_data):
if row.get('category','') == 'Picture':
bbox = row['bbox']
page_no = row['page_no']
if (bbox[2]-bbox[0])*(bbox[3]-bbox[1]) < 52000:
row['text'] = ""
else:
fig_path, text = save_fig(json_file, page_no=page_no, index=row['index'], bbox=bbox, scale=row['scale'])
# print(text)
row['text'] = json.loads(text)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=4)
return json_data

def main():
parser = argparse.ArgumentParser(description="Use vlm to get parsed figure content.")
parser.add_argument(
"--output", type=str, default="output",
help="Output parsed directory (default: output)"
)
args = parser.parse_args()


if os.path.isdir(args.output):
for file in sorted(Path(args.output).glob('*_layout.json')):
data = process_one_file(file)
elif os.path.isfile(args.output):
data = process_one_file(args.output)
else:
print(f"'{args.output}' no exist")

if __name__ == "__main__":
os.environ["DASHSCOPE_API_KEY"] = "your api key"
main()


17 changes: 12 additions & 5 deletions rag_factory/parser/Parser_Dotsocr/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,28 @@ snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir="Parser_Dotsocr/we
or

from modelscope import snapshot_download
snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir="Parser_Dotsocr/weights/DotsOCR")
snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir=model_dir)
```

## 2. vLLM inference

Using vLLM for faster paser speed ( based on vllm==0.9.1 )
Using vLLM for faster speed ( based on vllm==0.9.1 )

```
python vllm_launch.py --model_path dots_model_path
python vllm_launch.py --model_path weights/DotsOCR
```

## 3. Document Parse
## 3. Document parse

```
python parser.py pdf_path.pdf
python parser.py pdf_path.pdf (or pdfs_dir)
```

If you want to parse document with transformers,add `--use_hf=True`

## 4. Figure understand

Use vl model to understand content in parsed picture. Please obtain pdf layout parsed result first.
```
python fig_recognize.py --output output
```
12 changes: 0 additions & 12 deletions rag_factory/parser/Parser_Dotsocr/requirements.txt

This file was deleted.

2 changes: 1 addition & 1 deletion rag_factory/parser/Parser_Dotsocr/vllm_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
import argparse

def launch_vllm_server(hf_model_path="/home/yangcehao/doc_analysis/dots.ocr/weights/DotsOCR", num_gpus="0", gpu_memory_utilization=0.95, port=8001):
def launch_vllm_server(hf_model_path="weights/DotsOCR", num_gpus="0", gpu_memory_utilization=0.95, port=8001):
# 1. 检查模型路径
model_path = Path(hf_model_path).resolve()
if not model_path.exists():
Expand Down
20 changes: 19 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,22 @@ neo4j
aioboto3
llama-index
llama-index-core
peewee
peewee

mineru[core]
rank_bm25
faiss_gpu



# streamlit
PyMuPDF
openai
qwen_vl_utils
transformers==4.51.3
huggingface_hub
modelscope
flash-attn==2.8.0.post2
# for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2
accelerate
dashscope