成人免费xxxxx在线视频软件_久久精品久久久_亚洲国产精品久久久_天天色天天色_亚洲人成一区_欧美一级欧美三级在线观看

Meta公司開源大數據模型SAM實戰演練 原創

發布于 2024-7-15 09:08
瀏覽
0收藏

本文首先介紹Meta公司開發的開源圖像分割模型SAM的架構,然后通過一個河流像素分割遙感任務的實戰案例展示SAM模型應用開發涉及的關鍵技術與模型優勢。

引言

當前,許多強大的開源基礎模型的發布,加上微調技術的不斷進步,已經帶來了機器學習和人工智能的新范式。整體來看,這場革命的核心在于轉換器模型(https://arxiv.org/pdf/1706.03762)。

雖然除了資金充足的公司之外,所有公司都曾經無法獲得高精度的特定領域模型;但是如今,基礎模型范式甚至能夠允許學生或獨立研究人員獲得適度的資源,以實現與最先進的專有模型相抗衡的結果。

Meta公司開源大數據模型SAM實戰演練-AI.x社區

微調技術可以極大地提高模型任務的性能

本文旨在探討Meta公司的分割一切模型(SAM:Segment Anything Model)在河流像素分割遙感任務中的應用。如果您想直接跳到有關項目代碼,那么這個項目的源文件可以在GitHub(https://github.com/geo-smart/water-surf/blob/main/book/chapters/masking_distributed.ipynb)上獲得,數據也可以在HuggingFace(https://huggingface.co/datasets/stodoran/elwha-segmentation-v1)上找到。當然,我還是建議您先閱讀一下本文。

項目需求

第一個任務是找到或創建一個合適的數據集。根據現有文獻,一個良好的SAM微調數據集將至少包含200–800張圖像。過去十年深度學習進步的一個關鍵教訓是,數據越多越好,這樣一來,更大規模的微調數據集就不會出錯。然而,基礎模型研究的一個重要目標是,允許即使是相對較小的數據集也足以實現強大的性能。

此外,我們還需要有一個HuggingFace帳戶,這可以在鏈接https://huggingface.co/join處創建。使用HuggingFace,我們可以隨時從任何設備輕松存儲和獲取數據集,這使得協作和再現性更加容易。

最后一個需求是,具備一臺帶有GPU的設備,我們可以在其上運行訓練工作流程。通過Google Colab(https://colab.research.google.com/)免費提供的Nvidia T4 GPU足夠強大,可以在12小時內對1000張圖像進行50個時期的最大SAM模型檢查點(sam-vit-huge)訓練。

為了避免在托管運行時因使用限制而影響進度,您可以安裝Google Drive并將每個模型檢查點保存在那里。或者,部署并連接到GCP虛擬機(https://console.cloud.google.com/marketplace/product/colab-marketplace-image-public/colab)以完全繞過限制。如果您以前從未使用過GCP,那么您就有資格獲得300美元的免費信貸,這足以支持對模型進行至少十幾次的訓練了。

理解SAM架構

在開始訓練之前,我們需要先來了解一下SAM模型的架構。該模型包含三個組件:一個是從稍經修改的掩碼自動編碼器(https://arxiv.org/pdf/2111.06377)得到的圖像編碼器,一個能夠處理各種提示類型的相當靈活的提示編碼器,還有一個快速輕量級的掩碼解碼器。這種設計架構背后的一個重要動機是允許在邊緣設備上(例如在瀏覽器中)進行快速、實時的分割,因為圖像嵌入只需要計算一次,并且掩碼解碼器可以在CPU上運行約50ms。

Meta公司開源大數據模型SAM實戰演練-AI.x社區

SAM的模型架構向我們展示了模型接受哪些輸入以及需要訓練模型的哪些部分(圖片來源于SAM GitHub:https://github.com/facebookresearch/segment-anything)。

理論上,圖像編碼器已經學會了嵌入圖像的最佳方式,包括識別形狀、邊緣和其他一般視覺特征等。類似地,在理論上,提示編碼器已經能夠以最優方式對提示進行編碼。掩碼解碼器是模型架構的一部分,它采用這些圖像和提示嵌入,并通過對圖像和提示嵌入式進行操作來實際創建掩碼。

因此,一種方法是在訓練期間凍結與圖像和提示編碼器相關聯的模型參數,并且僅更新掩碼解碼器權重。這種方法的優點是允許有監督和無監督的下游任務,因為控制點和邊界框提示都是自動的,并且可供人工使用。

Meta公司開源大數據模型SAM實戰演練-AI.x社區

圖中顯示了AutoSAM體系架構中使用的凍結SAM圖像編碼器和掩碼解碼器,以及過載提示編碼器(來源于AutoSAM論文:https://arxiv.org/pdf/2306.06370)。

另一種方法是使提示編碼器過載,凍結圖像編碼器和掩碼解碼器,并且只是簡單地不使用原始SAM掩碼編碼器。例如,AutoSAM體系架構使用基于Harmonic Dense Net的網絡來基于圖像本身生成提示嵌入。在本教程中,我們將介紹第一種方法,即凍結圖像和提示編碼器,只訓練掩碼解碼器,但這種替代方法的代碼可以在AutoSAM GitHub(https://github.com/talshaharabany/AutoSAM/blob/main/inference.py)和論文(https://arxiv.org/pdf/2306.06370)中找到。

配置提示

接下來的一步是確定模型在推理過程中會收到什么類型的提示,以便我們可以在訓練時提供這種類型的提示。就我個人而言,考慮到自然語言處理的不可預測/不一致性,我不建議在任何嚴肅的計算機視覺項目架構中使用文本提示。剩下的解決方案就需要依賴控制點和邊界框技術了;但是,最終的選擇還要取決于特定數據集的特定性質,盡管有關文獻中已經指出邊界框方案的表現相當一致地優于控制點方案。

造成這種情況的原因尚不完全清楚,但可能是以下任何因素之一,或者是這些因素的組合:

  • 在推理時(當真實值掩碼未知時),好的控制點比邊界框更難選擇。
  • 可能的點提示的空間比可能的邊界框提示的空間大幾個數量級,因此它沒有經過徹底的訓練。
  • 最初的SAM模型作者主要專注于模型的零樣本和少樣本(根據人工提示交互計算)功能,因此預訓練可能更多地關注邊界框。

無論如何,河流分割實際上是一種罕見的情況;在這種情況下,點提示方案實際上優于邊界框(盡管只是輕微的,即使是在非常有利的域中)。假設在河流的任何圖像中,水體將從圖像的一端延伸到另一端,任何包含的邊界框幾乎總是覆蓋圖像的大部分。因此,河流非常不同部分的邊界框提示看起來非常相似。理論上,這意味著邊界框為模型提供的信息比控制點少得多;因此,導致性能較差。

Meta公司開源大數據模型SAM實戰演練-AI.x社區

控制點、邊界框提示和疊加在兩個樣本訓練圖像上的真實分割

請注意,在上圖中,盡管兩條河流部分的真實分割掩碼完全不同,但它們各自的邊界框幾乎相同,而它們的點提示(相對而言)差異更大。

另一個需要考慮的重要因素是在推理時生成輸入提示的容易程度。如果您希望在循環執行階段有人工介入,那么請注意邊界框和控制點在推理階段都是相當瑣碎的。然而,如果您打算使用一個完全自動化的架構方案,那么回答這些問題將變得更加復雜。

無論是使用控制點還是邊界框,生成提示通常首先包括估計感興趣對象的粗略掩碼。邊界框可以只是包裹粗略掩碼的最小框,而控制點需要從粗略掩碼中采樣。這意味著,當真實值掩碼未知時,邊界框更容易獲得,因為感興趣對象的估計掩碼只需要大致匹配真實對象的相同大小和位置;而對于控制點,估計掩碼將需要更緊密地匹配對象的輪廓。

Meta公司開源大數據模型SAM實戰演練-AI.x社區

當使用估計的掩碼而不是真實值時,控制點的放置可能包括錯誤標注的點,而邊界框通常位于正確的位置

對于河流分割,如果我們可以同時使用RGB和NIR,那么我們可以使用光譜指數閾值方法來獲得我們的粗略掩模。如果我們只能使用RGB模式,我們可以將圖像轉換為HSV模式,并對特定色調、飽和度和值范圍內的所有像素設置閾值。然后,我們可以移除低于特定大小閾值的連接內容,并使用skimage.morphology子模塊中的erosion函數來確保我們的掩模中只有1個像素是朝向藍色大斑點中心的像素。

模型訓練

為了訓練我們的模型,我們需要一個包含所有訓練數據的數據加載器,我們可以在每個訓練時期對這些數據進行迭代。當我們從HuggingFace加載數據集時,它采用datasets.Dataset類的形式。如果數據集是私有的,請確保首先安裝HuggingFace CLI并使用“!huggingface-cli login”方式登錄。

from datasets import load_dataset, load_from_disk, Dataset
hf_dataset_name = "stodoran/elwha-segmentation-v1"
training_data = load_dataset(hf_dataset_name, split="train")
validation_data = load_dataset(hf_dataset_name, split="validation")

然后,我們需要編寫自己的自定義數據集類,該類不僅返回任何索引的圖像和標簽,還返回提示詞信息。下面是一個可以同時處理控制點和邊界框提示的實現。要完成初始化工作,需要一個HuggingFace datasets.Dataset實例和SAM模型的處理器實例。

from torch.utils.data import Dataset
class PromptType:
CONTROL_POINTS = "pts"
BOUNDING_BOX = "bbox"
class SAMDataset(Dataset):
def __init__(
self, 
dataset, 
processor, 
prompt_type = PromptType.CONTROL_POINTS,
num_positive = 3,
num_negative = 0,
erode = True,
multi_mask = "mean",
perturbation = 10,
image_size = (1024, 1024),
mask_size = (256, 256),
):
#將所有值賦給self
...

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
datapoint = self.dataset[idx]
input_image = cv2.resize(np.array(datapoint["image"]), self.image_size)
ground_truth_mask = cv2.resize(np.array(datapoint["label"]), self.mask_size)

if self.prompt_type == PromptType.CONTROL_POINTS:
inputs = self._getitem_ctrlpts(input_image, ground_truth_mask)
elif self.prompt_type == PromptType.BOUNDING_BOX:
inputs = self._getitem_bbox(input_image, ground_truth_mask)

inputs["ground_truth_mask"] = ground_truth_mask
return inputs

我們還必須定義SAMDataset_getitem_ctrlpts和SAMDataset_getitem_box函數,盡管如果您只計劃使用一種提示類型,那么您可以重構代碼以直接處理SAMDataset.__getitem__中的該類型,并刪除幫助類工具函數。

class SAMDataset(Dataset):
...
def _getitem_ctrlpts(self, input_image, ground_truth_mask):
# 獲取控制點提示。請參閱GitHub獲取該函數的源代碼,或將其替換為您自己的點選擇算法。
input_points, input_labels = generate_input_points(
num_positive=self.num_positive,
num_negative=self.num_negative,
mask=ground_truth_mask,
dynamic_distance=True,
erode=self.erode,
)
input_points = input_points.astype(float).tolist()
input_labels = input_labels.tolist()
input_labels = [[x] for x in input_labels]

# 為模型準備圖像和提示。
inputs = self.processor(
input_image,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt"
)

#刪除處理器默認添加的批次維度。
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
inputs["input_labels"] = inputs["input_labels"].squeeze(1)

return inputs

def _getitem_bbox(self, input_image, ground_truth_mask):
#獲取邊界框提示。
bbox = get_input_bbox(ground_truth_mask, perturbation=self.perturbation)

#為模型準備圖像和提示。
inputs = self.processor(input_image, input_boxes=[[bbox]], return_tensors="pt")
inputs = {k: v.squeeze(0) for k, v in inputs.items()} # 刪除處理器默認添加的批次維度。

return inputs

將所有這些功能組合到一起,我們可以創建一個函數,該函數在給定HuggingFace數據集的任一部分的情況下創建并返回PyTorch數據加載器。編寫返回數據加載器的函數,而不僅僅是用相同的代碼執行單元,這不僅是編寫靈活和可維護代碼的好方法,而且如果您計劃使用HuggingFace Accelerate(https://huggingface.co/docs/accelerate/index)來運行分布式訓練的話,這也是必要的。

from transformers import SamProcessor
from torch.utils.data import DataLoader

def get_dataloader(
hf_dataset,
model_size = "base",  # One of "base", "large", or "huge" 
batch_size = 8, 
prompt_type = PromptType.CONTROL_POINTS,
num_positive = 3,
num_negative = 0,
erode = True,
multi_mask = "mean",
perturbation = 10,
image_size = (256, 256),
mask_size = (256, 256),
):
processor = SamProcessor.from_pretrained(f"facebook/sam-vit-{model_size}")

sam_dataset = SAMDataset(
dataset=hf_dataset, 
processor=processor, 
prompt_type=prompt_type,
num_positive=num_positive,
num_negative=num_negative,
erode=erode,
multi_mask=multi_mask,
perturbation=perturbation,
image_size=image_size,
mask_size=mask_size,
)
dataloader = DataLoader(sam_dataset, batch_size=batch_size, shuffle=True)

return dataloader

在此之后,訓練只需加載模型、凍結圖像和提示編碼器,并進行所需次數的迭代訓練。

model = SamModel.from_pretrained(f"facebook/sam-vit-{model_size}")
optimizer = AdamW(model.mask_decoder.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Train only the decoder.
for name, param in model.named_parameters():
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)

以下列出的是訓練過程循環部分的基本框架代碼。請注意,為了簡潔起見,forward_pass、calculate loss、evaluate_mode和save_model_checkpoint函數被省略了,但GitHub上提供了實現。正向傳遞碼根據提示類型略有不同,損失計算也需要基于提示類型的特殊情況;當使用點提示時,SAM模型為每個單個輸入點返回一個預測掩碼,因此為了獲得可以與真實數據進行比較的單個掩碼,需要對預測掩碼進行平均,或者需要選擇最佳預測掩碼(基于SAM的預測IoU分數來識別)。

train_losses = []
validation_losses = []
epoch_loop = tqdm(total=num_epochs, position=epoch, leave=False)
batch_loop = tqdm(total=len(train_dataloader), position=0, leave=True)

while epoch < num_epochs:
epoch_losses = []

batch_loop.n = 0  #循環重置
for idx, batch in enumerate(train_dataloader):
# 正向傳遞
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
outputs = forward_pass(model, batch, prompt_type)

#計算損失值
ground_truth_masks = batch["ground_truth_mask"].float()
train_loss = calculate_loss(outputs, ground_truth_masks, prompt_type, loss_fn, multi_mask="best")
epoch_losses.append(train_loss)

# 反向傳遞與優化環節
optimizer.zero_grad()
accelerator.backward(train_loss)
optimizer.step()
lr_scheduler.step()

batch_loop.set_description(f"Train Loss: {train_loss.item():.4f}")
batch_loop.update(1)

validation_loss = evaluate_model(model, validation_dataloader, accelerator.device, loss_fn)
train_losses.append(torch.mean(torch.Tensor(epoch_losses)))
validation_losses.append(validation_loss)

if validation_loss < best_loss:
save_model_checkpoint(
accelerator,
best_checkpoint_path,
model,
optimizer,
lr_scheduler,
epoch,
train_history,
validation_loss,
train_losses,
validation_losses,
loss_config,
model_descriptor=model_descriptor,
)
best_loss = validation_loss

epoch_loop.set_description(f"Best Loss: {best_loss:.4f}")
epoch_loop.update(1)
epoch += 1

微調結果分析

對于艾爾瓦河項目,最佳設置是在不到12小時的時間內使用GCP虛擬機實例,使用超過1k個分割掩碼的數據集訓練成功“sam-vit-base”模型。

與基準型SAM相比,微調顯著提高了性能,中值掩碼從不可用變為高度準確。

Meta公司開源大數據模型SAM實戰演練-AI.x社區

相對于基于默認提示詞的基準型SAM模型,微調后的SAM模型極大地提高了分割性能

需要注意的一個重要事實是,1k河流圖像的訓練數據集是不完美的,分割標簽在正確分類的像素數量上變化很大。因此,上述指標是在225幅河流圖像的像素完美數據集上計算出來的。

實驗過程中,我們觀察到的一個有趣的行為是,模型學會了從不完美的訓練數據中進行歸納。當在訓練樣本包含明顯錯誤分類的數據點上進行評估時,我們可以觀察到模型預測避免了誤差。請注意,顯示訓練樣本的頂行中的圖像包含的掩碼不會一直填充到河岸,而顯示模型預測的底行則更緊密地分割河流邊界。

Meta公司開源大數據模型SAM實戰演練-AI.x社區

即使訓練數據不完美,經微調的SAM模型也能帶來令人印象深刻的泛化效果。請注意,與訓練數據(頂行)相比,預測(底行)的錯誤分類更少,并且河流的填充程度更高。

結論

如果您已經順利完成本文中的實例內容,那么祝賀您!您已經學會了為任何下游愿景任務完全微調Meta的分割一切模型SAM所需的一切!

雖然您的微調工作流程無疑與本教程中介紹的實施方式不同,但從閱讀本教程中獲得的知識不僅會影響到您的細分項目,還會影響到未來的深度學習項目及其他項目。

最后,希望您繼續探索機器學習的世界,保持好奇心,并一如既往地快樂編程!

附錄

本文實例中使用的數據集是Elwha V1數據集(https://huggingface.co/datasets/stodoran/elwha-segmentation-v1),該數據集由華盛頓大學的GeoSMART研究實驗室(https://geo-smart.github.io/)創建,用于將微調的大型視覺變換器應用于地理空間分割任務的研究項目。本文描述的內容代表了即將發表的論文的精簡版和一個更易于實現的版本。在高水平上,Elwha V1數據集由SAM檢查點的后處理模型預測組成,該檢查點使用Buscombe等人(https://zenodo.org/records/10155783)發布并在多學科研究數據知識庫和文獻資源網站Zenodo上發布的標注正射影像的子集進行了微調。

譯者介紹

朱先忠,51CTO社區編輯,51CTO專家博客、講師,濰坊一所高校計算機教師,自由編程界老兵一枚。

原文標題:Learn Transformer Fine-Tuning and Segment Anything,作者:Stefan Todoran。

鏈接:??https://towardsdatascience.com/learn-transformer-fine-tuning-and-segment-anything-481c6c4ac802?。

?著作權歸作者所有,如需轉載,請注明出處,否則將追究法律責任
收藏
回復
舉報
回復
相關推薦
主站蜘蛛池模板: 中文字幕在线一区 | 国产精品日韩欧美一区二区三区 | 久久国产精品-国产精品 | 精产国产伦理一二三区 | 亚洲精品在线播放 | 欧美激情久久久 | 韩国毛片视频 | 中文二区 | 久久6| 亚洲精品久久久一区二区三区 | 亚洲最新网址 | 91网在线观看 | 日韩毛片免费看 | 欧美网站一区 | 免费播放一级片 | 精品综合久久 | 一区二区三区四区在线 | 亚洲二区视频 | 午夜精品久久久久久久星辰影院 | 超碰在线国产 | 91视视频在线观看入口直接观看 | 亚洲视频欧美视频 | 亚洲成人日韩 | 欧美一区二区三区在线观看 | 成人免费视频网址 | 91精品国产综合久久久亚洲 | 久久久精品久久久 | 国产成人高清视频 | 欧美日韩专区 | 成人久久18免费网站图片 | 国产精品久久久久久av公交车 | 亚洲中午字幕 | 成人自拍视频 | 中文字幕一区二区三区精彩视频 | 色视频网站在线观看 | 91av在线电影 | 国产一级片一区二区 | 日韩一区中文字幕 | 成年免费大片黄在线观看岛国 | 颜色网站在线观看 | 欧美a级网站 |