Files
MAGAIL4AutoDrive/Env/replay_policy.py

62 lines
1.7 KiB
Python
Raw Permalink Normal View History

import numpy as np
class ReplayPolicy:
"""
严格回放策略根据专家轨迹数据逐帧回放车辆状态
"""
def __init__(self, expert_trajectory, vehicle_id):
"""
Args:
expert_trajectory: 专家轨迹字典包含 positions, headings, velocities, valid
vehicle_id: 车辆ID用于调试
"""
self.trajectory = expert_trajectory
self.vehicle_id = vehicle_id
self.current_step = 0
def act(self, observation=None):
"""
返回动作在回放模式下返回空动作
实际状态由环境直接设置
"""
return [0.0, 0.0]
def get_target_state(self, step):
"""
获取指定时间步的目标状态
Args:
step: 时间步
Returns:
dict: 包含 position, heading, velocity 的字典如果无效则返回 None
"""
if step >= len(self.trajectory['valid']):
return None
if not self.trajectory['valid'][step]:
return None
return {
'position': self.trajectory['positions'][step],
'heading': self.trajectory['headings'][step],
'velocity': self.trajectory['velocities'][step]
}
def is_finished(self, step):
"""
判断轨迹是否已经播放完毕
Args:
step: 当前时间步
Returns:
bool: 如果轨迹已播放完或当前步无效返回 True
"""
# 超出轨迹长度
if step >= len(self.trajectory['valid']):
return True
# 当前步及之后都无效
return not any(self.trajectory['valid'][step:])