Chris4K commited on
Commit
f7b948e
·
verified ·
1 Parent(s): d72e108

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -5,7 +5,8 @@ import gradio as gr
5
  # Load Llama 3.2 model
6
  model_name = "meta-llama/Llama-3.2-3B-Instruct" # Replace with the exact model path
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
 
9
 
10
  # Helper function to process long contexts
11
  MAX_TOKENS = 100000 # Replace with the max token limit of the Llama model
@@ -131,7 +132,7 @@ def chat_with_model(user_input, chat_history=[]):
131
  print("prompt: ------------------------------------- \n"+prompt)
132
  input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=4096).to("cuda")
133
  tokenizer.pad_token = tokenizer.eos_token
134
- attention_mask = torch.ones_like(input_ids).to("cuda")
135
  outputs = model.generate(input_ids, attention_mask=attention_mask,
136
  max_new_tokens=1200, do_sample=True,
137
  top_k=50, temperature=0.7)
 
5
  # Load Llama 3.2 model
6
  model_name = "meta-llama/Llama-3.2-3B-Instruct" # Replace with the exact model path
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ #model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map=None, torch_dtype=torch.float32)
10
 
11
  # Helper function to process long contexts
12
  MAX_TOKENS = 100000 # Replace with the max token limit of the Llama model
 
132
  print("prompt: ------------------------------------- \n"+prompt)
133
  input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=4096).to("cuda")
134
  tokenizer.pad_token = tokenizer.eos_token
135
+ attention_mask = torch.ones_like(input_ids).to("cpu")
136
  outputs = model.generate(input_ids, attention_mask=attention_mask,
137
  max_new_tokens=1200, do_sample=True,
138
  top_k=50, temperature=0.7)