cheesyFishes
commited on
fix device handling
Browse files- custom_st.py +1 -5
custom_st.py
CHANGED
@@ -27,7 +27,6 @@ class Transformer(nn.Module):
|
|
27 |
tokenizer_args: Optional[Dict[str, Any]] = None,
|
28 |
config_args: Optional[Dict[str, Any]] = None,
|
29 |
cache_dir: Optional[str] = None,
|
30 |
-
device: str = 'cpu',
|
31 |
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
32 |
**kwargs,
|
33 |
) -> None:
|
@@ -38,7 +37,6 @@ class Transformer(nn.Module):
|
|
38 |
f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
|
39 |
)
|
40 |
|
41 |
-
self.device = device
|
42 |
self.dimension = dimension
|
43 |
self.max_pixels = max_pixels
|
44 |
self.min_pixels = min_pixels
|
@@ -160,15 +158,13 @@ class Transformer(nn.Module):
|
|
160 |
def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
|
161 |
processed_texts, processed_images = self._process_input(texts)
|
162 |
|
163 |
-
|
164 |
text=processed_texts,
|
165 |
images=processed_images,
|
166 |
videos=None,
|
167 |
padding=padding,
|
168 |
return_tensors='pt'
|
169 |
)
|
170 |
-
|
171 |
-
return {k: v.to(self.device) for k, v in inputs.items()}
|
172 |
|
173 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
174 |
"""Save the model, tokenizer and processor to the given path."""
|
|
|
27 |
tokenizer_args: Optional[Dict[str, Any]] = None,
|
28 |
config_args: Optional[Dict[str, Any]] = None,
|
29 |
cache_dir: Optional[str] = None,
|
|
|
30 |
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
31 |
**kwargs,
|
32 |
) -> None:
|
|
|
37 |
f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
|
38 |
)
|
39 |
|
|
|
40 |
self.dimension = dimension
|
41 |
self.max_pixels = max_pixels
|
42 |
self.min_pixels = min_pixels
|
|
|
158 |
def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
|
159 |
processed_texts, processed_images = self._process_input(texts)
|
160 |
|
161 |
+
return self.processor(
|
162 |
text=processed_texts,
|
163 |
images=processed_images,
|
164 |
videos=None,
|
165 |
padding=padding,
|
166 |
return_tensors='pt'
|
167 |
)
|
|
|
|
|
168 |
|
169 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
170 |
"""Save the model, tokenizer and processor to the given path."""
|