If you have ever worked on deploying a machine learning model, you know how challenging it is to ensure a smooth transition from your controlled environment to the real world. You need to consider many factors to achieve a successful implementation.
One aspect that can be easily overlooked is the misinterpretation of samples in the dataset by the model. There could be patterns or features in the training data that are not represented enough, or the representing features can be inaccurate for certain classes. This is where exploratory data analysis (EDA) and population analysis come into play.
EDA is a process of analyzing and summarizing data sets to understand their main characteristics and features. Population analysis goes beyond EDA in that it studies how the model interprets the features across the samples in the dataset. The goal of population analysis is to ensure that the model has learned the relevant features from the data and that the resulting model can generalize well. It includes examining the model features, identifying any biases or imbalances, and using statistical methods to make inferences about the larger population.
Let us continue using an example to understand why population analysis is crucial in machine learning.
Imagine you want to train a model for classifying animals. We can have the following problems:
- Imbalanced dataset: If 80% of your dataset consists of cat images, and the remaining 20% are images of 50 different animals, then your model would be biased towards cats and likely to classify everything it sees as a cat.
- Outliers: If there are outlier images in our cat class, let’s say toothbrush images labeled as cats, these can skew the result of our model.
- Wrong Feature Selection: If the majority of cat images in your dataset include cats sitting on a sofa, then it is very likely that your model will emphasize features of sofas instead of cats. This will cause the model to overfit the cats sitting on sofas and fail to classify when you pass it an image of a cat walking in the street.
EDA can solve the first two issues, but to tackle the third and the most challenging issue, we need population analysis. If we were to solve these issues, it would be as follows:
- Imbalanced dataset: Identify underrepresented and overrepresented classes. We can then include more samples or apply data augmentation to balance the dataset.
- Outliers: Identify any outliers for each class. We can remove these outliers from our dataset.
- Wrong Feature Selection: Identify any patterns or features in the training data that are not representative of the larger population. This can be solved by either altering the images in the dataset or adapting our model to learn different features.
Population analysis can improve the interpretability of machine learning models by providing insights into the underlying patterns and characteristics of the data. You can understand where your model may be failing and devise a solution.
So, why not apply population analysis for each model we have if it is this useful? The reason is simple. Population analysis is a highly time-consuming and cumbersome process. As you may already know or can easily imagine, reviewing thousands of samples in a dataset to identify which features the model learned about them is a time-consuming and arduous task. The process involves clustering the data, analyzing suspect samples, extracting feature maps from the model, comparing them with other failure cases, identifying common issues, discerning patterns, solving them, retraining the model, evaluating the results, and so on.
Let’s return to our example to clarify how we would have applied population analysis to a real problem and what the problems would be.
- Imbalanced dataset: This is relatively easy to detect, as just checking the number of samples in each class would be enough. We can apply data augmentation to generate more images of the underrepresented classes to solve it. Overall, balancing an imbalanced dataset requires
- Outliers: For our example, outliers are images that do not belong to any of the animal classes or are incorrectly labeled. To find these outliers, we must inspect samples in our dataset and ensure they are correctly labeled.
- Wrong Feature Selection: In our example, feature selection could mean identifying the most informative features that help the model distinguish between different animal classes. For example, the model could learn that specific shapes, textures, or colors characterize certain animal classes. To identify these features, we need to interpret the model and understand why it gives its decision for failure cases. We can use feature extraction or dimensionality reduction techniques such as PCA, LDA, or CNNs as a starting point. However, experimenting with different techniques and evaluating their performance can be time-consuming, especially for large datasets. Plus, if we want to interpolate the problem, we would need to generate heat maps for each wrong prediction to understand what the actual problem is.
All these problems, especially the interpretation of the model, are time and resource-consuming to detect without even considering the complexity of the corresponding solutions. Moreover, we cannot ensure the solutions we devise are error-prone, as what we think is the source of the problem could actually be irrelevant.
That’s where the Tensorleap Population Exploration feature comes in. Tensorleap enables you to visualize your dataset in an informative latent space created specifically for your model. It then clusters the datasets according to features that the model deems important and generates automated insights into any issues.
You can see an example screen below from Tensorleap that shows the latent space embeddings of the most informative features of our animal classifier model. These embeddings are obtained across every sample in the dataset. Note also the Insight panel that indicates that the dataset includes High Loss and Overfitted clusters.
Latent Space Embedding of the Most Informative Features of the Animal Classifier Model
Tensorleap has made the task of population analysis more manageable and cost-effective while eliminating the potential for human error. Now, we can develop a solution easily.
Zooming in on one of the clusters, see figure below, shows that the training samples, represented by the yellow circles, performed much better than the validation samples, represented by the red circles. The larger the circles, the greater the loss.
A Cluster of Training and Validation Samples
Placing the cursor on one of the red circles shows that the model has incorrectly predicted a cat as a panda. Looking closely at the images, we see that they are both perched on a branch. Could it be that the model overfitted to the branch in the training sample?
In just one click, we can create a heatmap that shows us which pixels contributed to the model’s decision in both the Test and Validation set. It looks like we were right – the model predicted that the cat was a panda based on the branch in the foreground.
Heatmaps Showing the Most Important Feature Contributing to the Model’s Prediction
Doing this manually would have been extremely time-consuming, but with Tensorleap, we have the information we need to resolve the issues readily available in a matter of seconds.