78 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			78 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import torch.nn as nn
 | 
						|
from torch.nn import TransformerEncoder, TransformerEncoderLayer
 | 
						|
import torch.nn.functional as F
 | 
						|
import torchvision.models as models
 | 
						|
 | 
						|
class PatchEmbedding(nn.Module):
 | 
						|
    """将图像分割为patch并进行embedding"""
 | 
						|
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
 | 
						|
        super().__init__()
 | 
						|
        self.img_size = img_size
 | 
						|
        self.patch_size = patch_size
 | 
						|
        self.n_patches = (img_size // patch_size) ** 2
 | 
						|
        
 | 
						|
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
 | 
						|
    
 | 
						|
    def forward(self, x):
 | 
						|
        # x: [B, C, H, W]
 | 
						|
        batch_size = x.shape[0]
 | 
						|
        x = self.proj(x)  # [B, E, H/P, W/P]
 | 
						|
        x = x.flatten(2)  # [B, E, (H/P)*(W/P)]
 | 
						|
        x = x.transpose(1, 2)  # [B, (H/P)*(W/P), E]
 | 
						|
        return x
 | 
						|
 | 
						|
 | 
						|
class VisionTransformer(nn.Module):
 | 
						|
    """基于Transformer的图像特征提取模型"""
 | 
						|
    def __init__(self, img_size=224, patch_size=16, in_channels=3, 
 | 
						|
                 embed_dim=768, depth=6, num_heads=12, dropout=0.1):
 | 
						|
        super().__init__()
 | 
						|
        
 | 
						|
        # Patch Embedding
 | 
						|
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
 | 
						|
        self.n_patches = self.patch_embed.n_patches
 | 
						|
        
 | 
						|
        # Position Embedding
 | 
						|
        self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim))
 | 
						|
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
 | 
						|
        
 | 
						|
        # Transformer Encoder
 | 
						|
        encoder_layer = TransformerEncoderLayer(
 | 
						|
            d_model=embed_dim,
 | 
						|
            nhead=num_heads,
 | 
						|
            dim_feedforward=embed_dim * 4,
 | 
						|
            dropout=dropout,
 | 
						|
            activation='gelu',
 | 
						|
            batch_first=True
 | 
						|
        )
 | 
						|
        self.transformer = TransformerEncoder(encoder_layer, num_layers=depth)
 | 
						|
        
 | 
						|
        # 层归一化
 | 
						|
        self.norm = nn.LayerNorm(embed_dim)
 | 
						|
        
 | 
						|
        # 初始化
 | 
						|
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
 | 
						|
        nn.init.trunc_normal_(self.cls_token, std=0.02)
 | 
						|
    
 | 
						|
    def forward(self, x):
 | 
						|
        # x: [B, C, H, W]
 | 
						|
        batch_size = x.shape[0]
 | 
						|
        
 | 
						|
        # Patch Embedding: [B, N, E]
 | 
						|
        x = self.patch_embed(x)
 | 
						|
        
 | 
						|
        # 添加CLS token
 | 
						|
        cls_token = self.cls_token.expand(batch_size, -1, -1)  # [B, 1, E]
 | 
						|
        x = torch.cat([cls_token, x], dim=1)  # [B, N+1, E]
 | 
						|
        
 | 
						|
        # 添加Position Embedding
 | 
						|
        x = x + self.pos_embed
 | 
						|
        
 | 
						|
        # Transformer Encoder
 | 
						|
        x = self.transformer(x)
 | 
						|
        
 | 
						|
        # 提取CLS token作为整个图像的特征
 | 
						|
        x = x[:, 0]  # [B, E]
 | 
						|
        
 | 
						|
        return x |