116 lines
3.4 KiB
Python
116 lines
3.4 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from models.image_models import VisionTransformer
|
||
|
||
class MultiModalFusionModel(nn.Module):
|
||
"""多模态融合模型,融合DIFF和WNB散点图的特征"""
|
||
def __init__(self, img_size=224, patch_size=16, in_channels=3,
|
||
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
|
||
super().__init__()
|
||
|
||
# DIFF散点图特征提取器
|
||
self.diff_encoder = VisionTransformer(
|
||
img_size=img_size,
|
||
patch_size=patch_size,
|
||
in_channels=in_channels,
|
||
embed_dim=embed_dim,
|
||
depth=depth,
|
||
num_heads=num_heads,
|
||
dropout=dropout
|
||
)
|
||
|
||
# WNB散点图特征提取器
|
||
self.wnb_encoder = VisionTransformer(
|
||
img_size=img_size,
|
||
patch_size=patch_size,
|
||
in_channels=in_channels,
|
||
embed_dim=embed_dim,
|
||
depth=depth,
|
||
num_heads=num_heads,
|
||
dropout=dropout
|
||
)
|
||
|
||
# 特征融合层
|
||
self.fusion = nn.Sequential(
|
||
nn.Linear(embed_dim * 2, embed_dim),
|
||
nn.LayerNorm(embed_dim),
|
||
nn.GELU(),
|
||
nn.Dropout(dropout)
|
||
)
|
||
|
||
# 分类头
|
||
self.classifier = nn.Linear(embed_dim, num_classes)
|
||
|
||
def forward(self, diff_img, wnb_img):
|
||
"""
|
||
前向传播
|
||
|
||
Args:
|
||
diff_img: DIFF散点图 [B, C, H, W]
|
||
wnb_img: WNB散点图 [B, C, H, W]
|
||
|
||
Returns:
|
||
logits: 分类logits [B, num_classes]
|
||
"""
|
||
# 提取特征
|
||
diff_features = self.diff_encoder(diff_img) # [B, E]
|
||
wnb_features = self.wnb_encoder(wnb_img) # [B, E]
|
||
|
||
# 特征融合
|
||
combined_features = torch.cat([diff_features, wnb_features], dim=1) # [B, 2*E]
|
||
fused_features = self.fusion(combined_features) # [B, E]
|
||
|
||
# 分类
|
||
logits = self.classifier(fused_features) # [B, num_classes]
|
||
|
||
return logits
|
||
|
||
def extract_features(self, diff_img, wnb_img):
|
||
"""提取各个模态的特征,用于分析"""
|
||
diff_features = self.diff_encoder(diff_img)
|
||
wnb_features = self.wnb_encoder(wnb_img)
|
||
|
||
return {
|
||
'diff': diff_features,
|
||
'wnb': wnb_features
|
||
}
|
||
|
||
|
||
class SingleModalModel(nn.Module):
|
||
"""单模态模型,用于对比实验"""
|
||
def __init__(self, img_size=224, patch_size=16, in_channels=3,
|
||
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
|
||
super().__init__()
|
||
|
||
# 图像特征提取器
|
||
self.encoder = VisionTransformer(
|
||
img_size=img_size,
|
||
patch_size=patch_size,
|
||
in_channels=in_channels,
|
||
embed_dim=embed_dim,
|
||
depth=depth,
|
||
num_heads=num_heads,
|
||
dropout=dropout
|
||
)
|
||
|
||
# 分类头
|
||
self.classifier = nn.Linear(embed_dim, num_classes)
|
||
|
||
def forward(self, img):
|
||
"""
|
||
前向传播
|
||
|
||
Args:
|
||
img: 输入图像 [B, C, H, W]
|
||
|
||
Returns:
|
||
logits: 分类logits [B, num_classes]
|
||
"""
|
||
# 提取特征
|
||
features = self.encoder(img) # [B, E]
|
||
|
||
# 分类
|
||
logits = self.classifier(features) # [B, num_classes]
|
||
|
||
return logits |