cheesyFishes
commited on
add backend handling
Browse files- custom_st.py +6 -0
custom_st.py
CHANGED
@@ -26,9 +26,15 @@ class Transformer(nn.Module):
|
|
26 |
processor_args: Optional[Dict[str, Any]] = None,
|
27 |
cache_dir: Optional[str] = None,
|
28 |
device: str = 'cuda:0',
|
|
|
29 |
**kwargs,
|
30 |
) -> None:
|
31 |
super(Transformer, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
self.device = device
|
34 |
self.dimension = dimension
|
|
|
26 |
processor_args: Optional[Dict[str, Any]] = None,
|
27 |
cache_dir: Optional[str] = None,
|
28 |
device: str = 'cuda:0',
|
29 |
+
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
30 |
**kwargs,
|
31 |
) -> None:
|
32 |
super(Transformer, self).__init__()
|
33 |
+
|
34 |
+
if backend != 'torch':
|
35 |
+
raise ValueError(
|
36 |
+
f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
|
37 |
+
)
|
38 |
|
39 |
self.device = device
|
40 |
self.dimension = dimension
|