c2t / app.py
realkun's picture
Update app.py
eee8d68 verified
import os
import sys
import logging
from datetime import datetime
import torch
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from PIL import Image
# 配置日志格式
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('app.log')
]
)
logger = logging.getLogger(__name__)
def print_section(title, char='='):
"""打印格式化的章节标题"""
print(f"\n{char * 50}")
print(f"{title.center(50)}")
print(f"{char * 50}\n")
def print_table(data):
"""格式化打印表格数据"""
if not data:
print("No data available")
return
# 计算每列的最大宽度
col_widths = []
for i in range(len(data[0])):
col_width = max(len(str(row[i])) for row in data)
col_widths.append(col_width)
# 打印表头
header = data[0]
header_str = " | ".join(str(header[i]).ljust(col_widths[i]) for i in range(len(header)))
print(header_str)
print("-" * len(header_str))
# 打印数据行
for row in data[1:]:
row_str = " | ".join(str(row[i]).ljust(col_widths[i]) for i in range(len(row)))
print(row_str)
class ChartAnalyzer:
def __init__(self):
try:
print_section("初始化模型")
print("正在加载模型和处理器...")
self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
self.processor = AutoProcessor.from_pretrained("google/deplot")
print("✓ 模型加载完成")
except Exception as e:
print("✗ 模型加载失败")
logger.error(f"Error initializing model: {str(e)}")
raise
def process_image(self, image_path, prompt=None):
"""处理图片并生成数据表格"""
try:
print_section("图片处理", char='-')
# 验证文件存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"找不到图片文件: {image_path}")
# 打开并处理图片
print(f"正在处理图片: {image_path}")
image = Image.open(image_path)
# 准备输入
if prompt is None:
prompt = "Generate underlying data table of the figure below:"
inputs = self.processor(
images=image,
text=prompt,
return_tensors="pt"
)
# 生成预测
print("\n正在生成数据分析...")
with torch.no_grad():
predictions = self.model.generate(
**inputs,
max_new_tokens=512,
num_beams=4,
length_penalty=1.0
)
# 解码预测结果
raw_output = self.processor.decode(predictions[0], skip_special_tokens=True)
# 处理结果
split_by_newline = raw_output.split("<0x0A>")
result_array = []
for item in split_by_newline:
if item.strip(): # 跳过空行
result_array.append([x.strip() for x in item.split("|")])
# 保存结果
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f'results_{timestamp}.log'
with open(output_file, mode='w', encoding='utf-8') as file:
for row in result_array:
file.write(" | ".join(row) + "\n")
print(f"\n✓ 结果已保存至: {output_file}")
return result_array
except Exception as e:
print("\n✗ 处理失败")
logger.error(f"Error processing image: {str(e)}")
raise
def main():
try:
print_section("图表数据提取系统", char='*')
# 创建分析器实例
analyzer = ChartAnalyzer()
# 指定图片路径
image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg'
# 处理图片
results = analyzer.process_image(image_path)
# 打印结果
print_section("分析结果")
print_table(results)
print_section("处理完成", char='*')
except Exception as e:
logger.error(f"Application error: {str(e)}")
print("\n✗ 程序执行出错,请查看日志获取详细信息")
raise
if __name__ == "__main__":
main()