csaybar commited on
Commit
1182f33
·
verified ·
1 Parent(s): e9ad50a

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -63,3 +63,5 @@ Mamba_Medium_SR/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
63
  Mamba_Medium_SR/model.safetensor filter=lfs diff=lfs merge=lfs -text
64
  Mamba_Large_SR/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
65
  Mamba_Large_SR/model.safetensor filter=lfs diff=lfs merge=lfs -text
 
 
 
63
  Mamba_Medium_SR/model.safetensor filter=lfs diff=lfs merge=lfs -text
64
  Mamba_Large_SR/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
65
  Mamba_Large_SR/model.safetensor filter=lfs diff=lfs merge=lfs -text
66
+ Swin_Light_SR/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
67
+ Swin_Light_SR/model.safetensor filter=lfs diff=lfs merge=lfs -text
Swin_Light_SR/example_data.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7709dd46aabc069c2005f39ce830cb5659306a9c11221307557c56f6ed6cf65
3
+ size 13631584
Swin_Light_SR/load.py ADDED
@@ -0,0 +1,1409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed
3
+ # Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345
4
+ # Written by Conde and Choi et al.
5
+ # -----------------------------------------------------------------------------------
6
+
7
+ import math
8
+ import pathlib
9
+ import safetensors.torch
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
17
+
18
+
19
+ class Mlp(nn.Module):
20
+ def __init__(
21
+ self,
22
+ in_features,
23
+ hidden_features=None,
24
+ out_features=None,
25
+ act_layer=nn.GELU,
26
+ drop=0.0,
27
+ ):
28
+ super().__init__()
29
+ out_features = out_features or in_features
30
+ hidden_features = hidden_features or in_features
31
+ self.fc1 = nn.Linear(in_features, hidden_features)
32
+ self.act = act_layer()
33
+ self.fc2 = nn.Linear(hidden_features, out_features)
34
+ self.drop = nn.Dropout(drop)
35
+
36
+ def forward(self, x):
37
+ x = self.fc1(x)
38
+ x = self.act(x)
39
+ x = self.drop(x)
40
+ x = self.fc2(x)
41
+ x = self.drop(x)
42
+ return x
43
+
44
+
45
+ def window_partition(x, window_size):
46
+ """
47
+ Args:
48
+ x: (B, H, W, C)
49
+ window_size (int): window size
50
+ Returns:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ """
53
+ B, H, W, C = x.shape
54
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
55
+ windows = (
56
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
57
+ )
58
+ return windows
59
+
60
+
61
+ def window_reverse(windows, window_size, H, W):
62
+ """
63
+ Args:
64
+ windows: (num_windows*B, window_size, window_size, C)
65
+ window_size (int): Window size
66
+ H (int): Height of image
67
+ W (int): Width of image
68
+ Returns:
69
+ x: (B, H, W, C)
70
+ """
71
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
72
+ x = windows.view(
73
+ B, H // window_size, W // window_size, window_size, window_size, -1
74
+ )
75
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
76
+ return x
77
+
78
+
79
+ class WindowAttention(nn.Module):
80
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
81
+ It supports both of shifted and non-shifted window.
82
+ Args:
83
+ dim (int): Number of input channels.
84
+ window_size (tuple[int]): The height and width of the window.
85
+ num_heads (int): Number of attention heads.
86
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
87
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
88
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
89
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ dim,
95
+ window_size,
96
+ num_heads,
97
+ qkv_bias=True,
98
+ attn_drop=0.0,
99
+ proj_drop=0.0,
100
+ pretrained_window_size=[0, 0],
101
+ ):
102
+ super().__init__()
103
+ self.dim = dim
104
+ self.window_size = window_size # Wh, Ww
105
+ self.pretrained_window_size = pretrained_window_size
106
+ self.num_heads = num_heads
107
+
108
+ self.logit_scale = nn.Parameter(
109
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
110
+ )
111
+
112
+ # mlp to generate continuous relative position bias
113
+ self.cpb_mlp = nn.Sequential(
114
+ nn.Linear(2, 512, bias=True),
115
+ nn.ReLU(inplace=True),
116
+ nn.Linear(512, num_heads, bias=False),
117
+ )
118
+
119
+ # get relative_coords_table
120
+ relative_coords_h = torch.arange(
121
+ -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
122
+ )
123
+ relative_coords_w = torch.arange(
124
+ -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
125
+ )
126
+ relative_coords_table = (
127
+ torch.stack(
128
+ torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij")
129
+ )
130
+ .permute(1, 2, 0)
131
+ .contiguous()
132
+ .unsqueeze(0)
133
+ ) # 1, 2*Wh-1, 2*Ww-1, 2
134
+ if pretrained_window_size[0] > 0:
135
+ relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
136
+ relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
137
+ else:
138
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
139
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
140
+ relative_coords_table *= 8 # normalize to -8, 8
141
+ relative_coords_table = (
142
+ torch.sign(relative_coords_table)
143
+ * torch.log2(torch.abs(relative_coords_table) + 1.0)
144
+ / np.log2(8)
145
+ )
146
+
147
+ self.register_buffer("relative_coords_table", relative_coords_table)
148
+
149
+ # get pair-wise relative position index for each token inside the window
150
+ coords_h = torch.arange(self.window_size[0])
151
+ coords_w = torch.arange(self.window_size[1])
152
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
153
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
154
+ relative_coords = (
155
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
156
+ ) # 2, Wh*Ww, Wh*Ww
157
+ relative_coords = relative_coords.permute(
158
+ 1, 2, 0
159
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
160
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
161
+ relative_coords[:, :, 1] += self.window_size[1] - 1
162
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
163
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
164
+ self.register_buffer("relative_position_index", relative_position_index)
165
+
166
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
167
+ if qkv_bias:
168
+ self.q_bias = nn.Parameter(torch.zeros(dim))
169
+ self.v_bias = nn.Parameter(torch.zeros(dim))
170
+ else:
171
+ self.q_bias = None
172
+ self.v_bias = None
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+ self.softmax = nn.Softmax(dim=-1)
177
+
178
+ def forward(self, x, mask=None):
179
+ """
180
+ Args:
181
+ x: input features with shape of (num_windows*B, N, C)
182
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
183
+ """
184
+ B_, N, C = x.shape
185
+ qkv_bias = None
186
+ if self.q_bias is not None:
187
+ qkv_bias = torch.cat(
188
+ (
189
+ self.q_bias,
190
+ torch.zeros_like(self.v_bias, requires_grad=False),
191
+ self.v_bias,
192
+ )
193
+ )
194
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
195
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
196
+ q, k, v = (
197
+ qkv[0],
198
+ qkv[1],
199
+ qkv[2],
200
+ ) # make torchscript happy (cannot use tensor as tuple)
201
+
202
+ # cosine attention
203
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
204
+ logit_scale = torch.clamp(
205
+ self.logit_scale,
206
+ max=torch.log(torch.tensor(1.0 / 0.01)).to(self.logit_scale.device),
207
+ ).exp()
208
+ attn = attn * logit_scale
209
+
210
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(
211
+ -1, self.num_heads
212
+ )
213
+ relative_position_bias = relative_position_bias_table[
214
+ self.relative_position_index.view(-1)
215
+ ].view(
216
+ self.window_size[0] * self.window_size[1],
217
+ self.window_size[0] * self.window_size[1],
218
+ -1,
219
+ ) # Wh*Ww,Wh*Ww,nH
220
+ relative_position_bias = relative_position_bias.permute(
221
+ 2, 0, 1
222
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
223
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
224
+ attn = attn + relative_position_bias.unsqueeze(0)
225
+
226
+ if mask is not None:
227
+ nW = mask.shape[0]
228
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
229
+ 1
230
+ ).unsqueeze(0)
231
+ attn = attn.view(-1, self.num_heads, N, N)
232
+ attn = self.softmax(attn)
233
+ else:
234
+ attn = self.softmax(attn)
235
+
236
+ attn = self.attn_drop(attn)
237
+
238
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
239
+ x = self.proj(x)
240
+ x = self.proj_drop(x)
241
+ return x
242
+
243
+ def extra_repr(self) -> str:
244
+ return (
245
+ f"dim={self.dim}, window_size={self.window_size}, "
246
+ f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
247
+ )
248
+
249
+ def flops(self, N):
250
+ # calculate flops for 1 window with token length of N
251
+ flops = 0
252
+ # qkv = self.qkv(x)
253
+ flops += N * self.dim * 3 * self.dim
254
+ # attn = (q @ k.transpose(-2, -1))
255
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
256
+ # x = (attn @ v)
257
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
258
+ # x = self.proj(x)
259
+ flops += N * self.dim * self.dim
260
+ return flops
261
+
262
+
263
+ class SwinTransformerBlock(nn.Module):
264
+ r"""Swin Transformer Block.
265
+ Args:
266
+ dim (int): Number of input channels.
267
+ input_resolution (tuple[int]): Input resulotion.
268
+ num_heads (int): Number of attention heads.
269
+ window_size (int): Window size.
270
+ shift_size (int): Shift size for SW-MSA.
271
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
272
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
273
+ drop (float, optional): Dropout rate. Default: 0.0
274
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
275
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
276
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
277
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
278
+ pretrained_window_size (int): Window size in pre-training.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ dim,
284
+ input_resolution,
285
+ num_heads,
286
+ window_size=7,
287
+ shift_size=0,
288
+ mlp_ratio=4.0,
289
+ qkv_bias=True,
290
+ drop=0.0,
291
+ attn_drop=0.0,
292
+ drop_path=0.0,
293
+ act_layer=nn.GELU,
294
+ norm_layer=nn.LayerNorm,
295
+ pretrained_window_size=0,
296
+ ):
297
+ super().__init__()
298
+ self.dim = dim
299
+ self.input_resolution = input_resolution
300
+ self.num_heads = num_heads
301
+ self.window_size = window_size
302
+ self.shift_size = shift_size
303
+ self.mlp_ratio = mlp_ratio
304
+ if min(self.input_resolution) <= self.window_size:
305
+ # if window size is larger than input resolution, we don't partition windows
306
+ self.shift_size = 0
307
+ self.window_size = min(self.input_resolution)
308
+ assert (
309
+ 0 <= self.shift_size < self.window_size
310
+ ), "shift_size must in 0-window_size"
311
+
312
+ self.norm1 = norm_layer(dim)
313
+ self.attn = WindowAttention(
314
+ dim,
315
+ window_size=to_2tuple(self.window_size),
316
+ num_heads=num_heads,
317
+ qkv_bias=qkv_bias,
318
+ attn_drop=attn_drop,
319
+ proj_drop=drop,
320
+ pretrained_window_size=to_2tuple(pretrained_window_size),
321
+ )
322
+
323
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
324
+ self.norm2 = norm_layer(dim)
325
+ mlp_hidden_dim = int(dim * mlp_ratio)
326
+ self.mlp = Mlp(
327
+ in_features=dim,
328
+ hidden_features=mlp_hidden_dim,
329
+ act_layer=act_layer,
330
+ drop=drop,
331
+ )
332
+
333
+ if self.shift_size > 0:
334
+ attn_mask = self.calculate_mask(self.input_resolution)
335
+ else:
336
+ attn_mask = None
337
+
338
+ self.register_buffer("attn_mask", attn_mask)
339
+
340
+ def calculate_mask(self, x_size):
341
+ # calculate attention mask for SW-MSA
342
+ H, W = x_size
343
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
344
+ h_slices = (
345
+ slice(0, -self.window_size),
346
+ slice(-self.window_size, -self.shift_size),
347
+ slice(-self.shift_size, None),
348
+ )
349
+ w_slices = (
350
+ slice(0, -self.window_size),
351
+ slice(-self.window_size, -self.shift_size),
352
+ slice(-self.shift_size, None),
353
+ )
354
+ cnt = 0
355
+ for h in h_slices:
356
+ for w in w_slices:
357
+ img_mask[:, h, w, :] = cnt
358
+ cnt += 1
359
+
360
+ mask_windows = window_partition(
361
+ img_mask, self.window_size
362
+ ) # nW, window_size, window_size, 1
363
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
364
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
365
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
366
+ attn_mask == 0, float(0.0)
367
+ )
368
+
369
+ return attn_mask
370
+
371
+ def forward(self, x, x_size):
372
+ H, W = x_size
373
+ B, L, C = x.shape
374
+ # assert L == H * W, "input feature has wrong size"
375
+
376
+ shortcut = x
377
+ x = x.view(B, H, W, C)
378
+
379
+ # cyclic shift
380
+ if self.shift_size > 0:
381
+ shifted_x = torch.roll(
382
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
383
+ )
384
+ else:
385
+ shifted_x = x
386
+
387
+ # partition windows
388
+ x_windows = window_partition(
389
+ shifted_x, self.window_size
390
+ ) # nW*B, window_size, window_size, C
391
+ x_windows = x_windows.view(
392
+ -1, self.window_size * self.window_size, C
393
+ ) # nW*B, window_size*window_size, C
394
+
395
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
396
+ if self.input_resolution == x_size:
397
+ attn_windows = self.attn(
398
+ x_windows, mask=self.attn_mask
399
+ ) # nW*B, window_size*window_size, C
400
+ else:
401
+ attn_windows = self.attn(
402
+ x_windows, mask=self.calculate_mask(x_size).to(x.device)
403
+ )
404
+
405
+ # merge windows
406
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
407
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
408
+
409
+ # reverse cyclic shift
410
+ if self.shift_size > 0:
411
+ x = torch.roll(
412
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
413
+ )
414
+ else:
415
+ x = shifted_x
416
+ x = x.view(B, H * W, C)
417
+ x = shortcut + self.drop_path(self.norm1(x))
418
+
419
+ # FFN
420
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
421
+
422
+ return x
423
+
424
+ def extra_repr(self) -> str:
425
+ return (
426
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
427
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
428
+ )
429
+
430
+ def flops(self):
431
+ flops = 0
432
+ H, W = self.input_resolution
433
+ # norm1
434
+ flops += self.dim * H * W
435
+ # W-MSA/SW-MSA
436
+ nW = H * W / self.window_size / self.window_size
437
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
438
+ # mlp
439
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
440
+ # norm2
441
+ flops += self.dim * H * W
442
+ return flops
443
+
444
+
445
+ class PatchMerging(nn.Module):
446
+ r"""Patch Merging Layer.
447
+ Args:
448
+ input_resolution (tuple[int]): Resolution of input feature.
449
+ dim (int): Number of input channels.
450
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
451
+ """
452
+
453
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
454
+ super().__init__()
455
+ self.input_resolution = input_resolution
456
+ self.dim = dim
457
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
458
+ self.norm = norm_layer(2 * dim)
459
+
460
+ def forward(self, x):
461
+ """
462
+ x: B, H*W, C
463
+ """
464
+ H, W = self.input_resolution
465
+ B, L, C = x.shape
466
+ assert L == H * W, "input feature has wrong size"
467
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
468
+
469
+ x = x.view(B, H, W, C)
470
+
471
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
472
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
473
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
474
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
475
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
476
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
477
+
478
+ x = self.reduction(x)
479
+ x = self.norm(x)
480
+
481
+ return x
482
+
483
+ def extra_repr(self) -> str:
484
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
485
+
486
+ def flops(self):
487
+ H, W = self.input_resolution
488
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
489
+ flops += H * W * self.dim // 2
490
+ return flops
491
+
492
+
493
+ class BasicLayer(nn.Module):
494
+ """A basic Swin Transformer layer for one stage.
495
+ Args:
496
+ dim (int): Number of input channels.
497
+ input_resolution (tuple[int]): Input resolution.
498
+ depth (int): Number of blocks.
499
+ num_heads (int): Number of attention heads.
500
+ window_size (int): Local window size.
501
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
502
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
503
+ drop (float, optional): Dropout rate. Default: 0.0
504
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
505
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
506
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
507
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
508
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
509
+ pretrained_window_size (int): Local window size in pre-training.
510
+ """
511
+
512
+ def __init__(
513
+ self,
514
+ dim,
515
+ input_resolution,
516
+ depth,
517
+ num_heads,
518
+ window_size,
519
+ mlp_ratio=4.0,
520
+ qkv_bias=True,
521
+ drop=0.0,
522
+ attn_drop=0.0,
523
+ drop_path=0.0,
524
+ norm_layer=nn.LayerNorm,
525
+ downsample=None,
526
+ use_checkpoint=False,
527
+ pretrained_window_size=0,
528
+ ):
529
+ super().__init__()
530
+ self.dim = dim
531
+ self.input_resolution = input_resolution
532
+ self.depth = depth
533
+ self.use_checkpoint = use_checkpoint
534
+
535
+ # build blocks
536
+ self.blocks = nn.ModuleList(
537
+ [
538
+ SwinTransformerBlock(
539
+ dim=dim,
540
+ input_resolution=input_resolution,
541
+ num_heads=num_heads,
542
+ window_size=window_size,
543
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
544
+ mlp_ratio=mlp_ratio,
545
+ qkv_bias=qkv_bias,
546
+ drop=drop,
547
+ attn_drop=attn_drop,
548
+ drop_path=(
549
+ drop_path[i] if isinstance(drop_path, list) else drop_path
550
+ ),
551
+ norm_layer=norm_layer,
552
+ pretrained_window_size=pretrained_window_size,
553
+ )
554
+ for i in range(depth)
555
+ ]
556
+ )
557
+
558
+ # patch merging layer
559
+ if downsample is not None:
560
+ self.downsample = downsample(
561
+ input_resolution, dim=dim, norm_layer=norm_layer
562
+ )
563
+ else:
564
+ self.downsample = None
565
+
566
+ def forward(self, x, x_size):
567
+ for blk in self.blocks:
568
+ if self.use_checkpoint:
569
+ x = checkpoint.checkpoint(blk, x, x_size)
570
+ else:
571
+ x = blk(x, x_size)
572
+ if self.downsample is not None:
573
+ x = self.downsample(x)
574
+ return x
575
+
576
+ def extra_repr(self) -> str:
577
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
578
+
579
+ def flops(self):
580
+ flops = 0
581
+ for blk in self.blocks:
582
+ flops += blk.flops()
583
+ if self.downsample is not None:
584
+ flops += self.downsample.flops()
585
+ return flops
586
+
587
+ def _init_respostnorm(self):
588
+ for blk in self.blocks:
589
+ nn.init.constant_(blk.norm1.bias, 0)
590
+ nn.init.constant_(blk.norm1.weight, 0)
591
+ nn.init.constant_(blk.norm2.bias, 0)
592
+ nn.init.constant_(blk.norm2.weight, 0)
593
+
594
+
595
+ class PatchEmbed(nn.Module):
596
+ r"""Image to Patch Embedding
597
+ Args:
598
+ img_size (int): Image size. Default: 224.
599
+ patch_size (int): Patch token size. Default: 4.
600
+ in_chans (int): Number of input image channels. Default: 3.
601
+ embed_dim (int): Number of linear projection output channels. Default: 96.
602
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
603
+ """
604
+
605
+ def __init__(
606
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
607
+ ):
608
+ super().__init__()
609
+ img_size = to_2tuple(img_size)
610
+ patch_size = to_2tuple(patch_size)
611
+ patches_resolution = [
612
+ img_size[0] // patch_size[0],
613
+ img_size[1] // patch_size[1],
614
+ ]
615
+ self.img_size = img_size
616
+ self.patch_size = patch_size
617
+ self.patches_resolution = patches_resolution
618
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
619
+
620
+ self.in_chans = in_chans
621
+ self.embed_dim = embed_dim
622
+
623
+ self.proj = nn.Conv2d(
624
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
625
+ )
626
+ if norm_layer is not None:
627
+ self.norm = norm_layer(embed_dim)
628
+ else:
629
+ self.norm = None
630
+
631
+ def forward(self, x):
632
+ B, C, H, W = x.shape
633
+ # FIXME look at relaxing size constraints
634
+ # assert H == self.img_size[0] and W == self.img_size[1],
635
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
636
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
637
+ if self.norm is not None:
638
+ x = self.norm(x)
639
+ return x
640
+
641
+ def flops(self):
642
+ Ho, Wo = self.patches_resolution
643
+ flops = (
644
+ Ho
645
+ * Wo
646
+ * self.embed_dim
647
+ * self.in_chans
648
+ * (self.patch_size[0] * self.patch_size[1])
649
+ )
650
+ if self.norm is not None:
651
+ flops += Ho * Wo * self.embed_dim
652
+ return flops
653
+
654
+
655
+ class RSTB(nn.Module):
656
+ """Residual Swin Transformer Block (RSTB).
657
+
658
+ Args:
659
+ dim (int): Number of input channels.
660
+ input_resolution (tuple[int]): Input resolution.
661
+ depth (int): Number of blocks.
662
+ num_heads (int): Number of attention heads.
663
+ window_size (int): Local window size.
664
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
665
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
666
+ drop (float, optional): Dropout rate. Default: 0.0
667
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
668
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
669
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
670
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
671
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
672
+ img_size: Input image size.
673
+ patch_size: Patch size.
674
+ resi_connection: The convolutional block before residual connection.
675
+ """
676
+
677
+ def __init__(
678
+ self,
679
+ dim,
680
+ input_resolution,
681
+ depth,
682
+ num_heads,
683
+ window_size,
684
+ mlp_ratio=4.0,
685
+ qkv_bias=True,
686
+ drop=0.0,
687
+ attn_drop=0.0,
688
+ drop_path=0.0,
689
+ norm_layer=nn.LayerNorm,
690
+ downsample=None,
691
+ use_checkpoint=False,
692
+ img_size=224,
693
+ patch_size=4,
694
+ resi_connection="1conv",
695
+ ):
696
+ super(RSTB, self).__init__()
697
+
698
+ self.dim = dim
699
+ self.input_resolution = input_resolution
700
+
701
+ self.residual_group = BasicLayer(
702
+ dim=dim,
703
+ input_resolution=input_resolution,
704
+ depth=depth,
705
+ num_heads=num_heads,
706
+ window_size=window_size,
707
+ mlp_ratio=mlp_ratio,
708
+ qkv_bias=qkv_bias,
709
+ drop=drop,
710
+ attn_drop=attn_drop,
711
+ drop_path=drop_path,
712
+ norm_layer=norm_layer,
713
+ downsample=downsample,
714
+ use_checkpoint=use_checkpoint,
715
+ )
716
+
717
+ if resi_connection == "1conv":
718
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
719
+ elif resi_connection == "3conv":
720
+ # to save parameters and memory
721
+ self.conv = nn.Sequential(
722
+ nn.Conv2d(dim, dim // 4, 3, 1, 1),
723
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
724
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
725
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
726
+ nn.Conv2d(dim // 4, dim, 3, 1, 1),
727
+ )
728
+
729
+ self.patch_embed = PatchEmbed(
730
+ img_size=img_size,
731
+ patch_size=patch_size,
732
+ in_chans=dim,
733
+ embed_dim=dim,
734
+ norm_layer=None,
735
+ )
736
+
737
+ self.patch_unembed = PatchUnEmbed(
738
+ img_size=img_size,
739
+ patch_size=patch_size,
740
+ in_chans=dim,
741
+ embed_dim=dim,
742
+ norm_layer=None,
743
+ )
744
+
745
+ def forward(self, x, x_size):
746
+ return (
747
+ self.patch_embed(
748
+ self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
749
+ )
750
+ + x
751
+ )
752
+
753
+ def flops(self):
754
+ flops = 0
755
+ flops += self.residual_group.flops()
756
+ H, W = self.input_resolution
757
+ flops += H * W * self.dim * self.dim * 9
758
+ flops += self.patch_embed.flops()
759
+ flops += self.patch_unembed.flops()
760
+
761
+ return flops
762
+
763
+
764
+ class PatchUnEmbed(nn.Module):
765
+ r"""Image to Patch Unembedding
766
+
767
+ Args:
768
+ img_size (int): Image size. Default: 224.
769
+ patch_size (int): Patch token size. Default: 4.
770
+ in_chans (int): Number of input image channels. Default: 3.
771
+ embed_dim (int): Number of linear projection output channels. Default: 96.
772
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
773
+ """
774
+
775
+ def __init__(
776
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
777
+ ):
778
+ super().__init__()
779
+ img_size = to_2tuple(img_size)
780
+ patch_size = to_2tuple(patch_size)
781
+ patches_resolution = [
782
+ img_size[0] // patch_size[0],
783
+ img_size[1] // patch_size[1],
784
+ ]
785
+ self.img_size = img_size
786
+ self.patch_size = patch_size
787
+ self.patches_resolution = patches_resolution
788
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
789
+
790
+ self.in_chans = in_chans
791
+ self.embed_dim = embed_dim
792
+
793
+ def forward(self, x, x_size):
794
+ B, HW, C = x.shape
795
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
796
+ return x
797
+
798
+ def flops(self):
799
+ flops = 0
800
+ return flops
801
+
802
+
803
+ class Upsample(nn.Sequential):
804
+ """Upsample module.
805
+
806
+ Args:
807
+ scale (int): Scale factor. Supported scales: 2^n and 3.
808
+ num_feat (int): Channel number of intermediate features.
809
+ """
810
+
811
+ def __init__(self, scale, num_feat):
812
+ m = []
813
+ if (scale & (scale - 1)) == 0: # scale = 2^n
814
+ for _ in range(int(math.log(scale, 2))):
815
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
816
+ m.append(nn.PixelShuffle(2))
817
+ elif scale == 3:
818
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
819
+ m.append(nn.PixelShuffle(3))
820
+ else:
821
+ raise ValueError(
822
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
823
+ )
824
+ super(Upsample, self).__init__(*m)
825
+
826
+
827
+ class Upsample_hf(nn.Sequential):
828
+ """Upsample module.
829
+
830
+ Args:
831
+ scale (int): Scale factor. Supported scales: 2^n and 3.
832
+ num_feat (int): Channel number of intermediate features.
833
+ """
834
+
835
+ def __init__(self, scale, num_feat):
836
+ m = []
837
+ if (scale & (scale - 1)) == 0: # scale = 2^n
838
+ for _ in range(int(math.log(scale, 2))):
839
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
840
+ m.append(nn.PixelShuffle(2))
841
+ elif scale == 3:
842
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
843
+ m.append(nn.PixelShuffle(3))
844
+ else:
845
+ raise ValueError(
846
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
847
+ )
848
+ super(Upsample_hf, self).__init__(*m)
849
+
850
+
851
+ class UpsampleOneStep(nn.Sequential):
852
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
853
+ Used in lightweight SR to save parameters.
854
+
855
+ Args:
856
+ scale (int): Scale factor. Supported scales: 2^n and 3.
857
+ num_feat (int): Channel number of intermediate features.
858
+
859
+ """
860
+
861
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
862
+ self.num_feat = num_feat
863
+ self.input_resolution = input_resolution
864
+ m = []
865
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
866
+ m.append(nn.PixelShuffle(scale))
867
+ super(UpsampleOneStep, self).__init__(*m)
868
+
869
+ def flops(self):
870
+ H, W = self.input_resolution
871
+ flops = H * W * self.num_feat * 3 * 9
872
+ return flops
873
+
874
+
875
+ class Swin2SR(nn.Module):
876
+ r"""Swin2SR
877
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
878
+
879
+ Args:
880
+ img_size (int | tuple(int)): Input image size. Default 64
881
+ patch_size (int | tuple(int)): Patch size. Default: 1
882
+ in_chans (int): Number of input image channels. Default: 3
883
+ embed_dim (int): Patch embedding dimension. Default: 96
884
+ depths (tuple(int)): Depth of each Swin Transformer layer.
885
+ num_heads (tuple(int)): Number of attention heads in different layers.
886
+ window_size (int): Window size. Default: 7
887
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
888
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
889
+ drop_rate (float): Dropout rate. Default: 0
890
+ attn_drop_rate (float): Attention dropout rate. Default: 0
891
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
892
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
893
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
894
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
895
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
896
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
897
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
898
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
899
+ """
900
+
901
+ def __init__(
902
+ self,
903
+ img_size=64,
904
+ patch_size=1,
905
+ in_channels=3,
906
+ out_channels=3,
907
+ embed_dim=96,
908
+ depths=[6, 6, 6, 6],
909
+ num_heads=[6, 6, 6, 6],
910
+ window_size=7,
911
+ mlp_ratio=4.0,
912
+ qkv_bias=True,
913
+ drop_rate=0.0,
914
+ attn_drop_rate=0.0,
915
+ drop_path_rate=0.1,
916
+ norm_layer=nn.LayerNorm,
917
+ ape=False,
918
+ patch_norm=True,
919
+ use_checkpoint=False,
920
+ upscale=2,
921
+ upsampler="",
922
+ resi_connection="1conv",
923
+ **kwargs,
924
+ ):
925
+ super(Swin2SR, self).__init__()
926
+ num_in_ch = in_channels
927
+ num_out_ch = out_channels
928
+ num_feat = 64
929
+ self.upscale = upscale
930
+ self.upsampler = upsampler
931
+ self.window_size = window_size
932
+
933
+ #####################################################################################################
934
+ ################################### 1, shallow feature extraction ###################################
935
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
936
+
937
+ #####################################################################################################
938
+ ################################### 2, deep feature extraction ######################################
939
+ self.num_layers = len(depths)
940
+ self.embed_dim = embed_dim
941
+ self.ape = ape
942
+ self.patch_norm = patch_norm
943
+ self.num_features = embed_dim
944
+ self.mlp_ratio = mlp_ratio
945
+
946
+ # split image into non-overlapping patches
947
+ self.patch_embed = PatchEmbed(
948
+ img_size=img_size,
949
+ patch_size=patch_size,
950
+ in_chans=embed_dim,
951
+ embed_dim=embed_dim,
952
+ norm_layer=norm_layer if self.patch_norm else None,
953
+ )
954
+ num_patches = self.patch_embed.num_patches
955
+ patches_resolution = self.patch_embed.patches_resolution
956
+ self.patches_resolution = patches_resolution
957
+
958
+ # merge non-overlapping patches into image
959
+ self.patch_unembed = PatchUnEmbed(
960
+ img_size=img_size,
961
+ patch_size=patch_size,
962
+ in_chans=embed_dim,
963
+ embed_dim=embed_dim,
964
+ norm_layer=norm_layer if self.patch_norm else None,
965
+ )
966
+
967
+ # absolute position embedding
968
+ if self.ape:
969
+ self.absolute_pos_embed = nn.Parameter(
970
+ torch.zeros(1, num_patches, embed_dim)
971
+ )
972
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
973
+
974
+ self.pos_drop = nn.Dropout(p=drop_rate)
975
+
976
+ # stochastic depth
977
+ dpr = [
978
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
979
+ ] # stochastic depth decay rule
980
+
981
+ # build Residual Swin Transformer blocks (RSTB)
982
+ self.layers = nn.ModuleList()
983
+ for i_layer in range(self.num_layers):
984
+ layer = RSTB(
985
+ dim=embed_dim,
986
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
987
+ depth=depths[i_layer],
988
+ num_heads=num_heads[i_layer],
989
+ window_size=window_size,
990
+ mlp_ratio=self.mlp_ratio,
991
+ qkv_bias=qkv_bias,
992
+ drop=drop_rate,
993
+ attn_drop=attn_drop_rate,
994
+ drop_path=dpr[
995
+ sum(depths[:i_layer]) : sum(depths[: i_layer + 1])
996
+ ], # no impact on SR results
997
+ norm_layer=norm_layer,
998
+ downsample=None,
999
+ use_checkpoint=use_checkpoint,
1000
+ img_size=img_size,
1001
+ patch_size=patch_size,
1002
+ resi_connection=resi_connection,
1003
+ )
1004
+ self.layers.append(layer)
1005
+
1006
+ if self.upsampler == "pixelshuffle_hf":
1007
+ self.layers_hf = nn.ModuleList()
1008
+ for i_layer in range(self.num_layers):
1009
+ layer = RSTB(
1010
+ dim=embed_dim,
1011
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1012
+ depth=depths[i_layer],
1013
+ num_heads=num_heads[i_layer],
1014
+ window_size=window_size,
1015
+ mlp_ratio=self.mlp_ratio,
1016
+ qkv_bias=qkv_bias,
1017
+ drop=drop_rate,
1018
+ attn_drop=attn_drop_rate,
1019
+ drop_path=dpr[
1020
+ sum(depths[:i_layer]) : sum(depths[: i_layer + 1])
1021
+ ], # no impact on SR results
1022
+ norm_layer=norm_layer,
1023
+ downsample=None,
1024
+ use_checkpoint=use_checkpoint,
1025
+ img_size=img_size,
1026
+ patch_size=patch_size,
1027
+ resi_connection=resi_connection,
1028
+ )
1029
+ self.layers_hf.append(layer)
1030
+
1031
+ self.norm = norm_layer(self.num_features)
1032
+
1033
+ # build the last conv layer in deep feature extraction
1034
+ if resi_connection == "1conv":
1035
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1036
+ elif resi_connection == "3conv":
1037
+ # to save parameters and memory
1038
+ self.conv_after_body = nn.Sequential(
1039
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
1040
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1041
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
1042
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1043
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
1044
+ )
1045
+
1046
+ #####################################################################################################
1047
+ ################################ 3, high quality image reconstruction ################################
1048
+ if self.upsampler == "pixelshuffle":
1049
+ # for classical SR
1050
+ self.conv_before_upsample = nn.Sequential(
1051
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1052
+ )
1053
+ self.upsample = Upsample(upscale, num_feat)
1054
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1055
+ elif self.upsampler == "pixelshuffle_aux":
1056
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
1057
+ self.conv_before_upsample = nn.Sequential(
1058
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1059
+ )
1060
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1061
+ self.conv_after_aux = nn.Sequential(
1062
+ nn.Conv2d(3, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1063
+ )
1064
+ self.upsample = Upsample(upscale, num_feat)
1065
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1066
+
1067
+ elif self.upsampler == "pixelshuffle_hf":
1068
+ self.conv_before_upsample = nn.Sequential(
1069
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1070
+ )
1071
+ self.upsample = Upsample(upscale, num_feat)
1072
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
1073
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1074
+ self.conv_first_hf = nn.Sequential(
1075
+ nn.Conv2d(num_feat, embed_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)
1076
+ )
1077
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1078
+ self.conv_before_upsample_hf = nn.Sequential(
1079
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1080
+ )
1081
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1082
+
1083
+ elif self.upsampler == "pixelshuffledirect":
1084
+ # for lightweight SR (to save parameters)
1085
+ self.upsample = UpsampleOneStep(
1086
+ upscale,
1087
+ embed_dim,
1088
+ num_out_ch,
1089
+ (patches_resolution[0], patches_resolution[1]),
1090
+ )
1091
+ elif self.upsampler == "nearest+conv":
1092
+ # for real-world SR (less artifacts)
1093
+ assert self.upscale == 4, "only support x4 now."
1094
+ self.conv_before_upsample = nn.Sequential(
1095
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1096
+ )
1097
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1098
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1099
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1100
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1101
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1102
+ else:
1103
+ # for image denoising and JPEG compression artifact reduction
1104
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
1105
+
1106
+ self.apply(self._init_weights)
1107
+
1108
+ def _init_weights(self, m):
1109
+ if isinstance(m, nn.Linear):
1110
+ trunc_normal_(m.weight, std=0.02)
1111
+ if isinstance(m, nn.Linear) and m.bias is not None:
1112
+ nn.init.constant_(m.bias, 0)
1113
+ elif isinstance(m, nn.LayerNorm):
1114
+ nn.init.constant_(m.bias, 0)
1115
+ nn.init.constant_(m.weight, 1.0)
1116
+
1117
+ @torch.jit.ignore
1118
+ def no_weight_decay(self):
1119
+ return {"absolute_pos_embed"}
1120
+
1121
+ @torch.jit.ignore
1122
+ def no_weight_decay_keywords(self):
1123
+ return {"relative_position_bias_table"}
1124
+
1125
+ def check_image_size(self, x):
1126
+ _, _, h, w = x.size()
1127
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
1128
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
1129
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
1130
+ return x
1131
+
1132
+ def forward_features(self, x):
1133
+ x_size = (x.shape[2], x.shape[3])
1134
+ x = self.patch_embed(x)
1135
+ if self.ape:
1136
+ x = x + self.absolute_pos_embed
1137
+ x = self.pos_drop(x)
1138
+
1139
+ for layer in self.layers:
1140
+ x = layer(x, x_size)
1141
+
1142
+ x = self.norm(x) # B L C
1143
+ x = self.patch_unembed(x, x_size)
1144
+
1145
+ return x
1146
+
1147
+ def forward_features_hf(self, x):
1148
+ x_size = (x.shape[2], x.shape[3])
1149
+ x = self.patch_embed(x)
1150
+ if self.ape:
1151
+ x = x + self.absolute_pos_embed
1152
+ x = self.pos_drop(x)
1153
+
1154
+ for layer in self.layers_hf:
1155
+ x = layer(x, x_size)
1156
+
1157
+ x = self.norm(x) # B L C
1158
+ x = self.patch_unembed(x, x_size)
1159
+
1160
+ return x
1161
+
1162
+ def forward(self, x):
1163
+ H, W = x.shape[2:]
1164
+ x = self.check_image_size(x)
1165
+
1166
+ if self.upsampler == "pixelshuffle":
1167
+ # for classical SR
1168
+ x = self.conv_first(x)
1169
+ x = self.conv_after_body(self.forward_features(x)) + x
1170
+ x = self.conv_before_upsample(x)
1171
+ x = self.conv_last(self.upsample(x))
1172
+ elif self.upsampler == "pixelshuffle_aux":
1173
+ bicubic = F.interpolate(
1174
+ x,
1175
+ size=(H * self.upscale, W * self.upscale),
1176
+ mode="bicubic",
1177
+ align_corners=False,
1178
+ )
1179
+ bicubic = self.conv_bicubic(bicubic)
1180
+ x = self.conv_first(x)
1181
+ x = self.conv_after_body(self.forward_features(x)) + x
1182
+ x = self.conv_before_upsample(x)
1183
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
1184
+ x = self.conv_after_aux(aux)
1185
+ x = (
1186
+ self.upsample(x)[:, :, : H * self.upscale, : W * self.upscale]
1187
+ + bicubic[:, :, : H * self.upscale, : W * self.upscale]
1188
+ )
1189
+ x = self.conv_last(x)
1190
+ elif self.upsampler == "pixelshuffle_hf":
1191
+ # for classical SR with HF
1192
+ x = self.conv_first(x)
1193
+ x = self.conv_after_body(self.forward_features(x)) + x
1194
+ x_before = self.conv_before_upsample(x)
1195
+ x_out = self.conv_last(self.upsample(x_before))
1196
+
1197
+ x_hf = self.conv_first_hf(x_before)
1198
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
1199
+ x_hf = self.conv_before_upsample_hf(x_hf)
1200
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
1201
+ x = x_out + x_hf
1202
+
1203
+ elif self.upsampler == "pixelshuffledirect":
1204
+ # for lightweight SR
1205
+ x = self.conv_first(x)
1206
+ x = self.conv_after_body(self.forward_features(x)) + x
1207
+ x = self.upsample(x)
1208
+ elif self.upsampler == "nearest+conv":
1209
+ # for real-world SR
1210
+ x = self.conv_first(x)
1211
+ x = self.conv_after_body(self.forward_features(x)) + x
1212
+ x = self.conv_before_upsample(x)
1213
+ x = self.lrelu(
1214
+ self.conv_up1(
1215
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
1216
+ )
1217
+ )
1218
+ x = self.lrelu(
1219
+ self.conv_up2(
1220
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
1221
+ )
1222
+ )
1223
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1224
+ else:
1225
+ # for image denoising and JPEG compression artifact reduction
1226
+ x_first = self.conv_first(x)
1227
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
1228
+ x = x + self.conv_last(res)
1229
+
1230
+ if self.upsampler == "pixelshuffle_aux":
1231
+ return x[:, :, : H * self.upscale, : W * self.upscale], aux
1232
+
1233
+ elif self.upsampler == "pixelshuffle_hf":
1234
+ return (
1235
+ x_out[:, :, : H * self.upscale, : W * self.upscale],
1236
+ x[:, :, : H * self.upscale, : W * self.upscale],
1237
+ x_hf[:, :, : H * self.upscale, : W * self.upscale],
1238
+ )
1239
+
1240
+ else:
1241
+ return x[:, :, : H * self.upscale, : W * self.upscale]
1242
+
1243
+ def flops(self):
1244
+ flops = 0
1245
+ H, W = self.patches_resolution
1246
+ flops += H * W * 3 * self.embed_dim * 9
1247
+ flops += self.patch_embed.flops()
1248
+ for i, layer in enumerate(self.layers):
1249
+ flops += layer.flops()
1250
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1251
+ flops += self.upsample.flops()
1252
+ return flops
1253
+
1254
+ def butterworth_filter(shape: tuple[int, int], cutoff: int, order: int) -> torch.Tensor:
1255
+ """
1256
+ Creates a Butterworth low-pass filter.
1257
+
1258
+ Args:
1259
+ shape: (rows, cols) of the filter.
1260
+ cutoff: Cutoff frequency.
1261
+ order: Order of the Butterworth filter.
1262
+
1263
+ Returns:
1264
+ torch.Tensor: Normalized Butterworth filter.
1265
+ """
1266
+ rows, cols = shape
1267
+ crow, ccol = rows // 2, cols // 2
1268
+ filter = torch.zeros((rows, cols), dtype=torch.float32)
1269
+ for u in range(rows):
1270
+ for v in range(cols):
1271
+ distance = ((u - crow) ** 2 + (v - ccol) ** 2) ** 0.5
1272
+ filter[u, v] = 1 / (1 + (distance / cutoff) ** (2 * order))
1273
+ filter /= filter.sum()
1274
+ return filter
1275
+
1276
+
1277
+ class CNNHardConstraint(nn.Module):
1278
+ """
1279
+ Applies a convolutional hard constraint using predefined filters for low-pass and high-pass filtering.
1280
+
1281
+ Args:
1282
+ filter_method: The type of filter to apply ('ideal', 'butterworth', 'gaussian', 'sigmoid').
1283
+ filter_hyperparameters: Dictionary containing hyperparameters specific to the chosen filter method.
1284
+ scale_factor: Scaling factor used to determine kernel size and cutoff frequency.
1285
+ in_channels: Number of input channels.
1286
+ out_channels: List of channels to be processed (default is [0, 1, 2, 3, 4, 5]).
1287
+ """
1288
+
1289
+ def __init__(
1290
+ self,
1291
+ scale_factor: int,
1292
+ in_channels: int,
1293
+ out_channels: list = [0, 1, 2, 3, 4, 5],
1294
+ ):
1295
+ super().__init__()
1296
+
1297
+ self.in_channels = in_channels
1298
+
1299
+ # Estimate the kernel according to the scale
1300
+ kernel_size = scale_factor * 3 + 1
1301
+ cutoff = scale_factor * 2
1302
+
1303
+ # Define the convolution layer with multiple input and output channels
1304
+ self.conv = nn.Conv2d(
1305
+ in_channels=in_channels,
1306
+ out_channels=len(out_channels),
1307
+ kernel_size=kernel_size,
1308
+ padding=kernel_size // 2,
1309
+ bias=False,
1310
+ groups=in_channels,
1311
+ )
1312
+
1313
+ # Remove the gradient for the filter weights
1314
+ self.conv.weight.requires_grad = False
1315
+
1316
+ # Initialize the filter kernel based on the filter method
1317
+ # hyperparameters["order"] = 6
1318
+ weight_data = butterworth_filter((kernel_size, kernel_size), cutoff, 6)
1319
+
1320
+ # Apply the same filter to all input channels
1321
+ self.conv.weight.data = (
1322
+ weight_data.unsqueeze(0).unsqueeze(0).repeat(in_channels, 1, 1, 1)
1323
+ )
1324
+ self.out_channels = out_channels
1325
+
1326
+ def forward(self, lr: torch.Tensor, sr: torch.Tensor) -> torch.Tensor:
1327
+ """
1328
+ Applies the filter constraint on the super-resolution image.
1329
+
1330
+ Args:
1331
+ lr: Low-resolution input tensor.
1332
+ sr: Super-resolution output tensor.
1333
+
1334
+ Returns:
1335
+ torch.Tensor: The resulting hybrid image after applying the constraint.
1336
+ """
1337
+ # Upsample the LR image to the size of SR
1338
+ lr = lr[:, self.out_channels]
1339
+
1340
+ # Upsample the LR image to the size of SR
1341
+ lr_up = F.interpolate(lr, size=sr.shape[-2:], mode="bicubic", antialias=True)
1342
+
1343
+ # Apply the convolutional filter to both LR and SR images
1344
+ lr_filtered = self.conv(lr_up)
1345
+ sr_filtered = self.conv(sr)
1346
+
1347
+ # Combine low-pass and high-pass components
1348
+ hybrid_image = lr_filtered + (sr - sr_filtered)
1349
+
1350
+ return hybrid_image
1351
+
1352
+
1353
+ class HardConstraintModel(torch.nn.Module):
1354
+ def __init__(self) -> None:
1355
+ super().__init__()
1356
+ params = {
1357
+ "img_size": (128, 128),
1358
+ "in_channels": 4,
1359
+ "out_channels": 4,
1360
+ "embed_dim": 72,
1361
+ "depths": [4, 4, 4, 4],
1362
+ "num_heads": [4, 4, 4, 4],
1363
+ "window_size": 4,
1364
+ "mlp_ratio": 2.0,
1365
+ "upscale": 4,
1366
+ "resi_connection": "1conv",
1367
+ "upsampler": "pixelshuffledirect",
1368
+ }
1369
+ self.sr_model = Swin2SR(**params)
1370
+ self.hard_constraint = CNNHardConstraint(8, 4, [0, 1, 2, 3])
1371
+
1372
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1373
+ sr = self.sr_model(x)
1374
+ return self.hard_constraint(x, sr)
1375
+
1376
+
1377
+
1378
+
1379
+ # MLSTAC API -----------------------------------------------------------------------
1380
+ def example_data(path: pathlib.Path, *args, **kwargs) -> torch.Tensor:
1381
+ data_f = safetensors.torch.load_file(path / "example_data.safetensor")
1382
+ return data_f["example_data"][[3, 2, 1, 7], 128:384, 128:384][None]
1383
+
1384
+ def trainable_model(path, *args, **kwargs):
1385
+ trainable_f = path / "model.safetensor"
1386
+
1387
+ # Load model parameters
1388
+ weights = safetensors.torch.load_file(trainable_f)
1389
+
1390
+ # Load model
1391
+ srmodel = HardConstraintModel()
1392
+ srmodel.load_state_dict(weights)
1393
+
1394
+ return srmodel
1395
+
1396
+ def compiled_model(path, *args, **kwargs):
1397
+ trainable_f = path / "model.safetensor"
1398
+
1399
+ # Load model parameters
1400
+ weights = safetensors.torch.load_file(trainable_f)
1401
+
1402
+ # Load model
1403
+ srmodel = HardConstraintModel()
1404
+ srmodel.load_state_dict(weights)
1405
+ srmodel.eval()
1406
+
1407
+ for param in srmodel.parameters():
1408
+ param.requires_grad = False
1409
+ return srmodel
Swin_Light_SR/mlm.json ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "Feature",
3
+ "stac_version": "1.1.0",
4
+ "stac_extensions": [
5
+ "https://stac-extensions.github.io/mlm/v1.4.0/schema.json"
6
+ ],
7
+ "id": "SwinSRv1 model",
8
+ "geometry": {
9
+ "type": "Polygon",
10
+ "coordinates": [
11
+ [
12
+ [
13
+ -180.0,
14
+ -90.0
15
+ ],
16
+ [
17
+ -180.0,
18
+ 90.0
19
+ ],
20
+ [
21
+ 180.0,
22
+ 90.0
23
+ ],
24
+ [
25
+ 180.0,
26
+ -90.0
27
+ ],
28
+ [
29
+ -180.0,
30
+ -90.0
31
+ ]
32
+ ]
33
+ ]
34
+ },
35
+ "bbox": [
36
+ -180,
37
+ -90,
38
+ 180,
39
+ 90
40
+ ],
41
+ "properties": {
42
+ "start_datetime": "1900-01-01T00:00:00Z",
43
+ "end_datetime": "9999-01-01T00:00:00Z",
44
+ "description": "A SwinSRv1 model trained on the SEN2NAIPv2 dataset.",
45
+ "forward_backward_pass": {
46
+ "32": 68.051456,
47
+ "64": 262.496768,
48
+ "128": 1040.278016,
49
+ "256": 4151.403008,
50
+ "512": 16595.902976
51
+ },
52
+ "dependencies": [
53
+ "torch",
54
+ "safetensors.torch",
55
+ "timm",
56
+ "einops"
57
+ ],
58
+ "mlm:framework": "pytorch",
59
+ "mlm:framework_version": "2.1.2+cu121",
60
+ "file:size": 3142752,
61
+ "mlm:memory_size": 1,
62
+ "mlm:accelerator": "cuda",
63
+ "mlm:accelerator_constrained": false,
64
+ "mlm:accelerator_summary": "Unknown",
65
+ "mlm:name": "Swin_Light_SR",
66
+ "mlm:architecture": "SwinSRv1",
67
+ "mlm:tasks": [
68
+ "super-resolution"
69
+ ],
70
+ "mlm:input": [
71
+ {
72
+ "name": "4 Band Sentinel-2 10m bands",
73
+ "bands": [
74
+ "B04",
75
+ "B03",
76
+ "B02",
77
+ "B08"
78
+ ],
79
+ "input": {
80
+ "shape": [
81
+ -1,
82
+ 4,
83
+ 128,
84
+ 128
85
+ ],
86
+ "dim_order": [
87
+ "batch",
88
+ "channel",
89
+ "height",
90
+ "width"
91
+ ],
92
+ "data_type": "float16"
93
+ },
94
+ "pre_processing_function": null
95
+ }
96
+ ],
97
+ "mlm:output": [
98
+ {
99
+ "name": "super-resolution",
100
+ "tasks": [
101
+ "super-resolution"
102
+ ],
103
+ "result": {
104
+ "shape": [
105
+ -1,
106
+ 4,
107
+ 512,
108
+ 512
109
+ ],
110
+ "dim_order": [
111
+ "batch",
112
+ "channel",
113
+ "height",
114
+ "width"
115
+ ],
116
+ "data_type": "float16"
117
+ },
118
+ "classification:classes": [],
119
+ "post_processing_function": null
120
+ }
121
+ ],
122
+ "mlm:total_parameters": 1036888,
123
+ "mlm:pretrained": true,
124
+ "datetime": null
125
+ },
126
+ "links": [],
127
+ "assets": {
128
+ "trainable": {
129
+ "href": "https://huggingface.co/tacofoundation/mlstac/resolve/main/Swin_Light_SR/model.safetensor",
130
+ "type": "application/octet-stream; application=safetensor",
131
+ "title": "Pytorch weights checkpoint",
132
+ "description": "A SwinSRv1 model trained on the SEN2NAIPv2 dataset.",
133
+ "mlm:artifact_type": "safetensor.torch.save_file",
134
+ "roles": [
135
+ "mlm:model",
136
+ "mlm:weights",
137
+ "data"
138
+ ]
139
+ },
140
+ "source_code": {
141
+ "href": "https://huggingface.co/tacofoundation/mlstac/resolve/main/Swin_Light_SR/load.py",
142
+ "type": "text/x-python",
143
+ "title": "Model load script",
144
+ "description": "Source code to run the model.",
145
+ "roles": [
146
+ "mlm:source_code",
147
+ "code"
148
+ ]
149
+ },
150
+ "example_data": {
151
+ "href": "https://huggingface.co/tacofoundation/mlstac/resolve/main/Swin_Light_SR/example_data.safetensor",
152
+ "type": "application/octet-stream; application=safetensors",
153
+ "title": "Example Sentinel-2 image",
154
+ "description": "Example Sentinel-2 image for model inference.",
155
+ "roles": [
156
+ "mlm:example_data",
157
+ "data"
158
+ ]
159
+ }
160
+ },
161
+ "collection": "ml-model"
162
+ }
Swin_Light_SR/model.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24cec2469c397eea0a729cae01601e390c8acd52339adbd5118498ac20298f20
3
+ size 12626800