VAE refactor

This commit is contained in:
Johnny Fernandes
2026-05-02 00:32:45 +01:00
parent 1a7f67ab9c
commit c7804d2984
43 changed files with 1140 additions and 1064 deletions
@@ -7,6 +7,8 @@
"model": "vae", "model": "vae",
"latent_dim": 256, "latent_dim": 256,
"ngf": 64, "ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"sample_interval": 10, "sample_interval": 10,
"fid_interval": 25, "fid_interval": 25,
"fid_n_real": 5000 "fid_n_real": 5000
+2 -2
View File
@@ -1,8 +1,8 @@
{ {
"extends": "_base_phase3.json", "extends": "_base_phase3.json",
"run_name": "p3_1_vae", "run_name": "p3_1_vae",
"lr": 1e-3, "lr": 5e-4,
"beta_kl": 1.0, "beta_kl": 0.5,
"lambda_perceptual": 0.0, "lambda_perceptual": 0.0,
"lambda_adversarial": 0.0 "lambda_adversarial": 0.0
} }
@@ -1,8 +1,8 @@
{ {
"extends": "_base_phase3.json", "extends": "_base_phase3.json",
"run_name": "p3_2_vae_perceptual", "run_name": "p3_2_vae_perceptual",
"lr": 1e-3, "lr": 5e-4,
"beta_kl": 0.0001, "beta_kl": 0.1,
"lambda_perceptual": 0.1, "lambda_perceptual": 0.1,
"lambda_adversarial": 0.0 "lambda_adversarial": 0.0
} }
@@ -1,9 +1,9 @@
{ {
"extends": "_base_phase3.json", "extends": "_base_phase3.json",
"run_name": "p3_3_vae_patchgan", "run_name": "p3_3_vae_patchgan",
"lr": 1e-3, "lr": 5e-4,
"lr_d": 1e-4, "lr_d": 1e-4,
"beta_kl": 0.0001, "beta_kl": 0.05,
"lambda_perceptual": 0.1, "lambda_perceptual": 0.1,
"lambda_adversarial": 0.1, "lambda_adversarial": 0.1,
"ndf_patch": 64 "ndf_patch": 64
+209 -207
View File
@@ -17,216 +17,218 @@
"model": "vae", "model": "vae",
"latent_dim": 256, "latent_dim": 256,
"ngf": 64, "ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"run_name": "p3_1_vae", "run_name": "p3_1_vae",
"lr": 0.001, "lr": 0.0005,
"beta_kl": 1.0, "beta_kl": 0.5,
"lambda_perceptual": 0.0, "lambda_perceptual": 0.0,
"lambda_adversarial": 0.0 "lambda_adversarial": 0.0
}, },
"history": { "history": {
"recon_loss": [ "recon_loss": [
0.23614721197603095, 0.16871115132274792,
0.23315699178821, 0.16150351059742463,
0.22991716011594504, 0.15984480073436713,
NaN, 0.15590801271490562,
0.23217070787253544, 0.15061364529861343,
0.23155480842941847, 0.14644645811973983,
0.23157141198459855, 0.14328488393917552,
0.23181156750418183, 0.14094583944887176,
0.23201335527193853, 0.13928856181665364,
0.23178868266379732, 0.1377539750124909,
0.2315022333755962, 0.13583827684195632,
0.2311908418042028, 0.13372899556898662,
0.23185610672474927, 0.13197963861509776,
0.23176095832107413, 0.1295166944559568,
0.23165411693163407, 0.12748059913770765,
0.23174296459581098, 0.126394641561768,
0.2317636658747991, 0.12385777387226748,
0.2317118427883356, 0.12244280847983482,
0.23172695364834917, 0.121336308045265,
0.2316696329567677, 0.12060208647296979,
0.23168399261358458, 0.11928399897411339,
0.2316194716681782, 0.11839368250061813,
0.23164867447354856, 0.11770952168183449,
0.2315481170757204, 0.11709171382344177,
0.23165068109957582, 0.11688287127922234,
0.23167062098653907, 0.11606283619617805,
0.23162642907765177, 0.11564522911595483,
0.2315922882567104, 0.11527438517500702,
0.2315914996414103, 0.11461337662150717,
0.23156180984189367, 0.11428412470297936,
0.23156551628286004, 0.11390002192849787,
0.2315698005259037, 0.11359650807248221,
0.2315660522470617, 0.11298466332129434,
0.23156735001720935, 0.11281277696227926,
0.23161396435183337, 0.11266172842846976,
0.23158050178844705, 0.11232183446996233,
0.23159921089680785, 0.11213540144137338,
0.23149616745674712, 0.11184148606645246,
0.23159087484336308, 0.11143260068682015,
0.23156312872201967, 0.11121100015365161,
0.23153820200863048, 0.11103646766044135,
0.2315863819203825, 0.11086450471009454,
0.23150022140043414, 0.11068084617901562,
0.23154497337646973, 0.11046731698080006,
0.2315601774961011, 0.1101092367600172,
0.23153368950399578, 0.11040792095228139,
0.23152085642019907, 0.11019570772082378,
0.23151608884461924, 0.10987201248669726,
0.23154898990805334, 0.10948091489063878,
0.23155892872784892, 0.10964858110070738,
NaN, 0.10938817380457862,
NaN, 0.10904825658688688,
NaN, 0.10926413088718541,
0.24157701413600874, 0.10854257479246356,
NaN, 0.10841736453784327,
NaN, 0.10838394499041586,
0.24151325464630738, 0.10795553592153084,
NaN, 0.10781666870491627,
0.24154121766233036, 0.10770979084265538,
0.24155463749526912, 0.1076134722520653,
0.24158300176008135, 0.10732251404123938,
0.24158118757554609, 0.10716411035157676,
0.2415294518901242, 0.10698133315413426,
0.24156020069096842, 0.10672279318364766,
0.2415176352374574, 0.10645651288776316,
0.2415566616015047, 0.10662059826601265,
NaN, 0.10623115535156849,
0.24161437115608117, 0.10627176197102436,
0.24159398913765565, 0.1060571956456217,
0.24149432768806434, 0.10596368315382901,
0.24153172199287984, 0.10585042324840513,
0.24161516999204954, 0.1054028309085685,
0.24158193846034187, 0.10517520199601467,
0.2415451397562129, 0.1053342710639167,
0.24155487772873324, 0.1052525288815427,
0.24155297130346298, 0.10502007727821668,
NaN, 0.10490142727573203,
0.24157197961313093, 0.10471088643002714,
NaN, 0.10447397850390173,
NaN, 0.1044878327470814,
0.24158605401459923, 0.10447094976328887,
0.24156368870893094, 0.10421697295501701,
0.24159100852333582, 0.10414895819675209,
0.24153350121699846, 0.10414911504102568,
0.24153158377505776, 0.10366271499894623,
NaN, 0.10378186422217096,
0.24161708673350832, 0.10345360714719336,
0.24158515879868442, 0.10363478323396964,
NaN, 0.10342726219668348,
0.24157126235146809, 0.10305201177859408,
0.24162366709265953, 0.10310276317545491,
NaN, 0.10302646558445233,
0.2415581897665293, 0.10291730190635237,
NaN, 0.10278329038276122,
NaN, 0.10258485910156344,
0.2415400046823371, 0.10270861135079311,
NaN, 0.10251107273830308,
0.2415627600927638, 0.10234160087684281,
0.2415567432076503, 0.10214911544552216,
0.2415620140476614 0.10251628613879538
], ],
"kl_loss": [ "kl_loss": [
12.394881742504927, 29.481668366326225,
184.775765717539, 26.226897622784996,
127.26797539963681, 25.98807217932155,
33346392.786626913, 25.807672516912476,
35.72433020722153, 25.701779809772457,
31.41954361882984, 25.66069327052842,
16.178619678203876, 25.639777212061432,
10.234501274223001, 25.62703147301307,
14.817130448471787, 25.620143454299015,
9.230570034084158, 25.61577053966685,
9.643558593896719, 25.612662934849403,
8.47786058498244, 25.60999612319164,
5.573643362929678, 25.608206019442306,
2.4644629534365783, 25.60651744940342,
1.5757666807462516, 25.605323510292248,
0.426466258131286, 25.60462644772652,
1.7924597560404203, 25.603682958162747,
0.2769168242652956, 25.603215930808304,
0.21636260826236162, 25.602825727218235,
0.48804672485870176, 25.602524419116158,
0.10833573165453142, 25.602402703374878,
0.13318477837671328, 25.602261029756985,
0.17373992544877478, 25.602163278139553,
0.09584700099678121, 25.602091275728664,
0.0977757986014088, 25.602065310518967,
0.07794108981282538, 25.602008819580078,
0.05691333960853199, 25.601970770420174,
0.07221067506167242, 25.601926473470833,
0.036222075203704275, 25.601869432335224,
0.03126689469696492, 25.60185879519862,
0.04264315036642882, 25.601852824545315,
0.016960328184147805, 25.60183112234132,
0.03314871971324309, 25.601783161489372,
0.014776984407789368, 25.601782627594776,
0.011375301962312406, 25.601758304824177,
0.013948339588828703, 25.601712039393238,
0.01186063720120324, 25.601732046176227,
0.0099704478863372, 25.601690267905212,
0.00536374123289417, 25.60164840404804,
0.009618068660179583, 25.601653335440872,
0.00418840028031164, 25.60167118626782,
0.004865833775052785, 25.601645954653748,
0.005830266629345715, 25.60163675210415,
0.0023000687699064487, 25.601624191316784,
0.0038261460966199762, 25.601573104532356,
0.0022056369562673136, 25.60163724524343,
0.002220870125003987, 25.601609873975445,
0.0024217167485139184, 25.601576218238243,
0.001954249278483037, 25.601538580706997,
0.0021431104709895756, 25.601551618331516,
0.0022583500309011494, 25.60154583922818,
0.002132287005193404, 25.6015028138446,
56.80083886633675, 25.60153263858241,
82.57108385134966, 25.601467629783173,
82.55195800259582, 25.60146150018415,
82.57428529527452, 25.601416970929527,
82.56009972401155, 25.601392692989773,
82.55269345666608, 25.601404438670883,
82.57006728343474, 25.601370538401806,
82.55670593131302, 25.60135773308257,
82.54445134676419, 25.601353584191738,
82.57745079301361, 25.60130478785588,
82.57933913336859, 25.601299799405613,
82.5570435157189, 25.60127318618644,
82.56808758597089, 25.6012494054615,
82.56800172267816, 25.601233796176746,
82.56525711320404, 25.601193383208706,
82.56189481621115, 25.601183601933666,
82.55193622295673, 25.60118840698503,
82.55375865382007, 25.601158150240906,
82.56600202250685, 25.60114165656587,
82.57064581324912, 25.601124979491928,
82.55481151026538, 25.601091221866444,
82.55367833324986, 25.60108805925418,
82.56042112040724, 25.601056249732647,
82.5616829048874, 25.601044960511036,
82.5771528553759, 25.601008244049854,
82.55317820035495, 25.600999204521504,
82.57550573756552, 25.60097885539389,
82.57334061973116, 25.60095446741479,
82.56044387817383, 25.600959500695904,
82.5752662593483, 25.600936054164528,
82.56673936762361, 25.60091640195276,
82.56828115740393, 25.600889332274086,
82.56990289280557, 25.600879269787388,
82.55218840052939, 25.600861647190193,
82.56695372426611, 25.600824853293915,
82.575043066954, 25.6008102914207,
82.55754522991995, 25.600796605786705,
82.56361721723508, 25.600769034817688,
82.5628145821074, 25.600744826161964,
82.56431990403395, 25.600724538167317,
82.55777725806603, 25.60070591298943,
82.5742861918914, 25.600691750518276,
82.56361025622768, 25.600666380336143,
82.56887233766736, 25.60064803636991,
82.56539458902473, 25.60063368642432,
82.55887828729091, 25.600601587540066,
82.56073884882478, 25.60058915309417,
82.55578186165573 25.600568995516525
], ],
"perc_loss": [ "perc_loss": [
0.0, 0.0,
@@ -535,12 +537,12 @@
0.0 0.0
], ],
"fid": { "fid": {
"25": 315.9393615722656, "25": 238.42819213867188,
"50": 419.273193359375, "50": 232.70050048828125,
"75": 360.4432678222656, "75": 234.88893127441406,
"100": 363.9911193847656 "100": 236.51181030273438
}, },
"train_time_s": 660.9630489349365 "train_time_s": 676.644668340683
}, },
"n_params": 10608451 "n_params": 10608451
} }
+309 -307
View File
@@ -17,318 +17,320 @@
"model": "vae", "model": "vae",
"latent_dim": 256, "latent_dim": 256,
"ngf": 64, "ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"run_name": "p3_2_vae_perceptual", "run_name": "p3_2_vae_perceptual",
"lr": 0.001, "lr": 0.0005,
"beta_kl": 0.0001, "beta_kl": 0.1,
"lambda_perceptual": 0.1, "lambda_perceptual": 0.1,
"lambda_adversarial": 0.0 "lambda_adversarial": 0.0
}, },
"history": { "history": {
"recon_loss": [ "recon_loss": [
NaN, 0.13868479322419208,
0.041370747770127066, 0.1345828948622076,
0.035506031866002284, 0.13401474649261716,
0.06229733923672993, 0.13219879449814811,
0.06089382184048494, 0.13071280969386426,
0.053464704654856116, 0.12897613197246677,
0.04950355895213846, 0.12651141290353912,
0.04674786329269409, 0.12554800317773962,
0.05693557829811023, 0.123991425602864,
0.04808994451076047, 0.12276749478446113,
NaN, 0.12151926139799449,
0.05869839608701121, 0.1200498831896191,
0.05293231666820426, 0.11872616813032545,
0.05057006603918779, 0.11811881408923203,
0.04713928615117175, 0.11655244218488024,
0.04481589820427008, 0.11565455276932982,
0.04331832689543565, 0.11529083312767693,
0.041820655449524395, 0.11437753734425601,
NaN, 0.11373461206626688,
0.24315394992884407, 0.1133939535317258,
0.24305414841470555, 0.11269663740745467,
0.2429187158998261, 0.11214834819428432,
0.24289727962424612, 0.11180534907895276,
0.2377373814328104, 0.1112786961448753,
0.23880945107875726, 0.11112714579535855,
0.24047312904626894, 0.11040605649224713,
0.23915709144411942, 0.11024710934004213,
NaN, 0.11025449748222645,
0.24000247061634675, 0.10986682841092603,
0.2362435321904655, 0.10947157509433918,
0.23705266075383905, 0.10914939207335313,
NaN, 0.1090434947425229,
0.23998981039238793, 0.10872587247982494,
0.23828768376738596, 0.10891366730897854,
0.23908851302077627, 0.10840102485739268,
NaN, 0.10831964285009438,
NaN, 0.10826414010017855,
NaN, 0.10774775957449889,
NaN, 0.10791046626101701,
NaN, 0.10784940838686422,
NaN, 0.10743295191190182,
NaN, 0.10734256694459507,
NaN, 0.10702427010187227,
NaN, 0.10701240906412275,
NaN, 0.10711385588296968,
NaN, 0.10697286784585215,
NaN, 0.10673481402679896,
NaN, 0.10650451705814937,
NaN, 0.10629599592369846,
NaN, 0.1064668823288292,
NaN, 0.1063920924296746,
NaN, 0.10610189062789974,
NaN, 0.10592550779573429,
NaN, 0.10588065830942912,
NaN, 0.105781758379223,
NaN, 0.10560809290752961,
NaN, 0.10550812136732106,
NaN, 0.10535470090615444,
NaN, 0.10536463093808573,
NaN, 0.1051216669794586,
NaN, 0.10498357508490738,
NaN, 0.10464231009220976,
NaN, 0.10468940513256268,
NaN, 0.10468925925719942,
NaN, 0.10429271149775411,
NaN, 0.10437219857405393,
NaN, 0.10406083403489529,
NaN, 0.10395075554330634,
NaN, 0.10419673752835673,
NaN, 0.10405941009839885,
NaN, 0.10379417274051751,
NaN, 0.10373205498943472,
NaN, 0.10360019166882221,
NaN, 0.10355540880790123,
NaN, 0.10355440188103761,
NaN, 0.10325965399925525,
NaN, 0.10304029177651446,
NaN, 0.10311986905578364,
NaN, 0.10273497300142916,
NaN, 0.10302559410532315,
NaN, 0.10278304515040329,
NaN, 0.10263998298627189,
NaN, 0.10254473253511466,
NaN, 0.10245785787383206,
NaN, 0.10246957698438922,
NaN, 0.10233539204375866,
NaN, 0.1025002559280803,
NaN, 0.10214536613187729,
NaN, 0.10215426669416265,
NaN, 0.10214609539725332,
NaN, 0.10188078407484752,
NaN, 0.1020691341951362,
NaN, 0.1019192597645725,
NaN, 0.10151305177018173,
NaN, 0.10163224848289774,
NaN, 0.10181667007760614,
NaN, 0.10129789227985928,
NaN, 0.10133470410210454,
NaN, 0.1014308016269635,
NaN 0.10122469465574647
], ],
"kl_loss": [ "kl_loss": [
41526536.5961082, 32.82580094867282,
1258.7768262553418, 27.63585641649034,
1165.5987154968784, 26.858297156472492,
3689.321748260759, 26.36294916756133,
1320.3167513333833, 26.101821015023777,
928.3473894168169, 25.925728178431847,
789.754654843583, 25.810632864634197,
693.4729607736963, 25.75408285499638,
1017.8856519389357, 25.714011412400467,
622.2623097998464, 25.691052310487144,
304942.2120989938, 25.673080118293438,
750.5718396830763, 25.658130865830643,
618.6888906364767, 25.648127645508858,
586.4194498958751, 25.640685057028747,
579.2793892102363, 25.63372629116743,
525.8830437945504, 25.62904858792949,
497.5363791050055, 25.625450937157,
452.93342662061383, 25.622334708515396,
172476214612.95245, 25.61985801631569,
1112.027792645316, 25.61797430168869,
1109.8156276605068, 25.617105920090633,
1109.8872808472722, 25.61665568800054,
1112.3999198196282, 25.615915049854507,
1163.6093267457097, 25.615573080176983,
1195.4633280436199, 25.61539109955486,
1232.709252512353, 25.614891084850345,
1253.0651078183428, 25.614565458053196,
19923.65347159622, 25.614424391689464,
6871.588723207132, 25.614320644965538,
6872.017739842081, 25.613855867304352,
6872.62934758113, 25.613705484276142,
27062.17111336472, 25.613627437852386,
4992.576698759683, 25.61319611215184,
5000.895510942508, 25.613402957590218,
5011.723715236044, 25.612938135098187,
4.196954763212122e+20, 25.612793853140285,
NaN, 25.61277205719907,
NaN, 25.6125225368728,
NaN, 25.612593328850902,
NaN, 25.612471242236275,
NaN, 25.61213506796421,
NaN, 25.61220080220801,
NaN, 25.612097707569088,
NaN, 25.611939287593223,
NaN, 25.611856093773476,
NaN, 25.611754238096058,
NaN, 25.61164761812259,
NaN, 25.61152194096492,
NaN, 25.61156750948001,
NaN, 25.611467068011944,
NaN, 25.61140904059777,
NaN, 25.611222067449848,
NaN, 25.61103208655985,
NaN, 25.610888195852947,
NaN, 25.610904620243954,
NaN, 25.610692464388332,
NaN, 25.610453536367825,
NaN, 25.610431373628796,
NaN, 25.610199642996502,
NaN, 25.609933506729256,
NaN, 25.60987384095151,
NaN, 25.609577077066795,
NaN, 25.6096679817917,
NaN, 25.60952104666294,
NaN, 25.60934520786644,
NaN, 25.60918986899221,
NaN, 25.60895837474073,
NaN, 25.6090076030829,
NaN, 25.60884045331906,
NaN, 25.608697809724728,
NaN, 25.608583588885445,
NaN, 25.608446227179634,
NaN, 25.60828152273455,
NaN, 25.608113199217705,
NaN, 25.60813422080798,
NaN, 25.607940327407967,
NaN, 25.60771349963979,
NaN, 25.607653996883293,
NaN, 25.607560540875816,
NaN, 25.607479812752487,
NaN, 25.607303941351734,
NaN, 25.607178337553627,
NaN, 25.60706621968848,
NaN, 25.60683028310792,
NaN, 25.60676119470189,
NaN, 25.606651432493813,
NaN, 25.606486483516857,
NaN, 25.606393325023163,
NaN, 25.606248602907883,
NaN, 25.60615853366689,
NaN, 25.605982751927826,
NaN, 25.60586034334623,
NaN, 25.60572910308838,
NaN, 25.605554120153442,
NaN, 25.605445796607906,
NaN, 25.605342062110577,
NaN, 25.605153943738365,
NaN, 25.60500627501398,
NaN, 25.60483596263788,
NaN 25.60475177031297
], ],
"perc_loss": [ "perc_loss": [
NaN, 3.4957813187542124,
3.0017659954535656, 3.3751721346480217,
2.8931196978968434, 3.341614580561972,
3.084328340159522, 3.325737010209988,
3.1398269641093717, 3.314503056371314,
3.0839009172896032, 3.3061294968311605,
3.0459561949102287, 3.2988068002920885,
3.015714469628456, 3.2951497960294414,
3.0917934143645134, 3.289724970475221,
3.010985380054539, 3.2849769531152186,
NaN, 3.2810079547075124,
3.102917709411719, 3.2768430485684648,
3.055045795746339, 3.2737610197474813,
3.0351512865123587, 3.270859044331771,
3.002274121993627, 3.267263490929563,
2.9754957128793764, 3.2641563344205546,
2.9553630066733074, 3.2620158368705683,
2.936704080328982, 3.259672654999627,
NaN, 3.2573955410566087,
3.66463458385223, 3.2559085101143928,
3.6419444985878773, 3.2524190697914515,
3.6321474706005845, 3.2508940416523533,
3.6272282503608966, 3.2493486078376446,
3.781295938878997, 3.2472466857005386,
3.740793755421272, 3.246741561808138,
3.759704489993234, 3.2446710722059264,
3.7078149914741516, 3.2429053090576434,
NaN, 3.2415418217324805,
3.7647176323792872, 3.2410791664042025,
3.7441278461717133, 3.238901428687267,
3.7264530261357627, 3.2372534682608056,
NaN, 3.2365858203325515,
3.7299226170931106, 3.2345399107688513,
3.7314435080585318, 3.234266322392684,
3.754668933713538, 3.2330421667832594,
NaN, 3.2321024134627776,
NaN, 3.2317300596807756,
NaN, 3.2297990515700774,
NaN, 3.2291462492739034,
NaN, 3.229038153958117,
NaN, 3.227532477969797,
NaN, 3.227750709423652,
NaN, 3.22642219117564,
NaN, 3.225604632980803,
NaN, 3.224408136983203,
NaN, 3.2237918310695224,
NaN, 3.223153405719333,
NaN, 3.2227800172618313,
NaN, 3.22208283854346,
NaN, 3.2215008134515877,
NaN, 3.220343905636388,
NaN, 3.2200386534389267,
NaN, 3.2190595391469126,
NaN, 3.218352971932827,
NaN, 3.2173767803061724,
NaN, 3.2164535395100584,
NaN, 3.2164310828233376,
NaN, 3.2155216560404525,
NaN, 3.2145652862695546,
NaN, 3.2132638708139076,
NaN, 3.213306623136895,
NaN, 3.2116519161778636,
NaN, 3.2117279686479487,
NaN, 3.210561112460927,
NaN, 3.2098504080731645,
NaN, 3.2099855195762763,
NaN, 3.2090730035406914,
NaN, 3.2085196706983776,
NaN, 3.208222340824258,
NaN, 3.207890693448548,
NaN, 3.206765956348843,
NaN, 3.2065086430973477,
NaN, 3.2055318915945854,
NaN, 3.204780939297798,
NaN, 3.205001257933103,
NaN, 3.203499240243537,
NaN, 3.20293498905296,
NaN, 3.2031849953863354,
NaN, 3.2017455717437286,
NaN, 3.2029461646691346,
NaN, 3.202106138579866,
NaN, 3.200608807751256,
NaN, 3.200435662880922,
NaN, 3.2001758396116076,
NaN, 3.199647111260993,
NaN, 3.199422711490566,
NaN, 3.1987570305155892,
NaN, 3.198392223089169,
NaN, 3.1976211967631283,
NaN, 3.1978414170762415,
NaN, 3.197227580425067,
NaN, 3.197008974531777,
NaN, 3.1967804961734347,
NaN, 3.1954670217302112,
NaN, 3.19552276073358,
NaN, 3.1953888051530237,
NaN, 3.194187126098535,
NaN, 3.194483350484799,
NaN, 3.1943916347291736,
NaN 3.194042898650862
], ],
"adv_g_loss": [ "adv_g_loss": [
0.0, 0.0,
@@ -535,12 +537,12 @@
0.0 0.0
], ],
"fid": { "fid": {
"25": 263.1458740234375, "25": 218.6470947265625,
"50": 598.3736572265625, "50": 236.44911193847656,
"75": 598.3736572265625, "75": 235.74722290039062,
"100": 598.3736572265625 "100": 237.4191436767578
}, },
"train_time_s": 952.6596128940582 "train_time_s": 1509.468991279602
}, },
"n_params": 10608451 "n_params": 10608451
} }
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,14 @@
{
"created_at": "2026-05-01T23:33:40.557082+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 35985830,
"offer_id": 30940273,
"ssh_host": "ssh4.vast.ai",
"ssh_port": 25830,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,14 @@
{
"created_at": "2026-05-02T01:21:14.070736+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 35989396,
"offer_id": 35960853,
"ssh_host": "ssh6.vast.ai",
"ssh_port": 29396,
"status": "cancelled",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,14 @@
{
"created_at": "2026-05-02T01:28:24.565349+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 35989623,
"offer_id": 29302404,
"ssh_host": "ssh9.vast.ai",
"ssh_port": 29622,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

After

Width:  |  Height:  |  Size: 204 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 120 KiB

After

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

After

Width:  |  Height:  |  Size: 128 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.5 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 126 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.5 KiB

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

After

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.3 KiB

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

After

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.4 KiB

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

After

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 121 KiB

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

After

Width:  |  Height:  |  Size: 9.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.3 KiB

After

Width:  |  Height:  |  Size: 9.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

After

Width:  |  Height:  |  Size: 68 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 135 KiB

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 31 KiB

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 23 KiB

+2 -1
View File
@@ -98,7 +98,8 @@ class VAE(nn.Module):
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = self.encoder(x).flatten(1) h = self.encoder(x).flatten(1)
return self.fc_mu(h), self.fc_lv(h) log_var = self.fc_lv(h).clamp(-10.0, 10.0)
return self.fc_mu(h), log_var
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * log_var) std = torch.exp(0.5 * log_var)
+52 -27
View File
@@ -403,13 +403,20 @@ def train_vae(
run_name: str, run_name: str,
device: str = "cuda", device: str = "cuda",
) -> dict: ) -> dict:
"""VAE training loop covering Phase 3.1 3.3. """VAE training loop covering Phase 3.1 3.3 and Phase 5.
Config toggles: Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+) lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3) lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
free_bits > 0 → per-dimension KL free bits (prevents posterior
collapse and KL explosion)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
computation runs in float32.
""" """
device = torch.device(device if torch.cuda.is_available() else "cpu") device = torch.device(device if torch.cuda.is_available() else "cpu")
vae = vae.to(device) vae = vae.to(device)
@@ -425,6 +432,8 @@ def train_vae(
lambda_perceptual = cfg.get("lambda_perceptual", 0.0) lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
lambda_adversarial = cfg.get("lambda_adversarial", 0.0) lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
lr_d = cfg.get("lr_d", 1e-4) lr_d = cfg.get("lr_d", 1e-4)
free_bits_val = cfg.get("free_bits", 0.0)
grad_clip = cfg.get("grad_clip", 1.0)
ema_decay = cfg.get("ema_decay", 0.9999) ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10) sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25) fid_interval = cfg.get("fid_interval", 25)
@@ -432,6 +441,7 @@ def train_vae(
use_perceptual = lambda_perceptual > 0 use_perceptual = lambda_perceptual > 0
use_adversarial = lambda_adversarial > 0 use_adversarial = lambda_adversarial > 0
use_free_bits = free_bits_val > 0
loader = DataLoader( loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, train_dataset, batch_size=batch_size, shuffle=True,
@@ -440,8 +450,8 @@ def train_vae(
) )
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr) opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
use_amp = device.type == "cuda" # AMP disabled — float16 overflows on KL spikes, causing NaN cascades
scaler = _GradScaler("cuda", enabled=use_amp) use_amp = False
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training # KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
kl_warmup_epochs = max(1, epochs // 5) kl_warmup_epochs = max(1, epochs // 5)
@@ -456,11 +466,10 @@ def train_vae(
perc_fn = None perc_fn = None
patchgan = None patchgan = None
opt_d = None opt_d = None
scaler_d = None
if use_perceptual: if use_perceptual:
from src.training.perceptual import PerceptualLoss from src.training.perceptual import PerceptualLoss
perc_fn = PerceptualLoss().to(device) perc_fn = PerceptualLoss().to(device).float()
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3") print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
if use_adversarial: if use_adversarial:
@@ -468,15 +477,14 @@ def train_vae(
patchgan = PatchGANDiscriminator( patchgan = PatchGANDiscriminator(
ndf=cfg.get("ndf_patch", 64), ndf=cfg.get("ndf_patch", 64),
image_size=cfg.get("image_size", 64), image_size=cfg.get("image_size", 64),
).to(device) ).to(device).float()
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999)) opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
scaler_d = _GradScaler("cuda", enabled=use_amp)
sched_d = torch.optim.lr_scheduler.LambdaLR( sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1))) opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
n_d = sum(p.numel() for p in patchgan.parameters()) n_d = sum(p.numel() for p in patchgan.parameters())
print(f"PatchGAN: {n_d:,} params") print(f"PatchGAN: {n_d:,} params")
else: else:
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called hinge_d_loss = hinge_g_loss = None # never called
# ── Fixed seeds for consistent visualisation ────────────────────────── # ── Fixed seeds for consistent visualisation ──────────────────────────
fixed_z = torch.randn(16, latent_dim, device=device) fixed_z = torch.randn(16, latent_dim, device=device)
@@ -497,9 +505,11 @@ def train_vae(
"adv_g_loss": [], "adv_d_loss": [], "fid": {}, "adv_g_loss": [], "adv_d_loss": [], "fid": {},
} }
best_fid = float("inf") best_fid = float("inf")
nan_skipped = 0
print( print(
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}" f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}" f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
f" λ_adv={lambda_adversarial} free_bits={free_bits_val}"
) )
t_start = time.time() t_start = time.time()
@@ -513,43 +523,56 @@ def train_vae(
n_batches = 0 n_batches = 0
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False): for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
real = real.to(device) real = real.to(device).float()
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs # KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs) current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
# ── VAE forward ─────────────────────────────────────────────── # ── VAE forward (float32, no AMP) ────────────────────────────
with _autocast("cuda", enabled=use_amp):
recon, mu, log_var = vae(real) recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real) mse = F.mse_loss(recon, real)
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
# KL divergence with optional free bits
kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) # (B, latent_dim)
if use_free_bits:
# Free bits: ensure each dimension contributes at least free_bits_val KL.
# Dimensions below the threshold are raised to it, preventing posterior
# collapse (dimensions that go to 0) while still penalising large KL.
kl_per_dim = torch.clamp(kl_per_dim, min=free_bits_val)
kl = kl_per_dim.sum(1).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze() perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── NaN/Inf guard ────────────────────────────────────────────
if not torch.isfinite(vae_loss):
nan_skipped += 1
opt_vae.zero_grad()
continue
# ── PatchGAN discriminator step ─────────────────────────────── # ── PatchGAN discriminator step ───────────────────────────────
adv_d = real.new_zeros(1).squeeze() adv_d = real.new_zeros(1).squeeze()
if use_adversarial: if use_adversarial:
opt_d.zero_grad() opt_d.zero_grad()
with _autocast("cuda", enabled=use_amp):
d_real = patchgan(real) d_real = patchgan(real)
d_fake = patchgan(recon.detach()) d_fake = patchgan(recon.detach())
adv_d = hinge_d_loss(d_real, d_fake) adv_d = hinge_d_loss(d_real, d_fake)
scaler_d.scale(adv_d).backward() if torch.isfinite(adv_d):
scaler_d.step(opt_d) adv_d.backward()
scaler_d.update() torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
opt_d.step()
# ── PatchGAN generator adversarial loss ─────────────────────── # ── PatchGAN generator adversarial loss ───────────────────────
adv_g = real.new_zeros(1).squeeze() adv_g = real.new_zeros(1).squeeze()
if use_adversarial: if use_adversarial:
with _autocast("cuda", enabled=use_amp):
adv_g = hinge_g_loss(patchgan(recon)) adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g vae_loss = vae_loss + lambda_adversarial * adv_g
# ── VAE backward ────────────────────────────────────────────── # ── VAE backward ──────────────────────────────────────────────
opt_vae.zero_grad() opt_vae.zero_grad()
scaler.scale(vae_loss).backward() vae_loss.backward()
scaler.step(opt_vae) torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip)
scaler.update() opt_vae.step()
ema.update(vae) ema.update(vae)
recon_sum += mse.item() recon_sum += mse.item()
@@ -559,11 +582,11 @@ def train_vae(
adv_d_sum += adv_d.item() adv_d_sum += adv_d.item()
n_batches += 1 n_batches += 1
avg_r = recon_sum / n_batches avg_r = recon_sum / max(n_batches, 1)
avg_k = kl_sum / n_batches avg_k = kl_sum / max(n_batches, 1)
avg_p = perc_sum / n_batches avg_p = perc_sum / max(n_batches, 1)
avg_g = adv_g_sum / n_batches avg_g = adv_g_sum / max(n_batches, 1)
avg_d = adv_d_sum / n_batches avg_d = adv_d_sum / max(n_batches, 1)
history["recon_loss"].append(avg_r) history["recon_loss"].append(avg_r)
history["kl_loss"].append(avg_k) history["kl_loss"].append(avg_k)
history["perc_loss"].append(avg_p) history["perc_loss"].append(avg_p)
@@ -574,6 +597,7 @@ def train_vae(
f"[{epoch:03d}/{epochs}] " f"[{epoch:03d}/{epochs}] "
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} " f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}" f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
f" (NaN skipped: {nan_skipped})"
) )
if epoch % sample_interval == 0: if epoch % sample_interval == 0:
@@ -607,6 +631,7 @@ def train_vae(
if patchgan is not None: if patchgan is not None:
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt") torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
history["train_time_s"] = time.time() - t_start history["train_time_s"] = time.time() - t_start
print(f"Total NaN-skipped batches: {nan_skipped}")
return history return history