VAE refactor

This commit is contained in:
Johnny Fernandes
2026-05-02 00:32:45 +01:00
parent 1a7f67ab9c
commit c7804d2984
43 changed files with 1140 additions and 1064 deletions
+209 -207
View File
@@ -17,216 +17,218 @@
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"run_name": "p3_1_vae",
"lr": 0.001,
"beta_kl": 1.0,
"lr": 0.0005,
"beta_kl": 0.5,
"lambda_perceptual": 0.0,
"lambda_adversarial": 0.0
},
"history": {
"recon_loss": [
0.23614721197603095,
0.23315699178821,
0.22991716011594504,
NaN,
0.23217070787253544,
0.23155480842941847,
0.23157141198459855,
0.23181156750418183,
0.23201335527193853,
0.23178868266379732,
0.2315022333755962,
0.2311908418042028,
0.23185610672474927,
0.23176095832107413,
0.23165411693163407,
0.23174296459581098,
0.2317636658747991,
0.2317118427883356,
0.23172695364834917,
0.2316696329567677,
0.23168399261358458,
0.2316194716681782,
0.23164867447354856,
0.2315481170757204,
0.23165068109957582,
0.23167062098653907,
0.23162642907765177,
0.2315922882567104,
0.2315914996414103,
0.23156180984189367,
0.23156551628286004,
0.2315698005259037,
0.2315660522470617,
0.23156735001720935,
0.23161396435183337,
0.23158050178844705,
0.23159921089680785,
0.23149616745674712,
0.23159087484336308,
0.23156312872201967,
0.23153820200863048,
0.2315863819203825,
0.23150022140043414,
0.23154497337646973,
0.2315601774961011,
0.23153368950399578,
0.23152085642019907,
0.23151608884461924,
0.23154898990805334,
0.23155892872784892,
NaN,
NaN,
NaN,
0.24157701413600874,
NaN,
NaN,
0.24151325464630738,
NaN,
0.24154121766233036,
0.24155463749526912,
0.24158300176008135,
0.24158118757554609,
0.2415294518901242,
0.24156020069096842,
0.2415176352374574,
0.2415566616015047,
NaN,
0.24161437115608117,
0.24159398913765565,
0.24149432768806434,
0.24153172199287984,
0.24161516999204954,
0.24158193846034187,
0.2415451397562129,
0.24155487772873324,
0.24155297130346298,
NaN,
0.24157197961313093,
NaN,
NaN,
0.24158605401459923,
0.24156368870893094,
0.24159100852333582,
0.24153350121699846,
0.24153158377505776,
NaN,
0.24161708673350832,
0.24158515879868442,
NaN,
0.24157126235146809,
0.24162366709265953,
NaN,
0.2415581897665293,
NaN,
NaN,
0.2415400046823371,
NaN,
0.2415627600927638,
0.2415567432076503,
0.2415620140476614
0.16871115132274792,
0.16150351059742463,
0.15984480073436713,
0.15590801271490562,
0.15061364529861343,
0.14644645811973983,
0.14328488393917552,
0.14094583944887176,
0.13928856181665364,
0.1377539750124909,
0.13583827684195632,
0.13372899556898662,
0.13197963861509776,
0.1295166944559568,
0.12748059913770765,
0.126394641561768,
0.12385777387226748,
0.12244280847983482,
0.121336308045265,
0.12060208647296979,
0.11928399897411339,
0.11839368250061813,
0.11770952168183449,
0.11709171382344177,
0.11688287127922234,
0.11606283619617805,
0.11564522911595483,
0.11527438517500702,
0.11461337662150717,
0.11428412470297936,
0.11390002192849787,
0.11359650807248221,
0.11298466332129434,
0.11281277696227926,
0.11266172842846976,
0.11232183446996233,
0.11213540144137338,
0.11184148606645246,
0.11143260068682015,
0.11121100015365161,
0.11103646766044135,
0.11086450471009454,
0.11068084617901562,
0.11046731698080006,
0.1101092367600172,
0.11040792095228139,
0.11019570772082378,
0.10987201248669726,
0.10948091489063878,
0.10964858110070738,
0.10938817380457862,
0.10904825658688688,
0.10926413088718541,
0.10854257479246356,
0.10841736453784327,
0.10838394499041586,
0.10795553592153084,
0.10781666870491627,
0.10770979084265538,
0.1076134722520653,
0.10732251404123938,
0.10716411035157676,
0.10698133315413426,
0.10672279318364766,
0.10645651288776316,
0.10662059826601265,
0.10623115535156849,
0.10627176197102436,
0.1060571956456217,
0.10596368315382901,
0.10585042324840513,
0.1054028309085685,
0.10517520199601467,
0.1053342710639167,
0.1052525288815427,
0.10502007727821668,
0.10490142727573203,
0.10471088643002714,
0.10447397850390173,
0.1044878327470814,
0.10447094976328887,
0.10421697295501701,
0.10414895819675209,
0.10414911504102568,
0.10366271499894623,
0.10378186422217096,
0.10345360714719336,
0.10363478323396964,
0.10342726219668348,
0.10305201177859408,
0.10310276317545491,
0.10302646558445233,
0.10291730190635237,
0.10278329038276122,
0.10258485910156344,
0.10270861135079311,
0.10251107273830308,
0.10234160087684281,
0.10214911544552216,
0.10251628613879538
],
"kl_loss": [
12.394881742504927,
184.775765717539,
127.26797539963681,
33346392.786626913,
35.72433020722153,
31.41954361882984,
16.178619678203876,
10.234501274223001,
14.817130448471787,
9.230570034084158,
9.643558593896719,
8.47786058498244,
5.573643362929678,
2.4644629534365783,
1.5757666807462516,
0.426466258131286,
1.7924597560404203,
0.2769168242652956,
0.21636260826236162,
0.48804672485870176,
0.10833573165453142,
0.13318477837671328,
0.17373992544877478,
0.09584700099678121,
0.0977757986014088,
0.07794108981282538,
0.05691333960853199,
0.07221067506167242,
0.036222075203704275,
0.03126689469696492,
0.04264315036642882,
0.016960328184147805,
0.03314871971324309,
0.014776984407789368,
0.011375301962312406,
0.013948339588828703,
0.01186063720120324,
0.0099704478863372,
0.00536374123289417,
0.009618068660179583,
0.00418840028031164,
0.004865833775052785,
0.005830266629345715,
0.0023000687699064487,
0.0038261460966199762,
0.0022056369562673136,
0.002220870125003987,
0.0024217167485139184,
0.001954249278483037,
0.0021431104709895756,
0.0022583500309011494,
0.002132287005193404,
56.80083886633675,
82.57108385134966,
82.55195800259582,
82.57428529527452,
82.56009972401155,
82.55269345666608,
82.57006728343474,
82.55670593131302,
82.54445134676419,
82.57745079301361,
82.57933913336859,
82.5570435157189,
82.56808758597089,
82.56800172267816,
82.56525711320404,
82.56189481621115,
82.55193622295673,
82.55375865382007,
82.56600202250685,
82.57064581324912,
82.55481151026538,
82.55367833324986,
82.56042112040724,
82.5616829048874,
82.5771528553759,
82.55317820035495,
82.57550573756552,
82.57334061973116,
82.56044387817383,
82.5752662593483,
82.56673936762361,
82.56828115740393,
82.56990289280557,
82.55218840052939,
82.56695372426611,
82.575043066954,
82.55754522991995,
82.56361721723508,
82.5628145821074,
82.56431990403395,
82.55777725806603,
82.5742861918914,
82.56361025622768,
82.56887233766736,
82.56539458902473,
82.55887828729091,
82.56073884882478,
82.55578186165573
29.481668366326225,
26.226897622784996,
25.98807217932155,
25.807672516912476,
25.701779809772457,
25.66069327052842,
25.639777212061432,
25.62703147301307,
25.620143454299015,
25.61577053966685,
25.612662934849403,
25.60999612319164,
25.608206019442306,
25.60651744940342,
25.605323510292248,
25.60462644772652,
25.603682958162747,
25.603215930808304,
25.602825727218235,
25.602524419116158,
25.602402703374878,
25.602261029756985,
25.602163278139553,
25.602091275728664,
25.602065310518967,
25.602008819580078,
25.601970770420174,
25.601926473470833,
25.601869432335224,
25.60185879519862,
25.601852824545315,
25.60183112234132,
25.601783161489372,
25.601782627594776,
25.601758304824177,
25.601712039393238,
25.601732046176227,
25.601690267905212,
25.60164840404804,
25.601653335440872,
25.60167118626782,
25.601645954653748,
25.60163675210415,
25.601624191316784,
25.601573104532356,
25.60163724524343,
25.601609873975445,
25.601576218238243,
25.601538580706997,
25.601551618331516,
25.60154583922818,
25.6015028138446,
25.60153263858241,
25.601467629783173,
25.60146150018415,
25.601416970929527,
25.601392692989773,
25.601404438670883,
25.601370538401806,
25.60135773308257,
25.601353584191738,
25.60130478785588,
25.601299799405613,
25.60127318618644,
25.6012494054615,
25.601233796176746,
25.601193383208706,
25.601183601933666,
25.60118840698503,
25.601158150240906,
25.60114165656587,
25.601124979491928,
25.601091221866444,
25.60108805925418,
25.601056249732647,
25.601044960511036,
25.601008244049854,
25.600999204521504,
25.60097885539389,
25.60095446741479,
25.600959500695904,
25.600936054164528,
25.60091640195276,
25.600889332274086,
25.600879269787388,
25.600861647190193,
25.600824853293915,
25.6008102914207,
25.600796605786705,
25.600769034817688,
25.600744826161964,
25.600724538167317,
25.60070591298943,
25.600691750518276,
25.600666380336143,
25.60064803636991,
25.60063368642432,
25.600601587540066,
25.60058915309417,
25.600568995516525
],
"perc_loss": [
0.0,
@@ -535,12 +537,12 @@
0.0
],
"fid": {
"25": 315.9393615722656,
"50": 419.273193359375,
"75": 360.4432678222656,
"100": 363.9911193847656
"25": 238.42819213867188,
"50": 232.70050048828125,
"75": 234.88893127441406,
"100": 236.51181030273438
},
"train_time_s": 660.9630489349365
"train_time_s": 676.644668340683
},
"n_params": 10608451
}
+309 -307
View File
@@ -17,318 +17,320 @@
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"run_name": "p3_2_vae_perceptual",
"lr": 0.001,
"beta_kl": 0.0001,
"lr": 0.0005,
"beta_kl": 0.1,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.0
},
"history": {
"recon_loss": [
NaN,
0.041370747770127066,
0.035506031866002284,
0.06229733923672993,
0.06089382184048494,
0.053464704654856116,
0.04950355895213846,
0.04674786329269409,
0.05693557829811023,
0.04808994451076047,
NaN,
0.05869839608701121,
0.05293231666820426,
0.05057006603918779,
0.04713928615117175,
0.04481589820427008,
0.04331832689543565,
0.041820655449524395,
NaN,
0.24315394992884407,
0.24305414841470555,
0.2429187158998261,
0.24289727962424612,
0.2377373814328104,
0.23880945107875726,
0.24047312904626894,
0.23915709144411942,
NaN,
0.24000247061634675,
0.2362435321904655,
0.23705266075383905,
NaN,
0.23998981039238793,
0.23828768376738596,
0.23908851302077627,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN
0.13868479322419208,
0.1345828948622076,
0.13401474649261716,
0.13219879449814811,
0.13071280969386426,
0.12897613197246677,
0.12651141290353912,
0.12554800317773962,
0.123991425602864,
0.12276749478446113,
0.12151926139799449,
0.1200498831896191,
0.11872616813032545,
0.11811881408923203,
0.11655244218488024,
0.11565455276932982,
0.11529083312767693,
0.11437753734425601,
0.11373461206626688,
0.1133939535317258,
0.11269663740745467,
0.11214834819428432,
0.11180534907895276,
0.1112786961448753,
0.11112714579535855,
0.11040605649224713,
0.11024710934004213,
0.11025449748222645,
0.10986682841092603,
0.10947157509433918,
0.10914939207335313,
0.1090434947425229,
0.10872587247982494,
0.10891366730897854,
0.10840102485739268,
0.10831964285009438,
0.10826414010017855,
0.10774775957449889,
0.10791046626101701,
0.10784940838686422,
0.10743295191190182,
0.10734256694459507,
0.10702427010187227,
0.10701240906412275,
0.10711385588296968,
0.10697286784585215,
0.10673481402679896,
0.10650451705814937,
0.10629599592369846,
0.1064668823288292,
0.1063920924296746,
0.10610189062789974,
0.10592550779573429,
0.10588065830942912,
0.105781758379223,
0.10560809290752961,
0.10550812136732106,
0.10535470090615444,
0.10536463093808573,
0.1051216669794586,
0.10498357508490738,
0.10464231009220976,
0.10468940513256268,
0.10468925925719942,
0.10429271149775411,
0.10437219857405393,
0.10406083403489529,
0.10395075554330634,
0.10419673752835673,
0.10405941009839885,
0.10379417274051751,
0.10373205498943472,
0.10360019166882221,
0.10355540880790123,
0.10355440188103761,
0.10325965399925525,
0.10304029177651446,
0.10311986905578364,
0.10273497300142916,
0.10302559410532315,
0.10278304515040329,
0.10263998298627189,
0.10254473253511466,
0.10245785787383206,
0.10246957698438922,
0.10233539204375866,
0.1025002559280803,
0.10214536613187729,
0.10215426669416265,
0.10214609539725332,
0.10188078407484752,
0.1020691341951362,
0.1019192597645725,
0.10151305177018173,
0.10163224848289774,
0.10181667007760614,
0.10129789227985928,
0.10133470410210454,
0.1014308016269635,
0.10122469465574647
],
"kl_loss": [
41526536.5961082,
1258.7768262553418,
1165.5987154968784,
3689.321748260759,
1320.3167513333833,
928.3473894168169,
789.754654843583,
693.4729607736963,
1017.8856519389357,
622.2623097998464,
304942.2120989938,
750.5718396830763,
618.6888906364767,
586.4194498958751,
579.2793892102363,
525.8830437945504,
497.5363791050055,
452.93342662061383,
172476214612.95245,
1112.027792645316,
1109.8156276605068,
1109.8872808472722,
1112.3999198196282,
1163.6093267457097,
1195.4633280436199,
1232.709252512353,
1253.0651078183428,
19923.65347159622,
6871.588723207132,
6872.017739842081,
6872.62934758113,
27062.17111336472,
4992.576698759683,
5000.895510942508,
5011.723715236044,
4.196954763212122e+20,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN
32.82580094867282,
27.63585641649034,
26.858297156472492,
26.36294916756133,
26.101821015023777,
25.925728178431847,
25.810632864634197,
25.75408285499638,
25.714011412400467,
25.691052310487144,
25.673080118293438,
25.658130865830643,
25.648127645508858,
25.640685057028747,
25.63372629116743,
25.62904858792949,
25.625450937157,
25.622334708515396,
25.61985801631569,
25.61797430168869,
25.617105920090633,
25.61665568800054,
25.615915049854507,
25.615573080176983,
25.61539109955486,
25.614891084850345,
25.614565458053196,
25.614424391689464,
25.614320644965538,
25.613855867304352,
25.613705484276142,
25.613627437852386,
25.61319611215184,
25.613402957590218,
25.612938135098187,
25.612793853140285,
25.61277205719907,
25.6125225368728,
25.612593328850902,
25.612471242236275,
25.61213506796421,
25.61220080220801,
25.612097707569088,
25.611939287593223,
25.611856093773476,
25.611754238096058,
25.61164761812259,
25.61152194096492,
25.61156750948001,
25.611467068011944,
25.61140904059777,
25.611222067449848,
25.61103208655985,
25.610888195852947,
25.610904620243954,
25.610692464388332,
25.610453536367825,
25.610431373628796,
25.610199642996502,
25.609933506729256,
25.60987384095151,
25.609577077066795,
25.6096679817917,
25.60952104666294,
25.60934520786644,
25.60918986899221,
25.60895837474073,
25.6090076030829,
25.60884045331906,
25.608697809724728,
25.608583588885445,
25.608446227179634,
25.60828152273455,
25.608113199217705,
25.60813422080798,
25.607940327407967,
25.60771349963979,
25.607653996883293,
25.607560540875816,
25.607479812752487,
25.607303941351734,
25.607178337553627,
25.60706621968848,
25.60683028310792,
25.60676119470189,
25.606651432493813,
25.606486483516857,
25.606393325023163,
25.606248602907883,
25.60615853366689,
25.605982751927826,
25.60586034334623,
25.60572910308838,
25.605554120153442,
25.605445796607906,
25.605342062110577,
25.605153943738365,
25.60500627501398,
25.60483596263788,
25.60475177031297
],
"perc_loss": [
NaN,
3.0017659954535656,
2.8931196978968434,
3.084328340159522,
3.1398269641093717,
3.0839009172896032,
3.0459561949102287,
3.015714469628456,
3.0917934143645134,
3.010985380054539,
NaN,
3.102917709411719,
3.055045795746339,
3.0351512865123587,
3.002274121993627,
2.9754957128793764,
2.9553630066733074,
2.936704080328982,
NaN,
3.66463458385223,
3.6419444985878773,
3.6321474706005845,
3.6272282503608966,
3.781295938878997,
3.740793755421272,
3.759704489993234,
3.7078149914741516,
NaN,
3.7647176323792872,
3.7441278461717133,
3.7264530261357627,
NaN,
3.7299226170931106,
3.7314435080585318,
3.754668933713538,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN
3.4957813187542124,
3.3751721346480217,
3.341614580561972,
3.325737010209988,
3.314503056371314,
3.3061294968311605,
3.2988068002920885,
3.2951497960294414,
3.289724970475221,
3.2849769531152186,
3.2810079547075124,
3.2768430485684648,
3.2737610197474813,
3.270859044331771,
3.267263490929563,
3.2641563344205546,
3.2620158368705683,
3.259672654999627,
3.2573955410566087,
3.2559085101143928,
3.2524190697914515,
3.2508940416523533,
3.2493486078376446,
3.2472466857005386,
3.246741561808138,
3.2446710722059264,
3.2429053090576434,
3.2415418217324805,
3.2410791664042025,
3.238901428687267,
3.2372534682608056,
3.2365858203325515,
3.2345399107688513,
3.234266322392684,
3.2330421667832594,
3.2321024134627776,
3.2317300596807756,
3.2297990515700774,
3.2291462492739034,
3.229038153958117,
3.227532477969797,
3.227750709423652,
3.22642219117564,
3.225604632980803,
3.224408136983203,
3.2237918310695224,
3.223153405719333,
3.2227800172618313,
3.22208283854346,
3.2215008134515877,
3.220343905636388,
3.2200386534389267,
3.2190595391469126,
3.218352971932827,
3.2173767803061724,
3.2164535395100584,
3.2164310828233376,
3.2155216560404525,
3.2145652862695546,
3.2132638708139076,
3.213306623136895,
3.2116519161778636,
3.2117279686479487,
3.210561112460927,
3.2098504080731645,
3.2099855195762763,
3.2090730035406914,
3.2085196706983776,
3.208222340824258,
3.207890693448548,
3.206765956348843,
3.2065086430973477,
3.2055318915945854,
3.204780939297798,
3.205001257933103,
3.203499240243537,
3.20293498905296,
3.2031849953863354,
3.2017455717437286,
3.2029461646691346,
3.202106138579866,
3.200608807751256,
3.200435662880922,
3.2001758396116076,
3.199647111260993,
3.199422711490566,
3.1987570305155892,
3.198392223089169,
3.1976211967631283,
3.1978414170762415,
3.197227580425067,
3.197008974531777,
3.1967804961734347,
3.1954670217302112,
3.19552276073358,
3.1953888051530237,
3.194187126098535,
3.194483350484799,
3.1943916347291736,
3.194042898650862
],
"adv_g_loss": [
0.0,
@@ -535,12 +537,12 @@
0.0
],
"fid": {
"25": 263.1458740234375,
"50": 598.3736572265625,
"75": 598.3736572265625,
"100": 598.3736572265625
"25": 218.6470947265625,
"50": 236.44911193847656,
"75": 235.74722290039062,
"100": 237.4191436767578
},
"train_time_s": 952.6596128940582
"train_time_s": 1509.468991279602
},
"n_params": 10608451
}
File diff suppressed because it is too large Load Diff