大模型“拼好題”,45K數(shù)據(jù)撬動18%提升,數(shù)學問題拒絕死記硬背?|?MathFusion
當前數(shù)學領(lǐng)域的數(shù)據(jù)生成方法常常局限于對單個問題進行改寫或變換,好比是讓學生反復做同一道題的變種,卻忽略了數(shù)學題目之間內(nèi)在的關(guān)聯(lián)性。
為了打破這種局限,讓大模型學會“串聯(lián)”與“并聯(lián)”知識,上海AI Lab、人大高瓴等團隊聯(lián)合提出了MathFusion,通過指令融合增強大語言模型解決數(shù)學問題的能力。
僅使用45K的合成指令,MathFusion在多個基準測試中平均準確率提升了18.0個百分點,展現(xiàn)了卓越的數(shù)據(jù)效率和性能。
△越靠左上角,模型表現(xiàn)越好且數(shù)據(jù)效率越高。
核心思想:三種“融合策略”
MathFusion通過三種“融合策略”,將不同的數(shù)學問題巧妙地結(jié)合起來,生成封裝了二者關(guān)系和結(jié)構(gòu)的新問題。
- 順序融合(Sequential Fusion)將兩個問題串聯(lián)起來,前一個問題的答案作為后一個問題的某個輸入條件。這就像解決一個多步驟問題,模型需要先解出第一步,才能進行第二步,從而學會處理問題間的依賴關(guān)系。
- 并列融合(Parallel Fusion)將兩個相似的問題融合在一起,對它們的數(shù)學概念進行識別和融合,在原來問題的基礎(chǔ)上提出一道新的問題。
- 條件融合(Conditional Fusion)創(chuàng)造一個需要對兩個問題的解進行比較和選擇的問題場景。
首先從現(xiàn)有數(shù)據(jù)集(GSM8K、MATH)中識別出適合融合的問題對(主要通過embedding search),然后應用融合策略生成新問題,并利用GPT-4o-mini來生成解答。通過這三種策略,生成了一個全新的融合數(shù)據(jù)集MathFusionQA。
融合實例:不同策略的融合結(jié)果
為了更直觀地理解這三種融合策略,來看一個具體的例子:
原始問題
- 問題A:一天內(nèi),一艘船在湖中航行4次,每次最多可載12人。請問在2天內(nèi),這艘船可以運送多少人?
- 問題B:學校組織去博物館。他們租了4輛巴士來接送孩子和老師。第二輛巴士的人數(shù)是第一輛的兩倍,第三輛比第二輛少6人,第四輛比第一輛多9人。如果第一輛巴士上有12人,請問總共有多少人去了博物館?
順序融合
學校組織一次去博物館的旅行,需要運送學生和老師。首先,請計算一艘船在2天內(nèi)的載客量,這艘船每天航行4次,每次最多可載12人。然后,將這個總載客量作為第一輛巴士的人數(shù)。已知第二輛巴士的人數(shù)是第一輛的兩倍,第三輛比第二輛少6人,第四輛比第一輛多9人。請問總共有多少人去了博物館?
并列融合
一所學校組織一次到博物館的實地考察,并租用了4輛巴士和一艘船6。這艘船一天航行2次,每次載客12人。每輛巴士的人數(shù)不同:第一輛巴士有12人,…,第四輛比第一輛多9人。請計算在2天內(nèi),船和所有巴士總共可以運送多少人?
條件融合
一個社區(qū)正在組織兩種不同的郊游活動。對于湖上游覽,一艘船每天運營4次,載客量為12人,他們計劃讓這艘船服務2天。與此同時,一所學校正在安排一次有4輛巴士的博物館之旅11。第一輛巴士有12人,第二輛是第一輛的兩倍,第三輛比第二輛少6人,第四輛比第一輛多9人。考慮到這些安排,哪種交通方式的載客能力更強?
實驗結(jié)果:有效捕捉問題間深層聯(lián)系
在MathFusionQA的基礎(chǔ)上,使用三種融合策略——順序、并行和條件——對模型(DeepSeekMath-7B、Mistral-7B、Llama3-8B)進行微調(diào)。實驗得到以下發(fā)現(xiàn):
顯著提升模型性能與效率:與標準訓練方法(只在GSM8K和MATH上訓練)相比,MathFusion在多個base模型(包括DeepSeekMath-7B、Llama3-8B、Mistral-7B)上都取得了穩(wěn)定的性能提升。并且,MathFusion在大幅提升性能的同時,保持了極高的數(shù)據(jù)效率,用遠少于其他方法的數(shù)據(jù)量就達到了良好的效果。
策略之間優(yōu)勢互補:將順序融合、條件融合和并行融合三種策略結(jié)合使用,組合融合策略始終優(yōu)于每種單一融合策略。另外,基礎(chǔ)模型性能越弱,組合融合策略帶來的提升就越大。在所有基準測試中,組合融合策略在DeepSeekMath-7B上平均提升了3.1分,在Llama3-8B上提升了4.9分,在Mistral-7B上提升了7.5分。
強大的泛化與擴展能力:MathFusion不僅在in-domain測試中表現(xiàn)優(yōu)異,在更具挑戰(zhàn)性的out-of-domain基準測試中同樣超越了標準模型。
對MathFusion做進一步的分析,有以下幾點發(fā)現(xiàn):
- 融合之后的問題的指令遵循難度(IFD)更高,說明融合之后的問題對于模型來說更加困難。
- 隨著融合數(shù)據(jù)量的增加,MathFusion模型的性能呈現(xiàn)出近似對數(shù)形式的增長。
- 當把MathFusionQA數(shù)據(jù)集與DART-Math數(shù)據(jù)集結(jié)合使用時,模型的性能可以得到進一步的提升,甚至超過了單獨使用任何一個數(shù)據(jù)集時的表現(xiàn)。這表明MathFusion的“問題融合”思路與DART-Math的“挖掘難題”思路是互補的。
- 通過t-SNE可視化分析,發(fā)現(xiàn)MathFusion得到的問題在特征空間中的分布比原始問題更均勻和廣泛。
- 通過對teacher model的消融分析,證明了MathFusion帶來的提升源自于問題融合本身,而非teacher model的好壞。
總的來說,通過生成結(jié)構(gòu)更多樣、邏輯更復雜的合成問題,MathFusion有效地增強了模型捕捉問題間深層聯(lián)系的能力。
但目前MathFusion還只在GSM8K、MATH這種比較簡單的數(shù)學問題,以及short cot solution的數(shù)據(jù)集上進行了驗證,有待進一步擴展到更難的數(shù)學問題、long cot solution以及其他領(lǐng)域的數(shù)據(jù)上。