Source code for texteller.api.load

from pathlib import Path

import wget
from onnxruntime import InferenceSession
from transformers import RobertaTokenizerFast

from texteller.constants import LATEX_DET_MODEL_URL, TEXT_DET_MODEL_URL, TEXT_REC_MODEL_URL
from texteller.globals import Globals
from texteller.logger import get_logger
from texteller.models import TexTeller
from texteller.paddleocr import predict_det, predict_rec
from texteller.paddleocr.utility import parse_args
from texteller.utils import cuda_available, mkdir, resolve_path
from texteller.types import TexTellerModel

_logger = get_logger(__name__)


[docs] def load_model(model_dir: str | None = None, use_onnx: bool = False) -> TexTellerModel: """ Load the TexTeller model for LaTeX recognition. This function loads the main TexTeller model, which is responsible for converting images to LaTeX. It can load either the standard PyTorch model or the optimized ONNX version. Args: model_dir: Directory containing the model files. If None, uses the default model. use_onnx: Whether to load the ONNX version of the model for faster inference. Requires the 'optimum' package and ONNX Runtime. Returns: Loaded TexTeller model instance Example: >>> from texteller import load_model >>> >>> model = load_model(use_onnx=True) """ return TexTeller.from_pretrained(model_dir, use_onnx=use_onnx)
[docs] def load_tokenizer(tokenizer_dir: str | None = None) -> RobertaTokenizerFast: """ Load the tokenizer for the TexTeller model. This function loads the tokenizer used by the TexTeller model for encoding and decoding LaTeX sequences. Args: tokenizer_dir: Directory containing the tokenizer files. If None, uses the default tokenizer. Returns: RobertaTokenizerFast instance Example: >>> from texteller import load_tokenizer >>> >>> tokenizer = load_tokenizer() """ return TexTeller.get_tokenizer(tokenizer_dir)
[docs] def load_latexdet_model() -> InferenceSession: """ Load the LaTeX detection model. This function loads the model responsible for detecting LaTeX formulas in images. The model is implemented as an ONNX InferenceSession for optimal performance. Returns: ONNX InferenceSession for LaTeX detection Example: >>> from texteller import load_latexdet_model >>> >>> detector = load_latexdet_model() """ fpath = _maybe_download(LATEX_DET_MODEL_URL) return InferenceSession( resolve_path(fpath), providers=["CUDAExecutionProvider" if cuda_available() else "CPUExecutionProvider"], )
[docs] def load_textrec_model() -> predict_rec.TextRecognizer: """ Load the text recognition model. This function loads the model responsible for recognizing regular text in images. It's based on PaddleOCR's text recognition model. Returns: PaddleOCR TextRecognizer instance Example: >>> from texteller import load_textrec_model >>> >>> text_recognizer = load_textrec_model() """ fpath = _maybe_download(TEXT_REC_MODEL_URL) paddleocr_args = parse_args() paddleocr_args.use_onnx = True paddleocr_args.rec_model_dir = resolve_path(fpath) paddleocr_args.use_gpu = cuda_available() predictor = predict_rec.TextRecognizer(paddleocr_args) return predictor
[docs] def load_textdet_model() -> predict_det.TextDetector: """ Load the text detection model. This function loads the model responsible for detecting text regions in images. It's based on PaddleOCR's text detection model. Returns: PaddleOCR TextDetector instance Example: >>> from texteller import load_textdet_model >>> >>> text_detector = load_textdet_model() """ fpath = _maybe_download(TEXT_DET_MODEL_URL) paddleocr_args = parse_args() paddleocr_args.use_onnx = True paddleocr_args.det_model_dir = resolve_path(fpath) paddleocr_args.use_gpu = cuda_available() predictor = predict_det.TextDetector(paddleocr_args) return predictor
def _maybe_download(url: str, dirpath: str | None = None, force: bool = False) -> Path: """ Download a file if it doesn't already exist. Args: url: URL to download from dirpath: Directory to save the file in. If None, uses the default cache directory. force: Whether to force download even if the file already exists Returns: Path to the downloaded file """ if dirpath is None: dirpath = Globals().cache_dir mkdir(dirpath) fname = Path(url).name fpath = Path(dirpath) / fname if not fpath.exists() or force: _logger.info(f"Downloading {fname} from {url} to {fpath}") wget.download(url, resolve_path(fpath)) return fpath