captcha/predict.py

52 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from cnn_net import ConvNet
import os
import random
import captcha_settings
import one_hot_encoding
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
def predict(model, file_path):
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Grayscale()
])
with torch.no_grad():
X = trans(Image.open(file_path)).reshape(1, 1, 60, 160)
X = X.to(device)
pred = model(X)
text = one_hot_encoding.decode(pred)
return text
def main():
model = ConvNet().to(device)
model.load_state_dict(torch.load(f"./model.pth", weights_only=True))
model.eval()
# random pickup some test images
pickup_count = int(input("输入验证集的图片长度默认30") or 30)
input_rect = input("输入图片的行数和列数默认5x6")
pickup_rect = [int(i) for i in input_rect.split("x")] if input_rect else [5, 6]
files = os.listdir(captcha_settings.PREDICT_DATASET_PATH)
images_picked = random.sample(files, pickup_count)
# show as a grid, with predicted text, correct or not
fig, axes = plt.subplots(nrows=pickup_rect[0], ncols=pickup_rect[1], figsize=(10, 8))
for i, image_name in enumerate(images_picked):
real_text = image_name.split(".")[0].split("_")[-1]
file_path = os.path.join(captcha_settings.PREDICT_DATASET_PATH, image_name)
pred_text = predict(model, file_path)
correct = real_text == pred_text
axes[i//pickup_rect[1], i%pickup_rect[1]].imshow(plt.imread(file_path))
axes[i//pickup_rect[1], i%pickup_rect[1]].set_title(f"{pred_text}, {'yes' if correct else 'no'}")
axes[i//pickup_rect[1], i%pickup_rect[1]].axis('off')
plt.show()
if __name__ == "__main__":
main()