116 lines
3.4 KiB
Python
116 lines
3.4 KiB
Python
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from models.image_models import VisionTransformer
|
|||
|
|
|
|||
|
|
class MultiModalFusionModel(nn.Module):
|
|||
|
|
"""多模态融合模型,融合DIFF和WNB散点图的特征"""
|
|||
|
|
def __init__(self, img_size=224, patch_size=16, in_channels=3,
|
|||
|
|
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
# DIFF散点图特征提取器
|
|||
|
|
self.diff_encoder = VisionTransformer(
|
|||
|
|
img_size=img_size,
|
|||
|
|
patch_size=patch_size,
|
|||
|
|
in_channels=in_channels,
|
|||
|
|
embed_dim=embed_dim,
|
|||
|
|
depth=depth,
|
|||
|
|
num_heads=num_heads,
|
|||
|
|
dropout=dropout
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# WNB散点图特征提取器
|
|||
|
|
self.wnb_encoder = VisionTransformer(
|
|||
|
|
img_size=img_size,
|
|||
|
|
patch_size=patch_size,
|
|||
|
|
in_channels=in_channels,
|
|||
|
|
embed_dim=embed_dim,
|
|||
|
|
depth=depth,
|
|||
|
|
num_heads=num_heads,
|
|||
|
|
dropout=dropout
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 特征融合层
|
|||
|
|
self.fusion = nn.Sequential(
|
|||
|
|
nn.Linear(embed_dim * 2, embed_dim),
|
|||
|
|
nn.LayerNorm(embed_dim),
|
|||
|
|
nn.GELU(),
|
|||
|
|
nn.Dropout(dropout)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 分类头
|
|||
|
|
self.classifier = nn.Linear(embed_dim, num_classes)
|
|||
|
|
|
|||
|
|
def forward(self, diff_img, wnb_img):
|
|||
|
|
"""
|
|||
|
|
前向传播
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
diff_img: DIFF散点图 [B, C, H, W]
|
|||
|
|
wnb_img: WNB散点图 [B, C, H, W]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
logits: 分类logits [B, num_classes]
|
|||
|
|
"""
|
|||
|
|
# 提取特征
|
|||
|
|
diff_features = self.diff_encoder(diff_img) # [B, E]
|
|||
|
|
wnb_features = self.wnb_encoder(wnb_img) # [B, E]
|
|||
|
|
|
|||
|
|
# 特征融合
|
|||
|
|
combined_features = torch.cat([diff_features, wnb_features], dim=1) # [B, 2*E]
|
|||
|
|
fused_features = self.fusion(combined_features) # [B, E]
|
|||
|
|
|
|||
|
|
# 分类
|
|||
|
|
logits = self.classifier(fused_features) # [B, num_classes]
|
|||
|
|
|
|||
|
|
return logits
|
|||
|
|
|
|||
|
|
def extract_features(self, diff_img, wnb_img):
|
|||
|
|
"""提取各个模态的特征,用于分析"""
|
|||
|
|
diff_features = self.diff_encoder(diff_img)
|
|||
|
|
wnb_features = self.wnb_encoder(wnb_img)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'diff': diff_features,
|
|||
|
|
'wnb': wnb_features
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SingleModalModel(nn.Module):
|
|||
|
|
"""单模态模型,用于对比实验"""
|
|||
|
|
def __init__(self, img_size=224, patch_size=16, in_channels=3,
|
|||
|
|
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
# 图像特征提取器
|
|||
|
|
self.encoder = VisionTransformer(
|
|||
|
|
img_size=img_size,
|
|||
|
|
patch_size=patch_size,
|
|||
|
|
in_channels=in_channels,
|
|||
|
|
embed_dim=embed_dim,
|
|||
|
|
depth=depth,
|
|||
|
|
num_heads=num_heads,
|
|||
|
|
dropout=dropout
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 分类头
|
|||
|
|
self.classifier = nn.Linear(embed_dim, num_classes)
|
|||
|
|
|
|||
|
|
def forward(self, img):
|
|||
|
|
"""
|
|||
|
|
前向传播
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
img: 输入图像 [B, C, H, W]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
logits: 分类logits [B, num_classes]
|
|||
|
|
"""
|
|||
|
|
# 提取特征
|
|||
|
|
features = self.encoder(img) # [B, E]
|
|||
|
|
|
|||
|
|
# 分类
|
|||
|
|
logits = self.classifier(features) # [B, num_classes]
|
|||
|
|
|
|||
|
|
return logits
|