Source code for pytsc

import numpy as np

from pytsc.backends.cityflow import CITYFLOW_MODULES
from pytsc.backends.sumo import SUMO_MODULES
from pytsc.common import ACTION_SPACES, OBSERVATION_SPACES, REWARD_FUNCTIONS
from pytsc.common.actions import CentralizedActionSpace
from pytsc.common.utils import validate_input_against_allowed

SUPPORTED_SIMULATOR_BACKENDS = ("sumo", "cityflow")
REWARD_METRICS = ("queue", "pressure")  # TODO: Implement `pressure`
SIMULATOR_MODULES = {
    "cityflow": CITYFLOW_MODULES,
    "sumo": SUMO_MODULES,
}


[docs]class TrafficSignalNetwork: def __init__(self, scenario, simulator_backend, **kwargs): self.scenario = scenario self.simulator_backend = simulator_backend self.disrupted = kwargs.get("disrupted", False) self.domain_class = kwargs.get("domain_class", None) assert ( self.simulator_backend in SUPPORTED_SIMULATOR_BACKENDS ), f"Simulator backend {self.simulator_backend} not supported." if self.disrupted: self.config = SIMULATOR_MODULES[simulator_backend]["disrupted_config"]( scenario, **kwargs ) else: self.config = SIMULATOR_MODULES[simulator_backend]["config"]( scenario, **kwargs ) self._validate_config() self.parsed_network = SIMULATOR_MODULES[simulator_backend]["network_parser"]( self.config ) self.simulator = SIMULATOR_MODULES[simulator_backend]["simulator"]( self.parsed_network ) self.simulator.start_simulator() self._init_traffic_signals() self._init_parsers() self._set_n_agents() self._init_counters() def _init_counters(self): self.hour_count = 0 self.episode_count = 0 @property def episode_limit(self): return int( self.config.simulator["episode_limit"] / self.config.simulator["delta_time"] ) @property def episode_over(self): if self.simulator.sim_step > 0: return self.simulator.sim_step % self.config.simulator["episode_limit"] == 0 else: return False def _validate_config(self): validate_input_against_allowed( self.config.signal["action_space"], ACTION_SPACES.keys() ) validate_input_against_allowed( self.config.signal["observation_space"], OBSERVATION_SPACES.keys() ) def _init_parsers(self): self.action_space = ACTION_SPACES[self.config.signal["action_space"]]( self.config, self.traffic_signals ) if self.config.network["control_scheme"] == "centralized": self.action_space = CentralizedActionSpace(self.action_space) self.observation_space = OBSERVATION_SPACES[ self.config.signal["observation_space"] ]( self.config, self.parsed_network, self.traffic_signals, self.simulator_backend, ) self.metrics = SIMULATOR_MODULES[self.simulator_backend]["metrics_parser"]( self.parsed_network, self.simulator, self.traffic_signals, ) self.reward_function = REWARD_FUNCTIONS[self.config.signal["reward_function"]]( self.metrics, self.traffic_signals ) def _init_traffic_signals(self): parsed_traffic_signals = self.parsed_network.traffic_signals self.traffic_signals = {} for ts_id, signal_config in parsed_traffic_signals.items(): self.traffic_signals[ts_id] = SIMULATOR_MODULES[self.simulator_backend][ "traffic_signal" ](ts_id, signal_config, self.simulator) self.traffic_signals[ts_id].update_stats(self.simulator.step_measurements) def _set_n_agents(self): self.n_agents = ( len(self.traffic_signals) if self.config.network["control_scheme"] == "decentralized" else 1 # centralized ) def _update_ts_stats(self): for ts_id in self.traffic_signals.keys(): self.traffic_signals[ts_id].update_stats(self.simulator.step_measurements)
[docs] def get_action_mask(self): return self.action_space.get_mask()
[docs] def get_action_size(self): return self.action_space.get_size()
[docs] def get_observations(self): if self.config.network["control_scheme"] == "decentralized": return self.observation_space.get_observations() else: return [np.concatenate(self.observation_space.get_observations()).tolist()]
[docs] def get_observation_size(self): if self.config.network["control_scheme"] == "decentralized": return self.observation_space.get_size() else: n_a = self.parsed_network.adjacency_matrix.shape[0] return self.observation_space.get_size() * n_a
[docs] def get_state(self): return self.observation_space.get_state()
[docs] def get_state_size(self): return self.observation_space.get_state_size()
[docs] def get_reward(self): return self.reward_function.get_global_reward()
[docs] def get_rewards(self): if self.config.network["control_scheme"] == "decentralized": return self.reward_function.get_local_reward() else: return [self.reward_function.get_global_reward()]
[docs] def get_env_info(self): stats = self.metrics.get_step_stats() stats.update({"episode_count": self.episode_count}) stats.update({"episode_limit": self.episode_limit}) if self.disrupted: stats.update({"n_domains": len(self.config.domain_classes)}) stats.update({"domain_class": self.config.current_domain_class}) return stats
[docs] def get_env_stats(self): stats = self.get_env_info() for v in self.simulator.step_measurements.values(): stats.update(v) return stats
[docs] def restart(self, reset=True): if self.episode_over: self.episode_count += 1 # self.observation_space.reset_dropped_lanes() if self.simulator.is_terminated: self.hour_count += 1 self.simulator.close_simulator() if reset: self.simulator.start_simulator() if self.domain_class is not None: self.config.set_domain_class(self.domain_class) self._init_traffic_signals() self._init_parsers()
[docs] def step(self, actions): self.action_space.apply(actions) self.simulator.simulator_step(n_steps=self.config.simulator["delta_time"]) self._update_ts_stats() return self.get_reward(), self.episode_over, self.get_env_info()