mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
129 lines
3.8 KiB
Python
129 lines
3.8 KiB
Python
|
|
import torch
|
||
|
|
from torch import nn
|
||
|
|
|
||
|
|
|
||
|
|
class ConvBlock(nn.Module):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
in_channels,
|
||
|
|
out_channels,
|
||
|
|
discriminator=False,
|
||
|
|
use_act=True,
|
||
|
|
use_bn=True,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
super().__init__()
|
||
|
|
self.use_act = use_act
|
||
|
|
self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
|
||
|
|
self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
|
||
|
|
self.act = (
|
||
|
|
nn.LeakyReLU(0.2, inplace=True)
|
||
|
|
if discriminator
|
||
|
|
else nn.PReLU(num_parameters=out_channels)
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
|
||
|
|
|
||
|
|
|
||
|
|
class UpsampleBlock(nn.Module):
|
||
|
|
def __init__(self, in_c, scale_factor):
|
||
|
|
super().__init__()
|
||
|
|
self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
|
||
|
|
self.ps = nn.PixelShuffle(scale_factor) # in_c * 4, H, W --> in_c, H*2, W*2
|
||
|
|
self.act = nn.PReLU(num_parameters=in_c)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.act(self.ps(self.conv(x)))
|
||
|
|
|
||
|
|
|
||
|
|
class ResidualBlock(nn.Module):
|
||
|
|
def __init__(self, in_channels):
|
||
|
|
super().__init__()
|
||
|
|
self.block1 = ConvBlock(
|
||
|
|
in_channels,
|
||
|
|
in_channels,
|
||
|
|
kernel_size=3,
|
||
|
|
stride=1,
|
||
|
|
padding=1
|
||
|
|
)
|
||
|
|
self.block2 = ConvBlock(
|
||
|
|
in_channels,
|
||
|
|
in_channels,
|
||
|
|
kernel_size=3,
|
||
|
|
stride=1,
|
||
|
|
padding=1,
|
||
|
|
use_act=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
out = self.block1(x)
|
||
|
|
out = self.block2(out)
|
||
|
|
return out + x
|
||
|
|
|
||
|
|
|
||
|
|
class Generator(nn.Module):
|
||
|
|
def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
|
||
|
|
super().__init__()
|
||
|
|
self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
|
||
|
|
self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
|
||
|
|
self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
|
||
|
|
self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2))
|
||
|
|
self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
initial = self.initial(x)
|
||
|
|
x = self.residuals(initial)
|
||
|
|
x = self.convblock(x) + initial
|
||
|
|
x = self.upsamples(x)
|
||
|
|
return torch.tanh(self.final(x))
|
||
|
|
|
||
|
|
|
||
|
|
class Discriminator(nn.Module):
|
||
|
|
def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
|
||
|
|
super().__init__()
|
||
|
|
blocks = []
|
||
|
|
for idx, feature in enumerate(features):
|
||
|
|
blocks.append(
|
||
|
|
ConvBlock(
|
||
|
|
in_channels,
|
||
|
|
feature,
|
||
|
|
kernel_size=3,
|
||
|
|
stride=1 + idx % 2,
|
||
|
|
padding=1,
|
||
|
|
discriminator=True,
|
||
|
|
use_act=True,
|
||
|
|
use_bn=False if idx == 0 else True,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
in_channels = feature
|
||
|
|
|
||
|
|
self.blocks = nn.Sequential(*blocks)
|
||
|
|
self.classifier = nn.Sequential(
|
||
|
|
nn.AdaptiveAvgPool2d((6, 6)),
|
||
|
|
nn.Flatten(),
|
||
|
|
nn.Linear(512*6*6, 1024),
|
||
|
|
nn.LeakyReLU(0.2, inplace=True),
|
||
|
|
nn.Linear(1024, 1),
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.blocks(x)
|
||
|
|
return self.classifier(x)
|
||
|
|
|
||
|
|
def test():
|
||
|
|
low_resolution = 24 # 96x96 -> 24x24
|
||
|
|
with torch.cuda.amp.autocast():
|
||
|
|
x = torch.randn((5, 3, low_resolution, low_resolution))
|
||
|
|
gen = Generator()
|
||
|
|
gen_out = gen(x)
|
||
|
|
disc = Discriminator()
|
||
|
|
disc_out = disc(gen_out)
|
||
|
|
|
||
|
|
print(gen_out.shape)
|
||
|
|
print(disc_out.shape)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
test()
|