Trying attention method

This commit is contained in:
Johnny Fernandes
2026-04-26 22:32:13 +01:00
parent a2363d882f
commit 80a314b9e9
+2 -1
View File
@@ -139,7 +139,8 @@ class ShepherdAttentionExtractor(BaseFeaturesExtractor):
dropout=0.0, batch_first=True, dropout=0.0, batch_first=True,
) )
self.transformer = nn.TransformerEncoder(encoder_layer, self.transformer = nn.TransformerEncoder(encoder_layer,
num_layers=n_layers) num_layers=n_layers,
enable_nested_tensor=False)
def forward(self, obs: torch.Tensor) -> torch.Tensor: def forward(self, obs: torch.Tensor) -> torch.Tensor:
B = obs.shape[0] B = obs.shape[0]