成人免费xxxxx在线视频软件_久久精品久久久_亚洲国产精品久久久_天天色天天色_亚洲人成一区_欧美一级欧美三级在线观看

從幾個“補丁”中重建完整圖像 | 構(gòu)建可擴展學習器的掩模自編碼器

開發(fā)
在這個視覺transformer系列的這一部分,我將使用PyTorch從零開始構(gòu)建掩模自編碼器視覺transformer。

到目前為止,我們已經(jīng)詳細轉(zhuǎn)換了各種重要的ViT架構(gòu)。在這個視覺transformer系列的這一部分,我將使用PyTorch從零開始構(gòu)建掩模自編碼器視覺transformer。不再拖延,讓我們直接進入主題!

掩模自編碼器

Mae是一種自監(jiān)督學習方法,這意味著它沒有預先標記的目標數(shù)據(jù),而是在訓練時利用輸入數(shù)據(jù)。這種方法主要涉及遮蔽圖像的75%的補丁。因此,在創(chuàng)建補丁(H/補丁大小,W/補丁大小)之后,其中H和W是圖像的高度和寬度,我們遮蔽75%的補丁,只使用其余的補丁并將其輸入到標準的ViT中。這里的主要目標是僅使用圖像中已知的補丁重建缺失的補丁。

輸入(75%的補丁被遮蔽) | 目標(重建缺失的像素)

MAE主要包含這三個組件:

  • 隨機遮蔽
  • 編碼器
  • 解碼器

1.隨機掩蓋

這就像選擇圖像的隨機補丁,然后掩蓋其中的3/4一樣簡單。然而,官方實現(xiàn)使用了不同但更有效的技術(shù)。


