數據工程中的單元測試完全指南(上)
在數據工程領域中,經常被忽視的一項實踐是單元測試。許多人可能認為單元測試僅僅是一種軟件開發方法論,但事實遠非如此。隨著我們努力構建穩健、無錯誤的數據流水線和SQL數據模型,單元測試在數據工程中的價值變得越來越清晰。
本文帶你深入探索如何將這些成熟的軟件工程實踐應用到數據工程中。
1. 單元測試的重要性
在數據工程的背景下,采用單元測試可以確保您的數據和業務邏輯的準確性,進而產出高質量的數據,獲得您的數據分析師、科學家和決策者對數據的信任。
2. 單元測試數據流水線
數據流水線通常涉及復雜的數據抽取、轉換和加載(ETL)操作序列,出錯的可能性很大。為了對這些操作進行單元測試,我們將流水線拆分為單個組件,并對每個組件進行獨立驗證。
以一個簡單的流水線為例,該流水線從CSV文件中提取數據,通過清除空值來轉換數據,然后將其加載到數據庫中。以下是使用pandas的基于Python的示例:
import pandas as pd
from sqlalchemy import create_engine
# 加載CSV文件的函數
def load_data(file_name):
data = pd.read_csv(file_name)
return data
# 清理數據的函數
def clean_data(data):
data = data.dropna()
return data
# 將數據保存到SQL數據庫的函數
def save_data(data, db_string, table_name):
engine = create_engine(db_string)
data.to_sql(table_name, engine, if_exists='replace')
# 運行數據流水線
data = load_data('data.csv')
data = clean_data(data)
save_data(data, 'sqlite:///database.db', 'my_table')
為了對這個流水線進行單元測試,我們使用像pytest這樣的庫為每個函數編寫單獨的測試。
在這個示例中,有三個主要的函數:load_data、clean_data和save_data。我們會為每個函數編寫測試。對于load_data和save_data,需要設置一個臨時的CSV文件和SQLite數據庫,可以使用pytest庫的fixture功能來實現。
import os
import pandas as pd
import pytest
from sqlalchemy import create_engine, inspect
# 使用pytest fixture來設置臨時的CSV文件和SQLite數據庫
@pytest.fixture
def csv_file(tmp_path):
data = pd.DataFrame({
'name': ['John', 'Jane', 'Doe'],
'age': [34, None, 56] # Jane的年齡缺失
})
file_path = tmp_path / "data.csv"
data.to_csv(file_path, index=False)
return file_path
@pytest.fixture
def sqlite_db(tmp_path):
file_path = tmp_path / "database.db"
return 'sqlite:///' + str(file_path)
def test_load_data(csv_file):
data = load_data(csv_file)
assert 'name' in data.columns
assert 'age' in data.columns
assert len(data) == 3
def test_clean_data(csv_file):
data = load_data(csv_file)
data = clean_data(data)
assert data['age'].isna().sum() == 0
assert len(data) == 2 # Jane的記錄應該被刪除
def test_save_data(csv_file, sqlite_db):
data = load_data(csv_file)
data = clean_data(data)
save_data(data, sqlite_db, 'my_table')
# 檢查數據是否保存正確
engine = create_engine(sqlite_db)
inspector = inspect(engine)
tables = inspector.get_table_names()
assert 'my_table' in tables
loaded_data = pd.read_sql('my_table', engine)
assert len(loaded_data) == 2 # 只應該存在John和Doe的記錄
這里是另一個例子:假設您有一個從CSV文件中加載數據并將其中的“日期”列從字符串轉換為日期時間的流水線:
def convert_date(data, date_column):
data[date_column] = pd.to_datetime(data[date_column])
return data
為上述函數編寫的單元測試將傳入具有已知日期字符串格式的DataFrame。然后,它將驗證函數是否正確將日期轉換為日期時間對象,并且它是否適當處理無效格式。
我們為上述場景編寫一個單元測試。該測試首先使用有效日期檢查函數,斷言輸出DataFrame中的“date”列確實是datetime類型,并且值與預期相符。然后,它檢查在給出無效日期時,函數是否正確引發了ValueError。
import pandas as pd
import pytest
def test_convert_date():
# 使用有效日期進行測試
test_data = pd.DataFrame({
'date': ['2021-01-01', '2021-01-02']
})
converted_data = convert_date(test_data.copy(), 'date')
assert pd.api.types.is_datetime64_any_dtype(converted_data['date'])
assert converted_data.loc[0, 'date'] == pd.Timestamp('2021-01-01')
assert converted_data.loc[1, 'date'] == pd.Timestamp('2021-01-02')
# 使用無效日期進行測試
test_data = pd.DataFrame({
'date': ['2021-13-01'] # 這個日期是無效的,因為沒有第13個月
})
with pytest.raises(ValueError):
convert_date(test_data, 'date')
以下是最后一個例子:假設您有一個加載數據并進行聚合的流水線,計算每個地區的總銷售額:
def aggregate_sales(data):
aggregated = data.groupby('region').sales.sum().reset_index()
return aggregated
為該函數編寫的單元測試將向其傳遞具有各個地區銷售數據的DataFrame。測試將驗證函數是否正確計算每個地區的總銷售額。
我們為該函數編寫一個單元測試。在這個測試中,我們首先向aggregate_sales函數傳遞一個具有已知銷售數據的DataFrame,并檢查它是否正確聚合了銷售額。然后,向其傳遞一個沒有銷售數據的DataFrame,并檢查它是否正確將這些銷售額聚合為0。這樣可以確保函數正確處理典型情況和邊緣情況。
以下是使用pytest庫為aggregate_sales函數編寫單元測試的示例:
import pandas as pd
import pytest
def test_aggregate_sales():
# 各個地區的銷售數據
test_data = pd.DataFrame({
'region': ['North', 'North', 'South', 'South', 'East', 'East', 'West', 'West'],
'sales': [100, 200, 300, 400, 500, 600, 700, 800]
})
aggregated = aggregate_sales(test_data)
assert aggregated.loc[aggregated['region'] == 'North', 'sales'].values[0] == 300
assert aggregated.loc[aggregated['region'] == 'South', 'sales'].values[0] == 700
assert aggregated.loc[aggregated['region'] == 'East', 'sales'].values[0] == 1100
assert aggregated.loc[aggregated['region'] == 'West', 'sales'].values[0] == 1500
# 沒有銷售數據的測試
test_data = pd.DataFrame({
'region': ['North', 'South', 'East', 'West'],
'sales': [0, 0, 0, 0]
})
aggregated = aggregate_sales(test_data)
assert aggregated.loc[aggregated['region'] == 'North', 'sales'].values[0] == 0
assert aggregated.loc[aggregated['region'] == 'South', 'sales'].values[0] == 0
assert aggregated.loc[aggregated['region'] == 'East', 'sales'].values[0] == 0
assert aggregated.loc[aggregated['region'] == 'West', 'sales'].values[0] == 0