Image-Text-to-Text
sentence-transformers
Safetensors
Transformers
qwen2_vl
Qwen2-VL
conversational
cheesyFishes commited on
Commit
1df3a64
·
verified ·
1 Parent(s): 6a23f44

update handling of init args

Browse files
Files changed (1) hide show
  1. custom_st.py +16 -5
custom_st.py CHANGED
@@ -22,6 +22,8 @@ class Transformer(nn.Module):
22
  min_pixels: int = 1 * 28 * 28,
23
  dimension: int = 2048,
24
  max_seq_length: Optional[int] = None,
 
 
25
  cache_dir: Optional[str] = None,
26
  device: str = 'cuda:0',
27
  **kwargs,
@@ -34,6 +36,17 @@ class Transformer(nn.Module):
34
  self.min_pixels = min_pixels
35
  self.max_seq_length = max_seq_length
36
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Initialize model
38
  try:
39
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -42,7 +55,7 @@ class Transformer(nn.Module):
42
  torch_dtype=torch.bfloat16,
43
  device_map=device,
44
  cache_dir=cache_dir,
45
- **kwargs
46
  ).eval()
47
  except (ImportError, ValueError) as e:
48
  print(f"Flash attention not available, falling back to default attention: {e}")
@@ -51,15 +64,13 @@ class Transformer(nn.Module):
51
  torch_dtype=torch.bfloat16,
52
  device_map=device,
53
  cache_dir=cache_dir,
54
- **kwargs
55
  ).eval()
56
 
57
  # Initialize processor
58
  self.processor = AutoProcessor.from_pretrained(
59
  processor_name_or_path or model_name_or_path,
60
- min_pixels=min_pixels,
61
- max_pixels=max_pixels,
62
- cache_dir=cache_dir
63
  )
64
 
65
  # Set padding sides
 
22
  min_pixels: int = 1 * 28 * 28,
23
  dimension: int = 2048,
24
  max_seq_length: Optional[int] = None,
25
+ model_args: Optional[Dict[str, Any]] = None,
26
+ processor_args: Optional[Dict[str, Any]] = None,
27
  cache_dir: Optional[str] = None,
28
  device: str = 'cuda:0',
29
  **kwargs,
 
36
  self.min_pixels = min_pixels
37
  self.max_seq_length = max_seq_length
38
 
39
+ # Handle args
40
+ model_kwargs = model_args or {}
41
+ model_kwargs.update(kwargs)
42
+
43
+ processor_kwargs = processor_args or {}
44
+ processor_kwargs.update({
45
+ 'min_pixels': min_pixels,
46
+ 'max_pixels': max_pixels,
47
+ 'cache_dir': cache_dir
48
+ })
49
+
50
  # Initialize model
51
  try:
52
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
 
55
  torch_dtype=torch.bfloat16,
56
  device_map=device,
57
  cache_dir=cache_dir,
58
+ **model_kwargs
59
  ).eval()
60
  except (ImportError, ValueError) as e:
61
  print(f"Flash attention not available, falling back to default attention: {e}")
 
64
  torch_dtype=torch.bfloat16,
65
  device_map=device,
66
  cache_dir=cache_dir,
67
+ **model_kwargs
68
  ).eval()
69
 
70
  # Initialize processor
71
  self.processor = AutoProcessor.from_pretrained(
72
  processor_name_or_path or model_name_or_path,
73
+ **processor_kwargs
 
 
74
  )
75
 
76
  # Set padding sides