def random_masking(x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """

        B, T, D = x.shape  
        len_keep = int(T * (1 - mask_ratio))

        # creating noise of shape (B, T) to latter generate random indices
        noise = torch.rand(B, T, device=x.device)  

        # sorting the noise, and then ids_shuffle to keep the original indexe format
        ids_shuffle = torch.argsort(noise, dim=1)  
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # gathering the first few samples
        ids_keep = ids_shuffle[:, :len_keep]
        x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([B, T], device=x.device)
        mask[:, :len_keep] = 0 

        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x, mask, ids_restore
  • 假設(shè)輸入形狀是(B, T, C)。這里我們首先創(chuàng)建一個形狀為(B, T)的隨機張量,然后將其傳遞給argsort,這將為我們提供一個排序的索引張量——例如,torch.argsort([0.3, 0.4, 0.2]) = [2, 0, 1]。
  • 我們還將ids_shuffle傳遞給另一個argsort以獲取ids_restore。這只是一個具有原始索引格式的張量。
  • 接下來,我們收集我們想要保留的標記。
  • 生成二進制掩模,并將要保留的標記標記為0,其余標記為1。
  • 最后,對掩模進行解洗牌,這里我們創(chuàng)建的ids_restore將有助于生成表示,掩模應該具有的。即哪些索引的標記被遮蔽為0或1,與原始輸入有關(guān)?

注意:與在隨機位置創(chuàng)建隨機補丁不同,官方實現(xiàn)使用了不同的技術(shù)。

為圖像生成隨機索引。就像我們在ids_shuffle中所做的那樣。然后獲取前25%的索引(int(T*(1–3/4))或int(T/4)。我們只使用前25%的隨機索引并遮蔽其余部分。

然后我們用ids_restore中原始索引的順序幫助對掩模進行重新排序(解洗牌)。因此,在收集之前,掩模的前25%為0。但記住這些是隨機索引,這就是為什么我們重新排序以獲得掩模應該在的確切索引。

2.編碼器


class MaskedAutoEncoder(nn.Module):
    def __init__(self, emb_size=1024, decoder_emb_size=512, patch_size=16, num_head=16, encoder_num_layers=24, decoder_num_layers=8, in_channels=3, img_size=224):
        super().__init__()      
        self.patch_embed = PatchEmbedding(emb_size = emb_size)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_emb_size))
        self.encoder_transformer = nn.Sequential(*[Block(emb_size, num_head) for _ in range(encoder_num_layers)])

    def encoder(self, x, mask_ratio):
        x = self.patch_embed(x)

        cls_token = x[:, :1, :]
        x = x[:, 1:, :] 

        x, mask, restore_id = random_masking(x, mask_ratio)

        x = torch.cat((cls_token, x), dim=1)

        x = self.encoder_transformer(x)

        return x, mask, restore_id

PatchEmbedding和Block是ViT模型中的標準實現(xiàn)。

我們首先獲取圖像的補丁嵌入(B, C, H, W)→(B, T, C),這里的PatchEmbedding實現(xiàn)還返回連接在嵌入張量x中的cls_token。如果你想使用timm庫獲取標準的PatchEmbed和Block,也可以這樣做,但這個實現(xiàn)效果相同。即from timm.models.vision_transformer import PatchEmbed, Block

由于我們已經(jīng)有了cls_token,我們首先想要移除它,然后將其傳遞以生成遮蔽。x:(B K C),掩模:(B T)restore_id(B T),其中K是我們保留的標記的長度,即T/4。

然后我們將cls_token連接起來并傳遞給標準的編碼器_transformer。

3.器

解碼階段涉及將輸入嵌入維度更改為decoder_embedding_size。回想一下,輸入維度是(B, K, C),其中K是T/4。因此我們將未遮蔽的補丁與遮蔽的補丁連接起來,然后將它們輸入到另一個視覺transformer模型(解碼器)中,如圖1所示。

class MaskedAutoEncoder(nn.Module):
    def __init__(self, emb_size=1024, decoder_emb_size=512, patch_size=16, num_head=16, encoder_num_layers=24, decoder_num_layers=8, in_channels=3, img_size=224):
        super().__init__()      
        self.patch_embed = PatchEmbedding(emb_size = emb_size)
        self.decoder_embed = nn.Linear(emb_size, decoder_emb_size)
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, (img_size//patch_size)**2 + 1, decoder_emb_size), requires_grad=False)
        self.decoder_pred = nn.Linear(decoder_emb_size, patch_size**2 * in_channels, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_emb_size))
        self.encoder_transformer = nn.Sequential(*[Block(emb_size, num_head) for _ in range(encoder_num_layers)])
        self.decoder_transformer = nn.Sequential(*[Block(decoder_emb_size, num_head) for _ in range(decoder_num_layers)])
        self.project = self.projection = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=patch_size**2 * in_channels, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )

    def encoder(self, x, mask_ratio):
        x = self.patch_embed(x)

        cls_token = x[:, :1, :]
        x = x[:, 1:, :] 

        x, mask, restore_id = random_masking(x, mask_ratio)

        x = torch.cat((cls_token, x), dim=1)

        x = self.encoder_transformer(x)

        return x, mask, restore_id

    def decoder(self, x, restore_id):

        x = self.decoder_embed(x)

        mask_tokens = self.mask_token.repeat(x.shape[0], restore_id.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) 
        x_ = torch.gather(x_, dim=1, index=restore_id.unsqueeze(-1).repeat(1, 1, x.shape[2]))  
        x = torch.cat([x[:, :1, :], x_], dim=1)  

        # add pos embed
        x = x + self.decoder_pos_embed

        x = self.decoder_transformer(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x 

我們將輸入傳遞給decoder_embed。然后我們?yōu)樗形覀冋诒蔚臉擞泟?chuàng)建mask_tokens,并將其與原始輸入x連接起來,不包括其cls_token。

現(xiàn)在張量具有前K個未遮蔽的標記,其余為遮蔽的標記,但現(xiàn)在我們想要按照索引的確切順序重新排序它們。我們可以借助ids_restore來實現(xiàn)。

現(xiàn)在ids_restore具有索引,當傳遞給torch.gather時,將對輸入進行解洗牌。因此,我們在隨機遮蔽中選擇的未遮蔽標記(ids_shuffle中的前幾個隨機索引)現(xiàn)在被重新排列在它們應該在的確切順序中。稍后我們再次將cls_token與重新排序的補丁連接起來。

現(xiàn)在我們將整個輸入傳遞給標準的視覺transformer,并移除cls_token并返回張量x以計算損失。

損失函數(shù)

掩模自編碼器在遮蔽和未遮蔽的補丁上進行訓練,并學習重建圖像中的遮蔽補丁。掩模自編碼器視覺transformer中使用的損失函數(shù)是均方誤差。


class MaskedAutoEncoder(nn.Module):
    def __init__(self, emb_size=1024, decoder_emb_size=512, patch_size=16, num_head=16, encoder_num_layers=24, decoder_num_layers=8, in_channels=3, img_size=224):
        super().__init__()      
        self.patch_embed = PatchEmbedding(emb_size = emb_size)
        self.decoder_embed = nn.Linear(emb_size, decoder_emb_size)
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, (img_size//patch_size)**2 + 1, decoder_emb_size), requires_grad=False)
        self.decoder_pred = nn.Linear(decoder_emb_size, patch_size**2 * in_channels, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_emb_size))
        self.encoder_transformer = nn.Sequential(*[Block(emb_size, num_head) for _ in range(encoder_num_layers)])
        self.decoder_transformer = nn.Sequential(*[Block(decoder_emb_size, num_head) for _ in range(decoder_num_layers)])
        self.project = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=patch_size**2 * in_channels, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )

    def random_masking(x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """

        B, T, D = x.shape  
        len_keep = int(T * (1 - mask_ratio))

        # creating noise of shape (B, T) to latter generate random indices
        noise = torch.rand(B, T, device=x.device)  

        # sorting the noise, and then ids_shuffle to keep the original indexe format
        ids_shuffle = torch.argsort(noise, dim=1)  
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # gathering the first few samples
        ids_keep = ids_shuffle[:, :len_keep]
        x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([B, T], device=x.device)
        mask[:, :len_keep] = 0 

        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x, mask, ids_restore

    def encoder(self, x, mask_ratio):
        x = self.patch_embed(x)

        cls_token = x[:, :1, :]
        x = x[:, 1:, :] 

        x, mask, restore_id = self.random_masking(x, mask_ratio)

        x = torch.cat((cls_token, x), dim=1)

        x = self.encoder_transformer(x)

        return x, mask, restore_id

    def decoder(self, x, restore_id):

        x = self.decoder_embed(x)

        mask_tokens = self.mask_token.repeat(x.shape[0], restore_id.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) 
        x_ = torch.gather(x_, dim=1, index=restore_id.unsqueeze(-1).repeat(1, 1, x.shape[2]))  
        x = torch.cat([x[:, :1, :], x_], dim=1)  

        # add pos embed
        x = x + self.decoder_pos_embed

        x = self.decoder_transformer(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x

    def loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, patch*patch*3]
        mask: [N, L], 0 is keep, 1 is remove, 
        """
        target = self.project(imgs)

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

    def forward(self, img):
        mask_ratio = 0.75

        x, mask, restore_ids = self.encoder(img, mask_ratio)
        pred = self.decoder(x, restore_ids) 
        loss  = self.loss(img, pred, mask) 
        return loss, pred, mask

在未遮蔽的補丁上訓練視覺transformer模型,將未遮蔽補丁的輸出與遮蔽補丁重新排序。

在遮蔽和未遮蔽的補丁結(jié)合在一起的原始形式上訓練視覺transformer模型。

計算解碼器預測輸出的最后一個維度(B, T, decoder embed)和圖像的原始補丁嵌入(B, T, patch embedding)之間的均方誤差損失。

源碼:https://github.com/mishra-18/ML-Models/blob/main/Vission Transformers/mae.py

責任編輯:趙寧寧 來源: 小白玩轉(zhuǎn)Python
相關(guān)推薦

2021-03-22 10:52:13

人工智能深度學習自編碼器

2021-03-29 11:37:50

人工智能深度學習

2024-06-18 08:52:50

LLM算法深度學習

2022-04-02 21:46:27

深度學習編碼器圖像修復

2025-04-10 11:52:55

2021-02-20 20:57:16

深度學習編程人工智能

2024-10-21 16:47:56

2017-07-19 13:40:42

卷積自編碼器降噪

2017-07-03 07:14:49

深度學習無監(jiān)督學習稀疏編碼

2017-12-26 10:48:37

深度學習原始數(shù)據(jù)

2017-11-10 12:45:16

TensorFlowPython神經(jīng)網(wǎng)絡(luò)

2017-05-08 22:40:55

深度學習自編碼器對抗網(wǎng)絡(luò)

2022-09-13 15:26:40

機器學習算法數(shù)據(jù)

2020-04-26 11:26:02

人臉合成編碼器數(shù)據(jù)

2025-04-10 06:30:00

2021-11-02 20:44:47

數(shù)字化

2018-05-21 08:22:14

自編碼器協(xié)同過濾深度學習

2012-04-01 16:40:45

編碼器

2012-04-10 16:55:22

PowerSmart編碼器

2025-04-11 00:16:00

模態(tài)編碼器MAECLIP
點贊
收藏

51CTO技術(shù)棧公眾號

主站蜘蛛池模板: 超碰男人天堂 | 国产精品久久久久久网站 | 精品国产乱码久久久久久闺蜜 | 国产成人综合一区二区三区 | 日本a∨视频 | 中文字幕精品一区 | 亚洲一区二区免费视频 | 91手机精品视频 | 日韩在线中文 | 国产精品爱久久久久久久 | 国产精品日韩欧美一区二区 | www.97zyz.com| 97精品国产 | 天天躁日日躁狠狠躁白人 | 伊人免费观看视频 | 91视频一区 | 久久小视频 | 亚洲精品小视频在线观看 | 成人欧美一区二区三区 | 亚州成人| 中文在线а√在线8 | 天堂久久网 | 国产一级电影在线观看 | 欧美激情一区二区三区 | 欧美综合一区二区三区 | 日韩精品国产精品 | 九一在线观看 | 中文字幕日韩三级 | 在线观看国产视频 | 欧美精品在线一区二区三区 | 九九热这里只有精品6 | 亚洲精品一区中文字幕乱码 | 伊人久久大香线 | 中文字幕一区二区三区四区五区 | 亚洲成人av一区二区 | 成人网av| 亚洲免费人成在线视频观看 | 免费观看一级毛片 | 国产亚洲精品a | 国产亚洲精品一区二区三区 | 亚洲欧美高清 |