|
import torch.nn as nn |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def R2D2_conv_block(in_channels, out_channels, retain_activation=True, keep_prob=1.0): |
|
block = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, 3, padding=1), |
|
nn.BatchNorm2d(out_channels), |
|
nn.MaxPool2d(2) |
|
) |
|
if retain_activation: |
|
block.add_module("LeakyReLU", nn.LeakyReLU(0.1)) |
|
|
|
if keep_prob < 1.0: |
|
block.add_module("Dropout", nn.Dropout(p=1 - keep_prob, inplace=False)) |
|
|
|
return block |
|
|
|
class R2D2Embedding(nn.Module): |
|
def __init__(self, x_dim=3, h1_dim=96, h2_dim=192, h3_dim=384, z_dim=512, \ |
|
retain_last_activation=False): |
|
super(R2D2Embedding, self).__init__() |
|
|
|
self.block1 = R2D2_conv_block(x_dim, h1_dim) |
|
self.block2 = R2D2_conv_block(h1_dim, h2_dim) |
|
self.block3 = R2D2_conv_block(h2_dim, h3_dim, keep_prob=0.9) |
|
|
|
|
|
|
|
|
|
|
|
self.block4 = R2D2_conv_block(h3_dim, z_dim, retain_activation=retain_last_activation, keep_prob=0.7) |
|
|
|
def forward(self, x): |
|
b1 = self.block1(x) |
|
b2 = self.block2(b1) |
|
b3 = self.block3(b2) |
|
b4 = self.block4(b3) |
|
|
|
return torch.cat((b3.view(b3.size(0), -1), b4.view(b4.size(0), -1)), 1) |