import imageio
import matplotlib.pyplot as plt
import numpy as np
from pytsc.common.utils import EnvLogger
from pytsc.controllers.evaluate import Evaluate
[docs]class ObservationEvaluator(Evaluate):
def __init__(
self,
scenario,
simulator_backend,
controller,
add_env_args={},
add_controller_args={},
**kwargs,
):
super(ObservationEvaluator, self).__init__(
scenario,
simulator_backend,
controller,
add_env_args,
add_controller_args,
**kwargs,
)
self.obs_info = self.network.observation_space.get_observation_info()
self.mat_size = (self.obs_info["max_n_controlled_lanes"], 10)
# def run(self, hours, save_stats=False, plot_stats=False, output_folder=None):
# EnvLogger.log_info(f"Evaluating {self.controller_name} controller")
# output_folder = self._create_output_folder(output_folder)
# steps = int(hours * 3600 / self.delta_time)
# for step in range(steps):
# actions = self._get_actions()
# _, done, stats = self.network.step(actions)
# self._log_stats(step, stats)
# if self.network.simulator.is_terminated:
# self._init_network()
# self._init_controllers()
# if done:
# self.network.restart()
# if save_stats:
# self._save_stats(output_folder=output_folder)
# if plot_stats:
# self._plot_stats(output_folder=output_folder)
[docs] def run(self, hours, output_folder=None):
output_folder = self._create_output_folder(output_folder)
steps = int(hours * 3600 / self.delta_time)
frames_dict = {ts_id: [] for ts_id in self.network.traffic_signals.keys()}
n_signals = len(self.network.traffic_signals)
fig, axes = plt.subplots(1, n_signals, figsize=(5 * n_signals, 5))
if n_signals == 1:
axes = [axes]
for step in range(steps):
actions = self._get_actions()
_, done, stats = self.network.step(actions)
obs = self.network.get_observations()
for idx, (ts_id, ts) in enumerate(self.network.traffic_signals.items()):
ts_obs = np.asarray(obs[idx])
# lane_mats = ts_obs[self.obs_info["lane_mat_idxs"]] - 1
# pos_mats = ts_obs[self.obs_info["pos_mat_idxs"]] + lane_mats
# pos_mats = pos_mats.reshape(self.mat_size, self.mat_size)
pos_mats = ts_obs[:-10].reshape(*self.mat_size)
axes[idx].clear()
axes[idx].imshow(pos_mats, cmap="viridis")
axes[idx].set_title(f"Signal {ts_id} at step {step}")
axes[idx].axis("off")
frame = plt_to_image(fig)
frames_dict[ts_id].append(frame)
if self.network.simulator.is_terminated:
self._init_network()
self._init_controllers()
if done:
self.network.restart()
for ts_id, frames in frames_dict.items():
gif_path = (
f"{output_folder}/{self.scenario}_{ts_id}_obs_matrix_animation.gif"
)
imageio.mimsave(gif_path, frames, fps=10)
print(f"Saved GIF for {ts_id} at {gif_path}")
# def run(self, hours, output_folder=None):
# output_folder = self._create_output_folder(output_folder)
# steps = int(hours * 3600 / self.delta_time)
# frames = []
# fig, ax = plt.subplots(figsize=(6, 6))
# for step in range(steps):
# actions = self._get_actions()
# _, done, stats = self.network.step(actions)
# pos_mat = self.network.simulator.step_measurements["sim"]["position_matrix"]
# global_matrix = pos_mat
# global_matrix = global_matrix.reshape(268, 366)
# ax.clear()
# ax.imshow(global_matrix, cmap="viridis")
# ax.set_title(f"Global Observation Matrix at step {step}")
# ax.axis("off")
# # plt.pause(0.01)
# frame = plt_to_image(fig)
# frames.append(frame)
# if self.network.simulator.is_terminated:
# self._init_network()
# self._init_controllers()
# if done:
# self.network.restart()
gif_path = f"{output_folder}/{self.scenario}_global_obs_matrix_animation.gif"
imageio.mimsave(gif_path, frames, fps=10)
print(f"Saved global observation GIF at {gif_path}")
[docs]def plt_to_image(fig):
"""
Converts a matplotlib figure to an RGB image represented as a NumPy array.
"""
fig.canvas.draw() # Draw the figure
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return img
if __name__ == "__main__":
scenario = "random_grid_singles"
output_folder = "/home/rohitbokade/repos/pytsc"
add_env_args = {
"signal": {"observation_space": "position_matrix", "obs_dropout_prob": 0.5}
}
obs_evaluator = ObservationEvaluator(
scenario=scenario,
simulator_backend="sumo",
controller="sotl",
add_env_args=add_env_args,
)
obs_evaluator.run(hours=0.1, output_folder=output_folder)