VAE fix w/ new results

This commit is contained in:
Johnny Fernandes
2026-05-02 00:32:45 +01:00
parent 1a7f67ab9c
commit bac52bc15e
90 changed files with 1197 additions and 1106 deletions
@@ -7,6 +7,7 @@
"model": "vae", "model": "vae",
"latent_dim": 256, "latent_dim": 256,
"ngf": 64, "ngf": 64,
"grad_clip": 1.0,
"sample_interval": 10, "sample_interval": 10,
"fid_interval": 25, "fid_interval": 25,
"fid_n_real": 5000 "fid_n_real": 5000
+2 -2
View File
@@ -1,8 +1,8 @@
{ {
"extends": "_base_phase3.json", "extends": "_base_phase3.json",
"run_name": "p3_1_vae", "run_name": "p3_1_vae",
"lr": 1e-3, "lr": 5e-4,
"beta_kl": 1.0, "beta_kl": 0.25,
"lambda_perceptual": 0.0, "lambda_perceptual": 0.0,
"lambda_adversarial": 0.0 "lambda_adversarial": 0.0
} }
@@ -1,8 +1,8 @@
{ {
"extends": "_base_phase3.json", "extends": "_base_phase3.json",
"run_name": "p3_2_vae_perceptual", "run_name": "p3_2_vae_perceptual",
"lr": 1e-3, "lr": 5e-4,
"beta_kl": 0.0001, "beta_kl": 0.25,
"lambda_perceptual": 0.1, "lambda_perceptual": 0.1,
"lambda_adversarial": 0.0 "lambda_adversarial": 0.0
} }
@@ -1,10 +1,10 @@
{ {
"extends": "_base_phase3.json", "extends": "_base_phase3.json",
"run_name": "p3_3_vae_patchgan", "run_name": "p3_3_vae_patchgan",
"lr": 1e-3, "lr": 5e-4,
"lr_d": 1e-4, "lr_d": 1e-4,
"beta_kl": 0.0001, "beta_kl": 0.25,
"lambda_perceptual": 0.1, "lambda_perceptual": 0.1,
"lambda_adversarial": 0.1, "lambda_adversarial": 0.01,
"ndf_patch": 64 "ndf_patch": 64
} }
+2 -1
View File
@@ -6,5 +6,6 @@
"subsample": 1.0, "subsample": 1.0,
"sample_interval": 10, "sample_interval": 10,
"fid_interval": 25, "fid_interval": 25,
"fid_n_real": 5000 "fid_n_real": 5000,
"num_workers": 2
} }
+209 -207
View File
@@ -11,222 +11,224 @@
"sample_interval": 10, "sample_interval": 10,
"fid_interval": 25, "fid_interval": 25,
"fid_n_real": 5000, "fid_n_real": 5000,
"num_workers": 2,
"epochs": 100, "epochs": 100,
"augment": "hflip", "augment": "hflip",
"image_size": 64, "image_size": 64,
"model": "vae", "model": "vae",
"latent_dim": 256, "latent_dim": 256,
"ngf": 64, "ngf": 64,
"grad_clip": 1.0,
"run_name": "p3_1_vae", "run_name": "p3_1_vae",
"lr": 0.001, "lr": 0.0005,
"beta_kl": 1.0, "beta_kl": 0.25,
"lambda_perceptual": 0.0, "lambda_perceptual": 0.0,
"lambda_adversarial": 0.0 "lambda_adversarial": 0.0
}, },
"history": { "history": {
"recon_loss": [ "recon_loss": [
0.23614721197603095, 0.07131652818180811,
0.23315699178821, 0.04839004274521373,
0.22991716011594504, 0.04452062843956499,
NaN, 0.04353479713870165,
0.23217070787253544, 0.043699693698913626,
0.23155480842941847, 0.044162471062288836,
0.23157141198459855, 0.044934689377744995,
0.23181156750418183, 0.04592526849741355,
0.23201335527193853, 0.047077489165095694,
0.23178868266379732, 0.04793272184160275,
0.2315022333755962, 0.04926021105776995,
0.2311908418042028, 0.05025381562260226,
0.23185610672474927, 0.05120742289174316,
0.23176095832107413, 0.052108900581733286,
0.23165411693163407, 0.053131219620505966,
0.23174296459581098, 0.05401940069869798,
0.2317636658747991, 0.05489145301314246,
0.2317118427883356, 0.055811726878214084,
0.23172695364834917, 0.05665480460907914,
0.2316696329567677, 0.0575029267412093,
0.23168399261358458, 0.05754986319404382,
0.2316194716681782, 0.057160187035034865,
0.23164867447354856, 0.05722308493991438,
0.2315481170757204, 0.05713338962891418,
0.23165068109957582, 0.05686132515119946,
0.23167062098653907, 0.056816555862116,
0.23162642907765177, 0.05664816373784063,
0.2315922882567104, 0.05655015064164614,
0.2315914996414103, 0.056517735943516605,
0.23156180984189367, 0.056386624335542194,
0.23156551628286004, 0.05631861151156262,
0.2315698005259037, 0.056178740154092126,
0.2315660522470617, 0.056074508314586095,
0.23156735001720935, 0.05601463455738675,
0.23161396435183337, 0.05584243320438088,
0.23158050178844705, 0.05574028127086468,
0.23159921089680785, 0.05563880511137665,
0.23149616745674712, 0.055547706926098235,
0.23159087484336308, 0.05556490144923202,
0.23156312872201967, 0.05538980011692923,
0.23153820200863048, 0.05529476007303366,
0.2315863819203825, 0.05527778912303794,
0.23150022140043414, 0.0552029303378529,
0.23154497337646973, 0.05519345425801654,
0.2315601774961011, 0.05497326165374018,
0.23153368950399578, 0.05496025659366805,
0.23152085642019907, 0.0549636375095345,
0.23151608884461924, 0.0548208407086567,
0.23154898990805334, 0.05475613919015114,
0.23155892872784892, 0.05467982544826391,
NaN, 0.05467521703332408,
NaN, 0.05451676477160719,
NaN, 0.054423171549271315,
0.24157701413600874, 0.05440536335620106,
NaN, 0.054226587956341415,
NaN, 0.05415793756643931,
0.24151325464630738, 0.0540492157650809,
NaN, 0.0539091299725776,
0.24154121766233036, 0.05381450568859139,
0.24155463749526912, 0.053790501263151824,
0.24158300176008135, 0.053688646684217654,
0.24158118757554609, 0.05361353090176216,
0.2415294518901242, 0.05348594906206569,
0.24156020069096842, 0.053407231775613934,
0.2415176352374574, 0.05329926665394734,
0.2415566616015047, 0.053199282489143886,
NaN, 0.05318549033413585,
0.24161437115608117, 0.053034571994446285,
0.24159398913765565, 0.05299898797375524,
0.24149432768806434, 0.052894837278713525,
0.24153172199287984, 0.05281204023422339,
0.24161516999204954, 0.05279633629685029,
0.24158193846034187, 0.05266677101071064,
0.2415451397562129, 0.0525879480310867,
0.24155487772873324, 0.052519615285862714,
0.24155297130346298, 0.05243872805761221,
NaN, 0.05236124007715883,
0.24157197961313093, 0.052327762763851725,
NaN, 0.05222526421117732,
NaN, 0.05212976400636964,
0.24158605401459923, 0.05209252470706263,
0.24156368870893094, 0.0520137355177321,
0.24159100852333582, 0.051939742568020635,
0.24153350121699846, 0.05186714857625656,
0.24153158377505776, 0.051828445420942754,
NaN, 0.051747049658726424,
0.24161708673350832, 0.05169421551414789,
0.24158515879868442, 0.05160299702109689,
NaN, 0.05152464638917874,
0.24157126235146809, 0.051478635081941754,
0.24162366709265953, 0.0514086935764704,
NaN, 0.05138895747403049,
0.2415581897665293, 0.051289146423785605,
NaN, 0.05126826443637793,
NaN, 0.05114383632555986,
0.2415400046823371, 0.051157466693120636,
NaN, 0.05104086854550828,
0.2415627600927638, 0.05102811497437139,
0.2415567432076503, 0.0509683930545918,
0.2415620140476614 0.05096899156068635
], ],
"kl_loss": [ "kl_loss": [
12.394881742504927, 0.7611013715847944,
184.775765717539, 0.5151619965321997,
127.26797539963681, 0.39718208143599015,
33346392.786626913, 0.3297253664360087,
35.72433020722153, 0.2878693372138545,
31.41954361882984, 0.257130523904776,
16.178619678203876, 0.23313261619490436,
10.234501274223001, 0.21345858896772066,
14.817130448471787, 0.1973419941197603,
9.230570034084158, 0.18355375779872266,
9.643558593896719, 0.1717705535583007,
8.47786058498244, 0.1612550135797415,
5.573643362929678, 0.15217030946260843,
2.4644629534365783, 0.1444861331047156,
1.5757666807462516, 0.13738160394132137,
0.426466258131286, 0.1306394640611023,
1.7924597560404203, 0.12505544035926333,
0.2769168242652956, 0.12011377760169344,
0.21636260826236162, 0.11535473169488275,
0.48804672485870176, 0.1109595281095841,
0.10833573165453142, 0.11022479862420477,
0.13318477837671328, 0.10970949868743236,
0.17373992544877478, 0.10961869400408533,
0.09584700099678121, 0.10945094661771232,
0.0977757986014088, 0.10919397698444688,
0.07794108981282538, 0.10875842463957448,
0.05691333960853199, 0.10874241859548622,
0.07221067506167242, 0.10849472851707385,
0.036222075203704275, 0.10825413890565053,
0.03126689469696492, 0.10820380448658243,
0.04264315036642882, 0.10811882353045492,
0.016960328184147805, 0.10812433239104402,
0.03314871971324309, 0.10808505038293,
0.014776984407789368, 0.10790401608006567,
0.011375301962312406, 0.10787000140955305,
0.013948339588828703, 0.10800883411151221,
0.01186063720120324, 0.10767164218247446,
0.0099704478863372, 0.10764909312765822,
0.00536374123289417, 0.10733446480435693,
0.009618068660179583, 0.10740953346348217,
0.00418840028031164, 0.10733820650822078,
0.004865833775052785, 0.10725643148279598,
0.005830266629345715, 0.10736855125834799,
0.0023000687699064487, 0.10717424525855443,
0.0038261460966199762, 0.10725876993030055,
0.0022056369562673136, 0.10694582887694366,
0.002220870125003987, 0.10713813684753373,
0.0024217167485139184, 0.10726828694853008,
0.001954249278483037, 0.10701580285134478,
0.0021431104709895756, 0.10700331553498395,
0.0022583500309011494, 0.10686330029215568,
0.002132287005193404, 0.10687567073947345,
56.80083886633675, 0.10698378102010132,
82.57108385134966, 0.10673714201483461,
82.55195800259582, 0.10696066876188812,
82.57428529527452, 0.10678731051520404,
82.56009972401155, 0.10679936213180041,
82.55269345666608, 0.10696167247290285,
82.57006728343474, 0.1067299399142846,
82.55670593131302, 0.10684242702893212,
82.54445134676419, 0.10679143969701906,
82.57745079301361, 0.10698746744957235,
82.57933913336859, 0.10674736718846183,
82.5570435157189, 0.10685917330730675,
82.56808758597089, 0.1070305180664246,
82.56800172267816, 0.10704450406388849,
82.56525711320404, 0.10696375739370656,
82.56189481621115, 0.10697911899441327,
82.55193622295673, 0.10700849091841115,
82.55375865382007, 0.10684541509383255,
82.56600202250685, 0.10709039088434134,
82.57064581324912, 0.10708275965900503,
82.55481151026538, 0.10700157213096435,
82.55367833324986, 0.10683403667221722,
82.56042112040724, 0.10696323639434627,
82.5616829048874, 0.10717970753709476,
82.5771528553759, 0.10707768420569408,
82.55317820035495, 0.10707720299052377,
82.57550573756552, 0.10703401576377387,
82.57334061973116, 0.10714245904396233,
82.56044387817383, 0.10722182246928032,
82.5752662593483, 0.10714326931052229,
82.56673936762361, 0.10708108509325573,
82.56828115740393, 0.10726493315245861,
82.56990289280557, 0.10718185655199565,
82.55218840052939, 0.10716394220407192,
82.56695372426611, 0.10727782184496903,
82.575043066954, 0.10729229825938869,
82.55754522991995, 0.10722832862510641,
82.56361721723508, 0.10727461647146787,
82.5628145821074, 0.10739002018593825,
82.56431990403395, 0.10721855878065793,
82.55777725806603, 0.10737398387784632,
82.5742861918914, 0.10721981757853785,
82.56361025622768, 0.10756766480895188,
82.56887233766736, 0.10733713450021723,
82.56539458902473, 0.10742478621884799,
82.55887828729091, 0.10721213524986027,
82.56073884882478, 0.10737172850113139,
82.55578186165573 0.10744189095293355
], ],
"perc_loss": [ "perc_loss": [
0.0, 0.0,
@@ -535,12 +537,12 @@
0.0 0.0
], ],
"fid": { "fid": {
"25": 315.9393615722656, "25": 108.92365264892578,
"50": 419.273193359375, "50": 93.73921203613281,
"75": 360.4432678222656, "75": 90.11531829833984,
"100": 363.9911193847656 "100": 88.40287780761719
}, },
"train_time_s": 660.9630489349365 "train_time_s": 1526.7542352676392
}, },
"n_params": 10608451 "n_params": 10608451
} }
+309 -307
View File
@@ -11,324 +11,326 @@
"sample_interval": 10, "sample_interval": 10,
"fid_interval": 25, "fid_interval": 25,
"fid_n_real": 5000, "fid_n_real": 5000,
"num_workers": 2,
"epochs": 100, "epochs": 100,
"augment": "hflip", "augment": "hflip",
"image_size": 64, "image_size": 64,
"model": "vae", "model": "vae",
"latent_dim": 256, "latent_dim": 256,
"ngf": 64, "ngf": 64,
"grad_clip": 1.0,
"run_name": "p3_2_vae_perceptual", "run_name": "p3_2_vae_perceptual",
"lr": 0.001, "lr": 0.0005,
"beta_kl": 0.0001, "beta_kl": 0.25,
"lambda_perceptual": 0.1, "lambda_perceptual": 0.1,
"lambda_adversarial": 0.0 "lambda_adversarial": 0.0
}, },
"history": { "history": {
"recon_loss": [ "recon_loss": [
NaN, 0.07014625494042014,
0.041370747770127066, 0.04432385861396025,
0.035506031866002284, 0.03902342810462683,
0.06229733923672993, 0.036902043435117625,
0.06089382184048494, 0.035906293887135565,
0.053464704654856116, 0.035914984559560686,
0.04950355895213846, 0.0360899488401846,
0.04674786329269409, 0.036847473583860785,
0.05693557829811023, 0.037394630213260144,
0.04808994451076047, 0.03836990671200503,
NaN, 0.039295883698022775,
0.05869839608701121, 0.040212508386526354,
0.05293231666820426, 0.04124497003757801,
0.05057006603918779, 0.04212733724305772,
0.04713928615117175, 0.04325710433638758,
0.04481589820427008, 0.04402907300963361,
0.04331832689543565, 0.045103192679647706,
0.041820655449524395, 0.04595246975525068,
NaN, 0.04698181097419598,
0.24315394992884407, 0.047825980820080154,
0.24305414841470555, 0.04772546164627768,
0.2429187158998261, 0.04763130701950982,
0.24289727962424612, 0.047709369705591954,
0.2377373814328104, 0.04753121634961194,
0.23880945107875726, 0.04741841174948674,
0.24047312904626894, 0.04729832872016053,
0.23915709144411942, 0.04722155279551561,
NaN, 0.0469712430340612,
0.24000247061634675, 0.04693082829093576,
0.2362435321904655, 0.0468663705401441,
0.23705266075383905, 0.04660685717040657,
NaN, 0.04661381538384236,
0.23998981039238793, 0.046462044382515624,
0.23828768376738596, 0.0464134263431924,
0.23908851302077627, 0.046272445438254595,
NaN, 0.04624118105882508,
NaN, 0.046066727958874315,
NaN, 0.04603223238363225,
NaN, 0.04589838094404365,
NaN, 0.0458371472489248,
NaN, 0.04577782691225537,
NaN, 0.04567820464985238,
NaN, 0.04560983914913785,
NaN, 0.045465615761076286,
NaN, 0.04543259674603613,
NaN, 0.04530587481159685,
NaN, 0.04528731873465909,
NaN, 0.04525673297098559,
NaN, 0.04507890154217553,
NaN, 0.0450239604514124,
NaN, 0.04497251241730574,
NaN, 0.044865927300774135,
NaN, 0.04475303162207715,
NaN, 0.044664276763796806,
NaN, 0.044507257990602754,
NaN, 0.0444452501515038,
NaN, 0.044273821509674065,
NaN, 0.044196275588220514,
NaN, 0.044097925846775375,
NaN, 0.04397543971864586,
NaN, 0.04393626836279773,
NaN, 0.043823353540247835,
NaN, 0.04371356347209623,
NaN, 0.043580674797169164,
NaN, 0.04348461520181507,
NaN, 0.04339473670682846,
NaN, 0.043282039009798795,
NaN, 0.04320543994888281,
NaN, 0.04314021340324583,
NaN, 0.04300904847904403,
NaN, 0.04293948887950844,
NaN, 0.04286929544730064,
NaN, 0.04278319672896312,
NaN, 0.04270272867547141,
NaN, 0.0425933551393513,
NaN, 0.04250210111276207,
NaN, 0.04243518800562263,
NaN, 0.04233357961424905,
NaN, 0.04227217231105026,
NaN, 0.042201977119677596,
NaN, 0.04207655237430436,
NaN, 0.04206944097024508,
NaN, 0.0420022471768097,
NaN, 0.041929350385808535,
NaN, 0.04179893332159417,
NaN, 0.041740718671781384,
NaN, 0.04166606991177695,
NaN, 0.04158131634959808,
NaN, 0.04155632069445828,
NaN, 0.04140628448440733,
NaN, 0.041379960253834724,
NaN, 0.041310225080093764,
NaN, 0.04123378162168794,
NaN, 0.04118641266105776,
NaN, 0.04113346862837545,
NaN, 0.041023712032116376,
NaN, 0.04102735277902112,
NaN, 0.04096395063062764,
NaN, 0.04094440599059702,
NaN 0.040864359348630294
], ],
"kl_loss": [ "kl_loss": [
41526536.5961082, 0.9578518337673612,
1258.7768262553418, 0.8413855069213443,
1165.5987154968784, 0.7201743645545764,
3689.321748260759, 0.6246620782165446,
1320.3167513333833, 0.5550231880102402,
928.3473894168169, 0.49975109450582766,
789.754654843583, 0.454662803783376,
693.4729607736963, 0.4155570190941167,
1017.8856519389357, 0.3821740793621438,
622.2623097998464, 0.352426735827556,
304942.2120989938, 0.32639839258204156,
750.5718396830763, 0.3032199253893306,
618.6888906364767, 0.2829984358360625,
586.4194498958751, 0.26505850424241817,
579.2793892102363, 0.24903864537676176,
525.8830437945504, 0.23453107457130384,
497.5363791050055, 0.22166774176761636,
452.93342662061383, 0.21033292894180006,
172476214612.95245, 0.19950743799662998,
1112.027792645316, 0.18994869887192026,
1109.8156276605068, 0.18805619294189999,
1109.8872808472722, 0.18697467955768618,
1112.3999198196282, 0.1855154659312505,
1163.6093267457097, 0.1852042447680082,
1195.4633280436199, 0.18474761941112006,
1232.709252512353, 0.1843619988514827,
1253.0651078183428, 0.18399313028551575,
19923.65347159622, 0.18352704119478536,
6871.588723207132, 0.18345636556036451,
6872.017739842081, 0.18366932483692455,
6872.62934758113, 0.18345450644946507,
27062.17111336472, 0.18317942117523944,
4992.576698759683, 0.1830864556006387,
5000.895510942508, 0.1831519056079734,
5011.723715236044, 0.18329900547734693,
4.196954763212122e+20, 0.1829878402252992,
NaN, 0.18321112291807803,
NaN, 0.18296580338197896,
NaN, 0.18312124166096377,
NaN, 0.18301436457878503,
NaN, 0.1831260528574642,
NaN, 0.18320244878657863,
NaN, 0.18332607687538505,
NaN, 0.18335427452101666,
NaN, 0.18322811696009758,
NaN, 0.18336237903334138,
NaN, 0.183573446308191,
NaN, 0.18356411601615769,
NaN, 0.18363770969912538,
NaN, 0.18377095862076834,
NaN, 0.18369814056234482,
NaN, 0.18386876185098264,
NaN, 0.18424720492245805,
NaN, 0.18420857423518458,
NaN, 0.18395239082921264,
NaN, 0.18442093539569113,
NaN, 0.1842692175036312,
NaN, 0.18475162495787328,
NaN, 0.18502490509014863,
NaN, 0.18480629139603713,
NaN, 0.1851997250993537,
NaN, 0.1850732435973791,
NaN, 0.1854018302809479,
NaN, 0.18547267151566652,
NaN, 0.18569090538936803,
NaN, 0.1859345402664099,
NaN, 0.18615943276219898,
NaN, 0.18631401260057065,
NaN, 0.1865774132948146,
NaN, 0.1867309583303256,
NaN, 0.18672093723574254,
NaN, 0.18682665680336136,
NaN, 0.18679539349853483,
NaN, 0.18683723016427115,
NaN, 0.1874776411578696,
NaN, 0.18746165714712223,
NaN, 0.1873003001141752,
NaN, 0.18763184111215112,
NaN, 0.18799093031348327,
NaN, 0.18802655252635989,
NaN, 0.18824852079662502,
NaN, 0.1882636730169129,
NaN, 0.18849152215143555,
NaN, 0.1886035389561429,
NaN, 0.188697873813729,
NaN, 0.18883028218888828,
NaN, 0.1889004738858113,
NaN, 0.18922981174073666,
NaN, 0.18946354345888153,
NaN, 0.18955816268029377,
NaN, 0.18944057507010606,
NaN, 0.18984314531852037,
NaN, 0.19001257871715432,
NaN, 0.19001863161340737,
NaN, 0.18997256346365327,
NaN, 0.19019658920856622,
NaN, 0.19018171425176483,
NaN, 0.19038618680758354,
NaN, 0.19034281325263855,
NaN 0.19046620505615178
], ],
"perc_loss": [ "perc_loss": [
NaN, 3.257610213552785,
3.0017659954535656, 2.987424520855276,
2.8931196978968434, 2.8892488026211405,
3.084328340159522, 2.8424616914529066,
3.1398269641093717, 2.817538415774321,
3.0839009172896032, 2.807587738220508,
3.0459561949102287, 2.8047564518757357,
3.015714469628456, 2.8080830421203222,
3.0917934143645134, 2.8133585376617236,
3.010985380054539, 2.821376688969441,
NaN, 2.830571848612565,
3.102917709411719, 2.840076332927769,
3.055045795746339, 2.8507282504668603,
3.0351512865123587, 2.8593530094521675,
3.002274121993627, 2.8690790253826695,
2.9754957128793764, 2.877339167982085,
2.9553630066733074, 2.8865632745954724,
2.936704080328982, 2.8936263704911256,
NaN, 2.9023511684857883,
3.66463458385223, 2.9098138376178904,
3.6419444985878773, 2.906375333794162,
3.6321474706005845, 2.9055831661591163,
3.6272282503608966, 2.9031203985214233,
3.781295938878997, 2.9002673641229286,
3.740793755421272, 2.897312533651662,
3.759704489993234, 2.8946053095352955,
3.7078149914741516, 2.8919762974111443,
NaN, 2.8902457397208257,
3.7647176323792872, 2.8877664065768576,
3.7441278461717133, 2.884983114197723,
3.7264530261357627, 2.882109224286854,
NaN, 2.8805122767758164,
3.7299226170931106, 2.8777520192994013,
3.7314435080585318, 2.876453428186922,
3.754668933713538, 2.874014714334765,
NaN, 2.8723304679251123,
NaN, 2.870547831058502,
NaN, 2.86905308564504,
NaN, 2.8669193524580736,
NaN, 2.865350999383845,
NaN, 2.8633312608441734,
NaN, 2.861651761409564,
NaN, 2.860612149421985,
NaN, 2.858717398256318,
NaN, 2.857239680412488,
NaN, 2.8558127604998074,
NaN, 2.853911985189487,
NaN, 2.852791876364977,
NaN, 2.850916815109742,
NaN, 2.8496591963319697,
NaN, 2.8490061500133614,
NaN, 2.8465628389619355,
NaN, 2.844754463587052,
NaN, 2.842957689211919,
NaN, 2.840948245464227,
NaN, 2.8397679522506194,
NaN, 2.837174654006958,
NaN, 2.835211678957328,
NaN, 2.833740895629948,
NaN, 2.8325765499701867,
NaN, 2.830295750218579,
NaN, 2.8292511431579914,
NaN, 2.8273453595291853,
NaN, 2.8257394835480256,
NaN, 2.823762384744791,
NaN, 2.822507558215378,
NaN, 2.8205288745399213,
NaN, 2.819065808230995,
NaN, 2.8176180297492914,
NaN, 2.8165967805772767,
NaN, 2.8143360721759305,
NaN, 2.8135039892971005,
NaN, 2.8120578970664587,
NaN, 2.8098325041624217,
NaN, 2.8090034366672874,
NaN, 2.8076502920215964,
NaN, 2.80601545684358,
NaN, 2.805327193859296,
NaN, 2.803745285058633,
NaN, 2.802450499473474,
NaN, 2.800580952412043,
NaN, 2.799968459157862,
NaN, 2.798532901156662,
NaN, 2.797372330967178,
NaN, 2.795650184663952,
NaN, 2.79500452371744,
NaN, 2.79342931152409,
NaN, 2.792220654650631,
NaN, 2.7911063600809145,
NaN, 2.7897617689564695,
NaN, 2.7887179505111823,
NaN, 2.7880892407180915,
NaN, 2.7868736916118197,
NaN, 2.785768971993373,
NaN, 2.784949649602939,
NaN, 2.784032260760283,
NaN, 2.783316181765662,
NaN, 2.7821382457374506,
NaN, 2.7815805586994204,
NaN 2.780150015639444
], ],
"adv_g_loss": [ "adv_g_loss": [
0.0, 0.0,
@@ -535,12 +537,12 @@
0.0 0.0
], ],
"fid": { "fid": {
"25": 263.1458740234375, "25": 85.4538345336914,
"50": 598.3736572265625, "50": 70.30448150634766,
"75": 598.3736572265625, "75": 68.88232421875,
"100": 598.3736572265625 "100": 68.23878479003906
}, },
"train_time_s": 952.6596128940582 "train_time_s": 1526.8077104091644
}, },
"n_params": 10608451 "n_params": 10608451
} }
File diff suppressed because it is too large Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,14 @@
{
"created_at": "2026-05-02T13:07:06.633366+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 36014025,
"offer_id": 35956124,
"ssh_host": "ssh3.vast.ai",
"ssh_port": 14024,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,12 @@
{
"created_at": "2026-05-02T17:39:27.858755+00:00",
"config_paths": [
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 36025335,
"offer_id": 19025020,
"ssh_host": "ssh6.vast.ai",
"ssh_port": 25334,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,12 @@
{
"created_at": "2026-05-02T18:44:33.143461+00:00",
"config_paths": [
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 36027837,
"offer_id": 29548831,
"ssh_host": null,
"ssh_port": null,
"status": "cancelled",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,12 @@
{
"created_at": "2026-05-02T18:51:49.983505+00:00",
"config_paths": [
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 36028210,
"offer_id": 35956124,
"ssh_host": "ssh1.vast.ai",
"ssh_port": 28210,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

After

Width:  |  Height:  |  Size: 93 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

After

Width:  |  Height:  |  Size: 221 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 78 KiB

After

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 120 KiB

After

Width:  |  Height:  |  Size: 214 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 76 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.5 KiB

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.5 KiB

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

After

Width:  |  Height:  |  Size: 212 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.3 KiB

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

After

Width:  |  Height:  |  Size: 211 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.4 KiB

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 211 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 121 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.3 KiB

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 91 KiB

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 215 KiB

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

After

Width:  |  Height:  |  Size: 86 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 135 KiB

After

Width:  |  Height:  |  Size: 205 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 87 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 129 KiB

After

Width:  |  Height:  |  Size: 206 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 89 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 207 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 93 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 93 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 118 KiB

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 247 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 31 KiB

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 141 KiB

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 127 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 212 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 212 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 212 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 99 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 214 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 214 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 215 KiB

+1 -1
View File
@@ -52,7 +52,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
# Count total trainable parameters # Count total trainable parameters
if isinstance(model, tuple): if isinstance(model, tuple):
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad) n_params = sum(p.numel() for m in model for p in m.parameters() if p.requires_grad)
else: else:
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {n_params:,}") print(f"Trainable params: {n_params:,}")
+11 -4
View File
@@ -22,17 +22,21 @@ def _init_weights(m):
nn.init.normal_(m.weight, 0.0, 0.02) nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) and m.weight is not None: elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None:
nn.init.normal_(m.weight, 1.0, 0.02) nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
def _norm(channels: int) -> nn.GroupNorm:
return nn.GroupNorm(8, channels)
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential: def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard.""" """Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
return nn.Sequential( return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"), nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch), _norm(out_ch),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
@@ -65,7 +69,7 @@ class VAE(nn.Module):
for _ in range(n_down - 1): for _ in range(n_down - 1):
enc_layers += [ enc_layers += [
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False), nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ch * 2), _norm(ch * 2),
nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2, inplace=True),
] ]
ch *= 2 ch *= 2
@@ -98,9 +102,12 @@ class VAE(nn.Module):
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = self.encoder(x).flatten(1) h = self.encoder(x).flatten(1)
return self.fc_mu(h), self.fc_lv(h) log_var = self.fc_lv(h).clamp(-10.0, 10.0)
return self.fc_mu(h), log_var
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
if not self.training:
return mu
std = torch.exp(0.5 * log_var) std = torch.exp(0.5 * log_var)
return mu + std * torch.randn_like(std) return mu + std * torch.randn_like(std)
+2
View File
@@ -20,3 +20,5 @@ class EMA:
def update(self, model: nn.Module) -> None: def update(self, model: nn.Module) -> None:
for p_ema, p in zip(self.model.parameters(), model.parameters()): for p_ema, p in zip(self.model.parameters(), model.parameters()):
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay) p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
for b_ema, b in zip(self.model.buffers(), model.buffers()):
b_ema.copy_(b)
+3 -2
View File
@@ -11,14 +11,15 @@ from torchmetrics.image.fid import FrechetInceptionDistance
class FIDEvaluator: class FIDEvaluator:
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"): def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
num_workers: int = 2):
self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.n_real = n_real self.n_real = n_real
# Cache real images as a CPU tensor ([-1, 1] range) # Cache real images as a CPU tensor ([-1, 1] range)
imgs_list = [] imgs_list = []
loader = DataLoader(real_dataset, batch_size=256, shuffle=False, loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
num_workers=4, drop_last=False) num_workers=num_workers, drop_last=False)
for batch in loader: for batch in loader:
imgs_list.append(batch.cpu()) imgs_list.append(batch.cpu())
if sum(x.size(0) for x in imgs_list) >= n_real: if sum(x.size(0) for x in imgs_list) >= n_real:
+60 -37
View File
@@ -66,7 +66,7 @@ def train_dcgan(
loader = DataLoader( loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1), num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True, pin_memory=(device.type == "cuda"), drop_last=True,
) )
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2)) opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
@@ -86,7 +86,8 @@ def train_dcgan(
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}} history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf") best_fid = float("inf")
@@ -239,7 +240,7 @@ def train_wgan(
loader = DataLoader( loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1), num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True, pin_memory=(device.type == "cuda"), drop_last=True,
) )
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2)) opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
@@ -257,7 +258,8 @@ def train_wgan(
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}} history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf") best_fid = float("inf")
@@ -391,7 +393,6 @@ def _save_vae_samples(
# Interleave real / reconstruction pairs # Interleave real / reconstruction pairs
pairs = torch.stack([real, recon], dim=1).flatten(0, 1) pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4) save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
vae.train()
def train_vae( def train_vae(
@@ -403,13 +404,21 @@ def train_vae(
run_name: str, run_name: str,
device: str = "cuda", device: str = "cuda",
) -> dict: ) -> dict:
"""VAE training loop covering Phase 3.1 3.3. """VAE training loop covering Phase 3.1 3.3 and Phase 5.
Config toggles: Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+) lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3) lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
KL is computed as mean over latent dimensions (scale-invariant), so
beta_kl is comparable across different latent_dim values.
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
computation runs in float32.
""" """
device = torch.device(device if torch.cuda.is_available() else "cpu") device = torch.device(device if torch.cuda.is_available() else "cpu")
vae = vae.to(device) vae = vae.to(device)
@@ -425,6 +434,7 @@ def train_vae(
lambda_perceptual = cfg.get("lambda_perceptual", 0.0) lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
lambda_adversarial = cfg.get("lambda_adversarial", 0.0) lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
lr_d = cfg.get("lr_d", 1e-4) lr_d = cfg.get("lr_d", 1e-4)
grad_clip = cfg.get("grad_clip", 1.0)
ema_decay = cfg.get("ema_decay", 0.9999) ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10) sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25) fid_interval = cfg.get("fid_interval", 25)
@@ -435,13 +445,13 @@ def train_vae(
loader = DataLoader( loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1), num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True, pin_memory=(device.type == "cuda"), drop_last=True,
) )
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr) opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
use_amp = device.type == "cuda" # AMP disabled — float16 overflows on KL spikes, causing NaN cascades
scaler = _GradScaler("cuda", enabled=use_amp) use_amp = False
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training # KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
kl_warmup_epochs = max(1, epochs // 5) kl_warmup_epochs = max(1, epochs // 5)
@@ -456,11 +466,10 @@ def train_vae(
perc_fn = None perc_fn = None
patchgan = None patchgan = None
opt_d = None opt_d = None
scaler_d = None
if use_perceptual: if use_perceptual:
from src.training.perceptual import PerceptualLoss from src.training.perceptual import PerceptualLoss
perc_fn = PerceptualLoss().to(device) perc_fn = PerceptualLoss().to(device).float()
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3") print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
if use_adversarial: if use_adversarial:
@@ -468,15 +477,14 @@ def train_vae(
patchgan = PatchGANDiscriminator( patchgan = PatchGANDiscriminator(
ndf=cfg.get("ndf_patch", 64), ndf=cfg.get("ndf_patch", 64),
image_size=cfg.get("image_size", 64), image_size=cfg.get("image_size", 64),
).to(device) ).to(device).float()
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999)) opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
scaler_d = _GradScaler("cuda", enabled=use_amp)
sched_d = torch.optim.lr_scheduler.LambdaLR( sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
n_d = sum(p.numel() for p in patchgan.parameters()) n_d = sum(p.numel() for p in patchgan.parameters())
print(f"PatchGAN: {n_d:,} params") print(f"PatchGAN: {n_d:,} params")
else: else:
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called hinge_d_loss = hinge_g_loss = None # never called
# ── Fixed seeds for consistent visualisation ────────────────────────── # ── Fixed seeds for consistent visualisation ──────────────────────────
fixed_z = torch.randn(16, latent_dim, device=device) fixed_z = torch.randn(16, latent_dim, device=device)
@@ -490,16 +498,19 @@ def train_vae(
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = { history = {
"recon_loss": [], "kl_loss": [], "perc_loss": [], "recon_loss": [], "kl_loss": [], "perc_loss": [],
"adv_g_loss": [], "adv_d_loss": [], "fid": {}, "adv_g_loss": [], "adv_d_loss": [], "fid": {},
} }
best_fid = float("inf") best_fid = float("inf")
nan_skipped = 0
print( print(
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}" f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}" f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
f" λ_adv={lambda_adversarial}"
) )
t_start = time.time() t_start = time.time()
@@ -513,43 +524,52 @@ def train_vae(
n_batches = 0 n_batches = 0
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False): for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
real = real.to(device) real = real.to(device).float()
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs # KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs) current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
# ── VAE forward ─────────────────────────────────────────────── # ── VAE forward (float32, no AMP) ────────────────────────────
with _autocast("cuda", enabled=use_amp):
recon, mu, log_var = vae(real) recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real) mse = F.mse_loss(recon, real)
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
# KL divergence: mean over latent dims (scale-invariant w.r.t. latent_dim)
kl = (-0.5 * (1 + log_var - mu.pow(2) - log_var.exp())).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze() perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── NaN/Inf guard ────────────────────────────────────────────
if not torch.isfinite(vae_loss):
nan_skipped += 1
opt_vae.zero_grad()
continue
# ── PatchGAN discriminator step ─────────────────────────────── # ── PatchGAN discriminator step ───────────────────────────────
adv_d = real.new_zeros(1).squeeze() adv_d = real.new_zeros(1).squeeze()
if use_adversarial: if use_adversarial:
# Warmup: only start adversarial training after 20% of epochs
if epoch > kl_warmup_epochs:
opt_d.zero_grad() opt_d.zero_grad()
with _autocast("cuda", enabled=use_amp):
d_real = patchgan(real) d_real = patchgan(real)
d_fake = patchgan(recon.detach()) d_fake = patchgan(recon.detach())
adv_d = hinge_d_loss(d_real, d_fake) adv_d = hinge_d_loss(d_real, d_fake)
scaler_d.scale(adv_d).backward() if torch.isfinite(adv_d):
scaler_d.step(opt_d) adv_d.backward()
scaler_d.update() torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
opt_d.step()
# ── PatchGAN generator adversarial loss ─────────────────────── # ── PatchGAN generator adversarial loss ───────────────────────
adv_g = real.new_zeros(1).squeeze() adv_g = real.new_zeros(1).squeeze()
if use_adversarial: if use_adversarial and epoch > kl_warmup_epochs:
with _autocast("cuda", enabled=use_amp):
adv_g = hinge_g_loss(patchgan(recon)) adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g vae_loss = vae_loss + lambda_adversarial * adv_g
# ── VAE backward ────────────────────────────────────────────── # ── VAE backward ──────────────────────────────────────────────
opt_vae.zero_grad() opt_vae.zero_grad()
scaler.scale(vae_loss).backward() vae_loss.backward()
scaler.step(opt_vae) torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip)
scaler.update() opt_vae.step()
ema.update(vae) ema.update(vae)
recon_sum += mse.item() recon_sum += mse.item()
@@ -559,11 +579,11 @@ def train_vae(
adv_d_sum += adv_d.item() adv_d_sum += adv_d.item()
n_batches += 1 n_batches += 1
avg_r = recon_sum / n_batches avg_r = recon_sum / max(n_batches, 1)
avg_k = kl_sum / n_batches avg_k = kl_sum / max(n_batches, 1)
avg_p = perc_sum / n_batches avg_p = perc_sum / max(n_batches, 1)
avg_g = adv_g_sum / n_batches avg_g = adv_g_sum / max(n_batches, 1)
avg_d = adv_d_sum / n_batches avg_d = adv_d_sum / max(n_batches, 1)
history["recon_loss"].append(avg_r) history["recon_loss"].append(avg_r)
history["kl_loss"].append(avg_k) history["kl_loss"].append(avg_k)
history["perc_loss"].append(avg_p) history["perc_loss"].append(avg_p)
@@ -574,6 +594,7 @@ def train_vae(
f"[{epoch:03d}/{epochs}] " f"[{epoch:03d}/{epochs}] "
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} " f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}" f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
f" (NaN skipped: {nan_skipped})"
) )
if epoch % sample_interval == 0: if epoch % sample_interval == 0:
@@ -607,6 +628,7 @@ def train_vae(
if patchgan is not None: if patchgan is not None:
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt") torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
history["train_time_s"] = time.time() - t_start history["train_time_s"] = time.time() - t_start
print(f"Total NaN-skipped batches: {nan_skipped}")
return history return history
@@ -662,7 +684,7 @@ def train_ddpm(
loader = DataLoader( loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1), num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True, pin_memory=(device.type == "cuda"), drop_last=True,
) )
opt = torch.optim.AdamW(model.parameters(), lr=lr) opt = torch.optim.AdamW(model.parameters(), lr=lr)
@@ -679,7 +701,8 @@ def train_ddpm(
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device)) fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"loss": [], "fid": {}} history = {"loss": [], "fid": {}}
best_fid = float("inf") best_fid = float("inf")