magail4autodrive: first commit
This commit is contained in:
31
Algorithm/policy.py
Normal file
31
Algorithm/policy.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from .utils import build_mlp, reparameterize, evaluate_lop_pi
|
||||
|
||||
class StateIndependentPolicy(nn.Module):
|
||||
|
||||
def __init__(self, state_shape, action_shape, hidden_units=(64, 64),
|
||||
hidden_activation=nn.Tanh()):
|
||||
super().__init__()
|
||||
|
||||
self.net = build_mlp(
|
||||
input_dim=state_shape[0],
|
||||
output_dim=action_shape[0],
|
||||
hidden_units=hidden_units,
|
||||
hidden_activation=hidden_activation
|
||||
)
|
||||
self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0]))
|
||||
self.means = None
|
||||
|
||||
def forward(self, states):
|
||||
return torch.tanh(self.net(states))
|
||||
|
||||
def sample(self, states):
|
||||
self.means = self.net(states)
|
||||
actions, log_pis = reparameterize(self.means, self.log_stds)
|
||||
return actions, log_pis
|
||||
|
||||
def evaluate_log_pi(self, states, actions):
|
||||
self.means = self.net(states)
|
||||
return evaluate_lop_pi(self.means, self.log_stds, actions)
|
||||
Reference in New Issue
Block a user