34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import torch
|
|
import numpy as np
|
|
from torch import nn
|
|
try:
|
|
from .utils import build_mlp, reparameterize, evaluate_lop_pi
|
|
except ImportError:
|
|
from utils import build_mlp, reparameterize, evaluate_lop_pi
|
|
|
|
class StateIndependentPolicy(nn.Module):
|
|
|
|
def __init__(self, state_shape, action_shape, hidden_units=(64, 64),
|
|
hidden_activation=nn.Tanh()):
|
|
super().__init__()
|
|
|
|
self.net = build_mlp(
|
|
input_dim=state_shape[0],
|
|
output_dim=action_shape[0],
|
|
hidden_units=hidden_units,
|
|
hidden_activation=hidden_activation
|
|
)
|
|
self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0]))
|
|
self.means = None
|
|
|
|
def forward(self, states):
|
|
return torch.tanh(self.net(states))
|
|
|
|
def sample(self, states):
|
|
self.means = self.net(states)
|
|
actions, log_pis = reparameterize(self.means, self.log_stds)
|
|
return actions, log_pis
|
|
|
|
def evaluate_log_pi(self, states, actions):
|
|
self.means = self.net(states)
|
|
return evaluate_lop_pi(self.means, self.log_stds, actions) |