import os # 1) Dynamo 완전 비활성화 os.environ["TORCH_DYNAMO_DISABLE"] = "1" # 2) Triton의 cudagraphs 최적화 비활성화 os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1" # (옵션) 경고 무시 설정 import warnings warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs") warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode") import torch # TensorFloat32 연산 활성화 (성능 최적화) torch.set_float32_matmul_precision('high') import torch._inductor torch._inductor.config.triton.cudagraphs = False import torch._dynamo # suppress_errors (오류 시 eager로 fallback) torch._dynamo.config.suppress_errors = True import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread from datasets import load_dataset import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer import pandas as pd import json from datetime import datetime import pyarrow.parquet as pq import pypdf import io import platform import subprocess import pytesseract from pdf2image import convert_from_path import queue import time # -------------------- PDF to Markdown 변환 관련 import -------------------- try: import re import requests from bs4 import BeautifulSoup import urllib.request import ocrmypdf import pytz import urllib.parse from pypdf import PdfReader except ModuleNotFoundError as e: raise ModuleNotFoundError( "필수 모듈이 누락되었습니다. 'beautifulsoup4' 패키지를 설치해주세요.\n" "예: pip install beautifulsoup4" ) # --------------------------------------------------------------------------- # 전역 변수 current_file_context = None # 환경 변수 설정 HF_TOKEN = os.environ.get("HF_TOKEN", None) MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024" MODEL_NAME = MODEL_ID.split("/")[-1] model = None # 전역에서 관리 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # (1) 위키피디아 데이터셋 로드 wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna") print("Wikipedia dataset loaded:", wiki_dataset) # (2) TF-IDF 벡터라이저 초기화 및 학습 (일부만 사용) print("TF-IDF 벡터화 시작...") questions = wiki_dataset['train']['question'][:10000] vectorizer = TfidfVectorizer(max_features=1000) question_vectors = vectorizer.fit_transform(questions) print("TF-IDF 벡터화 완료") # ------------------------- ChatHistory 클래스 ------------------------- class ChatHistory: def __init__(self): self.history = [] self.history_file = "/tmp/chat_history.json" self.load_history() def add_conversation(self, user_msg: str, assistant_msg: str): conversation = { "timestamp": datetime.now().isoformat(), "messages": [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": assistant_msg} ] } self.history.append(conversation) self.save_history() def format_for_display(self): formatted = [] for conv in self.history: formatted.append([ conv["messages"][0]["content"], conv["messages"][1]["content"] ]) return formatted def get_messages_for_api(self): messages = [] for conv in self.history: messages.extend([ {"role": "user", "content": conv["messages"][0]["content"]}, {"role": "assistant", "content": conv["messages"][1]["content"]} ]) return messages def clear_history(self): self.history = [] self.save_history() def save_history(self): try: with open(self.history_file, 'w', encoding='utf-8') as f: json.dump(self.history, f, ensure_ascii=False, indent=2) except Exception as e: print(f"히스토리 저장 실패: {e}") def load_history(self): try: if os.path.exists(self.history_file): with open(self.history_file, 'r', encoding='utf-8') as f: self.history = json.load(f) except Exception as e: print(f"히스토리 로드 실패: {e}") self.history = [] chat_history = ChatHistory() # ------------------------- 위키 문서 검색 (TF-IDF) ------------------------- def find_relevant_context(query, top_k=3): query_vector = vectorizer.transform([query]) similarities = (query_vector * question_vectors.T).toarray()[0] top_indices = np.argsort(similarities)[-top_k:][::-1] relevant_contexts = [] for idx in top_indices: if similarities[idx] > 0: relevant_contexts.append({ 'question': questions[idx], 'answer': wiki_dataset['train']['answer'][idx], 'similarity': similarities[idx] }) return relevant_contexts def init_msg(): return "파일을 분석하고 있습니다..." # -------------------- PDF 파일을 Markdown으로 변환하는 유틸 함수들 -------------------- def extract_text_from_pdf(reader: PdfReader) -> str: full_text = "" for idx, page in enumerate(reader.pages): text = page.extract_text() or "" if len(text) > 0: full_text += f"---- Page {idx+1} ----\n" + text + "\n\n" return full_text.strip() def convert_pdf_to_markdown(pdf_file: str): try: reader = PdfReader(pdf_file) except Exception as e: return f"PDF 파일을 읽는 중 오류 발생: {e}", None, None raw_meta = reader.metadata metadata = { "author": raw_meta.author if raw_meta else None, "creator": raw_meta.creator if raw_meta else None, "producer": raw_meta.producer if raw_meta else None, "subject": raw_meta.subject if raw_meta else None, "title": raw_meta.title if raw_meta else None, } full_text = extract_text_from_pdf(reader) image_count = sum(len(page.images) for page in reader.pages) if image_count > 0 and len(full_text) < 1000: try: out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf") ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True) reader_ocr = PdfReader(out_pdf_file) full_text = extract_text_from_pdf(reader_ocr) except Exception as e: full_text = f"OCR 처리 중 오류 발생: {e}\n\n원본 PDF 텍스트:\n\n" + full_text return full_text, metadata, pdf_file # ------------------------- 파일 분석 함수 ------------------------- def analyze_file_content(content, file_type): if file_type in ['parquet', 'csv']: try: lines = content.split('\n') header = lines[0] columns = header.count('|') - 1 rows = len(lines) - 3 return f"📊 Dataset Structure: {columns} columns, {rows} rows" except: return "❌ Failed to analyze dataset structure" lines = content.split('\n') total_lines = len(lines) non_empty_lines = len([line for line in lines if line.strip()]) if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']): functions = len([line for line in lines if 'def ' in line]) classes = len([line for line in lines if 'class ' in line]) imports = len([line for line in lines if 'import ' in line or 'from ' in line]) return f"💻 Code Structure: {total_lines} lines (Functions: {functions}, Classes: {classes}, Imports: {imports})" paragraphs = content.count('\n\n') + 1 words = len(content.split()) return f"📝 Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words" def read_uploaded_file(file): if file is None: return "", "" import pyarrow.parquet as pq import pandas as pd from tabulate import tabulate try: file_ext = os.path.splitext(file.name)[1].lower() if file_ext == '.parquet': try: table = pq.read_table(file.name) df = table.to_pandas() content = f"📊 Parquet File Analysis:\n\n" content += f"1. Basic Information:\n" content += f"- Total Rows: {len(df):,}\n" content += f"- Total Columns: {len(df.columns)}\n" mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024 content += f"- Memory Usage: {mem_usage:.2f} MB\n\n" content += f"2. Column Information:\n" for col in df.columns: content += f"- {col} ({df[col].dtype})\n" content += f"\n3. Data Preview:\n" content += tabulate(df.head(5), headers='keys', tablefmt='pipe', showindex=False) content += f"\n\n4. Missing Values:\n" null_counts = df.isnull().sum() for col, count in null_counts[null_counts > 0].items(): rate = count / len(df) * 100 content += f"- {col}: {count:,} ({rate:.1f}%)\n" numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns if len(numeric_cols) > 0: content += f"\n5. Numeric Column Statistics:\n" stats_df = df[numeric_cols].describe() content += tabulate(stats_df, headers='keys', tablefmt='pipe') return content, "parquet" except Exception as e: return f"Error reading Parquet file: {str(e)}", "error" elif file_ext == '.pdf': try: markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name) if metadata is None: return f"PDF 파일 변환 오류 또는 읽기 실패.\n\n원본 메시지:\n{markdown_text}", "error" content = "# PDF to Markdown Conversion\n\n" content += "## Metadata\n" for k, v in metadata.items(): content += f"**{k.capitalize()}**: {v}\n\n" content += "## Extracted Text\n\n" content += markdown_text return content, "pdf" except Exception as e: return f"Error reading PDF file: {str(e)}", "error" elif file_ext == '.csv': encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1'] for encoding in encodings: try: df = pd.read_csv(file.name, encoding=encoding) content = f"📊 CSV File Analysis:\n\n" content += f"1. Basic Information:\n" content += f"- Total Rows: {len(df):,}\n" content += f"- Total Columns: {len(df.columns)}\n" mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024 content += f"- Memory Usage: {mem_usage:.2f} MB\n\n" content += f"2. Column Information:\n" for col in df.columns: content += f"- {col} ({df[col].dtype})\n" content += f"\n3. Data Preview:\n" content += df.head(5).to_markdown(index=False) content += f"\n\n4. Missing Values:\n" null_counts = df.isnull().sum() for col, count in null_counts[null_counts > 0].items(): rate = count / len(df) * 100 content += f"- {col}: {count:,} ({rate:.1f}%)\n" return content, "csv" except UnicodeDecodeError: continue raise UnicodeDecodeError( f"Unable to read file with supported encodings ({', '.join(encodings)})" ) else: encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1'] for encoding in encodings: try: with open(file.name, 'r', encoding=encoding) as f: content = f.read() lines = content.split('\n') total_lines = len(lines) non_empty_lines = len([line for line in lines if line.strip()]) is_code = any( keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function'] ) analysis = "\n📝 File Analysis:\n" if is_code: functions = sum('def ' in line for line in lines) classes = sum('class ' in line for line in lines) imports = sum( ('import ' in line) or ('from ' in line) for line in lines ) analysis += f"- File Type: Code\n" analysis += f"- Total Lines: {total_lines:,}\n" analysis += f"- Functions: {functions}\n" analysis += f"- Classes: {classes}\n" analysis += f"- Import Statements: {imports}\n" else: words = len(content.split()) chars = len(content) analysis += f"- File Type: Text\n" analysis += f"- Total Lines: {total_lines:,}\n" analysis += f"- Non-empty Lines: {non_empty_lines:,}\n" analysis += f"- Word Count: {words:,}\n" analysis += f"- Character Count: {chars:,}\n" return content + analysis, "text" except UnicodeDecodeError: continue raise UnicodeDecodeError( f"Unable to read file with supported encodings ({', '.join(encodings)})" ) except Exception as e: return f"Error reading file: {str(e)}", "error" # ------------------------- CSS ------------------------- CSS = """ /* (생략: 동일) */ """ def clear_cuda_memory(): if hasattr(torch.cuda, 'empty_cache'): with torch.cuda.device('cuda'): torch.cuda.empty_cache() # ------------------------- 모델 로딩 함수 ------------------------- @spaces.GPU def load_model(): try: clear_cuda_memory() loaded_model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True, ) # (중요) 모델 기본 config에서도 캐시 사용 꺼둘 수 있음 loaded_model.config.use_cache = False return loaded_model except Exception as e: print(f"모델 로드 오류: {str(e)}") raise def build_prompt(conversation: list) -> str: prompt = "" for msg in conversation: if msg["role"] == "user": prompt += "User: " + msg["content"] + "\n" elif msg["role"] == "assistant": prompt += "Assistant: " + msg["content"] + "\n" prompt += "Assistant: " return prompt # ------------------------- 메시지 스트리밍 함수 ------------------------- @spaces.GPU def stream_chat( message: str, history: list, uploaded_file, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float ): global model, current_file_context try: if model is None: model = load_model() print(f'[User input] message: {message}') print(f'[History] {history}') # 1) 파일 업로드 처리 file_context = "" if uploaded_file and message == "파일을 분석하고 있습니다...": current_file_context = None try: content, file_type = read_uploaded_file(uploaded_file) if content: file_analysis = analyze_file_content(content, file_type) file_context = ( f"\n\n📄 파일 분석 결과:\n{file_analysis}" f"\n\n파일 내용:\n```\n{content}\n```" ) current_file_context = file_context message = "업로드된 파일을 분석해주세요." except Exception as e: print(f"[파일 분석 오류] {str(e)}") file_context = f"\n\n❌ 파일 분석 중 오류가 발생했습니다: {str(e)}" elif current_file_context: file_context = current_file_context # 2) 위키 컨텍스트 wiki_context = "" try: relevant_contexts = find_relevant_context(message) if relevant_contexts: wiki_context = "\n\n관련 위키피디아 정보:\n" for ctx in relevant_contexts: wiki_context += ( f"Q: {ctx['question']}\n" f"A: {ctx['answer']}\n" f"유사도: {ctx['similarity']:.3f}\n\n" ) except Exception as e: print(f"[컨텍스트 검색 오류] {str(e)}") # 3) 대화 이력 축소 max_history_length = 10 if len(history) > max_history_length: history = history[-max_history_length:] conversation = [] for prompt, answer in history: conversation.extend([ {"role": "user", "content": prompt}, {"role": "assistant", "content": answer} ]) # 4) 최종 메시지 final_message = message if file_context: final_message = file_context + "\n현재 질문: " + message if wiki_context: final_message = wiki_context + "\n현재 질문: " + message if file_context and wiki_context: final_message = file_context + wiki_context + "\n현재 질문: " + message conversation.append({"role": "user", "content": final_message}) # 5) 토큰화 input_ids_str = build_prompt(conversation) max_context = 8192 tokenized_input = tokenizer(input_ids_str, return_tensors="pt") input_length = tokenized_input["input_ids"].shape[1] # 6) 컨텍스트 초과 시 자르기 if input_length > max_context - max_new_tokens: print(f"[경고] 입력이 너무 깁니다: {input_length} 토큰 -> 잘라냄.") min_generation = min(256, max_new_tokens) new_desired_input_length = max_context - min_generation tokens = tokenizer.encode(input_ids_str) if len(tokens) > new_desired_input_length: tokens = tokens[-new_desired_input_length:] input_ids_str = tokenizer.decode(tokens) tokenized_input = tokenizer(input_ids_str, return_tensors="pt") input_length = tokenized_input["input_ids"].shape[1] print(f"[토큰 길이] {input_length}") inputs = tokenized_input.to("cuda") # 7) 남은 토큰 수로 max_new_tokens 보정 remaining = max_context - input_length if remaining < max_new_tokens: print(f"[max_new_tokens 조정] {max_new_tokens} -> {remaining}") max_new_tokens = remaining # 8) TextIteratorStreamer 설정 streamer = TextIteratorStreamer( tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True ) # ★ use_cache=False 설정 (중요) ★ generate_kwargs = dict( **inputs, streamer=streamer, top_k=top_k, top_p=top_p, repetition_penalty=penalty, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=False, # ← 여기가 핵심! ) clear_cuda_memory() # 9) 별도 스레드로 모델 호출 thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() # 10) 스트리밍 buffer = "" partial_message = "" last_yield_time = time.time() try: for new_text in streamer: buffer += new_text partial_message += new_text # 타이밍 or 일정 길이마다 UI 업데이트 current_time = time.time() if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20): yield "", history + [[message, buffer]] partial_message = "" last_yield_time = current_time # 마지막 출력 if buffer: yield "", history + [[message, buffer]] # 대화 히스토리 저장 chat_history.add_conversation(message, buffer) except Exception as e: print(f"[스트리밍 중 오류] {str(e)}") if not buffer: buffer = f"응답 생성 중 오류 발생: {str(e)}" yield "", history + [[message, buffer]] if thread.is_alive(): thread.join(timeout=5.0) clear_cuda_memory() except Exception as e: import traceback error_details = traceback.format_exc() error_message = f"오류가 발생했습니다: {str(e)}\n{error_details}" print(f"[Stream chat 오류] {error_message}") clear_cuda_memory() yield "", history + [[message, error_message]] # ------------------------- Gradio UI 구성 ------------------------- def create_demo(): with gr.Blocks(css=CSS) as demo: with gr.Column(elem_classes="markdown-style"): gr.Markdown(""" # 🤖 RAGOndevice #### 📊 RAG: Upload and Analyze Files (TXT, CSV, PDF, Parquet files) Upload your files for data analysis and learning """) chatbot = gr.Chatbot( value=[], height=600, label="GiniGEN AI Assistant", elem_classes="chat-container" ) with gr.Row(elem_classes="input-container"): with gr.Column(scale=1, min_width=70): file_upload = gr.File( type="filepath", elem_classes="file-upload-icon", scale=1, container=True, interactive=True, show_label=False ) with gr.Column(scale=3): msg = gr.Textbox( show_label=False, placeholder="Type your message here... 💭", container=False, elem_classes="input-textbox", scale=1 ) with gr.Column(scale=1, min_width=70): send = gr.Button( "Send", elem_classes="send-button custom-button", scale=1 ) with gr.Column(scale=1, min_width=70): clear = gr.Button( "Clear", elem_classes="clear-button custom-button", scale=1 ) # 고급 설정 with gr.Accordion("🎮 Advanced Settings", open=False): with gr.Row(): with gr.Column(scale=1): temperature = gr.Slider( minimum=0, maximum=1, step=0.1, value=0.8, label="Creativity Level 🎨" ) max_new_tokens = gr.Slider( minimum=128, maximum=8000, step=1, value=4000, label="Maximum Token Count 📝" ) with gr.Column(scale=1): top_p = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="Diversity Control 🎯" ) top_k = gr.Slider( minimum=1, maximum=20, step=1, value=20, label="Selection Range 📊" ) penalty = gr.Slider( minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty 🔄" ) # 예시 입력 gr.Examples( examples=[ ["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"], ["Please analyze this data and provide insights:\nAnnual Revenue (Million)\n2019: 1200\n2020: 980\n2021: 1450\n2022: 2100\n2023: 1890"], ["Please solve this math problem step by step: 'When a circle's area is twice that of its inscribed square, find the relationship between the circle's radius and the square's side length.'"], ["Please analyze this marketing campaign's ROI and suggest improvements:\nTotal Cost: $50,000\nReach: 1M users\nClick Rate: 2.3%\nConversion Rate: 0.8%\nAverage Purchase: $35"], ], inputs=msg ) # 대화 내용 초기화 def clear_conversation(): global current_file_context current_file_context = None return [], None, "Start a new conversation..." # 메시지 전송(Submit) msg.submit( stream_chat, inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], outputs=[msg, chatbot] ) send.click( stream_chat, inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], outputs=[msg, chatbot] ) # 파일 업로드 이벤트 file_upload.change( fn=lambda: ("처리 중...", [["시스템", "파일을 분석 중입니다. 잠시만 기다려주세요..."]]), outputs=[msg, chatbot], queue=False ).then( fn=init_msg, outputs=msg, queue=False ).then( fn=stream_chat, inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty], outputs=[msg, chatbot], queue=True ) # Clear 버튼 clear.click( fn=clear_conversation, outputs=[chatbot, file_upload, msg], queue=False ) return demo if __name__ == "__main__": demo = create_demo() demo.launch()