The Annotated BERT注釋加量版,讀懂代碼才算讀懂了BERT 原創
前面我們從0實現了Transformer和GPT2的預訓練過程,并且通過代碼注釋和打印數據維度使這個過程更容易理解,今天我將用同樣的方法繼續學習Bert。
原始Transformer是一個Encoder-Decoder架構,GPT是一種Decoder only模型,而Bert則是一種Encoder only模型,所以我們主要關注Transformer的左側部分。
后臺回復bert獲取訓練數據集、代碼和論文下載鏈接
閱讀本文時請結合代碼
https://github.com/AIDajiangtang/annotated-transformer/blob/master/AnnotatedBert.ipynb
0.準備訓練數據
0.0下載數據
原始BERT使用BooksCorpus和English Wikipedia作為預訓練數據,但這個數據集太大了,我們本次使用IMDb網站的50,000條電影評論數據來預訓練,它是一個包含兩列數據的csv文件,其中review列是電影評論,sentiment列是情感標簽,即正面(positive)或負面(negative),我們本次只使用review列的電影評論。
(后臺回復bert獲取數據集下載鏈接)
下面打印出一條評論
One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked.
They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO.
Trust me, this is not a show for the faint hearted or timid. This show pulls no punches with regards to drugs, sex or violence. Its is hardcore, in the classic use of the word.<br /><br />It is called OZ as that is the nickname given to the Oswald Maximum Security State Penitentary. It focuses mainly on Emerald City, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda. Em City is home to many..Aryans, Muslims, gangstas, Latinos, Christians, Italians, Irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.<br /><br />I would say the main appeal of the show is due to the fact that it goes where other shows wouldn't dare. Forget pretty pictures painted for mainstream audiences, forget charm, forget romance...OZ doesn't mess around. The first episode I ever saw struck me as so nasty it was surreal, I couldn't say I was ready for it, but as I watched more, I developed a taste for Oz, and got accustomed to the high levels of graphic violence. Not just violence, but injustice (crooked guards who'll be sold out for a nickel, inmates who'll kill on order and get away with it, well mannered, middle class inmates being turned into prison bitches due to their lack of street skills or prison experience) Watching Oz, you may become comfortable with what is uncomfortable viewing....thats if you can get in touch with your darker side.
ds = IMDBBertDataset(BASE_DIR.joinpath('data/imdb.csv'), ds_from=0, ds_to=1000)
為了加快訓練,通過ds_from和ds_to參數設置只讀取前1000條評論。
0.1計算上下文長度
上下文長度是指輸入序列的最大長度,再講Transformer和GPT2時,是直接通過超參數設置的,今天我們將根據訓練數據統計得出,通過pandas逐行讀取1000條數據,將每條評論按'.'分割成句子,并將所有句子的長度存儲到一個數組中。取句子長度數組中第90百分位的值。
通過計算,找到最優的句子長度:27,如果樣本長度大于27會被截斷,小于27會用特殊字符填充。
舉個簡單的例子,假設句子長度數組為 [10, 20, 30, 40, 50, 60, 70, 80, 90, 100],那么第90百分位的值就是90。
0.2分詞
本次使用的是basic_english分詞方法,它是一種非常簡單且直接的分詞方法,先將所有文本轉換為小寫,然后去除標點符號,最后按空格和標點符號將文本拆分成單詞。
"Hello, world! This is an example sentence."
['hello', 'world', 'this', 'is', 'an', 'example', 'sentence']
接下來將拆分后的單詞轉換成一個數字id,這個過程需要根據訓練數據構造一個詞表,也就是找到訓練數據中所有唯一單詞。
通過統計可知,這1000條數據包含詞匯數:9626
然后將下面特殊字符加到詞表前面。
CLS = '[CLS]'
PAD = '[PAD]'
SEP = '[SEP]'
MASK = '[MASK]'
UNK = '[UNK]'
0.3構造訓練數據
BERT是一種Encoder only架構,每一個token會與其它所有token計算注意力,無論是它前面的還是后面的。這樣能充分吸收上下文信息,Encoder only的模型適合理解任務。
而Decoder只與它前面的token計算注意力。從這種意義上看,GPT只利用了上文,但這種自回歸的方式也有好處,就是適合生成任務。
為了學習雙向表示,除了模型結構,構造訓練數據方式也有所不同。
GPT是用當前詞預測下一個詞,假設訓練數據的token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],context_length=4,stride=4,batch_size=2。
Input IDs: [tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8])]
Target IDs: [tensor([2, 3, 4, 5]),tensor([6, 7, 8, 9])]
BERT采用兩種方式構造預訓練數據:
MLM會隨機將一個樣本中的某些詞替換成[MASK],或者替換成詞表中的其它詞,在本例中,會替換15%的詞,其中80%替換成[MASK],20%替換成詞表中的其它詞。
NSP則是將相鄰的句子構造成正樣本對,將不相鄰的句子視為負樣本對,兩個句子之間加一個[SEP]分割符。
BERT不善于生成任務,那它如何完成問答等下游任務?其實,BERT會在每個樣本開頭都會放一個[CLS] token,通過CLS輸出進行二分類。
知道方法后,接下來構造訓練數據,首先遍歷這1000條電影評論文本。
以第一條評論為例
One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked.
They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO.
Trust me, this is not a show for the faint hearted or timid. This show pulls no punches with regards to drugs, sex or violence. Its is hardcore, in the classic use of the word.<br /><br />It is called OZ as that is the nickname given to the Oswald Maximum Security State Penitentary. It focuses mainly on Emerald City, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda. Em City is home to many..Aryans, Muslims, gangstas, Latinos, Christians, Italians, Irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.<br /><br />I would say the main appeal of the show is due to the fact that it goes where other shows wouldn't dare. Forget pretty pictures painted for mainstream audiences, forget charm, forget romance...OZ doesn't mess around. The first episode I ever saw struck me as so nasty it was surreal, I couldn't say I was ready for it, but as I watched more, I developed a taste for Oz, and got accustomed to the high levels of graphic violence. Not just violence, but injustice (crooked guards who'll be sold out for a nickel, inmates who'll kill on order and get away with it, well mannered, middle class inmates being turned into prison bitches due to their lack of street skills or prison experience) Watching Oz, you may become comfortable with what is uncomfortable viewing....thats if you can get in touch with your darker side.
將該評論按照“.” 分割成句子,遍歷每個句子。
第一個句子:
One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked
第二個句子:
They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO
第一個句子分詞:
['one', 'of', 'the', 'other', 'reviewers', 'has', 'mentioned', 'that', 'after', 'watching', 'just', '1', 'oz', 'episode', 'you', "'", 'll', 'be', 'hooked']
第二個句子分詞:
['they', 'are', 'right', ',', 'as', 'this', 'is', 'exactly', 'what', 'happened', 'with', 'me', '.', 'the', 'first', 'thing', 'that', 'struck', 'me', 'about', 'oz', 'was', 'its', 'brutality', 'and', 'unflinching', 'scenes', 'of', 'violence', ',', 'which', 'set', 'in', 'right', 'from', 'the', 'word', 'go']
將每個句子隨機選擇15%的單詞進行隨機掩碼,開頭加上[CLS],padding到上下文長度27,然后將兩個句子拼接在一起,用[SEP]分割符分開。
['[CLS]', 'one', 'of', 'the', 'other', 'reviewers', 'has', 'mentioned', '[MASK]', 'after', 'watching', 'just', '1', 'oz', 'episode', 'you', "'", '[MASK]', '[MASK]', 'hooked', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[SEP]', '[CLS]', 'they', 'are', 'right', ',', 'as', 'this', 'is', '[MASK]', 'what', 'happened', '[MASK]', 'me', '[MASK]', 'the', '[MASK]', 'financiers', 'that', 'struck', 'me', 'about', 'oz', 'was', 'its', 'brutality', 'and', 'unflinching']
根據上面掩碼句子構造輸入掩碼,[MASK]的位置設置成Flase,其余為True。
[True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, False, True, False, True, False, False, True, True, True, True, True, True, True, True, True, True]
將帶掩碼的句子轉換成token ids,這個也是最終要輸入到模型中的X。
[0, 5, 6, 7, 8, 9, 10, 11, 2, 13, 14, 15, 16, 17, 18, 19, 20, 2, 2, 23, 1, 1, 1, 1, 1, 1, 1, 3, 0, 24, 25, 26, 27, 28, 29, 30, 2, 32, 33, 2, 35, 2, 7, 2, 32940, 12, 39, 35, 40, 17, 41, 42, 43, 44, 45]
將掩碼前的句子轉換成token ids,這個就是標簽Y。
[0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 1, 1, 1, 1, 1, 1, 1, 3, 0, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 7, 37, 38, 12, 39, 35, 40, 17, 41, 42, 43, 44, 45]
通過模型輸出與標簽Y計算MLM損失。
那NSP的損失呢?在構造句子對時,如果兩個句子是相鄰的,那么標簽就是1,否則是0,最終通過[CLS]的輸出計算二分類損失。
最終根據前1000行數據構造了一個DataFrame,DataFrame中每一條是一個樣本,一共包含17122個樣本,每個樣本包含四列。
一個是輸入X,維度[1,55]
一個是標簽Y,維度[1,55],
輸入掩碼,維度[1,55]
NSP分類標簽,0或者1。
55等于2兩個句子的長度加上一個[SEP]分割符,每個句子長度27。
1.預訓練
超參數
EMB_SIZE = 64 #詞嵌入維度
HIDDEN_SIZE = 36 //
EPOCHS = 4
BATCH_SIZE = 12 #batch size
NUM_HEADS = 4 //頭的個數
根據超參數BATCH_SIZE = 12,也就是每個batch包含12個樣本,所以輸入X維度[12,55],標簽Y維度[12,55]。
1.0詞嵌入
接下來將token ids轉換成embedding,在Bert中,每個token都涉及到三種嵌入,第一種是Token embedding,token id轉換成詞嵌入向量,第二種是位置編碼。還有一種是Segment embedding。用于表示哪個句子,0表示第一個句子,1表示第二個句子。
根據超參數EMB_SIZE = 64,所以詞嵌入維度64,Token embedding通過一個嵌入層[9626,64]將輸入[12,55]映射成[12,55,64]。
9626是詞表的大小,[9626,64]的嵌入層可以看作是有9626個位置索引的查找表,每個位置存儲64維向量。
位置編碼可以通過學習的方式獲得,也可以通過固定計算方式獲得,本次采用固定計算方式。
Segment embedding和輸入X大小一致,第一個句子對應為0,第二個位置為1。
最后將三個embedding相加,然后將輸出的embedding[12,55,64]輸入到編碼器中。
1.1多頭注意力
編碼器的第一個操作是多頭注意力,與Transformer和GPT中不同的是,不計算[PAD]的注意力,會將[PAD]對應位置的注意力分數設置為一個非常小的值,使之經過softmax后為0。
多頭注意力的輸出維度[12,55,64]。
1.2MLP
與Transformer和GPT中的一致,MLP的輸出維度[12,55,64]。
1.3輸出
編碼器的輸出[12,55,64],接下來通過與標簽計算損失來更新參數。
MLM損失
將Encoder的輸出[12,55,64]通過一個線性層[64,9626]映射成概率分布[12,55,9626]。
因為只需要計算[MASK]對應位置的損失,所以會通過一些技巧將標簽和輸出中,非[MASK]位置設置為0。
最后與輸出標簽Y計算多分類交叉熵損失。
NSP損失
通過另一個線性層[64,2]將開頭的[CLS]的輸出[12,64]映射成[12,2],表示屬于正負類的概率,然后與標簽計算交叉熵損失。
2.0推理
最簡單的是完形填空,輸入一段文本[1,55],然后將某些詞替換成[MASK],將[MASK]的輸出通過一個輸出頭映射成[1,9626]。
因為我們在預訓練時使用了“next sentence prediction”(NSP),可以構造一個閉集VQA,就是為一個問題事先準備幾個答案,分別將問題和答案拼接在一起輸入到BERT,通過[CLS]的輸出去分類。
或者去預測答案的起始和終止位置,這就涉及到下游任務的微調了。
總結
至此,我們已經完成了GPT2和BERT的預訓練過程,為了讓模型能跟隨人類指令,后面還要對預訓練模型進行指令微調。
參考
??https://arxiv.org/pdf/1810.04805??
??https://github.com/coaxsoft/pytorch_bert??
??https://towardsdatascience.com/a-complete-guide-to-bert-with-code-9f87602e4a11??
??https://medium.com/data-and-beyond/complete-guide-to-building-bert-model-from-sratch-3e6562228891??
??https://coaxsoft.com/blog/building-bert-with-pytorch-from-scratch??
本文轉載自公眾號人工智能大講堂
