Cyril666 commited on
Commit
feae110
·
1 Parent(s): 0fdb9e6

First model version

Browse files
app.py CHANGED
@@ -25,9 +25,6 @@ def infer(filepath):
25
  )
26
  image = cv2.imread(filepath)
27
  result_polygons, result_masks, result_boxes = det_demo.run_on_opencv_image(image)
28
- patchs = [image[box[1]:box[3], box[0]:box[2], :] for box in result_boxes]
29
- patchs = [cv2.resize(patch, (128,32)) for patch in patchs]
30
- patchs = np.stack(patchs, axis=0).transpose(0,3,1,2)
31
  visual_image = det_demo.visualization(image.copy(), result_polygons, result_masks, result_boxes)
32
  cv2.imwrite('result.jpg', visual_image)
33
  return 'result.jpg'#, pd.DataFrame(result_words)
 
25
  )
26
  image = cv2.imread(filepath)
27
  result_polygons, result_masks, result_boxes = det_demo.run_on_opencv_image(image)
 
 
 
28
  visual_image = det_demo.visualization(image.copy(), result_polygons, result_masks, result_boxes)
29
  cv2.imwrite('result.jpg', visual_image)
30
  return 'result.jpg'#, pd.DataFrame(result_words)
data/charset_36.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 a
2
+ 1 b
3
+ 2 c
4
+ 3 d
5
+ 4 e
6
+ 5 f
7
+ 6 g
8
+ 7 h
9
+ 8 i
10
+ 9 j
11
+ 10 k
12
+ 11 l
13
+ 12 m
14
+ 13 n
15
+ 14 o
16
+ 15 p
17
+ 16 q
18
+ 17 r
19
+ 18 s
20
+ 19 t
21
+ 20 u
22
+ 21 v
23
+ 22 w
24
+ 23 x
25
+ 24 y
26
+ 25 z
27
+ 26 1
28
+ 27 2
29
+ 28 3
30
+ 29 4
31
+ 30 5
32
+ 31 6
33
+ 32 7
34
+ 33 8
35
+ 34 9
36
+ 35 0
data/charset_62.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0 0
2
+ 1 1
3
+ 2 2
4
+ 3 3
5
+ 4 4
6
+ 5 5
7
+ 6 6
8
+ 7 7
9
+ 8 8
10
+ 9 9
11
+ 10 A
12
+ 11 B
13
+ 12 C
14
+ 13 D
15
+ 14 E
16
+ 15 F
17
+ 16 G
18
+ 17 H
19
+ 18 I
20
+ 19 J
21
+ 20 K
22
+ 21 L
23
+ 22 M
24
+ 23 N
25
+ 24 O
26
+ 25 P
27
+ 26 Q
28
+ 27 R
29
+ 28 S
30
+ 29 T
31
+ 30 U
32
+ 31 V
33
+ 32 W
34
+ 33 X
35
+ 34 Y
36
+ 35 Z
37
+ 36 a
38
+ 37 b
39
+ 38 c
40
+ 39 d
41
+ 40 e
42
+ 41 f
43
+ 42 g
44
+ 43 h
45
+ 44 i
46
+ 45 j
47
+ 46 k
48
+ 47 l
49
+ 48 m
50
+ 49 n
51
+ 50 o
52
+ 51 p
53
+ 52 q
54
+ 53 r
55
+ 54 s
56
+ 55 t
57
+ 56 u
58
+ 57 v
59
+ 58 w
60
+ 59 x
61
+ 60 y
62
+ 61 z
demo.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import glob
5
+ import tqdm
6
+ import torch
7
+ import PIL
8
+ import cv2
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from torchvision import transforms
12
+ from utils import Config, Logger, CharsetMapper
13
+
14
+ def get_model(config):
15
+ import importlib
16
+ names = config.model_name.split('.')
17
+ module_name, class_name = '.'.join(names[:-1]), names[-1]
18
+ cls = getattr(importlib.import_module(module_name), class_name)
19
+ model = cls(config)
20
+ logging.info(model)
21
+ model = model.eval()
22
+ return model
23
+
24
+ def preprocess(img, width, height):
25
+ img = cv2.resize(np.array(img), (width, height))
26
+ img = transforms.ToTensor()(img).unsqueeze(0)
27
+ mean = torch.tensor([0.485, 0.456, 0.406])
28
+ std = torch.tensor([0.229, 0.224, 0.225])
29
+ return (img-mean[...,None,None]) / std[...,None,None]
30
+
31
+ def postprocess(output, charset, model_eval):
32
+ def _get_output(last_output, model_eval):
33
+ if isinstance(last_output, (tuple, list)):
34
+ for res in last_output:
35
+ if res['name'] == model_eval: output = res
36
+ else: output = last_output
37
+ return output
38
+
39
+ def _decode(logit):
40
+ """ Greed decode """
41
+ out = F.softmax(logit, dim=2)
42
+ pt_text, pt_scores, pt_lengths = [], [], []
43
+ for o in out:
44
+ text = charset.get_text(o.argmax(dim=1), padding=False, trim=False)
45
+ text = text.split(charset.null_char)[0] # end at end-token
46
+ pt_text.append(text)
47
+ pt_scores.append(o.max(dim=1)[0])
48
+ pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token
49
+ return pt_text, pt_scores, pt_lengths
50
+
51
+ output = _get_output(output, model_eval)
52
+ logits, pt_lengths = output['logits'], output['pt_lengths']
53
+ pt_text, pt_scores, pt_lengths_ = _decode(logits)
54
+
55
+ return pt_text, pt_scores, pt_lengths_
56
+
57
+ def load(model, file, device=None, strict=True):
58
+ if device is None: device = 'cpu'
59
+ elif isinstance(device, int): device = torch.device('cuda', device)
60
+ assert os.path.isfile(file)
61
+ state = torch.load(file, map_location=device)
62
+ if set(state.keys()) == {'model', 'opt'}:
63
+ state = state['model']
64
+ model.load_state_dict(state, strict=strict)
65
+ return model
66
+
67
+ def main():
68
+ parser = argparse.ArgumentParser()
69
+ parser.add_argument('--config', type=str, default='configs/train_abinet.yaml',
70
+ help='path to config file')
71
+ parser.add_argument('--input', type=str, default='figs/test')
72
+ parser.add_argument('--cuda', type=int, default=-1)
73
+ parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth')
74
+ parser.add_argument('--model_eval', type=str, default='alignment',
75
+ choices=['alignment', 'vision', 'language'])
76
+ args = parser.parse_args()
77
+ config = Config(args.config)
78
+ if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
79
+ if args.model_eval is not None: config.model_eval = args.model_eval
80
+ config.global_phase = 'test'
81
+ config.model_vision_checkpoint, config.model_language_checkpoint = None, None
82
+ device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'
83
+
84
+ Logger.init(config.global_workdir, config.global_name, config.global_phase)
85
+ Logger.enable_file()
86
+ logging.info(config)
87
+
88
+ logging.info('Construct model.')
89
+ model = get_model(config).to(device)
90
+ model = load(model, config.model_checkpoint, device=device)
91
+ charset = CharsetMapper(filename=config.dataset_charset_path,
92
+ max_length=config.dataset_max_length + 1)
93
+
94
+ if os.path.isdir(args.input):
95
+ paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)]
96
+ else:
97
+ paths = glob.glob(os.path.expanduser(args.input))
98
+ assert paths, "The input path(s) was not found"
99
+ paths = sorted(paths)
100
+ for path in tqdm.tqdm(paths):
101
+ img = PIL.Image.open(path).convert('RGB')
102
+ img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
103
+ img = img.to(device)
104
+ res = model(img)
105
+ pt_text, _, __ = postprocess(res, charset, config.model_eval)
106
+ logging.info(f'{path}: {pt_text[0]}')
107
+
108
+ if __name__ == '__main__':
109
+ main()
demo/1.jpg DELETED
Binary file (26.9 kB)
 
demo/2.jpg DELETED
Binary file (19.1 kB)
 
demo/example1.jpg DELETED
Binary file (31.9 kB)
 
demo/example_results.jpg DELETED
Binary file (49.1 kB)
 
