Files
MAGAIL4AutoDrive/Env/replay_policy.py

62 lines
1.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
class ReplayPolicy:
"""
严格回放策略:根据专家轨迹数据,逐帧回放车辆状态
"""
def __init__(self, expert_trajectory, vehicle_id):
"""
Args:
expert_trajectory: 专家轨迹字典,包含 positions, headings, velocities, valid
vehicle_id: 车辆ID用于调试
"""
self.trajectory = expert_trajectory
self.vehicle_id = vehicle_id
self.current_step = 0
def act(self, observation=None):
"""
返回动作:在回放模式下返回空动作
实际状态由环境直接设置
"""
return [0.0, 0.0]
def get_target_state(self, step):
"""
获取指定时间步的目标状态
Args:
step: 时间步
Returns:
dict: 包含 position, heading, velocity 的字典,如果无效则返回 None
"""
if step >= len(self.trajectory['valid']):
return None
if not self.trajectory['valid'][step]:
return None
return {
'position': self.trajectory['positions'][step],
'heading': self.trajectory['headings'][step],
'velocity': self.trajectory['velocities'][step]
}
def is_finished(self, step):
"""
判断轨迹是否已经播放完毕
Args:
step: 当前时间步
Returns:
bool: 如果轨迹已播放完或当前步无效,返回 True
"""
# 超出轨迹长度
if step >= len(self.trajectory['valid']):
return True
# 当前步及之后都无效
return not any(self.trajectory['valid'][step:])