45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
import os
|
|
import torch
|
|
from PIL import Image
|
|
from cnn_net import ConvNet
|
|
import one_hot_encoding
|
|
from torchvision import transforms
|
|
import captcha_settings
|
|
from tqdm import tqdm
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
print(f"Using {device} device")
|
|
|
|
|
|
def predict(model, file_path):
|
|
trans = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Grayscale()
|
|
])
|
|
with torch.no_grad():
|
|
X = trans(Image.open(file_path)).reshape(1, 1, 60, 160)
|
|
X = X.to(device)
|
|
pred = model(X)
|
|
text = one_hot_encoding.decode(pred)
|
|
return text
|
|
|
|
|
|
def main():
|
|
model = ConvNet().to(device)
|
|
model.load_state_dict(torch.load(f"./model.pth", weights_only=True))
|
|
model.eval()
|
|
|
|
correct = 0
|
|
total = len(os.listdir(captcha_settings.TEST_DATASET_PATH))
|
|
for filename in tqdm(os.listdir(captcha_settings.TEST_DATASET_PATH)):
|
|
file_path = f"{captcha_settings.TEST_DATASET_PATH}{os.path.sep}{filename}"
|
|
real_captcha = filename.split('.')[0].split('_')[-1]
|
|
pred_captcha = predict(model, file_path)
|
|
if pred_captcha == real_captcha:
|
|
correct += 1
|
|
accuracy = f"Test {total} files, accuracy: {correct / total * 100:.2f}%"
|
|
print(accuracy)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |