captcha/dataset.py

50 lines
1.7 KiB
Python
Raw Normal View History

2024-09-30 11:05:35 +08:00
import os
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import one_hot_encoding as ohe
from captcha_settings import TRAIN_DATASET_PATH, TEST_DATASET_PATH, PREDICT_DATASET_PATH
class CaptchaDataset(Dataset):
def __init__(self, dir, transform=None):
# list all image files in the directory
self.train_images = [os.path.join(dir, image_file) for image_file in os.listdir(dir)]
self.transform = transform
def __len__(self):
return len(self.train_images)
def __getitem__(self, idx):
# load the image and convert it to grayscale
image_root = self.train_images[idx]
image_name = image_root.split(os.path.sep)[-1]
image = Image.open(image_root)
if self.transform is not None:
image = self.transform(image)
label = ohe.encode(image_name.split('.')[0].split('_')[-1])
return image, label
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Grayscale(),
])
def get_train_loader(batch_size=60):
dataset = CaptchaDataset(TRAIN_DATASET_PATH, transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
def get_test_loader(batch_size=60):
dataset = CaptchaDataset(TEST_DATASET_PATH, transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
def get_predict_loader(batch_size=60):
dataset = CaptchaDataset(PREDICT_DATASET_PATH, transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
def main():
train_loader = get_train_loader()
for i, (image, label) in enumerate(train_loader):
print(image.shape, label.shape)
if __name__ == '__main__':
main()