DeepSeek R1本地訓練全流程實操指南,手把手教你打通其“任督二脈”
作者 | asher
許多關于 DeepSeek R1 的復現文章,主要聚焦在“rewards的設計、訓練指標的變化、benchmark測評”這些內容,但是對于“本地訓練”這個開啟深度探索的關鍵前置步驟,卻很少有人深挖。
可能有人覺得,照著readme操作就能輕松訓練了吧?太天真啦!實際動手就會發現,和自家的環境各種水土不服,大模型不是訓不起來就是訓的太慢,問題多到讓人頭大。
為了解決本地訓練的適配性問題,今天挑選HuggingFace的開源項目open-r1,為大家帶來一場全流程實操演示,從怎么在8卡A100(40G)上跑通基于Qwen-14B的DeepSeek R1復現,到分享超實用的環境鏡像,還有滿滿踩坑經驗,再到手把手教你改造代碼適配自己的任務數據,助大家光速開啟DeepSeek R1在自定義數據上的訓練探索之旅。
一、 環境搭建不求人
1. 顯卡驅動與CUDA適配要點
open-r1明確要求cuda12.4,得先瞅瞅自己機器的顯卡驅動版本(如下圖),要是版本太老,那可就得升級才能適配適配cuda12.4,我親測,顯卡驅動版本為470以上就能正常運行,我的版本是535。
# 查看自己的顯卡版本與cuda是否適配
import torch
print(torch.cuda.is_available()) # True就可以
2. 快速搞定環境安裝
與readme里的uv相比,我還是習慣使用conda管理虛擬環境:
1. conda create -n openr1 python=3.11
2. pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
3. pip install vllm==0.7.2
4. pip install flash-attn
5. 切換到open-r1目錄執行pip install -e ".[dev]"
二、訓練踩坑大避雷
1. 導致OOM的原因有這么多
以grpo訓練為例,使用Qwen-14B在A100上訓練很容易報錯OOM,原因有多種,讓我來為大家一一分析:grpo任務可以分為兩部分:一部分是模型訓練(7卡),一部分是模型推理(1卡),OOM報錯的原因就來自這兩部分。
- 訓練報錯oom:7張A100卡無法實現14B模型的訓練。解決方法:修改recipes/accelerate_configs/zero3.yaml,開啟offload
- 推理報錯oom:如果vllm版本在0.7.3以下,很容易發生oom,需要修改recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml,調低vllm_gpu_memory_utilization參數值,14B模型可以改為0.2,7B模型可以改為0.5。
- 推理報錯oom:指定vllm推理的max_model_len太長,導致kv caceh需要占用的顯存太多。解決方法:修改recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml,調低vllm_max_model_len,注意這個參數是指prompt+模型輸出長度,不宜過短,可以調整為4k-8k。默認值是讀取基座模型config,比如Qwen-14B默認是32768。
那么如何識別自己的OOM報錯是出自訓練還是推理呢?直接看報錯的GPU卡號,因為默認是最后一張卡用于推理,如下圖既然是GPU 7 內存不足,那就推理出了問題,只需要調整上述提到的兩個參數即可。
針對Qwen-14B在8卡A100(40G)訓練對應的配置文件,我已經調教好了放在本文最后,供大家參考。
2. reward函數的形參命名有講究
在設計reward函數,有個注意:reward函數聲明的形參很重要,不是隨便起的,要求與dataset的列名是一致的。比如下面這個reawrd函數的兩個形參,completions表示模型生成的內容,ground_truth表示dataset中”ground_truth“列的值,這里的形參ground_truth就是要求與dataset列名字對齊。
import re
def reward_func(completions, ground_truth, **kwargs):
# Regular expression to capture content inside \boxed{}
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
contents = [match.group(1) if match else "" for match in matches]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
三、DeepSeek R1訓練快速開啟不迷路
1. 數據先行!準備業務數據要點
離線構造業務數據集data.json,注意字段名為problem與solution,與官方給的示例數據字段名一致,這樣可以少去很多改代碼的麻煩:
{"problem": "Classify the text into neutral, negative, or positive\nText: I think the food was okay.\nSentiment:\n", "solution": "positive"}
{"problem": "Classify the text into neutral, negative, or positive\nText: I think the food was shit.\nSentiment:\n", "solution": "negative"}
2. 巧妙變身!輕松更改數據讀取方式
修改grpo.py中數據讀取方式,由讀取hub數據改為讀取離線數據:
dataset = load_dataset("json", data_files=XXX/data.json)
dataset = dataset["train"].train_test_split(test_size=0.02)
個性定制!手把手自定義reward函數 注意這里函數聲明中solution形參要與dataset的字段保持一致:
def accuracy_reward_ours(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = sol # 從數據集中讀取ground-truth
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = re.findall("<answer>(.*?)</answer>",content) # 從模型輸出文本中提取預測答案
if len(answer_parsed)>0:
answer_parsed = answer_parsed[0]
reward = float(1 if answer_parsed==gold_parsed else 0) # 判斷預測結果與真實結果是否一致
else:
reward = float(0)
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
return rewards
3. 一鍵啟動!暢爽開啟DeepSeek R1訓練
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml \
&> /workspace/user_code/Qwen2.5-14B-Instruct.log
四、能讓14B模型在A100上絲滑跑通R1的配置參數大公開
recipes/accelerate_configs/zero3.yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: "cpu"
offload_param_device: "cpu"
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml
# Model arguments
model_name_or_path: XXX/models/Qwen2.5-14B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: XXX/dataset/data.json
# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 7
# GRPO trainer config
reward_funcs:
- accuracy_ours
- format
bf16: true
use_vllm: true
vllm_device: cuda:7
vllm_gpu_memory_utilization: 0.2 # vllm版本在0.7.3以下
vllm_max_model_len: 8000
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen-2.5-7B-Simple-RL
hub_strategy: every_save
learning_rate: 3.0e-06
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_generations: 7
num_train_epochs: 1
output_dir: XXX/Qwen-2.5-7B-Instruct-RL
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 8
push_to_hub: false
report_to: "none"
save_strategy: "steps"
save_steps: 100
save_total_limit: 2
seed: 42
warmup_ratio: 0.1