first commit
This commit is contained in:
commit
9052bb9243
BIN
src/__pycache__/dataset.cpython-312.pyc
Normal file
BIN
src/__pycache__/dataset.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/model.cpython-312.pyc
Normal file
BIN
src/__pycache__/model.cpython-312.pyc
Normal file
Binary file not shown.
84
src/dataset.py
Normal file
84
src/dataset.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
class ChestXrayDataset(Dataset):
|
||||||
|
def __init__(self, file_paths, labels, transform=None):
|
||||||
|
self.file_paths = file_paths
|
||||||
|
self.labels = labels
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.file_paths)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img_path = self.file_paths[idx]
|
||||||
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
label = self.labels[idx]
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
image = self.transform(image)
|
||||||
|
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
def prepare_data(data_dir, batch_size=32):
|
||||||
|
# 获取所有图片文件路径
|
||||||
|
normal_dir = os.path.join(data_dir, 'normal')
|
||||||
|
pneumonia_dir = os.path.join(data_dir, 'pneumonia')
|
||||||
|
|
||||||
|
normal_files = [os.path.join(normal_dir, f) for f in os.listdir(normal_dir)
|
||||||
|
if f.endswith(('.png', '.jpg', '.jpeg'))]
|
||||||
|
pneumonia_files = [os.path.join(pneumonia_dir, f) for f in os.listdir(pneumonia_dir)
|
||||||
|
if f.endswith(('.png', '.jpg', '.jpeg'))]
|
||||||
|
|
||||||
|
# 合并文件路径和标签
|
||||||
|
all_files = normal_files + pneumonia_files
|
||||||
|
labels = [0] * len(normal_files) + [1] * len(pneumonia_files)
|
||||||
|
|
||||||
|
# 划分训练集和测试集
|
||||||
|
train_files, test_files, train_labels, test_labels = train_test_split(
|
||||||
|
all_files, labels, test_size=0.2, random_state=42, stratify=labels
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数据预处理和增强
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.RandomRotation(10),
|
||||||
|
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
|
||||||
|
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
test_transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
# 创建数据集
|
||||||
|
train_dataset = ChestXrayDataset(train_files, train_labels, train_transform)
|
||||||
|
test_dataset = ChestXrayDataset(test_files, test_labels, test_transform)
|
||||||
|
|
||||||
|
# 创建数据加载器
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, test_loader
|
||||||
67
src/model.py
Normal file
67
src/model.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class MultiResRibNet(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiResRibNet, self).__init__()
|
||||||
|
|
||||||
|
# 高分辨率路径 (224x224)
|
||||||
|
self.high_res_path = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 低分辨率路径 (112x112)
|
||||||
|
self.low_res_path = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 特征融合
|
||||||
|
self.fusion = nn.Sequential(
|
||||||
|
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 分类器
|
||||||
|
self.classifier = nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(64, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 高分辨率特征
|
||||||
|
high_res = self.high_res_path(x)
|
||||||
|
|
||||||
|
# 低分辨率特征
|
||||||
|
low_res_input = F.interpolate(x, scale_factor=0.5)
|
||||||
|
low_res = self.low_res_path(low_res_input)
|
||||||
|
low_res = F.interpolate(low_res, size=high_res.shape[2:])
|
||||||
|
|
||||||
|
# 特征融合
|
||||||
|
fused = torch.cat([high_res, low_res], dim=1)
|
||||||
|
fused = self.fusion(fused)
|
||||||
|
|
||||||
|
# 分类
|
||||||
|
out = self.classifier(fused)
|
||||||
|
return out
|
||||||
87
src/train.py
Normal file
87
src/train.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from model import MultiResRibNet
|
||||||
|
from dataset import prepare_data
|
||||||
|
import os
|
||||||
|
|
||||||
|
def train(model, train_loader, test_loader, num_epochs=50):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
|
||||||
|
|
||||||
|
best_acc = 0.0
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
# 训练阶段
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for inputs, labels in train_loader:
|
||||||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
train_loss = running_loss / len(train_loader)
|
||||||
|
train_acc = 100. * correct / total
|
||||||
|
|
||||||
|
# 测试阶段
|
||||||
|
model.eval()
|
||||||
|
test_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, labels in test_loader:
|
||||||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
test_loss += loss.item()
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
test_loss = test_loss / len(test_loader)
|
||||||
|
test_acc = 100. * correct / total
|
||||||
|
|
||||||
|
scheduler.step(test_loss)
|
||||||
|
|
||||||
|
print(f'Epoch [{epoch+1}/{num_epochs}]')
|
||||||
|
print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
|
||||||
|
print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%')
|
||||||
|
|
||||||
|
# 保存最佳模型
|
||||||
|
if test_acc > best_acc:
|
||||||
|
best_acc = test_acc
|
||||||
|
torch.save(model.state_dict(), 'best_model.pth')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 数据路径
|
||||||
|
data_dir = '../data/'
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
train_loader, test_loader = prepare_data(data_dir, batch_size=32)
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
model = MultiResRibNet()
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
train(model, train_loader, test_loader)
|
||||||
|
|
||||||
|
|
||||||
|
# 只为测试 111
|
||||||
Loading…
Reference in New Issue
Block a user