Source code for texteller.api.inference

import re
import time
from collections import Counter
from typing import Literal

import cv2
import numpy as np
import torch
from onnxruntime import InferenceSession
from optimum.onnxruntime import ORTModelForVision2Seq
from transformers import GenerationConfig, RobertaTokenizerFast

from texteller.constants import MAX_TOKEN_SIZE
from texteller.logger import get_logger
from texteller.paddleocr import predict_det, predict_rec
from texteller.types import Bbox, TexTellerModel
from texteller.utils import (
    bbox_merge,
    get_device,
    mask_img,
    readimgs,
    remove_style,
    slice_from_image,
    split_conflict,
    transform,
    add_newlines,
)

from .detection import latex_detect
from .format import format_latex
from .katex import to_katex

_logger = get_logger()


[docs] def img2latex( model: TexTellerModel, tokenizer: RobertaTokenizerFast, images: list[str] | list[np.ndarray], device: torch.device | None = None, out_format: Literal["latex", "katex"] = "latex", keep_style: bool = False, max_tokens: int = MAX_TOKEN_SIZE, num_beams: int = 1, no_repeat_ngram_size: int = 0, ) -> list[str]: """ Convert images to LaTeX or KaTeX formatted strings. Args: model: The TexTeller or ORTModelForVision2Seq model instance tokenizer: The tokenizer for the model images: List of image paths or numpy arrays (RGB format) device: The torch device to use (defaults to available GPU or CPU) out_format: Output format, either "latex" or "katex" keep_style: Whether to keep the style of the LaTeX max_tokens: Maximum number of tokens to generate num_beams: Number of beams for beam search no_repeat_ngram_size: Size of n-grams to prevent repetition Returns: List of LaTeX or KaTeX strings corresponding to each input image Example: >>> import torch >>> from texteller import load_model, load_tokenizer, img2latex >>> >>> model = load_model(model_path=None, use_onnx=False) >>> tokenizer = load_tokenizer(tokenizer_path=None) >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> >>> res = img2latex(model, tokenizer, ["path/to/image.png"], device=device, out_format="katex") """ assert isinstance(images, list) assert len(images) > 0 if device is None: device = get_device() if device.type != model.device.type: if isinstance(model, ORTModelForVision2Seq): _logger.warning( f"Onnxruntime device mismatch: detected {str(device)} but model is on {str(model.device)}, using {str(model.device)} instead" ) else: model = model.to(device=device) if isinstance(images[0], str): images = readimgs(images) else: # already numpy array(rgb format) assert isinstance(images[0], np.ndarray) images = images images = transform(images) pixel_values = torch.stack(images) generate_config = GenerationConfig( max_new_tokens=max_tokens, num_beams=num_beams, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, no_repeat_ngram_size=no_repeat_ngram_size, ) pred = model.generate( pixel_values.to(model.device), generation_config=generate_config, ) res = tokenizer.batch_decode(pred, skip_special_tokens=True) if out_format == "katex": res = [to_katex(r) for r in res] if not keep_style: res = [remove_style(r) for r in res] res = [format_latex(r) for r in res] res = [add_newlines(r) for r in res] return res
[docs] def paragraph2md( img_path: str, latexdet_model: InferenceSession, textdet_model: predict_det.TextDetector, textrec_model: predict_rec.TextRecognizer, latexrec_model: TexTellerModel, tokenizer: RobertaTokenizerFast, device: torch.device | None = None, num_beams=1, ) -> str: """ Convert an image containing both text and mathematical formulas to markdown format. This function processes a mixed-content image by: 1. Detecting mathematical formulas using a latex detection model 2. Masking detected formula areas and detecting text regions using OCR 3. Recognizing text in the detected regions 4. Converting formula regions to LaTeX using the latex recognition model 5. Combining all detected elements into a properly formatted markdown string Args: img_path: Path to the input image containing text and formulas latexdet_model: ONNX InferenceSession for LaTeX formula detection textdet_model: OCR text detector model textrec_model: OCR text recognition model latexrec_model: TexTeller model for LaTeX formula recognition tokenizer: Tokenizer for the LaTeX recognition model device: The torch device to use (defaults to available GPU or CPU) num_beams: Number of beams for beam search during LaTeX generation Returns: Markdown formatted string containing the recognized text and formulas Example: >>> from texteller import load_latexdet_model, load_textdet_model, load_textrec_model, load_tokenizer, paragraph2md >>> >>> # Load all required models >>> latexdet_model = load_latexdet_model() >>> textdet_model = load_textdet_model() >>> textrec_model = load_textrec_model() >>> latexrec_model = load_model() >>> tokenizer = load_tokenizer() >>> >>> # Convert image to markdown >>> markdown_text = paragraph2md( ... img_path="path/to/mixed_content_image.jpg", ... latexdet_model=latexdet_model, ... textdet_model=textdet_model, ... textrec_model=textrec_model, ... latexrec_model=latexrec_model, ... tokenizer=tokenizer, ... ) """ img = cv2.imread(img_path) corners = [tuple(img[0, 0]), tuple(img[0, -1]), tuple(img[-1, 0]), tuple(img[-1, -1])] bg_color = np.array(Counter(corners).most_common(1)[0][0]) start_time = time.time() latex_bboxes = latex_detect(img_path, latexdet_model) end_time = time.time() _logger.info(f"latex_det_model time: {end_time - start_time:.2f}s") latex_bboxes = sorted(latex_bboxes) latex_bboxes = bbox_merge(latex_bboxes) masked_img = mask_img(img, latex_bboxes, bg_color) start_time = time.time() det_prediction, _ = textdet_model(masked_img) end_time = time.time() _logger.info(f"ocr_det_model time: {end_time - start_time:.2f}s") ocr_bboxes = [ Bbox( p[0][0], p[0][1], p[3][1] - p[0][1], p[1][0] - p[0][0], label="text", confidence=None, content=None, ) for p in det_prediction ] ocr_bboxes = sorted(ocr_bboxes) ocr_bboxes = bbox_merge(ocr_bboxes) ocr_bboxes = split_conflict(ocr_bboxes, latex_bboxes) ocr_bboxes = list(filter(lambda x: x.label == "text", ocr_bboxes)) sliced_imgs: list[np.ndarray] = slice_from_image(img, ocr_bboxes) start_time = time.time() rec_predictions, _ = textrec_model(sliced_imgs) end_time = time.time() _logger.info(f"ocr_rec_model time: {end_time - start_time:.2f}s") assert len(rec_predictions) == len(ocr_bboxes) for content, bbox in zip(rec_predictions, ocr_bboxes): bbox.content = content[0] latex_imgs = [] for bbox in latex_bboxes: latex_imgs.append(img[bbox.p.y : bbox.p.y + bbox.h, bbox.p.x : bbox.p.x + bbox.w]) start_time = time.time() latex_rec_res = img2latex( model=latexrec_model, tokenizer=tokenizer, images=latex_imgs, num_beams=num_beams, out_format="katex", device=device, keep_style=False, ) end_time = time.time() _logger.info(f"latex_rec_model time: {end_time - start_time:.2f}s") for bbox, content in zip(latex_bboxes, latex_rec_res): if bbox.label == "embedding": bbox.content = " $" + content + "$ " elif bbox.label == "isolated": bbox.content = "\n\n" + r"$$" + content + r"$$" + "\n\n" bboxes = sorted(ocr_bboxes + latex_bboxes) if bboxes == []: return "" md = "" prev = Bbox(bboxes[0].p.x, bboxes[0].p.y, -1, -1, label="guard") for curr in bboxes: # Add the formula number back to the isolated formula if prev.label == "isolated" and curr.label == "text" and prev.same_row(curr): curr.content = curr.content.strip() if curr.content.startswith("(") and curr.content.endswith(")"): curr.content = curr.content[1:-1] if re.search(r"\\tag\{.*\}$", md[:-4]) is not None: # in case of multiple tag md = md[:-5] + f", {curr.content}" + "}" + md[-4:] else: md = md[:-4] + f"\\tag{{{curr.content}}}" + md[-4:] continue if not prev.same_row(curr): md += " " if curr.label == "embedding": # remove the bold effect from inline formulas curr.content = remove_style(curr.content) # change split environment into aligned curr.content = curr.content.replace(r"\begin{split}", r"\begin{aligned}") curr.content = curr.content.replace(r"\end{split}", r"\end{aligned}") # remove extra spaces (keeping only one) curr.content = re.sub(r" +", " ", curr.content) assert curr.content.startswith("$") and curr.content.endswith("$") curr.content = " $" + curr.content.strip("$") + "$ " md += curr.content prev = curr return md.strip()