52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
import torch
|
||
from torchvision import transforms
|
||
from PIL import Image
|
||
import matplotlib.pyplot as plt
|
||
from cnn_net import ConvNet
|
||
import os
|
||
import random
|
||
import captcha_settings
|
||
import one_hot_encoding
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||
print(f"Using {device} device")
|
||
|
||
def predict(model, file_path):
|
||
trans = transforms.Compose([
|
||
transforms.ToTensor(),
|
||
transforms.Grayscale()
|
||
])
|
||
with torch.no_grad():
|
||
X = trans(Image.open(file_path)).reshape(1, 1, 60, 160)
|
||
X = X.to(device)
|
||
pred = model(X)
|
||
text = one_hot_encoding.decode(pred)
|
||
return text
|
||
|
||
def main():
|
||
model = ConvNet().to(device)
|
||
model.load_state_dict(torch.load(f"./model.pth", weights_only=True))
|
||
model.eval()
|
||
|
||
# random pickup some test images
|
||
pickup_count = int(input("输入验证集的图片长度(默认30):") or 30)
|
||
input_rect = input("输入图片的行数和列数(默认5x6):")
|
||
pickup_rect = [int(i) for i in input_rect.split("x")] if input_rect else [5, 6]
|
||
files = os.listdir(captcha_settings.PREDICT_DATASET_PATH)
|
||
images_picked = random.sample(files, pickup_count)
|
||
|
||
# show as a grid, with predicted text, correct or not
|
||
fig, axes = plt.subplots(nrows=pickup_rect[0], ncols=pickup_rect[1], figsize=(10, 8))
|
||
for i, image_name in enumerate(images_picked):
|
||
real_text = image_name.split(".")[0].split("_")[-1]
|
||
file_path = os.path.join(captcha_settings.PREDICT_DATASET_PATH, image_name)
|
||
pred_text = predict(model, file_path)
|
||
correct = real_text == pred_text
|
||
axes[i//pickup_rect[1], i%pickup_rect[1]].imshow(plt.imread(file_path))
|
||
axes[i//pickup_rect[1], i%pickup_rect[1]].set_title(f"{pred_text}, {'yes' if correct else 'no'}")
|
||
axes[i//pickup_rect[1], i%pickup_rect[1]].axis('off')
|
||
plt.show()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|