commit 9052bb92432bd4bda8f683e14f39bfc60385aed8 Author: lotus Date: Thu Nov 21 22:05:36 2024 +0800 first commit diff --git a/src/__pycache__/dataset.cpython-312.pyc b/src/__pycache__/dataset.cpython-312.pyc new file mode 100644 index 0000000..2316d89 Binary files /dev/null and b/src/__pycache__/dataset.cpython-312.pyc differ diff --git a/src/__pycache__/model.cpython-312.pyc b/src/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000..e7d2d5f Binary files /dev/null and b/src/__pycache__/model.cpython-312.pyc differ diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..7df3a4c --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,84 @@ +import os +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from PIL import Image +from sklearn.model_selection import train_test_split + +class ChestXrayDataset(Dataset): + def __init__(self, file_paths, labels, transform=None): + self.file_paths = file_paths + self.labels = labels + self.transform = transform + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + img_path = self.file_paths[idx] + image = Image.open(img_path).convert('RGB') + label = self.labels[idx] + + if self.transform: + image = self.transform(image) + + return image, label + +def prepare_data(data_dir, batch_size=32): + # 获取所有图片文件路径 + normal_dir = os.path.join(data_dir, 'normal') + pneumonia_dir = os.path.join(data_dir, 'pneumonia') + + normal_files = [os.path.join(normal_dir, f) for f in os.listdir(normal_dir) + if f.endswith(('.png', '.jpg', '.jpeg'))] + pneumonia_files = [os.path.join(pneumonia_dir, f) for f in os.listdir(pneumonia_dir) + if f.endswith(('.png', '.jpg', '.jpeg'))] + + # 合并文件路径和标签 + all_files = normal_files + pneumonia_files + labels = [0] * len(normal_files) + [1] * len(pneumonia_files) + + # 划分训练集和测试集 + train_files, test_files, train_labels, test_labels = train_test_split( + all_files, labels, test_size=0.2, random_state=42, stratify=labels + ) + + # 数据预处理和增强 + train_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(10), + transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), + transforms.ColorJitter(brightness=0.2, contrast=0.2), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + test_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # 创建数据集 + train_dataset = ChestXrayDataset(train_files, train_labels, train_transform) + test_dataset = ChestXrayDataset(test_files, test_labels, test_transform) + + # 创建数据加载器 + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=4, + pin_memory=True + ) + + test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=4, + pin_memory=True + ) + + return train_loader, test_loader diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..5e426a8 --- /dev/null +++ b/src/model.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MultiResRibNet(nn.Module): + def __init__(self): + super(MultiResRibNet, self).__init__() + + # 高分辨率路径 (224x224) + self.high_res_path = nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU() + ) + + # 低分辨率路径 (112x112) + self.low_res_path = nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU() + ) + + # 特征融合 + self.fusion = nn.Sequential( + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU() + ) + + # 分类器 + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(64, 2) + ) + + def forward(self, x): + # 高分辨率特征 + high_res = self.high_res_path(x) + + # 低分辨率特征 + low_res_input = F.interpolate(x, scale_factor=0.5) + low_res = self.low_res_path(low_res_input) + low_res = F.interpolate(low_res, size=high_res.shape[2:]) + + # 特征融合 + fused = torch.cat([high_res, low_res], dim=1) + fused = self.fusion(fused) + + # 分类 + out = self.classifier(fused) + return out diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..0755928 --- /dev/null +++ b/src/train.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from model import MultiResRibNet +from dataset import prepare_data +import os + +def train(model, train_loader, test_loader, num_epochs=50): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3) + + best_acc = 0.0 + + for epoch in range(num_epochs): + # 训练阶段 + model.train() + running_loss = 0.0 + correct = 0 + total = 0 + + for inputs, labels in train_loader: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + train_loss = running_loss / len(train_loader) + train_acc = 100. * correct / total + + # 测试阶段 + model.eval() + test_loss = 0.0 + correct = 0 + total = 0 + + with torch.no_grad(): + for inputs, labels in test_loader: + inputs, labels = inputs.to(device), labels.to(device) + outputs = model(inputs) + loss = criterion(outputs, labels) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + test_loss = test_loss / len(test_loader) + test_acc = 100. * correct / total + + scheduler.step(test_loss) + + print(f'Epoch [{epoch+1}/{num_epochs}]') + print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%') + print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%') + + # 保存最佳模型 + if test_acc > best_acc: + best_acc = test_acc + torch.save(model.state_dict(), 'best_model.pth') + +if __name__ == '__main__': + # 数据路径 + data_dir = '../data/' + + # 准备数据 + train_loader, test_loader = prepare_data(data_dir, batch_size=32) + + # 创建模型 + model = MultiResRibNet() + + # 训练模型 + train(model, train_loader, test_loader) + + +# 只为测试 111 \ No newline at end of file