import os import torch from typing import Dict, List, Any from transformers import AutoTokenizer from gector import GECToR, predict, load_verb_dict class EndpointHandler: def __init__(self, path=""): self.model = GECToR.from_pretrained(path) self.tokenizer = AutoTokenizer.from_pretrained(path) self.encode, self.decode = load_verb_dict( os.path.join(path, "data/verb-form-vocab.txt") ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(device) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the input data and return the predicted results. Args: data (Dict[str, Any]): The input data dictionary containing the following keys: - "inputs" (List[str]): A list of input strings to be processed. - "n_iterations" (int, optional): The number of iterations for prediction. Defaults to 5. - "batch_size" (int, optional): The batch size for prediction. Defaults to 2. - "keep_confidence" (float, optional): The confidence threshold for keeping predictions. Defaults to 0.0. - "min_error_prob" (float, optional): The minimum error probability for keeping predictions. Defaults to 0.0. Returns: List[Dict[str, Any]]: A list of dictionaries containing the predicted results for each input string. """ srcs = data["inputs"] # Extract optional parameters from data, with defaults n_iterations = data.get("n_iterations", 5) batch_size = data.get("batch_size", 2) keep_confidence = data.get("keep_confidence", 0.0) min_error_prob = data.get("min_error_prob", 0.0) return predict( model=self.model, tokenizer=self.tokenizer, srcs=srcs, encode=self.encode, decode=self.decode, keep_confidence=keep_confidence, min_error_prob=min_error_prob, n_iteration=n_iterations, batch_size=batch_size, )