什么是神經網絡---LSTM模型實例講解
LSTM的關鍵在于它的“記憶單元”,能夠選擇性地記住或者忘記信息。其核心組件包括三個門和一個記憶單元:
1. 遺忘門(Forget Gate):決定應該丟棄哪些信息。
2. 輸入門(Input Gate):決定更新哪些新的信息。
3. 輸出門(Output Gate):決定當前狀態如何影響輸出。
數學公式解釋
- 遺忘門:
ft=σ(Wf?[ht?1,xt]+bf)
遺忘門決定了上一時刻的狀態 Ct?1中,哪些信息需要保留,哪些需要丟棄。值域為 [0, 1],1表示完全保留,0表示完全丟棄。 - 輸入門:
it=σ(Wi?[ht?1,xt]+bi)
輸入門決定了當前時刻的輸入 xt - 候選記憶單元:
C~t=tanh?(WC?[ht?1,xt]+bC)
這是當前時刻的候選記憶內容。 - 更新記憶單元:
Ct=ft?Ct?1+it?C~t
記憶單元通過遺忘門和輸入門結合,更新當前的記憶狀態。 - 輸出門:
ot=σ(Wo?[ht?1,xt]+bo)
輸出門決定了當前記憶單元 Ct - 隱藏狀態更新:
ht=ot?tanh?(Ct)
?隱藏狀態通過輸出門和當前記憶單元來更新。
有同學比較疑惑LSTM中的sigmoid函數和tanh的作用是什么?下來我來為大家解惑:?
- Sigmoid函數(遺忘門、輸入門、輸出門):用于控制信息流。由于sigmoid的輸出值在 0 和 1 之間,表示“選擇”的強度。0表示完全不通過,1表示完全通過。它起到類似開關的作用。
- Tanh函數(候選記憶單元、隱藏狀態):用于將輸入值縮放到 -1 到 1 之間,確保信息在網絡中不會增長過大或過小,幫助模型捕捉數據中的正負變化。同時,tanh可以引入非線性特征,增加網絡表達能力。
也就是說Sigmoid函數作為**“開關”**,在LSTM的各個門(遺忘門、輸入門、輸出門)中使用,決定信息流的多少。Tanh函數用于將數值范圍縮放到-1到1之間,幫助控制記憶單元的值,確保信息的平衡和穩定性,并用于生成隱藏狀態。
下面讓我們來用一個例子來輔助大家對模型的理解:
為了詳細講解LSTM如何工作,我們通過一個具體的例子一步步剖析每個步驟。
例子:預測序列中的下一個數
我們有一個簡單的序列數據:1,2,3,4,5,6,7,8,9,10
我們的目標是訓練一個LSTM模型,讓它能夠根據之前的數字預測下一個數字。例如,輸入[1, 2, 3]時,模型應該輸出4。
1. 輸入表示
LSTM處理的是時間序列數據,我們可以將每個數字視為一個時間步(time step)。對于輸入序列1,2,3,我們需要在每個時間步都輸入一個數字:
- 時間步1:輸入1
- 時間步2:輸入2
- 時間步3:輸入3
在每一個時間步,LSTM會使用前一步的隱藏狀態以及當前輸入來更新它的記憶單元C和隱藏狀態h。
2. LSTM的核心機制
LSTM有三個關鍵的門:遺忘門、輸入門和輸出門。這三個門控制了信息如何在LSTM單元中流動。讓我們看一下當我們輸入1,2,3時,LSTM內部發生了什么。
時間步1:輸入1
- 遺忘門
遺忘門的作用是決定上一個時間步的信息要保留多少。因為這是第一個時間步,之前沒有信息,所以LSTM的記憶單元初始為0。假設此時遺忘門計算出的值為0.8,這意味著LSTM會保留80%的之前的記憶狀態(雖然此時沒有實際的歷史狀態)。 - 輸入門
輸入門決定新信息要多少被寫入記憶單元。假設輸入門給出的值為0.9,意味著我們會把當前輸入的信息的90%加入到記憶單元。 - 更新記憶單元
LSTM單元會計算候選記憶內容。假設此時的候選內容為1(通過激活函數計算得到),結合遺忘門和輸入門,更新記憶單元:
C1=0.8?0+0.9?1=0.9 - 輸出門
輸出門決定記憶單元如何影響隱藏狀態。假設輸出門給出的值為0.7,隱藏狀態通過以下公式計算:
h1=0.7?tanh?(0.9)≈0.63 - 這個隱藏狀態會作為下一時間步的輸入。
時間步2:輸入2
- 遺忘門
遺忘門決定如何處理前一個時間步的記憶單元。假設遺忘門的值為0.7,意味著70%的上一步記憶將被保留。 - 輸入門
假設此時輸入門的值為0.8,意味著會將當前輸入2的80%加入到記憶單元。 - 更新記憶單元
候選內容通過激活函數計算,假設此時候選內容為2。結合遺忘門和輸入門:
C2=0.7?0.9+0.8?2=2.03 - 輸出門
假設輸出門的值為0.6,隱藏狀態為:
h2=0.6?tanh?(2.03)≈0.56
時間步3:輸入3
- 遺忘門
假設遺忘門的值為0.75,保留75%的上一個記憶。 - 輸入門
假設輸入門的值為0.85,意味著會將當前輸入3的85%加入到記憶單元。 - 更新記憶單元
假設候選內容為3,更新記憶單元:
C3=0.75?2.03+0.85?3=3.5 - 輸出門
假設輸出門的值為0.65,隱藏狀態為:
h3=0.65?tanh?(3.5)≈0.58
3. LSTM如何預測
經過這三個時間步,LSTM得到了一個隱藏狀態h3,代表了模型對序列1,2,3的理解。接下來,隱藏狀態會通過一個全連接層或線性層,輸出預測值。
假設線性層的輸出是4,這意味著LSTM模型根據序列1,2,3,預測接下來是4。
總結
- 遺忘門決定了LSTM是否保留之前的記憶。
- 輸入門決定了新輸入如何影響當前的記憶單元。
- 輸出門決定了隱藏狀態如何影響輸出。
通過這些門的組合,LSTM可以有效地在長序列數據中學習到哪些信息是重要的,哪些是可以忽略的,從而解決傳統RNN的長依賴問題。
接下來會繼續深入講解LSTM模型,比如如果輸入一串中文,LSTM該如何處理怎么預測,比如各個門的權重如何更新,如何決定遺忘多少,輸入多少。
本文轉載自???人工智能訓練營???,作者:人工智能訓練營
