fffiloni commited on
Commit
e856606
·
verified ·
1 Parent(s): bd85948

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -8,12 +8,12 @@ import gradio as gr
8
  # Load the model and tokenizer
9
  model_path = "ByteDance/Sa2VA-4B"
10
 
11
- """
12
  from unittest.mock import patch
13
  from transformers.dynamic_module_utils import get_imports
14
 
15
  def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
16
- if not str(filename).endswith("modeling_phi3.py"):
17
  return get_imports(filename)
18
  imports = get_imports(filename)
19
  imports.remove("flash_attn")
@@ -21,14 +21,14 @@ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
21
 
22
 
23
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
24
- """
25
- model = AutoModel.from_pretrained(
26
- model_path,
27
- torch_dtype = torch.bfloat16,
28
- low_cpu_mem_usage = False,
29
- use_flash_attn = False,
30
- trust_remote_code = True
31
- ).eval().cuda()
32
 
33
  tokenizer = AutoTokenizer.from_pretrained(
34
  model_path,
 
8
  # Load the model and tokenizer
9
  model_path = "ByteDance/Sa2VA-4B"
10
 
11
+
12
  from unittest.mock import patch
13
  from transformers.dynamic_module_utils import get_imports
14
 
15
  def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
16
+ if not str(filename).endswith("flash_attention.py"):
17
  return get_imports(filename)
18
  imports = get_imports(filename)
19
  imports.remove("flash_attn")
 
21
 
22
 
23
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
24
+
25
+ model = AutoModel.from_pretrained(
26
+ model_path,
27
+ torch_dtype = torch.bfloat16,
28
+ low_cpu_mem_usage = False,
29
+ use_flash_attn = False,
30
+ trust_remote_code = True
31
+ ).eval().cuda()
32
 
33
  tokenizer = AutoTokenizer.from_pretrained(
34
  model_path,