Bayesian Active Learning Library - BaaL (Short Notes)
I will try to briefly cover the basic details of this library.
List of Modules in BAAL
There are several modules that are handling the whole process of Active Learning.
- Model
- ModelWrapper: wraps pytorch model and adds the functionality to train/test/predict on batches/samples and also includes metrics and criterion
- Dropout: wraps _DropoutNd module from pytorch. Replacement for torch.nn.Dropout. Required for estimation of model prediction uncertainties. Contains additional scaling of outputs.
- Dropout2d: wraps _DropoutNd from pytorch. Replacement for torch.nn.Dropout2d.
- MCDropoutModule: Wrapper of nn.Module to add the functionality of adding dropouts of 1D/2D layers of NN.
- ConsistentDropout: Fixes the dropout between batches to be consistent in prediction phase (for research purposes). Replacement for torch.nn.Dropout.
- ConsistentDropout2d: Fixes the dropout2d between batches to be consistent in prediction phase (for research purposes). Replacement for torch.nn.Dropout2d.
- MCConsistentDropoutModule: Wrapper of nn.module to add the functionality of MCDropoutModule + consistency in dropouts
- WeightDropLinear: Applies dropout on torch.nn.Linear weights (bias not affected).
- WeightDropConv2d: Applies dropout on torch.nn.Conv2d weights (bias not affected).
- MCDropoutConnectModule: Applies dropout on all layers of pytorch Neural Network Model. Affects all modules containing keys: (conv, linear, lstm, gru)
- Calibration
- DirichletCalibrator: Dirichlet Calibration applied on classifier output. Adds extra Linear layer with n_outputs=n_classes and train only that layer with calibration in order to get better confidence values on predictions.
- Dataset
- SplittedDataset: Wrapper of pytorch dataset class with additional abilities to have label markers in order to separate labeled/unlabeled sets of data.
- ActiveNumpyArray: Customized SplittedDataset for working with Numpy datasets (for scikit-learn use cases).
- FileDataset: Wrapper of pytorch dataset with additional abilities to label data and load samples from files.
- HuggingFaceDatasets: Special wrapper for working with huggingface nlp datasets.
- ActiveLearningDataset: Handles labeling steps in terms of two labeled/unlabeled sets of data and labeling iteration steps.
- ActiveLearningPool: Consists of unlabeled samples datasets (created by pytorch dataset Subset with unlabeled data mask). Consumed by ActiveLearningDataset for storing unlabeled dataset named as
pool
on each step.
- SplittedDataset: Wrapper of pytorch dataset class with additional abilities to have label markers in order to separate labeled/unlabeled sets of data.
- Heuristics
- AbstractHeuristic: Abstract class defining general methods that are required to estimate uncertainty scores on model predictions which are generated by multiple iterations of inference on the model with different dropout. The main idea is that if model is uncertain about something then there is a high probability that while applying different dropout the final (class) predictions will vary a lot. These heuristics are analyzing that effect in different ways. Different aggeregations are used in addition to get scalars for uncertainty scores and return ranked (sorted by these uncertainties) unlabeled samples.
- BALD: “Bayesian Active Learning by Disagreement” which means that uncertainty is sum of two entropy values: 1) entropy of model output on given sample 2) expected entropy of outputs using sampled models (MC sampling - different dropouts applied on model weights). (good explanation here)
- BatchBALD: Batch BALD is an improvement over BALD in terms of sampling strategy. Batch BALD additionally applies Diversity Sampling in order to not include samples that have high uncertainty (as BALD measures) but they are highly correlated.
- Variance: High variance means high uncertainty
- Entropy: Entropy is high if predictions are uniform (almost the same as variance)
- Margin: Difference of probability between most certain and second most certain classes.
- Certainty: The most certain class probabilities (the lower is the probability the higher is the uncertainty)
- Precomputed: If probabilities that are given as an argument are the actual uncertainty scores
- Random: Random ranking
- CombineHeuristics: Combines heuristics with weighted sum and additional reduction operations. One strict requirements is that the ordering by each heuristic should be in the same direction (reversed or not reversed) otherwise there is no point to apply weighted sum.
- AbstractGPUHeuristic: The same as AbstractHeuristic but for GPU (doing computations on GPU). It additionally defines methods for model inference on dataset on GPU.
- BALDGPUWrapper: GPU implementation for BALD heuristic
- AbstractHeuristic: Abstract class defining general methods that are required to estimate uncertainty scores on model predictions which are generated by multiple iterations of inference on the model with different dropout. The main idea is that if model is uncertain about something then there is a high probability that while applying different dropout the final (class) predictions will vary a lot. These heuristics are analyzing that effect in different ways. Different aggeregations are used in addition to get scalars for uncertainty scores and return ranked (sorted by these uncertainties) unlabeled samples.
- ActiveLearningLoop: The main ‘engine’ for active learning which uses all types of objects to orchestrate entire workflow. It has only one method called ‘step’ which runs labeling iteration and uncertainty sampling from pool to label these datapoints.
References
https://github.com/ElementAI/baal
https://jacobgil.github.io/deeplearning/activelearning#batchbald (great Blog on Active Learning)
https://baal.readthedocs.io/en/latest/ (official documentation)
https://arxiv.org/abs/2006.09916 (ElementAI paper about BaaL)
https://www.elementai.com/news/2019/element-ai-makes-its-bayesian-active-learning-library-open-source (BaaL presentation by ElementAI)
https://www.elementai.com/news/2020/road-defect-detection-using-deep-active-learning (Interesting use case of BaaL)