VAE fix w/ new results
@@ -7,6 +7,7 @@
|
|||||||
"model": "vae",
|
"model": "vae",
|
||||||
"latent_dim": 256,
|
"latent_dim": 256,
|
||||||
"ngf": 64,
|
"ngf": 64,
|
||||||
|
"grad_clip": 1.0,
|
||||||
"sample_interval": 10,
|
"sample_interval": 10,
|
||||||
"fid_interval": 25,
|
"fid_interval": 25,
|
||||||
"fid_n_real": 5000
|
"fid_n_real": 5000
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
{
|
{
|
||||||
"extends": "_base_phase3.json",
|
"extends": "_base_phase3.json",
|
||||||
"run_name": "p3_1_vae",
|
"run_name": "p3_1_vae",
|
||||||
"lr": 1e-3,
|
"lr": 5e-4,
|
||||||
"beta_kl": 1.0,
|
"beta_kl": 0.25,
|
||||||
"lambda_perceptual": 0.0,
|
"lambda_perceptual": 0.0,
|
||||||
"lambda_adversarial": 0.0
|
"lambda_adversarial": 0.0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
{
|
{
|
||||||
"extends": "_base_phase3.json",
|
"extends": "_base_phase3.json",
|
||||||
"run_name": "p3_2_vae_perceptual",
|
"run_name": "p3_2_vae_perceptual",
|
||||||
"lr": 1e-3,
|
"lr": 5e-4,
|
||||||
"beta_kl": 0.0001,
|
"beta_kl": 0.25,
|
||||||
"lambda_perceptual": 0.1,
|
"lambda_perceptual": 0.1,
|
||||||
"lambda_adversarial": 0.0
|
"lambda_adversarial": 0.0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
{
|
{
|
||||||
"extends": "_base_phase3.json",
|
"extends": "_base_phase3.json",
|
||||||
"run_name": "p3_3_vae_patchgan",
|
"run_name": "p3_3_vae_patchgan",
|
||||||
"lr": 1e-3,
|
"lr": 5e-4,
|
||||||
"lr_d": 1e-4,
|
"lr_d": 1e-4,
|
||||||
"beta_kl": 0.0001,
|
"beta_kl": 0.25,
|
||||||
"lambda_perceptual": 0.1,
|
"lambda_perceptual": 0.1,
|
||||||
"lambda_adversarial": 0.1,
|
"lambda_adversarial": 0.01,
|
||||||
"ndf_patch": 64
|
"ndf_patch": 64
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,5 +6,6 @@
|
|||||||
"subsample": 1.0,
|
"subsample": 1.0,
|
||||||
"sample_interval": 10,
|
"sample_interval": 10,
|
||||||
"fid_interval": 25,
|
"fid_interval": 25,
|
||||||
"fid_n_real": 5000
|
"fid_n_real": 5000,
|
||||||
|
"num_workers": 2
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,222 +11,224 @@
|
|||||||
"sample_interval": 10,
|
"sample_interval": 10,
|
||||||
"fid_interval": 25,
|
"fid_interval": 25,
|
||||||
"fid_n_real": 5000,
|
"fid_n_real": 5000,
|
||||||
|
"num_workers": 2,
|
||||||
"epochs": 100,
|
"epochs": 100,
|
||||||
"augment": "hflip",
|
"augment": "hflip",
|
||||||
"image_size": 64,
|
"image_size": 64,
|
||||||
"model": "vae",
|
"model": "vae",
|
||||||
"latent_dim": 256,
|
"latent_dim": 256,
|
||||||
"ngf": 64,
|
"ngf": 64,
|
||||||
|
"grad_clip": 1.0,
|
||||||
"run_name": "p3_1_vae",
|
"run_name": "p3_1_vae",
|
||||||
"lr": 0.001,
|
"lr": 0.0005,
|
||||||
"beta_kl": 1.0,
|
"beta_kl": 0.25,
|
||||||
"lambda_perceptual": 0.0,
|
"lambda_perceptual": 0.0,
|
||||||
"lambda_adversarial": 0.0
|
"lambda_adversarial": 0.0
|
||||||
},
|
},
|
||||||
"history": {
|
"history": {
|
||||||
"recon_loss": [
|
"recon_loss": [
|
||||||
0.23614721197603095,
|
0.07131652818180811,
|
||||||
0.23315699178821,
|
0.04839004274521373,
|
||||||
0.22991716011594504,
|
0.04452062843956499,
|
||||||
NaN,
|
0.04353479713870165,
|
||||||
0.23217070787253544,
|
0.043699693698913626,
|
||||||
0.23155480842941847,
|
0.044162471062288836,
|
||||||
0.23157141198459855,
|
0.044934689377744995,
|
||||||
0.23181156750418183,
|
0.04592526849741355,
|
||||||
0.23201335527193853,
|
0.047077489165095694,
|
||||||
0.23178868266379732,
|
0.04793272184160275,
|
||||||
0.2315022333755962,
|
0.04926021105776995,
|
||||||
0.2311908418042028,
|
0.05025381562260226,
|
||||||
0.23185610672474927,
|
0.05120742289174316,
|
||||||
0.23176095832107413,
|
0.052108900581733286,
|
||||||
0.23165411693163407,
|
0.053131219620505966,
|
||||||
0.23174296459581098,
|
0.05401940069869798,
|
||||||
0.2317636658747991,
|
0.05489145301314246,
|
||||||
0.2317118427883356,
|
0.055811726878214084,
|
||||||
0.23172695364834917,
|
0.05665480460907914,
|
||||||
0.2316696329567677,
|
0.0575029267412093,
|
||||||
0.23168399261358458,
|
0.05754986319404382,
|
||||||
0.2316194716681782,
|
0.057160187035034865,
|
||||||
0.23164867447354856,
|
0.05722308493991438,
|
||||||
0.2315481170757204,
|
0.05713338962891418,
|
||||||
0.23165068109957582,
|
0.05686132515119946,
|
||||||
0.23167062098653907,
|
0.056816555862116,
|
||||||
0.23162642907765177,
|
0.05664816373784063,
|
||||||
0.2315922882567104,
|
0.05655015064164614,
|
||||||
0.2315914996414103,
|
0.056517735943516605,
|
||||||
0.23156180984189367,
|
0.056386624335542194,
|
||||||
0.23156551628286004,
|
0.05631861151156262,
|
||||||
0.2315698005259037,
|
0.056178740154092126,
|
||||||
0.2315660522470617,
|
0.056074508314586095,
|
||||||
0.23156735001720935,
|
0.05601463455738675,
|
||||||
0.23161396435183337,
|
0.05584243320438088,
|
||||||
0.23158050178844705,
|
0.05574028127086468,
|
||||||
0.23159921089680785,
|
0.05563880511137665,
|
||||||
0.23149616745674712,
|
0.055547706926098235,
|
||||||
0.23159087484336308,
|
0.05556490144923202,
|
||||||
0.23156312872201967,
|
0.05538980011692923,
|
||||||
0.23153820200863048,
|
0.05529476007303366,
|
||||||
0.2315863819203825,
|
0.05527778912303794,
|
||||||
0.23150022140043414,
|
0.0552029303378529,
|
||||||
0.23154497337646973,
|
0.05519345425801654,
|
||||||
0.2315601774961011,
|
0.05497326165374018,
|
||||||
0.23153368950399578,
|
0.05496025659366805,
|
||||||
0.23152085642019907,
|
0.0549636375095345,
|
||||||
0.23151608884461924,
|
0.0548208407086567,
|
||||||
0.23154898990805334,
|
0.05475613919015114,
|
||||||
0.23155892872784892,
|
0.05467982544826391,
|
||||||
NaN,
|
0.05467521703332408,
|
||||||
NaN,
|
0.05451676477160719,
|
||||||
NaN,
|
0.054423171549271315,
|
||||||
0.24157701413600874,
|
0.05440536335620106,
|
||||||
NaN,
|
0.054226587956341415,
|
||||||
NaN,
|
0.05415793756643931,
|
||||||
0.24151325464630738,
|
0.0540492157650809,
|
||||||
NaN,
|
0.0539091299725776,
|
||||||
0.24154121766233036,
|
0.05381450568859139,
|
||||||
0.24155463749526912,
|
0.053790501263151824,
|
||||||
0.24158300176008135,
|
0.053688646684217654,
|
||||||
0.24158118757554609,
|
0.05361353090176216,
|
||||||
0.2415294518901242,
|
0.05348594906206569,
|
||||||
0.24156020069096842,
|
0.053407231775613934,
|
||||||
0.2415176352374574,
|
0.05329926665394734,
|
||||||
0.2415566616015047,
|
0.053199282489143886,
|
||||||
NaN,
|
0.05318549033413585,
|
||||||
0.24161437115608117,
|
0.053034571994446285,
|
||||||
0.24159398913765565,
|
0.05299898797375524,
|
||||||
0.24149432768806434,
|
0.052894837278713525,
|
||||||
0.24153172199287984,
|
0.05281204023422339,
|
||||||
0.24161516999204954,
|
0.05279633629685029,
|
||||||
0.24158193846034187,
|
0.05266677101071064,
|
||||||
0.2415451397562129,
|
0.0525879480310867,
|
||||||
0.24155487772873324,
|
0.052519615285862714,
|
||||||
0.24155297130346298,
|
0.05243872805761221,
|
||||||
NaN,
|
0.05236124007715883,
|
||||||
0.24157197961313093,
|
0.052327762763851725,
|
||||||
NaN,
|
0.05222526421117732,
|
||||||
NaN,
|
0.05212976400636964,
|
||||||
0.24158605401459923,
|
0.05209252470706263,
|
||||||
0.24156368870893094,
|
0.0520137355177321,
|
||||||
0.24159100852333582,
|
0.051939742568020635,
|
||||||
0.24153350121699846,
|
0.05186714857625656,
|
||||||
0.24153158377505776,
|
0.051828445420942754,
|
||||||
NaN,
|
0.051747049658726424,
|
||||||
0.24161708673350832,
|
0.05169421551414789,
|
||||||
0.24158515879868442,
|
0.05160299702109689,
|
||||||
NaN,
|
0.05152464638917874,
|
||||||
0.24157126235146809,
|
0.051478635081941754,
|
||||||
0.24162366709265953,
|
0.0514086935764704,
|
||||||
NaN,
|
0.05138895747403049,
|
||||||
0.2415581897665293,
|
0.051289146423785605,
|
||||||
NaN,
|
0.05126826443637793,
|
||||||
NaN,
|
0.05114383632555986,
|
||||||
0.2415400046823371,
|
0.051157466693120636,
|
||||||
NaN,
|
0.05104086854550828,
|
||||||
0.2415627600927638,
|
0.05102811497437139,
|
||||||
0.2415567432076503,
|
0.0509683930545918,
|
||||||
0.2415620140476614
|
0.05096899156068635
|
||||||
],
|
],
|
||||||
"kl_loss": [
|
"kl_loss": [
|
||||||
12.394881742504927,
|
0.7611013715847944,
|
||||||
184.775765717539,
|
0.5151619965321997,
|
||||||
127.26797539963681,
|
0.39718208143599015,
|
||||||
33346392.786626913,
|
0.3297253664360087,
|
||||||
35.72433020722153,
|
0.2878693372138545,
|
||||||
31.41954361882984,
|
0.257130523904776,
|
||||||
16.178619678203876,
|
0.23313261619490436,
|
||||||
10.234501274223001,
|
0.21345858896772066,
|
||||||
14.817130448471787,
|
0.1973419941197603,
|
||||||
9.230570034084158,
|
0.18355375779872266,
|
||||||
9.643558593896719,
|
0.1717705535583007,
|
||||||
8.47786058498244,
|
0.1612550135797415,
|
||||||
5.573643362929678,
|
0.15217030946260843,
|
||||||
2.4644629534365783,
|
0.1444861331047156,
|
||||||
1.5757666807462516,
|
0.13738160394132137,
|
||||||
0.426466258131286,
|
0.1306394640611023,
|
||||||
1.7924597560404203,
|
0.12505544035926333,
|
||||||
0.2769168242652956,
|
0.12011377760169344,
|
||||||
0.21636260826236162,
|
0.11535473169488275,
|
||||||
0.48804672485870176,
|
0.1109595281095841,
|
||||||
0.10833573165453142,
|
0.11022479862420477,
|
||||||
0.13318477837671328,
|
0.10970949868743236,
|
||||||
0.17373992544877478,
|
0.10961869400408533,
|
||||||
0.09584700099678121,
|
0.10945094661771232,
|
||||||
0.0977757986014088,
|
0.10919397698444688,
|
||||||
0.07794108981282538,
|
0.10875842463957448,
|
||||||
0.05691333960853199,
|
0.10874241859548622,
|
||||||
0.07221067506167242,
|
0.10849472851707385,
|
||||||
0.036222075203704275,
|
0.10825413890565053,
|
||||||
0.03126689469696492,
|
0.10820380448658243,
|
||||||
0.04264315036642882,
|
0.10811882353045492,
|
||||||
0.016960328184147805,
|
0.10812433239104402,
|
||||||
0.03314871971324309,
|
0.10808505038293,
|
||||||
0.014776984407789368,
|
0.10790401608006567,
|
||||||
0.011375301962312406,
|
0.10787000140955305,
|
||||||
0.013948339588828703,
|
0.10800883411151221,
|
||||||
0.01186063720120324,
|
0.10767164218247446,
|
||||||
0.0099704478863372,
|
0.10764909312765822,
|
||||||
0.00536374123289417,
|
0.10733446480435693,
|
||||||
0.009618068660179583,
|
0.10740953346348217,
|
||||||
0.00418840028031164,
|
0.10733820650822078,
|
||||||
0.004865833775052785,
|
0.10725643148279598,
|
||||||
0.005830266629345715,
|
0.10736855125834799,
|
||||||
0.0023000687699064487,
|
0.10717424525855443,
|
||||||
0.0038261460966199762,
|
0.10725876993030055,
|
||||||
0.0022056369562673136,
|
0.10694582887694366,
|
||||||
0.002220870125003987,
|
0.10713813684753373,
|
||||||
0.0024217167485139184,
|
0.10726828694853008,
|
||||||
0.001954249278483037,
|
0.10701580285134478,
|
||||||
0.0021431104709895756,
|
0.10700331553498395,
|
||||||
0.0022583500309011494,
|
0.10686330029215568,
|
||||||
0.002132287005193404,
|
0.10687567073947345,
|
||||||
56.80083886633675,
|
0.10698378102010132,
|
||||||
82.57108385134966,
|
0.10673714201483461,
|
||||||
82.55195800259582,
|
0.10696066876188812,
|
||||||
82.57428529527452,
|
0.10678731051520404,
|
||||||
82.56009972401155,
|
0.10679936213180041,
|
||||||
82.55269345666608,
|
0.10696167247290285,
|
||||||
82.57006728343474,
|
0.1067299399142846,
|
||||||
82.55670593131302,
|
0.10684242702893212,
|
||||||
82.54445134676419,
|
0.10679143969701906,
|
||||||
82.57745079301361,
|
0.10698746744957235,
|
||||||
82.57933913336859,
|
0.10674736718846183,
|
||||||
82.5570435157189,
|
0.10685917330730675,
|
||||||
82.56808758597089,
|
0.1070305180664246,
|
||||||
82.56800172267816,
|
0.10704450406388849,
|
||||||
82.56525711320404,
|
0.10696375739370656,
|
||||||
82.56189481621115,
|
0.10697911899441327,
|
||||||
82.55193622295673,
|
0.10700849091841115,
|
||||||
82.55375865382007,
|
0.10684541509383255,
|
||||||
82.56600202250685,
|
0.10709039088434134,
|
||||||
82.57064581324912,
|
0.10708275965900503,
|
||||||
82.55481151026538,
|
0.10700157213096435,
|
||||||
82.55367833324986,
|
0.10683403667221722,
|
||||||
82.56042112040724,
|
0.10696323639434627,
|
||||||
82.5616829048874,
|
0.10717970753709476,
|
||||||
82.5771528553759,
|
0.10707768420569408,
|
||||||
82.55317820035495,
|
0.10707720299052377,
|
||||||
82.57550573756552,
|
0.10703401576377387,
|
||||||
82.57334061973116,
|
0.10714245904396233,
|
||||||
82.56044387817383,
|
0.10722182246928032,
|
||||||
82.5752662593483,
|
0.10714326931052229,
|
||||||
82.56673936762361,
|
0.10708108509325573,
|
||||||
82.56828115740393,
|
0.10726493315245861,
|
||||||
82.56990289280557,
|
0.10718185655199565,
|
||||||
82.55218840052939,
|
0.10716394220407192,
|
||||||
82.56695372426611,
|
0.10727782184496903,
|
||||||
82.575043066954,
|
0.10729229825938869,
|
||||||
82.55754522991995,
|
0.10722832862510641,
|
||||||
82.56361721723508,
|
0.10727461647146787,
|
||||||
82.5628145821074,
|
0.10739002018593825,
|
||||||
82.56431990403395,
|
0.10721855878065793,
|
||||||
82.55777725806603,
|
0.10737398387784632,
|
||||||
82.5742861918914,
|
0.10721981757853785,
|
||||||
82.56361025622768,
|
0.10756766480895188,
|
||||||
82.56887233766736,
|
0.10733713450021723,
|
||||||
82.56539458902473,
|
0.10742478621884799,
|
||||||
82.55887828729091,
|
0.10721213524986027,
|
||||||
82.56073884882478,
|
0.10737172850113139,
|
||||||
82.55578186165573
|
0.10744189095293355
|
||||||
],
|
],
|
||||||
"perc_loss": [
|
"perc_loss": [
|
||||||
0.0,
|
0.0,
|
||||||
@@ -535,12 +537,12 @@
|
|||||||
0.0
|
0.0
|
||||||
],
|
],
|
||||||
"fid": {
|
"fid": {
|
||||||
"25": 315.9393615722656,
|
"25": 108.92365264892578,
|
||||||
"50": 419.273193359375,
|
"50": 93.73921203613281,
|
||||||
"75": 360.4432678222656,
|
"75": 90.11531829833984,
|
||||||
"100": 363.9911193847656
|
"100": 88.40287780761719
|
||||||
},
|
},
|
||||||
"train_time_s": 660.9630489349365
|
"train_time_s": 1526.7542352676392
|
||||||
},
|
},
|
||||||
"n_params": 10608451
|
"n_params": 10608451
|
||||||
}
|
}
|
||||||
@@ -11,324 +11,326 @@
|
|||||||
"sample_interval": 10,
|
"sample_interval": 10,
|
||||||
"fid_interval": 25,
|
"fid_interval": 25,
|
||||||
"fid_n_real": 5000,
|
"fid_n_real": 5000,
|
||||||
|
"num_workers": 2,
|
||||||
"epochs": 100,
|
"epochs": 100,
|
||||||
"augment": "hflip",
|
"augment": "hflip",
|
||||||
"image_size": 64,
|
"image_size": 64,
|
||||||
"model": "vae",
|
"model": "vae",
|
||||||
"latent_dim": 256,
|
"latent_dim": 256,
|
||||||
"ngf": 64,
|
"ngf": 64,
|
||||||
|
"grad_clip": 1.0,
|
||||||
"run_name": "p3_2_vae_perceptual",
|
"run_name": "p3_2_vae_perceptual",
|
||||||
"lr": 0.001,
|
"lr": 0.0005,
|
||||||
"beta_kl": 0.0001,
|
"beta_kl": 0.25,
|
||||||
"lambda_perceptual": 0.1,
|
"lambda_perceptual": 0.1,
|
||||||
"lambda_adversarial": 0.0
|
"lambda_adversarial": 0.0
|
||||||
},
|
},
|
||||||
"history": {
|
"history": {
|
||||||
"recon_loss": [
|
"recon_loss": [
|
||||||
NaN,
|
0.07014625494042014,
|
||||||
0.041370747770127066,
|
0.04432385861396025,
|
||||||
0.035506031866002284,
|
0.03902342810462683,
|
||||||
0.06229733923672993,
|
0.036902043435117625,
|
||||||
0.06089382184048494,
|
0.035906293887135565,
|
||||||
0.053464704654856116,
|
0.035914984559560686,
|
||||||
0.04950355895213846,
|
0.0360899488401846,
|
||||||
0.04674786329269409,
|
0.036847473583860785,
|
||||||
0.05693557829811023,
|
0.037394630213260144,
|
||||||
0.04808994451076047,
|
0.03836990671200503,
|
||||||
NaN,
|
0.039295883698022775,
|
||||||
0.05869839608701121,
|
0.040212508386526354,
|
||||||
0.05293231666820426,
|
0.04124497003757801,
|
||||||
0.05057006603918779,
|
0.04212733724305772,
|
||||||
0.04713928615117175,
|
0.04325710433638758,
|
||||||
0.04481589820427008,
|
0.04402907300963361,
|
||||||
0.04331832689543565,
|
0.045103192679647706,
|
||||||
0.041820655449524395,
|
0.04595246975525068,
|
||||||
NaN,
|
0.04698181097419598,
|
||||||
0.24315394992884407,
|
0.047825980820080154,
|
||||||
0.24305414841470555,
|
0.04772546164627768,
|
||||||
0.2429187158998261,
|
0.04763130701950982,
|
||||||
0.24289727962424612,
|
0.047709369705591954,
|
||||||
0.2377373814328104,
|
0.04753121634961194,
|
||||||
0.23880945107875726,
|
0.04741841174948674,
|
||||||
0.24047312904626894,
|
0.04729832872016053,
|
||||||
0.23915709144411942,
|
0.04722155279551561,
|
||||||
NaN,
|
0.0469712430340612,
|
||||||
0.24000247061634675,
|
0.04693082829093576,
|
||||||
0.2362435321904655,
|
0.0468663705401441,
|
||||||
0.23705266075383905,
|
0.04660685717040657,
|
||||||
NaN,
|
0.04661381538384236,
|
||||||
0.23998981039238793,
|
0.046462044382515624,
|
||||||
0.23828768376738596,
|
0.0464134263431924,
|
||||||
0.23908851302077627,
|
0.046272445438254595,
|
||||||
NaN,
|
0.04624118105882508,
|
||||||
NaN,
|
0.046066727958874315,
|
||||||
NaN,
|
0.04603223238363225,
|
||||||
NaN,
|
0.04589838094404365,
|
||||||
NaN,
|
0.0458371472489248,
|
||||||
NaN,
|
0.04577782691225537,
|
||||||
NaN,
|
0.04567820464985238,
|
||||||
NaN,
|
0.04560983914913785,
|
||||||
NaN,
|
0.045465615761076286,
|
||||||
NaN,
|
0.04543259674603613,
|
||||||
NaN,
|
0.04530587481159685,
|
||||||
NaN,
|
0.04528731873465909,
|
||||||
NaN,
|
0.04525673297098559,
|
||||||
NaN,
|
0.04507890154217553,
|
||||||
NaN,
|
0.0450239604514124,
|
||||||
NaN,
|
0.04497251241730574,
|
||||||
NaN,
|
0.044865927300774135,
|
||||||
NaN,
|
0.04475303162207715,
|
||||||
NaN,
|
0.044664276763796806,
|
||||||
NaN,
|
0.044507257990602754,
|
||||||
NaN,
|
0.0444452501515038,
|
||||||
NaN,
|
0.044273821509674065,
|
||||||
NaN,
|
0.044196275588220514,
|
||||||
NaN,
|
0.044097925846775375,
|
||||||
NaN,
|
0.04397543971864586,
|
||||||
NaN,
|
0.04393626836279773,
|
||||||
NaN,
|
0.043823353540247835,
|
||||||
NaN,
|
0.04371356347209623,
|
||||||
NaN,
|
0.043580674797169164,
|
||||||
NaN,
|
0.04348461520181507,
|
||||||
NaN,
|
0.04339473670682846,
|
||||||
NaN,
|
0.043282039009798795,
|
||||||
NaN,
|
0.04320543994888281,
|
||||||
NaN,
|
0.04314021340324583,
|
||||||
NaN,
|
0.04300904847904403,
|
||||||
NaN,
|
0.04293948887950844,
|
||||||
NaN,
|
0.04286929544730064,
|
||||||
NaN,
|
0.04278319672896312,
|
||||||
NaN,
|
0.04270272867547141,
|
||||||
NaN,
|
0.0425933551393513,
|
||||||
NaN,
|
0.04250210111276207,
|
||||||
NaN,
|
0.04243518800562263,
|
||||||
NaN,
|
0.04233357961424905,
|
||||||
NaN,
|
0.04227217231105026,
|
||||||
NaN,
|
0.042201977119677596,
|
||||||
NaN,
|
0.04207655237430436,
|
||||||
NaN,
|
0.04206944097024508,
|
||||||
NaN,
|
0.0420022471768097,
|
||||||
NaN,
|
0.041929350385808535,
|
||||||
NaN,
|
0.04179893332159417,
|
||||||
NaN,
|
0.041740718671781384,
|
||||||
NaN,
|
0.04166606991177695,
|
||||||
NaN,
|
0.04158131634959808,
|
||||||
NaN,
|
0.04155632069445828,
|
||||||
NaN,
|
0.04140628448440733,
|
||||||
NaN,
|
0.041379960253834724,
|
||||||
NaN,
|
0.041310225080093764,
|
||||||
NaN,
|
0.04123378162168794,
|
||||||
NaN,
|
0.04118641266105776,
|
||||||
NaN,
|
0.04113346862837545,
|
||||||
NaN,
|
0.041023712032116376,
|
||||||
NaN,
|
0.04102735277902112,
|
||||||
NaN,
|
0.04096395063062764,
|
||||||
NaN,
|
0.04094440599059702,
|
||||||
NaN
|
0.040864359348630294
|
||||||
],
|
],
|
||||||
"kl_loss": [
|
"kl_loss": [
|
||||||
41526536.5961082,
|
0.9578518337673612,
|
||||||
1258.7768262553418,
|
0.8413855069213443,
|
||||||
1165.5987154968784,
|
0.7201743645545764,
|
||||||
3689.321748260759,
|
0.6246620782165446,
|
||||||
1320.3167513333833,
|
0.5550231880102402,
|
||||||
928.3473894168169,
|
0.49975109450582766,
|
||||||
789.754654843583,
|
0.454662803783376,
|
||||||
693.4729607736963,
|
0.4155570190941167,
|
||||||
1017.8856519389357,
|
0.3821740793621438,
|
||||||
622.2623097998464,
|
0.352426735827556,
|
||||||
304942.2120989938,
|
0.32639839258204156,
|
||||||
750.5718396830763,
|
0.3032199253893306,
|
||||||
618.6888906364767,
|
0.2829984358360625,
|
||||||
586.4194498958751,
|
0.26505850424241817,
|
||||||
579.2793892102363,
|
0.24903864537676176,
|
||||||
525.8830437945504,
|
0.23453107457130384,
|
||||||
497.5363791050055,
|
0.22166774176761636,
|
||||||
452.93342662061383,
|
0.21033292894180006,
|
||||||
172476214612.95245,
|
0.19950743799662998,
|
||||||
1112.027792645316,
|
0.18994869887192026,
|
||||||
1109.8156276605068,
|
0.18805619294189999,
|
||||||
1109.8872808472722,
|
0.18697467955768618,
|
||||||
1112.3999198196282,
|
0.1855154659312505,
|
||||||
1163.6093267457097,
|
0.1852042447680082,
|
||||||
1195.4633280436199,
|
0.18474761941112006,
|
||||||
1232.709252512353,
|
0.1843619988514827,
|
||||||
1253.0651078183428,
|
0.18399313028551575,
|
||||||
19923.65347159622,
|
0.18352704119478536,
|
||||||
6871.588723207132,
|
0.18345636556036451,
|
||||||
6872.017739842081,
|
0.18366932483692455,
|
||||||
6872.62934758113,
|
0.18345450644946507,
|
||||||
27062.17111336472,
|
0.18317942117523944,
|
||||||
4992.576698759683,
|
0.1830864556006387,
|
||||||
5000.895510942508,
|
0.1831519056079734,
|
||||||
5011.723715236044,
|
0.18329900547734693,
|
||||||
4.196954763212122e+20,
|
0.1829878402252992,
|
||||||
NaN,
|
0.18321112291807803,
|
||||||
NaN,
|
0.18296580338197896,
|
||||||
NaN,
|
0.18312124166096377,
|
||||||
NaN,
|
0.18301436457878503,
|
||||||
NaN,
|
0.1831260528574642,
|
||||||
NaN,
|
0.18320244878657863,
|
||||||
NaN,
|
0.18332607687538505,
|
||||||
NaN,
|
0.18335427452101666,
|
||||||
NaN,
|
0.18322811696009758,
|
||||||
NaN,
|
0.18336237903334138,
|
||||||
NaN,
|
0.183573446308191,
|
||||||
NaN,
|
0.18356411601615769,
|
||||||
NaN,
|
0.18363770969912538,
|
||||||
NaN,
|
0.18377095862076834,
|
||||||
NaN,
|
0.18369814056234482,
|
||||||
NaN,
|
0.18386876185098264,
|
||||||
NaN,
|
0.18424720492245805,
|
||||||
NaN,
|
0.18420857423518458,
|
||||||
NaN,
|
0.18395239082921264,
|
||||||
NaN,
|
0.18442093539569113,
|
||||||
NaN,
|
0.1842692175036312,
|
||||||
NaN,
|
0.18475162495787328,
|
||||||
NaN,
|
0.18502490509014863,
|
||||||
NaN,
|
0.18480629139603713,
|
||||||
NaN,
|
0.1851997250993537,
|
||||||
NaN,
|
0.1850732435973791,
|
||||||
NaN,
|
0.1854018302809479,
|
||||||
NaN,
|
0.18547267151566652,
|
||||||
NaN,
|
0.18569090538936803,
|
||||||
NaN,
|
0.1859345402664099,
|
||||||
NaN,
|
0.18615943276219898,
|
||||||
NaN,
|
0.18631401260057065,
|
||||||
NaN,
|
0.1865774132948146,
|
||||||
NaN,
|
0.1867309583303256,
|
||||||
NaN,
|
0.18672093723574254,
|
||||||
NaN,
|
0.18682665680336136,
|
||||||
NaN,
|
0.18679539349853483,
|
||||||
NaN,
|
0.18683723016427115,
|
||||||
NaN,
|
0.1874776411578696,
|
||||||
NaN,
|
0.18746165714712223,
|
||||||
NaN,
|
0.1873003001141752,
|
||||||
NaN,
|
0.18763184111215112,
|
||||||
NaN,
|
0.18799093031348327,
|
||||||
NaN,
|
0.18802655252635989,
|
||||||
NaN,
|
0.18824852079662502,
|
||||||
NaN,
|
0.1882636730169129,
|
||||||
NaN,
|
0.18849152215143555,
|
||||||
NaN,
|
0.1886035389561429,
|
||||||
NaN,
|
0.188697873813729,
|
||||||
NaN,
|
0.18883028218888828,
|
||||||
NaN,
|
0.1889004738858113,
|
||||||
NaN,
|
0.18922981174073666,
|
||||||
NaN,
|
0.18946354345888153,
|
||||||
NaN,
|
0.18955816268029377,
|
||||||
NaN,
|
0.18944057507010606,
|
||||||
NaN,
|
0.18984314531852037,
|
||||||
NaN,
|
0.19001257871715432,
|
||||||
NaN,
|
0.19001863161340737,
|
||||||
NaN,
|
0.18997256346365327,
|
||||||
NaN,
|
0.19019658920856622,
|
||||||
NaN,
|
0.19018171425176483,
|
||||||
NaN,
|
0.19038618680758354,
|
||||||
NaN,
|
0.19034281325263855,
|
||||||
NaN
|
0.19046620505615178
|
||||||
],
|
],
|
||||||
"perc_loss": [
|
"perc_loss": [
|
||||||
NaN,
|
3.257610213552785,
|
||||||
3.0017659954535656,
|
2.987424520855276,
|
||||||
2.8931196978968434,
|
2.8892488026211405,
|
||||||
3.084328340159522,
|
2.8424616914529066,
|
||||||
3.1398269641093717,
|
2.817538415774321,
|
||||||
3.0839009172896032,
|
2.807587738220508,
|
||||||
3.0459561949102287,
|
2.8047564518757357,
|
||||||
3.015714469628456,
|
2.8080830421203222,
|
||||||
3.0917934143645134,
|
2.8133585376617236,
|
||||||
3.010985380054539,
|
2.821376688969441,
|
||||||
NaN,
|
2.830571848612565,
|
||||||
3.102917709411719,
|
2.840076332927769,
|
||||||
3.055045795746339,
|
2.8507282504668603,
|
||||||
3.0351512865123587,
|
2.8593530094521675,
|
||||||
3.002274121993627,
|
2.8690790253826695,
|
||||||
2.9754957128793764,
|
2.877339167982085,
|
||||||
2.9553630066733074,
|
2.8865632745954724,
|
||||||
2.936704080328982,
|
2.8936263704911256,
|
||||||
NaN,
|
2.9023511684857883,
|
||||||
3.66463458385223,
|
2.9098138376178904,
|
||||||
3.6419444985878773,
|
2.906375333794162,
|
||||||
3.6321474706005845,
|
2.9055831661591163,
|
||||||
3.6272282503608966,
|
2.9031203985214233,
|
||||||
3.781295938878997,
|
2.9002673641229286,
|
||||||
3.740793755421272,
|
2.897312533651662,
|
||||||
3.759704489993234,
|
2.8946053095352955,
|
||||||
3.7078149914741516,
|
2.8919762974111443,
|
||||||
NaN,
|
2.8902457397208257,
|
||||||
3.7647176323792872,
|
2.8877664065768576,
|
||||||
3.7441278461717133,
|
2.884983114197723,
|
||||||
3.7264530261357627,
|
2.882109224286854,
|
||||||
NaN,
|
2.8805122767758164,
|
||||||
3.7299226170931106,
|
2.8777520192994013,
|
||||||
3.7314435080585318,
|
2.876453428186922,
|
||||||
3.754668933713538,
|
2.874014714334765,
|
||||||
NaN,
|
2.8723304679251123,
|
||||||
NaN,
|
2.870547831058502,
|
||||||
NaN,
|
2.86905308564504,
|
||||||
NaN,
|
2.8669193524580736,
|
||||||
NaN,
|
2.865350999383845,
|
||||||
NaN,
|
2.8633312608441734,
|
||||||
NaN,
|
2.861651761409564,
|
||||||
NaN,
|
2.860612149421985,
|
||||||
NaN,
|
2.858717398256318,
|
||||||
NaN,
|
2.857239680412488,
|
||||||
NaN,
|
2.8558127604998074,
|
||||||
NaN,
|
2.853911985189487,
|
||||||
NaN,
|
2.852791876364977,
|
||||||
NaN,
|
2.850916815109742,
|
||||||
NaN,
|
2.8496591963319697,
|
||||||
NaN,
|
2.8490061500133614,
|
||||||
NaN,
|
2.8465628389619355,
|
||||||
NaN,
|
2.844754463587052,
|
||||||
NaN,
|
2.842957689211919,
|
||||||
NaN,
|
2.840948245464227,
|
||||||
NaN,
|
2.8397679522506194,
|
||||||
NaN,
|
2.837174654006958,
|
||||||
NaN,
|
2.835211678957328,
|
||||||
NaN,
|
2.833740895629948,
|
||||||
NaN,
|
2.8325765499701867,
|
||||||
NaN,
|
2.830295750218579,
|
||||||
NaN,
|
2.8292511431579914,
|
||||||
NaN,
|
2.8273453595291853,
|
||||||
NaN,
|
2.8257394835480256,
|
||||||
NaN,
|
2.823762384744791,
|
||||||
NaN,
|
2.822507558215378,
|
||||||
NaN,
|
2.8205288745399213,
|
||||||
NaN,
|
2.819065808230995,
|
||||||
NaN,
|
2.8176180297492914,
|
||||||
NaN,
|
2.8165967805772767,
|
||||||
NaN,
|
2.8143360721759305,
|
||||||
NaN,
|
2.8135039892971005,
|
||||||
NaN,
|
2.8120578970664587,
|
||||||
NaN,
|
2.8098325041624217,
|
||||||
NaN,
|
2.8090034366672874,
|
||||||
NaN,
|
2.8076502920215964,
|
||||||
NaN,
|
2.80601545684358,
|
||||||
NaN,
|
2.805327193859296,
|
||||||
NaN,
|
2.803745285058633,
|
||||||
NaN,
|
2.802450499473474,
|
||||||
NaN,
|
2.800580952412043,
|
||||||
NaN,
|
2.799968459157862,
|
||||||
NaN,
|
2.798532901156662,
|
||||||
NaN,
|
2.797372330967178,
|
||||||
NaN,
|
2.795650184663952,
|
||||||
NaN,
|
2.79500452371744,
|
||||||
NaN,
|
2.79342931152409,
|
||||||
NaN,
|
2.792220654650631,
|
||||||
NaN,
|
2.7911063600809145,
|
||||||
NaN,
|
2.7897617689564695,
|
||||||
NaN,
|
2.7887179505111823,
|
||||||
NaN,
|
2.7880892407180915,
|
||||||
NaN,
|
2.7868736916118197,
|
||||||
NaN,
|
2.785768971993373,
|
||||||
NaN,
|
2.784949649602939,
|
||||||
NaN,
|
2.784032260760283,
|
||||||
NaN,
|
2.783316181765662,
|
||||||
NaN,
|
2.7821382457374506,
|
||||||
NaN,
|
2.7815805586994204,
|
||||||
NaN
|
2.780150015639444
|
||||||
],
|
],
|
||||||
"adv_g_loss": [
|
"adv_g_loss": [
|
||||||
0.0,
|
0.0,
|
||||||
@@ -535,12 +537,12 @@
|
|||||||
0.0
|
0.0
|
||||||
],
|
],
|
||||||
"fid": {
|
"fid": {
|
||||||
"25": 263.1458740234375,
|
"25": 85.4538345336914,
|
||||||
"50": 598.3736572265625,
|
"50": 70.30448150634766,
|
||||||
"75": 598.3736572265625,
|
"75": 68.88232421875,
|
||||||
"100": 598.3736572265625
|
"100": 68.23878479003906
|
||||||
},
|
},
|
||||||
"train_time_s": 952.6596128940582
|
"train_time_s": 1526.8077104091644
|
||||||
},
|
},
|
||||||
"n_params": 10608451
|
"n_params": 10608451
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"created_at": "2026-05-02T13:07:06.633366+00:00",
|
||||||
|
"config_paths": [
|
||||||
|
"generator/configs/phase3/p3_1_vae.json",
|
||||||
|
"generator/configs/phase3/p3_2_vae_perceptual.json",
|
||||||
|
"generator/configs/phase3/p3_3_vae_patchgan.json"
|
||||||
|
],
|
||||||
|
"instance_id": 36014025,
|
||||||
|
"offer_id": 35956124,
|
||||||
|
"ssh_host": "ssh3.vast.ai",
|
||||||
|
"ssh_port": 14024,
|
||||||
|
"status": "completed",
|
||||||
|
"remote_workspace": "/workspace/DRL_PROJ"
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"created_at": "2026-05-02T17:39:27.858755+00:00",
|
||||||
|
"config_paths": [
|
||||||
|
"generator/configs/phase3/p3_3_vae_patchgan.json"
|
||||||
|
],
|
||||||
|
"instance_id": 36025335,
|
||||||
|
"offer_id": 19025020,
|
||||||
|
"ssh_host": "ssh6.vast.ai",
|
||||||
|
"ssh_port": 25334,
|
||||||
|
"status": "completed",
|
||||||
|
"remote_workspace": "/workspace/DRL_PROJ"
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"created_at": "2026-05-02T18:44:33.143461+00:00",
|
||||||
|
"config_paths": [
|
||||||
|
"generator/configs/phase3/p3_3_vae_patchgan.json"
|
||||||
|
],
|
||||||
|
"instance_id": 36027837,
|
||||||
|
"offer_id": 29548831,
|
||||||
|
"ssh_host": null,
|
||||||
|
"ssh_port": null,
|
||||||
|
"status": "cancelled",
|
||||||
|
"remote_workspace": "/workspace/DRL_PROJ"
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"created_at": "2026-05-02T18:51:49.983505+00:00",
|
||||||
|
"config_paths": [
|
||||||
|
"generator/configs/phase3/p3_3_vae_patchgan.json"
|
||||||
|
],
|
||||||
|
"instance_id": 36028210,
|
||||||
|
"offer_id": 35956124,
|
||||||
|
"ssh_host": "ssh1.vast.ai",
|
||||||
|
"ssh_port": 28210,
|
||||||
|
"status": "completed",
|
||||||
|
"remote_workspace": "/workspace/DRL_PROJ"
|
||||||
|
}
|
||||||
|
Before Width: | Height: | Size: 80 KiB After Width: | Height: | Size: 93 KiB |
|
Before Width: | Height: | Size: 132 KiB After Width: | Height: | Size: 221 KiB |
|
Before Width: | Height: | Size: 78 KiB After Width: | Height: | Size: 81 KiB |
|
Before Width: | Height: | Size: 120 KiB After Width: | Height: | Size: 214 KiB |
|
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 76 KiB |
|
Before Width: | Height: | Size: 126 KiB After Width: | Height: | Size: 213 KiB |
|
Before Width: | Height: | Size: 5.5 KiB After Width: | Height: | Size: 75 KiB |
|
Before Width: | Height: | Size: 122 KiB After Width: | Height: | Size: 213 KiB |
|
Before Width: | Height: | Size: 7.5 KiB After Width: | Height: | Size: 75 KiB |
|
Before Width: | Height: | Size: 124 KiB After Width: | Height: | Size: 212 KiB |
|
Before Width: | Height: | Size: 6.3 KiB After Width: | Height: | Size: 75 KiB |
|
Before Width: | Height: | Size: 124 KiB After Width: | Height: | Size: 211 KiB |
|
Before Width: | Height: | Size: 4.4 KiB After Width: | Height: | Size: 74 KiB |
|
Before Width: | Height: | Size: 122 KiB After Width: | Height: | Size: 211 KiB |
|
Before Width: | Height: | Size: 3.7 KiB After Width: | Height: | Size: 74 KiB |
|
Before Width: | Height: | Size: 121 KiB After Width: | Height: | Size: 210 KiB |
|
Before Width: | Height: | Size: 3.4 KiB After Width: | Height: | Size: 74 KiB |
|
Before Width: | Height: | Size: 122 KiB After Width: | Height: | Size: 210 KiB |
|
Before Width: | Height: | Size: 3.3 KiB After Width: | Height: | Size: 74 KiB |
|
Before Width: | Height: | Size: 122 KiB After Width: | Height: | Size: 210 KiB |
|
Before Width: | Height: | Size: 91 KiB After Width: | Height: | Size: 91 KiB |
|
Before Width: | Height: | Size: 215 KiB After Width: | Height: | Size: 208 KiB |
|
Before Width: | Height: | Size: 40 KiB After Width: | Height: | Size: 86 KiB |
|
Before Width: | Height: | Size: 135 KiB After Width: | Height: | Size: 205 KiB |
|
Before Width: | Height: | Size: 30 KiB After Width: | Height: | Size: 87 KiB |
|
Before Width: | Height: | Size: 129 KiB After Width: | Height: | Size: 206 KiB |
|
Before Width: | Height: | Size: 285 B After Width: | Height: | Size: 89 KiB |
|
Before Width: | Height: | Size: 119 KiB After Width: | Height: | Size: 207 KiB |
|
Before Width: | Height: | Size: 285 B After Width: | Height: | Size: 90 KiB |
|
Before Width: | Height: | Size: 119 KiB After Width: | Height: | Size: 208 KiB |
|
Before Width: | Height: | Size: 285 B After Width: | Height: | Size: 92 KiB |
|
Before Width: | Height: | Size: 119 KiB After Width: | Height: | Size: 208 KiB |
|
Before Width: | Height: | Size: 285 B After Width: | Height: | Size: 93 KiB |
|
Before Width: | Height: | Size: 119 KiB After Width: | Height: | Size: 208 KiB |
|
Before Width: | Height: | Size: 285 B After Width: | Height: | Size: 93 KiB |
|
Before Width: | Height: | Size: 119 KiB After Width: | Height: | Size: 209 KiB |
|
Before Width: | Height: | Size: 285 B After Width: | Height: | Size: 94 KiB |
|
Before Width: | Height: | Size: 119 KiB After Width: | Height: | Size: 209 KiB |
|
Before Width: | Height: | Size: 285 B After Width: | Height: | Size: 94 KiB |
|
Before Width: | Height: | Size: 119 KiB After Width: | Height: | Size: 209 KiB |
|
Before Width: | Height: | Size: 118 KiB After Width: | Height: | Size: 92 KiB |
|
Before Width: | Height: | Size: 247 KiB After Width: | Height: | Size: 210 KiB |
|
Before Width: | Height: | Size: 31 KiB After Width: | Height: | Size: 90 KiB |
|
Before Width: | Height: | Size: 141 KiB After Width: | Height: | Size: 209 KiB |
|
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 92 KiB |
|
Before Width: | Height: | Size: 127 KiB After Width: | Height: | Size: 210 KiB |
|
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 94 KiB |
|
Before Width: | Height: | Size: 123 KiB After Width: | Height: | Size: 212 KiB |
|
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 96 KiB |
|
Before Width: | Height: | Size: 123 KiB After Width: | Height: | Size: 212 KiB |
|
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 98 KiB |
|
Before Width: | Height: | Size: 123 KiB After Width: | Height: | Size: 212 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 99 KiB |
|
Before Width: | Height: | Size: 123 KiB After Width: | Height: | Size: 213 KiB |
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 100 KiB |
|
Before Width: | Height: | Size: 123 KiB After Width: | Height: | Size: 214 KiB |
|
Before Width: | Height: | Size: 14 KiB After Width: | Height: | Size: 101 KiB |
|
Before Width: | Height: | Size: 123 KiB After Width: | Height: | Size: 214 KiB |
|
Before Width: | Height: | Size: 14 KiB After Width: | Height: | Size: 101 KiB |
|
Before Width: | Height: | Size: 123 KiB After Width: | Height: | Size: 215 KiB |
@@ -52,7 +52,7 @@ def main(config_path, *, data_dir_override=None, output_root="generator/outputs"
|
|||||||
|
|
||||||
# Count total trainable parameters
|
# Count total trainable parameters
|
||||||
if isinstance(model, tuple):
|
if isinstance(model, tuple):
|
||||||
n_params = sum(p.numel() for p in model[0].parameters() if p.requires_grad)
|
n_params = sum(p.numel() for m in model for p in m.parameters() if p.requires_grad)
|
||||||
else:
|
else:
|
||||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
print(f"Trainable params: {n_params:,}")
|
print(f"Trainable params: {n_params:,}")
|
||||||
|
|||||||
@@ -22,17 +22,21 @@ def _init_weights(m):
|
|||||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
elif isinstance(m, nn.BatchNorm2d) and m.weight is not None:
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)) and m.weight is not None:
|
||||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def _norm(channels: int) -> nn.GroupNorm:
|
||||||
|
return nn.GroupNorm(8, channels)
|
||||||
|
|
||||||
|
|
||||||
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
def _upsample_block(in_ch: int, out_ch: int) -> nn.Sequential:
|
||||||
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
|
"""Nearest-neighbour upsample followed by a 3×3 conv — no checkerboard."""
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Upsample(scale_factor=2, mode="nearest"),
|
nn.Upsample(scale_factor=2, mode="nearest"),
|
||||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
||||||
nn.BatchNorm2d(out_ch),
|
_norm(out_ch),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,7 +69,7 @@ class VAE(nn.Module):
|
|||||||
for _ in range(n_down - 1):
|
for _ in range(n_down - 1):
|
||||||
enc_layers += [
|
enc_layers += [
|
||||||
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
|
nn.Conv2d(ch, ch * 2, 4, stride=2, padding=1, bias=False),
|
||||||
nn.BatchNorm2d(ch * 2),
|
_norm(ch * 2),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
]
|
]
|
||||||
ch *= 2
|
ch *= 2
|
||||||
@@ -98,9 +102,12 @@ class VAE(nn.Module):
|
|||||||
|
|
||||||
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
h = self.encoder(x).flatten(1)
|
h = self.encoder(x).flatten(1)
|
||||||
return self.fc_mu(h), self.fc_lv(h)
|
log_var = self.fc_lv(h).clamp(-10.0, 10.0)
|
||||||
|
return self.fc_mu(h), log_var
|
||||||
|
|
||||||
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
|
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.training:
|
||||||
|
return mu
|
||||||
std = torch.exp(0.5 * log_var)
|
std = torch.exp(0.5 * log_var)
|
||||||
return mu + std * torch.randn_like(std)
|
return mu + std * torch.randn_like(std)
|
||||||
|
|
||||||
|
|||||||
@@ -20,3 +20,5 @@ class EMA:
|
|||||||
def update(self, model: nn.Module) -> None:
|
def update(self, model: nn.Module) -> None:
|
||||||
for p_ema, p in zip(self.model.parameters(), model.parameters()):
|
for p_ema, p in zip(self.model.parameters(), model.parameters()):
|
||||||
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
|
p_ema.data.mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
|
||||||
|
for b_ema, b in zip(self.model.buffers(), model.buffers()):
|
||||||
|
b_ema.copy_(b)
|
||||||
|
|||||||
@@ -11,14 +11,15 @@ from torchmetrics.image.fid import FrechetInceptionDistance
|
|||||||
|
|
||||||
|
|
||||||
class FIDEvaluator:
|
class FIDEvaluator:
|
||||||
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda"):
|
def __init__(self, real_dataset, n_real: int = 10_000, device: str = "cuda",
|
||||||
|
num_workers: int = 2):
|
||||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||||
self.n_real = n_real
|
self.n_real = n_real
|
||||||
|
|
||||||
# Cache real images as a CPU tensor ([-1, 1] range)
|
# Cache real images as a CPU tensor ([-1, 1] range)
|
||||||
imgs_list = []
|
imgs_list = []
|
||||||
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
|
loader = DataLoader(real_dataset, batch_size=256, shuffle=False,
|
||||||
num_workers=4, drop_last=False)
|
num_workers=num_workers, drop_last=False)
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
imgs_list.append(batch.cpu())
|
imgs_list.append(batch.cpu())
|
||||||
if sum(x.size(0) for x in imgs_list) >= n_real:
|
if sum(x.size(0) for x in imgs_list) >= n_real:
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ def train_dcgan(
|
|||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
train_dataset, batch_size=batch_size, shuffle=True,
|
train_dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=min(4, os.cpu_count() or 1),
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
|
||||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||||
)
|
)
|
||||||
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
||||||
@@ -86,7 +86,8 @@ def train_dcgan(
|
|||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
samples_dir = save_dir.parent / "samples" / run_name
|
samples_dir = save_dir.parent / "samples" / run_name
|
||||||
|
|
||||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||||
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||||
|
|
||||||
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
history = {"g_loss": [], "d_loss": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
@@ -239,7 +240,7 @@ def train_wgan(
|
|||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
train_dataset, batch_size=batch_size, shuffle=True,
|
train_dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=min(4, os.cpu_count() or 1),
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
|
||||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||||
)
|
)
|
||||||
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(beta1, beta2))
|
||||||
@@ -257,7 +258,8 @@ def train_wgan(
|
|||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
samples_dir = save_dir.parent / "samples" / run_name
|
samples_dir = save_dir.parent / "samples" / run_name
|
||||||
|
|
||||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||||
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||||
|
|
||||||
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
history = {"g_loss": [], "w_dist": [], "gp": [], "d_real": [], "d_fake": [], "fid": {}}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
@@ -391,7 +393,6 @@ def _save_vae_samples(
|
|||||||
# Interleave real / reconstruction pairs
|
# Interleave real / reconstruction pairs
|
||||||
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
|
pairs = torch.stack([real, recon], dim=1).flatten(0, 1)
|
||||||
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
save_image(pairs, samples_dir / f"epoch_{epoch:04d}_recon.png", nrow=4)
|
||||||
vae.train()
|
|
||||||
|
|
||||||
|
|
||||||
def train_vae(
|
def train_vae(
|
||||||
@@ -403,13 +404,21 @@ def train_vae(
|
|||||||
run_name: str,
|
run_name: str,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""VAE training loop covering Phase 3.1 – 3.3.
|
"""VAE training loop covering Phase 3.1 – 3.3 and Phase 5.
|
||||||
|
|
||||||
Config toggles:
|
Config toggles:
|
||||||
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
|
lambda_perceptual > 0 → VGG-16 perceptual loss (Phase 3.2+)
|
||||||
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
|
lambda_adversarial > 0 → PatchGAN hinge adversarial loss (Phase 3.3)
|
||||||
|
|
||||||
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
|
Loss: L = L_mse + λ_perc·L_vgg + λ_adv·L_adv + β_kl·L_kl
|
||||||
|
|
||||||
|
KL is computed as mean over latent dimensions (scale-invariant), so
|
||||||
|
beta_kl is comparable across different latent_dim values.
|
||||||
|
|
||||||
|
AMP is intentionally disabled for VAE training — mixed-precision float16
|
||||||
|
overflows when the KL divergence spikes, producing NaN cascades that
|
||||||
|
corrupt the model irrecoverably. All VAE + perceptual + PatchGAN
|
||||||
|
computation runs in float32.
|
||||||
"""
|
"""
|
||||||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||||
vae = vae.to(device)
|
vae = vae.to(device)
|
||||||
@@ -425,6 +434,7 @@ def train_vae(
|
|||||||
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
|
lambda_perceptual = cfg.get("lambda_perceptual", 0.0)
|
||||||
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
|
lambda_adversarial = cfg.get("lambda_adversarial", 0.0)
|
||||||
lr_d = cfg.get("lr_d", 1e-4)
|
lr_d = cfg.get("lr_d", 1e-4)
|
||||||
|
grad_clip = cfg.get("grad_clip", 1.0)
|
||||||
ema_decay = cfg.get("ema_decay", 0.9999)
|
ema_decay = cfg.get("ema_decay", 0.9999)
|
||||||
sample_interval = cfg.get("sample_interval", 10)
|
sample_interval = cfg.get("sample_interval", 10)
|
||||||
fid_interval = cfg.get("fid_interval", 25)
|
fid_interval = cfg.get("fid_interval", 25)
|
||||||
@@ -435,13 +445,13 @@ def train_vae(
|
|||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
train_dataset, batch_size=batch_size, shuffle=True,
|
train_dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=min(4, os.cpu_count() or 1),
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
|
||||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
|
opt_vae = torch.optim.Adam(vae.parameters(), lr=lr)
|
||||||
use_amp = device.type == "cuda"
|
# AMP disabled — float16 overflows on KL spikes, causing NaN cascades
|
||||||
scaler = _GradScaler("cuda", enabled=use_amp)
|
use_amp = False
|
||||||
|
|
||||||
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
|
# KL warmup: linearly ramp beta_kl from 0 to target over first 20% of training
|
||||||
kl_warmup_epochs = max(1, epochs // 5)
|
kl_warmup_epochs = max(1, epochs // 5)
|
||||||
@@ -456,11 +466,10 @@ def train_vae(
|
|||||||
perc_fn = None
|
perc_fn = None
|
||||||
patchgan = None
|
patchgan = None
|
||||||
opt_d = None
|
opt_d = None
|
||||||
scaler_d = None
|
|
||||||
|
|
||||||
if use_perceptual:
|
if use_perceptual:
|
||||||
from src.training.perceptual import PerceptualLoss
|
from src.training.perceptual import PerceptualLoss
|
||||||
perc_fn = PerceptualLoss().to(device)
|
perc_fn = PerceptualLoss().to(device).float()
|
||||||
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
|
print("Perceptual loss: VGG-16 relu1_2 + relu2_2 + relu3_3")
|
||||||
|
|
||||||
if use_adversarial:
|
if use_adversarial:
|
||||||
@@ -468,15 +477,14 @@ def train_vae(
|
|||||||
patchgan = PatchGANDiscriminator(
|
patchgan = PatchGANDiscriminator(
|
||||||
ndf=cfg.get("ndf_patch", 64),
|
ndf=cfg.get("ndf_patch", 64),
|
||||||
image_size=cfg.get("image_size", 64),
|
image_size=cfg.get("image_size", 64),
|
||||||
).to(device)
|
).to(device).float()
|
||||||
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
opt_d = torch.optim.Adam(patchgan.parameters(), lr=lr_d, betas=(0.5, 0.999))
|
||||||
scaler_d = _GradScaler("cuda", enabled=use_amp)
|
|
||||||
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
sched_d = torch.optim.lr_scheduler.LambdaLR(
|
||||||
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
opt_d, lr_lambda=lambda ep: max(0.0, 1.0 - max(ep - decay_start, 0) / max(epochs - decay_start, 1)))
|
||||||
n_d = sum(p.numel() for p in patchgan.parameters())
|
n_d = sum(p.numel() for p in patchgan.parameters())
|
||||||
print(f"PatchGAN: {n_d:,} params")
|
print(f"PatchGAN: {n_d:,} params")
|
||||||
else:
|
else:
|
||||||
hinge_d_loss = hinge_g_loss = None # satisfy linter, never called
|
hinge_d_loss = hinge_g_loss = None # never called
|
||||||
|
|
||||||
# ── Fixed seeds for consistent visualisation ──────────────────────────
|
# ── Fixed seeds for consistent visualisation ──────────────────────────
|
||||||
fixed_z = torch.randn(16, latent_dim, device=device)
|
fixed_z = torch.randn(16, latent_dim, device=device)
|
||||||
@@ -490,16 +498,19 @@ def train_vae(
|
|||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
samples_dir = save_dir.parent / "samples" / run_name
|
samples_dir = save_dir.parent / "samples" / run_name
|
||||||
|
|
||||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||||
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||||
|
|
||||||
history = {
|
history = {
|
||||||
"recon_loss": [], "kl_loss": [], "perc_loss": [],
|
"recon_loss": [], "kl_loss": [], "perc_loss": [],
|
||||||
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
|
"adv_g_loss": [], "adv_d_loss": [], "fid": {},
|
||||||
}
|
}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
|
nan_skipped = 0
|
||||||
print(
|
print(
|
||||||
f"Device: {device} AMP: {use_amp} Batches/epoch: {len(loader)}"
|
f"Device: {device} AMP: disabled (float32) Batches/epoch: {len(loader)}"
|
||||||
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual} λ_adv={lambda_adversarial}"
|
f" β_kl={beta_kl} (warmup {kl_warmup_epochs}ep) λ_perc={lambda_perceptual}"
|
||||||
|
f" λ_adv={lambda_adversarial}"
|
||||||
)
|
)
|
||||||
|
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
@@ -513,43 +524,52 @@ def train_vae(
|
|||||||
n_batches = 0
|
n_batches = 0
|
||||||
|
|
||||||
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
for real in tqdm(loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
|
||||||
real = real.to(device)
|
real = real.to(device).float()
|
||||||
|
|
||||||
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
|
# KL warmup: ramp from 0 to beta_kl over kl_warmup_epochs
|
||||||
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
|
current_beta = beta_kl * min(1.0, epoch / kl_warmup_epochs)
|
||||||
|
|
||||||
# ── VAE forward ───────────────────────────────────────────────
|
# ── VAE forward (float32, no AMP) ────────────────────────────
|
||||||
with _autocast("cuda", enabled=use_amp):
|
recon, mu, log_var = vae(real)
|
||||||
recon, mu, log_var = vae(real)
|
mse = F.mse_loss(recon, real)
|
||||||
mse = F.mse_loss(recon, real)
|
|
||||||
kl = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(1).mean()
|
# KL divergence: mean over latent dims (scale-invariant w.r.t. latent_dim)
|
||||||
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
|
kl = (-0.5 * (1 + log_var - mu.pow(2) - log_var.exp())).mean()
|
||||||
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
|
|
||||||
|
perc = perc_fn(recon, real) if use_perceptual else real.new_zeros(1).squeeze()
|
||||||
|
vae_loss = mse + current_beta * kl + lambda_perceptual * perc
|
||||||
|
|
||||||
|
# ── NaN/Inf guard ────────────────────────────────────────────
|
||||||
|
if not torch.isfinite(vae_loss):
|
||||||
|
nan_skipped += 1
|
||||||
|
opt_vae.zero_grad()
|
||||||
|
continue
|
||||||
|
|
||||||
# ── PatchGAN discriminator step ───────────────────────────────
|
# ── PatchGAN discriminator step ───────────────────────────────
|
||||||
adv_d = real.new_zeros(1).squeeze()
|
adv_d = real.new_zeros(1).squeeze()
|
||||||
if use_adversarial:
|
if use_adversarial:
|
||||||
opt_d.zero_grad()
|
# Warmup: only start adversarial training after 20% of epochs
|
||||||
with _autocast("cuda", enabled=use_amp):
|
if epoch > kl_warmup_epochs:
|
||||||
|
opt_d.zero_grad()
|
||||||
d_real = patchgan(real)
|
d_real = patchgan(real)
|
||||||
d_fake = patchgan(recon.detach())
|
d_fake = patchgan(recon.detach())
|
||||||
adv_d = hinge_d_loss(d_real, d_fake)
|
adv_d = hinge_d_loss(d_real, d_fake)
|
||||||
scaler_d.scale(adv_d).backward()
|
if torch.isfinite(adv_d):
|
||||||
scaler_d.step(opt_d)
|
adv_d.backward()
|
||||||
scaler_d.update()
|
torch.nn.utils.clip_grad_norm_(patchgan.parameters(), grad_clip)
|
||||||
|
opt_d.step()
|
||||||
|
|
||||||
# ── PatchGAN generator adversarial loss ───────────────────────
|
# ── PatchGAN generator adversarial loss ───────────────────────
|
||||||
adv_g = real.new_zeros(1).squeeze()
|
adv_g = real.new_zeros(1).squeeze()
|
||||||
if use_adversarial:
|
if use_adversarial and epoch > kl_warmup_epochs:
|
||||||
with _autocast("cuda", enabled=use_amp):
|
adv_g = hinge_g_loss(patchgan(recon))
|
||||||
adv_g = hinge_g_loss(patchgan(recon))
|
vae_loss = vae_loss + lambda_adversarial * adv_g
|
||||||
vae_loss = vae_loss + lambda_adversarial * adv_g
|
|
||||||
|
|
||||||
# ── VAE backward ──────────────────────────────────────────────
|
# ── VAE backward ──────────────────────────────────────────────
|
||||||
opt_vae.zero_grad()
|
opt_vae.zero_grad()
|
||||||
scaler.scale(vae_loss).backward()
|
vae_loss.backward()
|
||||||
scaler.step(opt_vae)
|
torch.nn.utils.clip_grad_norm_(vae.parameters(), grad_clip)
|
||||||
scaler.update()
|
opt_vae.step()
|
||||||
ema.update(vae)
|
ema.update(vae)
|
||||||
|
|
||||||
recon_sum += mse.item()
|
recon_sum += mse.item()
|
||||||
@@ -559,11 +579,11 @@ def train_vae(
|
|||||||
adv_d_sum += adv_d.item()
|
adv_d_sum += adv_d.item()
|
||||||
n_batches += 1
|
n_batches += 1
|
||||||
|
|
||||||
avg_r = recon_sum / n_batches
|
avg_r = recon_sum / max(n_batches, 1)
|
||||||
avg_k = kl_sum / n_batches
|
avg_k = kl_sum / max(n_batches, 1)
|
||||||
avg_p = perc_sum / n_batches
|
avg_p = perc_sum / max(n_batches, 1)
|
||||||
avg_g = adv_g_sum / n_batches
|
avg_g = adv_g_sum / max(n_batches, 1)
|
||||||
avg_d = adv_d_sum / n_batches
|
avg_d = adv_d_sum / max(n_batches, 1)
|
||||||
history["recon_loss"].append(avg_r)
|
history["recon_loss"].append(avg_r)
|
||||||
history["kl_loss"].append(avg_k)
|
history["kl_loss"].append(avg_k)
|
||||||
history["perc_loss"].append(avg_p)
|
history["perc_loss"].append(avg_p)
|
||||||
@@ -574,6 +594,7 @@ def train_vae(
|
|||||||
f"[{epoch:03d}/{epochs}] "
|
f"[{epoch:03d}/{epochs}] "
|
||||||
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
|
f"MSE: {avg_r:.4f} KL: {avg_k:.2f} β={current_beta:.6f} "
|
||||||
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
|
f"Perc: {avg_p:.4f} AdvG: {avg_g:.4f} AdvD: {avg_d:.4f}"
|
||||||
|
f" (NaN skipped: {nan_skipped})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if epoch % sample_interval == 0:
|
if epoch % sample_interval == 0:
|
||||||
@@ -607,6 +628,7 @@ def train_vae(
|
|||||||
if patchgan is not None:
|
if patchgan is not None:
|
||||||
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
|
torch.save(patchgan.state_dict(), save_dir / f"{run_name}_final_patchgan.pt")
|
||||||
history["train_time_s"] = time.time() - t_start
|
history["train_time_s"] = time.time() - t_start
|
||||||
|
print(f"Total NaN-skipped batches: {nan_skipped}")
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
@@ -662,7 +684,7 @@ def train_ddpm(
|
|||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
train_dataset, batch_size=batch_size, shuffle=True,
|
train_dataset, batch_size=batch_size, shuffle=True,
|
||||||
num_workers=min(4, os.cpu_count() or 1),
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)),
|
||||||
pin_memory=(device.type == "cuda"), drop_last=True,
|
pin_memory=(device.type == "cuda"), drop_last=True,
|
||||||
)
|
)
|
||||||
opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
opt = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||||
@@ -679,7 +701,8 @@ def train_ddpm(
|
|||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
samples_dir = save_dir.parent / "samples" / run_name
|
samples_dir = save_dir.parent / "samples" / run_name
|
||||||
|
|
||||||
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device))
|
fid_eval = FIDEvaluator(train_dataset, n_real=fid_n_real, device=str(device),
|
||||||
|
num_workers=cfg.get("num_workers", min(4, os.cpu_count() or 1)))
|
||||||
|
|
||||||
history = {"loss": [], "fid": {}}
|
history = {"loss": [], "fid": {}}
|
||||||
best_fid = float("inf")
|
best_fid = float("inf")
|
||||||
|
|||||||