Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html as html_lib
from typing import Dict

import numpy as np
Expand Down Expand Up @@ -235,6 +236,22 @@ def match_table_and_ocr(
return matched


def _split_outer_bold_wrapper(content: str) -> tuple[bool, str]:
"""Return whether content is wrapped by a single outer ``<b>...</b>`` pair."""
if content.startswith("<b>") and content.endswith("</b>"):
return True, content[3:-4]
return False, content


def _escape_cell_content(content: str) -> str:
"""Escape OCR cell text while preserving a single outer bold wrapper."""
is_bold, inner_content = _split_outer_bold_wrapper(content)
escaped_content = html_lib.escape(inner_content, quote=True)
if is_bold:
return f"<b>{escaped_content}</b>"
return escaped_content


def get_html_result(
matched_index: dict, ocr_contents: dict, pred_structures: list
) -> str:
Expand All @@ -261,7 +278,7 @@ def get_html_result(
if td_index in matched_index.keys():
b_with = False
if (
"<b>" in ocr_contents[matched_index[td_index][0]]
ocr_contents[matched_index[td_index][0]].startswith("<b>")
and len(matched_index[td_index]) > 1
):
b_with = True
Expand All @@ -273,15 +290,17 @@ def get_html_result(
continue
if content[0] == " ":
content = content[1:]
if "<b>" in content:
if content.startswith("<b>"):
content = content[3:]
if "</b>" in content:
if content.endswith("</b>"):
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
content += " "
pred_html.extend(content)
pred_html.append(html_lib.escape(content, quote=True))
else:
pred_html.append(_escape_cell_content(content))
if b_with:
pred_html.extend("</b>")
if "<td></td>" == tag:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html as html_lib
import math

import numpy as np
Expand Down Expand Up @@ -188,6 +189,22 @@ def match_table_and_ocr(cell_box_list, ocr_dt_boxes, table_cells_flag, row_start
return all_matched


def _split_outer_bold_wrapper(content: str) -> tuple[bool, str]:
"""Return whether content is wrapped by a single outer ``<b>...</b>`` pair."""
if content.startswith("<b>") and content.endswith("</b>"):
return True, content[3:-4]
return False, content


def _escape_cell_content(content: str) -> str:
"""Escape OCR cell text while preserving a single outer bold wrapper."""
is_bold, inner_content = _split_outer_bold_wrapper(content)
escaped_content = html_lib.escape(inner_content, quote=True)
if is_bold:
return f"<b>{escaped_content}</b>"
return escaped_content


def get_html_result(
all_matched_index: dict, ocr_contents: dict, pred_structures: list, table_cells_flag
) -> str:
Expand Down Expand Up @@ -219,7 +236,7 @@ def get_html_result(
continue
b_with = False
if (
"<b>" in ocr_contents[matched_index[td_index][0]]
ocr_contents[matched_index[td_index][0]].startswith("<b>")
and len(matched_index[td_index]) > 1
):
b_with = True
Expand All @@ -231,15 +248,17 @@ def get_html_result(
continue
if content[0] == " ":
content = content[1:]
if "<b>" in content:
if content.startswith("<b>"):
content = content[3:]
if "</b>" in content:
if content.endswith("</b>"):
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
content += " "
pred_html.extend(content)
pred_html.append(html_lib.escape(content, quote=True))
else:
pred_html.append(_escape_cell_content(content))
if b_with:
pred_html.extend("</b>")
if "<td></td>" == tag:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Regression tests for table markdown HTML escaping."""

import sys
import types
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path

import pytest


def _repo_root() -> Path:
current = Path(__file__).resolve()
return next(
parent for parent in current.parents if (parent / "pyproject.toml").exists()
)


def _install_fake_package_tree():
module_names = [
"paddlex",
"paddlex.inference",
"paddlex.inference.pipelines",
"paddlex.inference.pipelines.components",
"paddlex.inference.pipelines.layout_parsing",
"paddlex.inference.pipelines.layout_parsing.utils",
"paddlex.inference.pipelines.ocr",
"paddlex.inference.pipelines.ocr.result",
"paddlex.inference.pipelines.table_recognition",
"paddlex.inference.pipelines.table_recognition.result",
]
originals = {name: sys.modules.get(name) for name in module_names}

paddlex_mod = types.ModuleType("paddlex")
paddlex_mod.__path__ = []
inference_mod = types.ModuleType("paddlex.inference")
inference_mod.__path__ = []
pipelines_mod = types.ModuleType("paddlex.inference.pipelines")
pipelines_mod.__path__ = []
components_mod = types.ModuleType("paddlex.inference.pipelines.components")
layout_parsing_mod = types.ModuleType("paddlex.inference.pipelines.layout_parsing")
layout_parsing_mod.__path__ = []
layout_utils_mod = types.ModuleType(
"paddlex.inference.pipelines.layout_parsing.utils"
)
ocr_mod = types.ModuleType("paddlex.inference.pipelines.ocr")
ocr_mod.__path__ = []
ocr_result_mod = types.ModuleType("paddlex.inference.pipelines.ocr.result")
table_pkg_mod = types.ModuleType("paddlex.inference.pipelines.table_recognition")
table_pkg_mod.__path__ = []
table_result_mod = types.ModuleType(
"paddlex.inference.pipelines.table_recognition.result"
)

class OCRResult(dict):
pass

class SingleTableRecognitionResult(dict):
pass

components_mod.convert_points_to_boxes = lambda *args, **kwargs: None
layout_utils_mod.get_sub_regions_ocr_res = lambda *args, **kwargs: None
ocr_result_mod.OCRResult = OCRResult
table_result_mod.SingleTableRecognitionResult = SingleTableRecognitionResult

sys.modules["paddlex"] = paddlex_mod
sys.modules["paddlex.inference"] = inference_mod
sys.modules["paddlex.inference.pipelines"] = pipelines_mod
sys.modules["paddlex.inference.pipelines.components"] = components_mod
sys.modules["paddlex.inference.pipelines.layout_parsing"] = layout_parsing_mod
sys.modules["paddlex.inference.pipelines.layout_parsing.utils"] = layout_utils_mod
sys.modules["paddlex.inference.pipelines.ocr"] = ocr_mod
sys.modules["paddlex.inference.pipelines.ocr.result"] = ocr_result_mod
sys.modules["paddlex.inference.pipelines.table_recognition"] = table_pkg_mod
sys.modules["paddlex.inference.pipelines.table_recognition.result"] = (
table_result_mod
)

return originals


def _restore_modules(originals, loaded_name):
for name, original in originals.items():
if original is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = original

sys.modules.pop(loaded_name, None)


def _load_table_module(module_filename: str):
repo_root = _repo_root()
module_name = (
f"paddlex.inference.pipelines.table_recognition.{module_filename[:-3]}"
)
module_path = (
repo_root
/ "paddlex"
/ "inference"
/ "pipelines"
/ "table_recognition"
/ module_filename
)
originals = _install_fake_package_tree()
try:
spec = spec_from_file_location(module_name, module_path)
module = module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader is not None
spec.loader.exec_module(module)
return module
finally:
_restore_modules(originals, module_name)


def _minimal_table_structure():
return [
"<html>",
"<body>",
"<table>",
"<tr>",
"<td></td>",
"</tr>",
"</table>",
"</body>",
"</html>",
]


def _render_v1(ocr_contents):
module = _load_table_module("table_recognition_post_processing.py")
return module.get_html_result(
{0: list(range(len(ocr_contents)))},
ocr_contents,
_minimal_table_structure(),
)


def _render_v2(ocr_contents):
module = _load_table_module("table_recognition_post_processing_v2.py")
return module.get_html_result(
[{0: list(range(len(ocr_contents)))}],
ocr_contents,
_minimal_table_structure(),
[0, 1],
)


@pytest.mark.parametrize("render_html", [_render_v1, _render_v2])
def test_escapes_html_sensitive_ocr_text_in_single_cell(render_html):
html = render_html(['<recv response="200" response_txn="invite" />'])

assert (
'<td>&lt;recv response=&quot;200&quot; response_txn=&quot;invite&quot; /&gt;</td>'
in html
)


@pytest.mark.parametrize("render_html", [_render_v1, _render_v2])
def test_preserves_single_outer_bold_wrapper_when_escaping_cell_text(render_html):
html = render_html(['<b><pause milliseconds="5000"/></b>'])

assert '<td><b>&lt;pause milliseconds=&quot;5000&quot;/&gt;</b></td>' in html


@pytest.mark.parametrize("render_html", [_render_v1, _render_v2])
def test_escapes_multi_fragment_bold_cell_text(render_html):
html = render_html(['<b><recv response="200"', 'response_txn="invite"/></b>'])

assert (
'<td><b>&lt;recv response=&quot;200&quot; '
'response_txn=&quot;invite&quot;/&gt;</b></td>' in html
)
Loading