diff --git a/training/train.py b/training/train.py index 7e549c4..dc19371 100644 --- a/training/train.py +++ b/training/train.py @@ -380,8 +380,8 @@ DEFAULT_CONFIG = { "W_COMPLETE": 100.0, "W_STEP_COST": 0.02, "W_COMPACT": 0.0, - "W_WALL_TOUCH": 0.15, - "WALL_TOUCH_BUFFER": 0.8, + "W_WALL_TOUCH": 0.04, + "WALL_TOUCH_BUFFER": 0.3, "ALIGN_SHAPE": "standoff", "ALIGN_GATED": True, "ENTRY_AWARE": False, @@ -408,11 +408,15 @@ def parse_args(): def main(): args = parse_args() - # Load config + # Load config: --config overrides, else auto-load config.json if present cfg = dict(DEFAULT_CONFIG) - if args.config: - with open(args.config) as f: + config_path = args.config + if config_path is None and os.path.exists("config.json"): + config_path = "config.json" + if config_path: + with open(config_path) as f: cfg.update(json.load(f)) + print(f"Config loaded from {config_path}") rcfg = {k: v for k, v in cfg.items() if hasattr(HerdingEnv, k)}