來源:北大青鳥總部 2023年02月21日 14:45
眾所周知,隨著深度學習技術的發展,深度卷積神經網絡在圖像分類、識別以及關鍵點定位上已得到廣泛應用。目前在人體姿態、人臉識別等多個方面的關鍵點定位算法已經取得長足發展,但是應用于多變性的圖像背景以及姿態等依然面臨很大的挑戰,如服飾在類別、比例和外觀上具有多變性,其關鍵點定位精度并不高。下文將在傳統的殘差與沙漏網絡模型基礎上,介紹一種新的級聯金字塔結構卷積神經網絡,實現對困難關鍵點的定位進行精細調整。并通過實例剖析進一步幫助大家來理解。
1.沙漏網絡
沙漏網絡,正如其名,是一種形如沙漏的下采樣-上采樣結構,如下圖所示。圖中左側部分通過卷積和池化操作將特征圖降低到較低的分辨率。下采樣通過池化操作完成,同時通過另一路卷積保留下采樣前的特征圖,用于和右側上采樣部分同尺度的特征圖進行融合。當下采樣部分特征圖達到最小分辨率后,網絡經過最近鄰上采樣后與保留的同尺度特征圖進行融合,最后網絡輸出表示各個關節點在該像素出現的概率的特征集。
沙漏網絡設計的目的在于獲取不同尺度下圖像所包含信息。利用多模塊的沙漏網絡可以定位關鍵點進而來識別人體姿態特征。
2.深度殘差網絡
假定一個網絡的輸入為
理想的映射輸出為
為了獲取
利用堆疊的非線性層來擬合殘差映射
由此可以得到
因此擬合最優映射的問題轉化為擬合殘差映射函數,使得網絡模型不再是學習一個完整的輸出,而只是學習殘差
解決深度卷積網絡中,隨著網絡層數的加深,造成的梯度消失、爆炸等問題。
相比較普通網絡,深度殘差網絡引入捷徑跳過某些層的連接,再與主徑匯合,如下圖所示。這使得底層的誤差可通過捷徑向上層傳遞而解決梯度消失的問題,在不增加額外參數又不提高計算復雜度的同時增加網絡模型的訓練速度、提高訓練效果。作為簡單且實用的深層次網絡模型,深度殘差網絡在圖像分割、目標檢測等圖像處理領域內應用廣泛。
級聯金字塔結構卷積神經網絡的算法框架分為兩部分,如下圖所示:
第1部分為全局關鍵點定位網絡,使用殘差網絡作為特征提取網絡,通過特征金字塔融合多尺度特征,實現關鍵點的初步定位;
第2部分以沙漏網絡為基礎對第1級損失較大的關鍵點精細調整,進而實現對服飾關鍵點的精確定位。
在進一步解釋前,大家需要對使用殘差網絡提取的不同層的特征圖尺度形成的金字塔結構有一定了解。如下圖所示,特征金字塔結構在網絡前向卷積的過程中對每一分辨率的特征圖引入后一分辨率縮放2倍的特征圖做逐個元素自底向上相加的操作,以這種方式將卷積神經網絡中高分辨率低語義信息的底層特征圖和低分辨率高語義信息的高層特征圖進行融合,使得融合之后特征圖既包含豐富的語義信息,也包含由于不斷降采樣而丟失的底層細節信息。
詳細介紹:
1.第1級網絡
第1級網絡首先通過殘差網絡進行特征提取,C1~C5分別代表殘差網絡中卷積Conv1~Conv5產生的特征圖。比如,輸入一張大小為512×512的圖像,原始的ResNet經過5次步長為2的卷積操作達到降采樣的目的,特征圖發生5次尺度變化,最終卷積層輸出的特征圖C5的尺寸為16×16。這里,算法引入空洞卷積為了提高特征圖空間分辨率。
利用殘差網絡提取的特征圖構建特征金字塔時,因為特征圖C3~C5具有相同的尺寸,所以可不經過上采樣直接融合。融合后的結果與C2繼續融合時,先經過雙線性插值進行2倍的上采樣。每一級產生的特征圖都生成一組熱力圖,同組的每張熱力圖包含輸入圖像的一個關鍵點的坐標,和真實關鍵點坐標生成的熱力圖進行誤差計算求得損失,共同監督網絡訓練。在測試階段,第1級網絡輸出的熱力圖可以得到全部關鍵點的位置坐標。
2.第2級網絡
第2級網絡使用兩個堆疊的沙漏網絡,但與原始的沙漏網絡不同的是,第1個沙漏網絡的下采樣部分即上采樣部分的輸入是第1級金字塔結構輸出的特征圖。針對困難關鍵點,選擇第1級損失較大的關鍵點進行精細調整,僅從這部分關鍵點反向傳播損失算法。第1個沙漏網絡融合來自第1級網絡所有金字塔層的信息進行定位,第2個沙漏網絡利用前一個沙漏網絡輸出的熱力圖作為關鍵點之間的結構先驗進行定位。每個沙漏網絡都生成一組熱力圖,并與真值的誤差作為損失函數監督網絡訓練。測試階段,最后結果為2級輸出結果的綜合。
雖然第1級網絡已經能夠完成關鍵點定位任務。但是由于服飾背景、姿態等的復雜性,一些困難關鍵點依然難以實現精確定位,這里設計了第2級網絡對困難關鍵點的坐標進行精細調整。
數據集選取
這里以具有多變性的女裝服飾圖片作為對象來研究分析。實驗選取2018 FashionAI 服飾關鍵點定位數據集。此數據集是同時符合機器學習要求和服飾專業性的高質量數據集。服飾的關鍵點基于服裝設計的5大專業類別定義,分別為上衣、外套、褲子、半身裙、連身裙。在該數據集中,每種服飾具體關鍵點如下圖所示。本文案例的數據僅包含單個模特或者商品的圖像。所預測的服飾所屬的類別已知,不需要單獨進行分類。數據集包括54166個訓練樣本和9971個測試樣本。
級聯結果分析
采用上面算法,通過級聯的兩級卷積神經網絡分別實現對關鍵點的初步定位和進一步修正,其結果如下圖所示。圖中所示為包含上衣、外套類別的4張服飾圖像經過級聯網絡的關鍵點定位結果圖,每張圖片的上面一張顯示的是只經過第1級網絡的結果輸出圖像,下面一張包含第2級網絡的結果輸出圖像。圖像中的部分關鍵點經過了調整,尤其是方框圈起來的關鍵點在第2級網絡經過了比較明顯的調整,比如:
第1張圖像中的右腋窩和右袖口內關鍵點,由于被遮擋誤差較大;
第2張圖像左腋窩定位錯誤;
第3張圖像左袖口內側被遮擋定位誤差較大;
第4張圖像右腰部關鍵點被水印遮擋。
經過第2級網絡這些關鍵點都得到了進一步調整,很明顯地減小了定位誤差,使得最終輸出的定位結果更加準確.這一級聯結果對比證明了上述算法可提高關鍵點精確度的有效性。
上文通過將傳統的殘差與沙漏網絡模型進行級聯,并詳細介紹了一種新的級聯金字塔結構卷積神經網絡。為了進一步優化對關鍵點定位精度的問題,充分利用特征信息,在第1級使用殘差網絡進行特征提取網絡形成特征金字塔結構,保留了更多的圖像細節信息,實現對所有關鍵點的定位;在第2級以沙漏網絡為基礎,整合來自上一級的特征信息,利用前一級預測出來的關鍵點之間的結構先驗,對困難關鍵點即第1級損失較大的關鍵點進行精細調整,進一步提升整個網絡的定位精度。該網絡模型對具有多變性的圖像背景以及姿態等進行關鍵點定位有很好的適應性。