Add new OCRReader with PDF+OCR text merging (#66)
This change speeds up OCR extraction by allowing bypassing OCR for texts that are irrelevant (not in table). --------- Co-authored-by: Nguyen Trung Duc (john) <trungduc1992@gmail.com>
This commit is contained in:
committed by
GitHub
parent
d79b3744cb
commit
4704e2c11a
@@ -2,6 +2,8 @@ import csv
|
||||
from io import StringIO
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from .box import get_rect_iou
|
||||
|
||||
|
||||
def check_col_conflicts(
|
||||
col_a: List[str], col_b: List[str], thres: float = 0.15
|
||||
@@ -77,61 +79,6 @@ def compress_csv(csv_rows: List[List[str]]) -> List[List[str]]:
|
||||
return csv_rows
|
||||
|
||||
|
||||
def _get_rect_iou(gt_box: List[tuple], pd_box: List[tuple], iou_type=0) -> int:
|
||||
"""Intersection over union on layout rectangle
|
||||
|
||||
Args:
|
||||
gt_box: List[tuple]
|
||||
A list contains bounding box coordinates of ground truth
|
||||
pd_box: List[tuple]
|
||||
A list contains bounding box coordinates of prediction
|
||||
iou_type: int
|
||||
0: intersection / union, normal IOU
|
||||
1: intersection / min(areas), useful when boxes are under/over-segmented
|
||||
|
||||
Input format: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||
Annotation for each element in bbox:
|
||||
(x1, y1) (x2, y1)
|
||||
+-------+
|
||||
| |
|
||||
| |
|
||||
+-------+
|
||||
(x1, y2) (x2, y2)
|
||||
|
||||
Returns:
|
||||
Intersection over union value
|
||||
"""
|
||||
|
||||
assert iou_type in [0, 1], "Only support 0: origin iou, 1: intersection / min(area)"
|
||||
|
||||
# determine the (x, y)-coordinates of the intersection rectangle
|
||||
# gt_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||
# pd_box: [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
|
||||
x_left = max(gt_box[0][0], pd_box[0][0])
|
||||
y_top = max(gt_box[0][1], pd_box[0][1])
|
||||
x_right = min(gt_box[2][0], pd_box[2][0])
|
||||
y_bottom = min(gt_box[2][1], pd_box[2][1])
|
||||
|
||||
# compute the area of intersection rectangle
|
||||
interArea = max(0, x_right - x_left) * max(0, y_bottom - y_top)
|
||||
|
||||
# compute the area of both the prediction and ground-truth
|
||||
# rectangles
|
||||
gt_area = (gt_box[2][0] - gt_box[0][0]) * (gt_box[2][1] - gt_box[0][1])
|
||||
pd_area = (pd_box[2][0] - pd_box[0][0]) * (pd_box[2][1] - pd_box[0][1])
|
||||
|
||||
# compute the intersection over union by taking the intersection
|
||||
# area and dividing it by the sum of prediction + ground-truth
|
||||
# areas - the interesection area
|
||||
if iou_type == 0:
|
||||
iou = interArea / float(gt_area + pd_area - interArea)
|
||||
elif iou_type == 1:
|
||||
iou = interArea / max(min(gt_area, pd_area), 1)
|
||||
|
||||
# return the intersection over union value
|
||||
return iou
|
||||
|
||||
|
||||
def get_table_from_ocr(ocr_list: List[dict], table_list: List[dict]):
|
||||
"""Get list of text lines belong to table regions specified by table_list
|
||||
|
||||
@@ -148,7 +95,7 @@ def get_table_from_ocr(ocr_list: List[dict], table_list: List[dict]):
|
||||
continue
|
||||
cur_table_texts = []
|
||||
for ocr in ocr_list:
|
||||
_iou = _get_rect_iou(table["location"], ocr["location"], iou_type=1)
|
||||
_iou = get_rect_iou(table["location"], ocr["location"], iou_type=1)
|
||||
if _iou > 0.8:
|
||||
cur_table_texts.append(ocr["text"])
|
||||
table_texts.append(cur_table_texts)
|
||||
@@ -272,33 +219,6 @@ def strip_special_chars_markdown(text: str) -> str:
|
||||
return text.replace("|", "").replace(":---:", "").replace("---", "")
|
||||
|
||||
|
||||
def markdown_to_list(markdown_text: str, pad_to_max_col: Optional[bool] = True):
|
||||
rows = []
|
||||
lines = markdown_text.split("\n")
|
||||
markdown_lines = [line.strip() for line in lines if " | " in line]
|
||||
|
||||
for row in markdown_lines:
|
||||
tmp = row
|
||||
# Get rid of leading and trailing '|'
|
||||
if tmp.startswith("|"):
|
||||
tmp = tmp[1:]
|
||||
if tmp.endswith("|"):
|
||||
tmp = tmp[:-1]
|
||||
|
||||
# Split line and ignore column whitespace
|
||||
clean_line = tmp.split("|")
|
||||
if not all(c == "" for c in clean_line):
|
||||
# Append clean row data to rows variable
|
||||
rows.append(clean_line)
|
||||
|
||||
# Get rid of syntactical sugar to indicate header (2nd row)
|
||||
rows = [row for row in rows if "---" not in " ".join(row)]
|
||||
max_cols = max(len(row) for row in rows)
|
||||
if pad_to_max_col:
|
||||
rows = [row + [""] * (max_cols - len(row)) for row in rows]
|
||||
return rows
|
||||
|
||||
|
||||
def parse_markdown_text_to_tables(text: str) -> Tuple[List[str], List[str]]:
|
||||
"""Convert markdown text to list of non-table spans and table spans
|
||||
|
||||
@@ -333,3 +253,36 @@ def parse_markdown_text_to_tables(text: str) -> Tuple[List[str], List[str]]:
|
||||
table_texts = ["\n".join(table) for table in tables]
|
||||
non_table_texts = ["\n".join(text) for text in texts]
|
||||
return table_texts, non_table_texts
|
||||
|
||||
|
||||
def table_cells_to_markdown(cells: List[dict]):
|
||||
"""Convert list of cells with attached text to Markdown table"""
|
||||
|
||||
if len(cells) == 0:
|
||||
return ""
|
||||
|
||||
all_row_ids = []
|
||||
all_col_ids = []
|
||||
for cell in cells:
|
||||
all_row_ids.extend(cell["rows"])
|
||||
all_col_ids.extend(cell["columns"])
|
||||
|
||||
num_rows, num_cols = max(all_row_ids) + 1, max(all_col_ids) + 1
|
||||
table_rows = [["" for c in range(num_cols)] for r in range(num_rows)]
|
||||
|
||||
# start filling in the grid
|
||||
for cell in cells:
|
||||
cell_text = " ".join(item["text"] for item in cell["ocr"])
|
||||
start_row_id, end_row_id = cell["rows"]
|
||||
start_col_id, end_col_id = cell["columns"]
|
||||
span_cell = end_row_id != start_row_id or end_col_id != start_col_id
|
||||
|
||||
# do not repeat long text in span cell to prevent context length issue
|
||||
if span_cell and len(cell_text.replace(" ", "")) < 20 and start_row_id > 0:
|
||||
for row in range(start_row_id, end_row_id + 1):
|
||||
for col in range(start_col_id, end_col_id + 1):
|
||||
table_rows[row][col] += cell_text + " "
|
||||
else:
|
||||
table_rows[start_row_id][start_col_id] += cell_text + " "
|
||||
|
||||
return make_markdown_table(table_rows)
|
||||
|
Reference in New Issue
Block a user