import os import torch from PIL import Image from cnn_net import ConvNet import one_hot_encoding from torchvision import transforms import captcha_settings from tqdm import tqdm device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device") def predict(model, file_path): trans = transforms.Compose([ transforms.ToTensor(), transforms.Grayscale() ]) with torch.no_grad(): X = trans(Image.open(file_path)).reshape(1, 1, 60, 160) X = X.to(device) pred = model(X) text = one_hot_encoding.decode(pred) return text def main(): model = ConvNet().to(device) model.load_state_dict(torch.load(f"./model.pth", weights_only=True)) model.eval() correct = 0 total = len(os.listdir(captcha_settings.TEST_DATASET_PATH)) for filename in tqdm(os.listdir(captcha_settings.TEST_DATASET_PATH)): file_path = f"{captcha_settings.TEST_DATASET_PATH}{os.path.sep}{filename}" real_captcha = filename.split('.')[0].split('_')[-1] pred_captcha = predict(model, file_path) if pred_captcha == real_captcha: correct += 1 accuracy = f"Test {total} files, accuracy: {correct / total * 100:.2f}%" print(accuracy) if __name__ == '__main__': main()