import torch from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt from cnn_net import ConvNet import os import random import captcha_settings import one_hot_encoding device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device") def predict(model, file_path): trans = transforms.Compose([ transforms.ToTensor(), transforms.Grayscale() ]) with torch.no_grad(): X = trans(Image.open(file_path)).reshape(1, 1, 60, 160) X = X.to(device) pred = model(X) text = one_hot_encoding.decode(pred) return text def main(): model = ConvNet().to(device) model.load_state_dict(torch.load(f"./model.pth", weights_only=True)) model.eval() # random pickup some test images pickup_count = int(input("输入验证集的图片长度(默认30):") or 30) input_rect = input("输入图片的行数和列数(默认5x6):") pickup_rect = [int(i) for i in input_rect.split("x")] if input_rect else [5, 6] files = os.listdir(captcha_settings.PREDICT_DATASET_PATH) images_picked = random.sample(files, pickup_count) # show as a grid, with predicted text, correct or not fig, axes = plt.subplots(nrows=pickup_rect[0], ncols=pickup_rect[1], figsize=(10, 8)) for i, image_name in enumerate(images_picked): real_text = image_name.split(".")[0].split("_")[-1] file_path = os.path.join(captcha_settings.PREDICT_DATASET_PATH, image_name) pred_text = predict(model, file_path) correct = real_text == pred_text axes[i//pickup_rect[1], i%pickup_rect[1]].imshow(plt.imread(file_path)) axes[i//pickup_rect[1], i%pickup_rect[1]].set_title(f"{pred_text}, {'yes' if correct else 'no'}") axes[i//pickup_rect[1], i%pickup_rect[1]].axis('off') plt.show() if __name__ == "__main__": main()