first commit
This commit is contained in:
commit
9052bb9243
BIN
src/__pycache__/dataset.cpython-312.pyc
Normal file
BIN
src/__pycache__/dataset.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/model.cpython-312.pyc
Normal file
BIN
src/__pycache__/model.cpython-312.pyc
Normal file
Binary file not shown.
84
src/dataset.py
Normal file
84
src/dataset.py
Normal file
@ -0,0 +1,84 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
class ChestXrayDataset(Dataset):
|
||||
def __init__(self, file_paths, labels, transform=None):
|
||||
self.file_paths = file_paths
|
||||
self.labels = labels
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.file_paths[idx]
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
label = self.labels[idx]
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
return image, label
|
||||
|
||||
def prepare_data(data_dir, batch_size=32):
|
||||
# 获取所有图片文件路径
|
||||
normal_dir = os.path.join(data_dir, 'normal')
|
||||
pneumonia_dir = os.path.join(data_dir, 'pneumonia')
|
||||
|
||||
normal_files = [os.path.join(normal_dir, f) for f in os.listdir(normal_dir)
|
||||
if f.endswith(('.png', '.jpg', '.jpeg'))]
|
||||
pneumonia_files = [os.path.join(pneumonia_dir, f) for f in os.listdir(pneumonia_dir)
|
||||
if f.endswith(('.png', '.jpg', '.jpeg'))]
|
||||
|
||||
# 合并文件路径和标签
|
||||
all_files = normal_files + pneumonia_files
|
||||
labels = [0] * len(normal_files) + [1] * len(pneumonia_files)
|
||||
|
||||
# 划分训练集和测试集
|
||||
train_files, test_files, train_labels, test_labels = train_test_split(
|
||||
all_files, labels, test_size=0.2, random_state=42, stratify=labels
|
||||
)
|
||||
|
||||
# 数据预处理和增强
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomRotation(10),
|
||||
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# 创建数据集
|
||||
train_dataset = ChestXrayDataset(train_files, train_labels, train_transform)
|
||||
test_dataset = ChestXrayDataset(test_files, test_labels, test_transform)
|
||||
|
||||
# 创建数据加载器
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
return train_loader, test_loader
|
||||
67
src/model.py
Normal file
67
src/model.py
Normal file
@ -0,0 +1,67 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class MultiResRibNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiResRibNet, self).__init__()
|
||||
|
||||
# 高分辨率路径 (224x224)
|
||||
self.high_res_path = nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# 低分辨率路径 (112x112)
|
||||
self.low_res_path = nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# 特征融合
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# 分类器
|
||||
self.classifier = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Flatten(),
|
||||
nn.Linear(64, 2)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# 高分辨率特征
|
||||
high_res = self.high_res_path(x)
|
||||
|
||||
# 低分辨率特征
|
||||
low_res_input = F.interpolate(x, scale_factor=0.5)
|
||||
low_res = self.low_res_path(low_res_input)
|
||||
low_res = F.interpolate(low_res, size=high_res.shape[2:])
|
||||
|
||||
# 特征融合
|
||||
fused = torch.cat([high_res, low_res], dim=1)
|
||||
fused = self.fusion(fused)
|
||||
|
||||
# 分类
|
||||
out = self.classifier(fused)
|
||||
return out
|
||||
87
src/train.py
Normal file
87
src/train.py
Normal file
@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from model import MultiResRibNet
|
||||
from dataset import prepare_data
|
||||
import os
|
||||
|
||||
def train(model, train_loader, test_loader, num_epochs=50):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
|
||||
|
||||
best_acc = 0.0
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
# 训练阶段
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for inputs, labels in train_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
train_loss = running_loss / len(train_loader)
|
||||
train_acc = 100. * correct / total
|
||||
|
||||
# 测试阶段
|
||||
model.eval()
|
||||
test_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_acc = 100. * correct / total
|
||||
|
||||
scheduler.step(test_loss)
|
||||
|
||||
print(f'Epoch [{epoch+1}/{num_epochs}]')
|
||||
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
|
||||
print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%')
|
||||
|
||||
# 保存最佳模型
|
||||
if test_acc > best_acc:
|
||||
best_acc = test_acc
|
||||
torch.save(model.state_dict(), 'best_model.pth')
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 数据路径
|
||||
data_dir = '../data/'
|
||||
|
||||
# 准备数据
|
||||
train_loader, test_loader = prepare_data(data_dir, batch_size=32)
|
||||
|
||||
# 创建模型
|
||||
model = MultiResRibNet()
|
||||
|
||||
# 训练模型
|
||||
train(model, train_loader, test_loader)
|
||||
|
||||
|
||||
# 只为测试 111
|
||||
Loading…
Reference in New Issue
Block a user