Testing VAE until it works - v1

This commit is contained in:
Johnny Fernandes
2026-05-02 13:11:56 +01:00
parent c7804d2984
commit ec8d4ae336
84 changed files with 9 additions and 1744 deletions
@@ -7,7 +7,6 @@
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"sample_interval": 10,
"fid_interval": 25,
+1 -1
View File
@@ -2,7 +2,7 @@
"extends": "_base_phase3.json",
"run_name": "p3_1_vae",
"lr": 5e-4,
"beta_kl": 0.5,
"beta_kl": 0.25,
"lambda_perceptual": 0.0,
"lambda_adversarial": 0.0
}
@@ -2,7 +2,7 @@
"extends": "_base_phase3.json",
"run_name": "p3_2_vae_perceptual",
"lr": 5e-4,
"beta_kl": 0.1,
"beta_kl": 0.25,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.0
}
@@ -3,7 +3,7 @@
"run_name": "p3_3_vae_patchgan",
"lr": 5e-4,
"lr_d": 1e-4,
"beta_kl": 0.05,
"beta_kl": 0.25,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.1,
"ndf_patch": 64
-548
View File
@@ -1,548 +0,0 @@
{
"run_name": "p3_1_vae",
"config": {
"batch_size": 64,
"ema_decay": 0.9999,
"data_dir": "cropped/generator",
"sources": [
"wiki"
],
"subsample": 1.0,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000,
"epochs": 100,
"augment": "hflip",
"image_size": 64,
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"run_name": "p3_1_vae",
"lr": 0.0005,
"beta_kl": 0.5,
"lambda_perceptual": 0.0,
"lambda_adversarial": 0.0
},
"history": {
"recon_loss": [
0.16871115132274792,
0.16150351059742463,
0.15984480073436713,
0.15590801271490562,
0.15061364529861343,
0.14644645811973983,
0.14328488393917552,
0.14094583944887176,
0.13928856181665364,
0.1377539750124909,
0.13583827684195632,
0.13372899556898662,
0.13197963861509776,
0.1295166944559568,
0.12748059913770765,
0.126394641561768,
0.12385777387226748,
0.12244280847983482,
0.121336308045265,
0.12060208647296979,
0.11928399897411339,
0.11839368250061813,
0.11770952168183449,
0.11709171382344177,
0.11688287127922234,
0.11606283619617805,
0.11564522911595483,
0.11527438517500702,
0.11461337662150717,
0.11428412470297936,
0.11390002192849787,
0.11359650807248221,
0.11298466332129434,
0.11281277696227926,
0.11266172842846976,
0.11232183446996233,
0.11213540144137338,
0.11184148606645246,
0.11143260068682015,
0.11121100015365161,
0.11103646766044135,
0.11086450471009454,
0.11068084617901562,
0.11046731698080006,
0.1101092367600172,
0.11040792095228139,
0.11019570772082378,
0.10987201248669726,
0.10948091489063878,
0.10964858110070738,
0.10938817380457862,
0.10904825658688688,
0.10926413088718541,
0.10854257479246356,
0.10841736453784327,
0.10838394499041586,
0.10795553592153084,
0.10781666870491627,
0.10770979084265538,
0.1076134722520653,
0.10732251404123938,
0.10716411035157676,
0.10698133315413426,
0.10672279318364766,
0.10645651288776316,
0.10662059826601265,
0.10623115535156849,
0.10627176197102436,
0.1060571956456217,
0.10596368315382901,
0.10585042324840513,
0.1054028309085685,
0.10517520199601467,
0.1053342710639167,
0.1052525288815427,
0.10502007727821668,
0.10490142727573203,
0.10471088643002714,
0.10447397850390173,
0.1044878327470814,
0.10447094976328887,
0.10421697295501701,
0.10414895819675209,
0.10414911504102568,
0.10366271499894623,
0.10378186422217096,
0.10345360714719336,
0.10363478323396964,
0.10342726219668348,
0.10305201177859408,
0.10310276317545491,
0.10302646558445233,
0.10291730190635237,
0.10278329038276122,
0.10258485910156344,
0.10270861135079311,
0.10251107273830308,
0.10234160087684281,
0.10214911544552216,
0.10251628613879538
],
"kl_loss": [
29.481668366326225,
26.226897622784996,
25.98807217932155,
25.807672516912476,
25.701779809772457,
25.66069327052842,
25.639777212061432,
25.62703147301307,
25.620143454299015,
25.61577053966685,
25.612662934849403,
25.60999612319164,
25.608206019442306,
25.60651744940342,
25.605323510292248,
25.60462644772652,
25.603682958162747,
25.603215930808304,
25.602825727218235,
25.602524419116158,
25.602402703374878,
25.602261029756985,
25.602163278139553,
25.602091275728664,
25.602065310518967,
25.602008819580078,
25.601970770420174,
25.601926473470833,
25.601869432335224,
25.60185879519862,
25.601852824545315,
25.60183112234132,
25.601783161489372,
25.601782627594776,
25.601758304824177,
25.601712039393238,
25.601732046176227,
25.601690267905212,
25.60164840404804,
25.601653335440872,
25.60167118626782,
25.601645954653748,
25.60163675210415,
25.601624191316784,
25.601573104532356,
25.60163724524343,
25.601609873975445,
25.601576218238243,
25.601538580706997,
25.601551618331516,
25.60154583922818,
25.6015028138446,
25.60153263858241,
25.601467629783173,
25.60146150018415,
25.601416970929527,
25.601392692989773,
25.601404438670883,
25.601370538401806,
25.60135773308257,
25.601353584191738,
25.60130478785588,
25.601299799405613,
25.60127318618644,
25.6012494054615,
25.601233796176746,
25.601193383208706,
25.601183601933666,
25.60118840698503,
25.601158150240906,
25.60114165656587,
25.601124979491928,
25.601091221866444,
25.60108805925418,
25.601056249732647,
25.601044960511036,
25.601008244049854,
25.600999204521504,
25.60097885539389,
25.60095446741479,
25.600959500695904,
25.600936054164528,
25.60091640195276,
25.600889332274086,
25.600879269787388,
25.600861647190193,
25.600824853293915,
25.6008102914207,
25.600796605786705,
25.600769034817688,
25.600744826161964,
25.600724538167317,
25.60070591298943,
25.600691750518276,
25.600666380336143,
25.60064803636991,
25.60063368642432,
25.600601587540066,
25.60058915309417,
25.600568995516525
],
"perc_loss": [
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
],
"adv_g_loss": [
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
],
"adv_d_loss": [
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
],
"fid": {
"25": 238.42819213867188,
"50": 232.70050048828125,
"75": 234.88893127441406,
"100": 236.51181030273438
},
"train_time_s": 676.644668340683
},
"n_params": 10608451
}
@@ -1,548 +0,0 @@
{
"run_name": "p3_2_vae_perceptual",
"config": {
"batch_size": 64,
"ema_decay": 0.9999,
"data_dir": "cropped/generator",
"sources": [
"wiki"
],
"subsample": 1.0,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000,
"epochs": 100,
"augment": "hflip",
"image_size": 64,
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"run_name": "p3_2_vae_perceptual",
"lr": 0.0005,
"beta_kl": 0.1,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.0
},
"history": {
"recon_loss": [
0.13868479322419208,
0.1345828948622076,
0.13401474649261716,
0.13219879449814811,
0.13071280969386426,
0.12897613197246677,
0.12651141290353912,
0.12554800317773962,
0.123991425602864,
0.12276749478446113,
0.12151926139799449,
0.1200498831896191,
0.11872616813032545,
0.11811881408923203,
0.11655244218488024,
0.11565455276932982,
0.11529083312767693,
0.11437753734425601,
0.11373461206626688,
0.1133939535317258,
0.11269663740745467,
0.11214834819428432,
0.11180534907895276,
0.1112786961448753,
0.11112714579535855,
0.11040605649224713,
0.11024710934004213,
0.11025449748222645,
0.10986682841092603,
0.10947157509433918,
0.10914939207335313,
0.1090434947425229,
0.10872587247982494,
0.10891366730897854,
0.10840102485739268,
0.10831964285009438,
0.10826414010017855,
0.10774775957449889,
0.10791046626101701,
0.10784940838686422,
0.10743295191190182,
0.10734256694459507,
0.10702427010187227,
0.10701240906412275,
0.10711385588296968,
0.10697286784585215,
0.10673481402679896,
0.10650451705814937,
0.10629599592369846,
0.1064668823288292,
0.1063920924296746,
0.10610189062789974,
0.10592550779573429,
0.10588065830942912,
0.105781758379223,
0.10560809290752961,
0.10550812136732106,
0.10535470090615444,
0.10536463093808573,
0.1051216669794586,
0.10498357508490738,
0.10464231009220976,
0.10468940513256268,
0.10468925925719942,
0.10429271149775411,
0.10437219857405393,
0.10406083403489529,
0.10395075554330634,
0.10419673752835673,
0.10405941009839885,
0.10379417274051751,
0.10373205498943472,
0.10360019166882221,
0.10355540880790123,
0.10355440188103761,
0.10325965399925525,
0.10304029177651446,
0.10311986905578364,
0.10273497300142916,
0.10302559410532315,
0.10278304515040329,
0.10263998298627189,
0.10254473253511466,
0.10245785787383206,
0.10246957698438922,
0.10233539204375866,
0.1025002559280803,
0.10214536613187729,
0.10215426669416265,
0.10214609539725332,
0.10188078407484752,
0.1020691341951362,
0.1019192597645725,
0.10151305177018173,
0.10163224848289774,
0.10181667007760614,
0.10129789227985928,
0.10133470410210454,
0.1014308016269635,
0.10122469465574647
],
"kl_loss": [
32.82580094867282,
27.63585641649034,
26.858297156472492,
26.36294916756133,
26.101821015023777,
25.925728178431847,
25.810632864634197,
25.75408285499638,
25.714011412400467,
25.691052310487144,
25.673080118293438,
25.658130865830643,
25.648127645508858,
25.640685057028747,
25.63372629116743,
25.62904858792949,
25.625450937157,
25.622334708515396,
25.61985801631569,
25.61797430168869,
25.617105920090633,
25.61665568800054,
25.615915049854507,
25.615573080176983,
25.61539109955486,
25.614891084850345,
25.614565458053196,
25.614424391689464,
25.614320644965538,
25.613855867304352,
25.613705484276142,
25.613627437852386,
25.61319611215184,
25.613402957590218,
25.612938135098187,
25.612793853140285,
25.61277205719907,
25.6125225368728,
25.612593328850902,
25.612471242236275,
25.61213506796421,
25.61220080220801,
25.612097707569088,
25.611939287593223,
25.611856093773476,
25.611754238096058,
25.61164761812259,
25.61152194096492,
25.61156750948001,
25.611467068011944,
25.61140904059777,
25.611222067449848,
25.61103208655985,
25.610888195852947,
25.610904620243954,
25.610692464388332,
25.610453536367825,
25.610431373628796,
25.610199642996502,
25.609933506729256,
25.60987384095151,
25.609577077066795,
25.6096679817917,
25.60952104666294,
25.60934520786644,
25.60918986899221,
25.60895837474073,
25.6090076030829,
25.60884045331906,
25.608697809724728,
25.608583588885445,
25.608446227179634,
25.60828152273455,
25.608113199217705,
25.60813422080798,
25.607940327407967,
25.60771349963979,
25.607653996883293,
25.607560540875816,
25.607479812752487,
25.607303941351734,
25.607178337553627,
25.60706621968848,
25.60683028310792,
25.60676119470189,
25.606651432493813,
25.606486483516857,
25.606393325023163,
25.606248602907883,
25.60615853366689,
25.605982751927826,
25.60586034334623,
25.60572910308838,
25.605554120153442,
25.605445796607906,
25.605342062110577,
25.605153943738365,
25.60500627501398,
25.60483596263788,
25.60475177031297
],
"perc_loss": [
3.4957813187542124,
3.3751721346480217,
3.341614580561972,
3.325737010209988,
3.314503056371314,
3.3061294968311605,
3.2988068002920885,
3.2951497960294414,
3.289724970475221,
3.2849769531152186,
3.2810079547075124,
3.2768430485684648,
3.2737610197474813,
3.270859044331771,
3.267263490929563,
3.2641563344205546,
3.2620158368705683,
3.259672654999627,
3.2573955410566087,
3.2559085101143928,
3.2524190697914515,
3.2508940416523533,
3.2493486078376446,
3.2472466857005386,
3.246741561808138,
3.2446710722059264,
3.2429053090576434,
3.2415418217324805,
3.2410791664042025,
3.238901428687267,
3.2372534682608056,
3.2365858203325515,
3.2345399107688513,
3.234266322392684,
3.2330421667832594,
3.2321024134627776,
3.2317300596807756,
3.2297990515700774,
3.2291462492739034,
3.229038153958117,
3.227532477969797,
3.227750709423652,
3.22642219117564,
3.225604632980803,
3.224408136983203,
3.2237918310695224,
3.223153405719333,
3.2227800172618313,
3.22208283854346,
3.2215008134515877,
3.220343905636388,
3.2200386534389267,
3.2190595391469126,
3.218352971932827,
3.2173767803061724,
3.2164535395100584,
3.2164310828233376,
3.2155216560404525,
3.2145652862695546,
3.2132638708139076,
3.213306623136895,
3.2116519161778636,
3.2117279686479487,
3.210561112460927,
3.2098504080731645,
3.2099855195762763,
3.2090730035406914,
3.2085196706983776,
3.208222340824258,
3.207890693448548,
3.206765956348843,
3.2065086430973477,
3.2055318915945854,
3.204780939297798,
3.205001257933103,
3.203499240243537,
3.20293498905296,
3.2031849953863354,
3.2017455717437286,
3.2029461646691346,
3.202106138579866,
3.200608807751256,
3.200435662880922,
3.2001758396116076,
3.199647111260993,
3.199422711490566,
3.1987570305155892,
3.198392223089169,
3.1976211967631283,
3.1978414170762415,
3.197227580425067,
3.197008974531777,
3.1967804961734347,
3.1954670217302112,
3.19552276073358,
3.1953888051530237,
3.194187126098535,
3.194483350484799,
3.1943916347291736,
3.194042898650862
],
"adv_g_loss": [
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
],
"adv_d_loss": [
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
],
"fid": {
"25": 218.6470947265625,
"50": 236.44911193847656,
"75": 235.74722290039062,
"100": 237.4191436767578
},
"train_time_s": 1509.468991279602
},
"n_params": 10608451
}
@@ -1,550 +0,0 @@
{
"run_name": "p3_3_vae_patchgan",
"config": {
"batch_size": 64,
"ema_decay": 0.9999,
"data_dir": "cropped/generator",
"sources": [
"wiki"
],
"subsample": 1.0,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000,
"epochs": 100,
"augment": "hflip",
"image_size": 64,
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"free_bits": 0.1,
"grad_clip": 1.0,
"run_name": "p3_3_vae_patchgan",
"lr": 0.0005,
"lr_d": 0.0001,
"beta_kl": 0.05,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.1,
"ndf_patch": 64
},
"history": {
"recon_loss": [
0.20664009576042494,
0.20546884704222027,
0.2026536207423251,
0.19372608764176694,
0.18886286809913114,
0.18587723256558433,
0.18252239019697547,
0.18123333547741938,
0.1826040716125415,
0.17910013230055824,
0.17871976541912454,
0.1786047644340075,
0.17894037895732456,
0.1770373120203487,
0.17791039133683229,
0.17689791439562783,
0.1768273785073533,
0.17606183192414096,
0.17561813542603427,
0.17555826488468382,
0.1764616704840436,
0.17546906631089684,
0.17435517222580746,
0.17366362592348686,
0.17288134801082122,
0.17304196351995835,
0.1726210970336046,
0.17214958993797627,
0.17092827892201579,
0.17048384440250886,
0.17009502111209762,
0.16923288472442546,
0.16926560054222742,
0.1685944145752324,
0.16913408833818558,
0.16864649152271768,
0.16715254338505942,
0.1674847087671614,
0.16627334867022994,
0.16667878182015866,
0.1673471293069868,
0.1658064564769594,
0.16568841266199055,
0.1656897024721162,
0.16533309702053028,
0.1651330927116239,
0.16490820941762027,
0.16450774570942944,
0.16505960890879998,
0.1657766858036192,
0.16384747038539657,
0.16329284715983602,
0.1630541528774123,
0.1632642790547803,
0.16329585748095798,
0.16283382962529475,
0.162109665827364,
0.16237290041186872,
0.16060760107814756,
0.15995743348557726,
0.16160479480894202,
0.16059094689722753,
0.16062985516638836,
0.1601817585591577,
0.15990230441093445,
0.15923667055928808,
0.15944513557558385,
0.15786530402226326,
0.157669540717561,
0.15771191879215404,
0.15836861143764266,
0.1579915552567213,
0.1574682067347388,
0.15650623944452685,
0.15726811556607231,
0.15652572038846138,
0.15526711217995381,
0.15566525547804996,
0.1559221432504491,
0.15413334128311557,
0.1538484821525904,
0.15617854215013674,
0.15460723718134767,
0.15452368347308573,
0.15389670648126522,
0.153534261111775,
0.15374353961047962,
0.15358849379241976,
0.15341616234081423,
0.1539706002570625,
0.15283517762381807,
0.15448490144987392,
0.15431701716704246,
0.15514606670436695,
0.15397184323041868,
0.15382100995152426,
0.15411812398168775,
0.15284074429008696,
0.15258317590396628,
0.15257134797990832
],
"kl_loss": [
68.12842368264484,
58.52319613888732,
44.18469975544856,
33.35476183279967,
29.756296320858166,
28.173175371610203,
27.34979814545721,
26.876195630456646,
26.51762535633185,
26.26355965117104,
26.173303608201508,
26.087695989853298,
26.018392094180115,
25.98833444383409,
25.93635396875887,
25.901821535876675,
25.878659419524364,
25.856926702026627,
25.831880243415508,
25.809421323303482,
25.812285415127747,
25.809016423347668,
25.803619710808125,
25.80864535845243,
25.807042011847862,
25.807511097345596,
25.80379281899868,
25.800463224068665,
25.80085773141975,
25.803321630526813,
25.80114118461935,
25.797412778577232,
25.79841448710515,
25.803915830758903,
25.800587319920204,
25.799186054457966,
25.79303413374811,
25.800722203703007,
25.78836956187191,
25.795500290699493,
25.801840765863403,
25.795673590440018,
25.788382379417744,
25.790301979097546,
25.79772145523984,
25.786431638603535,
25.788207636939156,
25.7939785117777,
25.790546046362984,
25.791608264303616,
25.782368317628517,
25.781000569335415,
25.78018570353842,
25.78992747852945,
25.791334254109962,
25.785102526346844,
25.777616117754555,
25.787795796353592,
25.77431806336101,
25.780450967641976,
25.76879140250703,
25.769127409682316,
25.77564622601892,
25.774933334089752,
25.773548411507893,
25.781807695698536,
25.767144321376442,
25.76645992556189,
25.764682867588142,
25.763574934413292,
25.764604503272945,
25.757338796925342,
25.760351731226994,
25.753257017869217,
25.756990607987102,
25.74790485903748,
25.750839449401596,
25.762463439224113,
25.747816888695088,
25.745640946249676,
25.755338065644615,
25.75259123500596,
25.73955793462248,
25.74994816739335,
25.743701678055984,
25.741836702721752,
25.746034561059414,
25.74659049205291,
25.748252192114155,
25.75794408667801,
25.740664445436916,
25.729339342850906,
25.73942155104417,
25.73681424621843,
25.74283838679648,
25.743645028171375,
25.737078760424232,
25.738885565700695,
25.730926350650623,
25.73253002329769
],
"perc_loss": [
4.284683430296743,
4.412672106017414,
4.450774787328182,
4.398514354330862,
4.374030993534968,
4.3477759218623495,
4.351976938736745,
4.363105954777481,
4.3694430824018955,
4.36106641883524,
4.347512537597591,
4.3484030238583555,
4.345626479540115,
4.3261580187031345,
4.325833334882035,
4.326021440008766,
4.3213376805313635,
4.314768468212877,
4.3192369550721255,
4.318043978805216,
4.3108146455552845,
4.308960727137378,
4.298641049963797,
4.289013148373009,
4.271883454078283,
4.273704745830634,
4.270601803420956,
4.263830899173378,
4.2647860131712045,
4.259130676077981,
4.248837271307269,
4.246620962762425,
4.243108230778295,
4.251198922467028,
4.23574418695564,
4.232880752310794,
4.219367954466078,
4.214961731026316,
4.215395676274585,
4.21049471033944,
4.21475775527139,
4.209023853143056,
4.196490774806748,
4.201774182992104,
4.191486187979707,
4.182125541389498,
4.182892669469882,
4.186337848504384,
4.185713631984515,
4.185926957008166,
4.173205633448739,
4.170292638815367,
4.171781816543677,
4.1742227994478664,
4.162097737320468,
4.153104884502215,
4.148357060220507,
4.1533002023004055,
4.14140139787625,
4.12923164652963,
4.1401490854401874,
4.1272752071038274,
4.130248739169194,
4.132051646200001,
4.123921276157738,
4.123644507338858,
4.120624431687543,
4.116347138698284,
4.098994751771291,
4.10354998223802,
4.115526872312921,
4.093757800057403,
4.108234375969976,
4.095992333359188,
4.090389885963538,
4.086705027482449,
4.077634346281362,
4.082558892730974,
4.089221684341757,
4.073336944620833,
4.064374155977852,
4.082791697775197,
4.069841699213044,
4.076490493411692,
4.064204802370479,
4.061286298128275,
4.052628819759075,
4.0491809697232695,
4.058965112409021,
4.052690645568391,
4.058227839632931,
4.052835367683672,
4.053200018711579,
4.058637539036254,
4.054576839646723,
4.049067093266381,
4.050501744971316,
4.0333083093675794,
4.031478229241493,
4.03050094944799
],
"adv_g_loss": [
1.6945906835488784,
2.2359936534084826,
2.1402350850084906,
2.315248914508738,
2.233884201202,
2.1723885705583115,
2.0820165915239572,
2.055549093959933,
2.051057412153763,
2.015316769034944,
1.9094728334179012,
1.8913572477674288,
1.897682955655723,
1.882107293062135,
1.8927111610666745,
1.8688224147782366,
1.9567883037255567,
1.8886699466538241,
1.9518562734075304,
1.9636508357814617,
1.9562227244520736,
1.9984713741598061,
1.9947010266294098,
1.931989847578936,
2.0217894085197368,
2.087464853603807,
2.0854533564132183,
2.073050996646858,
2.1703977200529003,
2.2110010142991334,
2.1600928430119133,
2.216297351460681,
2.3224761097763595,
2.304568846176705,
2.2836171977030926,
2.297615868668271,
2.3322910882182355,
2.3625818552115025,
2.38175813596632,
2.406708580911414,
2.405017257691958,
2.40933440421891,
2.466819341907389,
2.5458592870551295,
2.5856447751259704,
2.5504658403050184,
2.55090725606578,
2.635533580954513,
2.676304424420381,
2.709152910380791,
2.6734945714537406,
2.6929639437769213,
2.741068474438965,
2.788715256075574,
2.749225435858099,
2.7717263867330346,
2.8026352137581916,
2.8223156610615234,
2.7881575422918696,
2.831760157632013,
2.8619034027951393,
2.924719768202203,
2.900455035460301,
2.8933564607913675,
2.921116583240338,
2.9773028457266655,
2.9610479168402843,
2.990598867336909,
3.0113285031074133,
3.009772124708208,
3.024522620643306,
3.030661901092937,
3.0474760035673776,
3.0628549543201413,
3.096129789056941,
3.149008815869307,
3.1224647607558813,
3.2760142256052065,
3.228409260766119,
3.2076573489058733,
3.2398641476264367,
3.232522460639986,
3.2508003120748405,
3.240605200967218,
3.286187628395537,
3.2734705255581784,
3.2855510803369374,
3.299281896688999,
3.3648755433212996,
3.4358235439683638,
3.368582401520167,
3.305457336270911,
3.3189472449131503,
3.3348770279150743,
3.2767913479071398,
3.2550050493998404,
3.280222808193957,
3.2900603296410322,
3.300272215125907,
3.312584845428793
],
"adv_d_loss": [
0.4152746302131404,
0.20451608963278878,
0.23405680463959774,
0.19588809168268728,
0.23218965130802402,
0.21756976912928444,
0.24355858413136414,
0.24630528256997594,
0.2544653647953374,
0.2547152943781808,
0.28348284006182456,
0.27956568545255905,
0.2854043210291455,
0.2891966873286372,
0.27150706937297797,
0.2832239268618262,
0.26533032114752847,
0.2864285026255072,
0.26611408748878884,
0.2652207337415371,
0.2585428311752203,
0.2641343518168244,
0.26014862816112166,
0.255134222519576,
0.23690097094473675,
0.23983751258088482,
0.23426481247202963,
0.23212065616160885,
0.21858446362117925,
0.2071263620104545,
0.21544637269953376,
0.2005526166027173,
0.189934483769103,
0.21699241688873014,
0.1878608494487583,
0.2009370942783152,
0.18916509833791825,
0.18732606336219698,
0.18729545694226638,
0.17616531394549415,
0.1873442790287937,
0.18464299187892014,
0.16848225999846417,
0.1521743570661379,
0.16208869308774543,
0.15208105374382347,
0.16345988764053482,
0.15889255851347986,
0.1423531938981042,
0.14329825668858412,
0.14804775817876953,
0.13917387437680337,
0.1278887752793793,
0.13149240751488087,
0.1295400700237379,
0.11739795816202576,
0.10729667171039897,
0.11524664730422644,
0.10759468113159777,
0.1041740366918409,
0.09709937493993431,
0.09480845095175836,
0.09410861701680681,
0.09673418626428032,
0.09225290424278022,
0.08629086040533505,
0.08498162722899619,
0.08168459130833164,
0.0742182365945007,
0.07380153158775125,
0.07149151855140415,
0.06872231239437038,
0.06956076094259818,
0.06435911980274524,
0.0590689304865833,
0.05689009580697514,
0.05565431249391638,
0.05011548129961085,
0.05019014166532737,
0.04653629917317094,
0.04471593751556152,
0.04516388761460718,
0.04336228132502645,
0.04008000360455555,
0.03857759905493476,
0.037857113883663446,
0.035939875439915836,
0.035004505209035724,
0.031199621726185657,
0.02919165365712351,
0.030646794869636115,
0.02958681059575393,
0.027908083806152686,
0.027144664387083333,
0.026594226197817195,
0.026633828753529865,
0.025609989997803465,
0.023096644604164693,
0.02264232223885309,
0.02160924944964946
],
"fid": {
"25": 237.9630584716797,
"50": 250.21066284179688,
"75": 254.84861755371094,
"100": 259.5051574707031
},
"train_time_s": 2268.0758962631226
},
"n_params": 10608451
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,14 +0,0 @@
{
"created_at": "2026-05-01T23:33:40.557082+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 35985830,
"offer_id": 30940273,
"ssh_host": "ssh4.vast.ai",
"ssh_port": 25830,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -1,14 +0,0 @@
{
"created_at": "2026-05-02T01:21:14.070736+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 35989396,
"offer_id": 35960853,
"ssh_host": "ssh6.vast.ai",
"ssh_port": 29396,
"status": "cancelled",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -1,14 +0,0 @@
{
"created_at": "2026-05-02T01:28:24.565349+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 35989623,
"offer_id": 29302404,
"ssh_host": "ssh9.vast.ai",
"ssh_port": 29622,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
Binary file not shown.

Before

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 204 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 128 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 125 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 215 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 68 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 118 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 247 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 141 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 127 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

+6 -13
View File
@@ -408,11 +408,12 @@ def train_vae(
Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
free_bits > 0 → per-dimension KL free bits (prevents posterior
collapse and KL explosion)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
KL is computed as mean over latent dimensions (scale-invariant), so
beta_kl is comparable across different latent_dim values.
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
@@ -432,7 +433,6 @@ def train_vae(
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
lr_d = cfg.get("lr_d", 1e-4)
free_bits_val = cfg.get("free_bits", 0.0)
grad_clip = cfg.get("grad_clip", 1.0)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
@@ -441,7 +441,6 @@ def train_vae(
use_perceptual = lambda_perceptual > 0
use_adversarial = lambda_adversarial > 0
use_free_bits = free_bits_val > 0
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
@@ -509,7 +508,7 @@ def train_vae(
print(
f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
f" λ_adv={lambda_adversarial} free_bits={free_bits_val}"
f" λ_adv={lambda_adversarial}"
)
t_start = time.time()
@@ -532,14 +531,8 @@ def train_vae(
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
# KL divergence with optional free bits
kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) # (B, latent_dim)
if use_free_bits:
# Free bits: ensure each dimension contributes at least free_bits_val KL.
# Dimensions below the threshold are raised to it, preventing posterior
# collapse (dimensions that go to 0) while still penalising large KL.
kl_per_dim = torch.clamp(kl_per_dim, min=free_bits_val)
kl = kl_per_dim.sum(1).mean()
# KL divergence: mean over latent dims (scale-invariant w.r.t. latent_dim)
kl = (-0.5 * (1 + log_var - mu.pow(2) - log_var.exp())).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc