VAE fix w/ new results

This commit is contained in:
Johnny Fernandes
2026-05-02 00:32:45 +01:00
parent 1a7f67ab9c
commit bac52bc15e
90 changed files with 1197 additions and 1106 deletions
@@ -7,6 +7,7 @@
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"grad_clip": 1.0,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
+2 -2
View File
@@ -1,8 +1,8 @@
{
"extends": "_base_phase3.json",
"run_name": "p3_1_vae",
"lr": 1e-3,
"beta_kl": 1.0,
"lr": 5e-4,
"beta_kl": 0.25,
"lambda_perceptual": 0.0,
"lambda_adversarial": 0.0
}
@@ -1,8 +1,8 @@
{
"extends": "_base_phase3.json",
"run_name": "p3_2_vae_perceptual",
"lr": 1e-3,
"beta_kl": 0.0001,
"lr": 5e-4,
"beta_kl": 0.25,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.0
}
@@ -1,10 +1,10 @@
{
"extends": "_base_phase3.json",
"run_name": "p3_3_vae_patchgan",
"lr": 1e-3,
"lr": 5e-4,
"lr_d": 1e-4,
"beta_kl": 0.0001,
"beta_kl": 0.25,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.1,
"lambda_adversarial": 0.01,
"ndf_patch": 64
}
+2 -1
View File
@@ -6,5 +6,6 @@
"subsample": 1.0,
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000
"fid_n_real": 5000,
"num_workers": 2
}
+209 -207
View File
@@ -11,222 +11,224 @@
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000,
"num_workers": 2,
"epochs": 100,
"augment": "hflip",
"image_size": 64,
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"grad_clip": 1.0,
"run_name": "p3_1_vae",
"lr": 0.001,
"beta_kl": 1.0,
"lr": 0.0005,
"beta_kl": 0.25,
"lambda_perceptual": 0.0,
"lambda_adversarial": 0.0
},
"history": {
"recon_loss": [
0.23614721197603095,
0.23315699178821,
0.22991716011594504,
NaN,
0.23217070787253544,
0.23155480842941847,
0.23157141198459855,
0.23181156750418183,
0.23201335527193853,
0.23178868266379732,
0.2315022333755962,
0.2311908418042028,
0.23185610672474927,
0.23176095832107413,
0.23165411693163407,
0.23174296459581098,
0.2317636658747991,
0.2317118427883356,
0.23172695364834917,
0.2316696329567677,
0.23168399261358458,
0.2316194716681782,
0.23164867447354856,
0.2315481170757204,
0.23165068109957582,
0.23167062098653907,
0.23162642907765177,
0.2315922882567104,
0.2315914996414103,
0.23156180984189367,
0.23156551628286004,
0.2315698005259037,
0.2315660522470617,
0.23156735001720935,
0.23161396435183337,
0.23158050178844705,
0.23159921089680785,
0.23149616745674712,
0.23159087484336308,
0.23156312872201967,
0.23153820200863048,
0.2315863819203825,
0.23150022140043414,
0.23154497337646973,
0.2315601774961011,
0.23153368950399578,
0.23152085642019907,
0.23151608884461924,
0.23154898990805334,
0.23155892872784892,
NaN,
NaN,
NaN,
0.24157701413600874,
NaN,
NaN,
0.24151325464630738,
NaN,
0.24154121766233036,
0.24155463749526912,
0.24158300176008135,
0.24158118757554609,
0.2415294518901242,
0.24156020069096842,
0.2415176352374574,
0.2415566616015047,
NaN,
0.24161437115608117,
0.24159398913765565,
0.24149432768806434,
0.24153172199287984,
0.24161516999204954,
0.24158193846034187,
0.2415451397562129,
0.24155487772873324,
0.24155297130346298,
NaN,
0.24157197961313093,
NaN,
NaN,
0.24158605401459923,
0.24156368870893094,
0.24159100852333582,
0.24153350121699846,
0.24153158377505776,
NaN,
0.24161708673350832,
0.24158515879868442,
NaN,
0.24157126235146809,
0.24162366709265953,
NaN,
0.2415581897665293,
NaN,
NaN,
0.2415400046823371,
NaN,
0.2415627600927638,
0.2415567432076503,
0.2415620140476614
0.07131652818180811,
0.04839004274521373,
0.04452062843956499,
0.04353479713870165,
0.043699693698913626,
0.044162471062288836,
0.044934689377744995,
0.04592526849741355,
0.047077489165095694,
0.04793272184160275,
0.04926021105776995,
0.05025381562260226,
0.05120742289174316,
0.052108900581733286,
0.053131219620505966,
0.05401940069869798,
0.05489145301314246,
0.055811726878214084,
0.05665480460907914,
0.0575029267412093,
0.05754986319404382,
0.057160187035034865,
0.05722308493991438,
0.05713338962891418,
0.05686132515119946,
0.056816555862116,
0.05664816373784063,
0.05655015064164614,
0.056517735943516605,
0.056386624335542194,
0.05631861151156262,
0.056178740154092126,
0.056074508314586095,
0.05601463455738675,
0.05584243320438088,
0.05574028127086468,
0.05563880511137665,
0.055547706926098235,
0.05556490144923202,
0.05538980011692923,
0.05529476007303366,
0.05527778912303794,
0.0552029303378529,
0.05519345425801654,
0.05497326165374018,
0.05496025659366805,
0.0549636375095345,
0.0548208407086567,
0.05475613919015114,
0.05467982544826391,
0.05467521703332408,
0.05451676477160719,
0.054423171549271315,
0.05440536335620106,
0.054226587956341415,
0.05415793756643931,
0.0540492157650809,
0.0539091299725776,
0.05381450568859139,
0.053790501263151824,
0.053688646684217654,
0.05361353090176216,
0.05348594906206569,
0.053407231775613934,
0.05329926665394734,
0.053199282489143886,
0.05318549033413585,
0.053034571994446285,
0.05299898797375524,
0.052894837278713525,
0.05281204023422339,
0.05279633629685029,
0.05266677101071064,
0.0525879480310867,
0.052519615285862714,
0.05243872805761221,
0.05236124007715883,
0.052327762763851725,
0.05222526421117732,
0.05212976400636964,
0.05209252470706263,
0.0520137355177321,
0.051939742568020635,
0.05186714857625656,
0.051828445420942754,
0.051747049658726424,
0.05169421551414789,
0.05160299702109689,
0.05152464638917874,
0.051478635081941754,
0.0514086935764704,
0.05138895747403049,
0.051289146423785605,
0.05126826443637793,
0.05114383632555986,
0.051157466693120636,
0.05104086854550828,
0.05102811497437139,
0.0509683930545918,
0.05096899156068635
],
"kl_loss": [
12.394881742504927,
184.775765717539,
127.26797539963681,
33346392.786626913,
35.72433020722153,
31.41954361882984,
16.178619678203876,
10.234501274223001,
14.817130448471787,
9.230570034084158,
9.643558593896719,
8.47786058498244,
5.573643362929678,
2.4644629534365783,
1.5757666807462516,
0.426466258131286,
1.7924597560404203,
0.2769168242652956,
0.21636260826236162,
0.48804672485870176,
0.10833573165453142,
0.13318477837671328,
0.17373992544877478,
0.09584700099678121,
0.0977757986014088,
0.07794108981282538,
0.05691333960853199,
0.07221067506167242,
0.036222075203704275,
0.03126689469696492,
0.04264315036642882,
0.016960328184147805,
0.03314871971324309,
0.014776984407789368,
0.011375301962312406,
0.013948339588828703,
0.01186063720120324,
0.0099704478863372,
0.00536374123289417,
0.009618068660179583,
0.00418840028031164,
0.004865833775052785,
0.005830266629345715,
0.0023000687699064487,
0.0038261460966199762,
0.0022056369562673136,
0.002220870125003987,
0.0024217167485139184,
0.001954249278483037,
0.0021431104709895756,
0.0022583500309011494,
0.002132287005193404,
56.80083886633675,
82.57108385134966,
82.55195800259582,
82.57428529527452,
82.56009972401155,
82.55269345666608,
82.57006728343474,
82.55670593131302,
82.54445134676419,
82.57745079301361,
82.57933913336859,
82.5570435157189,
82.56808758597089,
82.56800172267816,
82.56525711320404,
82.56189481621115,
82.55193622295673,
82.55375865382007,
82.56600202250685,
82.57064581324912,
82.55481151026538,
82.55367833324986,
82.56042112040724,
82.5616829048874,
82.5771528553759,
82.55317820035495,
82.57550573756552,
82.57334061973116,
82.56044387817383,
82.5752662593483,
82.56673936762361,
82.56828115740393,
82.56990289280557,
82.55218840052939,
82.56695372426611,
82.575043066954,
82.55754522991995,
82.56361721723508,
82.5628145821074,
82.56431990403395,
82.55777725806603,
82.5742861918914,
82.56361025622768,
82.56887233766736,
82.56539458902473,
82.55887828729091,
82.56073884882478,
82.55578186165573
0.7611013715847944,
0.5151619965321997,
0.39718208143599015,
0.3297253664360087,
0.2878693372138545,
0.257130523904776,
0.23313261619490436,
0.21345858896772066,
0.1973419941197603,
0.18355375779872266,
0.1717705535583007,
0.1612550135797415,
0.15217030946260843,
0.1444861331047156,
0.13738160394132137,
0.1306394640611023,
0.12505544035926333,
0.12011377760169344,
0.11535473169488275,
0.1109595281095841,
0.11022479862420477,
0.10970949868743236,
0.10961869400408533,
0.10945094661771232,
0.10919397698444688,
0.10875842463957448,
0.10874241859548622,
0.10849472851707385,
0.10825413890565053,
0.10820380448658243,
0.10811882353045492,
0.10812433239104402,
0.10808505038293,
0.10790401608006567,
0.10787000140955305,
0.10800883411151221,
0.10767164218247446,
0.10764909312765822,
0.10733446480435693,
0.10740953346348217,
0.10733820650822078,
0.10725643148279598,
0.10736855125834799,
0.10717424525855443,
0.10725876993030055,
0.10694582887694366,
0.10713813684753373,
0.10726828694853008,
0.10701580285134478,
0.10700331553498395,
0.10686330029215568,
0.10687567073947345,
0.10698378102010132,
0.10673714201483461,
0.10696066876188812,
0.10678731051520404,
0.10679936213180041,
0.10696167247290285,
0.1067299399142846,
0.10684242702893212,
0.10679143969701906,
0.10698746744957235,
0.10674736718846183,
0.10685917330730675,
0.1070305180664246,
0.10704450406388849,
0.10696375739370656,
0.10697911899441327,
0.10700849091841115,
0.10684541509383255,
0.10709039088434134,
0.10708275965900503,
0.10700157213096435,
0.10683403667221722,
0.10696323639434627,
0.10717970753709476,
0.10707768420569408,
0.10707720299052377,
0.10703401576377387,
0.10714245904396233,
0.10722182246928032,
0.10714326931052229,
0.10708108509325573,
0.10726493315245861,
0.10718185655199565,
0.10716394220407192,
0.10727782184496903,
0.10729229825938869,
0.10722832862510641,
0.10727461647146787,
0.10739002018593825,
0.10721855878065793,
0.10737398387784632,
0.10721981757853785,
0.10756766480895188,
0.10733713450021723,
0.10742478621884799,
0.10721213524986027,
0.10737172850113139,
0.10744189095293355
],
"perc_loss": [
0.0,
@@ -535,12 +537,12 @@
0.0
],
"fid": {
"25": 315.9393615722656,
"50": 419.273193359375,
"75": 360.4432678222656,
"100": 363.9911193847656
"25": 108.92365264892578,
"50": 93.73921203613281,
"75": 90.11531829833984,
"100": 88.40287780761719
},
"train_time_s": 660.9630489349365
"train_time_s": 1526.7542352676392
},
"n_params": 10608451
}
+309 -307
View File
@@ -11,324 +11,326 @@
"sample_interval": 10,
"fid_interval": 25,
"fid_n_real": 5000,
"num_workers": 2,
"epochs": 100,
"augment": "hflip",
"image_size": 64,
"model": "vae",
"latent_dim": 256,
"ngf": 64,
"grad_clip": 1.0,
"run_name": "p3_2_vae_perceptual",
"lr": 0.001,
"beta_kl": 0.0001,
"lr": 0.0005,
"beta_kl": 0.25,
"lambda_perceptual": 0.1,
"lambda_adversarial": 0.0
},
"history": {
"recon_loss": [
NaN,
0.041370747770127066,
0.035506031866002284,
0.06229733923672993,
0.06089382184048494,
0.053464704654856116,
0.04950355895213846,
0.04674786329269409,
0.05693557829811023,
0.04808994451076047,
NaN,
0.05869839608701121,
0.05293231666820426,
0.05057006603918779,
0.04713928615117175,
0.04481589820427008,
0.04331832689543565,
0.041820655449524395,
NaN,
0.24315394992884407,
0.24305414841470555,
0.2429187158998261,
0.24289727962424612,
0.2377373814328104,
0.23880945107875726,
0.24047312904626894,
0.23915709144411942,
NaN,
0.24000247061634675,
0.2362435321904655,
0.23705266075383905,
NaN,
0.23998981039238793,
0.23828768376738596,
0.23908851302077627,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN
0.07014625494042014,
0.04432385861396025,
0.03902342810462683,
0.036902043435117625,
0.035906293887135565,
0.035914984559560686,
0.0360899488401846,
0.036847473583860785,
0.037394630213260144,
0.03836990671200503,
0.039295883698022775,
0.040212508386526354,
0.04124497003757801,
0.04212733724305772,
0.04325710433638758,
0.04402907300963361,
0.045103192679647706,
0.04595246975525068,
0.04698181097419598,
0.047825980820080154,
0.04772546164627768,
0.04763130701950982,
0.047709369705591954,
0.04753121634961194,
0.04741841174948674,
0.04729832872016053,
0.04722155279551561,
0.0469712430340612,
0.04693082829093576,
0.0468663705401441,
0.04660685717040657,
0.04661381538384236,
0.046462044382515624,
0.0464134263431924,
0.046272445438254595,
0.04624118105882508,
0.046066727958874315,
0.04603223238363225,
0.04589838094404365,
0.0458371472489248,
0.04577782691225537,
0.04567820464985238,
0.04560983914913785,
0.045465615761076286,
0.04543259674603613,
0.04530587481159685,
0.04528731873465909,
0.04525673297098559,
0.04507890154217553,
0.0450239604514124,
0.04497251241730574,
0.044865927300774135,
0.04475303162207715,
0.044664276763796806,
0.044507257990602754,
0.0444452501515038,
0.044273821509674065,
0.044196275588220514,
0.044097925846775375,
0.04397543971864586,
0.04393626836279773,
0.043823353540247835,
0.04371356347209623,
0.043580674797169164,
0.04348461520181507,
0.04339473670682846,
0.043282039009798795,
0.04320543994888281,
0.04314021340324583,
0.04300904847904403,
0.04293948887950844,
0.04286929544730064,
0.04278319672896312,
0.04270272867547141,
0.0425933551393513,
0.04250210111276207,
0.04243518800562263,
0.04233357961424905,
0.04227217231105026,
0.042201977119677596,
0.04207655237430436,
0.04206944097024508,
0.0420022471768097,
0.041929350385808535,
0.04179893332159417,
0.041740718671781384,
0.04166606991177695,
0.04158131634959808,
0.04155632069445828,
0.04140628448440733,
0.041379960253834724,
0.041310225080093764,
0.04123378162168794,
0.04118641266105776,
0.04113346862837545,
0.041023712032116376,
0.04102735277902112,
0.04096395063062764,
0.04094440599059702,
0.040864359348630294
],
"kl_loss": [
41526536.5961082,
1258.7768262553418,
1165.5987154968784,
3689.321748260759,
1320.3167513333833,
928.3473894168169,
789.754654843583,
693.4729607736963,
1017.8856519389357,
622.2623097998464,
304942.2120989938,
750.5718396830763,
618.6888906364767,
586.4194498958751,
579.2793892102363,
525.8830437945504,
497.5363791050055,
452.93342662061383,
172476214612.95245,
1112.027792645316,
1109.8156276605068,
1109.8872808472722,
1112.3999198196282,
1163.6093267457097,
1195.4633280436199,
1232.709252512353,
1253.0651078183428,
19923.65347159622,
6871.588723207132,
6872.017739842081,
6872.62934758113,
27062.17111336472,
4992.576698759683,
5000.895510942508,
5011.723715236044,
4.196954763212122e+20,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN
0.9578518337673612,
0.8413855069213443,
0.7201743645545764,
0.6246620782165446,
0.5550231880102402,
0.49975109450582766,
0.454662803783376,
0.4155570190941167,
0.3821740793621438,
0.352426735827556,
0.32639839258204156,
0.3032199253893306,
0.2829984358360625,
0.26505850424241817,
0.24903864537676176,
0.23453107457130384,
0.22166774176761636,
0.21033292894180006,
0.19950743799662998,
0.18994869887192026,
0.18805619294189999,
0.18697467955768618,
0.1855154659312505,
0.1852042447680082,
0.18474761941112006,
0.1843619988514827,
0.18399313028551575,
0.18352704119478536,
0.18345636556036451,
0.18366932483692455,
0.18345450644946507,
0.18317942117523944,
0.1830864556006387,
0.1831519056079734,
0.18329900547734693,
0.1829878402252992,
0.18321112291807803,
0.18296580338197896,
0.18312124166096377,
0.18301436457878503,
0.1831260528574642,
0.18320244878657863,
0.18332607687538505,
0.18335427452101666,
0.18322811696009758,
0.18336237903334138,
0.183573446308191,
0.18356411601615769,
0.18363770969912538,
0.18377095862076834,
0.18369814056234482,
0.18386876185098264,
0.18424720492245805,
0.18420857423518458,
0.18395239082921264,
0.18442093539569113,
0.1842692175036312,
0.18475162495787328,
0.18502490509014863,
0.18480629139603713,
0.1851997250993537,
0.1850732435973791,
0.1854018302809479,
0.18547267151566652,
0.18569090538936803,
0.1859345402664099,
0.18615943276219898,
0.18631401260057065,
0.1865774132948146,
0.1867309583303256,
0.18672093723574254,
0.18682665680336136,
0.18679539349853483,
0.18683723016427115,
0.1874776411578696,
0.18746165714712223,
0.1873003001141752,
0.18763184111215112,
0.18799093031348327,
0.18802655252635989,
0.18824852079662502,
0.1882636730169129,
0.18849152215143555,
0.1886035389561429,
0.188697873813729,
0.18883028218888828,
0.1889004738858113,
0.18922981174073666,
0.18946354345888153,
0.18955816268029377,
0.18944057507010606,
0.18984314531852037,
0.19001257871715432,
0.19001863161340737,
0.18997256346365327,
0.19019658920856622,
0.19018171425176483,
0.19038618680758354,
0.19034281325263855,
0.19046620505615178
],
"perc_loss": [
NaN,
3.0017659954535656,
2.8931196978968434,
3.084328340159522,
3.1398269641093717,
3.0839009172896032,
3.0459561949102287,
3.015714469628456,
3.0917934143645134,
3.010985380054539,
NaN,
3.102917709411719,
3.055045795746339,
3.0351512865123587,
3.002274121993627,
2.9754957128793764,
2.9553630066733074,
2.936704080328982,
NaN,
3.66463458385223,
3.6419444985878773,
3.6321474706005845,
3.6272282503608966,
3.781295938878997,
3.740793755421272,
3.759704489993234,
3.7078149914741516,
NaN,
3.7647176323792872,
3.7441278461717133,
3.7264530261357627,
NaN,
3.7299226170931106,
3.7314435080585318,
3.754668933713538,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN,
NaN
3.257610213552785,
2.987424520855276,
2.8892488026211405,
2.8424616914529066,
2.817538415774321,
2.807587738220508,
2.8047564518757357,
2.8080830421203222,
2.8133585376617236,
2.821376688969441,
2.830571848612565,
2.840076332927769,
2.8507282504668603,
2.8593530094521675,
2.8690790253826695,
2.877339167982085,
2.8865632745954724,
2.8936263704911256,
2.9023511684857883,
2.9098138376178904,
2.906375333794162,
2.9055831661591163,
2.9031203985214233,
2.9002673641229286,
2.897312533651662,
2.8946053095352955,
2.8919762974111443,
2.8902457397208257,
2.8877664065768576,
2.884983114197723,
2.882109224286854,
2.8805122767758164,
2.8777520192994013,
2.876453428186922,
2.874014714334765,
2.8723304679251123,
2.870547831058502,
2.86905308564504,
2.8669193524580736,
2.865350999383845,
2.8633312608441734,
2.861651761409564,
2.860612149421985,
2.858717398256318,
2.857239680412488,
2.8558127604998074,
2.853911985189487,
2.852791876364977,
2.850916815109742,
2.8496591963319697,
2.8490061500133614,
2.8465628389619355,
2.844754463587052,
2.842957689211919,
2.840948245464227,
2.8397679522506194,
2.837174654006958,
2.835211678957328,
2.833740895629948,
2.8325765499701867,
2.830295750218579,
2.8292511431579914,
2.8273453595291853,
2.8257394835480256,
2.823762384744791,
2.822507558215378,
2.8205288745399213,
2.819065808230995,
2.8176180297492914,
2.8165967805772767,
2.8143360721759305,
2.8135039892971005,
2.8120578970664587,
2.8098325041624217,
2.8090034366672874,
2.8076502920215964,
2.80601545684358,
2.805327193859296,
2.803745285058633,
2.802450499473474,
2.800580952412043,
2.799968459157862,
2.798532901156662,
2.797372330967178,
2.795650184663952,
2.79500452371744,
2.79342931152409,
2.792220654650631,
2.7911063600809145,
2.7897617689564695,
2.7887179505111823,
2.7880892407180915,
2.7868736916118197,
2.785768971993373,
2.784949649602939,
2.784032260760283,
2.783316181765662,
2.7821382457374506,
2.7815805586994204,
2.780150015639444
],
"adv_g_loss": [
0.0,
@@ -535,12 +537,12 @@
0.0
],
"fid": {
"25": 263.1458740234375,
"50": 598.3736572265625,
"75": 598.3736572265625,
"100": 598.3736572265625
"25": 85.4538345336914,
"50": 70.30448150634766,
"75": 68.88232421875,
"100": 68.23878479003906
},
"train_time_s": 952.6596128940582
"train_time_s": 1526.8077104091644
},
"n_params": 10608451
}
File diff suppressed because it is too large Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,14 @@
{
"created_at": "2026-05-02T13:07:06.633366+00:00",
"config_paths": [
"generator/configs/phase3/p3_1_vae.json",
"generator/configs/phase3/p3_2_vae_perceptual.json",
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 36014025,
"offer_id": 35956124,
"ssh_host": "ssh3.vast.ai",
"ssh_port": 14024,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,12 @@
{
"created_at": "2026-05-02T17:39:27.858755+00:00",
"config_paths": [
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 36025335,
"offer_id": 19025020,
"ssh_host": "ssh6.vast.ai",
"ssh_port": 25334,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,12 @@
{
"created_at": "2026-05-02T18:44:33.143461+00:00",
"config_paths": [
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 36027837,
"offer_id": 29548831,
"ssh_host": null,
"ssh_port": null,
"status": "cancelled",
"remote_workspace": "/workspace/DRL_PROJ"
}
@@ -0,0 +1,12 @@
{
"created_at": "2026-05-02T18:51:49.983505+00:00",
"config_paths": [
"generator/configs/phase3/p3_3_vae_patchgan.json"
],
"instance_id": 36028210,
"offer_id": 35956124,
"ssh_host": "ssh1.vast.ai",
"ssh_port": 28210,
"status": "completed",
"remote_workspace": "/workspace/DRL_PROJ"
}
Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB

After

Width:  |  Height:  |  Size: 93 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

After

Width:  |  Height:  |  Size: 221 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 78 KiB

After

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 120 KiB

After

Width:  |  Height:  |  Size: 214 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 76 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.5 KiB

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.5 KiB

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

After

Width:  |  Height:  |  Size: 212 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.3 KiB

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

After

Width:  |  Height:  |  Size: 211 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.4 KiB

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 211 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.7 KiB

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 121 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 KiB

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.3 KiB

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 122 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 91 KiB

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 215 KiB

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

After

Width:  |  Height:  |  Size: 86 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 135 KiB

After

Width:  |  Height:  |  Size: 205 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 87 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 129 KiB

After

Width:  |  Height:  |  Size: 206 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 89 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 207 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 93 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 93 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 285 B

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 118 KiB

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 247 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 31 KiB

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 141 KiB

After

Width:  |  Height:  |  Size: 209 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 127 KiB

After

Width:  |  Height:  |  Size: 210 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 212 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 212 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 212 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 99 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 214 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 214 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 215 KiB

+1 -1
View File
@@ -52,7 +52,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
# Count total trainable parameters
if isinstance(model, tuple):
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)
n_params = sum(p.numel() for m in model for p in m.parameters() if p.requires_grad)
else:
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {n_params:,}")
+11 -4
View File
@@ -22,17 +22,21 @@ def _init_weights(m):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) and m.weight is not None:
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
def _norm(channels: int) -> nn.GroupNorm:
return nn.GroupNorm(8, channels)
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
_norm(out_ch),
nn.ReLU(inplace=True),
)
@@ -65,7 +69,7 @@ class VAE(nn.Module):
for _ in range(n_down - 1):
enc_layers += [
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ch * 2),
_norm(ch * 2),
nn.LeakyReLU(0.2, inplace=True),
]
ch *= 2
@@ -98,9 +102,12 @@ class VAE(nn.Module):
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = self.encoder(x).flatten(1)
return self.fc_mu(h), self.fc_lv(h)
log_var = self.fc_lv(h).clamp(-10.0, 10.0)
return self.fc_mu(h), log_var
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
if not self.training:
return mu
std = torch.exp(0.5 * log_var)
return mu + std * torch.randn_like(std)
+2
View File
@@ -20,3 +20,5 @@ class EMA:
def update(self, model: nn.Module) -> None:
for p_ema, p in zip(self.model.parameters(), model.parameters()):
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
for b_ema, b in zip(self.model.buffers(), model.buffers()):
b_ema.copy_(b)
+3 -2
View File
@@ -11,14 +11,15 @@ from torchmetrics.image.fid import FrechetInceptionDistance
class FIDEvaluator:
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"):
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
num_workers: int = 2):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.n_real = n_real
# Cache real images as a CPU tensor ([-1, 1] range)
imgs_list = []
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
num_workers=4, drop_last=False)
num_workers=num_workers, drop_last=False)
for batch in loader:
imgs_list.append(batch.cpu())
if sum(x.size(0) for x in imgs_list) >= n_real:
+67 -44
View File
@@ -66,7 +66,7 @@ def train_dcgan(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
@@ -86,7 +86,8 @@ def train_dcgan(
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf")
@@ -239,7 +240,7 @@ def train_wgan(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
@@ -257,7 +258,8 @@ def train_wgan(
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
best_fid = float("inf")
@@ -391,7 +393,6 @@ def _save_vae_samples(
# Interleave real / reconstruction pairs
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
vae.train()
def train_vae(
@@ -403,13 +404,21 @@ def train_vae(
run_name: str,
device: str = "cuda",
) -> dict:
"""VAE training loop covering Phase 3.1 3.3.
"""VAE training loop covering Phase 3.1 3.3 and Phase 5.
Config toggles:
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
KL is computed as mean over latent dimensions (scale-invariant), so
beta_kl is comparable across different latent_dim values.
AMP is intentionally disabled for VAE training — mixed-precision float16
overflows when the KL divergence spikes, producing NaN cascades that
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
computation runs in float32.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
@@ -425,6 +434,7 @@ def train_vae(
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
lr_d = cfg.get("lr_d", 1e-4)
grad_clip = cfg.get("grad_clip", 1.0)
ema_decay = cfg.get("ema_decay", 0.9999)
sample_interval = cfg.get("sample_interval", 10)
fid_interval = cfg.get("fid_interval", 25)
@@ -435,13 +445,13 @@ def train_vae(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
use_amp = device.type == "cuda"
scaler = _GradScaler("cuda", enabled=use_amp)
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
use_amp = False
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
kl_warmup_epochs = max(1, epochs // 5)
@@ -456,11 +466,10 @@ def train_vae(
perc_fn = None
patchgan = None
opt_d = None
scaler_d = None
if use_perceptual:
from src.training.perceptual import PerceptualLoss
perc_fn = PerceptualLoss().to(device)
perc_fn = PerceptualLoss().to(device).float()
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
if use_adversarial:
@@ -468,15 +477,14 @@ def train_vae(
patchgan = PatchGANDiscriminator(
ndf=cfg.get("ndf_patch", 64),
image_size=cfg.get("image_size", 64),
).to(device)
).to(device).float()
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
scaler_d = _GradScaler("cuda", enabled=use_amp)
sched_d = torch.optim.lr_scheduler.LambdaLR(
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
n_d = sum(p.numel() for p in patchgan.parameters())
print(f"PatchGAN: {n_d:,} params")
else:
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called
hinge_d_loss = hinge_g_loss = None # never called
# ── Fixed seeds for consistent visualisation ──────────────────────────
fixed_z = torch.randn(16, latent_dim, device=device)
@@ -490,16 +498,19 @@ def train_vae(
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {
"recon_loss": [], "kl_loss": [], "perc_loss": [],
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
}
best_fid = float("inf")
nan_skipped = 0
print(
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}"
f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
f" λ_adv={lambda_adversarial}"
)
t_start = time.time()
@@ -513,43 +524,52 @@ def train_vae(
n_batches = 0
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
real = real.to(device)
real = real.to(device).float()
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
# ── VAE forward ───────────────────────────────────────────────
with _autocast("cuda", enabled=use_amp):
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── VAE forward (float32, no AMP) ────────────────────────────
recon, mu, log_var = vae(real)
mse = F.mse_loss(recon, real)
# KL divergence: mean over latent dims (scale-invariant w.r.t. latent_dim)
kl = (-0.5 * (1 + log_var - mu.pow(2) - log_var.exp())).mean()
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
# ── NaN/Inf guard ────────────────────────────────────────────
if not torch.isfinite(vae_loss):
nan_skipped += 1
opt_vae.zero_grad()
continue
# ── PatchGAN discriminator step ───────────────────────────────
adv_d = real.new_zeros(1).squeeze()
if use_adversarial:
opt_d.zero_grad()
with _autocast("cuda", enabled=use_amp):
# Warmup: only start adversarial training after 20% of epochs
if epoch > kl_warmup_epochs:
opt_d.zero_grad()
d_real = patchgan(real)
d_fake = patchgan(recon.detach())
adv_d = hinge_d_loss(d_real, d_fake)
scaler_d.scale(adv_d).backward()
scaler_d.step(opt_d)
scaler_d.update()
if torch.isfinite(adv_d):
adv_d.backward()
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
opt_d.step()
# ── PatchGAN generator adversarial loss ───────────────────────
adv_g = real.new_zeros(1).squeeze()
if use_adversarial:
with _autocast("cuda", enabled=use_amp):
adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g
if use_adversarial and epoch > kl_warmup_epochs:
adv_g = hinge_g_loss(patchgan(recon))
vae_loss = vae_loss + lambda_adversarial * adv_g
# ── VAE backward ──────────────────────────────────────────────
opt_vae.zero_grad()
scaler.scale(vae_loss).backward()
scaler.step(opt_vae)
scaler.update()
vae_loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip)
opt_vae.step()
ema.update(vae)
recon_sum += mse.item()
@@ -559,11 +579,11 @@ def train_vae(
adv_d_sum += adv_d.item()
n_batches += 1
avg_r = recon_sum / n_batches
avg_k = kl_sum / n_batches
avg_p = perc_sum / n_batches
avg_g = adv_g_sum / n_batches
avg_d = adv_d_sum / n_batches
avg_r = recon_sum / max(n_batches, 1)
avg_k = kl_sum / max(n_batches, 1)
avg_p = perc_sum / max(n_batches, 1)
avg_g = adv_g_sum / max(n_batches, 1)
avg_d = adv_d_sum / max(n_batches, 1)
history["recon_loss"].append(avg_r)
history["kl_loss"].append(avg_k)
history["perc_loss"].append(avg_p)
@@ -574,6 +594,7 @@ def train_vae(
f"[{epoch:03d}/{epochs}] "
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
f" (NaN skipped: {nan_skipped})"
)
if epoch % sample_interval == 0:
@@ -607,6 +628,7 @@ def train_vae(
if patchgan is not None:
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
history["train_time_s"] = time.time() - t_start
print(f"Total NaN-skipped batches: {nan_skipped}")
return history
@@ -662,7 +684,7 @@ def train_ddpm(
loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(4, os.cpu_count() or 1),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
pin_memory=(device.type == "cuda"), drop_last=True,
)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
@@ -679,7 +701,8 @@ def train_ddpm(
save_dir.mkdir(parents=True, exist_ok=True)
samples_dir = save_dir.parent / "samples" / run_name
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
history = {"loss": [], "fid": {}}
best_fid = float("inf")