DeepSeek如何用MTP逆天改命?
DeepSeek-V3 的 Multi-Token Prediction 到底在做什么?這個問題在大模型面試中經常被問到,屬于 DeepSeek 的高頻面試題。
所以這篇文章我們就來看看,如果你在面試現場被問到這個問題,應該如何作答?
1.面試官心理分析
首先老規矩,我們還是來分析一下面試官的心理,面試官問這個問題,它其實主要是想考察你 3 個方面:
- 第一,為什么要做 MTP?你是否知道這個算法背后的動機?
- 第二,之前的工作 MTP 是怎么做的?DeepSeek 肯定不是這個方法的首創,那之前的研究,前因后果你是否清楚呢?
- 第三,DeepSeek 的 MTP 是怎么做的,它的設計相比之前的,有什么不同之處?
好,了解了面試官的心理之后,接下來我們就沿著面試官的心理預期,來回答一下這道題目!
2.面試題解析
首先第一個問題:為什么要做 MTP?
我們都知道,當前主流的大模型都是 decoder-only 的架構,每生成一個 token,都要頻繁的跟訪存交互,加載 KV-Cache,再完成前向計算。
那對于這樣的訪存密集型任務,通常會因為訪存效率而形成推理的瓶頸,針對這種 token-by-token 生成效率的瓶頸,業界有很多方法來優化,比如減少存儲空間,減少訪存次數等等。
那 MTP 也是優化訓練和推理效率的方法之一,它的核心動機是:通過解碼階段的優化,將 next 1-token 的生成,轉變成 multi-token 的生成,以提升訓練和推理的性能。
對于訓練階段,一次生成多個后續 token,可以一次學習多個位置的 label,這樣可以增加樣本的利用效率,提高訓練速度;而在推理階段,通過一次生成多個 token,可以實現成倍的解碼加速,來提升推理性能。
好,到這里我們就回答了第一個問題:為什么要用 MTP?接著我們再來看看,DeepSeek 之前的 MTP 都是如何做的?業界經過了哪些探索?
其實最早做 MTP 方法的是 Google 在 18 年發表的這篇論文《Blockwise Parallel Decoding for Deep Autoregressive Models》。
其思想很簡單,我們看這張圖:
可以看到,logits 上接了多個輸出頭,這樣訓練的時候可以同時預測出多個未來的 token,也就是分別預測下個 token,再下個 token,再再下個 token,以此類推。
好,理解了網絡細節,我們再看并行解碼過程就很好理解了,整個推理過程看這張圖:
可以看到,解碼過程主要分成三步:
階段 1:predict,利用 k 個 Head 一次生成 k 個 token,每個 Head 生成一個 token。
階段 2:verify,將原始的序列和生成的 k 個 token 拼接,組成 sequence_input 和 label 的 Pair 對。
Pair<sequence_input, label>
大家看圖中的 verify 階段,黑框里是 sequence_input,箭頭指向的是要驗證的 label。
我們將組裝的 k 個 Pair 對組成一個 batch,一次性發給 Head1 做校驗,檢查 Head1 生成的 token 是否跟 label 一致。
然后是階段 3:accept,選擇 Head1 預估結果與 label 一致的最長的 k 個 token,作為可接受的結果。
最優情況下,所有輔助 Head 預測結果跟 Head1 完全一樣,也就是相當于一個 step 正確解碼出了多個 token,這可以極大的提升解碼效率。
實際上在 24 年,meta 也發表過一篇大模型 MTP 的工作,這是當時的論文,其結構跟 Google 那篇差別不大,這里我們就不再單獨贅述。
感興趣的同學可以去看看這篇論文《Better & Faster Large Language Models via Multi-token Prediction》。
好,了解了 MTP 在業界的發展,我們再來看看,DeepSeek 是怎么做 MTP 的?
這里直接說改進,DeepSeek 的 MTP 設計,看這張圖:
實際上它在論文實現上保留了序列推理的 causal chain,也就是存在從一個 head 連接到后繼 head 的箭頭。其他的思路跟 google 那篇論文差不多。
另外在訓練的時候,同樣采用的是 teacher forcing 的思想,也就是 input 會輸入真實的 token,而在實際預測解碼的階段,采用的是 free running 的思想,也就是直接用上一個 step 解碼的輸出,來作為下一個 step 的輸入。
本文轉載自???丁師兄大模型??,作者: 丁師兄
