import torch import torch.nn as nn from tqdm import tqdm from cnn_net import ConvNet import dataset num_epochs = 10 learning_rate = 0.001 device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" def main(): model = ConvNet().to(device) model.train() criterion = nn.MultiLabelSoftMarginLoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Train the model train_dataloader = dataset.get_train_loader() for epoch in range(num_epochs): print("Epoch:", epoch+1) pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader)) for i, (images, labels) in pbar: images, labels = images.to(device), labels.to(device) predict_labels = model(images) loss = criterion(predict_labels, labels) optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description("loss: %.4f" % loss.item()) print("loss:", loss.item(), '\n') torch.save(model.state_dict(), "./model.pth") if __name__ == "__main__": main()