From 80a314b9e93de4284c80fc3aa6e42731eaa65eaa Mon Sep 17 00:00:00 2001 From: Johnny Fernandes Date: Sun, 26 Apr 2026 22:32:13 +0100 Subject: [PATCH] Trying attention method --- training/train_at.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/training/train_at.py b/training/train_at.py index cbf37f3..2289496 100644 --- a/training/train_at.py +++ b/training/train_at.py @@ -139,7 +139,8 @@ class ShepherdAttentionExtractor(BaseFeaturesExtractor): dropout=0.0, batch_first=True, ) self.transformer = nn.TransformerEncoder(encoder_layer, - num_layers=n_layers) + num_layers=n_layers, + enable_nested_tensor=False) def forward(self, obs: torch.Tensor) -> torch.Tensor: B = obs.shape[0]