116 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			116 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						||
import torch.nn as nn
 | 
						||
import torch.nn.functional as F
 | 
						||
from models.image_models import VisionTransformer
 | 
						||
 | 
						||
class MultiModalFusionModel(nn.Module):
 | 
						||
    """多模态融合模型,融合DIFF和WNB散点图的特征"""
 | 
						||
    def __init__(self, img_size=224, patch_size=16, in_channels=3, 
 | 
						||
                 embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
 | 
						||
        super().__init__()
 | 
						||
        
 | 
						||
        # DIFF散点图特征提取器
 | 
						||
        self.diff_encoder = VisionTransformer(
 | 
						||
            img_size=img_size,
 | 
						||
            patch_size=patch_size,
 | 
						||
            in_channels=in_channels,
 | 
						||
            embed_dim=embed_dim,
 | 
						||
            depth=depth,
 | 
						||
            num_heads=num_heads,
 | 
						||
            dropout=dropout
 | 
						||
        )
 | 
						||
        
 | 
						||
        # WNB散点图特征提取器
 | 
						||
        self.wnb_encoder = VisionTransformer(
 | 
						||
            img_size=img_size,
 | 
						||
            patch_size=patch_size,
 | 
						||
            in_channels=in_channels,
 | 
						||
            embed_dim=embed_dim,
 | 
						||
            depth=depth,
 | 
						||
            num_heads=num_heads,
 | 
						||
            dropout=dropout
 | 
						||
        )
 | 
						||
        
 | 
						||
        # 特征融合层
 | 
						||
        self.fusion = nn.Sequential(
 | 
						||
            nn.Linear(embed_dim * 2, embed_dim),
 | 
						||
            nn.LayerNorm(embed_dim),
 | 
						||
            nn.GELU(),
 | 
						||
            nn.Dropout(dropout)
 | 
						||
        )
 | 
						||
        
 | 
						||
        # 分类头
 | 
						||
        self.classifier = nn.Linear(embed_dim, num_classes)
 | 
						||
    
 | 
						||
    def forward(self, diff_img, wnb_img):
 | 
						||
        """
 | 
						||
        前向传播
 | 
						||
        
 | 
						||
        Args:
 | 
						||
            diff_img: DIFF散点图 [B, C, H, W]
 | 
						||
            wnb_img: WNB散点图 [B, C, H, W]
 | 
						||
        
 | 
						||
        Returns:
 | 
						||
            logits: 分类logits [B, num_classes]
 | 
						||
        """
 | 
						||
        # 提取特征
 | 
						||
        diff_features = self.diff_encoder(diff_img)  # [B, E]
 | 
						||
        wnb_features = self.wnb_encoder(wnb_img)    # [B, E]
 | 
						||
        
 | 
						||
        # 特征融合
 | 
						||
        combined_features = torch.cat([diff_features, wnb_features], dim=1)  # [B, 2*E]
 | 
						||
        fused_features = self.fusion(combined_features)  # [B, E]
 | 
						||
        
 | 
						||
        # 分类
 | 
						||
        logits = self.classifier(fused_features)  # [B, num_classes]
 | 
						||
        
 | 
						||
        return logits
 | 
						||
    
 | 
						||
    def extract_features(self, diff_img, wnb_img):
 | 
						||
        """提取各个模态的特征,用于分析"""
 | 
						||
        diff_features = self.diff_encoder(diff_img)
 | 
						||
        wnb_features = self.wnb_encoder(wnb_img)
 | 
						||
        
 | 
						||
        return {
 | 
						||
            'diff': diff_features,
 | 
						||
            'wnb': wnb_features
 | 
						||
        }
 | 
						||
 | 
						||
 | 
						||
class SingleModalModel(nn.Module):
 | 
						||
    """单模态模型,用于对比实验"""
 | 
						||
    def __init__(self, img_size=224, patch_size=16, in_channels=3, 
 | 
						||
                 embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
 | 
						||
        super().__init__()
 | 
						||
        
 | 
						||
        # 图像特征提取器
 | 
						||
        self.encoder = VisionTransformer(
 | 
						||
            img_size=img_size,
 | 
						||
            patch_size=patch_size,
 | 
						||
            in_channels=in_channels,
 | 
						||
            embed_dim=embed_dim,
 | 
						||
            depth=depth,
 | 
						||
            num_heads=num_heads,
 | 
						||
            dropout=dropout
 | 
						||
        )
 | 
						||
        
 | 
						||
        # 分类头
 | 
						||
        self.classifier = nn.Linear(embed_dim, num_classes)
 | 
						||
    
 | 
						||
    def forward(self, img):
 | 
						||
        """
 | 
						||
        前向传播
 | 
						||
        
 | 
						||
        Args:
 | 
						||
            img: 输入图像 [B, C, H, W]
 | 
						||
        
 | 
						||
        Returns:
 | 
						||
            logits: 分类logits [B, num_classes]
 | 
						||
        """
 | 
						||
        # 提取特征
 | 
						||
        features = self.encoder(img)  # [B, E]
 | 
						||
        
 | 
						||
        # 分类
 | 
						||
        logits = self.classifier(features)  # [B, num_classes]
 | 
						||
        
 | 
						||
        return logits |