50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
import os
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
import one_hot_encoding as ohe
|
|
from captcha_settings import TRAIN_DATASET_PATH, TEST_DATASET_PATH, PREDICT_DATASET_PATH
|
|
|
|
class CaptchaDataset(Dataset):
|
|
def __init__(self, dir, transform=None):
|
|
# list all image files in the directory
|
|
self.train_images = [os.path.join(dir, image_file) for image_file in os.listdir(dir)]
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.train_images)
|
|
|
|
def __getitem__(self, idx):
|
|
# load the image and convert it to grayscale
|
|
image_root = self.train_images[idx]
|
|
image_name = image_root.split(os.path.sep)[-1]
|
|
image = Image.open(image_root)
|
|
if self.transform is not None:
|
|
image = self.transform(image)
|
|
label = ohe.encode(image_name.split('.')[0].split('_')[-1])
|
|
return image, label
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Grayscale(),
|
|
])
|
|
|
|
def get_train_loader(batch_size=60):
|
|
dataset = CaptchaDataset(TRAIN_DATASET_PATH, transform)
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
def get_test_loader(batch_size=60):
|
|
dataset = CaptchaDataset(TEST_DATASET_PATH, transform)
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
def get_predict_loader(batch_size=60):
|
|
dataset = CaptchaDataset(PREDICT_DATASET_PATH, transform)
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
def main():
|
|
train_loader = get_train_loader()
|
|
for i, (image, label) in enumerate(train_loader):
|
|
print(image.shape, label.shape)
|
|
|
|
if __name__ == '__main__':
|
|
main() |