78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
|
import torch.nn.functional as F
|
|
import torchvision.models as models
|
|
|
|
class PatchEmbedding(nn.Module):
|
|
"""将图像分割为patch并进行embedding"""
|
|
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
|
|
super().__init__()
|
|
self.img_size = img_size
|
|
self.patch_size = patch_size
|
|
self.n_patches = (img_size // patch_size) ** 2
|
|
|
|
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
def forward(self, x):
|
|
# x: [B, C, H, W]
|
|
batch_size = x.shape[0]
|
|
x = self.proj(x) # [B, E, H/P, W/P]
|
|
x = x.flatten(2) # [B, E, (H/P)*(W/P)]
|
|
x = x.transpose(1, 2) # [B, (H/P)*(W/P), E]
|
|
return x
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
"""基于Transformer的图像特征提取模型"""
|
|
def __init__(self, img_size=224, patch_size=16, in_channels=3,
|
|
embed_dim=768, depth=6, num_heads=12, dropout=0.1):
|
|
super().__init__()
|
|
|
|
# Patch Embedding
|
|
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
|
|
self.n_patches = self.patch_embed.n_patches
|
|
|
|
# Position Embedding
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim))
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
|
# Transformer Encoder
|
|
encoder_layer = TransformerEncoderLayer(
|
|
d_model=embed_dim,
|
|
nhead=num_heads,
|
|
dim_feedforward=embed_dim * 4,
|
|
dropout=dropout,
|
|
activation='gelu',
|
|
batch_first=True
|
|
)
|
|
self.transformer = TransformerEncoder(encoder_layer, num_layers=depth)
|
|
|
|
# 层归一化
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
# 初始化
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
|
|
|
def forward(self, x):
|
|
# x: [B, C, H, W]
|
|
batch_size = x.shape[0]
|
|
|
|
# Patch Embedding: [B, N, E]
|
|
x = self.patch_embed(x)
|
|
|
|
# 添加CLS token
|
|
cls_token = self.cls_token.expand(batch_size, -1, -1) # [B, 1, E]
|
|
x = torch.cat([cls_token, x], dim=1) # [B, N+1, E]
|
|
|
|
# 添加Position Embedding
|
|
x = x + self.pos_embed
|
|
|
|
# Transformer Encoder
|
|
x = self.transformer(x)
|
|
|
|
# 提取CLS token作为整个图像的特征
|
|
x = x[:, 0] # [B, E]
|
|
|
|
return x |