From 9052bb92432bd4bda8f683e14f39bfc60385aed8 Mon Sep 17 00:00:00 2001 From: lotus Date: Thu, 21 Nov 2024 22:05:36 +0800 Subject: [PATCH] first commit --- src/__pycache__/dataset.cpython-312.pyc | Bin 0 -> 4071 bytes src/__pycache__/model.cpython-312.pyc | Bin 0 -> 3355 bytes src/dataset.py | 84 +++++++++++++++++++++++ src/model.py | 67 ++++++++++++++++++ src/train.py | 87 ++++++++++++++++++++++++ 5 files changed, 238 insertions(+) create mode 100644 src/__pycache__/dataset.cpython-312.pyc create mode 100644 src/__pycache__/model.cpython-312.pyc create mode 100644 src/dataset.py create mode 100644 src/model.py create mode 100644 src/train.py diff --git a/src/__pycache__/dataset.cpython-312.pyc b/src/__pycache__/dataset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2316d8913a5fb612303934692587438b212891c3 GIT binary patch literal 4071 zcmcgvO>7&-6`tiTm&+fClt_sZE!&bMN2X&db!{j9Nz5uv?Mjj5)OL!1RlA_MOKO?r zPj^>VBrH)Ld=Q!x5rLvqdWn!;WLO6solE29l%NS9h7Qb5Y}h~z1iBfPS|I7AZ)Qn~ zW>gd{ffg^K18#)=Mua-qYn{ymbV^7<5~HH1)iO~A zWSeSBvr#t9MLC9yv#LGKM|s-kR7YBf3bb!mooQFp#UL9xk0ky&k{mL7leO$e-I7xh zt}{{3920hJkip5Y4Le?m88KZpu#0pivN1{4!VGp9S}da{vRYaP$v&2jC1vRNU|7kB zhO8T+o>LVgZW*_@ttM>tL#KpfM7)pMB)epTbyJocTP219RXvALQP zlX0$Qf;7Y&aei|ISP7YoiA3mGWZFPt8WIaqeR(v8ZZMN!7P}IPDvP<8F|A`>ja`*h z9XlHv3v-y$Wi_FBz>3yJSUWJCxX=&u?T~b8Dwc`Gvl%Hi70*iY)LY={l&;06B^ z{Sb6Y$Y@|DN;UQgP>pWZ@XJVp^g<9<~pj!Q-ACJf;Xe zhz`&#pnKlV#bc#NF;ejkt_y=&Cyd6MU1+%a2_O(Al!CUEOr{rx@-~BLQ%W_5x;AFo z>KjQpYC}zp#MIF|ggYfPLG}68DjcBnzCEp>b-*||o)$9MCcCW3C2W`B1Sc_j>HN!K z-bxQ1=3p+HlQY->32{xL8k7tD2G7Y(SDH9Se? zOe`&nB6f*lIxEe<{SmQ86n`}nQ!PXz{%BoA0m z{ha6b>uDp$^XX3Y^o%OuRcQoNX z!2KowAVJc&X1dGUWz4TGR0MrMrtD2IWHZroVV*$QsU{E*0Y|?f2kGYpz z4ux}QHosIAfIz6 zsa#iMMtoY-iPH{!dM1%jW$ez(q{Z2+_NJ`q*p-6{GA*aG+T6KtC$?pE%#j)mbE&KX zpI24U4N1|kAZH|fRxzeEat$?75NPDmW5?NSI+xXD%wLk}I)O{{aDI~2lzG@mOg*P6 zIoxiIUdqCeDcKD6Sn_Bhp=7|NvspE(jVp#B!%|nWSLBSI)v$At1`WKoK4x@Rek_oJ zPXwX4Q?JoF?aD znk*7S+UuaY3x4{ukiiA?sKZ|h7K5d}V&BsAicsk|QW(D%=q{Zvo-a)lCzg9wPFDgi z6fQhEdHSQ!hoRf!m6I3$ftbKcYopcfU@28htq+{K(|vl2MP0$htkB9&zYIM0nCKNI zX!9sC`qBJ{^Pe29MBae;!5=ZV2CBWG_e1Z6)`woY(|ZOC^n^;Y#o3ZsG?y=|9r-eN z9A<)I;iA<(7Es6TrC*nQBZaYQyLaL8?~X2-75`AVZD@sA)>kjD?X4UhEf0^D+eQoL zws>@EoT&;Oi_b4#Uej)$+hT2}_%io=;nl^^($vc7Ef)$7-0eP4?mlo&=)5cRmWAG> zgB4*vm4?g0a775y{y6)g;cPZS0M_{|fJL_sJ2xBlF7Vl2pg96mAD$Ay`V zei#b3X)i#7?M7CMPs3kL(u~X)imH#$gd_E#MlK|zEWxfRI^1jPO<-Reo6u;H9@O7d z<(QTkf$~*WMfis&$I1P-9uQiglFZfi5k?9Dn%mzBFIwg98Ep#2$aAb80QQ(=80IVF k{0eyvf*TM8H&Tb)Sww^i3?5CFg1z5?C!~g&Q literal 0 HcmV?d00001 diff --git a/src/__pycache__/model.cpython-312.pyc b/src/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7d2d5fe7bd2b260e22af98c0f7630b1f04a0bf4 GIT binary patch literal 3355 zcmeHK&2JM&6rcUx*z0`QG=V@86F|ZeiArcGMHMBjN(3TS39Txupw+UwjyGNJn%OlF zb-&AuhF*Q0N%uz9Ri!gDuDF)a;ZB!59`5Ww42MJ082|5pwlt$eW z6AZDRVDsPxmEbOd7=t<|v!-iI>CTjKC9brfzR_TRRC^7==#1u4%`g_ zP^wEP;0nIdEY8Z#(IWM-_o`f@G#dXg57`o1leFl!$%sam>hZSdk*IeJKe`sJcHU`? zdn51Mzv`XS_&4&-znS+^M=dEiAUEpLl8?q(Gj%P}cG>#BR)77T{{33js|NnryAs~+ z-FnSU;ot7ND)Eh*-MFVM5&iki?rq+8pVGuOFKf3iZO&~_rN1flDSY;hQXMUQ1jlc% zfl_lI7Km>VnB9t!4*|3oje=k6(9%uWF>dNe$S9g-STkFviN*r?0IV0IAvv7)*&;4t zU5InC8683bpGcHt!!lf1Mzm$2z(xIPR<~S3F%kc%ZCxAFQ1BDQRa0>rrZK}#>F2*h zp_DO`lA-R%8O2Q@$+WLGIX;CjjcfY&wVBUt+r+ztGp6FYx`p`jhNUYI z1yoaU93yGy5TA#2bTf%MWZ7GlWrxHYU3fXHot#!IMYSzWnO1F0pS}T=xoHQg(`n3^ zqnSC_ja6^Yt4i-}xSAM&#DC#}OivdB3pC8Hw_8&Fe-Q56>wNQC={F#!&E8*d=dM-LS8 zfy1yeSr-X)7I7FtbPX|f)^$i7qO}2shy@{O!|MvvZuahY?e+Edu#&TgKMUYlr0+@O z&_?9Ys#1;|&WrfD@9AF&Er(W>b+xklWInVR?tKzIun|76I#~{n<%Or=NMU4YWK~#8 zeP=ze9$k1e^?3AjB|M%N@aCT0LdQ}^@xgCRZd1kXKoDaY#jwZ$vthJQ1j4YJ$KbnI zABE~TsOqLE%P|TL;#G%S^V41;JpA@P4dDU0lR->nX%Uco>V3 z=v8LavLLx1DB`oD$*iTi7_>3$2bwfIM6~!xFH+CY15h(c3h`3>sCI@^@FZqQV9tL0 kdCpK2^&^me0`L3`h8}~VZT1uuxj$97v~=lLfJx89AI5oNF#rGn literal 0 HcmV?d00001 diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..7df3a4c --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,84 @@ +import os +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from PIL import Image +from sklearn.model_selection import train_test_split + +class ChestXrayDataset(Dataset): + def __init__(self, file_paths, labels, transform=None): + self.file_paths = file_paths + self.labels = labels + self.transform = transform + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + img_path = self.file_paths[idx] + image = Image.open(img_path).convert('RGB') + label = self.labels[idx] + + if self.transform: + image = self.transform(image) + + return image, label + +def prepare_data(data_dir, batch_size=32): + # 获取所有图片文件路径 + normal_dir = os.path.join(data_dir, 'normal') + pneumonia_dir = os.path.join(data_dir, 'pneumonia') + + normal_files = [os.path.join(normal_dir, f) for f in os.listdir(normal_dir) + if f.endswith(('.png', '.jpg', '.jpeg'))] + pneumonia_files = [os.path.join(pneumonia_dir, f) for f in os.listdir(pneumonia_dir) + if f.endswith(('.png', '.jpg', '.jpeg'))] + + # 合并文件路径和标签 + all_files = normal_files + pneumonia_files + labels = [0] * len(normal_files) + [1] * len(pneumonia_files) + + # 划分训练集和测试集 + train_files, test_files, train_labels, test_labels = train_test_split( + all_files, labels, test_size=0.2, random_state=42, stratify=labels + ) + + # 数据预处理和增强 + train_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(10), + transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), + transforms.ColorJitter(brightness=0.2, contrast=0.2), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + test_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # 创建数据集 + train_dataset = ChestXrayDataset(train_files, train_labels, train_transform) + test_dataset = ChestXrayDataset(test_files, test_labels, test_transform) + + # 创建数据加载器 + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=4, + pin_memory=True + ) + + test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=4, + pin_memory=True + ) + + return train_loader, test_loader diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..5e426a8 --- /dev/null +++ b/src/model.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MultiResRibNet(nn.Module): + def __init__(self): + super(MultiResRibNet, self).__init__() + + # 高分辨率路径 (224x224) + self.high_res_path = nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU() + ) + + # 低分辨率路径 (112x112) + self.low_res_path = nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU() + ) + + # 特征融合 + self.fusion = nn.Sequential( + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU() + ) + + # 分类器 + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(64, 2) + ) + + def forward(self, x): + # 高分辨率特征 + high_res = self.high_res_path(x) + + # 低分辨率特征 + low_res_input = F.interpolate(x, scale_factor=0.5) + low_res = self.low_res_path(low_res_input) + low_res = F.interpolate(low_res, size=high_res.shape[2:]) + + # 特征融合 + fused = torch.cat([high_res, low_res], dim=1) + fused = self.fusion(fused) + + # 分类 + out = self.classifier(fused) + return out diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..0755928 --- /dev/null +++ b/src/train.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from model import MultiResRibNet +from dataset import prepare_data +import os + +def train(model, train_loader, test_loader, num_epochs=50): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3) + + best_acc = 0.0 + + for epoch in range(num_epochs): + # 训练阶段 + model.train() + running_loss = 0.0 + correct = 0 + total = 0 + + for inputs, labels in train_loader: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + train_loss = running_loss / len(train_loader) + train_acc = 100. * correct / total + + # 测试阶段 + model.eval() + test_loss = 0.0 + correct = 0 + total = 0 + + with torch.no_grad(): + for inputs, labels in test_loader: + inputs, labels = inputs.to(device), labels.to(device) + outputs = model(inputs) + loss = criterion(outputs, labels) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += labels.size(0) + correct += predicted.eq(labels).sum().item() + + test_loss = test_loss / len(test_loader) + test_acc = 100. * correct / total + + scheduler.step(test_loss) + + print(f'Epoch [{epoch+1}/{num_epochs}]') + print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%') + print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%') + + # 保存最佳模型 + if test_acc > best_acc: + best_acc = test_acc + torch.save(model.state_dict(), 'best_model.pth') + +if __name__ == '__main__': + # 数据路径 + data_dir = '../data/' + + # 准备数据 + train_loader, test_loader = prepare_data(data_dir, batch_size=32) + + # 创建模型 + model = MultiResRibNet() + + # 训练模型 + train(model, train_loader, test_loader) + + +# 只为测试 111 \ No newline at end of file