diff --git a/training/train_at.py b/training/train_at.py index cbf37f3..2289496 100644 --- a/training/train_at.py +++ b/training/train_at.py @@ -139,7 +139,8 @@ class ShepherdAttentionExtractor(BaseFeaturesExtractor): dropout=0.0, batch_first=True, ) self.transformer = nn.TransformerEncoder(encoder_layer, - num_layers=n_layers) + num_layers=n_layers, + enable_nested_tensor=False) def forward(self, obs: torch.Tensor) -> torch.Tensor: B = obs.shape[0]