|
import os |
|
import sys |
|
import logging |
|
from datetime import datetime |
|
import torch |
|
from transformers import AutoProcessor, Pix2StructForConditionalGeneration |
|
from PIL import Image |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout), |
|
logging.FileHandler('app.log') |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def print_section(title, char='='): |
|
"""打印格式化的章节标题""" |
|
print(f"\n{char * 50}") |
|
print(f"{title.center(50)}") |
|
print(f"{char * 50}\n") |
|
|
|
def print_table(data): |
|
"""格式化打印表格数据""" |
|
if not data: |
|
print("No data available") |
|
return |
|
|
|
|
|
col_widths = [] |
|
for i in range(len(data[0])): |
|
col_width = max(len(str(row[i])) for row in data) |
|
col_widths.append(col_width) |
|
|
|
|
|
header = data[0] |
|
header_str = " | ".join(str(header[i]).ljust(col_widths[i]) for i in range(len(header))) |
|
print(header_str) |
|
print("-" * len(header_str)) |
|
|
|
|
|
for row in data[1:]: |
|
row_str = " | ".join(str(row[i]).ljust(col_widths[i]) for i in range(len(row))) |
|
print(row_str) |
|
|
|
class ChartAnalyzer: |
|
def __init__(self): |
|
try: |
|
print_section("初始化模型") |
|
print("正在加载模型和处理器...") |
|
self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot") |
|
self.processor = AutoProcessor.from_pretrained("google/deplot") |
|
print("✓ 模型加载完成") |
|
except Exception as e: |
|
print("✗ 模型加载失败") |
|
logger.error(f"Error initializing model: {str(e)}") |
|
raise |
|
|
|
def process_image(self, image_path, prompt=None): |
|
"""处理图片并生成数据表格""" |
|
try: |
|
print_section("图片处理", char='-') |
|
|
|
|
|
if not os.path.exists(image_path): |
|
raise FileNotFoundError(f"找不到图片文件: {image_path}") |
|
|
|
|
|
print(f"正在处理图片: {image_path}") |
|
image = Image.open(image_path) |
|
|
|
|
|
if prompt is None: |
|
prompt = "Generate underlying data table of the figure below:" |
|
|
|
inputs = self.processor( |
|
images=image, |
|
text=prompt, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
print("\n正在生成数据分析...") |
|
with torch.no_grad(): |
|
predictions = self.model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
num_beams=4, |
|
length_penalty=1.0 |
|
) |
|
|
|
|
|
raw_output = self.processor.decode(predictions[0], skip_special_tokens=True) |
|
|
|
|
|
split_by_newline = raw_output.split("<0x0A>") |
|
result_array = [] |
|
for item in split_by_newline: |
|
if item.strip(): |
|
result_array.append([x.strip() for x in item.split("|")]) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_file = f'results_{timestamp}.log' |
|
|
|
with open(output_file, mode='w', encoding='utf-8') as file: |
|
for row in result_array: |
|
file.write(" | ".join(row) + "\n") |
|
|
|
print(f"\n✓ 结果已保存至: {output_file}") |
|
return result_array |
|
|
|
except Exception as e: |
|
print("\n✗ 处理失败") |
|
logger.error(f"Error processing image: {str(e)}") |
|
raise |
|
|
|
def main(): |
|
try: |
|
print_section("图表数据提取系统", char='*') |
|
|
|
|
|
analyzer = ChartAnalyzer() |
|
|
|
|
|
image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg' |
|
|
|
|
|
results = analyzer.process_image(image_path) |
|
|
|
|
|
print_section("分析结果") |
|
print_table(results) |
|
|
|
print_section("处理完成", char='*') |
|
|
|
except Exception as e: |
|
logger.error(f"Application error: {str(e)}") |
|
print("\n✗ 程序执行出错,请查看日志获取详细信息") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|