如何對機器學習代碼進行單元測試?
目前,關于神經網絡代碼,并沒有一個特別完善的單元測試的在線教程。甚至像 OpenAI 這樣的站點,也只能靠 盯著每一行看來思考哪里錯了來尋找 bug。很明顯,大多數人沒有那樣的時間,并且也討厭這么做。所以希望這篇教程能幫助你開始穩健的測試系統。
首先來看一個簡單的例子,嘗試找出以下代碼的 bug。
看出來了嗎?網絡并沒有實際融合(stacking)。寫這段代碼時,只是復制、粘貼了 slim.conv2d(…) 這行,修改了核(kernel)大小,忘記修改實際的輸入。
這個實際上是作者一周前剛剛碰到的狀況,很尷尬,但是也是重要的一個教訓!這些 bug 很難發現,有以下原因。
- 這些代碼不會崩潰,不會拋出異常,甚至不會變慢。
- 這個網絡仍然能訓練,并且損失(loss)也會下降。
- 運行多個小時后,值回歸到很差的結果,讓人抓耳撓腮不知如何修復。
只有最終的驗證錯誤這一條線索情況下,必須回顧整個網絡架構才能找到問題所在。很明顯,你需要需要一個更好的處理方式。
比起在運行了很多天的訓練后才發現,我們如何提前預防呢?這里可以明顯注意到,層(layers)的值并沒有到達函數外的任何張量(tensors)。在有損失和優化器情況下,如果這些張量從未被優化,它們會保持默認值。
因此,只需要比較值在訓練步驟前后有沒有發生變化,我們就可以發現這種情況。
哇。只需要短短 15 行不到的代碼,就能保證至少所有創建的變量都被訓練到了。
這個測試,簡單但是卻很有用。現在問題修復了,讓我們來嘗試添加批量標準化。看你能否用眼睛看出 bug 來。
發現了嗎?這個 bug 很巧妙。在 tensorflow 中,batch_norm 的 is_training 默認值是 False,所以在訓練過程中添加這行代碼,會導致輸入無法標準化!幸虧,我們剛剛添加的那個單元測試會立即捕捉到這個問題!(3 天前,它剛剛幫助我捕捉到這個問題。)
讓我們看另外一個例子。這是我從 reddit 帖子中看來的。我們不會太深入原帖,簡單的說,發帖的人想要創建一個分類器,輸出的范圍在 0 到 1 之間。看看你能否看出哪里不對。
發現問題了嗎?這個問題很難發現,結果非常難以理解。簡單的說,因為預測只有單個輸出值,應用了 softmax 交叉熵函數后,損失就會永遠是 0 了。
最簡單的發現這個問題的測試方式,就是保證損失永遠不等于 0。
我們***個實現的測試,也能發現這種錯誤,但是要反向檢查:保證只訓練需要訓練的變量。就生成式對抗網絡(GAN)為例,一個常見的 bug 就是在優化過程中不小心忘記設置需要訓練哪個變量。這樣的代碼隨處可見。
這段代碼***的問題是,優化器默認會優化所有的變量。在像生成式對抗網絡這樣高級的結構中,這意味著遙遙無期的訓練時間。然而只需要一個簡單測試,就可以檢查到這種錯誤:
也可以對判定模型(discriminator)寫一個同類型的測試。同樣的測試,也可以應用來加強大量其他的學習算法。很多演員評判家(actor-critic)模型,有不同的網絡需要用不同的損失來優化。
這里列出一些作者推薦的測試模式。
- 確保輸入的確定性。如果發現一個詭異的失敗測試,但是卻再也無法重現,將會是很糟糕的事情。在特別需要隨機輸入的場景下,確保用了同一個隨機數種子。這樣出現了失敗后,可以再次以同樣的輸入重現它。
- 確保測試很精簡。不要用同一個單元測試檢查回歸訓練和檢查一個驗證集合。這樣做只是浪費時間。
- 確保每次測試時都重置了圖。
作為總結,這些黑盒算法仍然有大量方法來測試!花一個小時寫一個簡單的測試,可以節約成天的重新運行時間,并且大大提升你的研究能力。天才的想法,永遠不要因為一個充滿 bug 的實現而無法成為現實。
這篇文章列出的測試遠遠沒有完備,但是是一個很好的起步!如果你發現有其他的建議或者某種特定類型的測試,請在 twitter 上給我消息!我很樂意寫這篇文章的續集。
文章中所有的觀點,僅代表作者的個人經驗,并沒有 Google 的支持、贊助。
查看英文原文
https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765