In PyTorch the labels are generated by default form subfolder names in the root directory. It’s very comfortable unless you’d like to change the labels. Let’s say, I ‘d like to train the two-staged Image Recognition model:

  1. Predict if an image contains any animals (labels: blank – no animals present; non-blank – animals present)
  2. Based on the prediction in step one, I only take non-blank images and create the second model to predict particular animal species (labels: zebra, giraffe, elephant, other (any other animals)).

To use the same images with different labels in PyTorch I can do the following:

  • have two root folders, each with different subfolders, but the same images. It’s space-consuming, and if you store it in the cloud – it can double the storage cost.
  • After the first step, create new subfolders and move the images there. It seems to be messy, especially If you’d like to go back to step one and retrain the model using old labels.
  • Keep all images in the root directory and load the labels from CSV.

Personally, I prefer the last option. In the next part, you’ll see how to load custom labels for the PyTorch model.

DataLoader

The default DataLoader (load data along with labels) fits in two lines of code:

train_set = datasets.ImageFolder("/root_folder_path")
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32,shuffle=True, num_workers=4)

To create a custom Pytorch DataLoader, we need to create a new class. Fortunately, PyTorch comes with help, by creating an abstract Dataset class. The custom DataLoader should inherit from Dataset class and override the methods:

  • __len__ to return the length of the custom dataset
  • __getitem__ to return the data and labels. It should enable indexing (so that the dataset[n]returns the data and label of n-th image)

At first, let’s load the labels CSV. My df_labelswould contain filenames in the index and labels column:

df_labels = pd.read_csv("labels.csv", index_col=0)

The data frame with labels.

Then, we need to define the data loader. It would take as the arguments: the root directory (images_folder_path), a data frame with labels (file_df) and PyTorch image transformations (transform). In case the file_df is much bigger than the amount of the images, we’ll save only the subset we need in theself.label_files_df:

class CreateDataset(Dataset):
    def __init__(self, images_folder_path, file_df, transform):
        self.folderpath = images_folder_path
        self.image_filenames_list = [f for f in os.listdir(self.folderpath) if f in file_df.index]
        self.label_files_df = file_df.loc[self.image_filenames_list]
        self.transform = transform

Then, we need to override the two methods. The __len__method literally returns the length of the dataset. __getitem__is the method that loads the data and labels. It iterates over the indexes of the data and in each iteration, one image is opened and transformed. At the end of each iteration, the label and image are returned. We’d like to return the same image format (PIL) as the default PyTorch loader so I used the Image.open from the Pillow library:

def __len__(self):
    return len(self.image_filenames_list)

def __getitem__(self, idx):
    filename = self.image_filenames_list[idx]
    filepath = os.path.join(self.folderpath, filename)
    
    image = Image.open(filepath)
    image = self.transform(image)
    label = self.label_files_df.loc[filename, "labels"]
    return image, filename

Finally, we need to use the class:

dataset = CreateDataset(images_folder_path, df_labels, transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

To sum up as one piece of the code:

class CreateDataset(Dataset):
    def __init__(self, images_folder_path, file_df, transform):
        self.folderpath = images_folder_path
        self.image_filenames_list = [f for f in os.listdir(self.folderpath) if f in file_df.index]
        self.label_files_df = file_df.loc[self.image_filenames_list]
        self.transform = transform
    
    def __len__(self):
        return len(self.image_filenames_list)

    def __getitem__(self, idx):
        filename = self.image_filenames_list[idx]
        filepath = os.path.join(self.folderpath, filename)
        
        image = Image.open(filename)
        image = self.transform(image)
        label = self.label_files_df.loc[filename, "labels"]
        return image, filename

df_labels = pd.read_csv("labels.csv", index_col=0)
dataset = CreateDataset(images_folder_path, df_labels, transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

If you’d like to see the custom data loader implemented along with boto3 (for AWS EC2), take a look at my Github repository.