46 lines
1.1 KiB
Python
46 lines
1.1 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from cnn_net import ConvNet
|
||
|
import dataset
|
||
|
|
||
|
num_epochs = 10
|
||
|
learning_rate = 0.001
|
||
|
|
||
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||
|
|
||
|
def main():
|
||
|
model = ConvNet().to(device)
|
||
|
|
||
|
model.train()
|
||
|
|
||
|
criterion = nn.MultiLabelSoftMarginLoss()
|
||
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||
|
|
||
|
# Train the model
|
||
|
train_dataloader = dataset.get_train_loader()
|
||
|
|
||
|
for epoch in range(num_epochs):
|
||
|
print("Epoch:", epoch+1)
|
||
|
pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
|
||
|
for i, (images, labels) in pbar:
|
||
|
images, labels = images.to(device), labels.to(device)
|
||
|
|
||
|
predict_labels = model(images)
|
||
|
|
||
|
loss = criterion(predict_labels, labels)
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
pbar.set_description("loss: %.4f" % loss.item())
|
||
|
|
||
|
print("loss:", loss.item(), '\n')
|
||
|
torch.save(model.state_dict(), "./model.pth")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|