Sheep training flock of 10
This commit is contained in:
@@ -204,6 +204,15 @@ while robot.step(timestep) != -1:
|
|||||||
fx += math.cos(wander_angle) * 0.5
|
fx += math.cos(wander_angle) * 0.5
|
||||||
fy += math.sin(wander_angle) * 0.5
|
fy += math.sin(wander_angle) * 0.5
|
||||||
|
|
||||||
|
# Hard-stop clamp: within 0.5 m of a wall, zero any force component that
|
||||||
|
# would push further into it. Prevents the flee force from pinning a sheep
|
||||||
|
# against the boundary when the dog approaches from outside.
|
||||||
|
HS = 0.5
|
||||||
|
if x < X_MIN + HS and fx < 0: fx = 0.0
|
||||||
|
if x > X_MAX - HS and fx > 0: fx = 0.0
|
||||||
|
if y < Y_MIN + HS and fy < 0: fy = 0.0
|
||||||
|
if y > Y_MAX - HS and fy > 0: fy = 0.0
|
||||||
|
|
||||||
heading = math.atan2(fy, fx)
|
heading = math.atan2(fy, fx)
|
||||||
mag = math.hypot(fx, fy)
|
mag = math.hypot(fx, fy)
|
||||||
speed = max(WANDER_SPEED, min(FLEE_SPEED, mag * 3.0))
|
speed = max(WANDER_SPEED, min(FLEE_SPEED, mag * 3.0))
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class HerdingEnv(gym.Env):
|
|||||||
# -----------------------------------------------------------------------
|
# -----------------------------------------------------------------------
|
||||||
# World constants — must match Webots world file
|
# World constants — must match Webots world file
|
||||||
# -----------------------------------------------------------------------
|
# -----------------------------------------------------------------------
|
||||||
MAX_SHEEP = 5
|
MAX_SHEEP = 10
|
||||||
FIELD = 15.0 # half-size; positions ∈ [-FIELD, FIELD]
|
FIELD = 15.0 # half-size; positions ∈ [-FIELD, FIELD]
|
||||||
PEN_X = (10.0, 13.0)
|
PEN_X = (10.0, 13.0)
|
||||||
PEN_Y = (-15.0, -8.0)
|
PEN_Y = (-15.0, -8.0)
|
||||||
@@ -344,6 +344,14 @@ class HerdingEnv(gym.Env):
|
|||||||
if pos[1] < -F + m: fy += ((-F + m - pos[1]) / m) * 6.0
|
if pos[1] < -F + m: fy += ((-F + m - pos[1]) / m) * 6.0
|
||||||
if pos[1] > F - m: fy -= ((pos[1] - (F - m)) / m) * 6.0
|
if pos[1] > F - m: fy -= ((pos[1] - (F - m)) / m) * 6.0
|
||||||
|
|
||||||
|
# Hard-stop clamp: mirrors sheep.py — zero any force driving further
|
||||||
|
# into the wall within 0.5 m so the flee force cannot pin the sheep.
|
||||||
|
HS = 0.5
|
||||||
|
if pos[0] < -F + HS and fx < 0: fx = 0.0
|
||||||
|
if pos[0] > F - HS and fx > 0: fx = 0.0
|
||||||
|
if pos[1] < -F + HS and fy < 0: fy = 0.0
|
||||||
|
if pos[1] > F - HS and fy > 0: fy = 0.0
|
||||||
|
|
||||||
# Wander — suppressed while fleeing
|
# Wander — suppressed while fleeing
|
||||||
if not fleeing:
|
if not fleeing:
|
||||||
if self.np_random.random() < 0.02:
|
if self.np_random.random() < 0.02:
|
||||||
|
|||||||
+95
-60
@@ -3,13 +3,17 @@ PPO training script for the herding task.
|
|||||||
|
|
||||||
Usage examples
|
Usage examples
|
||||||
--------------
|
--------------
|
||||||
# Start fresh with curriculum (1 → 5 sheep):
|
# Proper 5-sheep curriculum, 1 M steps per stage:
|
||||||
python train.py --curriculum
|
python train.py --curriculum --steps-per-stage 1000000 --total-steps 5000000
|
||||||
|
|
||||||
# Resume from checkpoint, skip directly to 3 sheep:
|
# Success-rate curriculum (advances when 70 % success over 100 episodes):
|
||||||
python train.py --resume runs/ppo_herding/ckpt_200000_steps.zip --n-sheep 3
|
python train.py --curriculum --threshold 0.70
|
||||||
|
|
||||||
# Quick smoke-test (no curriculum, single env):
|
# Resume from checkpoint at stage 3:
|
||||||
|
python train.py --resume runs/ppo_herding/ckpt_3000000_steps.zip --n-sheep 3 \
|
||||||
|
--curriculum --steps-per-stage 1000000 --total-steps 5000000
|
||||||
|
|
||||||
|
# Quick smoke-test:
|
||||||
python train.py --n-envs 1 --total-steps 50000
|
python train.py --n-envs 1 --total-steps 50000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -35,39 +39,64 @@ from herding_env import HerdingEnv
|
|||||||
|
|
||||||
class CurriculumCallback(BaseCallback):
|
class CurriculumCallback(BaseCallback):
|
||||||
"""
|
"""
|
||||||
Advances the curriculum (number of active sheep) when the rolling mean
|
Advances n_sheep on both training and eval envs.
|
||||||
episode success rate exceeds a threshold.
|
|
||||||
|
|
||||||
Success = episode terminated (all sheep penned) rather than truncated.
|
Two modes (mutually exclusive):
|
||||||
|
steps_per_stage — advance every N environment steps regardless of
|
||||||
|
success rate (recommended for reliability).
|
||||||
|
threshold — advance when rolling success rate exceeds this value
|
||||||
|
(requires the policy to actually reach the threshold).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
THRESHOLD = 0.75 # success rate to graduate
|
def __init__(self, start_sheep: int, max_sheep: int,
|
||||||
WINDOW = 100 # episodes to average over
|
eval_env=None,
|
||||||
MIN_EPISODES = 50 # don't graduate before seeing this many episodes
|
steps_per_stage: int = None,
|
||||||
|
threshold: float = 0.75,
|
||||||
def __init__(self, start_sheep: int, max_sheep: int, verbose: int = 1):
|
window: int = 100,
|
||||||
|
min_episodes: int = 50,
|
||||||
|
verbose: int = 1):
|
||||||
super().__init__(verbose)
|
super().__init__(verbose)
|
||||||
self.max_sheep = max_sheep
|
self.max_sheep = max_sheep
|
||||||
self._successes = []
|
self.eval_env = eval_env
|
||||||
self._cur_sheep = start_sheep
|
self.steps_per_stage = steps_per_stage
|
||||||
|
self.threshold = threshold
|
||||||
|
self.window = window
|
||||||
|
self.min_episodes = min_episodes
|
||||||
|
self._cur_sheep = start_sheep
|
||||||
|
self._successes = []
|
||||||
|
self._stage_start = 0
|
||||||
|
|
||||||
|
def _advance(self):
|
||||||
|
self._cur_sheep += 1
|
||||||
|
self.training_env.env_method("set_n_sheep", self._cur_sheep)
|
||||||
|
if self.eval_env is not None:
|
||||||
|
self.eval_env.env_method("set_n_sheep", self._cur_sheep)
|
||||||
|
self._stage_start = self.num_timesteps
|
||||||
|
self._successes.clear()
|
||||||
|
if self.verbose:
|
||||||
|
print(f"\n[Curriculum] → {self._cur_sheep} sheep "
|
||||||
|
f"at step {self.num_timesteps:,}\n")
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
for info, done in zip(self.locals["infos"], self.locals["dones"]):
|
if self._cur_sheep >= self.max_sheep:
|
||||||
if done:
|
return True
|
||||||
truncated = info.get("TimeLimit.truncated", False)
|
|
||||||
self._successes.append(0 if truncated else 1)
|
|
||||||
if len(self._successes) > self.WINDOW:
|
|
||||||
self._successes.pop(0)
|
|
||||||
|
|
||||||
if (self._cur_sheep < self.max_sheep
|
if self.steps_per_stage is not None:
|
||||||
and len(self._successes) >= self.MIN_EPISODES
|
# Time-based: advance every steps_per_stage env steps
|
||||||
and np.mean(self._successes) >= self.THRESHOLD):
|
if self.num_timesteps - self._stage_start >= self.steps_per_stage:
|
||||||
self._cur_sheep += 1
|
self._advance()
|
||||||
self.training_env.env_method("set_n_sheep", self._cur_sheep)
|
else:
|
||||||
self._successes.clear()
|
# Success-rate based
|
||||||
if self.verbose:
|
for info, done in zip(self.locals["infos"], self.locals["dones"]):
|
||||||
print(f"\n[Curriculum] Advanced to {self._cur_sheep} sheep "
|
if done:
|
||||||
f"at step {self.num_timesteps}\n")
|
truncated = info.get("TimeLimit.truncated", False)
|
||||||
|
self._successes.append(0 if truncated else 1)
|
||||||
|
if len(self._successes) > self.window:
|
||||||
|
self._successes.pop(0)
|
||||||
|
|
||||||
|
if (len(self._successes) >= self.min_episodes
|
||||||
|
and np.mean(self._successes) >= self.threshold):
|
||||||
|
self._advance()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -90,36 +119,35 @@ def make_env(n_sheep: int, seed: int, max_steps: int):
|
|||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
p = argparse.ArgumentParser()
|
p = argparse.ArgumentParser()
|
||||||
p.add_argument("--n-sheep", type=int, default=1,
|
p.add_argument("--n-sheep", type=int, default=1,
|
||||||
help="Starting number of sheep (or fixed count if no curriculum)")
|
help="Starting sheep count")
|
||||||
p.add_argument("--max-sheep", type=int, default=5,
|
p.add_argument("--max-sheep", type=int, default=5,
|
||||||
help="Maximum sheep for curriculum (ignored without --curriculum)")
|
help="Final sheep count for curriculum")
|
||||||
p.add_argument("--n-envs", type=int, default=8,
|
p.add_argument("--n-envs", type=int, default=8,
|
||||||
help="Number of parallel environments")
|
help="Parallel training environments")
|
||||||
p.add_argument("--total-steps", type=int, default=5_000_000,
|
p.add_argument("--total-steps", type=int, default=5_000_000)
|
||||||
help="Total environment steps to train for")
|
p.add_argument("--max-steps", type=int, default=2000,
|
||||||
p.add_argument("--max-steps", type=int, default=2000,
|
help="Episode step limit")
|
||||||
help="Episode step limit inside each env")
|
p.add_argument("--curriculum", action="store_true",
|
||||||
p.add_argument("--curriculum", action="store_true",
|
help="Enable curriculum advancement")
|
||||||
help="Enable automatic curriculum advancement")
|
p.add_argument("--steps-per-stage", type=int, default=None,
|
||||||
p.add_argument("--resume", type=str, default=None,
|
help="Advance curriculum every N steps (overrides --threshold)")
|
||||||
help="Path to a .zip checkpoint to resume training from")
|
p.add_argument("--threshold", type=float, default=0.75,
|
||||||
p.add_argument("--run-dir", type=str, default="runs/ppo_herding",
|
help="Success-rate threshold to advance (used without --steps-per-stage)")
|
||||||
help="Output directory for checkpoints and logs")
|
p.add_argument("--resume", type=str, default=None,
|
||||||
p.add_argument("--save-freq", type=int, default=100_000,
|
help="Checkpoint .zip to resume from")
|
||||||
help="Checkpoint every N steps (per-env, not total)")
|
p.add_argument("--run-dir", type=str, default="runs/ppo_herding")
|
||||||
p.add_argument("--eval-freq", type=int, default=50_000,
|
p.add_argument("--save-freq", type=int, default=100_000)
|
||||||
help="Evaluate every N steps")
|
p.add_argument("--eval-freq", type=int, default=50_000)
|
||||||
p.add_argument("--eval-eps", type=int, default=20,
|
p.add_argument("--eval-eps", type=int, default=20)
|
||||||
help="Episodes per evaluation run")
|
|
||||||
return p.parse_args()
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
os.makedirs(args.run_dir, exist_ok=True)
|
os.makedirs(args.run_dir, exist_ok=True)
|
||||||
ckpt_dir = os.path.join(args.run_dir, "checkpoints")
|
ckpt_dir = os.path.join(args.run_dir, "checkpoints")
|
||||||
best_dir = os.path.join(args.run_dir, "best_model")
|
best_dir = os.path.join(args.run_dir, "best_model")
|
||||||
norm_path = os.path.join(args.run_dir, "vecnorm.pkl")
|
norm_path = os.path.join(args.run_dir, "vecnorm.pkl")
|
||||||
os.makedirs(ckpt_dir, exist_ok=True)
|
os.makedirs(ckpt_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -130,13 +158,13 @@ def main():
|
|||||||
])
|
])
|
||||||
if args.resume and os.path.exists(norm_path):
|
if args.resume and os.path.exists(norm_path):
|
||||||
train_env = VecNormalize.load(norm_path, train_env)
|
train_env = VecNormalize.load(norm_path, train_env)
|
||||||
train_env.training = True
|
train_env.training = True
|
||||||
train_env.norm_reward = True
|
train_env.norm_reward = True
|
||||||
else:
|
else:
|
||||||
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True,
|
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=True,
|
||||||
clip_obs=10.0)
|
clip_obs=10.0)
|
||||||
|
|
||||||
# Eval env (no reward normalisation, deterministic)
|
# Eval env — starts at same difficulty, advances with curriculum callback
|
||||||
eval_env = SubprocVecEnv([
|
eval_env = SubprocVecEnv([
|
||||||
make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps)
|
make_env(args.n_sheep, seed=1000 + i, max_steps=args.max_steps)
|
||||||
for i in range(2)
|
for i in range(2)
|
||||||
@@ -161,9 +189,17 @@ def main():
|
|||||||
verbose=1,
|
verbose=1,
|
||||||
)
|
)
|
||||||
callbacks = [checkpoint_cb, eval_cb]
|
callbacks = [checkpoint_cb, eval_cb]
|
||||||
|
|
||||||
if args.curriculum:
|
if args.curriculum:
|
||||||
callbacks.append(CurriculumCallback(start_sheep=args.n_sheep,
|
cur_cb = CurriculumCallback(
|
||||||
max_sheep=args.max_sheep))
|
start_sheep=args.n_sheep,
|
||||||
|
max_sheep=args.max_sheep,
|
||||||
|
eval_env=eval_env,
|
||||||
|
steps_per_stage=args.steps_per_stage,
|
||||||
|
threshold=args.threshold,
|
||||||
|
)
|
||||||
|
callbacks.append(cur_cb)
|
||||||
|
|
||||||
callback_list = CallbackList(callbacks)
|
callback_list = CallbackList(callbacks)
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
@@ -201,7 +237,6 @@ def main():
|
|||||||
tb_log_name="ppo",
|
tb_log_name="ppo",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save final artefacts
|
|
||||||
model.save(os.path.join(args.run_dir, "final_model"))
|
model.save(os.path.join(args.run_dir, "final_model"))
|
||||||
train_env.save(norm_path)
|
train_env.save(norm_path)
|
||||||
print(f"\nTraining complete. Artefacts saved to {args.run_dir}/")
|
print(f"\nTraining complete. Artefacts saved to {args.run_dir}/")
|
||||||
|
|||||||
Reference in New Issue
Block a user