62 lines
1.7 KiB
Python
62 lines
1.7 KiB
Python
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
class ReplayPolicy:
|
|||
|
|
"""
|
|||
|
|
严格回放策略:根据专家轨迹数据,逐帧回放车辆状态
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, expert_trajectory, vehicle_id):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
expert_trajectory: 专家轨迹字典,包含 positions, headings, velocities, valid
|
|||
|
|
vehicle_id: 车辆ID(用于调试)
|
|||
|
|
"""
|
|||
|
|
self.trajectory = expert_trajectory
|
|||
|
|
self.vehicle_id = vehicle_id
|
|||
|
|
self.current_step = 0
|
|||
|
|
|
|||
|
|
def act(self, observation=None):
|
|||
|
|
"""
|
|||
|
|
返回动作:在回放模式下返回空动作
|
|||
|
|
实际状态由环境直接设置
|
|||
|
|
"""
|
|||
|
|
return [0.0, 0.0]
|
|||
|
|
|
|||
|
|
def get_target_state(self, step):
|
|||
|
|
"""
|
|||
|
|
获取指定时间步的目标状态
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
step: 时间步
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
dict: 包含 position, heading, velocity 的字典,如果无效则返回 None
|
|||
|
|
"""
|
|||
|
|
if step >= len(self.trajectory['valid']):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if not self.trajectory['valid'][step]:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'position': self.trajectory['positions'][step],
|
|||
|
|
'heading': self.trajectory['headings'][step],
|
|||
|
|
'velocity': self.trajectory['velocities'][step]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def is_finished(self, step):
|
|||
|
|
"""
|
|||
|
|
判断轨迹是否已经播放完毕
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
step: 当前时间步
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
bool: 如果轨迹已播放完或当前步无效,返回 True
|
|||
|
|
"""
|
|||
|
|
# 超出轨迹长度
|
|||
|
|
if step >= len(self.trajectory['valid']):
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
# 当前步及之后都无效
|
|||
|
|
return not any(self.trajectory['valid'][step:])
|