import os from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms from PIL import Image import one_hot_encoding as ohe from captcha_settings import TRAIN_DATASET_PATH, TEST_DATASET_PATH, PREDICT_DATASET_PATH class CaptchaDataset(Dataset): def __init__(self, dir, transform=None): # list all image files in the directory self.train_images = [os.path.join(dir, image_file) for image_file in os.listdir(dir)] self.transform = transform def __len__(self): return len(self.train_images) def __getitem__(self, idx): # load the image and convert it to grayscale image_root = self.train_images[idx] image_name = image_root.split(os.path.sep)[-1] image = Image.open(image_root) if self.transform is not None: image = self.transform(image) label = ohe.encode(image_name.split('.')[0].split('_')[-1]) return image, label transform = transforms.Compose([ transforms.ToTensor(), transforms.Grayscale(), ]) def get_train_loader(batch_size=60): dataset = CaptchaDataset(TRAIN_DATASET_PATH, transform) return DataLoader(dataset, batch_size=batch_size, shuffle=True) def get_test_loader(batch_size=60): dataset = CaptchaDataset(TEST_DATASET_PATH, transform) return DataLoader(dataset, batch_size=batch_size, shuffle=True) def get_predict_loader(batch_size=60): dataset = CaptchaDataset(PREDICT_DATASET_PATH, transform) return DataLoader(dataset, batch_size=batch_size, shuffle=True) def main(): train_loader = get_train_loader() for i, (image, label) in enumerate(train_loader): print(image.shape, label.shape) if __name__ == '__main__': main()