figs/cases.png ADDED
figs/framework.png ADDED
figs/test/CANDY.png ADDED
figs/test/ESPLANADE.png ADDED
figs/test/GLOBE.png ADDED
figs/test/KAPPA.png ADDED
figs/test/MANDARIN.png ADDED
figs/test/MEETS.png ADDED
figs/test/MONTHLY.png ADDED
figs/test/RESTROOM.png ADDED
modules/__init__.py ADDED
File without changes
modules/attention.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .transformer import PositionalEncoding
4
+
5
+ class Attention(nn.Module):
6
+ def __init__(self, in_channels=512, max_length=25, n_feature=256):
7
+ super().__init__()
8
+ self.max_length = max_length
9
+
10
+ self.f0_embedding = nn.Embedding(max_length, in_channels)
11
+ self.w0 = nn.Linear(max_length, n_feature)
12
+ self.wv = nn.Linear(in_channels, in_channels)
13
+ self.we = nn.Linear(in_channels, max_length)
14
+
15
+ self.active = nn.Tanh()
16
+ self.softmax = nn.Softmax(dim=2)
17
+
18
+ def forward(self, enc_output):
19
+ enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
20
+ reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
21
+ reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
22
+ reading_order_embed = self.f0_embedding(reading_order) # b,25,512
23
+
24
+ t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
25
+ t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
26
+
27
+ attn = self.we(t) # b,256,25
28
+ attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
29
+ g_output = torch.bmm(attn, enc_output) # b,25,512
30
+ return g_output, attn.view(*attn.shape[:2], 8, 32)
31
+
32
+
33
+ def encoder_layer(in_c, out_c, k=3, s=2, p=1):
34
+ return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
35
+ nn.BatchNorm2d(out_c),
36
+ nn.ReLU(True))
37
+
38
+ def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
39
+ align_corners = None if mode=='nearest' else True
40
+ return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
41
+ mode=mode, align_corners=align_corners),
42
+ nn.Conv2d(in_c, out_c, k, s, p),
43
+ nn.BatchNorm2d(out_c),
44
+ nn.ReLU(True))
45
+
46
+
47
+ class PositionAttention(nn.Module):
48
+ def __init__(self, max_length, in_channels=512, num_channels=64,
49
+ h=8, w=32, mode='nearest', **kwargs):
50
+ super().__init__()
51
+ self.max_length = max_length
52
+ self.k_encoder = nn.Sequential(
53
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
54
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
55
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
56
+ encoder_layer(num_channels, num_channels, s=(2, 2))
57
+ )
58
+ self.k_decoder = nn.Sequential(
59
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
60
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
61
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
62
+ decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
63
+ )
64
+
65
+ self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length)
66
+ self.project = nn.Linear(in_channels, in_channels)
67
+
68
+ def forward(self, x):
69
+ N, E, H, W = x.size()
70
+ k, v = x, x # (N, E, H, W)
71
+
72
+ # calculate key vector
73
+ features = []
74
+ for i in range(0, len(self.k_encoder)):
75
+ k = self.k_encoder[i](k)
76
+ features.append(k)
77
+ for i in range(0, len(self.k_decoder) - 1):
78
+ k = self.k_decoder[i](k)
79
+ k = k + features[len(self.k_decoder) - 2 - i]
80
+ k = self.k_decoder[-1](k)
81
+
82
+ # calculate query vector
83
+ # TODO q=f(q,k)
84
+ zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
85
+ q = self.pos_encoder(zeros) # (T, N, E)
86
+ q = q.permute(1, 0, 2) # (N, T, E)
87
+ q = self.project(q) # (N, T, E)
88
+
89
+ # calculate attention
90
+ attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
91
+ attn_scores = attn_scores / (E ** 0.5)
92
+ attn_scores = torch.softmax(attn_scores, dim=-1)
93
+
94
+ v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
95
+ attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
96
+
97
+ return attn_vecs, attn_scores.view(N, -1, H, W)
modules/backbone.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.model import _default_tfmer_cfg
6
+ from modules.resnet import resnet45
7
+ from modules.transformer import (PositionalEncoding,
8
+ TransformerEncoder,
9
+ TransformerEncoderLayer)
10
+
11
+
12
+ class ResTranformer(nn.Module):
13
+ def __init__(self, config):
14
+ super().__init__()
15
+ self.resnet = resnet45()
16
+
17
+ self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model'])
18
+ nhead = ifnone(config.model_vision_nhead, _default_tfmer_cfg['nhead'])
19
+ d_inner = ifnone(config.model_vision_d_inner, _default_tfmer_cfg['d_inner'])
20
+ dropout = ifnone(config.model_vision_dropout, _default_tfmer_cfg['dropout'])
21
+ activation = ifnone(config.model_vision_activation, _default_tfmer_cfg['activation'])
22
+ num_layers = ifnone(config.model_vision_backbone_ln, 2)
23
+
24
+ self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32)
25
+ encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead,
26
+ dim_feedforward=d_inner, dropout=dropout, activation=activation)
27
+ self.transformer = TransformerEncoder(encoder_layer, num_layers)
28
+
29
+ def forward(self, images):
30
+ feature = self.resnet(images)
31
+ n, c, h, w = feature.shape
32
+ feature = feature.view(n, c, -1).permute(2, 0, 1)
33
+ feature = self.pos_encoder(feature)
34
+ feature = self.transformer(feature)
35
+ feature = feature.permute(1, 2, 0).view(n, c, h, w)
36
+ return feature
modules/model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from utils import CharsetMapper
5
+
6
+
7
+ _default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024
8
+ dropout=0.1, activation='relu')
9
+
10
+ class Model(nn.Module):
11
+
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ self.max_length = config.dataset_max_length + 1
15
+ self.charset = CharsetMapper(config.dataset_charset_path, max_length=self.max_length)
16
+
17
+ def load(self, source, device=None, strict=True):
18
+ state = torch.load(source, map_location=device)
19
+ self.load_state_dict(state['model'], strict=strict)
20
+
21
+ def _get_length(self, logit, dim=-1):
22
+ """ Greed decoder to obtain length from logit"""
23
+ out = (logit.argmax(dim=-1) == self.charset.null_label)
24
+ abn = out.any(dim)
25
+ out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
26
+ out = out + 1 # additional end token
27
+ out = torch.where(abn, out, out.new_tensor(logit.shape[1]))
28
+ return out
29
+
30
+ @staticmethod
31
+ def _get_padding_mask(length, max_length):
32
+ length = length.unsqueeze(-1)
33
+ grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
34
+ return grid >= length
35
+
36
+ @staticmethod
37
+ def _get_square_subsequent_mask(sz, device, diagonal=0, fw=True):
38
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
39
+ Unmasked positions are filled with float(0.0).
40
+ """
41
+ mask = (torch.triu(torch.ones(sz, sz, device=device), diagonal=diagonal) == 1)
42
+ if fw: mask = mask.transpose(0, 1)
43
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
44
+ return mask
45
+
46
+ @staticmethod
47
+ def _get_location_mask(sz, device=None):
48
+ mask = torch.eye(sz, device=device)
49
+ mask = mask.float().masked_fill(mask == 1, float('-inf'))
50
+ return mask
modules/model_abinet.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from .model_vision import BaseVision
6
+ from .model_language import BCNLanguage
7
+ from .model_alignment import BaseAlignment
8
+
9
+
10
+ class ABINetModel(nn.Module):
11
+ def __init__(self, config):
12
+ super().__init__()
13
+ self.use_alignment = ifnone(config.model_use_alignment, True)
14
+ self.max_length = config.dataset_max_length + 1 # additional stop token
15
+ self.vision = BaseVision(config)
16
+ self.language = BCNLanguage(config)
17
+ if self.use_alignment: self.alignment = BaseAlignment(config)
18
+
19
+ def forward(self, images, *args):
20
+ v_res = self.vision(images)
21
+ v_tokens = torch.softmax(v_res['logits'], dim=-1)
22
+ v_lengths = v_res['pt_lengths'].clamp_(2, self.max_length) # TODO:move to langauge model
23
+
24
+ l_res = self.language(v_tokens, v_lengths)
25
+ if not self.use_alignment:
26
+ return l_res, v_res
27
+ l_feature, v_feature = l_res['feature'], v_res['feature']
28
+
29
+ a_res = self.alignment(l_feature, v_feature)
30
+ return a_res, l_res, v_res
modules/model_abinet_iter.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from .model_vision import BaseVision
6
+ from .model_language import BCNLanguage
7
+ from .model_alignment import BaseAlignment
8
+
9
+
10
+ class ABINetIterModel(nn.Module):
11
+ def __init__(self, config):
12
+ super().__init__()
13
+ self.iter_size = ifnone(config.model_iter_size, 1)
14
+ self.max_length = config.dataset_max_length + 1 # additional stop token
15
+ self.vision = BaseVision(config)
16
+ self.language = BCNLanguage(config)
17
+ self.alignment = BaseAlignment(config)
18
+
19
+ def forward(self, images, *args):
20
+ v_res = self.vision(images)
21
+ a_res = v_res
22
+ all_l_res, all_a_res = [], []
23
+ for _ in range(self.iter_size):
24
+ tokens = torch.softmax(a_res['logits'], dim=-1)
25
+ lengths = a_res['pt_lengths']
26
+ lengths.clamp_(2, self.max_length) # TODO:move to langauge model
27
+ l_res = self.language(tokens, lengths)
28
+ all_l_res.append(l_res)
29
+ a_res = self.alignment(l_res['feature'], v_res['feature'])
30
+ all_a_res.append(a_res)
31
+ if self.training:
32
+ return all_a_res, all_l_res, v_res
33
+ else:
34
+ return a_res, all_l_res[-1], v_res
modules/model_alignment.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.model import Model, _default_tfmer_cfg
6
+
7
+
8
+ class BaseAlignment(Model):
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model'])
12
+
13
+ self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0)
14
+ self.max_length = config.dataset_max_length + 1 # additional stop token
15
+ self.w_att = nn.Linear(2 * d_model, d_model)
16
+ self.cls = nn.Linear(d_model, self.charset.num_classes)
17
+
18
+ def forward(self, l_feature, v_feature):
19
+ """
20
+ Args:
21
+ l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
22
+ v_feature: (N, T, E) shape the same as l_feature
23
+ l_lengths: (N,)
24
+ v_lengths: (N,)
25
+ """
26
+ f = torch.cat((l_feature, v_feature), dim=2)
27
+ f_att = torch.sigmoid(self.w_att(f))
28
+ output = f_att * v_feature + (1 - f_att) * l_feature
29
+
30
+ logits = self.cls(output) # (N, T, C)
31
+ pt_lengths = self._get_length(logits)
32
+
33
+ return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight,
34
+ 'name': 'alignment'}
modules/model_language.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.model import _default_tfmer_cfg
6
+ from modules.model import Model
7
+ from modules.transformer import (PositionalEncoding,
8
+ TransformerDecoder,
9
+ TransformerDecoderLayer)
10
+
11
+
12
+ class BCNLanguage(Model):
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ d_model = ifnone(config.model_language_d_model, _default_tfmer_cfg['d_model'])
16
+ nhead = ifnone(config.model_language_nhead, _default_tfmer_cfg['nhead'])
17
+ d_inner = ifnone(config.model_language_d_inner, _default_tfmer_cfg['d_inner'])
18
+ dropout = ifnone(config.model_language_dropout, _default_tfmer_cfg['dropout'])
19
+ activation = ifnone(config.model_language_activation, _default_tfmer_cfg['activation'])
20
+ num_layers = ifnone(config.model_language_num_layers, 4)
21
+ self.d_model = d_model
22
+ self.detach = ifnone(config.model_language_detach, True)
23
+ self.use_self_attn = ifnone(config.model_language_use_self_attn, False)
24
+ self.loss_weight = ifnone(config.model_language_loss_weight, 1.0)
25
+ self.max_length = config.dataset_max_length + 1 # additional stop token
26
+ self.debug = ifnone(config.global_debug, False)
27
+
28
+ self.proj = nn.Linear(self.charset.num_classes, d_model, False)
29
+ self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
30
+ self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
31
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
32
+ activation, self_attn=self.use_self_attn, debug=self.debug)
33
+ self.model = TransformerDecoder(decoder_layer, num_layers)
34
+
35
+ self.cls = nn.Linear(d_model, self.charset.num_classes)
36
+
37
+ if config.model_language_checkpoint is not None:
38
+ logging.info(f'Read language model from {config.model_language_checkpoint}.')
39
+ self.load(config.model_language_checkpoint)
40
+
41
+ def forward(self, tokens, lengths):
42
+ """
43
+ Args:
44
+ tokens: (N, T, C) where T is length, N is batch size and C is classes number
45
+ lengths: (N,)
46
+ """
47
+ if self.detach: tokens = tokens.detach()
48
+ embed = self.proj(tokens) # (N, T, E)
49
+ embed = embed.permute(1, 0, 2) # (T, N, E)
50
+ embed = self.token_encoder(embed) # (T, N, E)
51
+ padding_mask = self._get_padding_mask(lengths, self.max_length)
52
+
53
+ zeros = embed.new_zeros(*embed.shape)
54
+ qeury = self.pos_encoder(zeros)
55
+ location_mask = self._get_location_mask(self.max_length, tokens.device)
56
+ output = self.model(qeury, embed,
57
+ tgt_key_padding_mask=padding_mask,
58
+ memory_mask=location_mask,
59
+ memory_key_padding_mask=padding_mask) # (T, N, E)
60
+ output = output.permute(1, 0, 2) # (N, T, E)
61
+
62
+ logits = self.cls(output) # (N, T, C)
63
+ pt_lengths = self._get_length(logits)
64
+
65
+ res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths,
66
+ 'loss_weight':self.loss_weight, 'name': 'language'}
67
+ return res
modules/model_vision.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch.nn as nn
3
+ from fastai.vision import *
4
+
5
+ from modules.attention import *
6
+ from modules.backbone import ResTranformer
7
+ from modules.model import Model
8
+ from modules.resnet import resnet45
9
+
10
+
11
+ class BaseVision(Model):
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ self.loss_weight = ifnone(config.model_vision_loss_weight, 1.0)
15
+ self.out_channels = ifnone(config.model_vision_d_model, 512)
16
+
17
+ if config.model_vision_backbone == 'transformer':
18
+ self.backbone = ResTranformer(config)
19
+ else: self.backbone = resnet45()
20
+
21
+ if config.model_vision_attention == 'position':
22
+ mode = ifnone(config.model_vision_attention_mode, 'nearest')
23
+ self.attention = PositionAttention(
24
+ max_length=config.dataset_max_length + 1, # additional stop token
25
+ mode=mode,
26
+ )
27
+ elif config.model_vision_attention == 'attention':
28
+ self.attention = Attention(
29
+ max_length=config.dataset_max_length + 1, # additional stop token
30
+ n_feature=8*32,
31
+ )
32
+ else:
33
+ raise Exception(f'{config.model_vision_attention} is not valid.')
34
+ self.cls = nn.Linear(self.out_channels, self.charset.num_classes)
35
+
36
+ if config.model_vision_checkpoint is not None:
37
+ logging.info(f'Read vision model from {config.model_vision_checkpoint}.')
38
+ self.load(config.model_vision_checkpoint)
39
+
40
+ def forward(self, images, *args):
41
+ features = self.backbone(images) # (N, E, H, W)
42
+ attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
43
+ logits = self.cls(attn_vecs) # (N, T, C)
44
+ pt_lengths = self._get_length(logits)
45
+
46
+ return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
47
+ 'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision'}
modules/resnet.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+
7
+
8
+ def conv1x1(in_planes, out_planes, stride=1):
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
10
+
11
+
12
+ def conv3x3(in_planes, out_planes, stride=1):
13
+ "3x3 convolution with padding"
14
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15
+ padding=1, bias=False)
16
+
17
+
18
+ class BasicBlock(nn.Module):
19
+ expansion = 1
20
+
21
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv1x1(inplanes, planes)
24
+ self.bn1 = nn.BatchNorm2d(planes)
25
+ self.relu = nn.ReLU(inplace=True)
26
+ self.conv2 = conv3x3(planes, planes, stride)
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+ self.downsample = downsample
29
+ self.stride = stride
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+
34
+ out = self.conv1(x)
35
+ out = self.bn1(out)
36
+ out = self.relu(out)
37
+
38
+ out = self.conv2(out)
39
+ out = self.bn2(out)
40
+
41
+ if self.downsample is not None:
42
+ residual = self.downsample(x)
43
+
44
+ out += residual
45
+ out = self.relu(out)
46
+
47
+ return out
48
+
49
+
50
+ class ResNet(nn.Module):
51
+
52
+ def __init__(self, block, layers):
53
+ self.inplanes = 32
54
+ super(ResNet, self).__init__()
55
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
56
+ bias=False)
57
+ self.bn1 = nn.BatchNorm2d(32)
58
+ self.relu = nn.ReLU(inplace=True)
59
+
60
+ self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
61
+ self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
62
+ self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
63
+ self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
64
+ self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
65
+
66
+ for m in self.modules():
67
+ if isinstance(m, nn.Conv2d):
68
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
69
+ m.weight.data.normal_(0, math.sqrt(2. / n))
70
+ elif isinstance(m, nn.BatchNorm2d):
71
+ m.weight.data.fill_(1)
72
+ m.bias.data.zero_()
73
+
74
+ def _make_layer(self, block, planes, blocks, stride=1):
75
+ downsample = None
76
+ if stride != 1 or self.inplanes != planes * block.expansion:
77
+ downsample = nn.Sequential(
78
+ nn.Conv2d(self.inplanes, planes * block.expansion,
79
+ kernel_size=1, stride=stride, bias=False),
80
+ nn.BatchNorm2d(planes * block.expansion),
81
+ )
82
+
83
+ layers = []
84
+ layers.append(block(self.inplanes, planes, stride, downsample))
85
+ self.inplanes = planes * block.expansion
86
+ for i in range(1, blocks):
87
+ layers.append(block(self.inplanes, planes))
88
+
89
+ return nn.Sequential(*layers)
90
+
91
+ def forward(self, x):
92
+ x = self.conv1(x)
93
+ x = self.bn1(x)
94
+ x = self.relu(x)
95
+ x = self.layer1(x)
96
+ x = self.layer2(x)
97
+ x = self.layer3(x)
98
+ x = self.layer4(x)
99
+ x = self.layer5(x)
100
+ return x
101
+
102
+
103
+ def resnet45():
104
+ return ResNet(BasicBlock, [3, 4, 6, 6, 3])
modules/transformer.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch 1.5.0
2
+ import copy
3
+ import math
4
+ import warnings
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+ from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, Parameter
11
+ from torch.nn import functional as F
12
+ from torch.nn.init import constant_, xavier_uniform_
13
+
14
+
15
+ def multi_head_attention_forward(query, # type: Tensor
16
+ key, # type: Tensor
17
+ value, # type: Tensor
18
+ embed_dim_to_check, # type: int
19
+ num_heads, # type: int
20
+ in_proj_weight, # type: Tensor
21
+ in_proj_bias, # type: Tensor
22
+ bias_k, # type: Optional[Tensor]
23
+ bias_v, # type: Optional[Tensor]
24
+ add_zero_attn, # type: bool
25
+ dropout_p, # type: float
26
+ out_proj_weight, # type: Tensor
27
+ out_proj_bias, # type: Tensor
28
+ training=True, # type: bool
29
+ key_padding_mask=None, # type: Optional[Tensor]
30
+ need_weights=True, # type: bool
31
+ attn_mask=None, # type: Optional[Tensor]
32
+ use_separate_proj_weight=False, # type: bool
33
+ q_proj_weight=None, # type: Optional[Tensor]
34
+ k_proj_weight=None, # type: Optional[Tensor]
35
+ v_proj_weight=None, # type: Optional[Tensor]
36
+ static_k=None, # type: Optional[Tensor]
37
+ static_v=None # type: Optional[Tensor]
38
+ ):
39
+ # type: (...) -> Tuple[Tensor, Optional[Tensor]]
40
+ r"""
41
+ Args:
42
+ query, key, value: map a query and a set of key-value pairs to an output.
43
+ See "Attention Is All You Need" for more details.
44
+ embed_dim_to_check: total dimension of the model.
45
+ num_heads: parallel attention heads.
46
+ in_proj_weight, in_proj_bias: input projection weight and bias.
47
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
48
+ add_zero_attn: add a new batch of zeros to the key and
49
+ value sequences at dim=1.
50
+ dropout_p: probability of an element to be zeroed.
51
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
52
+ training: apply dropout if is ``True``.
53
+ key_padding_mask: if provided, specified padding elements in the key will
54
+ be ignored by the attention. This is an binary mask. When the value is True,
55
+ the corresponding value on the attention layer will be filled with -inf.
56
+ need_weights: output attn_output_weights.
57
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
58
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
59
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
60
+ and value in different forms. If false, in_proj_weight will be used, which is
61
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
62
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
63
+ static_k, static_v: static key and value used for attention operators.
64
+ Shape:
65
+ Inputs:
66
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
67
+ the embedding dimension.
68
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
69
+ the embedding dimension.
70
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
71
+ the embedding dimension.
72
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
73
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
74
+ will be unchanged. If a BoolTensor is provided, the positions with the
75
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
76
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
77
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
78
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
79
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
80
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
81
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
82
+ is provided, it will be added to the attention weight.
83
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
84
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
85
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
86
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
87
+ Outputs:
88
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
89
+ E is the embedding dimension.
90
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
91
+ L is the target sequence length, S is the source sequence length.
92
+ """
93
+ # if not torch.jit.is_scripting():
94
+ # tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
95
+ # out_proj_weight, out_proj_bias)
96
+ # if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
97
+ # return handle_torch_function(
98
+ # multi_head_attention_forward, tens_ops, query, key, value,
99
+ # embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
100
+ # bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
101
+ # out_proj_bias, training=training, key_padding_mask=key_padding_mask,
102
+ # need_weights=need_weights, attn_mask=attn_mask,
103
+ # use_separate_proj_weight=use_separate_proj_weight,
104
+ # q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
105
+ # v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
106
+ tgt_len, bsz, embed_dim = query.size()
107
+ assert embed_dim == embed_dim_to_check
108
+ assert key.size() == value.size()
109
+
110
+ head_dim = embed_dim // num_heads
111
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
112
+ scaling = float(head_dim) ** -0.5
113
+
114
+ if not use_separate_proj_weight:
115
+ if torch.equal(query, key) and torch.equal(key, value):
116
+ # self-attention
117
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
118
+
119
+ elif torch.equal(key, value):
120
+ # encoder-decoder attention
121
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
122
+ _b = in_proj_bias
123
+ _start = 0
124
+ _end = embed_dim
125
+ _w = in_proj_weight[_start:_end, :]
126
+ if _b is not None:
127
+ _b = _b[_start:_end]
128
+ q = F.linear(query, _w, _b)
129
+
130
+ if key is None:
131
+ assert value is None
132
+ k = None
133
+ v = None
134
+ else:
135
+
136
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
137
+ _b = in_proj_bias
138
+ _start = embed_dim
139
+ _end = None
140
+ _w = in_proj_weight[_start:, :]
141
+ if _b is not None:
142
+ _b = _b[_start:]
143
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
144
+
145
+ else:
146
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
147
+ _b = in_proj_bias
148
+ _start = 0
149
+ _end = embed_dim
150
+ _w = in_proj_weight[_start:_end, :]
151
+ if _b is not None:
152
+ _b = _b[_start:_end]
153
+ q = F.linear(query, _w, _b)
154
+
155
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
156
+ _b = in_proj_bias
157
+ _start = embed_dim
158
+ _end = embed_dim * 2
159
+ _w = in_proj_weight[_start:_end, :]
160
+ if _b is not None:
161
+ _b = _b[_start:_end]
162
+ k = F.linear(key, _w, _b)
163
+
164
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
165
+ _b = in_proj_bias
166
+ _start = embed_dim * 2
167
+ _end = None
168
+ _w = in_proj_weight[_start:, :]
169
+ if _b is not None:
170
+ _b = _b[_start:]
171
+ v = F.linear(value, _w, _b)
172
+ else:
173
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
174
+ len1, len2 = q_proj_weight_non_opt.size()
175
+ assert len1 == embed_dim and len2 == query.size(-1)
176
+
177
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
178
+ len1, len2 = k_proj_weight_non_opt.size()
179
+ assert len1 == embed_dim and len2 == key.size(-1)
180
+
181
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
182
+ len1, len2 = v_proj_weight_non_opt.size()
183
+ assert len1 == embed_dim and len2 == value.size(-1)
184
+
185
+ if in_proj_bias is not None:
186
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
187
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
188
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
189
+ else:
190
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
191
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
192
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
193
+ q = q * scaling
194
+
195
+ if attn_mask is not None:
196
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
197
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
198
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
199
+ if attn_mask.dtype == torch.uint8:
200
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
201
+ attn_mask = attn_mask.to(torch.bool)
202
+
203
+ if attn_mask.dim() == 2:
204
+ attn_mask = attn_mask.unsqueeze(0)
205
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
206
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
207
+ elif attn_mask.dim() == 3:
208
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
209
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
210
+ else:
211
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
212
+ # attn_mask's dim is 3 now.
213
+
214
+ # # convert ByteTensor key_padding_mask to bool
215
+ # if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
216
+ # warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
217
+ # key_padding_mask = key_padding_mask.to(torch.bool)
218
+
219
+ if bias_k is not None and bias_v is not None:
220
+ if static_k is None and static_v is None:
221
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
222
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
223
+ if attn_mask is not None:
224
+ attn_mask = pad(attn_mask, (0, 1))
225
+ if key_padding_mask is not None:
226
+ key_padding_mask = pad(key_padding_mask, (0, 1))
227
+ else:
228
+ assert static_k is None, "bias cannot be added to static key."
229
+ assert static_v is None, "bias cannot be added to static value."
230
+ else:
231
+ assert bias_k is None
232
+ assert bias_v is None
233
+
234
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
235
+ if k is not None:
236
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
237
+ if v is not None:
238
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
239
+
240
+ if static_k is not None:
241
+ assert static_k.size(0) == bsz * num_heads
242
+ assert static_k.size(2) == head_dim
243
+ k = static_k
244
+
245
+ if static_v is not None:
246
+ assert static_v.size(0) == bsz * num_heads
247
+ assert static_v.size(2) == head_dim
248
+ v = static_v
249
+
250
+ src_len = k.size(1)
251
+
252
+ if key_padding_mask is not None:
253
+ assert key_padding_mask.size(0) == bsz
254
+ assert key_padding_mask.size(1) == src_len
255
+
256
+ if add_zero_attn:
257
+ src_len += 1
258
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
259
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
260
+ if attn_mask is not None:
261
+ attn_mask = pad(attn_mask, (0, 1))
262
+ if key_padding_mask is not None:
263
+ key_padding_mask = pad(key_padding_mask, (0, 1))
264
+
265
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
266
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
267
+
268
+ if attn_mask is not None:
269
+ if attn_mask.dtype == torch.bool:
270
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
271
+ else:
272
+ attn_output_weights += attn_mask
273
+
274
+
275
+ if key_padding_mask is not None:
276
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
277
+ attn_output_weights = attn_output_weights.masked_fill(
278
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
279
+ float('-inf'),
280
+ )
281
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
282
+
283
+ attn_output_weights = F.softmax(
284
+ attn_output_weights, dim=-1)
285
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
286
+
287
+ attn_output = torch.bmm(attn_output_weights, v)
288
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
289
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
290
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
291
+
292
+ if need_weights:
293
+ # average attention weights over heads
294
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
295
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
296
+ else:
297
+ return attn_output, None
298
+
299
+ class MultiheadAttention(Module):
300
+ r"""Allows the model to jointly attend to information
301
+ from different representation subspaces.
302
+ See reference: Attention Is All You Need
303
+ .. math::
304
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
305
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
306
+ Args:
307
+ embed_dim: total dimension of the model.
308
+ num_heads: parallel attention heads.
309
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
310
+ bias: add bias as module parameter. Default: True.
311
+ add_bias_kv: add bias to the key and value sequences at dim=0.
312
+ add_zero_attn: add a new batch of zeros to the key and
313
+ value sequences at dim=1.
314
+ kdim: total number of features in key. Default: None.
315
+ vdim: total number of features in value. Default: None.
316
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
317
+ query, key, and value have the same number of features.
318
+ Examples::
319
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
320
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
321
+ """
322
+ # __annotations__ = {
323
+ # 'bias_k': torch._jit_internal.Optional[torch.Tensor],
324
+ # 'bias_v': torch._jit_internal.Optional[torch.Tensor],
325
+ # }
326
+ __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']
327
+
328
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
329
+ super(MultiheadAttention, self).__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout = dropout
337
+ self.head_dim = embed_dim // num_heads
338
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
339
+
340
+ if self._qkv_same_embed_dim is False:
341
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
342
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
343
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
344
+ self.register_parameter('in_proj_weight', None)
345
+ else:
346
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
347
+ self.register_parameter('q_proj_weight', None)
348
+ self.register_parameter('k_proj_weight', None)
349
+ self.register_parameter('v_proj_weight', None)
350
+
351
+ if bias:
352
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
353
+ else:
354
+ self.register_parameter('in_proj_bias', None)
355
+ self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
356
+
357
+ if add_bias_kv:
358
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
359
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
360
+ else:
361
+ self.bias_k = self.bias_v = None
362
+
363
+ self.add_zero_attn = add_zero_attn
364
+
365
+ self._reset_parameters()
366
+
367
+ def _reset_parameters(self):
368
+ if self._qkv_same_embed_dim:
369
+ xavier_uniform_(self.in_proj_weight)
370
+ else:
371
+ xavier_uniform_(self.q_proj_weight)
372
+ xavier_uniform_(self.k_proj_weight)
373
+ xavier_uniform_(self.v_proj_weight)
374
+
375
+ if self.in_proj_bias is not None:
376
+ constant_(self.in_proj_bias, 0.)
377
+ constant_(self.out_proj.bias, 0.)
378
+ if self.bias_k is not None:
379
+ xavier_normal_(self.bias_k)
380
+ if self.bias_v is not None:
381
+ xavier_normal_(self.bias_v)
382
+
383
+ def __setstate__(self, state):
384
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
385
+ if '_qkv_same_embed_dim' not in state:
386
+ state['_qkv_same_embed_dim'] = True
387
+
388
+ super(MultiheadAttention, self).__setstate__(state)
389
+
390
+ def forward(self, query, key, value, key_padding_mask=None,
391
+ need_weights=True, attn_mask=None):
392
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
393
+ r"""
394
+ Args:
395
+ query, key, value: map a query and a set of key-value pairs to an output.
396
+ See "Attention Is All You Need" for more details.
397
+ key_padding_mask: if provided, specified padding elements in the key will
398
+ be ignored by the attention. This is an binary mask. When the value is True,
399
+ the corresponding value on the attention layer will be filled with -inf.
400
+ need_weights: output attn_output_weights.
401
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
402
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
403
+ Shape:
404
+ - Inputs:
405
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
406
+ the embedding dimension.
407
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
408
+ the embedding dimension.
409
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
410
+ the embedding dimension.
411
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
412
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
413
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
414
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
415
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
416
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
417
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
418
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
419
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
420
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
421
+ is provided, it will be added to the attention weight.
422
+ - Outputs:
423
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
424
+ E is the embedding dimension.
425
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
426
+ L is the target sequence length, S is the source sequence length.
427
+ """
428
+ if not self._qkv_same_embed_dim:
429
+ return multi_head_attention_forward(
430
+ query, key, value, self.embed_dim, self.num_heads,
431
+ self.in_proj_weight, self.in_proj_bias,
432
+ self.bias_k, self.bias_v, self.add_zero_attn,
433
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
434
+ training=self.training,
435
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
436
+ attn_mask=attn_mask, use_separate_proj_weight=True,
437
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
438
+ v_proj_weight=self.v_proj_weight)
439
+ else:
440
+ return multi_head_attention_forward(
441
+ query, key, value, self.embed_dim, self.num_heads,
442
+ self.in_proj_weight, self.in_proj_bias,
443
+ self.bias_k, self.bias_v, self.add_zero_attn,
444
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
445
+ training=self.training,
446
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
447
+ attn_mask=attn_mask)
448
+
449
+
450
+ class Transformer(Module):
451
+ r"""A transformer model. User is able to modify the attributes as needed. The architecture
452
+ is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
453
+ Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
454
+ Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
455
+ Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
456
+ model with corresponding parameters.
457
+
458
+ Args:
459
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
460
+ nhead: the number of heads in the multiheadattention models (default=8).
461
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
462
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
463
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
464
+ dropout: the dropout value (default=0.1).
465
+ activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
466
+ custom_encoder: custom encoder (default=None).
467
+ custom_decoder: custom decoder (default=None).
468
+
469
+ Examples::
470
+ >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
471
+ >>> src = torch.rand((10, 32, 512))
472
+ >>> tgt = torch.rand((20, 32, 512))
473
+ >>> out = transformer_model(src, tgt)
474
+
475
+ Note: A full example to apply nn.Transformer module for the word language model is available in
476
+ https://github.com/pytorch/examples/tree/master/word_language_model
477
+ """
478
+
479
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
480
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
481
+ activation="relu", custom_encoder=None, custom_decoder=None):
482
+ super(Transformer, self).__init__()
483
+
484
+ if custom_encoder is not None:
485
+ self.encoder = custom_encoder
486
+ else:
487
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
488
+ encoder_norm = LayerNorm(d_model)
489
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
490
+
491
+ if custom_decoder is not None:
492
+ self.decoder = custom_decoder
493
+ else:
494
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
495
+ decoder_norm = LayerNorm(d_model)
496
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
497
+
498
+ self._reset_parameters()
499
+
500
+ self.d_model = d_model
501
+ self.nhead = nhead
502
+
503
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None,
504
+ memory_mask=None, src_key_padding_mask=None,
505
+ tgt_key_padding_mask=None, memory_key_padding_mask=None):
506
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor # noqa
507
+ r"""Take in and process masked source/target sequences.
508
+
509
+ Args:
510
+ src: the sequence to the encoder (required).
511
+ tgt: the sequence to the decoder (required).
512
+ src_mask: the additive mask for the src sequence (optional).
513
+ tgt_mask: the additive mask for the tgt sequence (optional).
514
+ memory_mask: the additive mask for the encoder output (optional).
515
+ src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
516
+ tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
517
+ memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
518
+
519
+ Shape:
520
+ - src: :math:`(S, N, E)`.
521
+ - tgt: :math:`(T, N, E)`.
522
+ - src_mask: :math:`(S, S)`.
523
+ - tgt_mask: :math:`(T, T)`.
524
+ - memory_mask: :math:`(T, S)`.
525
+ - src_key_padding_mask: :math:`(N, S)`.
526
+ - tgt_key_padding_mask: :math:`(N, T)`.
527
+ - memory_key_padding_mask: :math:`(N, S)`.
528
+
529
+ Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
530
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
531
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
532
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
533
+ is provided, it will be added to the attention weight.
534
+ [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
535
+ the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
536
+ positions will be unchanged. If a BoolTensor is provided, the positions with the
537
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
538
+
539
+ - output: :math:`(T, N, E)`.
540
+
541
+ Note: Due to the multi-head attention architecture in the transformer model,
542
+ the output sequence length of a transformer is same as the input sequence
543
+ (i.e. target) length of the decode.
544
+
545
+ where S is the source sequence length, T is the target sequence length, N is the
546
+ batch size, E is the feature number
547
+
548
+ Examples:
549
+ >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
550
+ """
551
+
552
+ if src.size(1) != tgt.size(1):
553
+ raise RuntimeError("the batch number of src and tgt must be equal")
554
+
555
+ if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
556
+ raise RuntimeError("the feature number of src and tgt must be equal to d_model")
557
+
558
+ memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
559
+ output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
560
+ tgt_key_padding_mask=tgt_key_padding_mask,
561
+ memory_key_padding_mask=memory_key_padding_mask)
562
+ return output
563
+
564
+ def generate_square_subsequent_mask(self, sz):
565
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
566
+ Unmasked positions are filled with float(0.0).
567
+ """
568
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
569
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
570
+ return mask
571
+
572
+ def _reset_parameters(self):
573
+ r"""Initiate parameters in the transformer model."""
574
+
575
+ for p in self.parameters():
576
+ if p.dim() > 1:
577
+ xavier_uniform_(p)
578
+
579
+
580
+ class TransformerEncoder(Module):
581
+ r"""TransformerEncoder is a stack of N encoder layers
582
+
583
+ Args:
584
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
585
+ num_layers: the number of sub-encoder-layers in the encoder (required).
586
+ norm: the layer normalization component (optional).
587
+
588
+ Examples::
589
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
590
+ >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
591
+ >>> src = torch.rand(10, 32, 512)
592
+ >>> out = transformer_encoder(src)
593
+ """
594
+ __constants__ = ['norm']
595
+
596
+ def __init__(self, encoder_layer, num_layers, norm=None):
597
+ super(TransformerEncoder, self).__init__()
598
+ self.layers = _get_clones(encoder_layer, num_layers)
599
+ self.num_layers = num_layers
600
+ self.norm = norm
601
+
602
+ def forward(self, src, mask=None, src_key_padding_mask=None):
603
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
604
+ r"""Pass the input through the encoder layers in turn.
605
+
606
+ Args:
607
+ src: the sequence to the encoder (required).
608
+ mask: the mask for the src sequence (optional).
609
+ src_key_padding_mask: the mask for the src keys per batch (optional).
610
+
611
+ Shape:
612
+ see the docs in Transformer class.
613
+ """
614
+ output = src
615
+
616
+ for i, mod in enumerate(self.layers):
617
+ output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
618
+
619
+ if self.norm is not None:
620
+ output = self.norm(output)
621
+
622
+ return output
623
+
624
+
625
+ class TransformerDecoder(Module):
626
+ r"""TransformerDecoder is a stack of N decoder layers
627
+
628
+ Args:
629
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
630
+ num_layers: the number of sub-decoder-layers in the decoder (required).
631
+ norm: the layer normalization component (optional).
632
+
633
+ Examples::
634
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
635
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
636
+ >>> memory = torch.rand(10, 32, 512)
637
+ >>> tgt = torch.rand(20, 32, 512)
638
+ >>> out = transformer_decoder(tgt, memory)
639
+ """
640
+ __constants__ = ['norm']
641
+
642
+ def __init__(self, decoder_layer, num_layers, norm=None):
643
+ super(TransformerDecoder, self).__init__()
644
+ self.layers = _get_clones(decoder_layer, num_layers)
645
+ self.num_layers = num_layers
646
+ self.norm = norm
647
+
648
+ def forward(self, tgt, memory, memory2=None, tgt_mask=None,
649
+ memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None,
650
+ memory_key_padding_mask=None, memory_key_padding_mask2=None):
651
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
652
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
653
+
654
+ Args:
655
+ tgt: the sequence to the decoder (required).
656
+ memory: the sequence from the last layer of the encoder (required).
657
+ tgt_mask: the mask for the tgt sequence (optional).
658
+ memory_mask: the mask for the memory sequence (optional).
659
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
660
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
661
+
662
+ Shape:
663
+ see the docs in Transformer class.
664
+ """
665
+ output = tgt
666
+
667
+ for mod in self.layers:
668
+ output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
669
+ memory_mask=memory_mask, memory_mask2=memory_mask2,
670
+ tgt_key_padding_mask=tgt_key_padding_mask,
671
+ memory_key_padding_mask=memory_key_padding_mask,
672
+ memory_key_padding_mask2=memory_key_padding_mask2)
673
+
674
+ if self.norm is not None:
675
+ output = self.norm(output)
676
+
677
+ return output
678
+
679
+ class TransformerEncoderLayer(Module):
680
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
681
+ This standard encoder layer is based on the paper "Attention Is All You Need".
682
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
683
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
684
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
685
+ in a different way during application.
686
+
687
+ Args:
688
+ d_model: the number of expected features in the input (required).
689
+ nhead: the number of heads in the multiheadattention models (required).
690
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
691
+ dropout: the dropout value (default=0.1).
692
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
693
+
694
+ Examples::
695
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
696
+ >>> src = torch.rand(10, 32, 512)
697
+ >>> out = encoder_layer(src)
698
+ """
699
+
700
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
701
+ activation="relu", debug=False):
702
+ super(TransformerEncoderLayer, self).__init__()
703
+ self.debug = debug
704
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
705
+ # Implementation of Feedforward model
706
+ self.linear1 = Linear(d_model, dim_feedforward)
707
+ self.dropout = Dropout(dropout)
708
+ self.linear2 = Linear(dim_feedforward, d_model)
709
+
710
+ self.norm1 = LayerNorm(d_model)
711
+ self.norm2 = LayerNorm(d_model)
712
+ self.dropout1 = Dropout(dropout)
713
+ self.dropout2 = Dropout(dropout)
714
+
715
+ self.activation = _get_activation_fn(activation)
716
+
717
+ def __setstate__(self, state):
718
+ if 'activation' not in state:
719
+ state['activation'] = F.relu
720
+ super(TransformerEncoderLayer, self).__setstate__(state)
721
+
722
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
723
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
724
+ r"""Pass the input through the encoder layer.
725
+
726
+ Args:
727
+ src: the sequence to the encoder layer (required).
728
+ src_mask: the mask for the src sequence (optional).
729
+ src_key_padding_mask: the mask for the src keys per batch (optional).
730
+
731
+ Shape:
732
+ see the docs in Transformer class.
733
+ """
734
+ src2, attn = self.self_attn(src, src, src, attn_mask=src_mask,
735
+ key_padding_mask=src_key_padding_mask)
736
+ if self.debug: self.attn = attn
737
+ src = src + self.dropout1(src2)
738
+ src = self.norm1(src)
739
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
740
+ src = src + self.dropout2(src2)
741
+ src = self.norm2(src)
742
+
743
+ return src
744
+
745
+
746
+ class TransformerDecoderLayer(Module):
747
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
748
+ This standard decoder layer is based on the paper "Attention Is All You Need".
749
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
750
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
751
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
752
+ in a different way during application.
753
+
754
+ Args:
755
+ d_model: the number of expected features in the input (required).
756
+ nhead: the number of heads in the multiheadattention models (required).
757
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
758
+ dropout: the dropout value (default=0.1).
759
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
760
+
761
+ Examples::
762
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
763
+ >>> memory = torch.rand(10, 32, 512)
764
+ >>> tgt = torch.rand(20, 32, 512)
765
+ >>> out = decoder_layer(tgt, memory)
766
+ """
767
+
768
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
769
+ activation="relu", self_attn=True, siamese=False, debug=False):
770
+ super(TransformerDecoderLayer, self).__init__()
771
+ self.has_self_attn, self.siamese = self_attn, siamese
772
+ self.debug = debug
773
+ if self.has_self_attn:
774
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
775
+ self.norm1 = LayerNorm(d_model)
776
+ self.dropout1 = Dropout(dropout)
777
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
778
+ # Implementation of Feedforward model
779
+ self.linear1 = Linear(d_model, dim_feedforward)
780
+ self.dropout = Dropout(dropout)
781
+ self.linear2 = Linear(dim_feedforward, d_model)
782
+
783
+ self.norm2 = LayerNorm(d_model)
784
+ self.norm3 = LayerNorm(d_model)
785
+ self.dropout2 = Dropout(dropout)
786
+ self.dropout3 = Dropout(dropout)
787
+ if self.siamese:
788
+ self.multihead_attn2 = MultiheadAttention(d_model, nhead, dropout=dropout)
789
+
790
+ self.activation = _get_activation_fn(activation)
791
+
792
+ def __setstate__(self, state):
793
+ if 'activation' not in state:
794
+ state['activation'] = F.relu
795
+ super(TransformerDecoderLayer, self).__setstate__(state)
796
+
797
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
798
+ tgt_key_padding_mask=None, memory_key_padding_mask=None,
799
+ memory2=None, memory_mask2=None, memory_key_padding_mask2=None):
800
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
801
+ r"""Pass the inputs (and mask) through the decoder layer.
802
+
803
+ Args:
804
+ tgt: the sequence to the decoder layer (required).
805
+ memory: the sequence from the last layer of the encoder (required).
806
+ tgt_mask: the mask for the tgt sequence (optional).
807
+ memory_mask: the mask for the memory sequence (optional).
808
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
809
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
810
+
811
+ Shape:
812
+ see the docs in Transformer class.
813
+ """
814
+ if self.has_self_attn:
815
+ tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
816
+ key_padding_mask=tgt_key_padding_mask)
817
+ tgt = tgt + self.dropout1(tgt2)
818
+ tgt = self.norm1(tgt)
819
+ if self.debug: self.attn = attn
820
+ tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
821
+ key_padding_mask=memory_key_padding_mask)
822
+ if self.debug: self.attn2 = attn2
823
+
824
+ if self.siamese:
825
+ tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2,
826
+ key_padding_mask=memory_key_padding_mask2)
827
+ tgt = tgt + self.dropout2(tgt3)
828
+ if self.debug: self.attn3 = attn3
829
+
830
+ tgt = tgt + self.dropout2(tgt2)
831
+ tgt = self.norm2(tgt)
832
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
833
+ tgt = tgt + self.dropout3(tgt2)
834
+ tgt = self.norm3(tgt)
835
+
836
+ return tgt
837
+
838
+
839
+ def _get_clones(module, N):
840
+ return ModuleList([copy.deepcopy(module) for i in range(N)])
841
+
842
+
843
+ def _get_activation_fn(activation):
844
+ if activation == "relu":
845
+ return F.relu
846
+ elif activation == "gelu":
847
+ return F.gelu
848
+
849
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
850
+
851
+
852
+ class PositionalEncoding(nn.Module):
853
+ r"""Inject some information about the relative or absolute position of the tokens
854
+ in the sequence. The positional encodings have the same dimension as
855
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
856
+ functions of different frequencies.
857
+ .. math::
858
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
859
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
860
+ \text{where pos is the word position and i is the embed idx)
861
+ Args:
862
+ d_model: the embed dim (required).
863
+ dropout: the dropout value (default=0.1).
864
+ max_len: the max. length of the incoming sequence (default=5000).
865
+ Examples:
866
+ >>> pos_encoder = PositionalEncoding(d_model)
867
+ """
868
+
869
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
870
+ super(PositionalEncoding, self).__init__()
871
+ self.dropout = nn.Dropout(p=dropout)
872
+
873
+ pe = torch.zeros(max_len, d_model)
874
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
875
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
876
+ pe[:, 0::2] = torch.sin(position * div_term)
877
+ pe[:, 1::2] = torch.cos(position * div_term)
878
+ pe = pe.unsqueeze(0).transpose(0, 1)
879
+ self.register_buffer('pe', pe)
880
+
881
+ def forward(self, x):
882
+ r"""Inputs of forward function
883
+ Args:
884
+ x: the sequence fed to the positional encoder model (required).
885
+ Shape:
886
+ x: [sequence length, batch size, embed dim]
887
+ output: [sequence length, batch size, embed dim]
888
+ Examples:
889
+ >>> output = pos_encoder(x)
890
+ """
891
+
892
+ x = x + self.pe[:x.size(0), :]
893
+ return self.dropout(x)
894
+
895
+
896
+ if __name__ == '__main__':
897
+ transformer_model = Transformer(nhead=16, num_encoder_layers=12)
898
+ src = torch.rand((10, 32, 512))
899
+ tgt = torch.rand((20, 32, 512))
900
+ out = transformer_model(src, tgt)
901
+ print(out)
requirements.txt CHANGED
@@ -1,11 +1,17 @@
 
 
 
 
 
 
 
 
 
1
  ninja
2
  yacs
3
  cython
4
  matplotlib
5
  tqdm
6
- opencv-python
7
- torch==1.4.0
8
- torchvision==0.5.0
9
  shapely
10
  scipy
11
  networkx
 
1
+ torch==1.4.0
2
+ torchvision==0.5.0
3
+ fastai==1.0.60
4
+ LMDB
5
+ Pillow
6
+ opencv-python
7
+ tensorboardX
8
+ PyYAML
9
+ gdown
10
  ninja
11
  yacs
12
  cython
13
  matplotlib
14
  tqdm
 
 
 
15
  shapely
16
  scipy
17
  networkx
utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import yaml
9
+ from matplotlib import colors
10
+ from matplotlib import pyplot as plt
11
+ from torch import Tensor, nn
12
+ from torch.utils.data import ConcatDataset
13
+
14
+ class CharsetMapper(object):
15
+ """A simple class to map ids into strings.
16
+
17
+ It works only when the character set is 1:1 mapping between individual
18
+ characters and individual ids.
19
+ """
20
+
21
+ def __init__(self,
22
+ filename='',
23
+ max_length=30,
24
+ null_char=u'\u2591'):
25
+ """Creates a lookup table.
26
+
27
+ Args:
28
+ filename: Path to charset file which maps characters to ids.
29
+ max_sequence_length: The max length of ids and string.
30
+ null_char: A unicode character used to replace '<null>' character.
31
+ the default value is a light shade block '░'.
32
+ """
33
+ self.null_char = null_char
34
+ self.max_length = max_length
35
+
36
+ self.label_to_char = self._read_charset(filename)
37
+ self.char_to_label = dict(map(reversed, self.label_to_char.items()))
38
+ self.num_classes = len(self.label_to_char)
39
+
40
+ def _read_charset(self, filename):
41
+ """Reads a charset definition from a tab separated text file.
42
+
43
+ Args:
44
+ filename: a path to the charset file.
45
+
46
+ Returns:
47
+ a dictionary with keys equal to character codes and values - unicode
48
+ characters.
49
+ """
50
+ import re
51
+ pattern = re.compile(r'(\d+)\t(.+)')
52
+ charset = {}
53
+ self.null_label = 0
54
+ charset[self.null_label] = self.null_char
55
+ with open(filename, 'r') as f:
56
+ for i, line in enumerate(f):
57
+ m = pattern.match(line)
58
+ assert m, f'Incorrect charset file. line #{i}: {line}'
59
+ label = int(m.group(1)) + 1
60
+ char = m.group(2)
61
+ charset[label] = char
62
+ return charset
63
+
64
+ def trim(self, text):
65
+ assert isinstance(text, str)
66
+ return text.replace(self.null_char, '')
67
+
68
+ def get_text(self, labels, length=None, padding=True, trim=False):
69
+ """ Returns a string corresponding to a sequence of character ids.
70
+ """
71
+ length = length if length else self.max_length
72
+ labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels]
73
+ if padding:
74
+ labels = labels + [self.null_label] * (length-len(labels))
75
+ text = ''.join([self.label_to_char[label] for label in labels])
76
+ if trim: text = self.trim(text)
77
+ return text
78
+
79
+ def get_labels(self, text, length=None, padding=True, case_sensitive=False):
80
+ """ Returns the labels of the corresponding text.
81
+ """
82
+ length = length if length else self.max_length
83
+ if padding:
84
+ text = text + self.null_char * (length - len(text))
85
+ if not case_sensitive:
86
+ text = text.lower()
87
+ labels = [self.char_to_label[char] for char in text]
88
+ return labels
89
+
90
+ def pad_labels(self, labels, length=None):
91
+ length = length if length else self.max_length
92
+
93
+ return labels + [self.null_label] * (length - len(labels))
94
+
95
+ @property
96
+ def digits(self):
97
+ return '0123456789'
98
+
99
+ @property
100
+ def digit_labels(self):
101
+ return self.get_labels(self.digits, padding=False)
102
+
103
+ @property
104
+ def alphabets(self):
105
+ all_chars = list(self.char_to_label.keys())
106
+ valid_chars = []
107
+ for c in all_chars:
108
+ if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ':
109
+ valid_chars.append(c)
110
+ return ''.join(valid_chars)
111
+
112
+ @property
113
+ def alphabet_labels(self):
114
+ return self.get_labels(self.alphabets, padding=False)
115
+
116
+
117
+ class Timer(object):
118
+ """A simple timer."""
119
+ def __init__(self):
120
+ self.data_time = 0.
121
+ self.data_diff = 0.
122
+ self.data_total_time = 0.
123
+ self.data_call = 0
124
+ self.running_time = 0.
125
+ self.running_diff = 0.
126
+ self.running_total_time = 0.
127
+ self.running_call = 0
128
+
129
+ def tic(self):
130
+ self.start_time = time.time()
131
+ self.running_time = self.start_time
132
+
133
+ def toc_data(self):
134
+ self.data_time = time.time()
135
+ self.data_diff = self.data_time - self.running_time
136
+ self.data_total_time += self.data_diff
137
+ self.data_call += 1
138
+
139
+ def toc_running(self):
140
+ self.running_time = time.time()
141
+ self.running_diff = self.running_time - self.data_time
142
+ self.running_total_time += self.running_diff
143
+ self.running_call += 1
144
+
145
+ def total_time(self):
146
+ return self.data_total_time + self.running_total_time
147
+
148
+ def average_time(self):
149
+ return self.average_data_time() + self.average_running_time()
150
+
151
+ def average_data_time(self):
152
+ return self.data_total_time / (self.data_call or 1)
153
+
154
+ def average_running_time(self):
155
+ return self.running_total_time / (self.running_call or 1)
156
+
157
+
158
+ class Logger(object):
159
+ _handle = None
160
+ _root = None
161
+
162
+ @staticmethod
163
+ def init(output_dir, name, phase):
164
+ format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \
165
+ '%(message)s'.format(name)
166
+ logging.basicConfig(level=logging.INFO, format=format)
167
+
168
+ try: os.makedirs(output_dir)
169
+ except: pass
170
+ config_path = os.path.join(output_dir, f'{phase}.txt')
171
+ Logger._handle = logging.FileHandler(config_path)
172
+ Logger._root = logging.getLogger()
173
+
174
+ @staticmethod
175
+ def enable_file():
176
+ if Logger._handle is None or Logger._root is None:
177
+ raise Exception('Invoke Logger.init() first!')
178
+ Logger._root.addHandler(Logger._handle)
179
+
180
+ @staticmethod
181
+ def disable_file():
182
+ if Logger._handle is None or Logger._root is None:
183
+ raise Exception('Invoke Logger.init() first!')
184
+ Logger._root.removeHandler(Logger._handle)
185
+
186
+
187
+ class Config(object):
188
+
189
+ def __init__(self, config_path, host=True):
190
+ def __dict2attr(d, prefix=''):
191
+ for k, v in d.items():
192
+ if isinstance(v, dict):
193
+ __dict2attr(v, f'{prefix}{k}_')
194
+ else:
195
+ if k == 'phase':
196
+ assert v in ['train', 'test']
197
+ if k == 'stage':
198
+ assert v in ['pretrain-vision', 'pretrain-language',
199
+ 'train-semi-super', 'train-super']
200
+ self.__setattr__(f'{prefix}{k}', v)
201
+
202
+ assert os.path.exists(config_path), '%s does not exists!' % config_path
203
+ with open(config_path) as file:
204
+ config_dict = yaml.load(file, Loader=yaml.FullLoader)
205
+ with open('configs/rec/template.yaml') as file:
206
+ default_config_dict = yaml.load(file, Loader=yaml.FullLoader)
207
+ __dict2attr(default_config_dict)
208
+ __dict2attr(config_dict)
209
+ self.global_workdir = os.path.join(self.global_workdir, self.global_name)
210
+
211
+ def __getattr__(self, item):
212
+ attr = self.__dict__.get(item)
213
+ if attr is None:
214
+ attr = dict()
215
+ prefix = f'{item}_'
216
+ for k, v in self.__dict__.items():
217
+ if k.startswith(prefix):
218
+ n = k.replace(prefix, '')
219
+ attr[n] = v
220
+ return attr if len(attr) > 0 else None
221
+ else:
222
+ return attr
223
+
224
+ def __repr__(self):
225
+ str = 'ModelConfig(\n'
226
+ for i, (k, v) in enumerate(sorted(vars(self).items())):
227
+ str += f'\t({i}): {k} = {v}\n'
228
+ str += ')'
229
+ return str
230
+
231
+ def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0):
232
+ # normalize mask
233
+ mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps)
234
+ if mask.shape != image.shape:
235
+ mask = cv2.resize(mask,(image.shape[1], image.shape[0]))
236
+ # get color map
237
+ color_map = plt.get_cmap(cmap)
238
+ mask = color_map(mask)[:,:,:3]
239
+ # convert float to uint8
240
+ mask = (mask * 255).astype(dtype=np.uint8)
241
+
242
+ # set the basic color
243
+ basic_color = np.array(colors.to_rgb(color)) * 255
244
+ basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1])
245
+ basic_color = basic_color.astype(dtype=np.uint8)
246
+ # blend with basic color
247
+ blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0)
248
+ # blend with mask
249
+ blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0)
250
+
251
+ return blended_img
252
+
253
+ def onehot(label, depth, device=None):
254
+ """
255
+ Args:
256
+ label: shape (n1, n2, ..., )
257
+ depth: a scalar
258
+
259
+ Returns:
260
+ onehot: (n1, n2, ..., depth)
261
+ """
262
+ if not isinstance(label, torch.Tensor):
263
+ label = torch.tensor(label, device=device)
264
+ onehot = torch.zeros(label.size() + torch.Size([depth]), device=device)
265
+ onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1)
266
+
267
+ return onehot
268
+
269
+ class MyDataParallel(nn.DataParallel):
270
+
271
+ def gather(self, outputs, target_device):
272
+ r"""
273
+ Gathers tensors from different GPUs on a specified device
274
+ (-1 means the CPU).
275
+ """
276
+ def gather_map(outputs):
277
+ out = outputs[0]
278
+ if isinstance(out, (str, int, float)):
279
+ return out
280
+ if isinstance(out, list) and isinstance(out[0], str):
281
+ return [o for out in outputs for o in out]
282
+ if isinstance(out, torch.Tensor):
283
+ return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs)
284
+ if out is None:
285
+ return None
286
+ if isinstance(out, dict):
287
+ if not all((len(out) == len(d) for d in outputs)):
288
+ raise ValueError('All dicts must have the same number of keys')
289
+ return type(out)(((k, gather_map([d[k] for d in outputs]))
290
+ for k in out))
291
+ return type(out)(map(gather_map, zip(*outputs)))
292
+
293
+ # Recursive function calls like this create reference cycles.
294
+ # Setting the function to None clears the refcycle.
295
+ try:
296
+ res = gather_map(outputs)
297
+ finally:
298
+ gather_map = None
299
+ return res
300
+
301
+
302
+ class MyConcatDataset(ConcatDataset):
303
+ def __getattr__(self, k):
304
+ return getattr(self.datasets[0], k)