kotaemon/knowledgehub/loaders/utils/table.py
Tuan Anh Nguyen Dang (Tadashi_Cin) 4704e2c11a Add new OCRReader with PDF+OCR text merging (#66)
This change speeds up OCR extraction by allowing bypassing OCR for texts that are irrelevant (not in table).

---------

Co-authored-by: Nguyen Trung Duc (john) <trungduc1992@gmail.com>
2023-11-13 17:43:02 +07:00

289 lines
8.5 KiB
Python

import csv
from io import StringIO
from typing import List, Optional, Tuple
from .box import get_rect_iou
def check_col_conflicts(
col_a: List[str], col_b: List[str], thres: float = 0.15
) -> bool:
"""Check if 2 columns A and B has non-empty content in the same row
(to be used with merge_cols)
Args:
col_a: column A (list of str)
col_b: column B (list of str)
thres: percentage of overlapping allowed
Returns:
if number of overlapping greater than threshold
"""
num_rows = len([cell for cell in col_a if cell])
assert len(col_a) == len(col_b)
conflict_count = 0
for cell_a, cell_b in zip(col_a, col_b):
if cell_a and cell_b:
conflict_count += 1
return conflict_count > num_rows * thres
def merge_cols(col_a: List[str], col_b: List[str]) -> List[str]:
"""Merge column A and B if they do not have conflict rows
Args:
col_a: column A (list of str)
col_b: column B (list of str)
Returns:
merged column
"""
for r_id in range(len(col_a)):
if col_b[r_id]:
col_a[r_id] = col_a[r_id] + " " + col_b[r_id]
return col_a
def add_index_col(csv_rows: List[List[str]]) -> List[List[str]]:
"""Add index column as the first column of the table csv_rows
Args:
csv_rows: input table
Returns:
output table with index column
"""
new_csv_rows = [["row id"] + [""] * len(csv_rows[0])]
for r_id, row in enumerate(csv_rows):
new_csv_rows.append([str(r_id + 1)] + row)
return new_csv_rows
def compress_csv(csv_rows: List[List[str]]) -> List[List[str]]:
"""Compress table csv_rows by merging sparse columns (merge_cols)
Args:
csv_rows: input table
Returns:
output: compressed table
"""
csv_cols = [[r[c_id] for r in csv_rows] for c_id in range(len(csv_rows[0]))]
to_remove_col_ids = []
last_c_id = 0
for c_id in range(1, len(csv_cols)):
if not check_col_conflicts(csv_cols[last_c_id], csv_cols[c_id]):
to_remove_col_ids.append(c_id)
csv_cols[last_c_id] = merge_cols(csv_cols[last_c_id], csv_cols[c_id])
else:
last_c_id = c_id
csv_cols = [r for c_id, r in enumerate(csv_cols) if c_id not in to_remove_col_ids]
csv_rows = [[c[r_id] for c in csv_cols] for r_id in range(len(csv_cols[0]))]
return csv_rows
def get_table_from_ocr(ocr_list: List[dict], table_list: List[dict]):
"""Get list of text lines belong to table regions specified by table_list
Args:
ocr_list: list of OCR output in Casia format (Flax)
table_list: list of table output in Casia format (Flax)
Returns:
_type_: _description_
"""
table_texts = []
for table in table_list:
if table["type"] != "table":
continue
cur_table_texts = []
for ocr in ocr_list:
_iou = get_rect_iou(table["location"], ocr["location"], iou_type=1)
if _iou > 0.8:
cur_table_texts.append(ocr["text"])
table_texts.append(cur_table_texts)
return table_texts
def make_markdown_table(array: List[List[str]]) -> str:
"""Convert table rows in list format to markdown string
Args:
Python list with rows of table as lists
First element as header.
Example Input:
[["Name", "Age", "Height"],
["Jake", 20, 5'10],
["Mary", 21, 5'7]]
Returns:
String to put into a .md file
"""
array = compress_csv(array)
array = add_index_col(array)
markdown = "\n" + str("| ")
for e in array[0]:
to_add = " " + str(e) + str(" |")
markdown += to_add
markdown += "\n"
markdown += "| "
for i in range(len(array[0])):
markdown += str("--- | ")
markdown += "\n"
for entry in array[1:]:
markdown += str("| ")
for e in entry:
to_add = str(e) + str(" | ")
markdown += to_add
markdown += "\n"
return markdown + "\n"
def parse_csv_string_to_list(csv_str: str) -> List[List[str]]:
"""Convert CSV string to list of rows
Args:
csv_str: input CSV string
Returns:
Output table in list format
"""
io = StringIO(csv_str)
csv_reader = csv.reader(io, delimiter=",")
rows = [row for row in csv_reader]
return rows
def format_cell(cell: str, length_limit: Optional[int] = None) -> str:
"""Format cell content by remove redundant character and enforce length limit
Args:
cell: input cell text
length_limit: limit of text length.
Returns:
new cell text
"""
cell = cell.replace("\n", " ")
if length_limit:
cell = cell[:length_limit]
return cell
def extract_tables_from_csv_string(
csv_content: str, table_texts: List[List[str]]
) -> Tuple[List[str], str]:
"""Extract list of table from FullOCR output
(csv_content) with the specified table_texts
Args:
csv_content: CSV output from FullOCR pipeline
table_texts: list of table texts extracted
from get_table_from_ocr()
Returns:
List of tables and non-text content
"""
rows = parse_csv_string_to_list(csv_content)
used_row_ids = []
table_csv_list = []
for table in table_texts:
cur_rows = []
for row_id, row in enumerate(rows):
scores = [
any(cell in cell_reference for cell in table)
for cell_reference in row
if cell_reference
]
score = sum(scores) / len(scores)
if score > 0.5 and row_id not in used_row_ids:
used_row_ids.append(row_id)
cur_rows.append([format_cell(cell) for cell in row])
if cur_rows:
table_csv_list.append(make_markdown_table(cur_rows))
else:
print("table not matched", table)
non_table_rows = [
row for row_id, row in enumerate(rows) if row_id not in used_row_ids
]
non_table_text = "\n".join(
" ".join(format_cell(cell) for cell in row) for row in non_table_rows
)
return table_csv_list, non_table_text
def strip_special_chars_markdown(text: str) -> str:
"""Strip special characters from input text in markdown table format"""
return text.replace("|", "").replace(":---:", "").replace("---", "")
def parse_markdown_text_to_tables(text: str) -> Tuple[List[str], List[str]]:
"""Convert markdown text to list of non-table spans and table spans
Args:
text: input markdown text
Returns:
list of table spans and non-table spans
"""
# init empty tables and texts list
tables = []
texts = []
# split input by line break
lines = text.split("\n")
cur_table = []
cur_text: List[str] = []
for line in lines:
line = line.strip()
if line.startswith("|"):
if len(cur_text) > 0:
texts.append(cur_text)
cur_text = []
cur_table.append(line)
else:
# add new table to the list
if len(cur_table) > 0:
tables.append(cur_table)
cur_table = []
cur_text.append(line)
table_texts = ["\n".join(table) for table in tables]
non_table_texts = ["\n".join(text) for text in texts]
return table_texts, non_table_texts
def table_cells_to_markdown(cells: List[dict]):
"""Convert list of cells with attached text to Markdown table"""
if len(cells) == 0:
return ""
all_row_ids = []
all_col_ids = []
for cell in cells:
all_row_ids.extend(cell["rows"])
all_col_ids.extend(cell["columns"])
num_rows, num_cols = max(all_row_ids) + 1, max(all_col_ids) + 1
table_rows = [["" for c in range(num_cols)] for r in range(num_rows)]
# start filling in the grid
for cell in cells:
cell_text = " ".join(item["text"] for item in cell["ocr"])
start_row_id, end_row_id = cell["rows"]
start_col_id, end_col_id = cell["columns"]
span_cell = end_row_id != start_row_id or end_col_id != start_col_id
# do not repeat long text in span cell to prevent context length issue
if span_cell and len(cell_text.replace(" ", "")) < 20 and start_row_id > 0:
for row in range(start_row_id, end_row_id + 1):
for col in range(start_col_id, end_col_id + 1):
table_rows[row][col] += cell_text + " "
else:
table_rows[start_row_id][start_col_id] += cell_text + " "
return make_markdown_table(table_rows)