From d1ab4a1f1991a65775cf273956121d18cc25c3b4 Mon Sep 17 00:00:00 2001 From: lotus Date: Thu, 17 Apr 2025 16:26:35 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B7=B2=E8=B7=91=E9=80=9A=E4=B8=80=E6=AC=A1?= =?UTF-8?q?=EF=BC=8C=E5=87=86=E7=A1=AE=E7=8E=8785%=EF=BC=8C500=E6=A0=B7?= =?UTF-8?q?=E6=9C=AC=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + __pycache__/config.cpython-312.pyc | Bin 0 -> 1134 bytes __pycache__/config.cpython-38.pyc | Bin 0 -> 861 bytes config.py | 36 ++ .../__pycache__/data_loader.cpython-312.pyc | Bin 0 -> 6954 bytes .../__pycache__/data_loader.cpython-38.pyc | Bin 0 -> 4595 bytes data_preprocessing/data_loader.py | 171 +++++++++ data_preprocessing/data_split.py | 0 .../__pycache__/evaluate.cpython-312.pyc | Bin 0 -> 8878 bytes .../__pycache__/evaluate.cpython-38.pyc | Bin 0 -> 4632 bytes evaluation/evaluate.py | 185 ++++++++++ evaluation/visualization.py | 0 main.py | 349 ++++++++++++++++++ .../__pycache__/fusion_model.cpython-312.pyc | Bin 0 -> 3870 bytes .../__pycache__/fusion_model.cpython-38.pyc | Bin 0 -> 2934 bytes .../__pycache__/image_models.cpython-312.pyc | Bin 0 -> 3758 bytes .../__pycache__/image_models.cpython-38.pyc | Bin 0 -> 2426 bytes models/fusion_model.py | 116 ++++++ models/image_models.py | 78 ++++ training/__pycache__/train.cpython-312.pyc | Bin 0 -> 11931 bytes training/__pycache__/train.cpython-38.pyc | Bin 0 -> 5360 bytes training/__pycache__/utils.cpython-312.pyc | Bin 0 -> 6613 bytes training/__pycache__/utils.cpython-38.pyc | Bin 0 -> 4178 bytes training/train.py | 312 ++++++++++++++++ training/utils.py | 126 +++++++ 25 files changed, 1376 insertions(+) create mode 100644 .gitignore create mode 100644 __pycache__/config.cpython-312.pyc create mode 100644 __pycache__/config.cpython-38.pyc create mode 100644 config.py create mode 100644 data_preprocessing/__pycache__/data_loader.cpython-312.pyc create mode 100644 data_preprocessing/__pycache__/data_loader.cpython-38.pyc create mode 100644 data_preprocessing/data_loader.py create mode 100644 data_preprocessing/data_split.py create mode 100644 evaluation/__pycache__/evaluate.cpython-312.pyc create mode 100644 evaluation/__pycache__/evaluate.cpython-38.pyc create mode 100644 evaluation/evaluate.py create mode 100644 evaluation/visualization.py create mode 100644 main.py create mode 100644 models/__pycache__/fusion_model.cpython-312.pyc create mode 100644 models/__pycache__/fusion_model.cpython-38.pyc create mode 100644 models/__pycache__/image_models.cpython-312.pyc create mode 100644 models/__pycache__/image_models.cpython-38.pyc create mode 100644 models/fusion_model.py create mode 100644 models/image_models.py create mode 100644 training/__pycache__/train.cpython-312.pyc create mode 100644 training/__pycache__/train.cpython-38.pyc create mode 100644 training/__pycache__/utils.cpython-312.pyc create mode 100644 training/__pycache__/utils.cpython-38.pyc create mode 100644 training/train.py create mode 100644 training/utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fcab143 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +data/ +_pycache_/ +output/ \ No newline at end of file diff --git a/__pycache__/config.cpython-312.pyc b/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edd963807eb6620af1a3cc09f07ed1d75aa4db1b GIT binary patch literal 1134 zcmYjPPi)&%7=KPu*LG|tS(;@nL5jqMPS7TRHi6J8FMet)5<9Zpjy_egOrDoIO_M5r zCbp9~QBJ55qQGfz0aPSTNO0kfcG)x#(c1+vE)kcZA%S+{y{1w44&V3t{(tX#-#tZ< z0Hu#U{-^bA9Dv76nNi{#JS}nnoB#|k*8ylMImF=@is3kl2cRv?Qs%OqoZ}e^;y7^{ zn`(%U0h+1Kd~+G_c;Op<7RLd+hy*Mk5sOH|5=!DElCg{wtRS@ou=)juQaFXuIE^wm zgD&6;D2uabiPq=Zm*;h7h+1Eo;kj(D&+!bMo`Zf2FLgPbJEePIgf8KWXc_0x3ciG% z!^@~}n_FF(Fqe6UUt7oi(7z+Bt@lR5!_koN2R$5g2JR|Hgy7w9F!Xx+liz6Yq{tXs zbL+ED8BI3M!A~3E-?SNtmrAA8MJCFWl}n6KR4!3eX0|f3mCNN2(BGTLtR6qF_ru17 zt(zp8BkYrS^KjI9xl;J^_bV$Kt-tmjAAR@lAJ-VSBI34IJBPWt_BJh9Z`7dUHK6TPYMSE^*|h9B z4Vr3c?fU<9=n$!D=sGm%9CaeHK2@mc)XZQ_+hH9Sb=zvtr;xKb&I-{)9tXF=X3)>F zMakj7^KbcK$A7025MGDdMg%uO~TB=QYtY*y(yU`5-Rc7P45Qguk4#?^K_ z$F&*_7T%zbgJuPO4W82K;MoiKja|R%H+x;|?>6aG>~@0DesB;nt$Qzzc5;IsDSHycC*dq4SJ^2ogF{sJyji=2vtv67-0YAQ;O)eL3SCQ{dO;XcTXQ%jLdvig0HrY7o^ z#YQwNExTH Rkvwg;Qg{IJ{nzLd{sWhLFc1I$ literal 0 HcmV?d00001 diff --git a/__pycache__/config.cpython-38.pyc b/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7abdf5675030acd4542aedc66919aa130699de30 GIT binary patch literal 861 zcmYjPJ8#oa6t-j6ah`p@Td2fB2hy5>0ig!hzG^Juhis>%SCuT2V^X(i62(ac{RLh4 z0Wh^}Y%C1?1Xjj7A<96+#D-Yl-UNwU%isOJk^ha(J}!n1>Wjfq;dGI}xuYA_2tf6qcw;M-ohc^eV-3kj65|I0G4+g)GiN z4(B0{3sAsCD8eL6!8FXkEX+01coOFE6fEFrSj021glAzH&qe&ze8itDTrFO6-Ah<$ zrC{}%$4jt=mth^Rz%9HAxA7YA9=EwMpnY=Ig;d&tuoqs)rLA_q*Xj3&bk>HbnfRML zk)s!J(hJ(BgD=#1u)#p2(0KEjtwH%8`BaX-)1hEtdwZLKLP23np-f>mh1nElQxqlU zsPiDtvPu85eJl@BR1vY>hmZ*MPQP(aTl?|#!AiOD^Z4fS!>?~UgZnIq{fW(Wl4>Vp zqSx-$k4Z*XeKl}x+b3DO<~ub%&<&S}ov?RI#FKX1B1O~IQ8nitNL1H#B)~73=VAf335G>^Ig@j0$2478)x>_>WGqgQ+-NPnobpYfrH#Y ztE!$y@|NwI)M=!q)ZKB7Jd&*#x{fT`heu?%EDL3K{k>Ohf}E zh|c4B)MXPO`8Wx}^Du6P&zcdDbhK}1h!mfxvYa%`>Q z;YI3KJq$eGb{yvJ(8D21L+`lmU7CIfzwf}qqp%g$+bsx>>hxk9HKYD%bQUw!dLB1Q ootH#@M6XQ08SU^iE-6uvd07y7kuM6x;m)`?WHV!WiCO9M7jGEx>i_@% literal 0 HcmV?d00001 diff --git a/config.py b/config.py new file mode 100644 index 0000000..b114e1d --- /dev/null +++ b/config.py @@ -0,0 +1,36 @@ +import os +import torch + +# 路径配置 +DATA_ROOT = "./data" # 数据根目录,包含000,001等子文件夹 +OUTPUT_DIR = "./output" +MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models") + +# 确保目录存在 +os.makedirs(MODEL_SAVE_DIR, exist_ok=True) + +# 数据配置 +IMG_SIZE = 224 # 调整图像大小 +BATCH_SIZE = 32 +NUM_WORKERS = 4 +TRAIN_RATIO = 0.8 +VAL_RATIO = 0.2 + +# 样本数量控制 +MAX_SAMPLES_PER_CLASS = 1000 # 每个类别最多读取的样本数 +NORMAL_CLASS = "000" # 正常类别的文件夹名 +ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"] # 异常类别的文件夹名 + +# 模型配置 +HIDDEN_DIM = 768 +NUM_HEADS = 12 +NUM_LAYERS = 6 +DROPOUT = 0.1 +NUM_CLASSES = 2 # 二分类:正常 vs 异常 + +# 训练配置 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +LEARNING_RATE = 1e-4 +WEIGHT_DECAY = 1e-5 +NUM_EPOCHS = 50 +EARLY_STOPPING_PATIENCE = 10 \ No newline at end of file diff --git a/data_preprocessing/__pycache__/data_loader.cpython-312.pyc b/data_preprocessing/__pycache__/data_loader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..508ba2c444a888125bc4b3482cf3de7e08eb9c8b GIT binary patch literal 6954 zcmd5gZEzD;k~7js8hu!n^hDD6p%($V^hYI)qwQaz@}C_P1<#5*m=aj#(x5jwoMsM*yU2-04asM{Xk z0xZpm22|M-4Nx=#bOXl(B5Ze*>1Cm1fmSHu=V+Gmvwh(Z*Ch`iLshq9p(Ynz z1ueYLQbS7_9Rt~_emxKko_xtA-pc*eaQLW?=g8=TP-NB%p9QKt6LLcVhV7^-D9S&(FMi`tJ3!1R6m`f+po&JaOK*v!l%guiiUe{y zB6%2mb0pA9i#jl~aG;B3kTQuxw2zL6IxxzkG{Z?2E7`87;i62}5mC$05jM(*L^v7= zdPzyiMRkaJLDUyqlt?NrO4R3_M2KRRfp}&)0#yK55&&`osL1wV^!tM}7vkt%u#w8| z+DL(jgs>k2h-2BRnyIRF>8f>;>U7nHG+CL@k38`}i#aNiJ4Sa*+3V8wy4157`?@K+ zFKzc-HGK3`#=aw=2Nt_4*)rNP7RuP_r)*7WThpW?WAi0QAlaPBhS7$xpJc47rmT%= zYvV+J#=0?~yJIXJZvACzwxTlm(&$U6BbkbIiJjSUcQP^>N&Pre{#ati9g96_95s$T zowj%qnt715EbH`SE2^d)?&QJIgQ=C5`mYSUJ#fp>@}GK*#W0U)^ajbe8D#lziGgQB z=i>K^2EV^IO7+7{;TKJQ|IhjZVVS}py=IVlG0085jmZL#Jut{jC4g*@MT&kByRR~8 zYQJ{lnpO9;6`K5Th_MTuf>IDdaws>z5r9L{#fc!i8_we~-GUoj$_W-xmvAd^&*I$$ zH&Vo_RR&)8U(8pFxQb_(r^2k%=W&3j`yP${j7`Ja~-P zD$-&GuHYn=awC6(bOTk8UT`Ev4xPGGc%y>lHP3H{Lv%mFffV?0j5ks^uW@5BBX8V} z9eM@az3CA%MEAm?!kd&C^i(b)L6g!^WMOwo9wB`R*UTIEk~8pTd~4y#Fh4>HGW11- zH!E{w@u7ZR-?L&712lOPZ(h(i92*)WFwlrnUQ@K&Iz`R^x*RRMZaX%D9|E0;l|5os zORWv#%L;pA{a@D{>tC%omcL1JEdS3n$NGPuIi-KC=8WJ&B?QLn`O*QB+I7xOoO^yDm$Rq*B8v0-PIDhQ9)LvXLe zoV<;9BIh%z8qyP3%mtY5^4K5ZaRX1qBji*r$@8$BWicn`EzI1-yLnQHXk*KZD3$!O zB1#pvrqE|OucFF$^(Cva&e)0~Y86mz4^cfuCulF))8)!uf=6@|;j0$#ZxKoz{QC;t z)2%uHdx*mqa~^v{&0F5T7TlTvJ+5-UOx?SFJ@+p|bLUUyE}Wbl`Gp*-N<;}f9H2_E z7ZZ(g5GzVS?M&)}*)#9tuD=dpX>RD{+{neIrlz$`P0h2H2XmK4X5M(^?zQ)F7e0_= zp%LNT>1RU&bc^TS7c^!NIYMVaSm z2F1wIOys3NRGWMMHxQ1_{^8$dCr$yWAXLsKFaa(Ug~8vsJbUc~gyb_pqAY^EP&ph{R$9DAx$%j@ z$VXwK3g_JKhVEVY=cPD1ygE@0hYFy%2?3}zH$kl#;hGVyxw+Y^4jQI@-M7{7Wn1vy zkKQ=vUiU>?@QtTWg=hc!Pi;V4-bro#a^ELy!Q*qQn#oHy+k&6mlu8=}$3!n3h=^*I zqXzU1N+`oXY71jf05qWcapW+HZWt6&&LJn*k6r*Xue}F{m`hBet(A6K&X~X|T0Gz3 zkaxA11I6Jp@oPn>xduQ%3PUgFoj_GZl7%hhLfywj6Xd6;Xs@5;0v!Dea2XaydXIY2pIOTh5>du&8@F+S?apX;HhMMn)r& z&qSY*V=R4#q-B6)&?|@WAbz4J8j6TSFdRM1pkYO=JYQxTVnMeHQzN127eXw@irPRQ zM0Av>>0?3>PE6rnyhdEkw_*Cy?pSo{|aFq zB*j)?q*KtX5G(fxq!Ys4vdSq-P1;g3Om*bq#Mc>exl~`4 zCYOC$z--oqAYe%MK zU%GmqKvdmzHzdr{?q%mzkFP$rc6@E>_+(?oy)98X?QkWZ9ewsJbMDyqv2*-5Ke6xX zbD0&pGLH6y0TyX0O$?lB6PEivHTtGa*5Qu9jg)4b%^jMM~L*|PXtl7h zJyWqeu@gqJSci`d9vgmX@TIXMpOvo8T0B#h)oIJ>3D@LKp?>RUme#LJFn8U&4Xa!; zUx(SMh1&IL>-zZatkpJRdyO4y%~Uj`%NiyE69?X1FZgz3*6&O|wliI}Gv0oe&aF^H`y=HPg^0c%Kvu+inq00XorfcKsNf^%}vwRxmtvi)e%i zmX-~N2194nW4klX`i$8d-!Wa?oFz-MHv7nPqg%6Om03r{1HGn3nB> zpW8Q2;3Lna%3l4cVBZ|y3n07ZiO+4DQh|~7v7f!VSFml0?*VXU&Eub2eSqBz*jB;n zi|+=I^=|mQY-4KMNW<7ay!x0>wlUrg0A}N;OYNQ9@{ucJ-wmp2txm0(7`VDAWBoyV zH&E?WS<;cUEzer)S^El@lwBL&^T3Q*mt{$tgm=hzs}^+GjA;yl$#cP$I+UqiFL*Wx z8`^KK|M0)6{|LwSrxMD8ko3C}d{@0vMWtqT?XUh+twgEXbMEY}H%`+aNGe7C-N=opAC;nNu+Am3eMWXZ3fY=FtZyErVC~+ktdJ}^PS~g_y zRty#1vKuSLjW~&PG(!WP;_AhOktimOq~NP=s+cy?0*|>}#f*_5(sL4xQ*BnFnw^+Y zjBc8x8QMKdj4UNTmU2BSFuydOBRuwe(JI(EnQOZ&-|}pqClT3Qwy4cOp7fbj@+Qly z=mC`l$|;$??fIrxaUDO8b~+ygh@J5F9syuWxGGqYPGzdV>LWy<;C~@_UVcG(#bp5-w%(@G(N7?f4%f5btWgpfUQF>l_mt1L+Q=gyDYob zgzbXOaOp<)$*IQK*?R3*yOCd_k3bQM$oVPvCDHS4?D69GbG`ql*7Y86E z;v`O{2iBbSa7!@{$Jc%kZ=X6kHa@10ucmvA^Vu5QIdx*ruap$_kCULp! zaTTMP>l1L%OIFe58l0bNl70LWBXXEJEb(S80MKX{w;m zH0`1pcz4gs1&lSlAXO87Y=69_1yT@)XiWNgO>c6NG`k3PswqDaq-IDEJD{@DezKMd zQj3I0L5c$Js0{H3oL$n3&sNi7WI9Mg^mfr6+6xxDTGoQDa}r&50b@QRc$pxxNZ8pR zQR@!UK?Z#$(!sOfQ+E)9Ikg!wC{?q82DZ;udzyBnpa(4UcKF%5#?R~lb%C&*Ydp<% zc>06?>1pr%Jq4S)g6>XV7s$a~2mcWyg6@9lRYLm;1iVhnkUoj77kD?~Ad~zo-4Myn zNf+c9SSN7$fdd*HoB=;E2U)7k>8bTb95{t_{<>ftT_30e(wsy$(!qHZa!Oxy{V_60 z=q9>(RuB4iNi(Fn;g|%REghVGq4R*?tQR^1Lg&E_&PKm4*bwvw`ixQ?1biUaDDX}G z`d~1S=}@51Z_KOJ%|LGg`kPI93#{ZJSjiyzO&=TP9`5Lf@L&3c)Mt zOZcnW@~vCp>nB#;o((U~HWprqYO4?maCbo!)q<7FKGZf*gzL2%jdLG`x88&@8lE^B zE_?vDX6wkvXyfv$;pK(;>0@_pUJWnaXjd~xXr|{(+uMeguguin{5UFn;l+2uD<|7Z z6PBOf*Ho*&dSm6-$u$b4siOH(nH{hh)^y=Xv=peq;kDmF`D}dh`^J06LDV)H^*qe- zoia@R(dEX?8E7HuOQ-JKKHgR;S}&VXzVpnA&CI-Oc{qjW5E{R`8P0uBKYt{=cy@W| zW_a?nDEin@bWT))F3WP;Lq!h|4Mp|Vj)V*I?OCmf&M`Lj6rx?Y@2IH15Mu>+!S<(E zsiW(O78(`s)=E3~C_M9Cd*(ynXiIbD{S(Vezqp^~I3~kU0HY%a02>_vY!rE;$QvCU zEhPW=PUXqu7h{D#9J>9R!SDQetZ@37<8I?$e;O-nG3XOt{P6R!!r_%iM)eOr8!LSN znRv#a{;NgXDsjd0>2%_;R&``OkYrgefNAZgr3oyU3_FPx1XEaD>Zh*4@r83&a8tz| zN4*{RzisFqjc7kRHLz~fkKF8Ftvm*-oMI%Rz8Hek&^_2p-dH>@U>1) zzV=jXJD6KD!yEtK8OH9N;Z)}gd!VmQu4WppdHbd&CtaJTI9`jU3cI4~tn+18gGfDC6+N32KO`d5gHIt~()uboOE_{g^Nzq|W zI`cmc&unDY(A{l1nv`o!i8}CgV{}EUY>_FN8tX9&%fC@GR=BxsJp&4rKdNUuromI?7C#iLOI7jd7H*-qU4ClQDbokaU~ ze;0V(&jCp5bkeJ+Bn{tN1OJX|D#?&OB@LZ)TGmOQz_W^~B88GnhXiSzY?gZo{6K`{2J90JUc>~REJ5qulKc&;Z3;pdR^ zeE>WL7pQWW?M7Vr(et~-X9;@{$*lMU8J_YT*Bi!-V%SZKAcYvj&yz!r=afs)1?u+> zxVFVg!_fQNkST4~&O`Ywb%Zy%808Vs?+YMyQ{NtaFzmpJ;o-f}N%Tv@I}2NyR max_samples_per_class: + random.shuffle(diff_files) + diff_files = diff_files[:max_samples_per_class] + + for diff_file in diff_files: + sample_id = os.path.basename(diff_file).replace('Diff.png', '') + wnb_file = os.path.join(normal_folder, f"{sample_id}Wnb.png") + + # 确保WNB文件存在 + if os.path.exists(wnb_file): + normal_samples.append((diff_file, wnb_file)) + + # 收集异常样本数据 + abnormal_samples = [] + + # 计算每个异常类别分配的样本数 + # 如果max_samples_per_class=500,总共取500个异常样本,平均分配给各个异常类别 + samples_per_abnormal_class = max_samples_per_class // len(abnormal_classes) + + for abnormal_class in abnormal_classes: + abnormal_folder = os.path.join(data_root, abnormal_class) + diff_files = glob.glob(os.path.join(abnormal_folder, "*Diff.png")) + + # 限制每个异常类别的样本数量 + if len(diff_files) > samples_per_abnormal_class: + random.shuffle(diff_files) + diff_files = diff_files[:samples_per_abnormal_class] + + for diff_file in diff_files: + sample_id = os.path.basename(diff_file).replace('Diff.png', '') + wnb_file = os.path.join(abnormal_folder, f"{sample_id}Wnb.png") + + # 确保WNB文件存在 + if os.path.exists(wnb_file): + abnormal_samples.append((diff_file, wnb_file)) + + # 准备数据集 + all_samples = normal_samples + abnormal_samples + all_labels = [0] * len(normal_samples) + [1] * len(abnormal_samples) + + print(f"收集到的正常样本数: {len(normal_samples)}") + print(f"收集到的异常样本数: {len(abnormal_samples)}") + print(f"总样本数: {len(all_samples)}") + + # 划分训练集和验证集 + indices = np.arange(len(all_samples)) + train_indices, val_indices = train_test_split( + indices, + test_size=(1-train_ratio), + stratify=all_labels, # 确保训练集和验证集中类别比例一致 + random_state=42 + ) + + # 准备训练集和验证集 + train_samples = [all_samples[i] for i in train_indices] + train_labels = [all_labels[i] for i in train_indices] + + val_samples = [all_samples[i] for i in val_indices] + val_labels = [all_labels[i] for i in val_indices] + + # 创建数据集 + train_dataset = LeukemiaDataset(train_samples, train_labels, transform) + val_dataset = LeukemiaDataset(val_samples, val_labels, transform) + + # 创建数据加载器 + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + print(f"训练集样本数: {len(train_dataset)}, 验证集样本数: {len(val_dataset)}") + + return train_loader, val_loader \ No newline at end of file diff --git a/data_preprocessing/data_split.py b/data_preprocessing/data_split.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluation/__pycache__/evaluate.cpython-312.pyc b/evaluation/__pycache__/evaluate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30035a2ba02afb006b921b98527fb6f16d75dbb7 GIT binary patch literal 8878 zcmd5?Yj7Lab>0OQ?>9k^phQY!D3Jmsk)j@!tmt7`)Ps~o*QTsUv|P>#3qYXopmrA$ z$pRiIku5M*LvZXwkR{tl9}~D@YbKpZnZ|bdtK*qAGe7}}oHa9X8&Q=0NtF{%{iE$U z3oHOa@K{Ycon8*^?!Einb02&6obQ~=e=-=f1f<`-{)H!APY}Psf|O*1n@3mR<|@Gu z4Cx~V^H*|ECYI%XYEVI9oy@27s|Hnm^`P3X8Pxc-gId3CQ0LbV>ivd6gNTv)jDFLg zNxUb0=0OXzqkJ`f>!8(NJ6P+t4cg$DhEZH12kYRfglj!qRdBV#RSnmMG18&Ek3;Ks zs01VJbcR^kIX1*OgRDz1jj%4Km-7Y#g$g~CX`io9QoFkgcV=fW;0}pEKh3k=Q9;87 zokJk(yi1VNA*X{BsAB{D2OM(2pLcD9p?*gQg6Ko^CAV2h&1U*NecMY*FF684m zCskNJyeQ!(hu@g>xmnmc5nb8n)n$Vy z+=g`wHLWP3^<`*rTP|;bXbF1?dWaYz#xON(SdPL{SUYTbyjEJPVar!Plhy`SQ&Cm2 z@qapJU$ZB@Cxa! zbcIc;-y^6l1Wt=L3Lkyi<-3n!K|9D9ufIQpU-sR<{`1AzS@Cx9!*AoQ z4r&)KzrOUtzgha;<(d6@^>_0vm3jRY!Y}z? z_hR%zkS23?JB$M5I#2rq1s`OcX9QIsIOJhzMv(DALB$2#yq_Kw8M|Z7z9D1ZkUE~VZ;kcN*Ec3kX6oDL>N_&^9jT|X^_ydd z=WQzz+o%3`&eoQ(wI$DIZJn`$50wN!dy}Ktx=pb|pDPL5imOMbjwahv=W?~(u>%YB zjaSE}#*+1^mTdioiNmpdvCso0VW^+eH)Zrq3HN)$*M_hA=2mUWtlE@awK=EXGOxGI z=^Hcp#>7kSoxFDP`pa`ootdW2Y}2NkzWb{m-qv?NY$57ezL3MrK3_+eRwag0j+v8L zt5(d{uTI+2bsaI?0@aY{NP1^h=BV8;Y=t&E6o z+Ppw%#A+R`%s{%nJFVTEqqa!ZYFM^ZYr4KOt?kNDo3Q#&{7|9?=HHsOwxzZ0IcnW! zx|(=jQZv2zgB|bhnCZ=~-;>tuNmF~i*iXVhAK-iVlEV#rvqy1or~DsxDh|n2&Z4O- zyjLtg3WKQtlS?x2_nM56xkZx)>aL(3z+4bvxuh&lL7NB_QIs2AR5sG{h>BN$R51lCATM?S5+S!nZhtr;C7^T-Oh9;6tijaA!L zF{O-&F?-6DcWIoss1*s&majk?E1*zhYO2m3KjDhRSgRsYAy*`(wkoo=0tQ9KHeFX0 zTUP~JUqXgutgY;A*&j!jh|{t+Q^oX1L<2~KYJhOqLr6xJex6v2{q)7>_usw#lcm4^ zML&zvbja@GzPOJY>3vdAxxF6FJMI#+49%VmvUI@Xx{q8cgp{S&wt3yYm-aXBZU$qr z#VwEypr{IG&=+I{nePINP!<~rRDeD|HZI$^(IIE|VC`P4)%d&t7srqJTplQjX(@ks zu-Eg;f5!A+ZwR-+;>FAN6MqZ#5KNI+v2^3TQbjW&Ca>Q6;QRNYuPy%FU*7rX#?r*k z7C*kY4hHJT&vIPn3tn)s&8*jRhDV#}3p$TEj4V!?Ma?V7M|`}Xgn5U6y;I1m1!~mi z^$XM(=E_mgNC?WY{8hnwd7leRhs%RDjq$8+I#-F^Yu0=peAip%RS; z_dCd<1Sm?&inm)5FC^AXnWDX*85`Fo2WJ|xE5Dhc?9u%(YixMIXo_sRD3!0}HuroVL}z5V1ZYEVMZaZBuYoK9Hc z$CFQ`=-I8esAr^S&iIztV2-kv%3EVE=cooOLr<|+bNQai?|U!yrj46&RCfUs-+hZ} zkpvvO)c>z4^MYD8shiNndvoelck%l zmUnf^sQ;6=IVsT6%CPc@+!qz#p2U4o1CByYIgC5AB1zx$FjA+ABt460z_F|5j;O-w zqW#c@DR9$Z^F(xDS9D^3Rh*Tuw(#ut97{>FUD-~dGgGY(O&RL1M8^T`QPI|co-4Xa z!&S60SjB=3jVMaa!*DfCtDIIzSddyI?%=+ytXNDaO6w`L=po|dE3l_Tr%x~Dx`^Rx z#%7RO1F}i#gHbaYu&}5dA%Xj79A2HjFVDlJwx+UqHJ4-Z_h5bXVLhW8hhmw#R9dxf zJl>|XYOoJB!YIt(l9|tf&Wc#ddRtSar73J-40-+nE|7)MtpL6|8DE_IL7o*Y{_LlBKKjMdTPc1cL=_h&Ke+d!*B7t+z#$VT zbdCi%_{A(LX%?BX$m05e9%9XXWasSDnC-$MesIuEAAk^~e`w#>}X_>X|wo+vdE-aGzTo|WVRVH z!3@C*2=DSkZWql1Mda93_zcx5i!03bV0Hj9LF0nhhKq&S3komq@^dUYK7zsR<%d8< zfO7{~KXAp+X^~|rLjmu(kP9lDG!HAr9zjXHnDt@y9A-x$gOEos02o&QMeBK_vJmtb zgQx}DkD%uvbJVh^u-Ubkp{iy72r}rrFzg}5>*VvibUi*rO~``sCs@>Vf>8{U6elUD zAwEI-y!@CzdIhy4M4)&s;3^Y}#r#W@dzn$e0^8vMPcj(X#FDmxZi>SEen?*{k6`UEDsfOn=*RzftskPjrNsAO$>^* z-#2Q~p0TxOZH_rxSH{+rwROi1LT{^D=bE-;nzqdJ&W>iA4#)Z*DT$_zxO(2)lsJebq!>?2C`i*-7+74AcqE@>FcC`$}RochYeyJ z1jhP#Gfd#BeafEb$XeD$`$~NM!Yyj`oz@MhSLQZ8liB!8cH^^it@|_T)iE-oTOB(b zKeJ%9UEMLYBXKxuT^mzFLvwA?GE9rHu zv=w?=Q>}1CkG1Lvb-ea_#s{s$njR7&Lk%n6GhQ>M3^VS`y4|ypPhQIG9Z0tyOTT<7 zXMZJ4*%oxB#0#kxbGl8@p7~n)Ty0yXwk_ShBU}4)^eDOxM=lyU6~CzYGc8bNt;&ZsOD;gOV5>gI54g`sy0K_=BT>r-p*@%U;GIP zd*UGhSw7@6{cPva&GKJvRvg{6JmjQ?Z$or|dLX|Fr#KAxXJkZs9hV-AM8m?)y{>&#UNCK z;xTrdN8$)85+j=~H-06ls68tZUvzFt22(Pe5fwv~Jtw64aUIM?%@^H|5^W2UFvXyPhT7@Wgu>&9rt;SJ!t7%i@$IJ;t)a+y@LvK0wQ zsI<1ruxLIj(#=A78Lqhk1VzTc7^h7osba{iG`DK5PI-v2ghWeOkxFA2Gdtmw2KV~~ z!eI#^AGrI$TlYTx&YfA{5m)kOG)r$^${*7>C>Gx+i|mfYH^w550p5hWm)(IG@+QyQ z*D&4R>mWsYDMqsjhdx&>-%Vc5AKt`Xzx(}}5Wddx&hkCw2O$zGk%fuGihUNa#{_bW zJqyI1N6`z;K2^sVIM^#`Z%;;~`wpix5m8ox;5W%P{jfwYq~9S(aJQcCr%7F&o-6E3-Uvpv^G#Nu@HjfsrY5m8>Yb^?H}P z^`%-=54k5w9q!fr$Pe2?gM+pk1jClq*gw?D^|0OOO2?KPbAM6fcv!C?*l|zjZns`h zg>mS~xvy4qQL2ZnPB->yEf1dcNKAEtFs^+nf)lyNy&CtTZV*RxgT^>K%J|LUdKg#q zh^i=w(pOd#RcR_c16Kx@8LNG5MLGY9swCWX-&#>tR6HvKOYsJa z35&Dca_r{sD)%biQJzq~r^LC$NHU4JtoKTZ&dj40{~LXiDJi2zmS+KYQw-X81If_KG&~2cm3*H zt8c6fUi;DNGq0~NUlSJd8;zRZYKrWM_PkUwfjjR78<=_pUsh*+Yp!devHIIz4&HrN zs)MUPBsCA~;Pk1rmEW!X`1Jj=!rEM7C-r^zH015Iw?7!&&W%=F+VlKx1{YsTdz$Im z>t+%0(fxyGt|FS&*k36M!#nN4N_z}+;7K1Yj;G(<=kDWltrc$GA|H$=EiO<2wmoJVKVQyU-v}5 z9=K6dYr8Ek5;kRpoRmo9NXL7yks8iNA~WyC7+!93G9zIPXEPG!uxB>x#z+k~cR2M? zYot-2>@Zp0RQO#O?8kSq`wzQq7fHk1!*vYuaNu>1dM)2QOq0=dVh0?+iQ>?X zyrXxb{E6;_P&7+3RGgZvT3S)XsoQ!{E$idzxVl$&bW62VN1s&XsquTzr%tH0d00^B zlIbX*Br3!CyvDQ!QVa_j37Y83+F`wC^fHv1wD^!-*;qTGGo2adXgO<0HP{Kc7$vK9 z(O|jDIA zQ*1~2{C=M8Bx|Dazaol5Kq2E52wajZl4QDD!V9jw|H|OZJ5PN3;nj<8ul?!6IZlXF z(KcunHYg9?R7JMoHzWU~Cv4{OqakaZPz3v#qsPF7;Uw-Q0*#~Bk4tF;`Rx1F= zOg#)jF0|kTr-yOMGGVmc<4Y&CJMO6H`~X=WAZsr0+g=ne1zrt)r}9d)m|hjKu*<>{z4 zvcpD>@F(GQLm8TcVN`}5!>8RULrv8AOIlithK1<(un-;TnTffr^kmW6Dm@lj1`3gZ zl9E|bdO5^2lWbzGXg!nz6tca&JhCJYpKM~w-=Qz9)ng0=D4XP@e#^EpWA$@M24yc3 z+eh-0VRFSgC}3FI%TyrWK}mCP<$HygybJIy#F9O*9Ex8NWvwKwrP%J5na%Qv!|a{P z3Y90TKSuEp=e}MU1+9Ek>y6=A<0v#^%pn|*noSxjB;||BWewbT7 zpg?b8bL=etKgK;V8aHYGZQNIB+$m~6`#(kuuNjFZ`UQ<2Kx#}qchQf! zd(c>a?cCtPna_?qSJ}gNP%nBqpC&Ow;wBP%Aw(GkA9OaZH9R*4C`FuN@>vp}Bk>gy zRJTOV10V1>USDMV*lR_cQd<-neq2M0cGL)Y3+U9Hm$+tj+y1ex2SeSBF=hPg6zOXu z9wYHMiEluFCWLMHcx?*p!~~uo3QMT;{F~(TEr`k_ClKQF)%h(X<{+wt&0*J=UyoBH ztk7NhhWv{pZYRH@q!y!>#O5JnZs5mDLiL5U8Ki2&QQO-Vl$z;R)P1%n#;_jdQBK%a zpu)%`eTjK&!_vK*vx~J;JB|;AXP)#%y`O-0)PMlP&^1;2muYNqjI!hvd(<7^8jdzC zk=F*}D1vKDtCphu+tz@`IEj1gQVE9%mACGu_-sjRHX(U{C*8nHh}uzcqLmy?qKbl* zXv8*}Mo*`$LZSe#_PhY(=qsm{hAQ_=R-!)wM3e|RN!XnjzzrkT!9T!BByM6G5Sw5$ z@akn)4mx`5rcFmSwi2+LWe{}4z7=?y#bxz!=nn)2lADvtP9@1DW@3Ta0R4`2V!aPE zl_;QtgE&LYVZdOZoeq+YEkgPfR-9HQ6;|3!QS=o}-Y;TDH#XEu5+U^?Wj4+x*yK5Q zB-#a)TZl6@WHfqzTw=`>ux3YhH|A&c%8TnCEMI>Y*>*N1|7)*4m-2s3+X6m8f+`-L zf~XTG**sAB6xyDVOGGYBRKlm7o_YpMa7Ouwb`0B#twId&(e~Xb%#_h0+d=GUkGE}Q z*#`nYiV2?{`~?N2=%%76CvPU3%ZD5`+IIjP#0{bcp=@>l?9qqk_mLpjYrgt3IyMeg zRJqzE`x$O?&YoWFpH8ldbR`1>Yx3Is`{*F_lYU3&^Vn)^!3LoScv9@wk4F4aG`4_b8(P@aqU7+= z=I$8(A|Otm?nNYa@cl9pZ3VD9Q5Qw#gwMc_7RU@P*#bes%_yDpke+lm-74NV!Ixl* zh9oxxGYDs4!EpmM6$P-XSF=)2SO?WGXMYp literal 0 HcmV?d00001 diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py new file mode 100644 index 0000000..682959c --- /dev/null +++ b/evaluation/evaluate.py @@ -0,0 +1,185 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc +from sklearn.manifold import TSNE +import seaborn as sns +from training.utils import compute_metrics, plot_confusion_matrix, save_results + +def evaluate_model(model, data_loader, device, class_names=None): + """评估模型性能""" + model.eval() + + all_labels = [] + all_preds = [] + all_probs = [] + + with torch.no_grad(): + for batch in data_loader: + # 获取数据和标签 + diff_imgs = batch['diff_img'].to(device) + wnb_imgs = batch['wnb_img'].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(diff_imgs, wnb_imgs) + + # 获取预测和概率 + probs = torch.softmax(outputs, dim=1) + _, preds = torch.max(outputs, 1) + + # 保存结果 + all_labels.extend(labels.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + all_probs.extend(probs.cpu().numpy()) + + # 计算评估指标 + metrics = compute_metrics(all_labels, all_preds, all_probs) + + # 绘制混淆矩阵 + if class_names is None: + class_names = ['正常', '异常'] + plot_confusion_matrix(all_labels, all_preds, class_names) + + # 输出评估结果 + print(f"准确率: {metrics['accuracy']:.4f}") + print(f"精确率: {metrics['precision']:.4f}") + print(f"召回率: {metrics['recall']:.4f}") + print(f"F1值: {metrics['f1']:.4f}") + + return metrics, all_labels, all_preds, np.array(all_probs) + + +def plot_roc_curve(all_labels, all_probs, save_path=None): + """绘制ROC曲线""" + # 对于二分类问题,取阳性类(异常类别)的概率 + pos_probs = all_probs[:, 1] + + # 计算ROC曲线 + fpr, tpr, thresholds = roc_curve(all_labels, pos_probs) + roc_auc = auc(fpr, tpr) + + # 绘制ROC曲线 + plt.figure(figsize=(8, 6)) + plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})') + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('假阳性率') + plt.ylabel('真阳性率') + plt.title('受试者工作特征(ROC)曲线') + plt.legend(loc='lower right') + plt.grid(True) + + if save_path: + plt.savefig(save_path) + + plt.show() + + return roc_auc + + +def extract_and_visualize_features(model, data_loader, device, save_path=None): + """提取特征并使用t-SNE可视化""" + model.eval() + + features_dict = { + 'diff': [], + 'wnb': [] + } + all_labels = [] + + with torch.no_grad(): + for batch in data_loader: + # 获取数据和标签 + diff_imgs = batch['diff_img'].to(device) + wnb_imgs = batch['wnb_img'].to(device) + labels = batch['label'].cpu().numpy() + + # 提取特征 + batch_features = model.extract_features(diff_imgs, wnb_imgs) + + # 保存特征和标签 + for modality in features_dict: + features_dict[modality].extend(batch_features[modality].cpu().numpy()) + all_labels.extend(labels) + + # 转换为NumPy数组 + all_labels = np.array(all_labels) + + # 可视化每个模态的特征 + plt.figure(figsize=(15, 5)) + + for i, (modality, features) in enumerate(features_dict.items()): + features = np.array(features) + + # 使用t-SNE降维 + tsne = TSNE(n_components=2, random_state=42) + features_tsne = tsne.fit_transform(features) + + # 绘制t-SNE结果 + plt.subplot(1, 3, i+1) + for label in np.unique(all_labels): + idx = all_labels == label + plt.scatter(features_tsne[idx, 0], features_tsne[idx, 1], + label=f"类别 {label}", alpha=0.7) + plt.title(f'{modality} 特征 t-SNE 可视化') + plt.xlabel('t-SNE 特征 1') + plt.ylabel('t-SNE 特征 2') + plt.legend() + plt.grid(True) + + # 将两种模态的特征连接起来进行可视化 + combined_features = np.concatenate([features_dict['diff'], features_dict['wnb']], axis=1) + tsne = TSNE(n_components=2, random_state=42) + combined_tsne = tsne.fit_transform(combined_features) + + plt.subplot(1, 3, 3) + for label in np.unique(all_labels): + idx = all_labels == label + plt.scatter(combined_tsne[idx, 0], combined_tsne[idx, 1], + label=f"类别 {label}", alpha=0.7) + plt.title('融合特征 t-SNE 可视化') + plt.xlabel('t-SNE 特征 1') + plt.ylabel('t-SNE 特征 2') + plt.legend() + plt.grid(True) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + + plt.show() + + +def compare_models(models_results, model_names, save_path=None): + """比较不同模型的性能""" + metrics = ['accuracy', 'precision', 'recall', 'f1'] + values = [] + + for result in models_results: + values.append([result[metric] for metric in metrics]) + + values = np.array(values) + + # 绘制条形图比较 + plt.figure(figsize=(10, 6)) + x = np.arange(len(metrics)) + width = 0.8 / len(models_results) + + for i, (name, vals) in enumerate(zip(model_names, values)): + plt.bar(x + i * width, vals, width, label=name) + + plt.xlabel('评估指标') + plt.ylabel('分数') + plt.title('不同模型性能比较') + plt.xticks(x + width * (len(models_results) - 1) / 2, metrics) + plt.ylim(0, 1.0) + plt.legend() + plt.grid(True, axis='y') + + if save_path: + plt.savefig(save_path) + + plt.show() \ No newline at end of file diff --git a/evaluation/visualization.py b/evaluation/visualization.py new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..7485c40 --- /dev/null +++ b/main.py @@ -0,0 +1,349 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import argparse +import matplotlib.pyplot as plt +from config import * + +from data_preprocessing.data_loader import load_data +from models.image_models import VisionTransformer +from models.fusion_model import MultiModalFusionModel, SingleModalModel +from training.train import train_model, train_single_modal_model +from training.utils import plot_training_curves, save_results +from evaluation.evaluate import evaluate_model, plot_roc_curve, extract_and_visualize_features, compare_models + + +def main(): + # 解析命令行参数 + parser = argparse.ArgumentParser(description='白血病智能筛查系统') + parser.add_argument('--data_root', type=str, default=DATA_ROOT, help='数据根目录') + parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='输出目录') + parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='批大小') + parser.add_argument('--epochs', type=int, default=NUM_EPOCHS, help='训练轮数') + parser.add_argument('--lr', type=float, default=LEARNING_RATE, help='学习率') + parser.add_argument('--weight_decay', type=float, default=WEIGHT_DECAY, help='权重衰减') + parser.add_argument('--mode', type=str, choices=['train', 'evaluate', 'compare'], default='train', help='运行模式') + args = parser.parse_args() + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + + # 加载数据 + print("加载数据...") + train_loader, val_loader = load_data( + data_root=args.data_root, + img_size=IMG_SIZE, + batch_size=args.batch_size, + num_workers=NUM_WORKERS, + train_ratio=TRAIN_RATIO, + max_samples_per_class=MAX_SAMPLES_PER_CLASS, + normal_class=NORMAL_CLASS, + abnormal_classes=ABNORMAL_CLASSES + ) + print(f"数据加载完成。训练集批次数: {len(train_loader)}, 验证集批次数: {len(val_loader)}") + + # 设置设备 + device = DEVICE + print(f"使用设备: {device}") + + if args.mode == 'train': + # 创建多模态模型 + print("创建多模态融合模型...") + multi_modal_model = MultiModalFusionModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + # 定义损失函数和优化器 + criterion = nn.CrossEntropyLoss() + optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + # 训练多模态模型 + print("开始训练多模态融合模型...") + multi_modal_model, multi_modal_history = train_model( + model=multi_modal_model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + optimizer=optimizer, + device=device, + num_epochs=args.epochs, + save_dir=MODEL_SAVE_DIR, + model_name='multi_modal' + ) + + # 可视化训练历史 + plot_training_curves( + multi_modal_history['train_losses'], + multi_modal_history['val_losses'], + multi_modal_history['train_accs'], + multi_modal_history['val_accs'], + save_path=os.path.join(args.output_dir, 'multi_modal_training_curves.png') + ) + + # 评估多模态模型 + print("\n评估多模态融合模型...") + multi_modal_metrics, labels, preds, probs = evaluate_model( + model=multi_modal_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + # 绘制ROC曲线 + plot_roc_curve( + labels, + probs, + save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png') + ) + + # 提取和可视化特征 + extract_and_visualize_features( + model=multi_modal_model, + data_loader=val_loader, + device=device, + save_path=os.path.join(args.output_dir, 'feature_visualization.png') + ) + + # 保存结果 + save_results( + multi_modal_metrics, + os.path.join(args.output_dir, 'multi_modal_results.txt') + ) + + elif args.mode == 'compare': + print("创建模型进行比较...") + + # 创建单模态模型 - DIFF + diff_model = SingleModalModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + # 创建单模态模型 - WNB + wnb_model = SingleModalModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + # 创建多模态模型 + multi_modal_model = MultiModalFusionModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + # 定义损失函数 + criterion = nn.CrossEntropyLoss() + + # 训练DIFF单模态模型 + print("训练DIFF散点图单模态模型...") + diff_optimizer = optim.AdamW(diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + diff_model, diff_history = train_single_modal_model( + model=diff_model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + optimizer=diff_optimizer, + device=device, + num_epochs=args.epochs, + save_dir=MODEL_SAVE_DIR, + model_name='diff_only', + modal_key='diff_img' + ) + + # 训练WNB单模态模型 + print("训练WNB散点图单模态模型...") + wnb_optimizer = optim.AdamW(wnb_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + wnb_model, wnb_history = train_single_modal_model( + model=wnb_model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + optimizer=wnb_optimizer, + device=device, + num_epochs=args.epochs, + save_dir=MODEL_SAVE_DIR, + model_name='wnb_only', + modal_key='wnb_img' + ) + + # 训练多模态模型 + print("训练多模态融合模型...") + multi_optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + multi_modal_model, multi_history = train_model( + model=multi_modal_model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + optimizer=multi_optimizer, + device=device, + num_epochs=args.epochs, + save_dir=MODEL_SAVE_DIR, + model_name='multi_modal' + ) + + # 评估并比较模型 + print("\n评估DIFF散点图单模态模型...") + diff_metrics, _, _, _ = evaluate_model( + model=diff_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + print("\n评估WNB散点图单模态模型...") + wnb_metrics, _, _, _ = evaluate_model( + model=wnb_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + print("\n评估多模态融合模型...") + multi_metrics, _, _, _ = evaluate_model( + model=multi_modal_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + # 比较不同模型的性能 + compare_models( + [diff_metrics, wnb_metrics, multi_metrics], + ['DIFF散点图', 'WNB散点图', '多模态融合'], + save_path=os.path.join(args.output_dir, 'model_comparison.png') + ) + + # 可视化训练曲线 + plt.figure(figsize=(12, 8)) + + plt.subplot(2, 2, 1) + plt.plot(diff_history['train_losses'], label='DIFF训练') + plt.plot(wnb_history['train_losses'], label='WNB训练') + plt.plot(multi_history['train_losses'], label='多模态训练') + plt.title('训练损失') + plt.xlabel('Epoch') + plt.ylabel('损失') + plt.legend() + plt.grid(True) + + plt.subplot(2, 2, 2) + plt.plot(diff_history['val_losses'], label='DIFF验证') + plt.plot(wnb_history['val_losses'], label='WNB验证') + plt.plot(multi_history['val_losses'], label='多模态验证') + plt.title('验证损失') + plt.xlabel('Epoch') + plt.ylabel('损失') + plt.legend() + plt.grid(True) + + plt.subplot(2, 2, 3) + plt.plot(diff_history['train_accs'], label='DIFF训练') + plt.plot(wnb_history['train_accs'], label='WNB训练') + plt.plot(multi_history['train_accs'], label='多模态训练') + plt.title('训练准确率') + plt.xlabel('Epoch') + plt.ylabel('准确率') + plt.legend() + plt.grid(True) + + plt.subplot(2, 2, 4) + plt.plot(diff_history['val_accs'], label='DIFF验证') + plt.plot(wnb_history['val_accs'], label='WNB验证') + plt.plot(multi_history['val_accs'], label='多模态验证') + plt.title('验证准确率') + plt.xlabel('Epoch') + plt.ylabel('准确率') + plt.legend() + plt.grid(True) + + plt.tight_layout() + plt.savefig(os.path.join(args.output_dir, 'all_models_training_curves.png')) + plt.show() + + # 保存结果 + save_results(diff_metrics, os.path.join(args.output_dir, 'diff_model_results.txt')) + save_results(wnb_metrics, os.path.join(args.output_dir, 'wnb_model_results.txt')) + save_results(multi_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt')) + + elif args.mode == 'evaluate': + # 加载预训练的多模态模型 + print("加载预训练的多模态模型...") + model_path = os.path.join(MODEL_SAVE_DIR, 'multi_modal_best.pth') + + if not os.path.exists(model_path): + print(f"错误:找不到预训练模型 {model_path}") + return + + multi_modal_model = MultiModalFusionModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + multi_modal_model.load_state_dict(torch.load(model_path)) + + # 评估模型 + print("评估多模态融合模型...") + multi_modal_metrics, labels, preds, probs = evaluate_model( + model=multi_modal_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + # 绘制ROC曲线 + plot_roc_curve( + labels, + probs, + save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png') + ) + + # 提取和可视化特征 + extract_and_visualize_features( + model=multi_modal_model, + data_loader=val_loader, + device=device, + save_path=os.path.join(args.output_dir, 'feature_visualization.png') + ) + + # 保存结果 + save_results( + multi_modal_metrics, + os.path.join(args.output_dir, 'multi_modal_results.txt') + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/models/__pycache__/fusion_model.cpython-312.pyc b/models/__pycache__/fusion_model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1d44b3e5c58f080cb6973722710cb3979072efc GIT binary patch literal 3870 zcmcH+ZEPGzb@uCS?|hfoaeO|ZP(+1mN}8rI4vtltLd9XUY;O1LHFtaN z?444ZjpTv}=hh*1X>b!zVgyKb9SHm*P(}Pi;!mQB)MKF{h@9Q~MZj*U`iu8=Z`W7H zr9~?9s7!_=-7k}6lPu%n zT$W7%p5>E*v%+Qq*gcyHJglws z87c5CYB!uTPWThnFzoPg?YV+xYypP{uVn2ySoleNPU2*$y6dZ?Xev z$5ZWaDx;XDp4K(G6t;O&%cQ9Yw^4LCih)m2Nfh5qqXHi3bwi zn0)Vm+B2jWN-Afl$`Hyul+g;`)UvuVl*Je^hg>maXVyJ7PMd(i#64RERJW%Nl+U~Z z;N3}bTZ~QzZ;0`m;;!`o>Fh4Ibe3Cpm*acNv99utR)jm-*SCwE;Ua%iY+H|!j)y*o zro*9$15In&e0DDTm&mrW{nP!k`oiwH?7~+|kx$^t!1TbYZKcQ_#QUcEUg}5W==9Oq zlch-4T1aRGE;9}5jikNvl5}2L=wJL+Y1gAg{>=6@5SgU>$b$)bq7&D@2V|aDWVGh- zF!hJUc%FNT1B&|dH5!ZcDfo@(9}V~{d<3(>Fyyw&tmn_C1mIbpYdoHU&oclNex&%U zh!yZDf)-{%uJAY54SVLVOhd|>yZ!pY%FDm-7k%V$I&Aj%Gd$zCH$+lT0`sDu5ffiK zv_J90{=^sdC!YQK`wn0PIP~=ibZc_|+|ydVK#ff|XL7@O-t0+KilMW%-1w?`{ zir~6GYf9y^Cm@Q|jiqf6EcNq%E2)X`;3=IZ@UkFOx|0f3Cpx!|q}ujFKx=*jz$Cd7 zCvBZ?#HAZ?Y4L$l{NZA>+}i#|>%JSU`xb+x)&s=|hS$B9`p@?-ocdMf`u$J*q3LGF z7q5$L<+jdC?dRLC-}jkP+atwocLHQbY`$x*YvG*srcof>&Qzy#3qngUMGe&PuG! zqcLm_Y#^~4HkAa&dUPMKqYogE5FmFG%dDezSLcn|^LeVI^1i($?%!JZo9UiJX4rmA zl3VeP|7rtJZp*#p=B|GVTx0kS;X+|YvrA(27M!nQ5v`(Db<}MYlv9cISR+D8rEMx91WDt+Y@^b|xQ`-s3vYJal@xhc#yAgZ}fE}%6`;?Z+$np|H zyI|GCm^kPtf}&m=fcKjn0N$MPv5PdxYM99c5xd?41Lm;6A`1PXVPr z@D(`L@Zv4|GJbPk=5%BMu)bh79@mZGj0Sy}k^unC*j`{$Idjf+1HTQc`M{N{D_35r z{PJS;wVAEkuG4V(V{W_rlsk$wbJKv9t-jFD>->>$N1P6|@AwrdkJx!wh?)V#MSZe-|q zsp-M`3{4+~1K?Sog%8Ew7y2lIUIZIi;Bz>A3_yL*AH+ojn9oi0J}2nAy`XOtdf(aG zD8w8;a0sOadLNM4qEj`y4a0Fhi5w|6w% z9ytKf4#~{TK?i~T5^--#0REoX3kPBQF?tN9n6k`A0NmjihWRszyh&RAO1iF-t~bdi i*ZDq%3BP&>3*I_GeEk`AFB6-k^W$^lZxcWr#{U7-`=;Lj literal 0 HcmV?d00001 diff --git a/models/__pycache__/fusion_model.cpython-38.pyc b/models/__pycache__/fusion_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9562715db92129378af9e81cee3efbb0738619d GIT binary patch literal 2934 zcmcIm>u(fQ6u)<7c4qd`(iTuO5!Wb~)mn*f8-vQDsDZ|SNHW$;hPkut;OuO-p{C{lVK|KsEwW+qx(onTZH||V2O(Ex81<4`fqW|56V?uakz~` zON-u6f~TP25Q8#eNQ@dX>pDfsl;v1{JZmVdz&gf>kzt)|4Sdy8L?MO-hjokfdm=A% z_f(nX4McbxG#xMMhyCi`T$;aJoB!bW&8dm{r@Up z%w%pwWSBD=j)<%uRm>5`Vu8>YuZ~uukmunsr4Eq^qEUwn-85a_4Na4y2}FUp<+ADc zHrSqv{1N|POObqEs9;Qm7Jv&{T8!Z#N58@*xniT#&p;e}I(Kb){`~i?&Q{}9J{x2>aSDfrshutvfFV;_= zT=?;qR7>_*;@tMpixgf7gd+onODewyG=47#kqN8Z9ucx_g?uCSK8O=_{tyo8&F2ZN zac|L)bd?j#!*M0ZM=j zLJB#WlY5}%q_GE92S{VhGJIrUDdHu;2R>*ud>}R${YnD=OVML6Y5v?~{oLpE$?>_L zuB3$6B06Bhg$tMGe*OduJwG*Bl7xaLFl2anNorfNlu2yHABXLDFA97M??bWTv0cvl zA-C+XWu=`kDM)gYgs+Bh`DbkN>Th1T6kLVj!13K2`Il=bFibeoEXQqAYSNLDe5>M^ zrpTEXz>$aTf@vO(EUz)5nI@~+KwB8sdzj~qm!_?xHvv+E=+ovvMyE{l-3b2jRNIGhgbfiI!0554A&6Jus zNfBBjJMazIe?JNgMqWgbSUE9zVrq2GUZ?>^^6I2JcHc_tv>6r8YZinLdTu+^0FQ1( z3?#}1qXL_P!D;jz47X_MZn-j{y_?rRtAF$@{FA7kIQRe49t7x%aOITLl$6wDxQEc2 zoW-BS@iv+h>muqI)Bw>;@(BlRxW$?c`nih;)cDh|A4ll$XHh(dq8(_@quRbp;`8J} zFTlW?PymWldfxaC-+=S7 literal 0 HcmV?d00001 diff --git a/models/__pycache__/image_models.cpython-312.pyc b/models/__pycache__/image_models.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef23ab212393f90bc6e146db132f23136c18f452 GIT binary patch literal 3758 zcmb_ee{2)?760D(Y~ML{NK%KyiBkflXh*wKl2&Q|*nmGe{zTe8SjiUeMpdTaw*IeDLWla#_TKk= zc4DG-n)WPzf4=YgzI)&IKJWcG6bcZazn=Zqz&kMdCr(`QRtBr*VK7ZJqH!q_b9ye# z$9N9Me2PnZVxF`R6WE+fdDDCx_R_wXzp^5xrI^GKo*X2a=M2#Voll0{Yij+x#~H^O z7^fFFeP6*D9N}bt8ISeKyxnq?s)pI0p=q5yZ^Sd2PVKh)*AA;AI*r%vC8TmOS=|k* zzl4v~NsQA-jMuoBM-%d-pVRm=QcOVgW8OSD#l?It6WMDwzNA|5!RON_bWKYb16dga zlorpH-u|HUi-}VHY-#-cJ6Em^BU9<>A3pul+n-)Kr@P#7*Tkrc5B{jf>RZtKnwusT z^bU=CbL36nov+L}lH+nj;obcl4>R{0aE~8_J&hX`34tBK<+B>xQ8MU`YJJX=6F?Hb zz|U7sz*)b$KSE?rZyB!yLpKcR)X-CknHbe=pJ~yArpu!3HM7HDa#2wdM#54QTTG+} z9Hsy>fisp2h0S!+<_*L4J(n>~ep9oBVVX(GK3g#LR6h;DO|%gjyCGpH@j=yq3r$;M z##6LJ+Ljb0o>EOyQB1U5H`3S_d$cF}R6~tt3{CAr7wAjr*&#ihQ2Wv_OPPI%v^t?pQ(=Y?~jV#iMCwiG*dTuc@^ z_ADMLL>`BQ#yjDy(}$)G%?!;waeepAA1TY>ALoTPrB#qa;?-w^2^yL``Z*X(!_y_9 zuP2R*z|v`W6Lp5Cp>58w4E?D} zwb8A028XD{^=4o)pMhqK+-}`EeQfI3%#o>BK2QucO&l6OG^xIEIL{S>p^cLl#$TA> zX1=i$?1IfLO?i>Fg9uqD--fQ-gbiv@Mq6L@qk|f1S_cOWDCh{;6pn+6@{_EZs_gg` zMa#q$g}DdqhSPSa;&5?2l~R<;9Nh}5Cfeoz`#6ej$Qm@UiJx-s4EdX|_b%`As@&&I zK5sKqjPb-_yGp1X(rOPh^#CLh2{7|ZK&bhZgqg?~>yR${D99;YxOV5-uhy-8a{i11 zOrMNjEq(Cwk0&Ne6TkU*?o#QuZ^ZGPUBWSHv^oYZPzOa)Jrr40k3--=5LpKw?sowU zCDA%{lw5ynW$(WSENhD^E{RWEM;bunX7T|RffKZ>=;p=-+>61^=*xlftL}xHH2E(`nJduy^>xoOi+JMkr{C=*VcDrcrmT)`K1u zY=ARWYgtuIlHE04ca(g{9o2fUBnEvQY5oP#)rlkjlhjSBfQ?ci?MS_}18y!OR@%aV zp2{ADEaBH22)Dh)pss567Dx$7zpiT)AOQxN8O~%aJD|p`#7WhH2(cR+Tzm zU}V!u3d;d=;Ai1?5B_8Ht^Kg{+fqDbDpqDlHz9Cq!&sTTQFFQDbx;G zG;72aTv1bs9AE%j%Z(K6M_W;vSJQ{BL0iHvLW7v}FSuI8x1n=v?MKXdn;XspY{nABpas``Y3Ut~cB`zSMavFD#4g#YopmL@q?+mB`*g zWG^N};OmH9h|WfTr~iKF($M1e#iM_G^}ScGtA%eJyb)R2_3~2WDDvHkbZ+1}az$F& z_3TpQxr&g;o?^JG*wR^yK0L?GzE})*7Tdat?a{T6KN84?)|yDPXD)riyWH8!RDpY2 z+olgs9ljV@3O@ouBv_FeCj8_66|t)zcFkxPj9FuG=S{I|S=@I^3~ajf7jYj0?hUyc zaStDI6$;<~NA4y`^o=@7Zj!BYyaqoPzfOCsDsCmKkzEr5h|CS}nrA_9V-+(zL^M8m z9U(HpE=g9gqQS0`UEzZ+-;wJ zYFITi`W>8j;;KcT#=ZT}$b#eK58}c$Z19AMwn2||+_f8KqmldFwG#YK;)&OwsjDJB z^N~>VN-$aoMrV#J1-IonSZR$+_fGY$gylk5p4+{6;>wPt@Y8weR@`x)Yi*B1;)x zV&MEN($)R+VcdKKn)Ri{*KrXWghCSw4^~<{UOe#K;{IpvdZONMfBsb1tuK&K4ms8p z87TD_27@aGJq(2G2i%RGe!v)nM>S;!){VLE+bxXeBr7j^oRZo0x&4c*a&oF3yzw$| zsvAsiUx-gRh1E0k6_{ccGWS9AnTO-JzmwpHr1c|m;3Kl_CfW8Od1y^&;+iLaIDLHT Y_!@!X+OEgA<{AEiG%MXDFl2)N1E@Ad>i_@% literal 0 HcmV?d00001 diff --git a/models/__pycache__/image_models.cpython-38.pyc b/models/__pycache__/image_models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ac8e52672580c5fa0ce170d1ebec73414637c2e GIT binary patch literal 2426 zcmZuy&5smC6tAl8>FMtI0CtfLNc=k4YBuV@7$jmskc)0eL=T;fou;d1c82~Kt9uZ} z*@M{?B)a0kma1twhZz3LryaJMG?=GCk3SHIt{K5R5x z0_D$(|F(Yt?;otp785G(z?XNyaKdR!{PLeB%x4sR7SqJ=jl}d#U8iv+VT@3cSblXY zVkeI8P;!!RgPZGwn}T)SHN&q#q{1zTSX+@grymfnI)U}blE>8CX&I(@E0c+kC(9YLPj`fw&3=$?j_Qq+~U=B$2Ymnops_@E)kRX7A#*= z%~N3!t(-{C2+licyYImS;~N*p*KUv3u8hCDIKKSb-9K*iu+#YF?+@-=d$9SnnD(Ch z1-&PG(HLk5*r)>Dw4iUnhpqWhFa=p7Bf1Gm8`2>O*pPuWz-Nab>GK3?W>M){Lt|({ z8{=%lgjkgm*sZZNVK4_It9Dk%RK!8v83<+Nh3s(Q*{YKFdyqCe2s&w}2m)nyl6Kk2 z)ki3OCk=E>+(wN4sDD`5%?=E~8B4|pfJ zPL&e`Q5@!Z5L_n@U*$)Z!!(Stl!wcBf6KAxpA|_bTuz`C=gXZWY>S|*EcPzQ2F&ij zmuqZNms;>^vcbZWcd|G=PK2kQ7ECE#`*Ha4MKD9!g%lnntB}$W9nwv@3U&l3-0}>- zB->;dg=CsblENs784))FOEn0176pM$r+gj;s>as6L?Vu%VS9rMI5+N6V(M_E9M`m|Nt?Eg zMNoaBU?(g)snyWba=5FhRpa$1sMS>aKI-J1ES=$S{}q^M{L`(wx4xN)-oLtD((V4` zo8#M`-@Eed_{#V9e%>7a_|XwiUZToL#I$lwg%UtnfYMZoOhzlJG7vHY zZ2*!4y0>zoI1h^KtVku^gGA=4nr2cd@ygLkz65%&%v4*{EY&DvKaB#6gmK_GnsM7= zB%i|B4&ZDRE_%g^a_~f$kX*io;f5$5kMfjKBVH3g!yYB)Fnsw7UlA3n{GC=wonpRU^`6BLhVpbjVmW3bR`QFw!0_s#6Vvx0ieY&smFsx|QlxDb;K82#N*^ULpX(vn3u+zJcZ~Fv=3=dtu7u zn^-ZTu#iWv4Id+r9n>@VESl%g;D{Ny&?!Sj!67*!08@rZn*$5~Y;bUAf~KNY-Hp8x zA6SXcw0s#2-owm_+XD3>G~oJrSrjOEGt2U*`HaL?uTdV#)*xI=Qw>iIH4Qp{e%i3u z0!$$t!dSU8K7H4p(|yiqrCKbXCS^_k-`9$(by2T4v(d+pQufhOnBT_zuhfRKN~z`7ytkO literal 0 HcmV?d00001 diff --git a/models/fusion_model.py b/models/fusion_model.py new file mode 100644 index 0000000..10d1043 --- /dev/null +++ b/models/fusion_model.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.image_models import VisionTransformer + +class MultiModalFusionModel(nn.Module): + """多模态融合模型,融合DIFF和WNB散点图的特征""" + def __init__(self, img_size=224, patch_size=16, in_channels=3, + embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2): + super().__init__() + + # DIFF散点图特征提取器 + self.diff_encoder = VisionTransformer( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + dropout=dropout + ) + + # WNB散点图特征提取器 + self.wnb_encoder = VisionTransformer( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + dropout=dropout + ) + + # 特征融合层 + self.fusion = nn.Sequential( + nn.Linear(embed_dim * 2, embed_dim), + nn.LayerNorm(embed_dim), + nn.GELU(), + nn.Dropout(dropout) + ) + + # 分类头 + self.classifier = nn.Linear(embed_dim, num_classes) + + def forward(self, diff_img, wnb_img): + """ + 前向传播 + + Args: + diff_img: DIFF散点图 [B, C, H, W] + wnb_img: WNB散点图 [B, C, H, W] + + Returns: + logits: 分类logits [B, num_classes] + """ + # 提取特征 + diff_features = self.diff_encoder(diff_img) # [B, E] + wnb_features = self.wnb_encoder(wnb_img) # [B, E] + + # 特征融合 + combined_features = torch.cat([diff_features, wnb_features], dim=1) # [B, 2*E] + fused_features = self.fusion(combined_features) # [B, E] + + # 分类 + logits = self.classifier(fused_features) # [B, num_classes] + + return logits + + def extract_features(self, diff_img, wnb_img): + """提取各个模态的特征,用于分析""" + diff_features = self.diff_encoder(diff_img) + wnb_features = self.wnb_encoder(wnb_img) + + return { + 'diff': diff_features, + 'wnb': wnb_features + } + + +class SingleModalModel(nn.Module): + """单模态模型,用于对比实验""" + def __init__(self, img_size=224, patch_size=16, in_channels=3, + embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2): + super().__init__() + + # 图像特征提取器 + self.encoder = VisionTransformer( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + dropout=dropout + ) + + # 分类头 + self.classifier = nn.Linear(embed_dim, num_classes) + + def forward(self, img): + """ + 前向传播 + + Args: + img: 输入图像 [B, C, H, W] + + Returns: + logits: 分类logits [B, num_classes] + """ + # 提取特征 + features = self.encoder(img) # [B, E] + + # 分类 + logits = self.classifier(features) # [B, num_classes] + + return logits \ No newline at end of file diff --git a/models/image_models.py b/models/image_models.py new file mode 100644 index 0000000..e17c749 --- /dev/null +++ b/models/image_models.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +from torch.nn import TransformerEncoder, TransformerEncoderLayer +import torch.nn.functional as F +import torchvision.models as models + +class PatchEmbedding(nn.Module): + """将图像分割为patch并进行embedding""" + def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.n_patches = (img_size // patch_size) ** 2 + + self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + # x: [B, C, H, W] + batch_size = x.shape[0] + x = self.proj(x) # [B, E, H/P, W/P] + x = x.flatten(2) # [B, E, (H/P)*(W/P)] + x = x.transpose(1, 2) # [B, (H/P)*(W/P), E] + return x + + +class VisionTransformer(nn.Module): + """基于Transformer的图像特征提取模型""" + def __init__(self, img_size=224, patch_size=16, in_channels=3, + embed_dim=768, depth=6, num_heads=12, dropout=0.1): + super().__init__() + + # Patch Embedding + self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) + self.n_patches = self.patch_embed.n_patches + + # Position Embedding + self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # Transformer Encoder + encoder_layer = TransformerEncoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=embed_dim * 4, + dropout=dropout, + activation='gelu', + batch_first=True + ) + self.transformer = TransformerEncoder(encoder_layer, num_layers=depth) + + # 层归一化 + self.norm = nn.LayerNorm(embed_dim) + + # 初始化 + nn.init.trunc_normal_(self.pos_embed, std=0.02) + nn.init.trunc_normal_(self.cls_token, std=0.02) + + def forward(self, x): + # x: [B, C, H, W] + batch_size = x.shape[0] + + # Patch Embedding: [B, N, E] + x = self.patch_embed(x) + + # 添加CLS token + cls_token = self.cls_token.expand(batch_size, -1, -1) # [B, 1, E] + x = torch.cat([cls_token, x], dim=1) # [B, N+1, E] + + # 添加Position Embedding + x = x + self.pos_embed + + # Transformer Encoder + x = self.transformer(x) + + # 提取CLS token作为整个图像的特征 + x = x[:, 0] # [B, E] + + return x \ No newline at end of file diff --git a/training/__pycache__/train.cpython-312.pyc b/training/__pycache__/train.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bde0646a1d4615388896b21ecb6913bb281b616 GIT binary patch literal 11931 zcmd5?d2k!od0*U%`@VRBmqbt$Mbf$@Gm>@Fl6)wZELsu^aX|_YLDCCQmlq1%cw#VC zO0a56u#!Wl?$l5u$4oQRYMQo@|B5I52M90$x@M-DsgCHMmWri%Mt}AD7FYm+R*apv zX&#Ah-|@Y-@BV)8_x9hkS``8AzrORu(5#yv{u2Yd6Bj(6Pk?8Jpa@DdO!VcSqJX$h zEXw;O0coGK;FtBuIXy+60-oYwGC=l`g*0WK65=Gos(`vr9nkb?0@^-pK-Z^(bm==p zpPrI|G$;v*q~!02`i$UHQVQ^y&WhZs`&ga*ZV@XDy-5Y!QdYC|l$Z7ldG~okUYgbJ z@X*6&4}?a?#(X0~tYvI?G}IrWJ-!hL_79BHr@TQ{KQI~?8xMK=1Ktqr8wd`_u+V~2 zz*7Rh&l^F`5DeiL@kyoUwSrefgix1rd|X&25GM~6-@%9|(RJ~)LcfL;jAAA5 z5Mh!bDTxzG^J6V493#C}s7Y7}deSeUrvg3Mm(Wv-p`DD9Q59>YffzBPVbsMKEhF(O zE2;!RzSSav(fZX(rc3J>JLH z%WvglN>4s5Y+z)5dnumxK})dHDjc7SQTwYak|0Pa_|f5O2|XegEZy)mm9FsDK3xxJ zQJp}nTqu?zFN^tFi(n33*e)f$O8!Cd0m7{uH^TdV{lky${p8!XZ%o|2@dNMJ=)j3b z@?z{&Sge#cIKYyWZ*Z{R7Z_p{r$>%+9@((xxOezIPHY-i#8ojX9Uct^S&3(0z-?e< z+}fq3n2sxr6^BMy6 zf?Y*9Ktnd#H#)*9N5?|G0OY1wIpsa&8}PDntgtsoo3MV;vEv?^l^yqlpc+cP!9kWR zv^WUr*yJFq7#$D6E)=A(|5#B!D;uM|RFGA2dHbRLK}>N~y2mT1jYm8qo`KO3%5!7@ zs(WPEJATp|@Oh5#J6+z>F?NwNFn z#Lfkk?Y)kPZ4c#y#t_xS*Jsr26Webatr6KRi!;`ivb1C^ttm@u!j`slMwB_TBj!w* zo3iF*Df6=UskFHxLgtLN=;`UYtg$g=Y>cl>8`~oC2XX>bmGSjyQ+q@KA%in!k1tE> z-4W@coKTvwWOa(H&XBb^QZuEU)Lzi%NaK`lQkNkeP#cSLX2bM`*rB*5ZEl?;BcjNf zoV6-8INhJMtVmf_q%AA6mhP0LJ8fAVQDTXDEG0_K_@;dqPsZ1NxarEK#G4=VWXSFu zsh(0zs-hcadZv5Q=BD_;58t@*#=P-^V;OQaX4Pg%M~ZahELAhT)4e&fZD!r{x(6~z zmFj;8iBi3|hA`O{dkBqbYUAX_XmIBA^y#F!E-9`196FX5z?sTB52R{Xf&DjdIMts%fr6*uoE4G?F1-h~ZS$tVlPM_o}21Q`jxQG`{X zmp493>#0F6Z+w>4LsLUAs&Z4q86Pd9;kVMTuB@FJf!b*)`E^B6TVJNF6V~Q+Ia6W? z8yO>`;YHbmX=Vy0B3&p=Zz5w z#%G-N44ieFSgH4vXE<+w9iYuBM)IbYGsm=riwceohQQ*|h??js6o`=M8WcE%XoMuJ z+8g3ys~&s^Dz}vr9K&Vhf68C^-A7@(e?I-hwL8 zUoe_^@96O$4SSunDwv=SpaP^<#MDheloAhxH{;9{8WL^TpB z=R+7lZ-d|(_yvDnw4}Xn_rA0DNlSXmXrAcJ8B7y0?>zvXm9guC*hw6O1aT?unmC+d3n=AY`cpdpVdkuZ(O53u>*J**Co}zC7_}#?lqp zv0$&6IXiteZcj9%?d_AhBU>ZmIcHsb*=&8*u`1oKc6tEo+CYGZ?Q{#pN};cQ)Js;)C#*PYQkm(y6YnwpfRCidFg;n~BN-pJN=q-s0T zwVfGF*RvCUsp)#qMcCGSA%T8;xSr70#rz5P{Nc3j8xy;7wwf9Lw0~k(Wary2<*I7u zmd`GSKyONFzoj;otaeVL=mjr%P~9 zf4_9hREE@iF_f55UYGFuT37`*LJdC+{Iu}XQPQwJR6MHru}dl(zt4var9&Gd=XZS$ zQJeroVGI?mDJ5f!l>8i&1NBlc@=L_~#p5iru%uVCeGpqb)&V{cF%m|~$gUN7m9yI> z*m>2$c9E|OcPte+0y5^3MyPlMg!hQyJ+f;B6n6whJ8TJC!#1#mLZw(rrGTXu(y*2^ z%37&pOqA-9lv1OOxFmf>UuYN$r=?}AOK9mB8>Ppcm(w*s$XE=SAY|r4?-xf1dd14< zPi}|p&(Bz4v{8j~e%!q)9AG12DXR+BZwZ_LZ?aLA%d+2j7F$cRn4hyx%otmtcgl{k zfFD(ig&|7%j7sPL92A4CKU6j@_Oe;(c)y5w(aMg9q{Nq76^!8=>=a>pS-DQePF0@ak2@M1xydn4$mbn_v#b<+C+_j}c&t$Ez{9+Y0}8BY)nm&_0rv>8 zMBx?Ro4fP3KicSeEa_N1_y{ewE1z)Z-S6DH^22+7_U=X(r(QrQ1@*jUKCWC75l2Du z!n=3lkon%<|7*US=UF9~{sCVQPUVkqk7{?_nYw=G()ZyV&$Dt!0!Z*!*6!Nb^*k#c z?7AvrW!x#4HE>7#{$L0Uc|YYF2(gypp}ss;flvgH5UWN{0U1T-OCfV|K-Q|p8oHDzVLLD5OUp}xq9dA%MGqOmlOA{Uw1vijl%`Z z3D?<>a`&UZ$@khLT$qIp{NaVWk!XYK?!-^-{YB#L%s*^&-ADWgJ>5O{2yv9_PUJiH z|9tMMyu@B0FTq2qygyF|Ai2Py8IERPH8>&q`R1xZqakQ8&>UXCf8x;~y$4E1_~r|S zAX~t5(o6a1AS?Bc`bKEnk%5f`ZB|BmMuy-qM*BuW9NWOkJ!4}4`Z&Nua{y8r^qlgt zsstj{?v^9?I?D&>@X$fqotZhYg$2hX9dP0i5jTh8HS`FMxri zSyg_dM?6qBJ=ey31Ws~RS9&XQ#WYwfg7%^Cqd19T7z9^U8dzj#A>GGC@Z=EiRLWw9 z=KB-5Ao?igJ%%CxB8bfv0+l6TSzu|n+R^WT`f<#631Gzp@sK1mx~X-O>!x}pdt$8_ zb<+Z=no>`yqq~7IS+F@~hNg#RhNp+)>*rMo=EjEP${lIj&WV=*_t}~fFU{8xwm$(PW!(PCGAFPxRgf0_PmiHxbzasg05`YwXQgJCHZV z_{Mm9QnxBYc0xkr#05QMYiLjGnBRQU`RShIp`&Tnu@qSwQAPLUNK2m5iS=g4Ng@A*0JXW*cGvq5+O7=iLmz6)(IdZ5H{_}tBYPJ#Mnt}VoO8|YZ1d_=^Xhc-+FWZ_wsmu= zb#uD42h?))Z4cCPhwh<?hDd-P1)iruJoX4TCZbu+v+`ciyHqT$M}_=aTr)}-l$jAq-v zYmAG#M1-m7;a-r1;f$S28Ql}T3sPlL(-Qa3o2~@rw)#AD|M!% z&dV)xZL@9h{&dxvjCAcSwISzd&e>e}Z>h=qA7~Y3`68jvppePwZkVJ4XhL|0ehwg0 zVCrUm8*u&O0?4W-0J5qKAU^?)0U+xEARE3M6aWZZyw!LDAWM(_6F>x08AOl)B>aj1 z8K6TEBJ)E~B4Gd|6#;1xu$u+qEM)*#L@BSAZh@?fl>_r& zwgcQ#e+fMYNBh}}`J51zFiyt7gSK!L084?J7QQV2d8(EIo{%nGill=6DFQFQ5Xf>? z+&EQ-YnU3w!O?_hbSfNQFaPq21PJnJ;abM*cR!sDEm`q&ET4;U_>r%uaFmAcJHE%D z#TLQRYrm$_t$t)nDjc+^vmy$DP!oYzN?UQ`v=}#ncI`n0nX+T<~isn?Y-zT`otzmkya?N>^HA-1-6{>i=VXmDTFHR0PWch7RyaPbY+g19mR$4umkbBoE{@FBQ}9=| zQ)mU(Po2=JEvu(qFjxK3xkBIIJAiR94SXHKjb+x)AiRrEMmR2;p0Sre+B1j~Gb7jz zt~3iHSX)ecMrl^2k+PvhtdRy29CsOgjv~CUNVf!uF87BVWKpeyPr;luWbsL ziGPAs18Z2Ia18La<~)g0dEo-TC|*YK3W`@z96<3+6tAH;h~f~6 z*HIis(U0O!PKavgfD2g!@Z=%4X0!^a`qChT@M(A7m zFoe5_qOmh*8wxbH9N)0?r4)^~n6wMUIEr;B)}z>n0ylOVS(f}o)hYCyMsWtkSrq3$ zxSN-zC+N3OlR*&%!7AaP4-Csm?^*ghMt%#$v-k&k0#lKApx;LErzj9(ayMP@K}%pL zSaHi$alJ)<8?`x7q6dAr9hcD)%@BMRJy9N?zXNmR36CI9Cd2Rr9oLXNZxMlVK7wD+ z4#V|F%v->r7cSy}K)Ik@ur^#g7Evu&To?C8l(+Tf=-&95jJ`7>h1)r}lN0#DW&5+Q z-?m>oS$X#c@6y_`n);NcKBH;KX$*LccHt$s+M3!w2`Bs-pf((}U-dGMzfq&+e6zkw zClhNw-gI@-{F~Q$GMbI0TeHJ+$7YYEYgZ;tef-wdw{CP@gWI|utdKE(=azFc&go`# zIlF5nJROElTqPS(XM^rhB4LVh8nmRgp^u7KEgiX12!3npVi8EQ}`lNGxRCep> z`^~5n_>|b`*}AN=E#<^JPiJ=&KOw4i#if^4MDhC}M+014E=$|HqEhI;nufUzvm4@v z5}tH*=QJ4=Mc3qN8{>ns{n?u5QZ>(|Yu03IHl}JerfW7w@$(>~b4EX{k4ryPUQvFa zP8|IBjjM0mFkU;BF>KBm>@&J)UF_`K`PuX7`c?C`f2{sd^^I44T$eFy!K~F;Lqp2Y z5MTCT+m*Ht+Ov(TQ;n=N0D%)!BybR6}>VVf8N!YaXm1 zsy2Kffi76=B$Va?A=WzAKHJVwWCgNoPCcv6lVtW~z}L9bw#^ZBk^aI;B?gjp-O29l zHxK{p=qE>$2M#8;A4=+8&ya@;MA#fN%OtiX>pGL2FWg-BvrV6D`ZSn);XqRN%?$b4 z{|zNJwP|uwtmYT$7T~+)HqCDOlGGTg{toom&#FJE{`5d{%PUFKs~OFK0!KE(Of&J< zaF+gdo@}24d9qhQ{vmj>FV2c!*tzNdIH-jU^5^m$#K9KXFI!{>S1H|+`^cj`COcMf z+h~w8C2lD_39t{D~6vsAL1@ESQ+==j29%1bPwjHt~+}kZ(AcH>8}I#qY@S=B^8Uh$uNqn#TS8D#5XSoGExrZiX6<4|{v) zi;#q1CWuI5Q6dtFJ|i^0BCNk6sy`!4pAky%e@3YPgE*Qdj{Zj4Dysj@K~bBib@7OZ e*dp2``dqSGB$6+_E+z&=W1@$W7saAHz4*VpXn#fk literal 0 HcmV?d00001 diff --git a/training/__pycache__/train.cpython-38.pyc b/training/__pycache__/train.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41ece40717462206367f2b3a83851a2baa3eecca GIT binary patch literal 5360 zcmb7I`)?dq6`tqr?Ck8*YkQqI&N@j8rfK?6%OkYVmbR2f3w_bDBrM~dv9r#;a%Y?- z;Vz_d)K(&;6#@#4)2>m4ny4u+300KG|G%rIr#J0Zx|(h|Ri~AyW?I>57WIZ_o>HngFX1Ke%zFhd z?K!8^YQZadIWG^X=&6q@Gb5KU6XlT^RT$wBuQg)`=iZ|}bLahs{m^G3^MK2m#~%+n z-EN~jFG}5JC#;3cZM0FW)fd@OKM=Wkr`25y{aVWpS)(4*O?qj374S^r{sQjc?*Oj! zl@&!*LUl>$nGzGB##K-Cv=_Brk|*z14*sEMam&*ul)lc56}6W_$rzSwl+0l%%@v;F z_LA0fxW*lxhLquY-=I-=rf=e*zT)DLYWhT3!3de0B)!txq(bG1fmUL>Ds&q~2;y0d=sl8VNR;G8Cm zW+w=Q?$+xwd11;g!E~zE0745pB6Y%No!UHeJz>qd^~2A&%o9cs`dzf|u=*jPx7=gG zoNIR6P#7)WZHvTW*K&?xp>gAcFVcY@*1DY_oNFAzfJf(N3c_r4JijR%Ic^Ov z@L(gUIzuzj)Yh zHQWOOLyd%e-Qyx1&0Ky2^3`B3K+!BUp(b=kv$U*=pKjx4src!6HLs7UW4fW@_aED2 zx1xQW9C>h+Af#|h1HTWS=?VikA~T;?dI{V~+*YXd^%do%HMoE=FngH%Ut#1aY4Rr5 zVe&RgWb#odjS`uBRB~YQc5LVx&u|AOpAEHz1kaGm!;NwnB}pY|==olO7q|mMw_xZn z_Fj=2eVb?dX&AcGcRVXLbQ9X*RoX-!WAwA=n_E%H%y|-;!W9||n%q{+kxxFQ+^IaO zJf=LYgn6FeNp3Cay%BCK6!?gjx=VqpkYqz5oh;4iaGUg^gB`te{`7mldUoy2mDSgN zAU!5hFbj2zz8*^xP*auxAd2Z;qa|G>vKtF_vdG-}@tLA9{G)C&@|7_(7Rh$xCDL

