import torch import torch.nn as nn import torch.nn.functional as F from models.image_models import VisionTransformer class MultiModalFusionModel(nn.Module): """多模态融合模型,融合DIFF和WNB散点图的特征""" def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2): super().__init__() # DIFF散点图特征提取器 self.diff_encoder = VisionTransformer( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, depth=depth, num_heads=num_heads, dropout=dropout ) # WNB散点图特征提取器 self.wnb_encoder = VisionTransformer( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, depth=depth, num_heads=num_heads, dropout=dropout ) # 特征融合层 self.fusion = nn.Sequential( nn.Linear(embed_dim * 2, embed_dim), nn.LayerNorm(embed_dim), nn.GELU(), nn.Dropout(dropout) ) # 分类头 self.classifier = nn.Linear(embed_dim, num_classes) def forward(self, diff_img, wnb_img): """ 前向传播 Args: diff_img: DIFF散点图 [B, C, H, W] wnb_img: WNB散点图 [B, C, H, W] Returns: logits: 分类logits [B, num_classes] """ # 提取特征 diff_features = self.diff_encoder(diff_img) # [B, E] wnb_features = self.wnb_encoder(wnb_img) # [B, E] # 特征融合 combined_features = torch.cat([diff_features, wnb_features], dim=1) # [B, 2*E] fused_features = self.fusion(combined_features) # [B, E] # 分类 logits = self.classifier(fused_features) # [B, num_classes] return logits def extract_features(self, diff_img, wnb_img): """提取各个模态的特征,用于分析""" diff_features = self.diff_encoder(diff_img) wnb_features = self.wnb_encoder(wnb_img) return { 'diff': diff_features, 'wnb': wnb_features } class SingleModalModel(nn.Module): """单模态模型,用于对比实验""" def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2): super().__init__() # 图像特征提取器 self.encoder = VisionTransformer( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, depth=depth, num_heads=num_heads, dropout=dropout ) # 分类头 self.classifier = nn.Linear(embed_dim, num_classes) def forward(self, img): """ 前向传播 Args: img: 输入图像 [B, C, H, W] Returns: logits: 分类logits [B, num_classes] """ # 提取特征 features = self.encoder(img) # [B, E] # 分类 logits = self.classifier(features) # [B, num_classes] return logits