VAE refactor

This commit is contained in:
Johnny Fernandes
2026-05-02 00:32:45 +01:00
parent 1a7f67ab9c
commit c7804d2984
43 changed files with 1140 additions and 1064 deletions
@@ -7,6 +7,8 @@
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
+2 -2
View File
@@ -1,8 +1,8 @@
{
"extends": "_base_phase3.json",
"run_name": "p3_1_vae",
"lr": 1e-3,
"beta_kl": 1.0,
"lr": 5e-4,
"beta_kl": 0.5,
"lambda_perceptual": 0.0,
"lambda_adversarial": 0.0
}
@@ -1,8 +1,8 @@
{
"extends": "_base_phase3.json",
"run_name": "p3_2_vae_perceptual",
"lr": 1e-3,
"beta_kl": 0.0001,
"lr": 5e-4,
"beta_kl": 0.1,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.0
}
@@ -1,9 +1,9 @@
{
"extends": "_base_phase3.json",
"run_name": "p3_3_vae_patchgan",
"lr": 1e-3,
"lr": 5e-4,
"lr_d": 1e-4,
"beta_kl": 0.0001,
"beta_kl": 0.05,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.1,
"ndf_patch": 64
+209 -207
View File
@@ -17,216 +17,218 @@
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"run_name": "p3_1_vae",
"lr": 0.001,
"beta_kl": 1.0,
"lr": 0.0005,
"beta_kl": 0.5,
"lambda_perceptual": 0.0,
"lambda_adversarial": 0.0
},
"history": {
"recon_loss": [
0.23614721197603095,
0.23315699178821,
0.22991716011594504,
NaN,
0.23217070787253544,
0.23155480842941847,
0.23157141198459855,
0.23181156750418183,
0.23201335527193853,
0.23178868266379732,
0.2315022333755962,
0.2311908418042028,
0.23185610672474927,
0.23176095832107413,
0.23165411693163407,
0.23174296459581098,
0.2317636658747991,
0.2317118427883356,
0.23172695364834917,
0.2316696329567677,
0.23168399261358458,
0.2316194716681782,
0.23164867447354856,
0.2315481170757204,
0.23165068109957582,
0.23167062098653907,
0.23162642907765177,
0.2315922882567104,
0.2315914996414103,
0.23156180984189367,
0.23156551628286004,
0.2315698005259037,
0.2315660522470617,
0.23156735001720935,
0.23161396435183337,
0.23158050178844705,
0.23159921089680785,
0.23149616745674712,
0.23159087484336308,
0.23156312872201967,
0.23153820200863048,
0.2315863819203825,
0.23150022140043414,
0.23154497337646973,
0.2315601774961011,
0.23153368950399578,
0.23152085642019907,
0.23151608884461924,
0.23154898990805334,
0.23155892872784892,
NaN,
NaN,
NaN,
0.24157701413600874,
NaN,
NaN,
0.24151325464630738,
NaN,
0.24154121766233036,
0.24155463749526912,
0.24158300176008135,
0.24158118757554609,
0.2415294518901242,
0.24156020069096842,
0.2415176352374574,
0.2415566616015047,
NaN,
0.24161437115608117,
0.24159398913765565,
0.24149432768806434,
0.24153172199287984,
0.24161516999204954,
0.24158193846034187,
0.2415451397562129,
0.24155487772873324,
0.24155297130346298,
NaN,
0.24157197961313093,
NaN,
NaN,
0.24158605401459923,
0.24156368870893094,
0.24159100852333582,
0.24153350121699846,
0.24153158377505776,
NaN,
0.24161708673350832,
0.24158515879868442,
NaN,
0.24157126235146809,
0.24162366709265953,
NaN,
0.2415581897665293,
NaN,
NaN,
0.2415400046823371,
NaN,
0.2415627600927638,
0.2415567432076503,
0.2415620140476614
0.16871115132274792,
0.16150351059742463,
0.15984480073436713,
0.15590801271490562,
0.15061364529861343,
0.14644645811973983,
0.14328488393917552,
0.14094583944887176,
0.13928856181665364,
0.1377539750124909,
0.13583827684195632,
0.13372899556898662,
0.13197963861509776,
0.1295166944559568,
0.12748059913770765,
0.126394641561768,
0.12385777387226748,
0.12244280847983482,
0.121336308045265,
0.12060208647296979,
0.11928399897411339,
0.11839368250061813,
0.11770952168183449,
0.11709171382344177,
0.11688287127922234,
0.11606283619617805,
0.11564522911595483,
0.11527438517500702,
0.11461337662150717,
0.11428412470297936,
0.11390002192849787,
0.11359650807248221,
0.11298466332129434,
0.11281277696227926,
0.11266172842846976,
0.11232183446996233,
0.11213540144137338,
0.11184148606645246,
0.11143260068682015,
0.11121100015365161,
0.11103646766044135,
0.11086450471009454,
0.11068084617901562,
0.11046731698080006,
0.1101092367600172,
0.11040792095228139,
0.11019570772082378,
0.10987201248669726,
0.10948091489063878,
0.10964858110070738,
0.10938817380457862,
0.10904825658688688,
0.10926413088718541,
0.10854257479246356,
0.10841736453784327,
0.10838394499041586,
0.10795553592153084,
0.10781666870491627,
0.10770979084265538,
0.1076134722520653,
0.10732251404123938,
0.10716411035157676,
0.10698133315413426,
0.10672279318364766,
0.10645651288776316,
0.10662059826601265,
0.10623115535156849,
0.10627176197102436,
0.1060571956456217,
0.10596368315382901,
0.10585042324840513,
0.1054028309085685,
0.10517520199601467,
0.1053342710639167,
0.1052525288815427,
0.10502007727821668,
0.10490142727573203,
0.10471088643002714,
0.10447397850390173,
0.1044878327470814,
0.10447094976328887,
0.10421697295501701,
0.10414895819675209,
0.10414911504102568,
0.10366271499894623,
0.10378186422217096,
0.10345360714719336,
0.10363478323396964,
0.10342726219668348,
0.10305201177859408,
0.10310276317545491,
0.10302646558445233,
0.10291730190635237,
0.10278329038276122,
0.10258485910156344,
0.10270861135079311,
0.10251107273830308,
0.10234160087684281,
0.10214911544552216,
0.10251628613879538
],
"kl_loss": [
12.394881742504927,
184.775765717539,
127.26797539963681,
33346392.786626913,
35.72433020722153,
31.41954361882984,
16.178619678203876,
10.234501274223001,
14.817130448471787,
9.230570034084158,
9.643558593896719,
8.47786058498244,
5.573643362929678,
2.4644629534365783,
1.5757666807462516,
0.426466258131286,
1.7924597560404203,
0.2769168242652956,
0.21636260826236162,
0.48804672485870176,
0.10833573165453142,
0.13318477837671328,
0.17373992544877478,
0.09584700099678121,
0.0977757986014088,
0.07794108981282538,
0.05691333960853199,
0.07221067506167242,
0.036222075203704275,
0.03126689469696492,
0.04264315036642882,
0.016960328184147805,
0.03314871971324309,
0.014776984407789368,
0.011375301962312406,
0.013948339588828703,
0.01186063720120324,
0.0099704478863372,
0.00536374123289417,
0.009618068660179583,
0.00418840028031164,
0.004865833775052785,
0.005830266629345715,
0.0023000687699064487,
0.0038261460966199762,
0.0022056369562673136,
0.002220870125003987,
0.0024217167485139184,
0.001954249278483037,
0.0021431104709895756,
0.0022583500309011494,
0.002132287005193404,
56.80083886633675,
82.57108385134966,
82.55195800259582,
82.57428529527452,
82.56009972401155,
82.55269345666608,
82.57006728343474,
82.55670593131302,
82.54445134676419,
82.57745079301361,
82.57933913336859,
82.5570435157189,
82.56808758597089,
82.56800172267816,
82.56525711320404,
82.56189481621115,
82.55193622295673,
82.55375865382007,
82.56600202250685,
82.57064581324912,
82.55481151026538,
82.55367833324986,
82.56042112040724,
82.5616829048874,
82.5771528553759,
82.55317820035495,
82.57550573756552,
82.57334061973116,
82.56044387817383,
82.5752662593483,
82.56673936762361,
82.56828115740393,
82.56990289280557,
82.55218840052939,
82.56695372426611,
82.575043066954,
82.55754522991995,
82.56361721723508,
82.5628145821074,
82.56431990403395,
82.55777725806603,
82.5742861918914,
82.56361025622768,
82.56887233766736,
82.56539458902473,
82.55887828729091,
82.56073884882478,
82.55578186165573
29.481668366326225,
26.226897622784996,
25.98807217932155,
25.807672516912476,
25.701779809772457,
25.66069327052842,
25.639777212061432,
25.62703147301307,
25.620143454299015,
25.61577053966685,
25.612662934849403,
25.60999612319164,
25.608206019442306,
25.60651744940342,
25.605323510292248,
25.60462644772652,
25.603682958162747,
25.603215930808304,
25.602825727218235,
25.602524419116158,
25.602402703374878,
25.602261029756985,
25.602163278139553,
25.602091275728664,
25.602065310518967,
25.602008819580078,
25.601970770420174,
25.601926473470833,
25.601869432335224,
25.60185879519862,
25.601852824545315,
25.60183112234132,
25.601783161489372,
25.601782627594776,
25.601758304824177,
25.601712039393238,
25.601732046176227,
25.601690267905212,
25.60164840404804,
25.601653335440872,
25.60167118626782,
25.601645954653748,
25.60163675210415,
25.601624191316784,
25.601573104532356,
25.60163724524343,
25.601609873975445,
25.601576218238243,
25.601538580706997,
25.601551618331516,
25.60154583922818,
25.6015028138446,
25.60153263858241,
25.601467629783173,
25.60146150018415,
25.601416970929527,
25.601392692989773,
25.601404438670883,
25.601370538401806,
25.60135773308257,
25.601353584191738,
25.60130478785588,
25.601299799405613,
25.60127318618644,
25.6012494054615,
25.601233796176746,
25.601193383208706,
25.601183601933666,
25.60118840698503,
25.601158150240906,
25.60114165656587,
25.601124979491928,
25.601091221866444,
25.60108805925418,
25.601056249732647,
25.601044960511036,
25.601008244049854,
25.600999204521504,
25.60097885539389,
25.60095446741479,
25.600959500695904,
25.600936054164528,
25.60091640195276,
25.600889332274086,
25.600879269787388,
25.600861647190193,
25.600824853293915,
25.6008102914207,
25.600796605786705,
25.600769034817688,
25.600744826161964,
25.600724538167317,
25.60070591298943,
25.600691750518276,
25.600666380336143,
25.60064803636991,
25.60063368642432,
25.600601587540066,
25.60058915309417,
25.600568995516525
],
"perc_loss": [
0.0,
@@ -535,12 +537,12 @@
0.0
],
"fid": {
"25": 315.9393615722656,
"50": 419.273193359375,
"75": 360.4432678222656,
"100": 363.9911193847656
"25": 238.42819213867188,
"50": 232.70050048828125,
"75": 234.88893127441406,
"100": 236.51181030273438
},
"train_time_s": 660.9630489349365
"train_time_s": 676.644668340683
},
"n_params": 10608451
}
+309 -307
View File
@@ -17,318 +17,320 @@
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"run_name": "p3_2_vae_perceptual",
"lr": 0.001,
"beta_kl": 0.0001,
"lr": 0.0005,
"beta_kl": 0.1,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.0
},
"history": {
"recon_loss": [
NaN,
0.041370747770127066,
0.035506031866002284,
0.06229733923672993,
0.06089382184048494,
0.053464704654856116,
0.04950355895213846,
0.04674786329269409,
0.05693557829811023,
0.04808994451076047,
NaN,
0.05869839608701121,
0.05293231666820426,
0.05057006603918779,
0.04713928615117175,
0.04481589820427008,
0.04331832689543565,
0.041820655449524395,
NaN,
0.24315394992884407,
0.24305414841470555,
0.2429187158998261,
0.24289727962424612,
0.2377373814328104,
0.23880945107875726,
0.24047312904626894,
0.23915709144411942,
NaN,
0.24000247061634675,
0.2362435321904655,
0.23705266075383905,
NaN,
0.23998981039238793,
0.23828768376738596,
0.23908851302077627,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN
0.13868479322419208,
0.1345828948622076,
0.13401474649261716,
0.13219879449814811,
0.13071280969386426,
0.12897613197246677,
0.12651141290353912,
0.12554800317773962,
0.123991425602864,
0.12276749478446113,
0.12151926139799449,
0.1200498831896191,
0.11872616813032545,
0.11811881408923203,
0.11655244218488024,
0.11565455276932982,
0.11529083312767693,
0.11437753734425601,
0.11373461206626688,
0.1133939535317258,
0.11269663740745467,
0.11214834819428432,
0.11180534907895276,
0.1112786961448753,
0.11112714579535855,
0.11040605649224713,
0.11024710934004213,
0.11025449748222645,
0.10986682841092603,
0.10947157509433918,
0.10914939207335313,
0.1090434947425229,
0.10872587247982494,
0.10891366730897854,
0.10840102485739268,
0.10831964285009438,
0.10826414010017855,
0.10774775957449889,
0.10791046626101701,
0.10784940838686422,
0.10743295191190182,
0.10734256694459507,
0.10702427010187227,
0.10701240906412275,
0.10711385588296968,
0.10697286784585215,
0.10673481402679896,
0.10650451705814937,
0.10629599592369846,
0.1064668823288292,
0.1063920924296746,
0.10610189062789974,
0.10592550779573429,
0.10588065830942912,
0.105781758379223,
0.10560809290752961,
0.10550812136732106,
0.10535470090615444,
0.10536463093808573,
0.1051216669794586,
0.10498357508490738,
0.10464231009220976,
0.10468940513256268,
0.10468925925719942,
0.10429271149775411,
0.10437219857405393,
0.10406083403489529,
0.10395075554330634,
0.10419673752835673,
0.10405941009839885,
0.10379417274051751,
0.10373205498943472,
0.10360019166882221,
0.10355540880790123,
0.10355440188103761,
0.10325965399925525,
0.10304029177651446,
0.10311986905578364,
0.10273497300142916,
0.10302559410532315,
0.10278304515040329,
0.10263998298627189,
0.10254473253511466,
0.10245785787383206,
0.10246957698438922,
0.10233539204375866,
0.1025002559280803,
0.10214536613187729,
0.10215426669416265,
0.10214609539725332,
0.10188078407484752,
0.1020691341951362,
0.1019192597645725,
0.10151305177018173,
0.10163224848289774,
0.10181667007760614,
0.10129789227985928,
0.10133470410210454,
0.1014308016269635,
0.10122469465574647
],
"kl_loss": [
41526536.5961082,
1258.7768262553418,
1165.5987154968784,
3689.321748260759,
1320.3167513333833,
928.3473894168169,
789.754654843583,
693.4729607736963,
1017.8856519389357,
622.2623097998464,
304942.2120989938,
750.5718396830763,
618.6888906364767,
586.4194498958751,
579.2793892102363,
525.8830437945504,
497.5363791050055,
452.93342662061383,
172476214612.95245,
1112.027792645316,
1109.8156276605068,
1109.8872808472722,
1112.3999198196282,
1163.6093267457097,
1195.4633280436199,
1232.709252512353,
1253.0651078183428,
19923.65347159622,
6871.588723207132,
6872.017739842081,
6872.62934758113,
27062.17111336472,
4992.576698759683,
5000.895510942508,
5011.723715236044,
4.196954763212122e+20,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN
32.82580094867282,
27.63585641649034,
26.858297156472492,
26.36294916756133,
26.101821015023777,
25.925728178431847,
25.810632864634197,
25.75408285499638,
25.714011412400467,
25.691052310487144,
25.673080118293438,
25.658130865830643,
25.648127645508858,
25.640685057028747,
25.63372629116743,
25.62904858792949,
25.625450937157,
25.622334708515396,
25.61985801631569,
25.61797430168869,
25.617105920090633,
25.61665568800054,
25.615915049854507,
25.615573080176983,
25.61539109955486,
25.614891084850345,
25.614565458053196,
25.614424391689464,
25.614320644965538,
25.613855867304352,
25.613705484276142,
25.613627437852386,
25.61319611215184,
25.613402957590218,
25.612938135098187,
25.612793853140285,
25.61277205719907,
25.6125225368728,
25.612593328850902,
25.612471242236275,
25.61213506796421,
25.61220080220801,
25.612097707569088,
25.611939287593223,
25.611856093773476,
25.611754238096058,
25.61164761812259,
25.61152194096492,
25.61156750948001,
25.611467068011944,
25.61140904059777,
25.611222067449848,
25.61103208655985,
25.610888195852947,
25.610904620243954,
25.610692464388332,
25.610453536367825,
25.610431373628796,
25.610199642996502,
25.609933506729256,
25.60987384095151,
25.609577077066795,
25.6096679817917,
25.60952104666294,
25.60934520786644,
25.60918986899221,
25.60895837474073,
25.6090076030829,
25.60884045331906,
25.608697809724728,
25.608583588885445,
25.608446227179634,
25.60828152273455,
25.608113199217705,
25.60813422080798,
25.607940327407967,
25.60771349963979,
25.607653996883293,
25.607560540875816,
25.607479812752487,
25.607303941351734,
25.607178337553627,
25.60706621968848,
25.60683028310792,
25.60676119470189,
25.606651432493813,
25.606486483516857,
25.606393325023163,
25.606248602907883,
25.60615853366689,
25.605982751927826,
25.60586034334623,
25.60572910308838,
25.605554120153442,
25.605445796607906,
25.605342062110577,
25.605153943738365,
25.60500627501398,
25.60483596263788,
25.60475177031297
],
"perc_loss": [
NaN,
3.0017659954535656,
2.8931196978968434,
3.084328340159522,
3.1398269641093717,
3.0839009172896032,
3.0459561949102287,
3.015714469628456,
3.0917934143645134,
3.010985380054539,
NaN,
3.102917709411719,
3.055045795746339,
3.0351512865123587,
3.002274121993627,
2.9754957128793764,
2.9553630066733074,
2.936704080328982,
NaN,
3.66463458385223,
3.6419444985878773,
3.6321474706005845,
3.6272282503608966,
3.781295938878997,
3.740793755421272,
3.759704489993234,
3.7078149914741516,
NaN,
3.7647176323792872,
3.7441278461717133,
3.7264530261357627,
NaN,
3.7299226170931106,
3.7314435080585318,
3.754668933713538,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN
3.4957813187542124,
3.3751721346480217,
3.341614580561972,
3.325737010209988,
3.314503056371314,
3.3061294968311605,
3.2988068002920885,
3.2951497960294414,
3.289724970475221,
3.2849769531152186,
3.2810079547075124,
3.2768430485684648,
3.2737610197474813,
3.270859044331771,
3.267263490929563,
3.2641563344205546,
3.2620158368705683,
3.259672654999627,
3.2573955410566087,
3.2559085101143928,
3.2524190697914515,
3.2508940416523533,
3.2493486078376446,
3.2472466857005386,
3.246741561808138,
3.2446710722059264,
3.2429053090576434,
3.2415418217324805,
3.2410791664042025,
3.238901428687267,
3.2372534682608056,
3.2365858203325515,
3.2345399107688513,
3.234266322392684,
3.2330421667832594,
3.2321024134627776,
3.2317300596807756,
3.2297990515700774,
3.2291462492739034,
3.229038153958117,
3.227532477969797,
3.227750709423652,
3.22642219117564,
3.225604632980803,
3.224408136983203,
3.2237918310695224,
3.223153405719333,
3.2227800172618313,
3.22208283854346,
3.2215008134515877,
3.220343905636388,
3.2200386534389267,
3.2190595391469126,
3.218352971932827,
3.2173767803061724,
3.2164535395100584,
3.2164310828233376,
3.2155216560404525,
3.2145652862695546,
3.2132638708139076,
3.213306623136895,
3.2116519161778636,
3.2117279686479487,
3.210561112460927,
3.2098504080731645,
3.2099855195762763,
3.2090730035406914,
3.2085196706983776,
3.208222340824258,
3.207890693448548,
3.206765956348843,
3.2065086430973477,
3.2055318915945854,
3.204780939297798,
3.205001257933103,
3.203499240243537,
3.20293498905296,
3.2031849953863354,
3.2017455717437286,
3.2029461646691346,
3.202106138579866,
3.200608807751256,
3.200435662880922,
3.2001758396116076,
3.199647111260993,
3.199422711490566,
3.1987570305155892,
3.198392223089169,
3.1976211967631283,
3.1978414170762415,
3.197227580425067,
3.197008974531777,
3.1967804961734347,
3.1954670217302112,
3.19552276073358,
3.1953888051530237,
3.194187126098535,
3.194483350484799,
3.1943916347291736,
3.194042898650862
],
"adv_g_loss": [
0.0,
@@ -535,12 +537,12 @@
0.0
],
"fid": {
"25": 263.1458740234375,
"50": 598.3736572265625,
"75": 598.3736572265625,
"100": 598.3736572265625
"25": 218.6470947265625,
"50": 236.44911193847656,
"75": 235.74722290039062,
"100": 237.4191436767578
},
"train_time_s": 952.6596128940582
"train_time_s": 1509.468991279602
},
"n_params": 10608451
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,14 @@
{
"created_at": "2026-05-01T23:33:40.557082+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 35985830,
"offer_id": 30940273,
"ssh_host": "ssh4.vast.ai",
"ssh_port": 25830,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,14 @@
{
"created_at": "2026-05-02T01:21:14.070736+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 35989396,
"offer_id": 35960853,
"ssh_host": "ssh6.vast.ai",
"ssh_port": 29396,
"status": "cancelled",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,14 @@
{
"created_at": "2026-05-02T01:28:24.565349+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 35989623,
"offer_id": 29302404,
"ssh_host": "ssh9.vast.ai",
"ssh_port": 29622,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

After

Width:  |  Height:  |  Size: 204 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 120 KiB

After

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

After

Width:  |  Height:  |  Size: 128 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.5 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 126 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.5 KiB

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

After

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.3 KiB

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

After

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.4 KiB

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

After

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 121 KiB

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

After

Width:  |  Height:  |  Size: 9.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.3 KiB

After

Width:  |  Height:  |  Size: 9.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

After

Width:  |  Height:  |  Size: 68 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 135 KiB

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 31 KiB

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 23 KiB

+2 -1
View File
@@ -98,7 +98,8 @@ class VAE(nn.Module):
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = self.encoder(x).flatten(1)
return self.fc_mu(h), self.fc_lv(h)
log_var = self.fc_lv(h).clamp(-10.0, 10.0)
return self.fc_mu(h), log_var
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * log_var)
+61 -36
View File
@@ -403,13 +403,20 @@ def train_vae(
run_name: str,
device: str = "cuda",
) -> dict:
"""VAE training loop covering Phase 3.1 3.3.
"""VAE training loop covering Phase 3.1 3.3 and Phase 5.
Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
free_bits > 0 → per-dimension KL free bits (prevents posterior
collapse and KL explosion)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
computation runs in float32.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
@@ -425,6 +432,8 @@ def train_vae(
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
lr_d = cfg.get("lr_d", 1e-4)
free_bits_val = cfg.get("free_bits", 0.0)
grad_clip = cfg.get("grad_clip", 1.0)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
@@ -432,6 +441,7 @@ def train_vae(
use_perceptual = lambda_perceptual > 0
use_adversarial = lambda_adversarial > 0
use_free_bits = free_bits_val > 0
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
@@ -440,8 +450,8 @@ def train_vae(
)
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
use_amp = device.type == "cuda"
scaler = _GradScaler("cuda", enabled=use_amp)
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
use_amp = False
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
kl_warmup_epochs = max(1, epochs // 5)
@@ -456,11 +466,10 @@ def train_vae(
perc_fn = None
patchgan = None
opt_d = None
scaler_d = None
if use_perceptual:
from src.training.perceptual import PerceptualLoss
perc_fn = PerceptualLoss().to(device)
perc_fn = PerceptualLoss().to(device).float()
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
if use_adversarial:
@@ -468,15 +477,14 @@ def train_vae(
patchgan = PatchGANDiscriminator(
ndf=cfg.get("ndf_patch", 64),
image_size=cfg.get("image_size", 64),
).to(device)
).to(device).float()
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
scaler_d = _GradScaler("cuda", enabled=use_amp)
sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
n_d = sum(p.numel() for p in patchgan.parameters())
print(f"PatchGAN: {n_d:,} params")
else:
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called
hinge_d_loss = hinge_g_loss = None # never called
# ── Fixed seeds for consistent visualisation ──────────────────────────
fixed_z = torch.randn(16, latent_dim, device=device)
@@ -497,9 +505,11 @@ def train_vae(
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
}
best_fid = float("inf")
nan_skipped = 0
print(
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}"
f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
f" λ_adv={lambda_adversarial} free_bits={free_bits_val}"
)
t_start = time.time()
@@ -513,43 +523,56 @@ def train_vae(
n_batches = 0
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
real = real.to(device)
real = real.to(device).float()
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
# ── VAE forward ───────────────────────────────────────────────
with _autocast("cuda", enabled=use_amp):
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── VAE forward (float32, no AMP) ────────────────────────────
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
# KL divergence with optional free bits
kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) # (B, latent_dim)
if use_free_bits:
# Free bits: ensure each dimension contributes at least free_bits_val KL.
# Dimensions below the threshold are raised to it, preventing posterior
# collapse (dimensions that go to 0) while still penalising large KL.
kl_per_dim = torch.clamp(kl_per_dim, min=free_bits_val)
kl = kl_per_dim.sum(1).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── NaN/Inf guard ────────────────────────────────────────────
if not torch.isfinite(vae_loss):
nan_skipped += 1
opt_vae.zero_grad()
continue
# ── PatchGAN discriminator step ───────────────────────────────
adv_d = real.new_zeros(1).squeeze()
if use_adversarial:
opt_d.zero_grad()
with _autocast("cuda", enabled=use_amp):
d_real = patchgan(real)
d_fake = patchgan(recon.detach())
adv_d = hinge_d_loss(d_real, d_fake)
scaler_d.scale(adv_d).backward()
scaler_d.step(opt_d)
scaler_d.update()
d_real = patchgan(real)
d_fake = patchgan(recon.detach())
adv_d = hinge_d_loss(d_real, d_fake)
if torch.isfinite(adv_d):
adv_d.backward()
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
opt_d.step()
# ── PatchGAN generator adversarial loss ───────────────────────
adv_g = real.new_zeros(1).squeeze()
if use_adversarial:
with _autocast("cuda", enabled=use_amp):
adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g
adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g
# ── VAE backward ──────────────────────────────────────────────
opt_vae.zero_grad()
scaler.scale(vae_loss).backward()
scaler.step(opt_vae)
scaler.update()
vae_loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip)
opt_vae.step()
ema.update(vae)
recon_sum += mse.item()
@@ -559,11 +582,11 @@ def train_vae(
adv_d_sum += adv_d.item()
n_batches += 1
avg_r = recon_sum / n_batches
avg_k = kl_sum / n_batches
avg_p = perc_sum / n_batches
avg_g = adv_g_sum / n_batches
avg_d = adv_d_sum / n_batches
avg_r = recon_sum / max(n_batches, 1)
avg_k = kl_sum / max(n_batches, 1)
avg_p = perc_sum / max(n_batches, 1)
avg_g = adv_g_sum / max(n_batches, 1)
avg_d = adv_d_sum / max(n_batches, 1)
history["recon_loss"].append(avg_r)
history["kl_loss"].append(avg_k)
history["perc_loss"].append(avg_p)
@@ -574,6 +597,7 @@ def train_vae(
f"[{epoch:03d}/{epochs}] "
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
f" (NaN skipped: {nan_skipped})"
)
if epoch % sample_interval == 0:
@@ -607,6 +631,7 @@ def train_vae(
if patchgan is not None:
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
history["train_time_s"] = time.time() - t_start
print(f"Total NaN-skipped batches: {nan_skipped}")
return history