Sheep training flock of 10 fix?
This commit is contained in:
@@ -98,11 +98,8 @@ def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
|
||||
vn = VecNormalize(train_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
|
||||
|
||||
if prev_model is not None:
|
||||
model = PPO.load(prev_model, env=vn,
|
||||
learning_rate=3e-4, n_steps=2048, batch_size=256,
|
||||
n_epochs=10, gamma=0.995, gae_lambda=0.95,
|
||||
clip_range=0.2, ent_coef=0.005, vf_coef=0.5,
|
||||
max_grad_norm=0.5)
|
||||
model = prev_model
|
||||
model.set_env(vn)
|
||||
else:
|
||||
model = PPO(
|
||||
"MlpPolicy", vn,
|
||||
@@ -113,7 +110,8 @@ def train_stage(n_sheep, steps, n_envs=4, prev_model=None, prev_vecnorm=None):
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None))
|
||||
model.learn(total_timesteps=steps, reset_num_timesteps=(prev_model is None),
|
||||
tb_log_name="ppo_smoke")
|
||||
return model, vn
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user