基于秘密共享重構(gòu) DeepSeek DeepGEMM Kernel 的安全高效 MPC-GEMM 方案
摘要
本文針對(duì)安全多方計(jì)算(MPC)框架下通用矩陣乘法(GEMM)運(yùn)算的性能瓶頸,提出一種全新的 MPC-GEMM 實(shí)現(xiàn)方案。該方案的核心思想在于:基于加法秘密共享重構(gòu) DeepSeek DeepGEMM 的 CUDA kernel,將 MPC 協(xié)議的邏輯與 DeepGEMM 的底層優(yōu)化深度融合,消除 MPC 協(xié)議與 GPU 計(jì)算之間的“兩張皮”現(xiàn)象。方案采用 INT8/FP8 數(shù)據(jù)表示、秘密共享運(yùn)算的 kernel 級(jí)實(shí)現(xiàn)、Hopper 架構(gòu)優(yōu)化(如適用)、GPU 并行 Beaver 三元組生成以及 JIT 編譯等關(guān)鍵技術(shù)。本文將詳細(xì)闡述方案的設(shè)計(jì)原理、技術(shù)框架、實(shí)現(xiàn)細(xì)節(jié)(包括 kernel 代碼示例、算法描述、優(yōu)化策略),并從可行性、安全性、高效性等方面進(jìn)行全面深入的論證,最后與其他 MPC-GEMM 方案進(jìn)行對(duì)比。方案旨在實(shí)現(xiàn)真正意義上的安全、高效的 MPC-GEMM,為隱私保護(hù)機(jī)器學(xué)習(xí)提供強(qiáng)有力的支持。
關(guān)鍵詞: DeepGEMM, DeepSeek, MPC, GEMM, 秘密共享, CUDA, Kernel 重構(gòu), 安全計(jì)算, INT8, FP8, Hopper 架構(gòu), Beaver 三元組, JIT 編譯, 并行計(jì)算
1. 引言:MPC-GEMM 的性能挑戰(zhàn)與 DeepGEMM 的機(jī)遇
安全多方計(jì)算(MPC)使得互不信任的參與方能夠在不泄露各自私有數(shù)據(jù)的前提下進(jìn)行協(xié)同計(jì)算,是實(shí)現(xiàn)隱私保護(hù)機(jī)器學(xué)習(xí)的關(guān)鍵技術(shù)。通用矩陣乘法(GEMM)作為深度學(xué)習(xí)模型的核心運(yùn)算,其在 MPC 框架下的實(shí)現(xiàn)(MPC-GEMM)的效率直接影響著隱私保護(hù)機(jī)器學(xué)習(xí)應(yīng)用的整體性能和實(shí)用性。然而,現(xiàn)有的 MPC-GEMM 方案普遍面臨著嚴(yán)重的性能挑戰(zhàn):
- 計(jì)算開銷:MPC 協(xié)議的密碼學(xué)運(yùn)算(如秘密共享、同態(tài)加密)計(jì)算復(fù)雜度遠(yuǎn)高于明文計(jì)算。
- 通信開銷:多數(shù) MPC 協(xié)議需要在參與方之間進(jìn)行大量的交互通信,尤其是在執(zhí)行乘法運(yùn)算時(shí),通信開銷成為主要瓶頸。
- 硬件加速:如何在 MPC 的安全約束下有效利用 GPU 等硬件加速器進(jìn)行計(jì)算,是一個(gè)極具挑戰(zhàn)性的問題。
傳統(tǒng)的 MPC-GEMM 方案通常采用“兩張皮”模式:MPC 協(xié)議負(fù)責(zé)保證計(jì)算的安全性,GPU 負(fù)責(zé)提供計(jì)算加速,兩者之間通過某種安全接口(如可信執(zhí)行環(huán)境 TEE 或同態(tài)加密)進(jìn)行交互。這種模式的缺點(diǎn)在于:
- 交互開銷:MPC 協(xié)議與 GPU 計(jì)算之間存在數(shù)據(jù)轉(zhuǎn)換(如明文與密文、秘密份額與 GPU 可處理格式之間的轉(zhuǎn)換)和通信的開銷,限制了整體性能。
- GPU 利用率:GPU 計(jì)算部分通常受到 MPC 協(xié)議的制約,無法充分發(fā)揮 GPU 的并行計(jì)算能力和 DeepGEMM 等底層優(yōu)化庫的性能優(yōu)勢(shì)。
DeepSeek 最新發(fā)布的 DeepGEMM 是一個(gè)為 NVIDIA GPU 優(yōu)化的高性能 GEMM 庫。它通過 FP8 低精度計(jì)算、針對(duì) GPU 架構(gòu)的優(yōu)化、CUDA kernel 優(yōu)化以及 JIT 編譯等技術(shù),大幅提升了 GEMM 運(yùn)算的效率。雖然 DeepGEMM 并非專門為 MPC 設(shè)計(jì),但其在 kernel 級(jí)別的優(yōu)化為我們提供了一個(gè)重要的機(jī)遇:能否將 MPC 協(xié)議與 DeepGEMM 的底層優(yōu)化進(jìn)行深度融合,消除“兩張皮”現(xiàn)象,實(shí)現(xiàn)真正意義上的安全高效的 MPC-GEMM?
2. 方案原理:深度融合 MPC 與 DeepGEMM
基于 MPC 與 DeepGEMM 的深度融合,就可以嘗試構(gòu)想一種全新的 MPC-GEMM 方案:基于秘密共享重構(gòu) DeepSeek DeepGEMM kernel。該方案的核心思想是:將 MPC 協(xié)議中與 GEMM 運(yùn)算相關(guān)的計(jì)算邏輯(秘密份額的加法、乘法)直接實(shí)現(xiàn)在 DeepGEMM 的 CUDA kernel 中,讓 GPU 直接執(zhí)行一個(gè)完整的“MPC-GEMM”運(yùn)算。
方案的設(shè)計(jì)基于以下幾個(gè)關(guān)鍵原理:
1)加法秘密共享:采用加法秘密共享作為 MPC 的基礎(chǔ)安全機(jī)制。加法秘密共享具有以下優(yōu)點(diǎn):
- 簡(jiǎn)單高效:實(shí)現(xiàn)簡(jiǎn)單,只需要進(jìn)行模加運(yùn)算。
- 加法同態(tài):秘密份額的加法對(duì)應(yīng)于明文的加法,使得加法運(yùn)算可以在本地高效執(zhí)行,無需通信。
- 安全性:信息論安全,只要參與方不合謀,任何單獨(dú)的秘密份額都不會(huì)泄露關(guān)于原始數(shù)據(jù)的任何信息。
2)INT8/FP8 數(shù)據(jù)表示:為了降低計(jì)算和通信開銷,我們借鑒 DeepGEMM 對(duì)低精度計(jì)算的使用,將輸入數(shù)據(jù)(FP32/FP64/定點(diǎn)數(shù))映射到 INT8 或 FP8。
- INT8 映射:對(duì)于 INT8,我們采用偏移映射等策略,充分利用 INT8 的表示范圍,并簡(jiǎn)化秘密共享運(yùn)算。
- FP8 映射:如果采用 FP8,可以利用 DeepGEMM 自身的 FP8 支持。
3)DeepGEMM Kernel 重構(gòu):方案的核心在于對(duì) DeepGEMM 的 CUDA kernel 進(jìn)行重構(gòu)。我們將 MPC 協(xié)議的邏輯(即秘密共享下的加法和乘法)直接嵌入到 kernel 中。
- 輸入/輸出:Kernel 的輸入和輸出直接是秘密份額(INT8 或 FP8),而不是明文數(shù)據(jù)。
- 基本運(yùn)算:將 kernel 中的加法和乘法替換為 MPC 協(xié)議下的秘密共享加法和乘法(基于 Beaver 三元組)。
- 保留優(yōu)化:盡最大可能保留 DeepGEMM 原有的針對(duì) GPU 架構(gòu)的優(yōu)化技術(shù),如 tiling、loop unrolling、shared memory 利用、warp-level primitives、指令級(jí)并行等,并針對(duì)秘密共享運(yùn)算進(jìn)行適配。
- 異步計(jì)算: 盡可能利用GPU的異步計(jì)算能力。
4)Beaver 三元組乘法:為了在秘密共享下實(shí)現(xiàn)乘法,采用 Beaver 三元組乘法協(xié)議??梢栽?kernel 中實(shí)現(xiàn) Beaver 三元組乘法協(xié)議,并利用 warp-level primitives(如??__shfl_xor_sync?
?)進(jìn)行優(yōu)化。
5)GPU 并行 Beaver 三元組生成:為了提高 Beaver 三元組的生成效率,并減少預(yù)處理階段的通信開銷,我們可以利用 GPU 的并行計(jì)算能力,在 GPU 上并行生成 Beaver 三元組。
6)JIT 編譯:我們充分利用 DeepGEMM 的 JIT 編譯技術(shù)(如果 DeepGEMM 提供 JIT 編譯接口;如果沒有,我們可以自行實(shí)現(xiàn) JIT 編譯),根據(jù) GEMM 形狀、塊大小、參與方數(shù)量等參數(shù),動(dòng)態(tài)生成高度優(yōu)化的 MPC-GEMM kernel。
7)簡(jiǎn)化的 MPC 協(xié)議:由于 GPU 直接參與 MPC 協(xié)議的執(zhí)行(我們將其視為一個(gè)“半誠實(shí)”的參與方),我們可以簡(jiǎn)化 MPC 協(xié)議的設(shè)計(jì),減少通信輪數(shù)和通信量。
3. 技術(shù)框架與實(shí)現(xiàn)細(xì)節(jié)
3.1 技術(shù)框架
方案的技術(shù)框架主要由以下幾個(gè)模塊構(gòu)成:
- 秘密共享模塊:
a.負(fù)責(zé)將參與方的輸入數(shù)據(jù)(FP32、FP64 或定點(diǎn)數(shù))進(jìn)行加法秘密共享。
b.將秘密份額轉(zhuǎn)換為 INT8 或 FP8 表示(通過映射)。
c.實(shí)現(xiàn)秘密共享上的加法和乘法運(yùn)算(基于 Beaver 三元組)。
d.提供秘密份額的生成、分發(fā)、重構(gòu)等功能。
- DeepGEMM Kernel 重構(gòu)模塊:
a.負(fù)責(zé)對(duì) DeepGEMM 的 CUDA kernel 進(jìn)行重構(gòu),將秘密共享運(yùn)算(加法和乘法)嵌入到 kernel 中。
b.保留并適配 DeepGEMM 原有的 GPU 架構(gòu)優(yōu)化。
c.利用 JIT 編譯技術(shù)(或手動(dòng)實(shí)現(xiàn)),動(dòng)態(tài)生成針對(duì)特定參數(shù)(GEMM 形狀、塊大小、參與方數(shù)量等)的優(yōu)化 kernel。
- MPC 協(xié)議協(xié)調(diào)模塊:
a.負(fù)責(zé)協(xié)調(diào)各參與方和 GPU 之間的交互。
b.管理 Beaver 三元組的分發(fā)(如果采用離線生成)。
c.觸發(fā) GPU kernel 的執(zhí)行。
- GPU Beaver 三元組生成模塊:
a.利用 GPU 的并行計(jì)算能力,高效生成 Beaver 三元組。
3.2 工作流程
整個(gè) MPC-GEMM 的計(jì)算流程分為離線階段和在線階段:
- 離線階段(預(yù)處理):
- 利用 GPU 并行生成 Beaver 三元組,并將三元組的秘密份額分發(fā)給各參與方(和 GPU 線程)。
- 在線階段:
- 參與方收集各自的輸出份額。
- 將對(duì)應(yīng)位置的份額相加(模運(yùn)算,如果是 INT8;浮點(diǎn)加法,如果是 FP8),重構(gòu)出最終的 GEMM 結(jié)果。
- 如果需要,可以將結(jié)果轉(zhuǎn)換回 FP32 或 FP64 格式。
- Kernel 計(jì)算完成后,輸出結(jié)果仍然是秘密份額(INT8 或 FP8)的形式。
- GPU 將輸出份額返回給參與方。
- GPU 執(zhí)行重構(gòu)后的 DeepGEMM kernel。
- 在 kernel 內(nèi)部:
- 整個(gè)計(jì)算過程高度并行化。
- 將輸入數(shù)據(jù)(秘密份額)和 Beaver 三元組份額加載到 shared memory。
- 使用 tiling 技術(shù)將矩陣分塊。
- 對(duì)于每個(gè)塊,執(zhí)行秘密共享下的加法和乘法運(yùn)算(利用 Beaver 三元組和 warp-level primitives)。
- 利用 GPU 架構(gòu)優(yōu)化(如 tiling, loop unrolling, shared memory, warp-level primitives, 指令級(jí)并行, 異步計(jì)算等)。
- 將中間結(jié)果累加到 shared memory 或 registers 中。
- MPC 協(xié)議協(xié)調(diào)模塊根據(jù) GEMM 運(yùn)算的參數(shù)(形狀、塊大小等)和參與方數(shù)量,觸發(fā) DeepGEMM Kernel 重構(gòu)模塊生成相應(yīng)的 CUDA kernel(利用 JIT 編譯或手動(dòng)實(shí)現(xiàn))。
- 參與方將各自持有的秘密份額(INT8 或 FP8)直接作為輸入,傳遞給生成的 CUDA kernel。
- 每個(gè)參與方將自己的輸入矩陣的每個(gè)元素進(jìn)行加法秘密共享。
- 將秘密份額轉(zhuǎn)換為 INT8 或 FP8 表示。
- ?輸入準(zhǔn)備:
- Kernel 調(diào)用:
- GPU 并行計(jì)算:
- 輸出處理:
- 結(jié)果重構(gòu):?
3.3 關(guān)鍵實(shí)現(xiàn)細(xì)節(jié)
本文中所用代碼均是偽代碼,根據(jù)通義靈碼的建議生成的,只能看出大致的意思,不能直接使用。
3.3.1 數(shù)據(jù)表示
- INT8 映射 (如果采用 INT8):
我們推薦使用偏移映射。假設(shè)原始數(shù)據(jù)為 FP32,映射規(guī)則如下:
映射公式:
其中,S1、S2 是縮放因子,O1、O2 是偏移量。具體數(shù)值需要根據(jù)實(shí)際數(shù)據(jù)分布和 INT8 的表示范圍來確定。 - 對(duì)于 FP32 正數(shù) x:?
?INT8 = round(x * S1) + O1?
? - 對(duì)于 FP32 負(fù)數(shù) x:?
?INT8 = round(x * S2) + O2?
? - 將 FP32 的 NaN 映射到 INT8 的 -128。
- 將 FP32 的 +Inf 映射到 INT8 的 -127。
- 將 FP32 的 -Inf 映射到 INT8 的 -126。
- 將 FP32 的 0 映射到 INT8 的 0。
- 將 FP32 的其他正數(shù),等比例映射到 INT8 的 [1, 127] 區(qū)間。
- 將 FP32 的其他負(fù)數(shù),等比例映射到 INT8 的 [-125, -1] 區(qū)間。
- FP8 表示 (如果采用 FP8):如果采用FP8,可以直接利用DeepGEMM對(duì)FP8的支持。
3.3.2 CUDA Kernel 中的秘密共享乘法
以下是 CUDA kernel 中實(shí)現(xiàn)秘密共享乘法(基于加法秘密共享和 Beaver 三元組)的示例代碼,并加入了詳細(xì)注釋:
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
template<typename T>
__global__ void mpc_gemm_kernel(T* x_shares, T* y_shares,
T* a_shares, T* b_shares, T* c_shares,
T* z_shares,
int m, int n, int k, int num_parties) {
// 獲取線程 ID、塊 ID 以及塊維度
int tid = threadIdx.x;
int bid_x = blockIdx.x;
int bid_y = blockIdx.y;
int block_dim = blockDim.x;
// 定義 shared memory 變量 (使用雙緩沖)
__shared__ T x_shared[2][BLOCK_SIZE][BLOCK_SIZE];
__shared__ T y_shared[2][BLOCK_SIZE][BLOCK_SIZE];
__shared__ T a_shared[2][BLOCK_SIZE][BLOCK_SIZE];
__shared__ T b_shared[2][BLOCK_SIZE][BLOCK_SIZE];
__shared__ T c_shared[2][BLOCK_SIZE][BLOCK_SIZE];
// 使用 cooperative groups
cg::thread_block cta = cg::this_thread_block();
cg::grid_group grid = cg::this_grid();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(cta);
// 計(jì)算當(dāng)前線程負(fù)責(zé)的矩陣元素的坐標(biāo)
int row = bid_y * BLOCK_SIZE + tid / BLOCK_SIZE;
int col = bid_x * BLOCK_SIZE + tid % BLOCK_SIZE;
// 初始化累加器
T acc = 0;
// 循環(huán)處理矩陣塊 (tiling)
int buffer_idx = 0; // 雙緩沖索引
for (int i = 0; i < k; i += BLOCK_SIZE) {
// 將數(shù)據(jù)從全局內(nèi)存加載到 shared memory (異步加載, 如果支持)
if (grid.rank() == 0 && i + BLOCK_SIZE < k) {
//僅rank 0 的block進(jìn)行異步加載
//這里只是偽代碼,實(shí)際使用需要根據(jù)數(shù)據(jù)類型進(jìn)行調(diào)整
cudaMemcpyAsync(&x_shared[(buffer_idx+1)%2][0][0], &x_shares[(row * k) + i + BLOCK_SIZE], BLOCK_SIZE * BLOCK_SIZE * sizeof(T), cudaMemcpyDeviceToDevice);
cudaMemcpyAsync(&y_shared[(buffer_idx+1)%2][0][0], &y_shares[((i + BLOCK_SIZE) * n) + col], BLOCK_SIZE * BLOCK_SIZE * sizeof(T), cudaMemcpyDeviceToDevice);
cudaMemcpyAsync(&a_shared[(buffer_idx+1)%2][0][0], &a_shares[(row * k) + i + BLOCK_SIZE], BLOCK_SIZE * BLOCK_SIZE * sizeof(T), cudaMemcpyDeviceToDevice);
cudaMemcpyAsync(&b_shared[(buffer_idx+1)%2][0][0], &b_shares[((i + BLOCK_SIZE) * n) + col], BLOCK_SIZE * BLOCK_SIZE * sizeof(T), cudaMemcpyDeviceToDevice);
cudaMemcpyAsync(&c_shared[(buffer_idx+1)%2][0][0], &c_shares[row*n + col], BLOCK_SIZE*BLOCK_SIZE*sizeof(T), cudaMemcpyDeviceToDevice);
}
if(row < m && (i + tid % BLOCK_SIZE) < k){
x_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = x_shares[row * k + (i + tid % BLOCK_SIZE)];
} else {
x_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = 0;
}
if((i + tid / BLOCK_SIZE) < k && col < n) {
y_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = y_shares[(i + tid / BLOCK_SIZE) * n + col];
} else {
y_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = 0;
}
if(row < m && (i + tid % BLOCK_SIZE) < k){
a_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = a_shares[row*k + (i + tid%BLOCK_SIZE)];
} else {
a_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = 0;
}
if((i + tid / BLOCK_SIZE) < k && col < n){
b_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = b_shares[(i + tid / BLOCK_SIZE)*n + col];
} else {
b_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = 0;
}
```cuda
if(row < m && col < n) {
c_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = c_shares[row*n + col];
} else {
c_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] = 0;
}
cta.sync(); // 等待所有線程加載完成, 以及異步加載完成
// 計(jì)算當(dāng)前塊的乘積 (循環(huán)展開)
#pragma unroll
for (int j = 0; j < BLOCK_SIZE; ++j) {
// 計(jì)算 d = x - a 和 e = y - b (本地計(jì)算)
T d_local = x_shared[buffer_idx][tid / BLOCK_SIZE][j] - a_shared[buffer_idx][tid / BLOCK_SIZE][j];
T e_local = y_shared[buffer_idx][j][tid % BLOCK_SIZE] - b_shared[buffer_idx][j][tid % BLOCK_SIZE];
// 使用 warp-level shuffle 指令計(jì)算 d 和 e 的全局和
T d_global = 0;
T e_global = 0;
#pragma unroll
for (int w = 0; w < warp.size(); ++w) {
d_global += warp.shfl_xor(d_local, w);
e_global += warp.shfl_xor(e_local, w);
}
// 計(jì)算 z = c + d * b + e * a + d * e (本地計(jì)算)
// 手動(dòng)進(jìn)行指令級(jí)并行
T term1 = d_local * b_shared[buffer_idx][j][tid % BLOCK_SIZE];
T term2 = e_local * a_shared[buffer_idx][tid / BLOCK_SIZE][j];
T term3 = d_global * e_global;
//根據(jù)數(shù)據(jù)類型進(jìn)行模運(yùn)算
if constexpr (std::is_same_v<T, int8_t>) {
acc = (acc + c_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] + term1 + term2 + term3) & 0xFF;
} else {
acc += c_shared[buffer_idx][tid / BLOCK_SIZE][tid % BLOCK_SIZE] + term1 + term2 + term3;
}
}
cta.sync(); // 確保所有線程完成當(dāng)前塊的計(jì)算
buffer_idx = (buffer_idx + 1) % 2;
}
// 將結(jié)果寫回全局內(nèi)存
if(row < m && col < n){
z_shares[row * n + col] = acc;
}
}
代碼解釋:
- 模板參數(shù) ?
?T?
?:使用模板參數(shù) ??T?
?,可以支持 INT8 和 FP8 兩種數(shù)據(jù)類型。 - 雙緩沖 (Double Buffering):使用兩組 shared memory 數(shù)組,實(shí)現(xiàn)計(jì)算和數(shù)據(jù)加載的流水線操作。
- 異步數(shù)據(jù)加載:在外層循環(huán)的開始處,嘗試使用 ?
?cudaMemcpyAsync?
? 異步地將下一批次的數(shù)據(jù)從全局內(nèi)存加載到 shared memory。 - Cooperative Groups:使用 Cooperative Groups 提供的 ?
?thread_block?
??、??grid_group?
?? 和 ??thread_block_tile?
? 類型來更精細(xì)地控制線程塊和 warp 級(jí)別的并行。 - Warp-level Shuffle 指令優(yōu)化:
a.使用 ??warp.shfl_xor(val, lane)?
?? 替代 ??__shfl_xor_sync(mask, val, lane)?
?。
b.循環(huán)展開 warp-level shuffle 操作。
- 指令級(jí)并行(手動(dòng)):在計(jì)算 ?
?z?
? 時(shí),將乘法和加法運(yùn)算交錯(cuò)進(jìn)行,盡可能利用 GPU 的指令級(jí)并行能力。 - 循環(huán)展開:使用 ?
?#pragma unroll?
? 指令展開內(nèi)層循環(huán)。 - 模運(yùn)算: 如果 ?
?T?
?? 是 ??int8_t?
??,則使用 ??& 0xFF?
? 進(jìn)行模 256 運(yùn)算。 - Tiling: 使用tiling技術(shù)將矩陣分塊處理。
- 并行化:
a.線程塊 (Block):不同的線程塊負(fù)責(zé)計(jì)算輸出矩陣 Z 的不同塊(tiling)。
b.線程 (Thread):線程塊內(nèi)部的線程協(xié)同計(jì)算秘密共享乘法。
3.3.3 Hopper 架構(gòu)優(yōu)化(深化)
- TMA (Tensor Memory Accelerator):通過流水線、雙緩沖和 ?
?cudaMemcpyAsync?
?,盡可能利用 TMA 的異步數(shù)據(jù)傳輸能力,隱藏內(nèi)存訪問延遲。 - Tensor Core 利用:
#include <mma.h>
usingnamespace nvcuda;
// ...
wmma::fragment<wmma::matrix_a, 16, 16, 16, int8_t, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, int8_t, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, int32_t> c_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, int32_t> acc_frag;
wmma::fill_fragment(acc_frag, 0);
for (int i = 0; i < k; i += 16) {
wmma::load_matrix_sync(a_frag, &x_shared[...], ...); // 加載數(shù)據(jù)到 fragment, 需要根據(jù)實(shí)際情況填寫參數(shù)
wmma::load_matrix_sync(b_frag, &y_shared[...], ...); // 加載數(shù)據(jù)到 fragment, 需要根據(jù)實(shí)際情況填寫參數(shù)
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag); // 矩陣乘累加
}
//將秘密共享乘法結(jié)果加到acc_frag上
wmma::store_matrix_sync(&c_shared[...], acc_frag, ... , wmma::mem_row_major); // 存儲(chǔ)結(jié)果
```
* **FP8 計(jì)算:** 如果采用FP8, 可以直接使用DeepGEMM中針對(duì)FP8和Tensor Core的優(yōu)化。
* **數(shù)據(jù)類型轉(zhuǎn)換:** 如果 INT8 的 `wmma.mma` 指令效果不佳,可考慮將 INT8 份額轉(zhuǎn)換為 FP16 或 INT32,然后使用相應(yīng)的 `wmma.mma` 指令。但類型轉(zhuǎn)換也需在秘密共享下進(jìn)行。
a.INT8 計(jì)算:嘗試使用 ??wmma::mma_s8s8s32?
? 指令進(jìn)行 INT8 矩陣乘法:
- Shared Memory 優(yōu)化:
a.通過 tiling 技術(shù)和合理的數(shù)據(jù)訪問模式,最大程度地復(fù)用 shared memory 中的數(shù)據(jù)。
b.合理安排 shared memory 中數(shù)據(jù)的存儲(chǔ)位置,避免 bank conflict。
- Warp-level Primitives 與指令級(jí)并行:
a.充分利用??__shfl_xor_sync?
?? 或??warp.shfl_xor?
? 指令在 warp 內(nèi)部高效地進(jìn)行數(shù)據(jù)交換和規(guī)約求和。
b.在 kernel 代碼中,盡可能地將獨(dú)立的指令放在一起執(zhí)行,利用 GPU 的指令級(jí)并行能力。
3.3.4 GPU 并行 Beaver 三元組生成
算法:
- 初始化 cuRAND:在每個(gè)線程中初始化一個(gè) cuRAND 偽隨機(jī)數(shù)生成器狀態(tài)。
- 生成隨機(jī)數(shù):使用 cuRAND 庫在每個(gè)線程中并行生成三個(gè) INT8 或 FP8 類型的隨機(jī)數(shù)(a, b, c)。
- 驗(yàn)證三元組:在每個(gè)線程中驗(yàn)證生成的三元組是否滿足 Beaver 三元組的條件(?
?c == a * b?
?)。 - 秘密共享:在 kernel 中直接對(duì)驗(yàn)證通過的三元組 (a, b, c) 進(jìn)行加法秘密共享。
- 存儲(chǔ)份額:將每個(gè)參與方的三元組份額存儲(chǔ)到全局內(nèi)存中的一個(gè)數(shù)組中。
CUDA Kernel 代碼示例(INT8):
#include <curand_kernel.h>
struct BeaverTripleShares {
int8 a_share;
int8 b_share;
int8 c_share;
};
__global__ void generate_beaver_triples(BeaverTripleShares* triples, int num_triples, int num_parties) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
// 初始化 cuRAND 偽隨機(jī)數(shù)生成器
curandState_t state;
curand_init(blockIdx.x * blockDim.x + threadIdx.x, 0, 0, &state);
// 生成 Beaver 三元組并進(jìn)行秘密共享
if (tid < num_triples) {
// 1. 生成隨機(jī)數(shù)
int8 a = (int8)curand(&state);
int8 b = (int8)curand(&state);
int8 c = (int8)curand(&state);
// 2. 驗(yàn)證三元組 (注意處理溢出)
if (((int)a * (int)b & 0xFF) == (c & 0xFF)) {
// 3. 進(jìn)行秘密共享
int8 a_shares[num_parties];
int8 b_shares[num_parties];
int8 c_shares[num_parties];
int8 a_sum = 0;
int8 b_sum = 0;
int8 c_sum = 0;
for (int i = 0; i < num_parties - 1; i++) {
a_shares[i] = (int8)curand(&state);
b_shares[i] = (int8)curand(&state);
c_shares[i] = (int8)curand(&state);
a_sum += a_shares[i];
b_sum += b_shares[i];
c_sum += c_shares[i];
}
a_shares[num_parties - 1] = a - a_sum; // 加法秘密共享
b_shares[num_parties - 1] = b - b_sum;
c_shares[num_parties - 1] = c - c_sum;
a_shares[num_parties - 1] = a_shares[num_parties-1] & 0xFF;
b_shares[num_parties - 1] = b_shares[num_parties-1] & 0xFF;
c_shares[num_parties - 1] = c_shares[num_parties-1] & 0xFF;
// 4. 存儲(chǔ)秘密份額
for (int i = 0; i < num_parties; i++) {
triples[tid * num_parties + i].a_share = a_shares[i];
triples[tid * num_parties + i].b_share = b_shares[i];
triples[tid * num_parties + i].c_share = c_shares[i];
}
} else {
// 如果驗(yàn)證失敗,可以將其設(shè)置為一個(gè)特殊值(如全 0),
for (int i = 0; i < num_parties; i++) {
triples[tid * num_parties + i].a_share = 0;
triples[tid * num_parties + i].b_share = 0;
triples[tid * num_parties + i].c_share = 0;
}
}
}
}
代碼解釋:
- ?
?curand_kernel.h?
?:包含了 cuRAND 庫的函數(shù)聲明。 - ?
?BeaverTripleShares?
? 結(jié)構(gòu)體:定義了 Beaver 三元組份額的結(jié)構(gòu)。 - ?
?generate_beaver_triples?
? kernel:
a.在 kernel 中直接對(duì)驗(yàn)證通過的 Beaver 三元組 (a, b, c) 進(jìn)行加法秘密共享。
b.為每個(gè)參與方生成隨機(jī)份額。
c.最后一個(gè)參與方的份額通過總和與其他份額的差值計(jì)算得到。
- ?
?((int)a * (int)b & 0xFF)?
?:計(jì)算 a * b (mod 256)。 - ?
?(c & 0xFF)?
??:取 ??c?
? 的低 8 位。 - ?
?triples?
??:指向全局內(nèi)存中存儲(chǔ) Beaver 三元組份額的數(shù)組的指針,其大小應(yīng)為 ??num_triples * num_parties?
?。 - ?
?num_triples?
?:要生成的 Beaver 三元組的數(shù)量。 - ?
?num_parties?
?:參與方的數(shù)量。 - ?
?tid?
?:線程 ID。 - ?
?curandState_t?
?:cuRAND 偽隨機(jī)數(shù)生成器的狀態(tài)。每個(gè)線程都需要一個(gè)獨(dú)立的狀態(tài)。 - ?
?curand_init?
?:初始化偽隨機(jī)數(shù)生成器。這里使用線程 ID 作為種子,確保每個(gè)線程生成的隨機(jī)數(shù)序列不同。 - ?
?curand?
??:生成一個(gè) 32 位無符號(hào)整數(shù)隨機(jī)數(shù)。我們將其強(qiáng)制轉(zhuǎn)換為 ??int8?
?。 - 驗(yàn)證三元組:
- 秘密共享:
- 存儲(chǔ)份額:將每個(gè)參與方的三元組份額存儲(chǔ)到 ?
?triples?
? 數(shù)組中。
使用方法:
- 在 GPU 上分配足夠大的內(nèi)存來存儲(chǔ) Beaver 三元組的所有份額 (?
?BeaverTripleShares* triples?
?)。 - 調(diào)用 ?
?generate_beaver_triples?
? kernel,生成 Beaver 三元組并進(jìn)行秘密共享。 - 在 MPC-GEMM kernel 中,每個(gè)線程根據(jù)其線程 ID 和參與方 ID 從 ?
?triples?
? 數(shù)組中獲取相應(yīng)的 Beaver 三元組份額。
優(yōu)化:
- 可以通過增加線程塊和線程數(shù)量來進(jìn)一步提高 Beaver 三元組生成的并行度。
- 可以使用更高效的隨機(jī)數(shù)生成器(如 Philox 算法)來提高隨機(jī)數(shù)生成的速度和質(zhì)量。
- 可以將 Beaver 三元組的生成、驗(yàn)證和秘密共享融合到一個(gè) kernel 中,減少數(shù)據(jù)傳輸開銷。
3.3.5 JIT 編譯優(yōu)化
JIT 編譯技術(shù)允許我們?cè)谶\(yùn)行時(shí)根據(jù)具體的參數(shù)動(dòng)態(tài)生成優(yōu)化的 CUDA kernel 代碼。在 MPC-GEMM 中,我們可以利用 JIT 編譯進(jìn)行以下優(yōu)化:
- 代碼特化:
- GEMM 參數(shù):根據(jù) GEMM 運(yùn)算的形狀(M, N, K)、塊大小(BLOCK_SIZE)、數(shù)據(jù)類型(INT8 或 FP8)等參數(shù),生成專門針對(duì)這些參數(shù)優(yōu)化的 kernel 代碼。例如,可以根據(jù) M、N、K 的大小選擇最合適的 tiling 策略和 shared memory 使用方式。
- MPC 參數(shù):根據(jù)參與方數(shù)量、秘密共享方案(加法秘密共享)等參數(shù),生成相應(yīng)的 kernel 代碼。例如,如果參與方數(shù)量較少,可以使用更激進(jìn)的 warp-level shuffle 優(yōu)化。
- Hopper 架構(gòu)特性:根據(jù)目標(biāo) GPU 的計(jì)算能力(Compute Capability),啟用或禁用某些 Hopper 架構(gòu)特有的優(yōu)化(如 TMA)。
- 常量折疊:
- Beaver 三元組內(nèi)聯(lián):如果 Beaver 三元組是在預(yù)處理階段生成的,并且在 kernel 執(zhí)行期間不會(huì)改變,可以將三元組的份額直接作為編譯時(shí)常量?jī)?nèi)聯(lián)到 kernel 代碼中,減少運(yùn)行時(shí)內(nèi)存訪問。
- 其他常量:將參與方數(shù)量、塊大小、GEMM 形狀等參數(shù)也作為編譯時(shí)常量?jī)?nèi)聯(lián)到 kernel 代碼中,允許編譯器進(jìn)行更多的優(yōu)化(如常量傳播、死代碼消除等)。
- 循環(huán)展開:
- 根據(jù) GEMM 形狀和塊大小,對(duì) kernel 中的循環(huán)進(jìn)行部分或完全展開,減少循環(huán)控制開銷,并增加指令級(jí)并行度。
- 特別是對(duì)于秘密共享乘法協(xié)議中的內(nèi)層循環(huán),可以進(jìn)行更激進(jìn)的展開。
- 指令級(jí)并行:
- JIT 編譯器可以分析 kernel 代碼中的數(shù)據(jù)依賴關(guān)系,盡可能地將獨(dú)立的指令放在一起執(zhí)行,利用 GPU 的指令級(jí)并行能力。
- 我們可以手動(dòng)調(diào)整 kernel 代碼中的指令順序,以幫助編譯器更好地進(jìn)行指令級(jí)并行優(yōu)化。
- 自動(dòng)調(diào)整block size和grid size: 可以根據(jù)矩陣規(guī)模、數(shù)據(jù)類型等,自動(dòng)調(diào)整kernel的block size和grid size,以充分利用GPU資源。
實(shí)現(xiàn)方式:
- NVRTC (NVIDIA Runtime Compilation):NVRTC 是 NVIDIA 提供的一個(gè)運(yùn)行時(shí)編譯庫,可以在程序運(yùn)行時(shí)將 CUDA C++ 代碼編譯為 PTX 匯編代碼,然后加載到 GPU 中執(zhí)行。
- NVCC (NVIDIA CUDA Compiler):NVCC 是 NVIDIA 的 CUDA 編譯器,也可以用于 JIT 編譯。可以在編譯時(shí)使用 ?
?-D?
?? 選項(xiàng)定義宏,然后在 kernel 代碼中使用 ??#ifdef?
? 等預(yù)處理指令來根據(jù)不同的宏定義生成不同的代碼。
示例:
假設(shè)我們要根據(jù)參與方數(shù)量 ??n?
? 進(jìn)行代碼特化。我們可以在 kernel 代碼中使用如下預(yù)處理指令:
#if N_PARTIES == 2
// 針對(duì) 2 個(gè)參與方的優(yōu)化代碼
int8 d_global = __shfl_xor_sync(0xFFFFFFFF, d_local, 1);
```C++
int8 e_global = __shfl_xor_sync(0xFFFFFFFF, e_local, 1);
#elif N_PARTIES == 3
// 針對(duì) 3 個(gè)參與方的優(yōu)化代碼
int8 d_global = 0;
int8 e_global = 0;
for (int w = 0; w < warp.size(); ++w) {
d_global += warp.shfl_xor(d_local, w);
e_global += warp.shfl_xor(e_local, w);
}
#else
// 通用代碼
#endif
在編譯時(shí),通過 ??-D?
?? 選項(xiàng)指定 ??N_PARTIES?
? 的值,NVCC 或 NVRTC 就會(huì)生成針對(duì)特定參與方數(shù)量的優(yōu)化 kernel 代碼。
4. 方案論證
4.1 可行性論證
- DeepGEMM Kernel 可修改性:DeepGEMM 的 CUDA kernel 本質(zhì)上是 C/C++ 代碼,可以進(jìn)行修改和擴(kuò)展。
- 秘密共享運(yùn)算可實(shí)現(xiàn)性:加法秘密共享和基于 Beaver 三元組的乘法協(xié)議都可以在 INT8 或 FP8 數(shù)據(jù)類型上高效實(shí)現(xiàn)。
- GPU 并行計(jì)算可行性:CUDA 編程模型支持細(xì)粒度的并行計(jì)算,可以充分利用 GPU 的并行計(jì)算能力。
- JIT 編譯可行性:JIT 編譯技術(shù)已經(jīng)廣泛應(yīng)用,NVRTC 和 NVCC 都提供了 JIT 編譯功能。
- Hopper 架構(gòu)優(yōu)化可行性:Hopper 架構(gòu)的特性(TMA、Tensor Core、Shared Memory、Warp-level Primitives)都可以在 CUDA 編程中加以利用。
4.2 安全性論證
本方案的安全性基于以下幾個(gè)方面:
- 秘密共享的安全性:采用的加法秘密共享方案是信息論安全的,只要參與方不合謀,任何單獨(dú)的秘密份額都不會(huì)泄露關(guān)于原始數(shù)據(jù)的任何信息。
- Beaver 三元組乘法協(xié)議的安全性:Beaver 三元組乘法協(xié)議在半誠實(shí)模型下是安全的。只要 Beaver 三元組是獨(dú)立于輸入數(shù)據(jù)生成的,并且參與方誠實(shí)地執(zhí)行協(xié)議,攻擊者就無法從公開的中間值(d 和 e)中推斷出關(guān)于秘密輸入(x 和 y)的任何信息。
- GPU 計(jì)算的安全性:
- GPU 始終只接觸到秘密份額,無法獲得任何關(guān)于明文數(shù)據(jù)的信息。
- 重構(gòu)后的 DeepGEMM kernel 只執(zhí)行秘密共享運(yùn)算,不包含任何可能泄露敏感信息的操作(如直接訪問內(nèi)存地址、向外部發(fā)送數(shù)據(jù)等)。
- 即使攻擊者控制了 GPU,也只能獲得秘密份額,無法恢復(fù)出原始數(shù)據(jù)。
- JIT 編譯的安全性:
- JIT 編譯器生成的 kernel 代碼只包含必要的秘密共享運(yùn)算和優(yōu)化邏輯,不包含任何惡意代碼。
- 可以對(duì) JIT 編譯器生成的代碼進(jìn)行靜態(tài)分析和安全審計(jì)。
- 抵御側(cè)信道攻擊:
- 雖然 GPU 內(nèi)部的計(jì)算對(duì)參與方透明,但仍然需要考慮側(cè)信道攻擊(如時(shí)間攻擊、功耗攻擊)。
- 可以采用掩碼(masking)技術(shù)來防御側(cè)信道攻擊。具體來說,可以將秘密份額與一個(gè)隨機(jī)數(shù)進(jìn)行運(yùn)算(如異或),然后在掩碼后的份額上進(jìn)行計(jì)算,最后再去除掩碼。
- 可以對(duì) kernel 代碼進(jìn)行隨機(jī)化,使得每次執(zhí)行的指令順序和內(nèi)存訪問模式都不同,增加側(cè)信道攻擊的難度。
4.3 高效性論證
相比于傳統(tǒng)的“兩張皮”MPC-GEMM 方案,本方案具有以下優(yōu)勢(shì):
- 消除交互開銷:將 MPC 協(xié)議邏輯直接嵌入到 DeepGEMM kernel 中,徹底消除了 MPC 協(xié)議與 GPU 計(jì)算之間的所有交互開銷(如數(shù)據(jù)格式轉(zhuǎn)換、安全通道傳輸?shù)龋_@是本方案相對(duì)于傳統(tǒng)方案最大的優(yōu)勢(shì)所在。
- 充分利用 DeepGEMM 優(yōu)化:GPU 直接執(zhí)行 MPC-GEMM 運(yùn)算,可以充分利用 DeepGEMM 原有的針對(duì) GPU 架構(gòu)的各種優(yōu)化(tiling、loop unrolling、shared memory 利用、TMA、Tensor Core、warp-level primitives、指令級(jí)并行等)。
- 低精度計(jì)算:使用 INT8 或 FP8 數(shù)據(jù)類型,相比于 FP32 或 FP64,可以顯著減少計(jì)算量和通信量。
- GPU 并行 Beaver 三元組生成:利用 GPU 并行生成 Beaver 三元組,大幅減少了預(yù)處理階段的開銷。
- 簡(jiǎn)化的 MPC 協(xié)議:將 GPU 視為“半誠實(shí)”參與方,可以簡(jiǎn)化 MPC 協(xié)議的設(shè)計(jì),減少通信輪數(shù)。
- JIT 編譯優(yōu)化:通過 JIT 編譯,可以針對(duì)具體的 GEMM 參數(shù)和 MPC 參數(shù)生成高度定制化的 kernel 代碼,進(jìn)一步提升性能。
- 高度并行化: 秘密共享的加法、乘法,Beaver三元組的生成都可以在GPU上高度并行。
量化分析(舉例):
假設(shè)一個(gè) MPC-GEMM 運(yùn)算涉及兩個(gè)矩陣 A 和 B 的乘法,矩陣大小為 1024x1024,參與方數(shù)量為 3。
傳統(tǒng)“兩張皮”方案:
整個(gè)過程中,數(shù)據(jù)至少需要在網(wǎng)絡(luò)上傳輸 3 次(輸入 2 次,輸出 1 次),并且涉及到多次數(shù)據(jù)格式轉(zhuǎn)換。
- 參與方之間需要通過網(wǎng)絡(luò)傳輸秘密份額(FP32 或 FP64)。
- 需要將秘密份額轉(zhuǎn)換為 GPU 可處理的格式(如加密)。
- GPU 執(zhí)行 GEMM 計(jì)算。
- 將計(jì)算結(jié)果(加密或編碼)傳輸回參與方。
- 參與方進(jìn)行解密或解碼,并重構(gòu)結(jié)果。
本方案:
整個(gè)過程中,數(shù)據(jù)只需要在網(wǎng)絡(luò)上傳輸 2 次(輸入和輸出),并且都是 INT8 類型,數(shù)據(jù)量大大減少。GPU 內(nèi)部的計(jì)算高度優(yōu)化,且無需與 MPC 協(xié)議進(jìn)行交互。
- 參與方將輸入數(shù)據(jù)進(jìn)行秘密共享,并映射到 INT8。
- 參與方將 INT8 秘密份額直接發(fā)送給 GPU(通過 MPI 等)。
- GPU 執(zhí)行重構(gòu)后的 DeepGEMM kernel,直接在 INT8 秘密份額上進(jìn)行計(jì)算。
- GPU 將計(jì)算結(jié)果(INT8 秘密份額)返回給參與方。
- 參與方重構(gòu)結(jié)果。
因此,我們可以預(yù)期,本方案的性能將比傳統(tǒng)方案有數(shù)量級(jí)的提升。
4.4 與其他方案的對(duì)比
方案 | 優(yōu)點(diǎn) | 缺點(diǎn) |
本方案 | 1. 深度融合 MPC 與 GPU 計(jì)算,消除了交互開銷。 2. 充分利用 DeepGEMM 的優(yōu)化和 GPU 架構(gòu)特性。 3. 采用 INT8/FP8 低精度計(jì)算。 4. GPU 并行 Beaver 三元組生成。 5. JIT 編譯優(yōu)化,kernel 代碼高度定制化。 6. 安全性基于信息論安全的秘密共享。 | 1. 需要對(duì) DeepGEMM kernel 進(jìn)行深度重構(gòu),開發(fā)難度較高。 2. 安全性依賴于 GPU 不泄露秘密份額(半誠實(shí)模型)。 3. 目前主要支持加法秘密共享和 Beaver 三元組乘法,對(duì)其他 MPC 協(xié)議的支持需要進(jìn)一步研究。 |
傳統(tǒng)“兩張皮”MPC-GEMM 方案 | 1. MPC 協(xié)議與 GPU 計(jì)算分離,模塊化程度高,易于實(shí)現(xiàn)和維護(hù)。 2. 可以使用現(xiàn)有的 MPC 框架和 GPU 加速庫。 | 1. 存在 MPC 協(xié)議與 GPU 計(jì)算之間的交互開銷(數(shù)據(jù)轉(zhuǎn)換、通信)。 2. GPU 計(jì)算部分受到 MPC 協(xié)議的制約,無法充分發(fā)揮 GPU 的性能和 DeepGEMM 的優(yōu)化。 |
基于 TEE 的 MPC-GEMM 方案 | 1. TEE 提供了一個(gè)可信的執(zhí)行環(huán)境,可以保護(hù)計(jì)算過程的安全性。 2. 可以利用 TEE 內(nèi)部的 GPU 進(jìn)行加速計(jì)算。 | 1. 安全性依賴于 TEE 硬件的安全性假設(shè)(存在側(cè)信道攻擊等風(fēng)險(xiǎn))。 2. TEE 的性能通常低于原生 GPU 計(jì)算。 3. TEE 的可用資源(內(nèi)存、計(jì)算能力)有限。 4. 不同廠商的 TEE 實(shí)現(xiàn)存在差異,可移植性較差。 |
基于同態(tài)加密的 MPC-GEMM 方案 | 1. 安全性基于數(shù)學(xué)難題(如格密碼),安全性高。 2. 可以在密文上直接進(jìn)行計(jì)算,無需解密。 | 1. 計(jì)算開銷非常大,通常比明文計(jì)算慢幾個(gè)數(shù)量級(jí),難以應(yīng)用于大規(guī)模矩陣運(yùn)算。 2. 通信開銷也很大,因?yàn)槊芪耐ǔ1让魑拇蠛芏唷?3. 支持的運(yùn)算類型有限,通常只支持加法和乘法同態(tài),難以支持復(fù)雜的非線性運(yùn)算。 4. 需要針對(duì)同態(tài)加密的特性對(duì)算法和 kernel 進(jìn)行重新設(shè)計(jì)。 |
對(duì)比總結(jié):
- 性能:本方案 > 基于 TEE 的方案 > 傳統(tǒng)“兩張皮”方案 > 基于同態(tài)加密的方案
- 安全性:本方案 ≈ 基于同態(tài)加密的方案 > 傳統(tǒng)“兩張皮”方案 > 基于 TEE 的方案
- 開發(fā)難度:基于同態(tài)加密的方案 > 本方案 > 基于 TEE 的方案 > 傳統(tǒng)“兩張皮”方案
- 硬件依賴:基于 TEE 的方案 > 本方案 > 傳統(tǒng)“兩張皮”方案 ≈ 基于同態(tài)加密的方案
- 靈活性:傳統(tǒng)“兩張皮”方案 > 本方案 > 基于 TEE 的方案 > 基于同態(tài)加密的方案
5. 總結(jié)
本文提出了一種基于秘密共享重構(gòu) DeepGEMM kernel 的 MPC-GEMM 方案。該方案通過將 MPC 協(xié)議邏輯直接嵌入到 DeepGEMM kernel 中,實(shí)現(xiàn)了 MPC 與 GPU 計(jì)算的深度融合,徹底消除了傳統(tǒng)方案中的“兩張皮”問題。方案充分利用了 DeepGEMM 的優(yōu)化技術(shù)、Hopper 架構(gòu)特性、INT8/FP8 低精度計(jì)算、GPU 并行 Beaver 三元組生成以及 JIT 編譯等關(guān)鍵技術(shù),在保證計(jì)算安全性的前提下,最大程度地發(fā)揮了 GPU 的計(jì)算能力。
相比于傳統(tǒng)的 MPC-GEMM 方案,理論上本方案在性能上具有顯著優(yōu)勢(shì),同時(shí)在安全性方面也達(dá)到了較高的水平。本方案為構(gòu)建高效安全的 MPC-GEMM 提供了一條全新的技術(shù)路線,是對(duì) MPC 與 GPU 加速深度融合的一次探索性設(shè)想。
參考鏈接:??https://github.com/deepseek-ai/DeepGEMM??
本文轉(zhuǎn)載自??上堵吟??,作者:上堵吟
