mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
20 lines
549 B
Python
20 lines
549 B
Python
import torch.nn as nn
|
|
from torchvision.models import vgg19
|
|
import config
|
|
|
|
|
|
class VGGLoss(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.vgg = vgg19(pretrained=True).features[:35].eval().to(config.DEVICE)
|
|
|
|
for param in self.vgg.parameters():
|
|
param.requires_grad = False
|
|
|
|
self.loss = nn.MSELoss()
|
|
|
|
def forward(self, input, target):
|
|
vgg_input_features = self.vgg(input)
|
|
vgg_target_features = self.vgg(target)
|
|
return self.loss(vgg_input_features, vgg_target_features)
|