import torch from mnist_net import MnistNet from torchvision import transforms from PIL import Image import os # find device device = None use_cuda = torch.cuda.is_available() use_mps = torch.backends.mps.is_available() if use_cuda: device = torch.device("cuda") elif use_mps: device = torch.device("mps") else: device = torch.device("cpu") print(f"Using device: {device}") # load image transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) def predict(device, model, img): # transform image img = transform(img) img = img.to(device) img = img.unsqueeze(0) # predict output = model(img) pred = output.argmax(dim=1, keepdim=True) return pred.item() if __name__ == '__main__': # load model model = MnistNet() model.load_state_dict(torch.load('mnist_cnn.pt', map_location=device)) model.to(device) # set model to eval mode model.eval() # load image folder = './input' predictions = [] for image_name in sorted(os.listdir(folder)): img = Image.open(os.path.join(folder, image_name)) predictions.append(predict(device, model, img)) print(predictions)