captcha/cnn_net.py

84 lines
2.7 KiB
Python
Raw Normal View History

2024-09-30 11:05:35 +08:00
import torch.nn as nn
import captcha_settings
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.layer2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.layer3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.layer4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.layer5 = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=15360, out_features=4096),
nn.Dropout(0.5),
nn.ReLU(),
nn.Linear(
4096,
captcha_settings.MAX_CAPTCHA * captcha_settings.ALL_CHAR_SET_LEN,
),
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
return x
# class ConvNet(nn.Module):
# def __init__(self):
# super(ConvNet, self).__init__()
# self.layer1 = nn.Sequential(
# nn.Conv2d(1, 32, kernel_size=3, padding=1),
# nn.BatchNorm2d(32),
# nn.Dropout(0.5), # drop 50% of the neuron
# nn.ReLU(),
# nn.MaxPool2d(2))
# self.layer2 = nn.Sequential(
# nn.Conv2d(32, 64, kernel_size=3, padding=1),
# nn.BatchNorm2d(64),
# nn.Dropout(0.5), # drop 50% of the neuron
# nn.ReLU(),
# nn.MaxPool2d(2))
# self.layer3 = nn.Sequential(
# nn.Conv2d(64, 64, kernel_size=3, padding=1),
# nn.BatchNorm2d(64),
# nn.Dropout(0.5), # drop 50% of the neuron
# nn.ReLU(),
# nn.MaxPool2d(2))
# self.fc = nn.Sequential(
# nn.Linear((captcha_settings.IMAGE_WIDTH//8)*(captcha_settings.IMAGE_HEIGHT//8)*64, 1024),
# nn.Dropout(0.5), # drop 50% of the neuron
# nn.ReLU())
# self.rfc = nn.Sequential(
# nn.Linear(1024, captcha_settings.MAX_CAPTCHA*captcha_settings.ALL_CHAR_SET_LEN),
# )
# def forward(self, x):
# out = self.layer1(x)
# out = self.layer2(out)
# out = self.layer3(out)
# out = out.view(out.size(0), -1)
# out = self.fc(out)
# out = self.rfc(out)
# return out