Files
Machine-Learning-Collection/ML/Pytorch/others/default_setups/CV - Image Classification/model.py
2021-05-27 10:21:14 +02:00

13 lines
411 B
Python

from torch import nn
from efficientnet_pytorch import EfficientNet
class Net(nn.Module):
def __init__(self, net_version, num_classes):
super(Net, self).__init__()
self.backbone = EfficientNet.from_pretrained('efficientnet-'+net_version)
self.backbone._fc = nn.Sequential(
nn.Linear(1280, num_classes),
)
def forward(self, x):
return self.backbone(x)