前言
自2015年Hinton大佬發表了知識蒸餾(knowledge distillation)以來,大家在模型訓練方法上除了遷移學習又多了一種選擇。
首先,ViT作為新紀元的開創者,它存在的目的毫無疑問就是被超越。它的問題在于:
1. 需要非常大的算力資源,
2. 只使用ImageNet訓練得到的準確率并未很高(top1-accuracy: 77.9%)
3. 預訓練數據集JFT-300M并未被公開,
4. 對超參數設置要求較高
DeiT
常言道:名師出高徒,又道:三人行則必有我師。對于ViT出現的一眾問題,DeiT使用知識蒸餾的方法,一方面自己對照ground truth數據進行訓練,另一方面讓RegNet作為自己的老師進行訓練,并使用了warmup, label smoothing和droppath等tricks。除去使用了knowledge distillation,在訓練中,作者還使用了數據增強,超參數調整等tricks。
知識蒸餾
簡單來說就是用teacher模型去訓練student模型,通常teacher模型更大而且已經訓練好了,student模型是我們當前需要訓練的模型。在這個過程中,teacher模型是不訓練的。
這里有兩種知識蒸餾方式:
1.
soft distillation

(軟蒸餾公式)

(軟蒸餾流程圖)
當teacher模型和student模型接收到相同的輸入圖片時,首先都進行前向傳播,這個時候因為teacher模型處在測試階段,所以通過softmax獲得一個label,但我們注意到,這個label叫soft label,因為在做softmax時,除以了一個參數
T,這個參數又叫做temperature(蒸餾溫度),然后softmax就會得到一個非常平緩的soft label。

(softmax with temperature)

(softmax with temperature)
student模型同樣也是除以一個
T,然后softmax得到一個soft prediction,我們希望student模型的soft-prediction和teacher模型的soft labels盡量接近,使用KLDivLoss進行兩者之間的差距度量,計算一個對應的損失teacher loss。
在訓練的時候,我們是可以拿的到訓練圖片的真實的ground truth(hard label)的,可以看到上面圖中student模型下面一路,就是預測結果和真實標簽之間的交叉熵損失cross entropy loss(CELoss)。
然后計算兩路的損失:KLDivLoss和CELoss,按照一個加權關系計算得到一個總損失total loss,反向傳更新參數的時候這個teacher模型是不做訓練的,只依據total loss訓練student模型。
2. hard distillation
同樣,作者也提供了硬蒸餾(hard distillation),至于孰好孰壞,目前暫無定論。

(硬蒸餾公式)
KLDivLoss(KL散度)
衡量兩個分布之間的相似程度或者說是距離
(KL散度)
Temperature(蒸餾溫度)
T 的作用在于使得整個離散概率分布的離散值變得更加接近。
如果是[1.0,20.0,400.0]直接做softmax,那結果是[0.0,0.0,1.0],可見結果完全借鑒第三個因子。而先進行處理(比如除以1000)后變為[0.001,0.02,0.4]時,在做softmax結果為[0.28,0.29,0.42],結果總綜合考慮了三部分,這顯然是更合理的結果。實際中,看我是更希望結果偏向于更大的值,還是偏向于綜合考慮來決定是否使用softmax前輸入的預處理。
Distillation in Transformer
蒸餾過程(論文圖)
先說一下,在這DeiT篇論文出來的時候,teacher model使用的是RegNet(何愷明大佬提出的一個CNN)。
ViT是使用class tokens去做分類的,相當于是一個額外的patch,這個patch去學習和別的patch之間的關系,然后連classifier,計算CELoss。在DeiT中為了做蒸餾,又額外加一個distill token,這個distill token也是去學和其他tokens之間的關系,然后連接teacher model計算KLDivLoss,那CELoss和KLDivLoss共同加權組合成一個新的loss取指導student model訓練(知識蒸餾中teacher model不訓練)。
在預測階段,class token和distill token分別產生一個結果,然后將其加權(分別0.5),再加在一起,得到最終的結果做預測。
Better Hyperparameter
參數初始化方式:truncated normal distribution(截斷標準分布)。
learning-rate:CNN中的結論:當batch size越大的時候,learning rate設置的越大。
learning rate decay:cosine,在warm-up階段lr先線性升上去,然后通過余弦方式lr降下來。
Data Augmentation
用了mixup和cutmix
效果對比
集萃感知的人工智能雷達視覺融合一體機是一款集毫米波雷達、智能視覺攝像機于一體的智能交通路側感知產品,該產品將融合毫米波雷達和攝像頭的感知優勢,通過雷達電磁調控、信號深度學習、雷視數據級融合等人工智能雷達技術,實現未來新一代智慧交通中——智能化交通信息全息采集及管理功能。