教育行業(yè)A股IPO第一股(股票代碼 003032)

全國咨詢/投訴熱線:400-618-4000

yolo算法:構造訓練樣本和設計損失函數(shù)

更新時間:2022年12月08日09時36分 來源:傳智教育 瀏覽次數(shù):

在進行模型訓練時,我們需要構造訓練樣本和設計損失函數(shù),才能利用梯度下降對網(wǎng)絡進行訓練。

訓練樣本的構建

將一幅圖片輸入到y(tǒng)olo模型中,對應的輸出是一個7x7x30張量,構建標簽label時對于原圖像中的每一個網(wǎng)格grid都需要構建一個30維的向量。對照下圖我們來構建目標向量:

1670392915921_28.png

20個對象分類的概率

對于輸入圖像中的每個對象,先找到其中心點。比如上圖中自行車,其中心點在黃色圓點位置,中心點落在黃色網(wǎng)格內(nèi),所以這個黃色網(wǎng)格對應的30維向量中,自行車的概率是1,其它對象的概率是0。所有其它48個網(wǎng)格的30維向量中,該自行車的概率都是0。這就是所謂的"中心點所在的網(wǎng)格對預測該對象負責"。狗和汽車的分類概率也是同樣的方法填寫

2個bounding box的位置

訓練樣本的bbox位置應該填寫對象真實的位置bbox,但一個對象對應了2個bounding box,該填哪一個呢?需要根據(jù)網(wǎng)絡輸出的bbox與對象實際bbox的IOU來選擇,所以要在訓練過程中動態(tài)決定到底填哪一個bbox。

2個bounding box的置信度

預測置信度的公式為:

1670393030360_29.png

利用網(wǎng)絡輸出的2個bounding box與對象真實bounding box計算出來。然后看這2個bounding box的IOU,哪個比較大,就由哪個bounding box來負責預測該對象是否存在,即該bounding box的Pr(Object)=1,同時對象真實bounding box的位置也就填入該bounding box。另一個不負責預測的bounding box的Pr(Object)=0。

上圖中自行車所在的grid對應的結果如下圖所示:

樣本標簽

損失函數(shù)

損失就是網(wǎng)絡實際輸出值與樣本標簽值之間的偏差:

損失函數(shù)

yolo給出的損失函數(shù):

損失函數(shù)

模型訓練

Yolo先使用ImageNet數(shù)據(jù)集對前20層卷積網(wǎng)絡進行預訓練,然后使用完整的網(wǎng)絡,在PASCAL VOC數(shù)據(jù)集上進行對象識別和定位的訓練。

Yolo的最后一層采用線性激活函數(shù),其它層都是Leaky ReLU。訓練中采用了drop out和數(shù)據(jù)增強(data augmentation)來防止過擬合。

模型預測

將圖片resize成448x448的大小,送入到y(tǒng)olo網(wǎng)絡中,輸出一個 7x7x30 的張量(tensor)來表示圖片中所有網(wǎng)格包含的對象(概率)以及該對象可能的2個位置(bounding box)和可信程度(置信度)。在采用NMS(Non-maximal suppression,非極大值抑制)算法選出最有可能是目標的結果。

總結:yolo模型預測速度非???,處理速度可以達到45fps,其快速版本(網(wǎng)絡較小)甚至可以達到155fps。訓練和預測可以端到端的進行,非常簡便。準確率會打折扣對于小目標和靠的很近的目標檢測效果并不好。

0 分享到:
和我們在線交談!