Tom Aarsen commited on
Commit
40df588
1 Parent(s): 837f25a

Add **kwargs to Transformer to avoid issues when ST adds new arguments

Browse files
Files changed (1) hide show
  1. custom_st.py +9 -0
custom_st.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  import os
3
  from io import BytesIO
4
  from typing import Any, Dict, List, Optional, Tuple, Union
@@ -7,6 +8,8 @@ import torch
7
  from torch import nn
8
  from transformers import AutoConfig, AutoModel, AutoTokenizer
9
 
 
 
10
 
11
  class Transformer(nn.Module):
12
  """Huggingface AutoModel to generate token embeddings.
@@ -40,6 +43,7 @@ class Transformer(nn.Module):
40
  cache_dir: str = None,
41
  do_lower_case: bool = False,
42
  tokenizer_name_or_path: str = None,
 
43
  ) -> None:
44
  super().__init__()
45
  self.config_keys = ["max_seq_length", "do_lower_case"]
@@ -51,6 +55,11 @@ class Transformer(nn.Module):
51
  if config_args is None:
52
  config_args = {}
53
 
 
 
 
 
 
54
 
55
  self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
56
 
 
1
  import json
2
+ import logging
3
  import os
4
  from io import BytesIO
5
  from typing import Any, Dict, List, Optional, Tuple, Union
 
8
  from torch import nn
9
  from transformers import AutoConfig, AutoModel, AutoTokenizer
10
 
11
+ logger = logging.getLogger(__name__)
12
+
13
 
14
  class Transformer(nn.Module):
15
  """Huggingface AutoModel to generate token embeddings.
 
43
  cache_dir: str = None,
44
  do_lower_case: bool = False,
45
  tokenizer_name_or_path: str = None,
46
+ **kwargs,
47
  ) -> None:
48
  super().__init__()
49
  self.config_keys = ["max_seq_length", "do_lower_case"]
 
55
  if config_args is None:
56
  config_args = {}
57
 
58
+ if kwargs.get("backend", "torch") != "torch":
59
+ logger.warning(
60
+ f'"jinaai/jina-embeddings-v3" is currently not compatible with the {kwargs["backend"]} backend. '
61
+ 'Continuing with the "torch" backend.'
62
+ )
63
 
64
  self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
65