Trying attention method
This commit is contained in:
@@ -139,7 +139,8 @@ class ShepherdAttentionExtractor(BaseFeaturesExtractor):
|
|||||||
dropout=0.0, batch_first=True,
|
dropout=0.0, batch_first=True,
|
||||||
)
|
)
|
||||||
self.transformer = nn.TransformerEncoder(encoder_layer,
|
self.transformer = nn.TransformerEncoder(encoder_layer,
|
||||||
num_layers=n_layers)
|
num_layers=n_layers,
|
||||||
|
enable_nested_tensor=False)
|
||||||
|
|
||||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||||
B = obs.shape[0]
|
B = obs.shape[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user