import torch import numpy as np from torch import nn try: from .utils import build_mlp, reparameterize, evaluate_lop_pi except ImportError: from utils import build_mlp, reparameterize, evaluate_lop_pi class StateIndependentPolicy(nn.Module): def __init__(self, state_shape, action_shape, hidden_units=(64, 64), hidden_activation=nn.Tanh()): super().__init__() self.net = build_mlp( input_dim=state_shape[0], output_dim=action_shape[0], hidden_units=hidden_units, hidden_activation=hidden_activation ) self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0])) self.means = None def forward(self, states): return torch.tanh(self.net(states)) def sample(self, states): self.means = self.net(states) actions, log_pis = reparameterize(self.means, self.log_stds) return actions, log_pis def evaluate_log_pi(self, states, actions): self.means = self.net(states) return evaluate_lop_pi(self.means, self.log_stds, actions)