56 lines
1.2 KiB
Python
56 lines
1.2 KiB
Python
|
import torch
|
||
|
from mnist_net import MnistNet
|
||
|
from torchvision import transforms
|
||
|
from PIL import Image
|
||
|
|
||
|
import os
|
||
|
|
||
|
# find device
|
||
|
device = None
|
||
|
|
||
|
use_cuda = torch.cuda.is_available()
|
||
|
use_mps = torch.backends.mps.is_available()
|
||
|
|
||
|
if use_cuda:
|
||
|
device = torch.device("cuda")
|
||
|
elif use_mps:
|
||
|
device = torch.device("mps")
|
||
|
else:
|
||
|
device = torch.device("cpu")
|
||
|
|
||
|
print(f"Using device: {device}")
|
||
|
|
||
|
# load image
|
||
|
transform = transforms.Compose(
|
||
|
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||
|
)
|
||
|
|
||
|
def predict(device, model, img):
|
||
|
# transform image
|
||
|
img = transform(img)
|
||
|
img = img.to(device)
|
||
|
img = img.unsqueeze(0)
|
||
|
|
||
|
# predict
|
||
|
output = model(img)
|
||
|
pred = output.argmax(dim=1, keepdim=True)
|
||
|
return pred.item()
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# load model
|
||
|
model = MnistNet()
|
||
|
model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device))
|
||
|
model.to(device)
|
||
|
|
||
|
# set model to eval mode
|
||
|
model.eval()
|
||
|
|
||
|
# load image
|
||
|
folder = './input'
|
||
|
predictions = []
|
||
|
|
||
|
for image_name in sorted(os.listdir(folder)):
|
||
|
img = Image.open(os.path.join(folder, image_name))
|
||
|
predictions.append(predict(device, model, img))
|
||
|
|
||
|
print(predictions)
|