mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
"""
|
|
Example code of how to initialize weights for a simple CNN network.
|
|
Usually this is not needed as default initialization is usually good,
|
|
but sometimes it can be useful to initialize weights in a specific way.
|
|
This way of doing it should generalize to other network types just make
|
|
sure to specify and change the modules you wish to modify.
|
|
|
|
Video explanation: https://youtu.be/xWQ-p_o0Uik
|
|
Got any questions leave a comment on youtube :)
|
|
|
|
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
|
|
* 2020-04-10 Initial coding
|
|
* 2022-12-16 Updated with more detailed comments, and checked code still functions as intended.
|
|
"""
|
|
|
|
# Imports
|
|
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
|
|
import torch.nn.functional as F # All functions that don't have any parameters
|
|
|
|
|
|
class CNN(nn.Module):
|
|
def __init__(self, in_channels, num_classes):
|
|
super(CNN, self).__init__()
|
|
self.conv1 = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=6,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
)
|
|
self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
|
|
self.conv2 = nn.Conv2d(
|
|
in_channels=6,
|
|
out_channels=16,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
)
|
|
self.fc1 = nn.Linear(16 * 7 * 7, num_classes)
|
|
self.initialize_weights()
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
x = self.pool(x)
|
|
x = F.relu(self.conv2(x))
|
|
x = self.pool(x)
|
|
x = x.reshape(x.shape[0], -1)
|
|
x = self.fc1(x)
|
|
|
|
return x
|
|
|
|
def initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_uniform_(m.weight)
|
|
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.kaiming_uniform_(m.weight)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = CNN(in_channels=3, num_classes=10)
|
|
|
|
for param in model.parameters():
|
|
print(param)
|