import torch import torch.nn as nn from torch.nn import TransformerEncoder, TransformerEncoderLayer import torch.nn.functional as F import torchvision.models as models class PatchEmbedding(nn.Module): """将图像分割为patch并进行embedding""" def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() self.img_size = img_size self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): # x: [B, C, H, W] batch_size = x.shape[0] x = self.proj(x) # [B, E, H/P, W/P] x = x.flatten(2) # [B, E, (H/P)*(W/P)] x = x.transpose(1, 2) # [B, (H/P)*(W/P), E] return x class VisionTransformer(nn.Module): """基于Transformer的图像特征提取模型""" def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=6, num_heads=12, dropout=0.1): super().__init__() # Patch Embedding self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) self.n_patches = self.patch_embed.n_patches # Position Embedding self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # Transformer Encoder encoder_layer = TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, dropout=dropout, activation='gelu', batch_first=True ) self.transformer = TransformerEncoder(encoder_layer, num_layers=depth) # 层归一化 self.norm = nn.LayerNorm(embed_dim) # 初始化 nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, x): # x: [B, C, H, W] batch_size = x.shape[0] # Patch Embedding: [B, N, E] x = self.patch_embed(x) # 添加CLS token cls_token = self.cls_token.expand(batch_size, -1, -1) # [B, 1, E] x = torch.cat([cls_token, x], dim=1) # [B, N+1, E] # 添加Position Embedding x = x + self.pos_embed # Transformer Encoder x = self.transformer(x) # 提取CLS token作为整个图像的特征 x = x[:, 0] # [B, E] return x