mnist-learning/mnist_test.py

56 lines
1.2 KiB
Python
Raw Normal View History

2023-12-29 19:53:46 +08:00
import torch
from mnist_net import MnistNet
from torchvision import transforms
from PIL import Image
import os
# find device
device = None
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()
if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
# load image
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
def predict(device, model, img):
# transform image
img = transform(img)
img = img.to(device)
img = img.unsqueeze(0)
# predict
output = model(img)
pred = output.argmax(dim=1, keepdim=True)
return pred.item()
if __name__ == '__main__':
# load model
model = MnistNet()
model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device))
model.to(device)
# set model to eval mode
model.eval()
# load image
folder = './input'
predictions = []
for image_name in sorted(os.listdir(folder)):
img = Image.open(os.path.join(folder, image_name))
predictions.append(predict(device, model, img))
print(predictions)