leukemia/models/fusion_model.py

116 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.image_models import VisionTransformer
class MultiModalFusionModel(nn.Module):
"""多模态融合模型融合DIFF和WNB散点图的特征"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
super().__init__()
# DIFF散点图特征提取器
self.diff_encoder = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
dropout=dropout
)
# WNB散点图特征提取器
self.wnb_encoder = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
dropout=dropout
)
# 特征融合层
self.fusion = nn.Sequential(
nn.Linear(embed_dim * 2, embed_dim),
nn.LayerNorm(embed_dim),
nn.GELU(),
nn.Dropout(dropout)
)
# 分类头
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, diff_img, wnb_img):
"""
前向传播
Args:
diff_img: DIFF散点图 [B, C, H, W]
wnb_img: WNB散点图 [B, C, H, W]
Returns:
logits: 分类logits [B, num_classes]
"""
# 提取特征
diff_features = self.diff_encoder(diff_img) # [B, E]
wnb_features = self.wnb_encoder(wnb_img) # [B, E]
# 特征融合
combined_features = torch.cat([diff_features, wnb_features], dim=1) # [B, 2*E]
fused_features = self.fusion(combined_features) # [B, E]
# 分类
logits = self.classifier(fused_features) # [B, num_classes]
return logits
def extract_features(self, diff_img, wnb_img):
"""提取各个模态的特征,用于分析"""
diff_features = self.diff_encoder(diff_img)
wnb_features = self.wnb_encoder(wnb_img)
return {
'diff': diff_features,
'wnb': wnb_features
}
class SingleModalModel(nn.Module):
"""单模态模型,用于对比实验"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
super().__init__()
# 图像特征提取器
self.encoder = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
dropout=dropout
)
# 分类头
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, img):
"""
前向传播
Args:
img: 输入图像 [B, C, H, W]
Returns:
logits: 分类logits [B, num_classes]
"""
# 提取特征
features = self.encoder(img) # [B, E]
# 分类
logits = self.classifier(features) # [B, num_classes]
return logits