Files
MAGAIL4AutoDrive/Algorithm/policy.py

31 lines
1006 B
Python
Raw Normal View History

2025-09-28 18:57:04 +08:00
import torch
import numpy as np
from torch import nn
from .utils import build_mlp, reparameterize, evaluate_lop_pi
class StateIndependentPolicy(nn.Module):
def __init__(self, state_shape, action_shape, hidden_units=(64, 64),
hidden_activation=nn.Tanh()):
super().__init__()
self.net = build_mlp(
input_dim=state_shape[0],
output_dim=action_shape[0],
hidden_units=hidden_units,
hidden_activation=hidden_activation
)
self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0]))
self.means = None
def forward(self, states):
return torch.tanh(self.net(states))
def sample(self, states):
self.means = self.net(states)
actions, log_pis = reparameterize(self.means, self.log_stds)
return actions, log_pis
def evaluate_log_pi(self, states, actions):
self.means = self.net(states)
return evaluate_lop_pi(self.means, self.log_stds, actions)