diff --git a/training/smoke_test.py b/training/smoke_test.py index 11e582c..35ce0ba 100644 --- a/training/smoke_test.py +++ b/training/smoke_test.py @@ -179,6 +179,12 @@ def main(): ) eval_env.close() + save_dir = f"runs/smoke_stage{n_sheep}" + os.makedirs(save_dir, exist_ok=True) + model.save(os.path.join(save_dir, "model")) + vn.save(os.path.join(save_dir, "vecnorm.pkl")) + print(f" Model saved to {save_dir}/") + passed = report(n_sheep, success_rate, failure_counts, args.episodes, threshold) if not passed: all_passed = False