譯者 | 李睿
審校 | 重樓
本文介紹了如何使用FastAPI和Redis緩存加速機器學習模型服務。FastAPI作為高性能Web框架用于構建API,Redis作為內存中的數據結構存儲系統作為緩存層。通過集成FastAPI和Redis,系統能快速響應重復請求,避免冗余計算,顯著降低延遲和CPU負載。此外還詳細闡述了實現步驟,包括加載模型、創建FastAPI端點、設置Redis緩存及測試性能提升。
你是否因為等待機器學習模型返回預測結果而耗費過長時間?很多人都有過這樣的經歷。機器學習模型在實時服務時可能會非常緩慢,尤其是那些大型且復雜的機器學習模型。另一方面,用戶希望得到即時反饋。因此,這使得延遲問題愈發凸顯。從技術層面來看,最主要的問題之一是當相同的輸入反復觸發相同的緩慢過程時,會出現冗余計算。本文將展示如何解決這個問題,因此將構建一個基于FastAPI的機器學習服務,并集成Redis緩存,以便在毫秒級的時間內迅速返回重復的預測結果。
什么是FastAPI?
FastAPI是一個基于Python的現代Web框架,用于構建 API。它使用Python的類型提示進行數據驗證,并使用Swagger UI和ReDoc自動生成交互式API文檔。FastAPI基于Starlette和Pydantic構建,支持異步編程,使其性能可與Node.js和Go相媲美。其設計有助于快速開發健壯的、生產就緒的API,使其成為將機器學習模型部署為可擴展的RESTful服務的絕佳選擇。
什么是Redis?
Redis(Remote Dictionary Server)是一個開源的內存數據結構存儲系統,其功能包括數據庫、緩存和消息代理。通過將數據存儲在內存中,Redis為讀寫操作提供了超低延遲,使其成為緩存頻繁或計算密集型任務(例如機器學習模型預測)的理想選擇。它支持各種數據結構,包括字符串、列表、集合和散列,并提供密鑰過期(TTL)等功能,以實現高效的緩存管理。
為什么要結合FastAPI和Redis?
將FastAPI與Redis集成,可以創建一個響應速度快、效率高的系統。FastAPI作為處理API請求的快速且可靠的接口,而Redis則作為緩存層可以存儲之前計算的結果。當再次接收到相同的輸入時,可以立即從Redis檢索結果,無需重新計算。這種方法降低了延遲,減輕了計算負載,并提高了應用程序的可擴展性。在分布式環境中,Redis充當可由多個FastAPI實例訪問的集中式緩存,使其非常適用于生產級機器學習部署。
接下來,深入了解如何實現一個使用Redis緩存提供機器學習模型預測的FastAPI應用程序。這種設置能夠確保針對相同輸入的重復請求能夠迅速從緩存中獲取服務,從而大幅減少計算時間,并縮短響應時間。其實現步驟如下:
- 加載預訓練模型
- 創建FastAPI預測端點
- 設置Redis緩存
- 測試和衡量性能提升
以下詳細地了解這些步驟。
步驟1:加載預訓練模型
首先,假設已經擁有一個訓練有素的機器學習模型,并準備將其投入部署。在實際應用中,大多數機器學習模型都是離線訓練的(例如scikit-learn模型,TensorFlow/Pytorch模型等),并保存到磁盤中,然后加載到服務應用程序中。在這個示例中,將創建一個簡單的scikit-learn分類器,它將在著名的Iris flower數據集上進行訓練,并使用joblib庫保存。如果已經保存了一個模型文件,可以跳過訓練步驟直接加載它。以下介紹如何訓練一個模型,然后加載它進行服務:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import joblib
# Load example dataset and train a simple model (Iris classification)
X, y = load_iris(return_X_y=True)
# Train the model
model = RandomForestClassifier().fit(X, y)
# Save the trained model to disk
joblib.dump(model, "model.joblib")
# Load the pre-trained model from disk (using the saved file)
model = joblib.load("model.joblib")
print("Model loaded and ready to serve predictions.")
在以上的代碼中,使用了scikit-learn的內置Iris數據集訓練了一個隨機森林分類器,然后將該模型保存到一個名為model.joblib的文件中。之后,使用joblib.load方法將其重新加載。joblib庫在保存scikit-learn模型時非常常見,主要是因為它擅長處理模型內的NumPy數組。隨后,就有了一個可以預測新數據的模型對象。不過需要注意的是,在這里使用任何預訓練的模型,使用FastAPI提供服務的方式以及緩存的結果或多或少是相同的。唯一的問題是,模型應該有一個預測方法,該方法接受一些輸入并產生結果。此外,確保每次輸入相同的數據時,都能給出一致的預測結果(即模型需具備確定性)。如果不是這樣,緩存對于非確定性模型來說將會出現問題,因為它將返回不正確的結果。
步驟2:創建FastAPI預測端點
現在已經有了一個訓練好的模型,可以通過API來使用它。我們將使用FASTAPI創建一個Web服務器來處理預測請求。FASTAPI可以很容易地定義端點并將請求參數映射到Python函數參數。在這個示例中,將假設模型需要四個特征作為輸入。并將創建一個GET端點/預測,該端點/預測接受這些特征作為查詢參數并返回模型的預測。
from fastapi import FastAPI
import joblib
app = FastAPI()
# Load the trained model at startup (to avoid re-loading on every request)
model = joblib.load("model.joblib") # Ensure this file exists from the training step
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
""" Predict the Iris flower species from input measurements. """
# Prepare the features for the model as a 2D list (model expects shape [n_samples, n_features])
features = [[sepal_length, sepal_width, petal_length, petal_width]]
# Get the prediction (in the iris dataset, prediction is an integer class label 0,1,2 representing the species)
prediction = model.predict(features)[0] # Get the first (only) prediction
return {"prediction": str(prediction)}
在以上代碼中,成功創建了一個FastAPI應用程序,并在執行該文件之后啟動API服務器。FastAPI對于Python來說非???,因此它可以輕松地處理大量請求。為避免在每次請求時都重復加載模型(這一操作會顯著降低性能),在程序啟動時就將模型加載到內存中,以便隨時調用。隨后使用@app創建了一個/predict端點。GET使測試變得簡單,因為可以在URL中傳遞內容,但在實際項目中,可能會想要使用POST,特別是在發送大型或復雜的輸入(如圖像或JSON)時。
這個函數接受4個輸入參數:sepal_length、sepal_width、petal_length和petal_width, FastAPI會自動從URL中讀取它們。在函數內部,將所有輸入放入一個2D列表中(因為scikit-learn只接受二維數組作為輸入),然后調用model.predict(),它會返回一個列表。然后將其作為JSON返回,例如{ “prediction”: “...”}。
該系統現在已經能正常運行,可以使用uvicorn main:app–reload命令運行它,然后訪問 /predict 端點并獲取結果。然而,再次發送相同的輸入,它仍然會再次運行模型,這顯然不夠高效,所以下一步是添加Redis來緩存之前的結果,從而避免重復計算。
步驟3:設置Redis緩存
為了緩存模型輸出,將使用Redis。首先,確保Redis服務器正在運行。你可以在本地安裝,或者直接運行Docker容器;在默認情況下,它通常運行在端口6379上,并使用Python Redis庫與服務器通信。
所以,其思路很簡單:當請求進來時,創建一個表示輸入的唯一鍵。然后檢查該鍵是否存在于Redis中;如果那個鍵已經存在,這意味著之前已經緩存了這個,所以只返回保存的結果,不需要再次調用模型。如果沒有,則執行model.predict,獲得輸出,將其保存在Redis中,并返回預測。
現在更新FastAPI應用程序來添加這個緩存邏輯。
!pip install redis
import redis # New import to use Redis
# Connect to a local Redis server (adjust host/port if needed)
cache = redis.Redis(host="localhost", port=6379, db=0)
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
"""
Predict the species, with caching to speed up repeated predictions.
"""
# 1. Create a unique cache key from input parameters
cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}"
# 2. Check if the result is already cached in Redis
cached_val = cache.get(cache_key)
if cached_val:
# If cache hit, decode the bytes to a string and return the cached prediction
return {"prediction": cached_val.decode("utf-8")}
# 3. If not cached, compute the prediction using the model
features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(features)[0]
# 4. Store the result in Redis for next time (as a string)
cache.set(cache_key, str(prediction))
# 5. Return the freshly computed prediction
return {"prediction": str(prediction)}
在以上的代碼中添加了Redis。首先,使用redis.Redis()創建了一個客戶端,它連接到Redis服務器。在默認情況下使用db=0。然后,通過連接輸入值來創建一個緩存鍵。在這里,它之所以有效,是因為輸入是簡單的數字,但對于復雜的數字,最好使用散列或JSON字符串。每個輸入的鍵必須是唯一的。因此使用了cache.get(cache_key)。如果它找到相同的鍵,它就返回這個鍵,這使其速度更快,并且不需要重新運行模型。但是如果在緩存中沒有找到,需要運行模型并獲得預測結果。最后,使用cache.set()保存在Redis中。而當相同的輸入在下次到來時,因為它已經存在,因為緩存將會很快。
步驟4:測試和衡量性能提升
現在,FastAPI應用程序正在運行并連接到Redis,現在是測試緩存如何提高響應時間的時候了。在這里,演示如何使用Python的請求庫使用相同的輸入兩次調用API,并衡量每次調用所花費的時間。此外,需要確保在運行測試代碼之前啟動FastAPI:
import requests, time
# Sample input to predict (same input will be used twice to test caching)
params = {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
# First request (expected to be a cache miss, will run the model)
start = time.time()
response1 = requests.get("http://localhost:8000/predict", params=params)
elapsed1 = time.time() - start
print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")
# Second request (same params, expected cache hit, no model computation)
start = time.time()
response2 = requests.get("http://localhost:8000/predict", params=params)
elapsed2 = time.time() - start
print("Second response:", response2.json(), f"(Time: {elapsed2:.6f}seconds)")
當運行這個命令時,應該看到第一個請求返回一個結果。然后第二個請求返回相同的結果,但明顯速度更快。例如,可能會發現第一次調用花費了幾十毫秒的時間(取決于模型的復雜性),而第二次調用可能只有幾毫秒或更少的時間。在使用輕量級模型的簡單演示中,差異可能很?。ㄒ驗槟P捅旧硭俣群芸欤?,但對于更大的模型來說,其效果非常顯著。
比較
為了更好地理解這一點,可以了解一下取得的成果:
- 無緩存:每個請求,即使是相同的請求,都會命中模型。如果模型每次預測需要100毫秒,那么10個相同的請求仍然需要約1000毫秒。
- 使用緩存:第一個請求需要全部命中(100毫秒),但接下來的9個相同的請求可能每個需要1~2毫秒(只是一個Redis查找和返回數據)。因此,這10個請求可能總共120毫秒,而不是1000毫秒,在這種情況下,速度提高了8倍。
在實際實驗中,緩存可以帶來數量級的性能提升。例如,在電子商務領域中,使用Redis意味著在微秒內返回重復請求的建議,而不必使用完整的模型服務管道重新計算它們。性能提升將取決于模型推理的成本。模型越復雜,從緩存重復調用中的收益越大。這也取決于請求模式:如果每個請求都是唯一的,緩存將無法發揮作用(沒有重復請求可以從內存中提供服務),但是許多應用程序確實會看到重疊的請求(例如,流行的搜索查詢,推薦的項目等)。
為了驗證Redis緩存是否正常存儲鍵值對可以直接對Redis緩存進行檢查。
結論
本文展示了FastAPI和Redis如何協同工作以加速機器學習模型服務。FastAPI提供了一個快速且易于構建的API層用于提供預測服務,Redis添加了一個緩存層,可以顯著減少重復計算的延遲和CPU負載。通過避免重復的模型調用,提高了響應速度,并使系統能夠使用相同的資源處理更多的請求。
原文標題:Accelerate Machine Learning Model Serving With FastAPI and Redis Caching,作者:Janvi Kumari