-#CrZKXfY|6YY1o>r=C_!BRvnmqQ7x)dx}(~vt+CIe*Cpz^ zY3;j3MuI@DwNB-oGvL9q+KSdog{G&G z3&4DBI7%9~gFAye>lrZV1jX=PK1}k&l6t@LvU>2LULgb*@Pd~ZSMVm(oKemiOEiMz zS@M~Jv#1QL(V`NkuxG5(a_FZG7jjyQg5_32BWumNw&*E8kOxV zFU5QVNmcok1;*mA^Hq4zXOT=IHLXK5YKR(>4@>!kQ zsMy6LzdTZn_Lk?%N^czPCiu9wZM))a$0{eImY7E#Gb->DLIuqg3<AaT{mL9$aObUwOKz1_jSWVm&j|uVpi7;8-5$~U_d&wJ}{omr2PQ#V@3M&rh3a} z>E>%(8(DOShj{Vz)i-{#zp|CCiy|sGgHH>k$NT3Q{ylL*OJb2sP zLYuqooGDDnGu6EGm0A$O|7u>N9#%_3p4_4pNeUT=SRzeA?8%!MxqGCi9o!&azxbE4 zs~3K@_S0Xlzj|tI`Q*hv{Vv++fBfZY}q) z?=@H;jD=34&4_%FuA;RtncJSnqstoYP*SrD26`iKkNU#Cy#2!VDXYcS zce)}+Sp#<_P8ia;gkzHde@NeG@8l_1O zFve5J#~f}VR~tuIP7h0EgvRkDwKoAWnc!vQX4??HGx8X+jlwvBr$Lz%M#&?fPQ+Q- z_TD6)D=q+jCO`lAT_<^GuGAoNl=FG0KGD(z^0oR2J&d1;~SW#Z$= z8=$ZW(#9kgS3 znB7nCMS_P2C^ctaBKR@^F-Mfb%fuZSqQZ#J*=+=0A^0l6!vv2Ie2w7i1dkFtM({Yn z69i8Ze1o7$@GXJ^1P2MeO;98F4uMNR;cw;!)*zXfhAj{rCTJ26!LSxVo1jBbAV7GE z4%+g#z(%M@$3I59BFW~XPY!GcwWCu4>k`~Sa3{g12tG~l8G;1C5duaK5QGGafSElX z%}wkmwS0!)7(k?C=&c?0kF(>XJV8(=xCZ&K?@|S^4&wwpf*S~K1mF*X)_m}xqsWKO zAnbe8S~AjSNur$L`rK5G=xi%V5LNii3+NV{2YeJ!$#c8rsPf#dTh~3cJF0{8JjjcV zJA=cHY=8NnQ^9*2ykSzE1mQVoJs>B?W`f3DZn<6WB4&=ITF>Ex0$T+3x(%)~9 z3EfH(9or-qGn(I)jK literal 0 HcmV?d00001 diff --git a/training/__pycache__/utils.cpython-312.pyc b/training/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea34919aa9e567193411a34f4e227046d338b4e8 GIT binary patch literal 6613 zcmb^#X>c3Wd3Uv|%aY~9b`lb73t~bQht4KWcgZItG)!*m*l9Hi|)+y)x>({`qR9AXb*lTO=dWXHdnU^ zy_t{R&J2%0-NzPfhr5~JpWQ9NbQh0i5l>{lEE8jnv#lOh7l@)N+HqFrRVC1&;vDx^ zYvhO9q>u!slchFaSXSBsO0-7_`lU9FNdAyN)Y%r*`~kH&64R}s2tAr8#%qg-Z_ah2 zhNu+)`kA!VvClqiA92hh9p%Y?1V%tx*aRaY5-qfqKMrl>m}_9-!D)#x9UTevxcF7Pw zVke6L#g``l8mRbG65~+W&mr!`kcTJL@FBGbmLRCXUK0V^X6TU4hE$y8cDpI}baON! zOPUg|Ucf-36W3Bvp2$I97H|+0!1|gsxiNLeo0~6KSB%ge`)l^3H0&C&r>mFlfAGx* zQ}Sr{f$j^{tAW2(%3sjbIb;b~2&h{~3M!(g+eI-LmZJfzJ4EpZQ7Moev52A^_K6~? zf;m*c- zo58r*ym9kj;)ZEEACpKR_N*3;M8JNcjj-6vz9d@m?n5)D-nn}E)aD@?TCIa08pJ}ZcrU(x^_8y?CF|i@Uf+r;>VF*UJoBvdU*q= zitA=mG{Xv%UweQr$gc{f5BSmz-KC{}c>+0>?*=TKSyIqUb1?>;D9$Oj+AYiuR_g(o zkTz=8NL@*h*{oUT_tsGPXhAu3mzi@A;vtOCNSwnpz6Y7Z{E4haupsktv&+xyX#*;) zG7lWd?-SQ&5torFz2moRrtg5;IoRJv#j8WW}B7`jtbQ_I^L@sQS#|zTj|=EgQdo z+R>7BE>7;)FTN>`X~*J6;}>dMK6SPL(!TAlZBMPg;HVqfe%Y~TvS!WrruViV-#)%> zV%f)ye`xxo>EnjU#ZON;o=Ll^_a%lCDQ&dxK;NgXd((HWO7cH(T{AINP1np!L*q@m zPQk7#ZXen-vi)W|umJHWgZU!>_A^@EOOy>6TzB3Z0D89MQit+dRIrI6eTBeRPAg^Z zAf!HMi?qfc7798F`XRJN(HqabyZ7pQyJtp5Ge7_3+<}oRAHCsT9XB;U&=D8hO-<|E z(Zzt|X?9;Ze>ih0H8XxRbL92r=4S9&f)*yeE=XuVUv#^wL3k2nzfYq+h^&EWRH1TV z4ZRF$0^zCN5wPGPBuO?QTwrVBKF7XiSXo;;$-DZ(^iQ}tkq+hLqSokk2?#R-W<_9 z7J|tjLBB}q5MX9v#JBsfhUAh41elK)A?^;WAwYXmF{opRYo6!Kg8PfAN}+Dv%?a(S zF$3{@#d60w)tPpH-=s4#dy*?iu}HsoSf!-MCK=>^0@sGZ7HebP*BE066O1gJ%;)vr zh|S8DTVP|HFDbIMEX`#+scb8YiMeA*k?mzMF&vhZoJB1-6J{(WMXo5j;;lH#|Hqu= zWXCOdRZb2j*?9}>a-6L!OY{H6Sy$OL|Nopt&r(tfvP$rzwn)FxRh5)PQOqsE#C*D> z6jqH@c@{<2!tm_5RA%S{BN>Fdp5X;&e>Pw^!YdPRWCr)l9z8w##^B7sU(NpZyoZ}Z zC7eTRgg z7EpBaE=sI!j%B|ArBeyXI^RkBvhL9Qon4w3kYbQCzzX2ftfPRgs~Xt7h*xHAs_~yZwsrUa(V9Nh=^d1b_8k0Q#9L zyk&6P?rq85X}$sLPwajo$tGKp(iFe!GPJks-ZIoSRzJ-SgsI5S6)?ZnE5#~24w?}e{$u9s*_pF@ zW=D_DrQQb*fy|o2^OFYw?9f?R7d8Z*tOTy140#cNZbA8n+a;o4R|kSl1YLA6sKE29 zPgPOVsn^asZkI0{j6~rIkVPRqfZ~V0ikbjQ;Hpb?jCEXeHT7?~T-7jU9yd=|C!czH zs_L2kt(TotL(Z(8a(=h}@yoVF$yKRU?T3`&$ifh~YO^kT5FXUqEGnfI*BoZ+Jh`hj)C$D=E25VH5KrTmjl(L{`-&@Z!Hh4G3Uv5Wu!{ApkE$|@P%dyN0kd0%{F4GH4FY6KL1GI6$ znmzaS%#lNxp))f-9h^CI65y3{ALwjvoO{@;3Y=>^@QDV9P@fe&?Nu+hq7@VC&dVvvf_^VF=K$CS2b)`%fZMR=z zIKeZq`o;+`8O zm~#UsUPtg1eU=_syWV2@#3HP(vKWt`Tx+YRmbAcrx)6#6;fb3IMf64R!jE1h;BSXl zZ>&OpbKEzf=-2~v8iDlVKjN=ls_|HtkUl?VYQ7$pHml)d(-o?U!0o-W3 eoo#2YGYIF`tzcdAkGKFdR(yS zn;zQrTs2pBs!n~VI#kbB^Yuctpq{a?SRGdVOV!dhHEwWoMdRivy*h%H#ci}~v_|;| z&v9pkRrl~wKE(6T#`qpy;6-SA`7kf7Xw`k39o5QXTkuJnC|k1R`F<)qf6(fuvIgB}+dfLVj{EDl<03>$Th}%;0`jPKolTUPv_udG$#o{pI0&b&Yl6Cc@L7I% z(u+JyihGlOg9npgkj@A7nm3sUuNKv!naMP%g>j|1C>_^@C2?J1Kuq7$*wX&(PFMOp zKH9rS@D^t+&CMiJI`DzTAaBe1grMn^O=&E6p)|aO8EM36U0Qx4jS@j0SI%c*n3_Wn zq{I${#%%Rl8r$KTn*SCJ?Nf;Q7wIN$1>KjRw3Kr#80Q#teV_I$uytXDMQ`bHAI1QdL&u2%X?|3P3|SZUOXL4V2qX6(%u~t$a*1` zf_j#76pzTO2d*pgu3K;LG$g&~y04~Q*n5(5UEc6rS75DXIEnqFQS`S7g<=jinuKz~ zVJ347#};JuaqqYYwdZih50Jd5)+PhZX;l)t=VLAMwC&I;}Nu${Wxzfim zNawW+#OaflJ2&3gym_O0`_#YhoZC8ixBcU_&AY3ev)9@e&TNqziWnaZ`?G;R-)z*P zq|!`$3SR#reXtImK%*B-gbM*QAO}o?ZgR3)igiZDx^I+i=`_8h7DRp^Egpo4Crub; zWlqsA5Hdd%#7Qq5akS$q%zfa`gHU1FoS53(Y2bQ|(P@Us^m$P_llP~blUB_Wvgb$vsD zMPLI=$_f_OFPVT7!BLud#Ue-*(X5GdAKu-1^W5fJCp&A`+Q0t&)-TuYy?=iEi6x`* z#pxw${HdpojHi@5%9ZhZcdxW>Te!>rLi%#^yrSrRd&T?0wYpQQo$3%;c`wqM1}(-2ww3q z5-QWs;uG|VZJ$*Rs8#TBJfM4!Y4VOUD#+p!z55sG1G@hcjRaYjGp?@zO`wm0iehSP z7}O%VQtnx!I8cir1NB<#CU-WncaV>In$7b)Ek~si<2k(0vt%QigOPP!?2m_wp^X*RP<8^(Zprp7e8=X!rKT_L+CHWQeMtv8DT~m5ea= z-aFr3J>A{7**(A7x%7wbpYN9SEkbvT&_a#RP1k1P+ESoaDd$P2IDNNC~ZhlpR)co_7ZXBKJv-!Z;f2+As2%9 z%Xp#+p{5u05)w;^<<-w&`3KsLZo(>Fz&kri7e$VWugWe7lmuC&#ZKwVX2RxBwtB2L z=FAp2Pn*b=E}m@ZkfAcry>@-;;yd6MIk!a=5EY1Hl5uH0AEp4&^wOkJwyKsFMTnR&T}SZz zy4RG26G_dVR}qV4Vex(=^&74R&yN6}+P;UMJ5a~0ZeK^cVBNKh`2s1RIk zteFTkb|k88t{KZ#f|NFB*{Hr|%WB!&x@<1zS~(;OJEgR{d2^$C^B3(4E8X9$ZN9Z~ z@6OwubE^PN+Eca26N_27)nyLLg3cFm7@47gqOqV(JlIgagi)~JcKW9xIW#yANjp0U zNqzch)L+Ua!XQCqLlA4rj`$Ag--W2=2;s(*CkQ9;3W*<(s6xnGJ;0I4k9Q+I;y3Ee z6m_ikyoSy(RY?uinkse4(3$%E$IgjK*nRS&e~}ZEQxv|>Ljk_P^pyTuWf^#`vkmP9 z?T3dZ^<`t(q=Ma|-HzTeY0q;-rQlchy{vsf`=0iF?G+p{fvly@Us?6--hRDv^>X{n z+nt}Ub}s)8^4{%tB|E;PA0C%1kHRR;Ml*<{RZD_8HYLXe6woe*Ge;UGNG;Q~FrXc| zWYf}^4;H0`B$NibfhelP(|cQ43@}mNFUl}TD@yA)o9R(g z?!k$aI2+cca9mT_)6V_{t(odbJRf4e8ddsX_>f?#C58+dcvB4#i3$3I-iXyHKow|} vxdd@s1wwhXZ5o3bFP*PsN0K9S*<@fv_3?@Y! literal 0 HcmV?d00001 diff --git a/training/train.py b/training/train.py new file mode 100644 index 0000000..64f7802 --- /dev/null +++ b/training/train.py @@ -0,0 +1,312 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import os +from tqdm import tqdm +import time +from training.utils import AverageMeter, EarlyStopping, plot_training_curves, compute_metrics + +def train_epoch(model, train_loader, criterion, optimizer, device): + """训练一个epoch""" + model.train() + losses = AverageMeter() + acc = AverageMeter() + + # 进度条 + pbar = tqdm(train_loader, desc='训练') + + for batch in pbar: + # 获取数据和标签 + diff_imgs = batch['diff_img'].to(device) + wnb_imgs = batch['wnb_img'].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(diff_imgs, wnb_imgs) + + # 计算损失 + loss = criterion(outputs, labels) + + # 反向传播和优化 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # 计算准确率 + _, preds = torch.max(outputs, 1) + batch_acc = (preds == labels).float().mean() + + # 更新统计 + losses.update(loss.item(), labels.size(0)) + acc.update(batch_acc.item(), labels.size(0)) + + # 更新进度条 + pbar.set_postfix({ + 'loss': losses.avg, + 'acc': acc.avg + }) + + return losses.avg, acc.avg + + +def validate(model, val_loader, criterion, device): + """验证模型""" + model.eval() + losses = AverageMeter() + acc = AverageMeter() + + all_labels = [] + all_preds = [] + all_probs = [] + + with torch.no_grad(): + for batch in tqdm(val_loader, desc='验证'): + # 获取数据和标签 + diff_imgs = batch['diff_img'].to(device) + wnb_imgs = batch['wnb_img'].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(diff_imgs, wnb_imgs) + + # 计算损失 + loss = criterion(outputs, labels) + + # 计算准确率 + probs = torch.softmax(outputs, dim=1) + _, preds = torch.max(outputs, 1) + batch_acc = (preds == labels).float().mean() + + # 更新统计 + losses.update(loss.item(), labels.size(0)) + acc.update(batch_acc.item(), labels.size(0)) + + # 保存预测结果用于计算指标 + all_labels.extend(labels.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + all_probs.extend(probs.cpu().numpy()) + + # 计算其他评估指标 + metrics = compute_metrics(all_labels, all_preds, all_probs) + metrics['loss'] = losses.avg + metrics['accuracy'] = acc.avg + + return losses.avg, acc.avg, metrics + + +def train_model(model, train_loader, val_loader, criterion, optimizer, device, + num_epochs=50, save_dir='./models', model_name='model'): + """训练模型""" + # 保存最佳模型的路径 + if not os.path.exists(save_dir): + os.makedirs(save_dir) + best_model_path = os.path.join(save_dir, f'{model_name}_best.pth') + + # 初始化早停 + early_stopping = EarlyStopping(patience=10, path=best_model_path) + + # 跟踪训练历史 + train_losses = [] + val_losses = [] + train_accs = [] + val_accs = [] + best_val_acc = 0.0 + + # 开始训练 + start_time = time.time() + + for epoch in range(num_epochs): + print(f'\nEpoch {epoch+1}/{num_epochs}') + print('-' * 20) + + # 训练 + train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) + train_losses.append(train_loss) + train_accs.append(train_acc) + + # 验证 + val_loss, val_acc, val_metrics = validate(model, val_loader, criterion, device) + val_losses.append(val_loss) + val_accs.append(val_acc) + + # 打印当前epoch的结果 + print(f'训练损失: {train_loss:.4f} 训练准确率: {train_acc:.4f}') + print(f'验证损失: {val_loss:.4f} 验证准确率: {val_acc:.4f}') + print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}') + + # 检查是否为最佳验证准确率 + if val_acc > best_val_acc: + best_val_acc = val_acc + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'val_acc': val_acc, + 'val_metrics': val_metrics + }, os.path.join(save_dir, f'{model_name}_best_acc.pth')) + print(f'保存新的最佳模型,验证准确率: {val_acc:.4f}') + + # 早停检查 + early_stopping(val_loss, model) + if early_stopping.early_stop: + print(f"早停! 在第 {epoch+1} 个Epoch停止训练") + break + + # 计算总训练时间 + total_time = time.time() - start_time + print(f'训练完成! 总用时: {total_time/60:.2f} 分钟') + + # 加载最佳模型 + model.load_state_dict(torch.load(best_model_path)) + + return model, { + 'train_losses': train_losses, + 'val_losses': val_losses, + 'train_accs': train_accs, + 'val_accs': val_accs, + 'best_val_acc': best_val_acc, + 'total_time': total_time + } + + +def train_single_modal_model(model, train_loader, val_loader, criterion, optimizer, device, + num_epochs=50, save_dir='./models', model_name='single_modal', modal_key='diff_img'): + """训练单模态模型""" + # 保存最佳模型的路径 + if not os.path.exists(save_dir): + os.makedirs(save_dir) + best_model_path = os.path.join(save_dir, f'{model_name}_best.pth') + + # 初始化早停 + early_stopping = EarlyStopping(patience=10, path=best_model_path) + + # 跟踪训练历史 + train_losses = [] + val_losses = [] + train_accs = [] + val_accs = [] + best_val_acc = 0.0 + + # 开始训练 + start_time = time.time() + + for epoch in range(num_epochs): + print(f'\nEpoch {epoch+1}/{num_epochs}') + print('-' * 20) + + # 训练 + model.train() + train_loss = AverageMeter() + train_acc = AverageMeter() + + pbar = tqdm(train_loader, desc='训练') + for batch in pbar: + # 获取数据和标签 + imgs = batch[modal_key].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(imgs) + + # 计算损失 + loss = criterion(outputs, labels) + + # 反向传播和优化 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # 计算准确率 + _, preds = torch.max(outputs, 1) + batch_acc = (preds == labels).float().mean() + + # 更新统计 + train_loss.update(loss.item(), labels.size(0)) + train_acc.update(batch_acc.item(), labels.size(0)) + + # 更新进度条 + pbar.set_postfix({ + 'loss': train_loss.avg, + 'acc': train_acc.avg + }) + + train_losses.append(train_loss.avg) + train_accs.append(train_acc.avg) + + # 验证 + model.eval() + val_loss = AverageMeter() + val_acc = AverageMeter() + + all_labels = [] + all_preds = [] + + with torch.no_grad(): + for batch in tqdm(val_loader, desc='验证'): + # 获取数据和标签 + imgs = batch[modal_key].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(imgs) + + # 计算损失 + loss = criterion(outputs, labels) + + # 计算准确率 + _, preds = torch.max(outputs, 1) + batch_acc = (preds == labels).float().mean() + + # 更新统计 + val_loss.update(loss.item(), labels.size(0)) + val_acc.update(batch_acc.item(), labels.size(0)) + + # 保存预测结果用于计算指标 + all_labels.extend(labels.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + val_losses.append(val_loss.avg) + val_accs.append(val_acc.avg) + + # 计算其他评估指标 + val_metrics = compute_metrics(all_labels, all_preds) + + # 打印当前epoch的结果 + print(f'训练损失: {train_loss.avg:.4f} 训练准确率: {train_acc.avg:.4f}') + print(f'验证损失: {val_loss.avg:.4f} 验证准确率: {val_acc.avg:.4f}') + print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}') + + # 检查是否为最佳验证准确率 + if val_acc.avg > best_val_acc: + best_val_acc = val_acc.avg + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'val_acc': val_acc.avg, + 'val_metrics': val_metrics + }, os.path.join(save_dir, f'{model_name}_best_acc.pth')) + print(f'保存新的最佳模型,验证准确率: {val_acc.avg:.4f}') + + # 早停检查 + early_stopping(val_loss.avg, model) + if early_stopping.early_stop: + print(f"早停! 在第 {epoch+1} 个Epoch停止训练") + break + + # 计算总训练时间 + total_time = time.time() - start_time + print(f'训练完成! 总用时: {total_time/60:.2f} 分钟') + + # 加载最佳模型 + model.load_state_dict(torch.load(best_model_path)) + + return model, { + 'train_losses': train_losses, + 'val_losses': val_losses, + 'train_accs': train_accs, + 'val_accs': val_accs, + 'best_val_acc': best_val_acc, + 'total_time': total_time + } \ No newline at end of file diff --git a/training/utils.py b/training/utils.py new file mode 100644 index 0000000..f74c572 --- /dev/null +++ b/training/utils.py @@ -0,0 +1,126 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix +import seaborn as sns +import os + +class AverageMeter: + """跟踪平均值和当前值""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class EarlyStopping: + """提前停止训练,避免过拟合""" + def __init__(self, patience=7, delta=0, path='checkpoint.pt'): + self.patience = patience + self.delta = delta + self.path = path + self.counter = 0 + self.best_score = None + self.early_stop = False + + def __call__(self, val_loss, model): + score = -val_loss + + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model) + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model) + self.counter = 0 + + def save_checkpoint(self, val_loss, model): + torch.save(model.state_dict(), self.path) + print(f'验证损失降低 ({self.best_score:.6f} --> {-val_loss:.6f}). 保存模型...') + + +def plot_training_curves(train_losses, val_losses, train_accs, val_accs, save_path=None): + """绘制训练和验证的损失与准确率曲线""" + plt.figure(figsize=(12, 5)) + + plt.subplot(1, 2, 1) + plt.plot(train_losses, label='训练损失') + plt.plot(val_losses, label='验证损失') + plt.title('损失曲线') + plt.xlabel('Epoch') + plt.ylabel('损失') + plt.legend() + plt.grid(True) + + plt.subplot(1, 2, 2) + plt.plot(train_accs, label='训练准确率') + plt.plot(val_accs, label='验证准确率') + plt.title('准确率曲线') + plt.xlabel('Epoch') + plt.ylabel('准确率') + plt.legend() + plt.grid(True) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + + plt.show() + + +def plot_confusion_matrix(y_true, y_pred, class_names=None, save_path=None): + """绘制混淆矩阵""" + cm = confusion_matrix(y_true, y_pred) + + plt.figure(figsize=(8, 6)) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=class_names if class_names else "auto", + yticklabels=class_names if class_names else "auto") + plt.title('混淆矩阵') + plt.xlabel('预测标签') + plt.ylabel('真实标签') + + if save_path: + plt.savefig(save_path) + + plt.show() + + +def compute_metrics(y_true, y_pred, y_proba=None): + """计算各种评估指标""" + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, average='binary', zero_division=0) + recall = recall_score(y_true, y_pred, average='binary') + f1 = f1_score(y_true, y_pred, average='binary') + + metrics = { + 'accuracy': accuracy, + 'precision': precision, + 'recall': recall, + 'f1': f1 + } + + return metrics + + +def save_results(results, filename): + """保存结果到文本文件""" + with open(filename, 'w') as f: + for key, value in results.items(): + f.write(f"{key}: {value}\n") \ No newline at end of file