【TVM 教程】編寫自定義 Pass 原創
Apache TVM是一個深度的深度學習編譯框架,適用于 CPU、GPU 和各種機器學習加速芯片。更多 TVM 中文文檔可訪問 →https://tvm.hyper.ai/
作者:Jian Weng
TVM 是一個抽象出機器學習加速器異質性的框架,有時用戶希望自定義一些分析和 IR 轉換,使得 TVM 適應自己的專用硬件。本教程介紹如何在 TVM 中編寫自定義 Pass。
先決條件?
閱讀本教程前,假設讀者已經熟悉以下主題:
- 在 TVM 中編寫算法并對其進行調度,若不熟悉,請參閱示例教程如?如何在 CPU 上優化 GEMM。
- 熟悉 HalideIR 的基本結構,若不熟悉,請參閱?
HalideIR/src/ir/IR.h
?了解定義了 IR 節點的哪些屬性。 - 訪問器設計模式,若不熟悉,請參閱?Python AST 模塊?以查看 AST 訪問器的實現原理。
- Schedule 如何降低為 IRModule 類或 LLVM 模塊。若不熟悉,請參閱?
python/tvm/build_module.py
?獲取相關基礎知識。
import tvm
from tvm import te
import numpy as np
首先編寫一個簡單的向量加法,并用默認 schedule 構建。然后,使用自定義的降低 pass 而非調度原語,來直接操作 IR。
n = tvm.tir.const(128, "int32")
a = te.placeholder((n,), name="a")
b = te.placeholder((n,), name="b")
c = te.compute((n,), lambda i: a[i] + b[i], name="c")
sch = te.create_schedule(c.op)
ir = tvm.lower(sch, [a, b, c])
print(ir)
輸出結果:
@main = primfn(a_1: handle, b_1: handle, c_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {a: Buffer(a_2: Pointer(float32), float32, [128], []),
b: Buffer(b_2: Pointer(float32), float32, [128], []),
c: Buffer(c_2: Pointer(float32), float32, [128], [])}
buffer_map = {a_1: a, b_1: b, c_1: c}
preflattened_buffer_map = {a_1: a_3: Buffer(a_2, float32, [128], []), b_1: b_3: Buffer(b_2, float32, [128], []), c_1: c_3: Buffer(c_2, float32, [128], [])} {
for (i: int32, 0, 128) {
c[i] = (a[i] + b[i])
}
}
編寫 Pass?
本質上,「IR 轉換 pass」是將語句映射到新語句的函數。因此,我們要定義這個向量化函數,并逐步實現它。
TVM 為用戶提供了兩個類來分析和轉換 IR。
IR 訪問器?
可以用?tvm.tir.stmt_functor.post_order_visit(stmt, func)
?從 Halide IR 中收集信息。?func
?是一個回調函數,會在退出當前 IR 節點之前調用,即 post-order visit。然后存儲 IR 訪問的結果,因為?func
?的返回值將被忽略。
備注
必須用數組來存儲 IR 訪問的結果。值甚至是一個單變量。這主要是由于 Python-C runtime 的限制,每次遞歸都會刷新變量值,但會保留數組值。
loops = []
def find_width8(op):
"""查找范圍可以被 8 整除的所有「tir.For」節點。"""
if isinstance(op, tvm.tir.For):
if isinstance(op.extent, tvm.tir.IntImm):
if op.extent.value % 8 == 0:
loops.append(op)
IR 轉換?
轉換接口與訪問器接口略有不同。訪問器中只有一個后序回調,但轉換訪問器同時支持前序回調和后序回調。若要保留原始 IR 節點,只需返回 None。若要將當前節點更改為某個節點,使用 TVM IR maker 接口構建,并返回這個值。
備注
若調用 pre-order 函數后返回一個非 None 的值,則將跳過 post-order 函數。
def vectorize8(op):
"""Split 可以向量化 `find_width8` 中的循環。"""
if op in loops:
extent = op.extent.value
name = op.loop_var.name
lo, li = te.var(name + ".outer"), te.var(name + ".inner")
body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})
body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)
body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)
return body
return None
@tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
global loops
tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)
if not loops:
return f
# 最后一個列表參數表示將轉換哪些類型的節點。
# 在這種情況下,只有 `For` 節點會調用 `vectorize8`
return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ["tir.For"]))
對接低層(Glue to Lowering)?
到目前為止,已經完成了這個 IR 轉換 pass 的編寫。接下來將這個 pass 和 TVM 的底層 pass 對接。
在這種情況下,通過元組列表作為參數提供給?tir.add_lower_pass
,將上面編寫的 pass 注入 TVM 標準較低級的 pass。「元組」表示降級的不同階段。 TVM 中有四個階段的降級,每個階段完成后,都會調用自定義的階段。
備注
以下是每個階段完成的基本轉換:
- 階段 0 生成原始 IR 和循環級別。
- 階段 1 扁平化數組存儲。
- 階段 2 轉換循環,如展開、矢量化和線程綁定。
- 階段 3 清理工作。
因此,這個轉換 pass 適合放在第 1 階段之后。
with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, vectorize)]}):
print(tvm.lower(sch, [a, b, c]))
輸出結果:
@main = primfn(a_1: handle, b_1: handle, c_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {a: Buffer(a_2: Pointer(float32), float32, [128], []),
b: Buffer(b_2: Pointer(float32), float32, [128], []),
c: Buffer(c_2: Pointer(float32), float32, [128], [])}
buffer_map = {a_1: a, b_1: b, c_1: c}
preflattened_buffer_map = {a_1: a_3: Buffer(a_2, float32, [128], []), b_1: b_3: Buffer(b_2, float32, [128], []), c_1: c_3: Buffer(c_2, float32, [128], [])} {
for (i.outer: int32, 0, 16) {
let cse_var_1: int32 = (i.outer*8)
c[ramp(cse_var_1, 1, 8)] = (a[ramp(cse_var_1, 1, 8)] + b[ramp(cse_var_1, 1, 8)])
}
}
快速回顧?
快速回顧本教程有關編寫自定義 IR 轉換 pass:
- 用?
tvm.tir.stmt_functor.post_order_visit
?收集每個 IR 節點的信息。 - 用?
tvm.tir.stmt_functor.ir_transform
?轉換 IR 節點。 - 總結以上兩點來編寫一個 IR 轉換函數。
- 用?
tvm.transform.PassContext
?將此函數放入 TVM 降級 pass。
