機器學習 | 從0開始大模型之模型LoRA訓練
1、LoRA是如何實現的?
在深入了解 LoRA 之前,我們先回顧一下一些基本的線性代數概念。
1.1、秩
給定矩陣中線性獨立的列(或行)的數量,稱為矩陣的秩,記為 rank(A) 。
- 矩陣的秩小于或等于列(或行)的數量,rank(A) ≤ min{m, n}
- 滿秩矩陣是所有的行或者列都獨立,rank(A) = min{m, n}
- 不滿秩矩陣是滿秩矩陣的反面是不滿秩,即 rank(A) < min(m, n),矩陣的列(或行)不是彼此線性獨立的
舉個兩個秩的例子:
不滿秩
滿秩
1.2、秩相關屬性
從上面的秩的介紹中可以看出,矩陣的秩可以被理解為它所表示的特征空間的維度,在這種情況下,特定大小的低秩矩陣比相同維度的滿秩矩陣封裝更少的特征(或更低維的特征空間)。與之相關的屬性如下:
- 矩陣的秩受其行數和列數中最小值的約束,rank(A) ≤ min{m, n};
- 兩個矩陣的乘積的秩受其各自秩的最小值的約束,給定矩陣 A 和 B,其中 rank(A) = m 且 rank(A) = n,則 rank(AB) ≤ min{m, n};
1.3、LoRA
LoRA(Low rand adaption) 是微軟研究人員提出的一種高效的微調技術,用于使大型模型適應特定任務和數據集。LoRA 的背后的主要思想是模型微調期間權重的變化也具有較低的內在維度,具體來說,如果W??代表單層的權重,ΔW??代表模型自適應過程中權重的變化,作者提出ΔW??是一個低秩矩陣,即:rank(ΔW??) << min(n,k) 。
為什么?模型有了基座以后,如果強調學習少量的特征,那么就可以大大減少參數的更新量,而ΔW??就可以實現,這樣就可以認為ΔW??是一個低秩矩陣。
實現原理ΔW??是一個更新矩陣,然后ΔW??根據秩的屬性,又可以拆分兩個低秩矩陣的乘積,即:B?? 和 A?? ,其中 r << min{n,k} 。這意味著網絡中權重 Wx = Wx + ΔWx = Wx + B??A??x,由于 r 很小,所以 B??A?? 的參數數量非常少,所以只需要更新很少的參數。
LoRA
2、peft庫
LoRA 訓練非常方便,只需要借助 https://huggingface.co/blog/zh/peft 庫,這是 huggingface 提供的,使用方法如下:
# 引入庫
from peft import get_peft_model, LoraConfig, TaskType
# 創建對應的配置
peft_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q", "v"],
lora_dropout=0.01,
bias="none"
task_type="SEQ_2_SEQ_LM",
)
# 包裝模型
model = AutoModelForSeq2SeqLM.from_pretrained(
"t5-small",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
LoraConfig 詳細參數如下:
- r:秩,即上面的r,默認為8;
- target_modules:對特定的模塊進行微調,默認為None,支持nn.Linear、nn.Embedding和nn.Conv2d;
- lora_alpha:ΔW 按 α / r 縮放,其中 α 是常數,默認為8;
- task_type:任務類型,支持包括 CAUSAL_LM、FEATURE_EXTRACTION、QUESTION_ANS、SEQ_2_SEQ_LM、SEQ_CLS 和 TOKEN_CLS 等;
- lora_dropout:Dropout 概率,默認為0,通過在訓練過程中以 dropout 概率隨機選擇要忽略的神經元來減少過度擬合的技術;
- bias:是否添加偏差,默認為 "none";
3、訓練
使用 peft 庫對SFT全量訓練修改如下:
def init_model():
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
return list(lora_module_names)
model = Transformer(lm_config)
ckp = f'./out/pretrain_{lm_config.dim}.pth.{batch_size}'
state_dict = torch.load(ckp, map_locatinotallow=device_type)
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
target_modules = find_all_linear_names(model)
peft_config = LoraConfig(
r=8,
target_modules=target_modules
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print(f'LLM總參數量:{count_parameters(model) / 1e6:.3f} 百萬')
model = model.to(device_type)
return model
只需要修改模型初始化部分,其他不變,訓練過程和之前一樣,這里不再贅述。
參考
(1)https://cloud.tencent.com/developer/article/2372297
(2)http://www.bimant.com/blog/lora-deep-dive/
(3)https://blog.csdn.net/shebao3333/article/details/134523779