Files

13 lines
411 B
Python
Raw Permalink Normal View History

2021-01-30 21:49:15 +01:00
from torch import nn
from efficientnet_pytorch import EfficientNet
class Net(nn.Module):
def __init__(self, net_version, num_classes):
super(Net, self).__init__()
self.backbone = EfficientNet.from_pretrained('efficientnet-'+net_version)
self.backbone._fc = nn.Sequential(
nn.Linear(1280, num_classes),
)
def forward(self, x):
return self.backbone(x)