使用 ML.NET 實現圖像分類:從入門到實踐
ML.NET是微軟開發的開源機器學習框架,讓.NET開發者能夠直接在.NET應用程序中集成機器學習功能。本文將詳細介紹如何使用ML.NET實現圖像分類,包括環境搭建、數據準備、模型訓練等完整流程。
環境準備
- Visual Studio 2022
- .NET 6.0或更高版本
- 需要安裝的NuGet包:
Microsoft.ML
Microsoft.ML.Vision
Microsoft.ML.ImageAnalytics
SciSharp.TensorFlow.Redist (版本2.3.1)
圖片
項目結構
ImageClassification/
├── Program.cs
├── assets/ # 存放訓練圖片
│ ├── CD/ # 有裂縫的圖片
│ └── UD/ # 無裂縫的圖片
└── workspace/ # 存放模型文件
代碼實現
一、創建數據模型類
// 原始圖像數據類
public class ImageData
{
public string ImagePath { get; set; }
public string Label { get; set; }
}
ImageData 類是用來表示和存儲圖像的基本信息的數據結構,主要用于數據加載和預處理階段。它包含兩個關鍵屬性:
1.ImagePath
- 存儲圖像文件的完整路徑
- 用于后續加載和訪問圖像文件
- 是一個字符串類型的屬性
2.Label
- 存儲圖像的分類標簽/類別
- 表示圖像所屬的類別(如示例中的"有裂縫"或"無裂縫")
- 是一個字符串類型的屬性
主要用途:
- 數據加載:在從目錄加載圖像數據時,用于初步組織和存儲圖像信息
- 數據組織:將文件系統中的圖像與其對應的分類標簽關聯起來
// 模型輸入類
public class ModelInput
{
public byte[] Image { get; set; }
public UInt32 LabelAsKey { get; set; }
public string ImagePath { get; set; }
public string Label { get; set; }
}
3.Image (byte[] 類型)
- 存儲圖像的字節數組表示
- 這是模型訓練和預測所必需的輸入格式
- 模型需要這種類型的圖像數據來進行訓練
4.LabelAsKey (UInt32 類型)
- 是 Label 的數值表示形式
- 將分類標簽轉換為數值形式,因為機器學習模型要求輸入采用數值格式
5.ImagePath (string 類型)
- 存儲圖像的完整路徑
- 用于方便訪問原始圖像文件
6.Label (string 類型)
- 圖像所屬的類別
- 這是需要預測的目標值
- 用于訓練時的標簽信息
重要說明:
- 在實際訓練和預測中,只有 Image 和 LabelAsKey 這兩個屬性被用于模型訓練和預測
- ImagePath 和 Label 屬性主要是為了方便訪問和追蹤原始數據,不直接參與模型計算
- 這個類是連接原始圖像數據和模型訓練需求的橋梁,將各種必要的信息整合在一起
// 模型輸出類
public class ModelOutput
{
public string ImagePath { get; set; }
public string Label { get; set; }
public string PredictedLabel { get; set; }
}
7.ImagePath (string 類型)
- 存儲圖像的完整文件路徑
- 用于追蹤和引用原始圖像文件
8.Label (string 類型)
- 存儲圖像的原始/真實類別標簽
- 這是圖像實際應該屬于的類別
9.PredictedLabel (string 類型)
- 存儲模型預測的類別標簽
- 這是模型通過分析圖像后預測出的類別
重要說明:
- 在實際預測過程中,只有 PredictedLabel 是必需的,因為它包含模型的預測結果
- ImagePath 和 Label 屬性主要用于評估和驗證目的,方便比較預測結果與實際標簽的差異
二、 圖像加載工具方法
private static IEnumerable<ImageData> LoadImagesFromDirectory(string folder, bool useFolderNameAsLabel = true)
{
var files = Directory.GetFiles(folder, "*", searchOption: SearchOption.AllDirectories);
foreach (var file in files)
{
if ((Path.GetExtension(file) != ".jpg") && (Path.GetExtension(file) != ".png"))
continue;
var label = Path.GetFileName(file);
if (useFolderNameAsLabel)
label = Directory.GetParent(file).Name;
else
{
for (int index = 0; index < label.Length; index++)
{
if (!char.IsLetter(label[index]))
{
label = label.Substring(0, index);
break;
}
}
}
yield return new ImageData()
{
ImagePath = file,
Label = label
};
}
}
三、主程序實現
class Program
{
static void Main(string[] args)
{
// 初始化ML.NET環境
MLContext mlContext = new MLContext();
// 設置路徑
var projectDirectory = Path.GetFullPath(Path.Combine(AppContext.BaseDirectory, "../../../"));
var workspaceRelativePath = Path.Combine(projectDirectory, "workspace");
var assetsRelativePath = Path.Combine(projectDirectory, "assets");
// 加載數據
IEnumerable<ImageData> images = LoadImagesFromDirectory(folder: assetsRelativePath, useFolderNameAsLabel: true);
IDataView imageData = mlContext.Data.LoadFromEnumerable(images);
IDataView shuffledData = mlContext.Data.ShuffleRows(imageData);
// 數據預處理
var preprocessingPipeline = mlContext.Transforms.Conversion.MapValueToKey(
inputColumnName: "Label",
outputColumnName: "LabelAsKey")
.Append(mlContext.Transforms.LoadRawImageBytes(
outputColumnName: "Image",
imageFolder: assetsRelativePath,
inputColumnName: "ImagePath"));
IDataView preProcessedData = preprocessingPipeline
.Fit(shuffledData)
.Transform(shuffledData);
// 數據集分割
var trainSplit = mlContext.Data.TrainTestSplit(data: preProcessedData, testFraction: 0.3);
var validationTestSplit = mlContext.Data.TrainTestSplit(trainSplit.TestSet);
// 配置訓練選項
var classifierOptions = new ImageClassificationTrainer.Options()
{
FeatureColumnName = "Image",
LabelColumnName = "LabelAsKey",
ValidationSet = validationTestSplit.TrainSet,
Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
MetricsCallback = (metrics) => Console.WriteLine(metrics),
TestOnTrainSet = false,
ReuseTrainSetBottleneckCachedValues = true,
ReuseValidationSetBottleneckCachedValues = true,
WorkspacePath = workspaceRelativePath
};
// 定義訓練管道
var trainingPipeline = mlContext.MulticlassClassification.Trainers
.ImageClassification(classifierOptions)
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
// 訓練模型
Console.WriteLine("*** 開始訓練模型 ***");
ITransformer trainedModel = trainingPipeline.Fit(trainSplit.TrainSet);
// 進行預測
ClassifySingleImage(mlContext, validationTestSplit.TestSet, trainedModel);
ClassifyImages(mlContext, validationTestSplit.TestSet, trainedModel);
}
}
數據預處理說明
第一步:標簽轉換
mlContext.Transforms.Conversion.MapValueToKey(
inputColumnName: "Label",
outputColumnName: "LabelAsKey")
- 將字符串類型的標簽("Label")轉換為數值類型("LabelAsKey")
- 例如:"CD" -> 0, "UD" -> 1
- 這是必需的,因為機器學習模型需要數值形式的標簽
- 輸入是 ImageData 類中的 Label 屬性
- 輸出存儲在 ModelInput 類的 LabelAsKey 屬性中
第二步:圖像加載
mlContext.Transforms.LoadRawImageBytes(
outputColumnName: "Image",
imageFolder: assetsRelativePath,
inputColumnName: "ImagePath")
- 將圖像文件轉換為字節數組格式
- `outputColumnName: "Image"`: 輸出到 ModelInput 類的 Image 屬性
- `imageFolder: assetsRelativePath`: 指定圖像文件所在的根目錄
- `inputColumnName: "ImagePath"`: 使用 ImageData 類中的 ImagePath 屬性
配置訓練選項
1.FeatureColumnName = "Image"
- 指定用作模型輸入的列名
- 這里使用"Image"列,它包含圖像的字節數組數據
2.LabelColumnName = "LabelAsKey"
- 指定要預測的目標值列名
- 使用"LabelAsKey"列,它是標簽的數值表示形式
3.ValidationSet = validationTestSplit.TrainSet
- 指定用于驗證的數據集
- 用于在訓練過程中評估模型性能
4.Arch = ImageClassificationTrainer.Architecture.ResnetV2101
- 指定使用的預訓練模型架構
- 這里使用 ResNet v2 的101層變體
- ResNet是一個預訓練模型,可以將圖像分為1000個類別
5.MetricsCallback = (metrics) => Console.WriteLine(metrics)
- 用于在訓練過程中跟蹤和顯示訓練指標
- 通過控制臺輸出訓練進度和性能指標
6.TestOnTrainSet = false
- 設置是否在訓練集上測試模型
- false表示不在訓練集上測試,避免過擬合
7.ReuseTrainSetBottleneckCachedValues = true
- 是否重用訓練集的瓶頸層計算結果
- true表示緩存并重用這些值,可以顯著減少訓練時間
- 適用于訓練數據不變但需要調整其他參數的情況
8.ReuseValidationSetBottleneckCachedValues = true
- 是否重用驗證集的瓶頸層計算結果
- 與上面類似,但作用于驗證數據集
9.WorkspacePath = workspaceRelativePath
- 指定存儲工作文件的目錄路徑
- 用于保存計算的瓶頸值和模型的.pb版本
- 便于后續重用和模型部署
這些參數的配置對模型的訓練效果和效率有重要影響,可以根據具體需求調整這些參數來優化模型性能。
定義訓練管道
- 圖像分類訓練器
mlContext.MulticlassClassification.Trainers.ImageClassification(classifierOptions)
- 使用多分類分類器進行圖像分類
- 基于之前定義的 classifierOptions 配置
- 使用遷移學習方法,基于預訓練的 ResNet 模型
- 主要功能:
- 提取圖像特征
- 訓練分類器
- 生成預測模型
- 預測標簽轉換
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"))
- 將模型輸出的數值預測結果轉換回原始標簽
- 與前面的 MapValueToKey 操作相反
- 例如:將 0 轉回 "CD",1 轉回 "UD"
- 確保最終輸出是人類可讀的標簽
整個訓練管道的工作流程:
- 接收預處理后的數據(圖像字節數組和數值標簽)
- 通過深度學習模型進行特征提取和分類
- 將數值預測結果轉換為原始標簽類別
- 輸出最終的分類結果
四、預測方法實現
private static void ClassifySingleImage(MLContext mlContext, IDataView data, ITransformer trainedModel)
{
var predictionEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel);
var image = mlContext.Data.CreateEnumerable<ModelInput>(data, reuseRowObject: true).First();
var prediction = predictionEngine.Predict(image);
Console.WriteLine($"單張圖片分類結果:");
Console.WriteLine($"圖片: {Path.GetFileName(prediction.ImagePath)}");
Console.WriteLine($"實際類別: {prediction.Label}");
Console.WriteLine($"預測類別: {prediction.PredictedLabel}");
}
private static void ClassifyImages(MLContext mlContext, IDataView data, ITransformer trainedModel)
{
IDataView predictionData = trainedModel.Transform(data);
var predictions = mlContext.Data.CreateEnumerable<ModelOutput>(predictionData, reuseRowObject: true)
.Take(10);
Console.WriteLine("\n批量圖片分類結果:");
foreach (var prediction in predictions)
{
Console.WriteLine($"圖片: {Path.GetFileName(prediction.ImagePath)}");
Console.WriteLine($"實際類別: {prediction.Label}");
Console.WriteLine($"預測類別: {prediction.PredictedLabel}\n");
}
}
五、執行
訓練速度比較慢
圖片
圖片
圖片
實際與預測都是CD。
圖片
這里會發現預測與實際是有出入的。
模型優化建議
1.增加訓練數據量: 收集更多的樣本數據可以提高模型的泛化能力。
2.數據增強:
- 對現有圖片進行旋轉、翻轉、縮放等操作
- 調整亮度、對比度
- 添加噪聲
3.調整超參數:
- 增加訓練輪數(Epoch)
- 調整學習率
- 嘗試不同的批次大小
4.使用不同的預訓練模型:
- ResNet不同版本
- Inception
- MobileNet
總結
本文詳細介紹了如何使用ML.NET實現圖像分類功能。通過使用遷移學習和預訓練模型,我們可以快速構建高質量的圖像分類應用。ML.NET提供了簡單易用的API,讓.NET開發者能夠方便地將機器學習集成到應用程序中。