captcha/train.py

46 lines
1.1 KiB
Python
Raw Permalink Normal View History

2024-09-30 11:05:35 +08:00
import torch
import torch.nn as nn
from tqdm import tqdm
from cnn_net import ConvNet
import dataset
num_epochs = 10
learning_rate = 0.001
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
def main():
model = ConvNet().to(device)
model.train()
criterion = nn.MultiLabelSoftMarginLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
train_dataloader = dataset.get_train_loader()
for epoch in range(num_epochs):
print("Epoch:", epoch+1)
pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
for i, (images, labels) in pbar:
images, labels = images.to(device), labels.to(device)
predict_labels = model(images)
loss = criterion(predict_labels, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description("loss: %.4f" % loss.item())
print("loss:", loss.item(), '\n')
torch.save(model.state_dict(), "./model.pth")
if __name__ == "__main__":
main()