mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-04-10 12:33:44 +00:00
Added bias=False
Bias term already included in the BN layers; can be set to False as it is redundant
This commit is contained in:
@@ -23,7 +23,7 @@ class block(nn.Module):
|
|||||||
super(block, self).__init__()
|
super(block, self).__init__()
|
||||||
self.expansion = 4
|
self.expansion = 4
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(
|
||||||
in_channels, intermediate_channels, kernel_size=1, stride=1, padding=0
|
in_channels, intermediate_channels, kernel_size=1, stride=1, padding=0, bias=False
|
||||||
)
|
)
|
||||||
self.bn1 = nn.BatchNorm2d(intermediate_channels)
|
self.bn1 = nn.BatchNorm2d(intermediate_channels)
|
||||||
self.conv2 = nn.Conv2d(
|
self.conv2 = nn.Conv2d(
|
||||||
@@ -32,6 +32,7 @@ class block(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=1,
|
padding=1,
|
||||||
|
bias=False
|
||||||
)
|
)
|
||||||
self.bn2 = nn.BatchNorm2d(intermediate_channels)
|
self.bn2 = nn.BatchNorm2d(intermediate_channels)
|
||||||
self.conv3 = nn.Conv2d(
|
self.conv3 = nn.Conv2d(
|
||||||
@@ -40,6 +41,7 @@ class block(nn.Module):
|
|||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
|
bias=False
|
||||||
)
|
)
|
||||||
self.bn3 = nn.BatchNorm2d(intermediate_channels * self.expansion)
|
self.bn3 = nn.BatchNorm2d(intermediate_channels * self.expansion)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
@@ -70,7 +72,7 @@ class ResNet(nn.Module):
|
|||||||
def __init__(self, block, layers, image_channels, num_classes):
|
def __init__(self, block, layers, image_channels, num_classes):
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
self.in_channels = 64
|
self.in_channels = 64
|
||||||
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
|
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||||
self.bn1 = nn.BatchNorm2d(64)
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
@@ -122,6 +124,7 @@ class ResNet(nn.Module):
|
|||||||
intermediate_channels * 4,
|
intermediate_channels * 4,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
|
bias=False
|
||||||
),
|
),
|
||||||
nn.BatchNorm2d(intermediate_channels * 4),
|
nn.BatchNorm2d(intermediate_channels * 4),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user