席亚东 commited on
Commit
ef2abea
1 Parent(s): 7f20926

update readme

Browse files
Files changed (4) hide show
  1. README.md +113 -0
  2. checkpoint_weight_index.json +584 -0
  3. dict.txt +0 -0
  4. inference.py +205 -0
README.md CHANGED
@@ -1,3 +1,116 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+
4
+ language: zh
5
+ inference: false
6
+ tags:
7
+ - text-generation
8
+ - dialogue-generation
9
+ - pytorch
10
+ - inference acceleration
11
+ - gpt2
12
+ - gpt3
13
  ---
14
+ # YuYan-Dialogue
15
+
16
+ YuYan is a series of Chinese language models with different size, developed by Fuxi AI lab, Netease.Inc. They are trained on a large Chinese novel dataset of high quality.
17
+
18
+ YuYan is in the same family of decoder-only models like [GPT2 and GPT-3](https://arxiv.org/abs/2005.14165). As such, it was pretrained using the self-supervised causal language modedling objective.
19
+
20
+ YuYan-Dialogue is a dialogue model by fine-tuning the YuYan-11b on a large multi-turn dialogue dataset of high quality. It has very strong conversation generation capabilities.
21
+
22
+ ## Model Inference Acceleration
23
+
24
+ As the model size increases, the model inference time increases and more computational resources are required.
25
+
26
+ Therefore, we developed our own transformer model inference acceleration framework, [EET](https://github.com/NetEase-FuXi/EET.git). More details are in [Easy and Efficient Transformer: Scalable Inference Solution For Large NLP Model](https://aclanthology.org/2022.naacl-industry.8/).
27
+
28
+ We combine our language model with the EET inference framework to provide industrial-grade inference reasoning performance.
29
+
30
+ ## How to use
31
+
32
+ Our model is trained based on the [fairseq](https://github.com/facebookresearch/fairseq). As a result, the inference and finetuning depend on it.
33
+
34
+ For inference, we modify some parts of the original fairseq codes. Mainly
35
+ > fairseq-0.12.2/fairseq/sequence_generator.py
36
+
37
+ We integrate the EET with sequence_generator. We replace the eos token to a token unlikely to be sampled to ensure the generated text length. The repetition penalty trick is also modified. You can change the penalty strength by adjusting the value of `self.ban_weight`.
38
+
39
+ Then, to keep the eos token in the final generated text, we change the line 75 `include_eos=False` to `include_eos=True` in
40
+ > fairseq-0.12.2/fairseq/data/dictionary.py
41
+
42
+ Finally, to pass in parameters in python scripts, we remove the line 67 ~ line 69 in
43
+ >fairseq-0.12.2/fairseq/dataclass/utils.py
44
+
45
+ Below are the install tutorial.
46
+
47
+ ```
48
+ # install pytorch
49
+ pip install torch==1.8.1 # install pytorch
50
+
51
+ # install fairseq
52
+ unzip fairseq-0.12.2.zip
53
+ cd fairseq-0.12.2
54
+ pip install.
55
+
56
+ # install EET
57
+ git clone https://github.com/NetEase-FuXi/EET.git
58
+ cd EET
59
+ pip install .
60
+
61
+ # install transformers (EET requirements)
62
+ pip install transformers==4.23
63
+
64
+ # make a folder, move the dictionary file and model file into it.
65
+ mkdir transformer_lm_gpt2_xxl_dialogue
66
+ mv dict.txt transformer_lm_gpt2_xxl_dialogue/
67
+ mv checkpoint_best_part_*.pt transformer_lm_gpt2_xxl_dialogue/
68
+
69
+ ```
70
+ `inference.py` is a script to provide a interface to initialize the EET object and sequence_generator. It includes some pre-process and post-process functions for text input and output. You can modify the script according to your needs.
71
+
72
+ In addition, it provide a simple object to organize the dialogue generation and dialogue history.
73
+
74
+ After the environment is ready, several lines of codes can realize the inference.
75
+
76
+ ``` python
77
+
78
+ from inference import Inference
79
+ model_path = "transformer_lm_gpt2_xxl_dialogue/checkpoint_best.pt"
80
+ data_path = "transformer_lm_gpt2_xxl_dialogue"
81
+ eet_batch_size = 10 # max inference batch size, adjust according to cuda memory, 40GB memory is necessary
82
+ inference = Inference(model_path, data_path, eet_batch_size)
83
+ dialogue_model = Dialogue(inference)
84
+ dialogue_model.get_repsonse("你好啊")
85
+ ```
86
+ ## Citation
87
+ If you find the technical report or resource is useful, please cite the following technical report in your paper.
88
+ - https://aclanthology.org/2022.naacl-industry.8/
89
+ ```
90
+ @inproceedings{li-etal-2022-easy,
91
+ title = "Easy and Efficient Transformer: Scalable Inference Solution For Large {NLP} Model",
92
+ author = "Li, Gongzheng and
93
+ Xi, Yadong and
94
+ Ding, Jingzhen and
95
+ Wang, Duan and
96
+ Luo, Ziyang and
97
+ Zhang, Rongsheng and
98
+ Liu, Bai and
99
+ Fan, Changjie and
100
+ Mao, Xiaoxi and
101
+ Zhao, Zeng",
102
+ booktitle = "Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies: Industry Track",
103
+ month = jul,
104
+ year = "2022",
105
+ address = "Hybrid: Seattle, Washington + Online",
106
+ publisher = "Association for Computational Linguistics",
107
+ url = "https://aclanthology.org/2022.naacl-industry.8",
108
+ doi = "10.18653/v1/2022.naacl-industry.8",
109
+ pages = "62--68"
110
+ }
111
+
112
+ ```
113
+ ## Contact Us
114
+ You can also contact us by email:
115
+
116
checkpoint_weight_index.json ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "decoder.version": "checkpoint_best_part_1.pt",
3
+ "decoder.embed_tokens.weight": "checkpoint_best_part_1.pt",
4
+ "decoder.embed_positions._float_tensor": "checkpoint_best_part_1.pt",
5
+ "decoder.layers.0.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
6
+ "decoder.layers.0.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
7
+ "decoder.layers.0.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
8
+ "decoder.layers.0.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
9
+ "decoder.layers.0.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
10
+ "decoder.layers.0.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
11
+ "decoder.layers.0.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
12
+ "decoder.layers.0.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
13
+ "decoder.layers.0.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
14
+ "decoder.layers.0.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
15
+ "decoder.layers.0.fc1.weight": "checkpoint_best_part_1.pt",
16
+ "decoder.layers.0.fc1.bias": "checkpoint_best_part_1.pt",
17
+ "decoder.layers.0.fc2.weight": "checkpoint_best_part_1.pt",
18
+ "decoder.layers.0.fc2.bias": "checkpoint_best_part_1.pt",
19
+ "decoder.layers.0.final_layer_norm.weight": "checkpoint_best_part_1.pt",
20
+ "decoder.layers.0.final_layer_norm.bias": "checkpoint_best_part_1.pt",
21
+ "decoder.layers.1.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
22
+ "decoder.layers.1.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
23
+ "decoder.layers.1.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
24
+ "decoder.layers.1.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
25
+ "decoder.layers.1.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
26
+ "decoder.layers.1.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
27
+ "decoder.layers.1.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
28
+ "decoder.layers.1.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
29
+ "decoder.layers.1.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
30
+ "decoder.layers.1.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
31
+ "decoder.layers.1.fc1.weight": "checkpoint_best_part_1.pt",
32
+ "decoder.layers.1.fc1.bias": "checkpoint_best_part_1.pt",
33
+ "decoder.layers.1.fc2.weight": "checkpoint_best_part_1.pt",
34
+ "decoder.layers.1.fc2.bias": "checkpoint_best_part_1.pt",
35
+ "decoder.layers.1.final_layer_norm.weight": "checkpoint_best_part_1.pt",
36
+ "decoder.layers.1.final_layer_norm.bias": "checkpoint_best_part_1.pt",
37
+ "decoder.layer_norm.weight": "checkpoint_best_part_1.pt",
38
+ "decoder.layer_norm.bias": "checkpoint_best_part_1.pt",
39
+ "decoder.output_projection.weight": "checkpoint_best_part_1.pt",
40
+ "decoder.layers.2.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
41
+ "decoder.layers.3.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
42
+ "decoder.layers.4.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
43
+ "decoder.layers.5.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
44
+ "decoder.layers.6.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
45
+ "decoder.layers.7.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
46
+ "decoder.layers.8.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
47
+ "decoder.layers.9.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
48
+ "decoder.layers.10.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
49
+ "decoder.layers.11.self_attn.k_proj.weight": "checkpoint_best_part_1.pt",
50
+ "decoder.layers.2.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
51
+ "decoder.layers.3.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
52
+ "decoder.layers.4.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
53
+ "decoder.layers.5.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
54
+ "decoder.layers.6.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
55
+ "decoder.layers.7.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
56
+ "decoder.layers.8.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
57
+ "decoder.layers.9.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
58
+ "decoder.layers.10.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
59
+ "decoder.layers.11.self_attn.k_proj.bias": "checkpoint_best_part_1.pt",
60
+ "decoder.layers.2.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
61
+ "decoder.layers.3.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
62
+ "decoder.layers.4.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
63
+ "decoder.layers.5.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
64
+ "decoder.layers.6.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
65
+ "decoder.layers.7.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
66
+ "decoder.layers.8.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
67
+ "decoder.layers.9.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
68
+ "decoder.layers.10.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
69
+ "decoder.layers.11.self_attn.v_proj.weight": "checkpoint_best_part_1.pt",
70
+ "decoder.layers.2.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
71
+ "decoder.layers.3.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
72
+ "decoder.layers.4.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
73
+ "decoder.layers.5.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
74
+ "decoder.layers.6.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
75
+ "decoder.layers.7.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
76
+ "decoder.layers.8.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
77
+ "decoder.layers.9.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
78
+ "decoder.layers.10.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
79
+ "decoder.layers.11.self_attn.v_proj.bias": "checkpoint_best_part_1.pt",
80
+ "decoder.layers.2.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
81
+ "decoder.layers.3.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
82
+ "decoder.layers.4.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
83
+ "decoder.layers.5.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
84
+ "decoder.layers.6.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
85
+ "decoder.layers.7.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
86
+ "decoder.layers.8.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
87
+ "decoder.layers.9.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
88
+ "decoder.layers.10.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
89
+ "decoder.layers.11.self_attn.q_proj.weight": "checkpoint_best_part_1.pt",
90
+ "decoder.layers.2.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
91
+ "decoder.layers.3.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
92
+ "decoder.layers.4.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
93
+ "decoder.layers.5.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
94
+ "decoder.layers.6.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
95
+ "decoder.layers.7.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
96
+ "decoder.layers.8.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
97
+ "decoder.layers.9.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
98
+ "decoder.layers.10.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
99
+ "decoder.layers.11.self_attn.q_proj.bias": "checkpoint_best_part_1.pt",
100
+ "decoder.layers.2.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
101
+ "decoder.layers.3.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
102
+ "decoder.layers.4.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
103
+ "decoder.layers.5.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
104
+ "decoder.layers.6.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
105
+ "decoder.layers.7.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
106
+ "decoder.layers.8.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
107
+ "decoder.layers.9.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
108
+ "decoder.layers.10.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
109
+ "decoder.layers.11.self_attn.out_proj.weight": "checkpoint_best_part_1.pt",
110
+ "decoder.layers.2.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
111
+ "decoder.layers.3.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
112
+ "decoder.layers.4.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
113
+ "decoder.layers.5.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
114
+ "decoder.layers.6.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
115
+ "decoder.layers.7.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
116
+ "decoder.layers.8.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
117
+ "decoder.layers.9.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
118
+ "decoder.layers.10.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
119
+ "decoder.layers.11.self_attn.out_proj.bias": "checkpoint_best_part_1.pt",
120
+ "decoder.layers.2.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
121
+ "decoder.layers.3.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
122
+ "decoder.layers.4.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
123
+ "decoder.layers.5.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
124
+ "decoder.layers.6.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
125
+ "decoder.layers.7.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
126
+ "decoder.layers.8.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
127
+ "decoder.layers.9.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
128
+ "decoder.layers.10.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
129
+ "decoder.layers.11.self_attn_layer_norm.weight": "checkpoint_best_part_1.pt",
130
+ "decoder.layers.2.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
131
+ "decoder.layers.3.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
132
+ "decoder.layers.4.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
133
+ "decoder.layers.5.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
134
+ "decoder.layers.6.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
135
+ "decoder.layers.7.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
136
+ "decoder.layers.8.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
137
+ "decoder.layers.9.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
138
+ "decoder.layers.10.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
139
+ "decoder.layers.11.self_attn_layer_norm.bias": "checkpoint_best_part_1.pt",
140
+ "decoder.layers.2.fc1.weight": "checkpoint_best_part_1.pt",
141
+ "decoder.layers.3.fc1.weight": "checkpoint_best_part_1.pt",
142
+ "decoder.layers.4.fc1.weight": "checkpoint_best_part_1.pt",
143
+ "decoder.layers.5.fc1.weight": "checkpoint_best_part_1.pt",
144
+ "decoder.layers.6.fc1.weight": "checkpoint_best_part_1.pt",
145
+ "decoder.layers.7.fc1.weight": "checkpoint_best_part_1.pt",
146
+ "decoder.layers.8.fc1.weight": "checkpoint_best_part_1.pt",
147
+ "decoder.layers.9.fc1.weight": "checkpoint_best_part_1.pt",
148
+ "decoder.layers.10.fc1.weight": "checkpoint_best_part_1.pt",
149
+ "decoder.layers.11.fc1.weight": "checkpoint_best_part_1.pt",
150
+ "decoder.layers.2.fc1.bias": "checkpoint_best_part_1.pt",
151
+ "decoder.layers.3.fc1.bias": "checkpoint_best_part_1.pt",
152
+ "decoder.layers.4.fc1.bias": "checkpoint_best_part_1.pt",
153
+ "decoder.layers.5.fc1.bias": "checkpoint_best_part_1.pt",
154
+ "decoder.layers.6.fc1.bias": "checkpoint_best_part_1.pt",
155
+ "decoder.layers.7.fc1.bias": "checkpoint_best_part_1.pt",
156
+ "decoder.layers.8.fc1.bias": "checkpoint_best_part_1.pt",
157
+ "decoder.layers.9.fc1.bias": "checkpoint_best_part_1.pt",
158
+ "decoder.layers.10.fc1.bias": "checkpoint_best_part_1.pt",
159
+ "decoder.layers.11.fc1.bias": "checkpoint_best_part_1.pt",
160
+ "decoder.layers.2.fc2.weight": "checkpoint_best_part_1.pt",
161
+ "decoder.layers.3.fc2.weight": "checkpoint_best_part_1.pt",
162
+ "decoder.layers.4.fc2.weight": "checkpoint_best_part_1.pt",
163
+ "decoder.layers.5.fc2.weight": "checkpoint_best_part_1.pt",
164
+ "decoder.layers.6.fc2.weight": "checkpoint_best_part_1.pt",
165
+ "decoder.layers.7.fc2.weight": "checkpoint_best_part_1.pt",
166
+ "decoder.layers.8.fc2.weight": "checkpoint_best_part_1.pt",
167
+ "decoder.layers.9.fc2.weight": "checkpoint_best_part_1.pt",
168
+ "decoder.layers.10.fc2.weight": "checkpoint_best_part_1.pt",
169
+ "decoder.layers.11.fc2.weight": "checkpoint_best_part_1.pt",
170
+ "decoder.layers.2.fc2.bias": "checkpoint_best_part_1.pt",
171
+ "decoder.layers.3.fc2.bias": "checkpoint_best_part_1.pt",
172
+ "decoder.layers.4.fc2.bias": "checkpoint_best_part_1.pt",
173
+ "decoder.layers.5.fc2.bias": "checkpoint_best_part_1.pt",
174
+ "decoder.layers.6.fc2.bias": "checkpoint_best_part_1.pt",
175
+ "decoder.layers.7.fc2.bias": "checkpoint_best_part_1.pt",
176
+ "decoder.layers.8.fc2.bias": "checkpoint_best_part_1.pt",
177
+ "decoder.layers.9.fc2.bias": "checkpoint_best_part_1.pt",
178
+ "decoder.layers.10.fc2.bias": "checkpoint_best_part_1.pt",
179
+ "decoder.layers.11.fc2.bias": "checkpoint_best_part_1.pt",
180
+ "decoder.layers.2.final_layer_norm.weight": "checkpoint_best_part_1.pt",
181
+ "decoder.layers.3.final_layer_norm.weight": "checkpoint_best_part_1.pt",
182
+ "decoder.layers.4.final_layer_norm.weight": "checkpoint_best_part_1.pt",
183
+ "decoder.layers.5.final_layer_norm.weight": "checkpoint_best_part_1.pt",
184
+ "decoder.layers.6.final_layer_norm.weight": "checkpoint_best_part_1.pt",
185
+ "decoder.layers.7.final_layer_norm.weight": "checkpoint_best_part_1.pt",
186
+ "decoder.layers.8.final_layer_norm.weight": "checkpoint_best_part_1.pt",
187
+ "decoder.layers.9.final_layer_norm.weight": "checkpoint_best_part_1.pt",
188
+ "decoder.layers.10.final_layer_norm.weight": "checkpoint_best_part_1.pt",
189
+ "decoder.layers.11.final_layer_norm.weight": "checkpoint_best_part_1.pt",
190
+ "decoder.layers.2.final_layer_norm.bias": "checkpoint_best_part_1.pt",
191
+ "decoder.layers.3.final_layer_norm.bias": "checkpoint_best_part_1.pt",
192
+ "decoder.layers.4.final_layer_norm.bias": "checkpoint_best_part_1.pt",
193
+ "decoder.layers.5.final_layer_norm.bias": "checkpoint_best_part_1.pt",
194
+ "decoder.layers.6.final_layer_norm.bias": "checkpoint_best_part_1.pt",
195
+ "decoder.layers.7.final_layer_norm.bias": "checkpoint_best_part_1.pt",
196
+ "decoder.layers.8.final_layer_norm.bias": "checkpoint_best_part_1.pt",
197
+ "decoder.layers.9.final_layer_norm.bias": "checkpoint_best_part_1.pt",
198
+ "decoder.layers.10.final_layer_norm.bias": "checkpoint_best_part_1.pt",
199
+ "decoder.layers.11.final_layer_norm.bias": "checkpoint_best_part_1.pt",
200
+ "decoder.layers.12.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
201
+ "decoder.layers.13.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
202
+ "decoder.layers.14.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
203
+ "decoder.layers.15.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
204
+ "decoder.layers.16.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
205
+ "decoder.layers.17.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
206
+ "decoder.layers.18.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
207
+ "decoder.layers.19.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
208
+ "decoder.layers.20.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
209
+ "decoder.layers.21.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
210
+ "decoder.layers.22.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
211
+ "decoder.layers.23.self_attn.k_proj.weight": "checkpoint_best_part_2.pt",
212
+ "decoder.layers.12.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
213
+ "decoder.layers.13.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
214
+ "decoder.layers.14.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
215
+ "decoder.layers.15.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
216
+ "decoder.layers.16.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
217
+ "decoder.layers.17.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
218
+ "decoder.layers.18.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
219
+ "decoder.layers.19.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
220
+ "decoder.layers.20.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
221
+ "decoder.layers.21.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
222
+ "decoder.layers.22.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
223
+ "decoder.layers.23.self_attn.k_proj.bias": "checkpoint_best_part_2.pt",
224
+ "decoder.layers.12.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
225
+ "decoder.layers.13.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
226
+ "decoder.layers.14.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
227
+ "decoder.layers.15.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
228
+ "decoder.layers.16.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
229
+ "decoder.layers.17.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
230
+ "decoder.layers.18.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
231
+ "decoder.layers.19.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
232
+ "decoder.layers.20.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
233
+ "decoder.layers.21.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
234
+ "decoder.layers.22.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
235
+ "decoder.layers.23.self_attn.v_proj.weight": "checkpoint_best_part_2.pt",
236
+ "decoder.layers.12.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
237
+ "decoder.layers.13.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
238
+ "decoder.layers.14.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
239
+ "decoder.layers.15.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
240
+ "decoder.layers.16.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
241
+ "decoder.layers.17.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
242
+ "decoder.layers.18.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
243
+ "decoder.layers.19.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
244
+ "decoder.layers.20.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
245
+ "decoder.layers.21.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
246
+ "decoder.layers.22.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
247
+ "decoder.layers.23.self_attn.v_proj.bias": "checkpoint_best_part_2.pt",
248
+ "decoder.layers.12.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
249
+ "decoder.layers.13.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
250
+ "decoder.layers.14.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
251
+ "decoder.layers.15.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
252
+ "decoder.layers.16.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
253
+ "decoder.layers.17.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
254
+ "decoder.layers.18.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
255
+ "decoder.layers.19.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
256
+ "decoder.layers.20.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
257
+ "decoder.layers.21.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
258
+ "decoder.layers.22.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
259
+ "decoder.layers.23.self_attn.q_proj.weight": "checkpoint_best_part_2.pt",
260
+ "decoder.layers.12.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
261
+ "decoder.layers.13.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
262
+ "decoder.layers.14.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
263
+ "decoder.layers.15.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
264
+ "decoder.layers.16.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
265
+ "decoder.layers.17.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
266
+ "decoder.layers.18.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
267
+ "decoder.layers.19.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
268
+ "decoder.layers.20.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
269
+ "decoder.layers.21.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
270
+ "decoder.layers.22.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
271
+ "decoder.layers.23.self_attn.q_proj.bias": "checkpoint_best_part_2.pt",
272
+ "decoder.layers.12.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
273
+ "decoder.layers.13.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
274
+ "decoder.layers.14.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
275
+ "decoder.layers.15.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
276
+ "decoder.layers.16.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
277
+ "decoder.layers.17.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
278
+ "decoder.layers.18.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
279
+ "decoder.layers.19.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
280
+ "decoder.layers.20.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
281
+ "decoder.layers.21.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
282
+ "decoder.layers.22.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
283
+ "decoder.layers.23.self_attn.out_proj.weight": "checkpoint_best_part_2.pt",
284
+ "decoder.layers.12.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
285
+ "decoder.layers.13.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
286
+ "decoder.layers.14.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
287
+ "decoder.layers.15.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
288
+ "decoder.layers.16.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
289
+ "decoder.layers.17.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
290
+ "decoder.layers.18.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
291
+ "decoder.layers.19.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
292
+ "decoder.layers.20.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
293
+ "decoder.layers.21.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
294
+ "decoder.layers.22.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
295
+ "decoder.layers.23.self_attn.out_proj.bias": "checkpoint_best_part_2.pt",
296
+ "decoder.layers.12.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
297
+ "decoder.layers.13.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
298
+ "decoder.layers.14.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
299
+ "decoder.layers.15.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
300
+ "decoder.layers.16.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
301
+ "decoder.layers.17.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
302
+ "decoder.layers.18.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
303
+ "decoder.layers.19.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
304
+ "decoder.layers.20.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
305
+ "decoder.layers.21.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
306
+ "decoder.layers.22.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
307
+ "decoder.layers.23.self_attn_layer_norm.weight": "checkpoint_best_part_2.pt",
308
+ "decoder.layers.12.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
309
+ "decoder.layers.13.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
310
+ "decoder.layers.14.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
311
+ "decoder.layers.15.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
312
+ "decoder.layers.16.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
313
+ "decoder.layers.17.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
314
+ "decoder.layers.18.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
315
+ "decoder.layers.19.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
316
+ "decoder.layers.20.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
317
+ "decoder.layers.21.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
318
+ "decoder.layers.22.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
319
+ "decoder.layers.23.self_attn_layer_norm.bias": "checkpoint_best_part_2.pt",
320
+ "decoder.layers.12.fc1.weight": "checkpoint_best_part_2.pt",
321
+ "decoder.layers.13.fc1.weight": "checkpoint_best_part_2.pt",
322
+ "decoder.layers.14.fc1.weight": "checkpoint_best_part_2.pt",
323
+ "decoder.layers.15.fc1.weight": "checkpoint_best_part_2.pt",
324
+ "decoder.layers.16.fc1.weight": "checkpoint_best_part_2.pt",
325
+ "decoder.layers.17.fc1.weight": "checkpoint_best_part_2.pt",
326
+ "decoder.layers.18.fc1.weight": "checkpoint_best_part_2.pt",
327
+ "decoder.layers.19.fc1.weight": "checkpoint_best_part_2.pt",
328
+ "decoder.layers.20.fc1.weight": "checkpoint_best_part_2.pt",
329
+ "decoder.layers.21.fc1.weight": "checkpoint_best_part_2.pt",
330
+ "decoder.layers.22.fc1.weight": "checkpoint_best_part_2.pt",
331
+ "decoder.layers.23.fc1.weight": "checkpoint_best_part_2.pt",
332
+ "decoder.layers.12.fc1.bias": "checkpoint_best_part_2.pt",
333
+ "decoder.layers.13.fc1.bias": "checkpoint_best_part_2.pt",
334
+ "decoder.layers.14.fc1.bias": "checkpoint_best_part_2.pt",
335
+ "decoder.layers.15.fc1.bias": "checkpoint_best_part_2.pt",
336
+ "decoder.layers.16.fc1.bias": "checkpoint_best_part_2.pt",
337
+ "decoder.layers.17.fc1.bias": "checkpoint_best_part_2.pt",
338
+ "decoder.layers.18.fc1.bias": "checkpoint_best_part_2.pt",
339
+ "decoder.layers.19.fc1.bias": "checkpoint_best_part_2.pt",
340
+ "decoder.layers.20.fc1.bias": "checkpoint_best_part_2.pt",
341
+ "decoder.layers.21.fc1.bias": "checkpoint_best_part_2.pt",
342
+ "decoder.layers.22.fc1.bias": "checkpoint_best_part_2.pt",
343
+ "decoder.layers.23.fc1.bias": "checkpoint_best_part_2.pt",
344
+ "decoder.layers.12.fc2.weight": "checkpoint_best_part_2.pt",
345
+ "decoder.layers.13.fc2.weight": "checkpoint_best_part_2.pt",
346
+ "decoder.layers.14.fc2.weight": "checkpoint_best_part_2.pt",
347
+ "decoder.layers.15.fc2.weight": "checkpoint_best_part_2.pt",
348
+ "decoder.layers.16.fc2.weight": "checkpoint_best_part_2.pt",
349
+ "decoder.layers.17.fc2.weight": "checkpoint_best_part_2.pt",
350
+ "decoder.layers.18.fc2.weight": "checkpoint_best_part_2.pt",
351
+ "decoder.layers.19.fc2.weight": "checkpoint_best_part_2.pt",
352
+ "decoder.layers.20.fc2.weight": "checkpoint_best_part_2.pt",
353
+ "decoder.layers.21.fc2.weight": "checkpoint_best_part_2.pt",
354
+ "decoder.layers.22.fc2.weight": "checkpoint_best_part_2.pt",
355
+ "decoder.layers.23.fc2.weight": "checkpoint_best_part_2.pt",
356
+ "decoder.layers.12.fc2.bias": "checkpoint_best_part_2.pt",
357
+ "decoder.layers.13.fc2.bias": "checkpoint_best_part_2.pt",
358
+ "decoder.layers.14.fc2.bias": "checkpoint_best_part_2.pt",
359
+ "decoder.layers.15.fc2.bias": "checkpoint_best_part_2.pt",
360
+ "decoder.layers.16.fc2.bias": "checkpoint_best_part_2.pt",
361
+ "decoder.layers.17.fc2.bias": "checkpoint_best_part_2.pt",
362
+ "decoder.layers.18.fc2.bias": "checkpoint_best_part_2.pt",
363
+ "decoder.layers.19.fc2.bias": "checkpoint_best_part_2.pt",
364
+ "decoder.layers.20.fc2.bias": "checkpoint_best_part_2.pt",
365
+ "decoder.layers.21.fc2.bias": "checkpoint_best_part_2.pt",
366
+ "decoder.layers.22.fc2.bias": "checkpoint_best_part_2.pt",
367
+ "decoder.layers.23.fc2.bias": "checkpoint_best_part_2.pt",
368
+ "decoder.layers.12.final_layer_norm.weight": "checkpoint_best_part_2.pt",
369
+ "decoder.layers.13.final_layer_norm.weight": "checkpoint_best_part_2.pt",
370
+ "decoder.layers.14.final_layer_norm.weight": "checkpoint_best_part_2.pt",
371
+ "decoder.layers.15.final_layer_norm.weight": "checkpoint_best_part_2.pt",
372
+ "decoder.layers.16.final_layer_norm.weight": "checkpoint_best_part_2.pt",
373
+ "decoder.layers.17.final_layer_norm.weight": "checkpoint_best_part_2.pt",
374
+ "decoder.layers.18.final_layer_norm.weight": "checkpoint_best_part_2.pt",
375
+ "decoder.layers.19.final_layer_norm.weight": "checkpoint_best_part_2.pt",
376
+ "decoder.layers.20.final_layer_norm.weight": "checkpoint_best_part_2.pt",
377
+ "decoder.layers.21.final_layer_norm.weight": "checkpoint_best_part_2.pt",
378
+ "decoder.layers.22.final_layer_norm.weight": "checkpoint_best_part_2.pt",
379
+ "decoder.layers.23.final_layer_norm.weight": "checkpoint_best_part_2.pt",
380
+ "decoder.layers.12.final_layer_norm.bias": "checkpoint_best_part_2.pt",
381
+ "decoder.layers.13.final_layer_norm.bias": "checkpoint_best_part_2.pt",
382
+ "decoder.layers.14.final_layer_norm.bias": "checkpoint_best_part_2.pt",
383
+ "decoder.layers.15.final_layer_norm.bias": "checkpoint_best_part_2.pt",
384
+ "decoder.layers.16.final_layer_norm.bias": "checkpoint_best_part_2.pt",
385
+ "decoder.layers.17.final_layer_norm.bias": "checkpoint_best_part_2.pt",
386
+ "decoder.layers.18.final_layer_norm.bias": "checkpoint_best_part_2.pt",
387
+ "decoder.layers.19.final_layer_norm.bias": "checkpoint_best_part_2.pt",
388
+ "decoder.layers.20.final_layer_norm.bias": "checkpoint_best_part_2.pt",
389
+ "decoder.layers.21.final_layer_norm.bias": "checkpoint_best_part_2.pt",
390
+ "decoder.layers.22.final_layer_norm.bias": "checkpoint_best_part_2.pt",
391
+ "decoder.layers.23.final_layer_norm.bias": "checkpoint_best_part_2.pt",
392
+ "decoder.layers.24.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
393
+ "decoder.layers.25.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
394
+ "decoder.layers.26.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
395
+ "decoder.layers.27.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
396
+ "decoder.layers.28.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
397
+ "decoder.layers.29.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
398
+ "decoder.layers.30.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
399
+ "decoder.layers.31.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
400
+ "decoder.layers.32.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
401
+ "decoder.layers.33.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
402
+ "decoder.layers.34.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
403
+ "decoder.layers.35.self_attn.k_proj.weight": "checkpoint_best_part_3.pt",
404
+ "decoder.layers.24.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
405
+ "decoder.layers.25.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
406
+ "decoder.layers.26.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
407
+ "decoder.layers.27.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
408
+ "decoder.layers.28.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
409
+ "decoder.layers.29.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
410
+ "decoder.layers.30.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
411
+ "decoder.layers.31.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
412
+ "decoder.layers.32.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
413
+ "decoder.layers.33.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
414
+ "decoder.layers.34.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
415
+ "decoder.layers.35.self_attn.k_proj.bias": "checkpoint_best_part_3.pt",
416
+ "decoder.layers.24.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
417
+ "decoder.layers.25.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
418
+ "decoder.layers.26.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
419
+ "decoder.layers.27.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
420
+ "decoder.layers.28.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
421
+ "decoder.layers.29.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
422
+ "decoder.layers.30.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
423
+ "decoder.layers.31.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
424
+ "decoder.layers.32.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
425
+ "decoder.layers.33.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
426
+ "decoder.layers.34.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
427
+ "decoder.layers.35.self_attn.v_proj.weight": "checkpoint_best_part_3.pt",
428
+ "decoder.layers.24.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
429
+ "decoder.layers.25.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
430
+ "decoder.layers.26.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
431
+ "decoder.layers.27.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
432
+ "decoder.layers.28.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
433
+ "decoder.layers.29.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
434
+ "decoder.layers.30.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
435
+ "decoder.layers.31.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
436
+ "decoder.layers.32.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
437
+ "decoder.layers.33.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
438
+ "decoder.layers.34.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
439
+ "decoder.layers.35.self_attn.v_proj.bias": "checkpoint_best_part_3.pt",
440
+ "decoder.layers.24.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
441
+ "decoder.layers.25.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
442
+ "decoder.layers.26.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
443
+ "decoder.layers.27.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
444
+ "decoder.layers.28.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
445
+ "decoder.layers.29.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
446
+ "decoder.layers.30.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
447
+ "decoder.layers.31.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
448
+ "decoder.layers.32.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
449
+ "decoder.layers.33.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
450
+ "decoder.layers.34.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
451
+ "decoder.layers.35.self_attn.q_proj.weight": "checkpoint_best_part_3.pt",
452
+ "decoder.layers.24.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
453
+ "decoder.layers.25.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
454
+ "decoder.layers.26.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
455
+ "decoder.layers.27.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
456
+ "decoder.layers.28.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
457
+ "decoder.layers.29.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
458
+ "decoder.layers.30.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
459
+ "decoder.layers.31.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
460
+ "decoder.layers.32.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
461
+ "decoder.layers.33.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
462
+ "decoder.layers.34.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
463
+ "decoder.layers.35.self_attn.q_proj.bias": "checkpoint_best_part_3.pt",
464
+ "decoder.layers.24.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
465
+ "decoder.layers.25.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
466
+ "decoder.layers.26.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
467
+ "decoder.layers.27.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
468
+ "decoder.layers.28.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
469
+ "decoder.layers.29.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
470
+ "decoder.layers.30.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
471
+ "decoder.layers.31.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
472
+ "decoder.layers.32.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
473
+ "decoder.layers.33.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
474
+ "decoder.layers.34.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
475
+ "decoder.layers.35.self_attn.out_proj.weight": "checkpoint_best_part_3.pt",
476
+ "decoder.layers.24.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
477
+ "decoder.layers.25.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
478
+ "decoder.layers.26.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
479
+ "decoder.layers.27.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
480
+ "decoder.layers.28.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
481
+ "decoder.layers.29.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
482
+ "decoder.layers.30.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
483
+ "decoder.layers.31.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
484
+ "decoder.layers.32.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
485
+ "decoder.layers.33.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
486
+ "decoder.layers.34.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
487
+ "decoder.layers.35.self_attn.out_proj.bias": "checkpoint_best_part_3.pt",
488
+ "decoder.layers.24.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
489
+ "decoder.layers.25.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
490
+ "decoder.layers.26.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
491
+ "decoder.layers.27.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
492
+ "decoder.layers.28.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
493
+ "decoder.layers.29.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
494
+ "decoder.layers.30.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
495
+ "decoder.layers.31.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
496
+ "decoder.layers.32.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
497
+ "decoder.layers.33.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
498
+ "decoder.layers.34.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
499
+ "decoder.layers.35.self_attn_layer_norm.weight": "checkpoint_best_part_3.pt",
500
+ "decoder.layers.24.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
501
+ "decoder.layers.25.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
502
+ "decoder.layers.26.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
503
+ "decoder.layers.27.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
504
+ "decoder.layers.28.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
505
+ "decoder.layers.29.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
506
+ "decoder.layers.30.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
507
+ "decoder.layers.31.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
508
+ "decoder.layers.32.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
509
+ "decoder.layers.33.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
510
+ "decoder.layers.34.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
511
+ "decoder.layers.35.self_attn_layer_norm.bias": "checkpoint_best_part_3.pt",
512
+ "decoder.layers.24.fc1.weight": "checkpoint_best_part_3.pt",
513
+ "decoder.layers.25.fc1.weight": "checkpoint_best_part_3.pt",
514
+ "decoder.layers.26.fc1.weight": "checkpoint_best_part_3.pt",
515
+ "decoder.layers.27.fc1.weight": "checkpoint_best_part_3.pt",
516
+ "decoder.layers.28.fc1.weight": "checkpoint_best_part_3.pt",
517
+ "decoder.layers.29.fc1.weight": "checkpoint_best_part_3.pt",
518
+ "decoder.layers.30.fc1.weight": "checkpoint_best_part_3.pt",
519
+ "decoder.layers.31.fc1.weight": "checkpoint_best_part_3.pt",
520
+ "decoder.layers.32.fc1.weight": "checkpoint_best_part_3.pt",
521
+ "decoder.layers.33.fc1.weight": "checkpoint_best_part_3.pt",
522
+ "decoder.layers.34.fc1.weight": "checkpoint_best_part_3.pt",
523
+ "decoder.layers.35.fc1.weight": "checkpoint_best_part_3.pt",
524
+ "decoder.layers.24.fc1.bias": "checkpoint_best_part_3.pt",
525
+ "decoder.layers.25.fc1.bias": "checkpoint_best_part_3.pt",
526
+ "decoder.layers.26.fc1.bias": "checkpoint_best_part_3.pt",
527
+ "decoder.layers.27.fc1.bias": "checkpoint_best_part_3.pt",
528
+ "decoder.layers.28.fc1.bias": "checkpoint_best_part_3.pt",
529
+ "decoder.layers.29.fc1.bias": "checkpoint_best_part_3.pt",
530
+ "decoder.layers.30.fc1.bias": "checkpoint_best_part_3.pt",
531
+ "decoder.layers.31.fc1.bias": "checkpoint_best_part_3.pt",
532
+ "decoder.layers.32.fc1.bias": "checkpoint_best_part_3.pt",
533
+ "decoder.layers.33.fc1.bias": "checkpoint_best_part_3.pt",
534
+ "decoder.layers.34.fc1.bias": "checkpoint_best_part_3.pt",
535
+ "decoder.layers.35.fc1.bias": "checkpoint_best_part_3.pt",
536
+ "decoder.layers.24.fc2.weight": "checkpoint_best_part_3.pt",
537
+ "decoder.layers.25.fc2.weight": "checkpoint_best_part_3.pt",
538
+ "decoder.layers.26.fc2.weight": "checkpoint_best_part_3.pt",
539
+ "decoder.layers.27.fc2.weight": "checkpoint_best_part_3.pt",
540
+ "decoder.layers.28.fc2.weight": "checkpoint_best_part_3.pt",
541
+ "decoder.layers.29.fc2.weight": "checkpoint_best_part_3.pt",
542
+ "decoder.layers.30.fc2.weight": "checkpoint_best_part_3.pt",
543
+ "decoder.layers.31.fc2.weight": "checkpoint_best_part_3.pt",
544
+ "decoder.layers.32.fc2.weight": "checkpoint_best_part_3.pt",
545
+ "decoder.layers.33.fc2.weight": "checkpoint_best_part_3.pt",
546
+ "decoder.layers.34.fc2.weight": "checkpoint_best_part_3.pt",
547
+ "decoder.layers.35.fc2.weight": "checkpoint_best_part_3.pt",
548
+ "decoder.layers.24.fc2.bias": "checkpoint_best_part_3.pt",
549
+ "decoder.layers.25.fc2.bias": "checkpoint_best_part_3.pt",
550
+ "decoder.layers.26.fc2.bias": "checkpoint_best_part_3.pt",
551
+ "decoder.layers.27.fc2.bias": "checkpoint_best_part_3.pt",
552
+ "decoder.layers.28.fc2.bias": "checkpoint_best_part_3.pt",
553
+ "decoder.layers.29.fc2.bias": "checkpoint_best_part_3.pt",
554
+ "decoder.layers.30.fc2.bias": "checkpoint_best_part_3.pt",
555
+ "decoder.layers.31.fc2.bias": "checkpoint_best_part_3.pt",
556
+ "decoder.layers.32.fc2.bias": "checkpoint_best_part_3.pt",
557
+ "decoder.layers.33.fc2.bias": "checkpoint_best_part_3.pt",
558
+ "decoder.layers.34.fc2.bias": "checkpoint_best_part_3.pt",
559
+ "decoder.layers.35.fc2.bias": "checkpoint_best_part_3.pt",
560
+ "decoder.layers.24.final_layer_norm.weight": "checkpoint_best_part_3.pt",
561
+ "decoder.layers.25.final_layer_norm.weight": "checkpoint_best_part_3.pt",
562
+ "decoder.layers.26.final_layer_norm.weight": "checkpoint_best_part_3.pt",
563
+ "decoder.layers.27.final_layer_norm.weight": "checkpoint_best_part_3.pt",
564
+ "decoder.layers.28.final_layer_norm.weight": "checkpoint_best_part_3.pt",
565
+ "decoder.layers.29.final_layer_norm.weight": "checkpoint_best_part_3.pt",
566
+ "decoder.layers.30.final_layer_norm.weight": "checkpoint_best_part_3.pt",
567
+ "decoder.layers.31.final_layer_norm.weight": "checkpoint_best_part_3.pt",
568
+ "decoder.layers.32.final_layer_norm.weight": "checkpoint_best_part_3.pt",
569
+ "decoder.layers.33.final_layer_norm.weight": "checkpoint_best_part_3.pt",
570
+ "decoder.layers.34.final_layer_norm.weight": "checkpoint_best_part_3.pt",
571
+ "decoder.layers.35.final_layer_norm.weight": "checkpoint_best_part_3.pt",
572
+ "decoder.layers.24.final_layer_norm.bias": "checkpoint_best_part_3.pt",
573
+ "decoder.layers.25.final_layer_norm.bias": "checkpoint_best_part_3.pt",
574
+ "decoder.layers.26.final_layer_norm.bias": "checkpoint_best_part_3.pt",
575
+ "decoder.layers.27.final_layer_norm.bias": "checkpoint_best_part_3.pt",
576
+ "decoder.layers.28.final_layer_norm.bias": "checkpoint_best_part_3.pt",
577
+ "decoder.layers.29.final_layer_norm.bias": "checkpoint_best_part_3.pt",
578
+ "decoder.layers.30.final_layer_norm.bias": "checkpoint_best_part_3.pt",
579
+ "decoder.layers.31.final_layer_norm.bias": "checkpoint_best_part_3.pt",
580
+ "decoder.layers.32.final_layer_norm.bias": "checkpoint_best_part_3.pt",
581
+ "decoder.layers.33.final_layer_norm.bias": "checkpoint_best_part_3.pt",
582
+ "decoder.layers.34.final_layer_norm.bias": "checkpoint_best_part_3.pt",
583
+ "decoder.layers.35.final_layer_norm.bias": "checkpoint_best_part_3.pt"
584
+ }
dict.txt ADDED
The diff for this file is too large to render. See raw diff
 
inference.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+
3
+ from collections import namedtuple
4
+
5
+ import math
6
+ import torch
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+ from fairseq import checkpoint_utils, options, tasks, utils
10
+
11
+ Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
12
+
13
+ def make_batches(lines, task, max_positions, encode_fn):
14
+
15
+ tokens = [task.source_dictionary.encode_line(encode_fn(line),
16
+ add_if_not_exist=False,
17
+ append_eos=False,
18
+ reverse_order=True).long()
19
+ for line in lines]
20
+ lengths = [t.numel() for t in tokens]
21
+ tokens = pad_sequence(tokens, batch_first=True,
22
+ padding_value=1).flip(dims=(1,))
23
+
24
+ return Batch(ids=torch.arange(len(tokens)),
25
+ src_tokens=tokens,
26
+ src_lengths=torch.tensor(lengths))
27
+
28
+ def encode_fn(x_str):
29
+ x_str = "</s> " + x_str
30
+ return x_str
31
+
32
+
33
+ def decode_fn(x):
34
+ x = x.replace(" ", "")
35
+ return x
36
+
37
+ def eos_token_filter(sent):
38
+ return True
39
+
40
+
41
+ def post_precess(line):
42
+
43
+ if "<" in line:
44
+ line = line.split("<")[0]
45
+ return line
46
+
47
+
48
+ class Inference(object):
49
+
50
+ def __init__(self, model_path, data_path, eet_batch_size):
51
+
52
+ parser = options.get_generation_parser(default_task="language_modeling")
53
+ args = options.parse_args_and_arch(parser)
54
+ args.data = data_path
55
+ args.path = model_path
56
+ self.args = args
57
+
58
+ # generate parameter
59
+ args.beam = 1 # don't change
60
+ args.min_len = 5
61
+ args.max_len_b = 30
62
+ args.lenpen = 1.0
63
+ args.sampling = True
64
+ # args.sampling_topp = 0.7
65
+ args.sampling_topk = 10
66
+ args.temperature = 0.8
67
+ args.no_repeat_ngram_size = 1
68
+ args.fp16 = True
69
+
70
+ # Setup task, e.g., translation
71
+ task = tasks.setup_task(args)
72
+ self.task = task
73
+ # Set dictionaries
74
+ self.src_dict = task.source_dictionary
75
+ self.tgt_dict = task.target_dictionary
76
+
77
+ use_cuda = torch.cuda.is_available() and not args.cpu
78
+ self.use_cuda = use_cuda
79
+
80
+ # Optimize ensemble for generation
81
+ state = torch.load(args.path, map_location=torch.device("cpu"))
82
+ cfg_args = eval(str(state["cfg"]))["model"]
83
+ del cfg_args["_name"]
84
+ keys_list = []
85
+ values_list = []
86
+ for key,value in cfg_args.items() :
87
+ keys_list.append(key)
88
+ values_list.append(value)
89
+ Model_args = namedtuple("Model_args", keys_list)
90
+ model_args = Model_args._make(values_list)
91
+ del state
92
+
93
+ eet_seq_len = 512 # max seqence length
94
+ eet_batch_size = eet_batch_size
95
+ data_type = torch.float16
96
+ eet_config = {"data_type":data_type,
97
+ "max_batch":eet_batch_size,
98
+ "full_seq_len":eet_seq_len}
99
+ print(model_args)
100
+ from eet.fairseq.transformer import EETTransformerDecoder
101
+ eet_model = EETTransformerDecoder.from_fairseq_pretrained(model_id_or_path = args.path,
102
+ dictionary = self.src_dict,args=model_args,
103
+ config = eet_config,
104
+ no_encoder_attn = True)
105
+ self.models = [eet_model]
106
+ # Initialize generator
107
+ self.generator = task.build_generator(self.models, args)
108
+
109
+ # Load alignment dictionary for unknown word replacement
110
+ # (None if no unknown word replacement, empty if no path to align dictionary)
111
+ self.align_dict = utils.load_align_dict(args.replace_unk)
112
+
113
+ self.max_positions = 1024
114
+ self.eos_index = self.tgt_dict.eos()
115
+ self.pad_index = self.tgt_dict.pad()
116
+
117
+ def __call__(self, inputs, append_right_eos=True):
118
+
119
+ results = []
120
+ start_id = 0
121
+
122
+ batch = make_batches(inputs, self.task, self.max_positions, encode_fn)
123
+ inputs_str = inputs
124
+
125
+ src_tokens = batch.src_tokens
126
+ src_lengths = batch.src_lengths
127
+ # a new paragraph always
128
+ if src_tokens[0][-1].item() != self.eos_index and append_right_eos:
129
+ src_tokens = torch.cat([src_tokens, src_tokens.new_ones(src_tokens.size(0), 1) * self.eos_index], dim=1)
130
+ src_lengths += 1
131
+ if self.use_cuda:
132
+ src_tokens = src_tokens.cuda()
133
+ src_lengths = src_lengths.cuda()
134
+ sample = {
135
+ 'net_input': {
136
+ 'src_tokens': src_tokens,
137
+ 'src_lengths': src_lengths,
138
+ },
139
+ }
140
+
141
+ translations = self.task.inference_step(self.generator, self.models, sample)
142
+
143
+ for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
144
+ results.append((start_id + id, src_tokens[i], hypos))
145
+
146
+ # sort output to match input order
147
+ final_results = []
148
+ for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
149
+ # Process top predictions
150
+ tmp_res = []
151
+ for hypo in hypos[:min(len(hypos), self.args.nbest)]:
152
+ hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
153
+ hypo_tokens=hypo['tokens'].int().cpu()[len(src_tokens)-1:],
154
+ src_str=None,
155
+ alignment=hypo['alignment'],
156
+ align_dict=self.align_dict,
157
+ tgt_dict=self.tgt_dict)
158
+
159
+ detok_hypo_str = decode_fn(hypo_str)
160
+ if eos_token_filter(detok_hypo_str):
161
+ detok_hypo_str = post_precess(detok_hypo_str)
162
+ score = hypo['score'] / math.log(2) # convert to base 2
163
+ tmp_res.append([detok_hypo_str, score])
164
+ final_results.append(tmp_res)
165
+ return final_results
166
+
167
+
168
+
169
+
170
+ class Dialogue(object):
171
+ def __init__(self, inference_model=None, max_dialogue_history=6):
172
+
173
+ self.inference_model = inference_model
174
+ self.max_dialogue_history = max_dialogue_history
175
+ self.dialogue_history = []
176
+
177
+ def get_repsonse(self, input_text):
178
+ self.dialogue_history.append(input_text.strip())
179
+ model_inp = ""
180
+ for idx, x in enumerate(self.dialogue_history[-self.max_dialogue_history:]):
181
+ if idx % 2 == 0:
182
+ model_inp += " <0> " + " ".join(list(x))
183
+ else:
184
+ model_inp += " <1> " + " ".join(list(x))
185
+ if idx % 2 == 0:
186
+ model_inp += " <1>"
187
+ else:
188
+ model_inp += " <0>"
189
+ # generate 5 candidates
190
+ text = self.inference_model([model_inp]*5, append_right_eos=False)
191
+ response = [x[0][0] for x in text]
192
+ # response rank according to length
193
+ response = sorted(response, key=lambda x:len(set(x)))
194
+ # overlap-score
195
+ overlap = [[len(set(x) & set(model_inp)) * len(x), x] for x in response[-4:-1]]
196
+ overlap = sorted(overlap, key=lambda x:x[0])
197
+ final_response = overlap[-2][1]
198
+ self.dialogue_history.append(final_response)
199
+ return final_response
200
+
201
+ def clear_dialogue_history(self):
202
+ self.dialogue_history = []
203
+
204
+
205
+