Trying attention method
This commit is contained in:
@@ -139,7 +139,8 @@ class ShepherdAttentionExtractor(BaseFeaturesExtractor):
|
||||
dropout=0.0, batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer,
|
||||
num_layers=n_layers)
|
||||
num_layers=n_layers,
|
||||
enable_nested_tensor=False)
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
B = obs.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user