leukemia/models/fusion_model.py

116 lines
3.4 KiB
Python
Raw Permalink Normal View History

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.image_models import VisionTransformer
class MultiModalFusionModel(nn.Module):
"""多模态融合模型融合DIFF和WNB散点图的特征"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
super().__init__()
# DIFF散点图特征提取器
self.diff_encoder = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
dropout=dropout
)
# WNB散点图特征提取器
self.wnb_encoder = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
dropout=dropout
)
# 特征融合层
self.fusion = nn.Sequential(
nn.Linear(embed_dim * 2, embed_dim),
nn.LayerNorm(embed_dim),
nn.GELU(),
nn.Dropout(dropout)
)
# 分类头
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, diff_img, wnb_img):
"""
前向传播
Args:
diff_img: DIFF散点图 [B, C, H, W]
wnb_img: WNB散点图 [B, C, H, W]
Returns:
logits: 分类logits [B, num_classes]
"""
# 提取特征
diff_features = self.diff_encoder(diff_img) # [B, E]
wnb_features = self.wnb_encoder(wnb_img) # [B, E]
# 特征融合
combined_features = torch.cat([diff_features, wnb_features], dim=1) # [B, 2*E]
fused_features = self.fusion(combined_features) # [B, E]
# 分类
logits = self.classifier(fused_features) # [B, num_classes]
return logits
def extract_features(self, diff_img, wnb_img):
"""提取各个模态的特征,用于分析"""
diff_features = self.diff_encoder(diff_img)
wnb_features = self.wnb_encoder(wnb_img)
return {
'diff': diff_features,
'wnb': wnb_features
}
class SingleModalModel(nn.Module):
"""单模态模型,用于对比实验"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
super().__init__()
# 图像特征提取器
self.encoder = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
dropout=dropout
)
# 分类头
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, img):
"""
前向传播
Args:
img: 输入图像 [B, C, H, W]
Returns:
logits: 分类logits [B, num_classes]
"""
# 提取特征
features = self.encoder(img) # [B, E]
# 分类
logits = self.classifier(features) # [B, num_classes]
return logits