import os import sys import logging from datetime import datetime import torch from transformers import AutoProcessor, Pix2StructForConditionalGeneration from PIL import Image # 配置日志格式 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler('app.log') ] ) logger = logging.getLogger(__name__) def print_section(title, char='='): """打印格式化的章节标题""" print(f"\n{char * 50}") print(f"{title.center(50)}") print(f"{char * 50}\n") def print_table(data): """格式化打印表格数据""" if not data: print("No data available") return # 计算每列的最大宽度 col_widths = [] for i in range(len(data[0])): col_width = max(len(str(row[i])) for row in data) col_widths.append(col_width) # 打印表头 header = data[0] header_str = " | ".join(str(header[i]).ljust(col_widths[i]) for i in range(len(header))) print(header_str) print("-" * len(header_str)) # 打印数据行 for row in data[1:]: row_str = " | ".join(str(row[i]).ljust(col_widths[i]) for i in range(len(row))) print(row_str) class ChartAnalyzer: def __init__(self): try: print_section("初始化模型") print("正在加载模型和处理器...") self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot") self.processor = AutoProcessor.from_pretrained("google/deplot") print("✓ 模型加载完成") except Exception as e: print("✗ 模型加载失败") logger.error(f"Error initializing model: {str(e)}") raise def process_image(self, image_path, prompt=None): """处理图片并生成数据表格""" try: print_section("图片处理", char='-') # 验证文件存在 if not os.path.exists(image_path): raise FileNotFoundError(f"找不到图片文件: {image_path}") # 打开并处理图片 print(f"正在处理图片: {image_path}") image = Image.open(image_path) # 准备输入 if prompt is None: prompt = "Generate underlying data table of the figure below:" inputs = self.processor( images=image, text=prompt, return_tensors="pt" ) # 生成预测 print("\n正在生成数据分析...") with torch.no_grad(): predictions = self.model.generate( **inputs, max_new_tokens=512, num_beams=4, length_penalty=1.0 ) # 解码预测结果 raw_output = self.processor.decode(predictions[0], skip_special_tokens=True) # 处理结果 split_by_newline = raw_output.split("<0x0A>") result_array = [] for item in split_by_newline: if item.strip(): # 跳过空行 result_array.append([x.strip() for x in item.split("|")]) # 保存结果 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_file = f'results_{timestamp}.log' with open(output_file, mode='w', encoding='utf-8') as file: for row in result_array: file.write(" | ".join(row) + "\n") print(f"\n✓ 结果已保存至: {output_file}") return result_array except Exception as e: print("\n✗ 处理失败") logger.error(f"Error processing image: {str(e)}") raise def main(): try: print_section("图表数据提取系统", char='*') # 创建分析器实例 analyzer = ChartAnalyzer() # 指定图片路径 image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg' # 处理图片 results = analyzer.process_image(image_path) # 打印结果 print_section("分析结果") print_table(results) print_section("处理完成", char='*') except Exception as e: logger.error(f"Application error: {str(e)}") print("\n✗ 程序执行出错,请查看日志获取详细信息") raise if __name__ == "__main__": main()