Source code for flatland.utils.env_edit_utils

from numpy.random import RandomState

import flatland.envs.observations as obs
import flatland.envs.rail_generators as rg
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.line_generators import BaseLineGen
from flatland.envs.rail_env import RailEnv
from flatland.envs.timetable_utils import Line
from flatland.utils import editor


# Start and end all agents at the same place
[docs]class SchedGen2(BaseLineGen): def __init__(self, rcStart, rcEnd, iDir): self.rcStart = rcStart self.rcEnd = rcEnd self.iDir = iDir
[docs] def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict = None, num_resets: int = None, np_random: RandomState = None) -> Line: return Line(agent_positions=[self.rcStart] * num_agents, agent_directions=[self.iDir] * num_agents, agent_targets=[self.rcEnd] * num_agents, agent_speeds=[1.0] * num_agents)
# cycle through lists of start, end and direction
[docs]class SchedGen3(BaseLineGen): def __init__(self, lrcStarts, lrcTargs, liDirs): self.lrcStarts = lrcStarts self.lrcTargs = lrcTargs self.liDirs = liDirs
[docs] def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict = None, num_resets: int = None, np_random: RandomState = None) -> Line: return Line(agent_positions=[self.lrcStarts[i % len(self.lrcStarts)] for i in range(num_agents)], agent_directions=[self.liDirs[i % len(self.liDirs)] for i in range(num_agents)], agent_targets=[self.lrcTargs[i % len(self.lrcTargs)] for i in range(num_agents)], agent_speeds=[1.0] * num_agents)
[docs]def makeEnv(nAg=2, width=20, height=10, oSG=None): env = RailEnv(width=width, height=height, rail_generator=rg.empty_rail_generator(), number_of_agents=nAg, line_generator=oSG, obs_builder_object=obs.TreeObsForRailEnv(max_depth=1)) envModel = editor.EditorModel(env) env.reset() return env, envModel
[docs]def makeEnv2(nAg=2, shape=(20, 10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDirs=[], remove_agents_at_target=True): oSG = SchedGen3(lrcStarts, lrcTargs, liDirs) env = RailEnv(width=shape[0], height=shape[1], rail_generator=rg.empty_rail_generator(), number_of_agents=nAg, line_generator=oSG, obs_builder_object=obs.TreeObsForRailEnv(max_depth=1), remove_agents_at_target=remove_agents_at_target, record_steps=True) envModel = editor.EditorModel(env) env.reset() for lrcPath in llrcPaths: envModel.mod_rail_cell_seq(envModel.interpolate_path(lrcPath)) return env, envModel
ddEnvSpecs = { # opposing stations with single alternative path "single_alternative": { "llrcPaths": [ [(1, 0), (1, 15)], # across the top [(1, 4), (1, 6), (3, 6), (3, 12), (1, 12), (1, 14)], # alternative loop below ], "lrcStarts": [(1, 3), (1, 14)], "lrcTargs": [(1, 14), (1, 3)], "liDirs": [1, 3] }, # single spur so one agent needs to wait "single_spur": { "llrcPaths": [ [(1, 0), (1, 15)], [(4, 0), (4, 6), (1, 6), (1, 8)]], "lrcStarts": [(1, 3), (1, 14)], "lrcTargs": [(1, 14), (4, 2)], "liDirs": [1, 3] }, # single spur so one agent needs to wait "merging_spurs": { "llrcPaths": [ [(1, 0), (1, 15), (7, 15), (7, 0)], [(4, 0), (4, 6), (1, 6), (1, 8)], # [((1,14), (1,16), (7,16), )] ], "lrcStarts": [(1, 2), (4, 2)], "lrcTargs": [(7, 3)], "liDirs": [1] }, # Concentric Loops "concentric_loops": { "llrcPaths": [ [(1, 1), (1, 5), (8, 5), (8, 1), (1, 1), (1, 3)], [(1, 3), (1, 10), (8, 10), (8, 3)] ], "lrcStarts": [(1, 3)], "lrcTargs": [(2, 1)], "liDirs": [1] }, # two loops "loop_with_loops": { "llrcPaths": [ # big outer loop Row 1, 8; Col 1, 15 [(1, 1), (1, 15), (8, 15), (8, 1), (1, 1), (1, 3)], # alternative 1 [(1, 3), (1, 5), (3, 5), (3, 10), (1, 10), (1, 12)], # alternative 2 [(8, 3), (8, 5), (6, 5), (6, 10), (8, 10), (8, 12)], ], # list of row,col of agent start cells "lrcStarts": [(1, 3), (8, 3)], # list of row,col of targets "lrcTargs": [(8, 2), (1, 2)], # list of initial directions "liDirs": [1, 1], } }
[docs]def makeTestEnv(sName="single_alternative", nAg=2, remove_agents_at_target=True): global ddEnvSpecs dSpec = ddEnvSpecs[sName] return makeEnv2(nAg=nAg, remove_agents_at_target=remove_agents_at_target, **dSpec)
[docs]def getAgentState(env): dAgState = {} for iAg, ag in enumerate(env.agents): dAgState[iAg] = (*ag.position, ag.direction) return dAgState