diff --git a/training/train.py b/training/train.py index 490c457..d0c0e46 100644 --- a/training/train.py +++ b/training/train.py @@ -33,7 +33,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNorm from herding_env import HerdingEnv -COMPACT_RADIUS = HerdingEnv.DRIVE_GATE_RADIUS +COMPACT_RADIUS = 5.0 def _classify(ep_radius, ep_com_dist, n_penned, n_sheep, success):