From 665bc28b366454575aad3affc5ef70a5f7746c6a Mon Sep 17 00:00:00 2001 From: TaurusXin Date: Mon, 30 Sep 2024 11:10:57 +0800 Subject: [PATCH] fix: wrong path --- predict.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/predict.py b/predict.py index f9e287e..42e3ae9 100644 --- a/predict.py +++ b/predict.py @@ -29,8 +29,9 @@ def main(): model.eval() # random pickup some test images - pickup_count = 30 - pickup_rect = [5, 6] + pickup_count = int(input("输入验证集的图片长度(默认30):") or 30) + input_rect = input("输入图片的行数和列数(默认5x6):") + pickup_rect = [int(i) for i in input_rect.split("x")] if input_rect else [5, 6] files = os.listdir(captcha_settings.PREDICT_DATASET_PATH) images_picked = random.sample(files, pickup_count) @@ -38,7 +39,7 @@ def main(): fig, axes = plt.subplots(nrows=pickup_rect[0], ncols=pickup_rect[1], figsize=(10, 8)) for i, image_name in enumerate(images_picked): real_text = image_name.split(".")[0].split("_")[-1] - file_path = os.path.join(captcha_settings.TEST_DATASET_PATH, image_name) + file_path = os.path.join(captcha_settings.PREDICT_DATASET_PATH, image_name) pred_text = predict(model, file_path) correct = real_text == pred_text axes[i//pickup_rect[1], i%pickup_rect[1]].imshow(plt.imread(file_path))