|
import streamlit as st |
|
|
|
from transformers import pipeline |
|
import torch |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
pipe = pipeline("text-generation", model="openai-community/gpt2") |
|
|
|
|
|
|
|
st.title("GPT-2 Kampfsimulator") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") |
|
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") |
|
|
|
model.eval() |
|
return model, tokenizer |
|
except Exception as e: |
|
st.error(f"Fehler beim Laden des Modells: {str(e)}") |
|
return None, None |
|
|
|
model, tokenizer = load_model() |
|
st.write("Modell und Tokenizer erfolgreich geladen.") |
|
|
|
|
|
user_input = st.text_input( |
|
"Beschreibe den Kampf:", |
|
"Ein Schwertkämpfer trifft auf einen Bogenschützen in einer Arena." |
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
temperature = st.slider("Kreativität (Temperature)", 0.1, 1.0, 0.7) |
|
max_length = st.slider("Maximale Textlänge", 50, 200, 100) |
|
with col2: |
|
num_sequences = st.slider("Anzahl der Generierungen", 1, 3, 1) |
|
|
|
|
|
if st.button("Kampf simulieren"): |
|
if model and tokenizer: |
|
try: |
|
|
|
prompt = f"In einem epischen Kampf: {user_input}\nDer Kampf beginnt:" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True) |
|
|
|
|
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=max_length, |
|
num_return_sequences=num_sequences, |
|
temperature=temperature, |
|
pad_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=2, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95 |
|
) |
|
|
|
|
|
for idx, output in enumerate(outputs): |
|
generated_text = tokenizer.decode(output, skip_special_tokens=True) |
|
st.markdown(f"**Kampfszenario {idx + 1}:**") |
|
st.text_area( |
|
label=f"Generierter Text {idx + 1}", |
|
value=generated_text, |
|
height=150 |
|
) |
|
|
|
except Exception as e: |
|
st.error(f"Fehler bei der Textgenerierung: {str(e)}") |
|
else: |
|
st.error("Modell konnte nicht geladen werden. Bitte überprüfen Sie die Installation.") |
|
|
|
|
|
st.markdown(""" |
|
--- |
|
**Hinweise:** |
|
- Die "Kreativität" steuert, wie kreativ/zufällig die Ausgabe sein soll |
|
- Die "Maximale Textlänge" bestimmt die maximale Anzahl der generierten Token |
|
- "Anzahl der Generierungen" erstellt mehrere Varianten des Kampfes |
|
""") |