This article was first published in ： Walker AI
Using deep learning to do multi classification is a common task in industry or research environment . In a research environment , Whether it's NLP、CV or TTS Series of tasks , The data is rich and clean . And in a real industrial environment , Data problems often become a big problem for practitioners ; Common data problems include ：
- The data sample size is small
- Lack of data tagging
- The data is not clean , There are a lot of disturbances
- The distribution of sample number among data classes is not balanced and so on .
besides , There are other problems , This article will not list them one by one . For the above 4 A question ,2020 year 7 month google publish one’s thesis 《 Long-Tail Learning via Logit Adjustment 》 adopt BER ( Balanced Error Rate ) The related reasoning of cross entropy function , On the basis of the original cross entropy , So that the average classification accuracy is higher . This paper will briefly interpret the core inference of this paper , And use keras Deep learning framework to achieve , Finally, through a simple Mnist Experimental results of handwritten numeral classification . This article will be interpreted from the following four aspects ：
- Basic concepts
- Core inference
- Code implementation
- experimental result
1. Basic concepts
In the multi classification problem based on deep learning , In order to get better classification results, we often need to analyze the data 、 The structural parameters of neural networks 、 The loss function and training parameters are adjusted ; Especially in the face of data with imbalanced categories , Make more adjustments . In the paper 《 Long-Tail Learning via Logit Adjustment 》 in , In order to alleviate the problem of low classification accuracy caused by imbalanced categories , By adding the prior knowledge of the label to the loss function, we get SOTA effect . therefore , This paper aims at its core inference , First of all, four basic concepts are briefly described ：（1） Long tail distribution 、（2）softmax、（3） Cross entropy 、（4）BER
1.1 Long tail distribution
If the training data of all categories are sorted from high to low according to the sample size of each category , And show the sorting result on the graph , Then the class imbalance training data will show “ Head ” and “ The tail ” The distribution form of , As shown in the figure below ：
Category formation with large sample size “ Head ” , The category with low sample size is formed “ The tail ” , The problem of class imbalance is significant .
softmax Because of its normalization function and easy derivation , It is often used as the activation function of the last layer of neural network in two or more classification problems , It is used to express the prediction output of neural network . This paper deals with softmax Don't go over it , Only the generalized formula is given ：
In the neural network , It's the output of the upper layer ; It is the distribution form of the output of this layer ; It's a batch Inside And .
1.3 Cross entropy
In this paper, we don't make too many inferences about the cross entropy function , For details, please refer to the relevant literature of information theory . In the problem of two or more classifications , The cross entropy function and its variants are usually used as the loss function for optimization , Give the basic formula ：
In the neural network , Is the expected sample distribution , Usually one-hot Coded tags ; Is the output of neural network , It can be regarded as the prediction result of neural network to samples .
BER In the second classification, it is the mean value of prediction error rate in positive samples and negative samples ; In the multi classification problem, it is the weighted sum of the error rates of all kinds of samples , It can be expressed in the following form （ Refer to the paper ）：
among , It's the whole neural network ; Indicates that the input is , Output is The neural network of ; Represents the label that is wrongly recognized by the neural network ; It is the calculation form of error rate ; For all kinds of weights .
2. Core inference
According to the idea of the paper , First, a neural network model is determined ：
namely To satisfy BER A neural network model of conditions . Then optimize this neural network model , This process is equivalent to , That is, given the training data Get the prediction tag , And the prediction label equalization （ Multiply by their respective weights ） Optimization process of . Shorthand for ：
about , obviously , among It's a label priori ; It's given training data And then the conditional probability of the prediction label . Combined with the essence of training in multi classification neural network ：
According to the above process , Let's say the network outputs logits Write it down as s*： , because Need to pass through softmax Activation layer , namely ; So it's not hard to come up with ： . combining , Can be Expressed as ：
Refer to the above formula , The paper gives the optimization Two ways of implementing ：
（1） adopt , In the input Through all the neural network layers to get predictions predict after , Divide by a priori . This method has been used before , And achieved certain results .
（2） adopt , In the input Get a code through the neural network layer logits Then subtract one . This paper adopts this idea .
Follow the second line of thought , In this paper, we give a general formula directly , be called logit adjustment loss：
Compared with the regular softmax Cross entropy ：
Essentially, an offset associated with the label prior is applied to each logarithmic output （ That is, through softmax The result before activation ）.
3. Code implementation
The idea of realization is ： The output of the neural network logits Plus a priori based offset . In practice, , In order to make it as simple and effective as possible , Take the regulatory factor =1, =1. be logit adjustment loss Simplified as ：
stay keras The implementation is as follows ：
import keras.backend as K def CE_with_prior(one_hot_label, logits, prior, tau=1.0): ''' param: one_hot_label param: logits param: prior: real data distribution obtained by statistics param: tau: regulator, default is 1 return: loss ''' log_prior = K.constant(np.log(prior + 1e-8)) # align dim for _ in range(K.ndim(logits) - 1): log_prior = K.expand_dims(log_prior, 0) logits = logits + tau * log_prior loss = K.categorical_crossentropy(one_hot_label, logits, from_logits=True) return loss Copy code
4. experimental result
The paper 《 Long-Tail Learning via Logit Adjustment 》 In this paper, we compare several methods to improve the classification accuracy of long tail distribution , And tested with different data sets , Test performance is better than existing methods , Detailed experimental results refer to the paper itself . In order to quickly verify the correctness of the implementation , And the effectiveness of this method , Use mnist A simple experiment of handwritten numeral classification is carried out . The background of the experiment is as follows ：
|The training sample||0 ~ 4：5000 Zhang / class ;5 ~ 9 :500 Zhang / class|
|Test samples||0 ~ 9：500/ class|
|Running environment||Local CPU|
|Network structure||Convolution + Maximum pooling + Full connection|
In the above background, comparative experiments are carried out , Compared with the standard multi classification cross entropy and the cross entropy with a priori, they are used as loss Function , The performance of classified networks . Take the same epoch=60, The experimental results are as follows ：
|\||Standard multi class cross entropy||Cross entropy with a priori|
|Training flow chart|
PS： More technical dry goods , Quick attention 【 official account | xingzhe_ai】, Discuss it with the traveler ！