csaybar commited on
Commit
88a0ceb
·
verified ·
1 Parent(s): 2f6d0c4

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -90,3 +90,8 @@ SR_S2_BestModel/auxiliar_refsrx2.safetensor filter=lfs diff=lfs merge=lfs -text
90
  SR_S2_BestModel/auxiliar_sr.safetensor filter=lfs diff=lfs merge=lfs -text
91
  SR_S2_BestModel/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
92
  SR_S2_BestModel/model.safetensor filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
90
  SR_S2_BestModel/auxiliar_sr.safetensor filter=lfs diff=lfs merge=lfs -text
91
  SR_S2_BestModel/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
92
  SR_S2_BestModel/model.safetensor filter=lfs diff=lfs merge=lfs -text
93
+ SR_S2_FastModel/auxiliar_refsrx2.jit filter=lfs diff=lfs merge=lfs -text
94
+ SR_S2_FastModel/auxiliar_sr.jit filter=lfs diff=lfs merge=lfs -text
95
+ SR_S2_FastModel/example_data.safetensor filter=lfs diff=lfs merge=lfs -text
96
+ SR_S2_FastModel/model.jit filter=lfs diff=lfs merge=lfs -text
97
+ SR_S2_FastModel/model.safetensor filter=lfs diff=lfs merge=lfs -text
SR_S2_FastModel/auxiliar_refsrx2.jit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a700df91e9eba64d0039f4ccd476c47bbf81e6a3e9749c13236aa0ef2eaec62
3
+ size 2532385
SR_S2_FastModel/auxiliar_sr.jit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d357606c29caffbee3316afd2badf2a2b2c289dd23a050d9d346201faf5c7a01
3
+ size 2561891
SR_S2_FastModel/example_data.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7709dd46aabc069c2005f39ce830cb5659306a9c11221307557c56f6ed6cf65
3
+ size 13631584
SR_S2_FastModel/load.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # I stole the code from here: https://github.com/hongyuanyu/SPAN
2
+ # The author of the code deserves all the credit. I just make
3
+ # basic modifications to make it work with my codebase.
4
+
5
+
6
+ from collections import OrderedDict
7
+ from typing import List, Optional, Union
8
+
9
+ import safetensors.numpy
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+ import safetensors.torch
14
+ import pathlib
15
+
16
+
17
+ def _make_pair(value: int) -> tuple:
18
+ """
19
+ Converts a single integer into a tuple of the same integer repeated twice.
20
+
21
+ Args:
22
+ value (int): Integer value to be converted.
23
+
24
+ Returns:
25
+ tuple: Tuple containing the integer repeated twice.
26
+ """
27
+ if isinstance(value, int):
28
+ value = (value,) * 2
29
+ return value
30
+
31
+
32
+ def conv_layer(
33
+ in_channels: int, out_channels: int, kernel_size: int, bias: bool = True
34
+ ) -> nn.Conv2d:
35
+ """
36
+ Creates a 2D convolutional layer with adaptive padding.
37
+
38
+ Args:
39
+ in_channels (int): Number of input channels.
40
+ out_channels (int): Number of output channels.
41
+ kernel_size (int): Size of the convolution kernel.
42
+ bias (bool, optional): Whether to include a bias term. Defaults to True.
43
+
44
+ Returns:
45
+ nn.Conv2d: 2D convolutional layer with calculated padding.
46
+ """
47
+ kernel_size = _make_pair(kernel_size)
48
+ padding = (int((kernel_size[0] - 1) / 2), int((kernel_size[1] - 1) / 2))
49
+ return nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
50
+
51
+
52
+ def activation(
53
+ act_type: str, inplace: bool = True, neg_slope: float = 0.05, n_prelu: int = 1
54
+ ) -> nn.Module:
55
+ """
56
+ Returns an activation layer based on the specified type.
57
+
58
+ Args:
59
+ act_type (str): Type of activation ('relu', 'lrelu', 'prelu').
60
+ inplace (bool, optional): If True, performs the operation in-place. Defaults to True.
61
+ neg_slope (float, optional): Negative slope for 'lrelu' and 'prelu'. Defaults to 0.05.
62
+ n_prelu (int, optional): Number of parameters for 'prelu'. Defaults to 1.
63
+
64
+ Returns:
65
+ nn.Module: Activation layer.
66
+ """
67
+ act_type = act_type.lower()
68
+ if act_type == "relu":
69
+ layer = nn.ReLU(inplace)
70
+ elif act_type == "lrelu":
71
+ layer = nn.LeakyReLU(neg_slope, inplace)
72
+ elif act_type == "prelu":
73
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
74
+ else:
75
+ raise NotImplementedError(
76
+ "activation layer [{:s}] is not found".format(act_type)
77
+ )
78
+ return layer
79
+
80
+
81
+ def sequential(*args) -> nn.Sequential:
82
+ """
83
+ Constructs a sequential container for the provided modules.
84
+
85
+ Args:
86
+ args: Modules in order of execution.
87
+
88
+ Returns:
89
+ nn.Sequential: A Sequential container.
90
+ """
91
+ if len(args) == 1:
92
+ if isinstance(args[0], OrderedDict):
93
+ raise NotImplementedError("sequential does not support OrderedDict input.")
94
+ return args[0]
95
+ modules = []
96
+ for module in args:
97
+ if isinstance(module, nn.Sequential):
98
+ for submodule in module.children():
99
+ modules.append(submodule)
100
+ elif isinstance(module, nn.Module):
101
+ modules.append(module)
102
+ return nn.Sequential(*modules)
103
+
104
+
105
+ def pixelshuffle_block(
106
+ in_channels: int, out_channels: int, upscale_factor: int = 2, kernel_size: int = 3
107
+ ) -> nn.Sequential:
108
+ """
109
+ Creates an upsampling block using pixel shuffle.
110
+
111
+ Args:
112
+ in_channels (int): Number of input channels.
113
+ out_channels (int): Number of output channels.
114
+ upscale_factor (int, optional): Factor by which to upscale. Defaults to 2.
115
+ kernel_size (int, optional): Size of the convolution kernel. Defaults to 3.
116
+
117
+ Returns:
118
+ nn.Sequential: Sequential block for upsampling.
119
+ """
120
+ conv = conv_layer(in_channels, out_channels * (upscale_factor**2), kernel_size)
121
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
122
+ return sequential(conv, pixel_shuffle)
123
+
124
+
125
+ class Conv3XC(nn.Module):
126
+ def __init__(
127
+ self,
128
+ c_in: int,
129
+ c_out: int,
130
+ gain1: int = 1,
131
+ s: int = 1,
132
+ bias: bool = True,
133
+ relu: bool = False,
134
+ train_mode: bool = True,
135
+ ):
136
+ """
137
+ Custom 3-stage convolutional block with optional ReLU activation and train/evaluation mode support.
138
+
139
+ Args:
140
+ c_in (int): Number of input channels.
141
+ c_out (int): Number of output channels.
142
+ gain1 (int, optional): Gain multiplier for intermediate layers. Defaults to 1.
143
+ s (int, optional): Stride value for the convolutions. Defaults to 1.
144
+ bias (bool, optional): Whether to include a bias term in the convolutions. Defaults to True.
145
+ relu (bool, optional): If True, apply a LeakyReLU activation after the convolution. Defaults to False.
146
+ train_mode (bool, optional): If True, use training mode with learnable parameters. Defaults to True.
147
+ """
148
+ super(Conv3XC, self).__init__()
149
+ self.train_mode = train_mode
150
+ self.weight_concat = None
151
+ self.bias_concat = None
152
+ self.update_params_flag = False
153
+ self.stride = s
154
+ self.has_relu = relu
155
+ gain = gain1
156
+
157
+ self.sk = nn.Conv2d(
158
+ in_channels=c_in,
159
+ out_channels=c_out,
160
+ kernel_size=1,
161
+ padding=0,
162
+ stride=s,
163
+ bias=bias,
164
+ )
165
+ self.conv = nn.Sequential(
166
+ nn.Conv2d(
167
+ in_channels=c_in,
168
+ out_channels=c_in * gain,
169
+ kernel_size=1,
170
+ padding=0,
171
+ bias=bias,
172
+ ),
173
+ nn.Conv2d(
174
+ in_channels=c_in * gain,
175
+ out_channels=c_out * gain,
176
+ kernel_size=3,
177
+ stride=s,
178
+ padding=0,
179
+ bias=bias,
180
+ ),
181
+ nn.Conv2d(
182
+ in_channels=c_out * gain,
183
+ out_channels=c_out,
184
+ kernel_size=1,
185
+ padding=0,
186
+ bias=bias,
187
+ ),
188
+ )
189
+
190
+ self.eval_conv = nn.Conv2d(
191
+ in_channels=c_in,
192
+ out_channels=c_out,
193
+ kernel_size=3,
194
+ padding=1,
195
+ stride=s,
196
+ bias=bias,
197
+ )
198
+ self.eval_conv.weight.requires_grad = False
199
+ self.eval_conv.bias.requires_grad = False
200
+ if not self.train_mode:
201
+ self.update_params()
202
+
203
+ def update_params(self):
204
+ """
205
+ Updates the parameters for evaluation mode by combining weights from the convolution layers.
206
+ """
207
+ w1 = self.conv[0].weight.data.clone().detach()
208
+ b1 = self.conv[0].bias.data.clone().detach()
209
+ w2 = self.conv[1].weight.data.clone().detach()
210
+ b2 = self.conv[1].bias.data.clone().detach()
211
+ w3 = self.conv[2].weight.data.clone().detach()
212
+ b3 = self.conv[2].bias.data.clone().detach()
213
+
214
+ w = (
215
+ F.conv2d(w1.flip(2, 3).permute(1, 0, 2, 3), w2, padding=2, stride=1)
216
+ .flip(2, 3)
217
+ .permute(1, 0, 2, 3)
218
+ )
219
+ b = (w2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + b2
220
+
221
+ self.weight_concat = (
222
+ F.conv2d(w.flip(2, 3).permute(1, 0, 2, 3), w3, padding=0, stride=1)
223
+ .flip(2, 3)
224
+ .permute(1, 0, 2, 3)
225
+ )
226
+ self.bias_concat = (w3 * b.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + b3
227
+
228
+ sk_w = self.sk.weight.data.clone().detach()
229
+ sk_b = self.sk.bias.data.clone().detach()
230
+ target_kernel_size = 3
231
+
232
+ H_pixels_to_pad = (target_kernel_size - 1) // 2
233
+ W_pixels_to_pad = (target_kernel_size - 1) // 2
234
+ sk_w = F.pad(
235
+ sk_w, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]
236
+ )
237
+
238
+ self.weight_concat = self.weight_concat + sk_w
239
+ self.bias_concat = self.bias_concat + sk_b
240
+
241
+ self.eval_conv.weight.data = self.weight_concat
242
+ self.eval_conv.bias.data = self.bias_concat
243
+
244
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
245
+ """
246
+ Forward pass of the convolution block.
247
+
248
+ Args:
249
+ x (torch.Tensor): Input tensor.
250
+
251
+ Returns:
252
+ torch.Tensor: Output tensor after convolution and optional activation.
253
+ """
254
+ if self.train_mode:
255
+ pad = 1
256
+ x_pad = F.pad(x, (pad, pad, pad, pad), "constant", 0)
257
+ out = self.conv(x_pad) + self.sk(x)
258
+ else:
259
+ self.update_params()
260
+ out = self.eval_conv(x)
261
+
262
+ if self.has_relu:
263
+ out = F.leaky_relu(out, negative_slope=0.05)
264
+ return out
265
+
266
+
267
+ class SPAB(nn.Module):
268
+ def __init__(
269
+ self,
270
+ in_channels: int,
271
+ mid_channels: Optional[int] = None,
272
+ out_channels: Optional[int] = None,
273
+ train_mode: bool = True,
274
+ bias: bool = False,
275
+ ):
276
+ """
277
+ Self-parameterized attention block (SPAB) with multiple convolution layers.
278
+
279
+ Args:
280
+ in_channels (int): Number of input channels.
281
+ mid_channels (Optional[int], optional): Number of middle channels. Defaults to in_channels.
282
+ out_channels (Optional[int], optional): Number of output channels. Defaults to in_channels.
283
+ train_mode (bool, optional): Indicates if the block is in training mode. Defaults to True.
284
+ bias (bool, optional): Include bias in convolutions. Defaults to False.
285
+ """
286
+ super(SPAB, self).__init__()
287
+ if mid_channels is None:
288
+ mid_channels = in_channels
289
+ if out_channels is None:
290
+ out_channels = in_channels
291
+
292
+ self.in_channels = in_channels
293
+ self.c1_r = Conv3XC(
294
+ in_channels, mid_channels, gain1=2, s=1, train_mode=train_mode
295
+ )
296
+ self.c2_r = Conv3XC(
297
+ mid_channels, mid_channels, gain1=2, s=1, train_mode=train_mode
298
+ )
299
+ self.c3_r = Conv3XC(
300
+ mid_channels, out_channels, gain1=2, s=1, train_mode=train_mode
301
+ )
302
+ self.act1 = torch.nn.SiLU(inplace=True)
303
+ self.act2 = activation("lrelu", neg_slope=0.1, inplace=True)
304
+
305
+ def forward(self, x: torch.Tensor) -> tuple:
306
+ """
307
+ Forward pass of the SPAB block.
308
+
309
+ Args:
310
+ x (torch.Tensor): Input tensor.
311
+
312
+ Returns:
313
+ tuple: (Output tensor, intermediate tensor, attention map).
314
+ """
315
+ out1 = self.c1_r(x)
316
+ out1_act = self.act1(out1)
317
+
318
+ out2 = self.c2_r(out1_act)
319
+ out2_act = self.act1(out2)
320
+
321
+ out3 = self.c3_r(out2_act)
322
+
323
+ sim_att = torch.sigmoid(out3) - 0.5
324
+ out = (out3 + x) * sim_att
325
+
326
+ return out, out1, sim_att
327
+
328
+
329
+ class CNNSR(nn.Module):
330
+ """
331
+ Swift Parameter-free Attention Network (SPAN) for efficient super-resolution
332
+ with deeper layers and channel attention.
333
+ """
334
+
335
+ def __init__(
336
+ self,
337
+ in_channels: int,
338
+ out_channels: int,
339
+ feature_channels: int = 48,
340
+ upscale: int = 4,
341
+ bias: bool = True,
342
+ train_mode: bool = True,
343
+ num_blocks: int = 10,
344
+ **kwargs,
345
+ ):
346
+ """
347
+ Initializes the CNNSR model.
348
+
349
+ Args:
350
+ in_channels (int): Number of input channels.
351
+ out_channels (int): Number of output channels.
352
+ feature_channels (int, optional): Number of feature channels. Defaults to 48.
353
+ upscale (int, optional): Upscaling factor. Defaults to 4.
354
+ bias (bool, optional): Whether to include a bias term. Defaults to True.
355
+ train_mode (bool, optional): If True, the model is in training mode. Defaults to True.
356
+ num_blocks (int, optional): Number of attention blocks in the network. Defaults to 10.
357
+ """
358
+ super(CNNSR, self).__init__()
359
+
360
+ # Initial Convolution
361
+ self.conv_1 = Conv3XC(
362
+ in_channels, feature_channels, gain1=2, s=1, train_mode=train_mode
363
+ )
364
+
365
+ # Deeper Blocks
366
+ self.blocks = nn.ModuleList(
367
+ [
368
+ SPAB(feature_channels, bias=bias, train_mode=train_mode)
369
+ for _ in range(num_blocks)
370
+ ]
371
+ )
372
+
373
+ # Convolution after attention blocks
374
+ self.conv_cat = conv_layer(
375
+ feature_channels * 4, feature_channels, kernel_size=1, bias=True
376
+ )
377
+ self.conv_2 = Conv3XC(
378
+ feature_channels, feature_channels, gain1=2, s=1, train_mode=train_mode
379
+ )
380
+
381
+ # Upsampling
382
+ self.upsampler = pixelshuffle_block(
383
+ feature_channels, out_channels, upscale_factor=upscale
384
+ )
385
+
386
+ def forward(
387
+ self, x: torch.Tensor, save_attentions: Optional[List[int]] = None
388
+ ) -> Union[torch.Tensor, tuple]:
389
+ """
390
+ Forward pass of the CNNSR model.
391
+
392
+ Args:
393
+ x (torch.Tensor): Input tensor.
394
+ save_attentions (Optional[List[int]], optional): List of block indices from which to save attention maps.
395
+
396
+ Returns:
397
+ torch.Tensor: Super-resolved output.
398
+ tuple: If save_attentions is specified, returns (output tensor, attention maps).
399
+ """
400
+ # Initial Convolution
401
+ out_feature = self.conv_1(x)
402
+
403
+ # Pass through all blocks, accumulating attention outputs
404
+ attentions = []
405
+ for index, block in enumerate(self.blocks):
406
+ out, out2, att = block(out_feature)
407
+
408
+ # Save the first residual block output
409
+ if index == 0:
410
+ out_b1 = out
411
+
412
+ # Save the last residual block output
413
+ if index == len(self.blocks) - 1:
414
+ out_blast = out2
415
+
416
+ # Save attention if needed
417
+ if save_attentions is not None and index in save_attentions:
418
+ attentions.append(att)
419
+
420
+ # Final Convolution and concatenation
421
+ out_bn = self.conv_2(out)
422
+ out = self.conv_cat(torch.cat([out_feature, out_bn, out_b1, out_blast], 1))
423
+
424
+ # Upsample
425
+ output = self.upsampler(out)
426
+
427
+ if save_attentions is not None:
428
+ return output, attentions
429
+ return output
430
+
431
+
432
+ def butterworth_filter(shape: tuple[int, int], cutoff: int, order: int) -> torch.Tensor:
433
+ """
434
+ Creates a Butterworth low-pass filter.
435
+
436
+ Args:
437
+ shape: (rows, cols) of the filter.
438
+ cutoff: Cutoff frequency.
439
+ order: Order of the Butterworth filter.
440
+
441
+ Returns:
442
+ torch.Tensor: Normalized Butterworth filter.
443
+ """
444
+ rows, cols = shape
445
+ crow, ccol = rows // 2, cols // 2
446
+ filter = torch.zeros((rows, cols), dtype=torch.float32)
447
+ for u in range(rows):
448
+ for v in range(cols):
449
+ distance = ((u - crow) ** 2 + (v - ccol) ** 2) ** 0.5
450
+ filter[u, v] = 1 / (1 + (distance / cutoff) ** (2 * order))
451
+ filter /= filter.sum()
452
+ return filter
453
+
454
+
455
+ class CNNHardConstraint(nn.Module):
456
+ """
457
+ Applies a convolutional hard constraint using predefined filters for low-pass and high-pass filtering.
458
+
459
+ Args:
460
+ filter_method: The type of filter to apply ('ideal', 'butterworth', 'gaussian', 'sigmoid').
461
+ filter_hyperparameters: Dictionary containing hyperparameters specific to the chosen filter method.
462
+ scale_factor: Scaling factor used to determine kernel size and cutoff frequency.
463
+ in_channels: Number of input channels.
464
+ out_channels: List of channels to be processed (default is [0, 1, 2, 3, 4, 5]).
465
+ """
466
+
467
+ def __init__(
468
+ self,
469
+ scale_factor: int,
470
+ in_channels: int,
471
+ out_channels: list = [0, 1, 2, 3, 4, 5],
472
+ ):
473
+ super().__init__()
474
+
475
+ self.in_channels = in_channels
476
+
477
+ # Estimate the kernel according to the scale
478
+ kernel_size = scale_factor * 3 + 1
479
+ cutoff = scale_factor * 2
480
+
481
+ # Define the convolution layer with multiple input and output channels
482
+ self.conv = nn.Conv2d(
483
+ in_channels=in_channels,
484
+ out_channels=len(out_channels),
485
+ kernel_size=kernel_size,
486
+ padding=kernel_size // 2,
487
+ bias=False,
488
+ groups=in_channels,
489
+ )
490
+
491
+ # Remove the gradient for the filter weights
492
+ self.conv.weight.requires_grad = False
493
+
494
+ # Initialize the filter kernel based on the filter method
495
+ # hyperparameters["order"] = 6
496
+ weight_data = butterworth_filter((kernel_size, kernel_size), cutoff, 6)
497
+
498
+ # Apply the same filter to all input channels
499
+ self.conv.weight.data = (
500
+ weight_data.unsqueeze(0).unsqueeze(0).repeat(in_channels, 1, 1, 1)
501
+ )
502
+ self.out_channels = out_channels
503
+
504
+ def forward(self, lr: torch.Tensor, sr: torch.Tensor) -> torch.Tensor:
505
+ """
506
+ Applies the filter constraint on the super-resolution image.
507
+
508
+ Args:
509
+ lr: Low-resolution input tensor.
510
+ sr: Super-resolution output tensor.
511
+
512
+ Returns:
513
+ torch.Tensor: The resulting hybrid image after applying the constraint.
514
+ """
515
+ # Upsample the LR image to the size of SR
516
+ lr = lr[:, self.out_channels]
517
+
518
+ # Upsample the LR image to the size of SR
519
+ lr_up = F.interpolate(lr, size=sr.shape[-2:], mode="bicubic", antialias=True)
520
+
521
+ # Apply the convolutional filter to both LR and SR images
522
+ lr_filtered = self.conv(lr_up)
523
+ sr_filtered = self.conv(sr)
524
+
525
+ # Combine low-pass and high-pass components
526
+ hybrid_image = lr_filtered + (sr - sr_filtered)
527
+
528
+ return hybrid_image
529
+
530
+
531
+ class HardConstraintModel(torch.nn.Module):
532
+ def __init__(self, sr_model_rgbn, sr_model_rswir):
533
+ super().__init__()
534
+ params = {
535
+ "in_channels": 10,
536
+ "out_channels": 6,
537
+ "feature_channels": 24,
538
+ "upscale": 1,
539
+ "bias": True,
540
+ "train_mode": True,
541
+ "num_blocks": 6,
542
+ }
543
+ self.sr_model = CNNSR(**params)
544
+ self.hard_constraint = CNNHardConstraint(2, 6, [0, 1, 2, 3, 4, 5])
545
+
546
+ # Load the model and freeze the parameters
547
+ self.sr_model_rgbn = torch.jit.load(sr_model_rgbn)
548
+ self.sr_model_rgbn.eval()
549
+ for param in self.sr_model_rgbn.parameters():
550
+ param.requires_grad = False
551
+
552
+ self.sr_model_rswir = torch.jit.load(sr_model_rswir)
553
+ self.sr_model_rswir.eval()
554
+ for param in self.sr_model_rswir.parameters():
555
+ param.requires_grad = False
556
+
557
+
558
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
559
+ # Band Selection
560
+ bands_20m = [3, 4, 5, 7, 8, 9]
561
+ bands_10m = [2, 1, 0, 6] # WARNING: The SR model needs RGBNIR bands? why? because i'm stupid
562
+
563
+ # Run Referece SR in the RSWIR bands (from 20m to 10m)
564
+ allbands10m = self.sr_model_rswir(x)
565
+
566
+ # Convert the SWIR bands to 2.5m
567
+ rsiwr_10m = allbands10m[:, bands_20m]
568
+ rsiwr_2dot5m_billinear = torch.nn.functional.interpolate(
569
+ rsiwr_10m, scale_factor=4, mode="bilinear", antialias=True
570
+ )
571
+
572
+ # Run SR in the RGBN bands (from 10m to 2.5m)
573
+ rgbn_2dot5m = self.sr_model_rgbn(x[:, bands_10m])
574
+
575
+ # Reorder the bands from RGBNIR to BGRNIR
576
+ rgbn_2dot5m = rgbn_2dot5m[:, [2, 1, 0, 3]]
577
+
578
+ # Run the fusion x4 model in the SWIR bands (10m to 2.5m)
579
+ input_data = torch.cat([rsiwr_2dot5m_billinear, rgbn_2dot5m], dim=1)
580
+ rsiwr_2dot5m_sr = self.hard_constraint(rsiwr_2dot5m_billinear, self.sr_model(input_data))
581
+
582
+ # Order the channels back
583
+ results = torch.stack(
584
+ [
585
+ rgbn_2dot5m[:, 0],
586
+ rgbn_2dot5m[:, 1],
587
+ rgbn_2dot5m[:, 2],
588
+ rsiwr_2dot5m_sr[:, 0],
589
+ rsiwr_2dot5m_sr[:, 1],
590
+ rsiwr_2dot5m_sr[:, 2],
591
+ rgbn_2dot5m[:, 3],
592
+ rsiwr_2dot5m_sr[:, 3],
593
+ rsiwr_2dot5m_sr[:, 4],
594
+ rsiwr_2dot5m_sr[:, 5],
595
+ ],
596
+ dim=1,
597
+ )
598
+
599
+ return results
600
+
601
+
602
+ # MLSTAC API -----------------------------------------------------------------------
603
+ def example_data(path: pathlib.Path, *args, **kwargs):
604
+ data_file = path / "example_data.safetensor"
605
+ # Select only 10 meters and 20 meters bands
606
+ # B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12
607
+ bands = [1, 2, 3, 4, 5, 6, 7, 8, 11, 12]
608
+ return safetensors.torch.load_file(data_file)["example_data"][bands, 128:384, 128:384][None]
609
+
610
+ def trainable_model(path, *args, **kwargs):
611
+ # from 10m to 2.5m (RGBN)
612
+ sr_model_rgbn = path / "auxiliar_sr.jit"
613
+
614
+ # from 20m to 10m (RSWIR)
615
+ sr_model_rswir = path / "auxiliar_refsrx2.jit"
616
+
617
+ # Load model parameters from 10m to 2.5m (RSWIR)
618
+ weights = safetensors.torch.load_file(path / "model.safetensor")
619
+
620
+ # Load model
621
+ srmodel = HardConstraintModel(sr_model_rgbn=sr_model_rgbn, sr_model_rswir=sr_model_rswir)
622
+ srmodel.sr_model.load_state_dict(weights)
623
+
624
+ return srmodel
SR_S2_FastModel/mlm.json ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "Feature",
3
+ "stac_version": "1.1.0",
4
+ "stac_extensions": [
5
+ "https://stac-extensions.github.io/mlm/v1.4.0/schema.json"
6
+ ],
7
+ "id": "SPAN model",
8
+ "geometry": {
9
+ "type": "Polygon",
10
+ "coordinates": [
11
+ [
12
+ [
13
+ -180.0,
14
+ -90.0
15
+ ],
16
+ [
17
+ -180.0,
18
+ 90.0
19
+ ],
20
+ [
21
+ 180.0,
22
+ 90.0
23
+ ],
24
+ [
25
+ 180.0,
26
+ -90.0
27
+ ],
28
+ [
29
+ -180.0,
30
+ -90.0
31
+ ]
32
+ ]
33
+ ]
34
+ },
35
+ "bbox": [
36
+ -180,
37
+ -90,
38
+ 180,
39
+ 90
40
+ ],
41
+ "properties": {
42
+ "start_datetime": "1900-01-01T00:00:00Z",
43
+ "end_datetime": "9999-01-01T00:00:00Z",
44
+ "description": "A Swift Parameter-free Attention Network (SPAN). The model was trained using the CloudSEN12+ dataset.",
45
+ "forward_backward_pass": {
46
+ "32": 24.593536,
47
+ "64": 96.375936,
48
+ "128": 381.5968,
49
+ "256": 1518.662784,
50
+ "512": 6059.291776
51
+ },
52
+ "dependencies": [
53
+ "torch",
54
+ "safetensors.torch"
55
+ ],
56
+ "mlm:framework": "pytorch",
57
+ "mlm:framework_version": "2.1.2+cu121",
58
+ "file:size": 1861672,
59
+ "mlm:memory_size": 1,
60
+ "mlm:accelerator": "cuda",
61
+ "mlm:accelerator_constrained": false,
62
+ "mlm:accelerator_summary": "Unknown",
63
+ "mlm:name": "CNN_Light_F4",
64
+ "mlm:architecture": "SPAN",
65
+ "mlm:tasks": [
66
+ "super-resolution"
67
+ ],
68
+ "mlm:input": [
69
+ {
70
+ "name": "Sentinel-2 10m converted 2.5m and 20m bands converted to 10m",
71
+ "bands": [
72
+ "B02",
73
+ "B03",
74
+ "B04",
75
+ "B05",
76
+ "B06",
77
+ "B07",
78
+ "B08",
79
+ "B8A",
80
+ "B11",
81
+ "B12"
82
+ ],
83
+ "input": {
84
+ "shape": [
85
+ -1,
86
+ 10,
87
+ 128,
88
+ 128
89
+ ],
90
+ "dim_order": [
91
+ "batch",
92
+ "channel",
93
+ "height",
94
+ "width"
95
+ ],
96
+ "data_type": "float16"
97
+ },
98
+ "pre_processing_function": null
99
+ }
100
+ ],
101
+ "mlm:output": [
102
+ {
103
+ "name": "super-resolution",
104
+ "tasks": [
105
+ "super-resolution"
106
+ ],
107
+ "result": {
108
+ "shape": [
109
+ -1,
110
+ 10,
111
+ 512,
112
+ 512
113
+ ],
114
+ "dim_order": [
115
+ "batch",
116
+ "channel",
117
+ "height",
118
+ "width"
119
+ ],
120
+ "data_type": "float16"
121
+ },
122
+ "classification:classes": [],
123
+ "post_processing_function": null
124
+ }
125
+ ],
126
+ "mlm:total_parameters": 465418,
127
+ "mlm:pretrained": true,
128
+ "datetime": null
129
+ },
130
+ "links": [],
131
+ "assets": {
132
+ "auxiliar_sr": {
133
+ "href": "https://huggingface.co/tacofoundation/mlstac/resolve/main/CNN_Light_F4/auxiliar_sr.jit",
134
+ "type": "application/octet-stream; application=safetensor",
135
+ "title": "Torchscript model",
136
+ "description": "A Swift Parameter-free Attention Network (SPAN). The model was trained using the CloudSEN12+ dataset.The model can convert RGBN bands from 10m to 2.5m resolution.",
137
+ "mlm:artifact_type": "torch.jit.save",
138
+ "roles": [
139
+ "mlm:model",
140
+ "mlm:weights",
141
+ "data"
142
+ ]
143
+ },
144
+ "auxiliar_refsrx2": {
145
+ "href": "https://huggingface.co/tacofoundation/mlstac/resolve/main/CNN_Light_F4/auxiliar_refsrx2.jit",
146
+ "type": "application/octet-stream; application=safetensor",
147
+ "title": "Torchscript model",
148
+ "description": "A Swift Parameter-free Attention Network (SPAN). The model was trained using the CloudSEN12+ dataset.The model can convert RSWIR bands from 20m to 10m resolution.",
149
+ "mlm:artifact_type": "torch.jit.save",
150
+ "roles": [
151
+ "mlm:model",
152
+ "mlm:weights",
153
+ "data"
154
+ ]
155
+ },
156
+ "trainable": {
157
+ "href": "https://huggingface.co/tacofoundation/mlstac/resolve/main/CNN_Light_F4/model.safetensor",
158
+ "type": "application/octet-stream; application=safetensor",
159
+ "title": "Pytorch weights checkpoint",
160
+ "description": "A Swift Parameter-free Attention Network (SPAN). The model was trained using the CloudSEN12+ dataset.",
161
+ "mlm:artifact_type": "safetensor.torch.save_file",
162
+ "roles": [
163
+ "mlm:model",
164
+ "mlm:weights",
165
+ "data"
166
+ ]
167
+ },
168
+ "compile": {
169
+ "href": "https://huggingface.co/tacofoundation/mlstac/resolve/main/CNN_Light_F4/model.jit",
170
+ "type": "application/octet-stream; application=pytorch",
171
+ "title": "Torchscript model",
172
+ "description": "A Swift Parameter-free Attention Network (SPAN). The model was trained using the CloudSEN12+ dataset.",
173
+ "mlm:artifact_type": "torch.jit.save",
174
+ "roles": [
175
+ "mlm:model",
176
+ "mlm:weights",
177
+ "data"
178
+ ]
179
+ },
180
+ "source_code": {
181
+ "href": "https://huggingface.co/tacofoundation/mlstac/resolve/main/CNN_Light_F4/load.py",
182
+ "type": "text/x-python",
183
+ "title": "Model load script",
184
+ "description": "Source code to run the model.",
185
+ "roles": [
186
+ "mlm:source_code",
187
+ "code"
188
+ ]
189
+ },
190
+ "example_data": {
191
+ "href": "https://huggingface.co/tacofoundation/mlstac/resolve/main/CNN_Light_F4/example_data.safetensor",
192
+ "type": "application/octet-stream; application=safetensors",
193
+ "title": "Example Sentinel-2 image",
194
+ "description": "Example Sentinel-2 image for model inference.",
195
+ "roles": [
196
+ "mlm:example_data",
197
+ "data"
198
+ ]
199
+ }
200
+ },
201
+ "collection": "ml-model"
202
+ }
SR_S2_FastModel/model.jit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3305cb8a07b1e2fdd4e05c25f4a2de883ba71f56f408510699e598c10cd86568
3
+ size 7604602
SR_S2_FastModel/model.safetensor ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f45c99af1d08408cb622ba02d1aeadaab2ca964f08991feeb4fd7ffd5197f59d
3
+ size 2285088