# -*- coding: utf-8 -*-
# @Author: Weisen Pan

import os
import glob
import numpy as np

# Define paths to the training and testing datasets
train_data_path = '/media/skydata/alpha0012/workspace/EdgeFLite/dataset_hub/pill/train_images'
test_data_path = '/media/skydata/alpha0012/workspace/EdgeFLite/dataset_hub/pill/test_images'

def list_image_files_by_class(directory):
    """ 
    Returns a list of image file paths and their corresponding class indices.
    
    Args:
        directory (str): The path to the directory containing class folders.

    Returns:
        list: A list of image file paths and their class indices.
    """
    # Get the sorted list of class labels (folder names)
    class_labels = sorted(os.listdir(directory))
    # Create a mapping from class names to indices
    class_to_idx = {class_name: idx for idx, class_name in enumerate(class_labels)}

    image_dataset = []  # Initialize an empty list to store image data

    # Iterate through each class
    for class_name in class_labels:
        class_folder = os.path.join(directory, class_name)  # Path to the class folder
        # Find all JPG images in the class folder and its subfolders
        image_files = glob.glob(os.path.join(class_folder, '**', '*.jpg'), recursive=True)
        
        # Append image file paths and their class indices to the dataset
        for image_file in image_files:
            image_dataset.append([image_file, class_to_idx[class_name]])

    return image_dataset

if __name__ == "__main__":
    # Retrieve and print the number of files in the training and testing datasets
    train_images = list_image_files_by_class(train_data_path)
    test_images = list_image_files_by_class(test_data_path)

    print(f"Training dataset size: {len(train_images)}")  # Output the size of the training dataset
    print(f"Testing dataset size: {len(test_images)}")    # Output the size of the testing dataset