CNN-based image classification on long-tailed CIFAR dataset
1. Problem description
In real-world scenarios, large-scale datasets naturally exhibit imbalanced and long-tailed distributions, where a few categories (majority categories) occupy most of the data while most categories (minority categories) are under-represented [1]. With traditional training formula, commonly used CNN as ResNet [2] has poor performance on such long-tailed datasets. This project will explore the effect of novel loss function and deferred re-weighting (DRW) on CNN training with long-tailed CIFAR datasets.
2. Pre-existing work
Previously, there are several re-sampling methods to deal with imbalanced datasets. However, random re-sampling may lead to overfitting issues.
Another possible way is re-weighting, which optimizes the loss function to increase the loss of the minority class. [3] proposed focal loss to drastically reduce the cross-entropy loss for easy examples (majority categories) while keeping the loss of hard examples (minority categories) at the same level as cross-entropy.
Besides, two-stage training is also effective. The first stage is imbalanced data training. Then transfer learning is conducted with the balanced subset [4].
However, experiments show that simply using re-weighting may not substantially increase the accuracy. A further improvement is to combine two-stage training with re-weighting.
The CNN architecture of this project is ResNet34 [2]. The implementation to deal with long-tailed dataset are based on [8].
3. Main approach
- Dataset: CIFAR [5] dataset is used for training and testing. By randomly selecting from the CIFAR100 dataset, we create the long-tailed CIFAR50 dataset with an imbalanced factor of 50.
- Transfer Learning: Then, we fine-tuned pre-trained ResNet 34 on ImageNet from PyTorch on balanced and long-tailed CIFAR 50 respectively. We used SGD as the optimizer and cross-entropy as the loss.
- Label-Distribution-Aware Margin Loss (LDAM): SVMs [6] aim at obtaining the maximum margins of classifiers. As a way to deal with imbalanced data, [7] proposed Label-Distribution-Aware Margin Loss (LDAM) to encourages bigger margins for minority classes.
- Deferred re-weighting(DRW): As a two stage training technique, [7] also introduces deferred re-weighting(DRW) that uses the vanilla training schedule in the first stage, and then applies re-weighting methods in the second stage [1].
- Applying tricks on long-tailed CIFAR training: We firstly used focal loss and DRW respectively to train ResNet32. Besides, we combined LDAM with DRW to train the model to verify its effectiveness. An extensive experiments have been conducted on long-tailed CIFAR10 and CIFAR100 dataset with imbalanced factors of 50 and 100.
4. Results
- Transfer Learning on long-tailed CIFAR50
From Figure 4 and Table 1, fine-tuned ResNet 34 has dropped 16.84% test accuracy on Long-tailed CIFAR50 (imbalanced factor 50) compared to the balanced dataset. This indicates the difficulties of training long-tailed dataset.
- Ablation study
From Table 2, we conclude that LDAM DRW is the best trick to improve model performance on long-tailed CIFAR. Besides, DRW itself is quite effective as well. However, focal loss performs on the same level as cross-entropy (CE) loss.
5. Discussion
- Future extension: In the future, researchers can explore whether these tricks (re-sampling, re-weighting and two stage training) conflict with each other to lead to worse performance. Also, combining two stage training with re-sampling is also a promising research topic.
- Problems encountered:
- As the long-tailed CIFAR50 is a customized dataset, we struggled to load the dataset properly in the beginning. After looking up online, we realized that we should write the data loader to load data by the json file.
- Another problem is that we initially expect focal loss to outperform cross-entropy but results cannot prove this. So we conduct further experiments to tune gamma of focal loss. However, the accuracy does not change much. This is mainly because focal loss hurts the accuracy of the majority class by decreasing their loss. To tackle this problem, we utilize LDAM DRW to significantly improve the classification accuracy on long-tailed CIFAR.
- Strengths and weaknesses:
As shown in table 2, the LDAM DRW method significantly improves the accuracy of long-tailed CIFAR in every scenario. Also, the DRW itself outperforms traditional one-stage training.
However, two stage training has more computational complexity and training time. This is an unavoidable shortcomings of the LDAM DRW approach.
Our work from scratch
For grading purposes, here is the list of our work from scratch:
- Visualizaiton of balanced/ imbalced CIFAR 50 data distribution.
- Code of DataLoader of customized CIFAR 50 from json file.
- Code of training and evaluation methods with Pytorch to fine tune pretrained ResNet 34.
- Understanding the theories and effect of focal loss, LDAM, and two stage training.
- Utilizing open-source code [8] to conduct experiments.
Reference:
[1] Zhang, Yongshun, et al. “Bag of tricks for long-tailed visual recognition with deep convolutional neural networks.” Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 35. No. 4. 2021.
[2] He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
[3] Lin, Tsung-Yi, et al. “Focal loss for dense object detection.” Proceedings of the IEEE international conference on computer vision. 2017.
[4] Cui, Yin, et al. “Large scale fine-grained categorization and domain-specific transfer learning.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.
[5] https://www.cs.toronto.edu/~kriz/cifar.html
[6] Suykens, Johan AK, and Joos Vandewalle. “Least squares support vector machine classifiers.” Neural processing letters 9.3 (1999): 293-300.
[7] Cao, Kaidi, et al. “Learning imbalanced datasets with label-distribution-aware margin loss.” Advances in neural information processing systems 32 (2019).