captcha/test.py

45 lines
1.3 KiB
Python

import os
import torch
from PIL import Image
from cnn_net import ConvNet
import one_hot_encoding
from torchvision import transforms
import captcha_settings
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
def predict(model, file_path):
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Grayscale()
])
with torch.no_grad():
X = trans(Image.open(file_path)).reshape(1, 1, 60, 160)
X = X.to(device)
pred = model(X)
text = one_hot_encoding.decode(pred)
return text
def main():
model = ConvNet().to(device)
model.load_state_dict(torch.load(f"./model.pth", weights_only=True))
model.eval()
correct = 0
total = len(os.listdir(captcha_settings.TEST_DATASET_PATH))
for filename in tqdm(os.listdir(captcha_settings.TEST_DATASET_PATH)):
file_path = f"{captcha_settings.TEST_DATASET_PATH}{os.path.sep}{filename}"
real_captcha = filename.split('.')[0].split('_')[-1]
pred_captcha = predict(model, file_path)
if pred_captcha == real_captcha:
correct += 1
accuracy = f"Test {total} files, accuracy: {correct / total * 100:.2f}%"
print(accuracy)
if __name__ == '__main__':
main()