diff --git a/training/smoke_test.py b/training/smoke_test.py index 7ae92dc..cd1ba94 100644 --- a/training/smoke_test.py +++ b/training/smoke_test.py @@ -98,11 +98,8 @@ def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None): vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0) if prev_model is not None: - model = PPO.load(prev_model, env=vn, - learning_rate=3e-4, n_steps=2048, batch_size=256, - n_epochs=10, gamma=0.995, gae_lambda=0.95, - clip_range=0.2, ent_coef=0.005, vf_coef=0.5, - max_grad_norm=0.5) + model = prev_model + model.set_env(vn) else: model = PPO( "MlpPolicy", vn, @@ -113,7 +110,8 @@ def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None): verbose=1, ) - model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None)) + model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None), + tb_log_name="ppo_smoke") return model, vn