diff --git a/entrance_rando.py b/entrance_rando.py index 5ed2cd764596..492fff32e32d 100644 --- a/entrance_rando.py +++ b/entrance_rando.py @@ -52,13 +52,15 @@ def remove(self, entrance: Entrance) -> None: _coupled: bool _usable_exits: set[Entrance] - def __init__(self, rng: random.Random, coupled: bool, usable_exits: set[Entrance]): + def __init__(self, rng: random.Random, coupled: bool, usable_exits: set[Entrance], targets: Iterable[Entrance]): self.dead_ends = EntranceLookup.GroupLookup() self.others = EntranceLookup.GroupLookup() self._random = rng self._expands_graph_cache = {} self._coupled = coupled self._usable_exits = usable_exits + for target in targets: + self.add(target) def _can_expand_graph(self, entrance: Entrance) -> bool: """ @@ -121,7 +123,14 @@ def get_targets( dead_end: bool, preserve_group_order: bool ) -> Iterable[Entrance]: + """ + Gets available targets for the requested groups + :param groups: The groups to find targets for + :param dead_end: Whether to find dead ends. If false, finds non-dead-ends + :param preserve_group_order: Whether to preserve the group order in the returned iterable. If true, a sequence + like AAABBB is guaranteed. If false, groups can be interleaved, e.g. BAABAB. + """ lookup = self.dead_ends if dead_end else self.others if preserve_group_order: for group in groups: @@ -132,6 +141,27 @@ def get_targets( self._random.shuffle(ret) return ret + def find_target(self, name: str, group: int | None = None, dead_end: bool | None = None) -> Entrance | None: + """ + Finds a specific target in the lookup, if it is present. + + :param name: The name of the target + :param group: The target's group. Providing this will make the lookup faster, but can be omitted if it is not + known ahead of time for some reason. + :param dead_end: Whether the target is a dead end. Providing this will make the lookup faster, but can be + omitted if this is not known ahead of time (much more likely) + """ + if dead_end is None: + return (found + if (found := self.find_target(name, group, True)) + else self.find_target(name, group, False)) + lookup = self.dead_ends if dead_end else self.others + targets_to_check = lookup if group is None else lookup[group] + for target in targets_to_check: + if target.name == name: + return target + return None + def __len__(self): return len(self.dead_ends) + len(self.others) @@ -146,15 +176,18 @@ class ERPlacementState: """The world which is having its entrances randomized""" collection_state: CollectionState """The CollectionState backing the entrance randomization logic""" + entrance_lookup: EntranceLookup + """A lookup table of all unconnected ER targets""" coupled: bool """Whether entrance randomization is operating in coupled mode""" - def __init__(self, world: World, coupled: bool): + def __init__(self, world: World, entrance_lookup: EntranceLookup, coupled: bool): self.placements = [] self.pairings = [] self.world = world self.coupled = coupled self.collection_state = world.multiworld.get_all_state(False, True) + self.entrance_lookup = entrance_lookup @property def placed_regions(self) -> set[Region]: @@ -182,6 +215,7 @@ def _connect_one_way(self, source_exit: Entrance, target_entrance: Entrance) -> self.collection_state.stale[self.world.player] = True self.placements.append(source_exit) self.pairings.append((source_exit.name, target_entrance.name)) + self.entrance_lookup.remove(target_entrance) def test_speculative_connection(self, source_exit: Entrance, target_entrance: Entrance, usable_exits: set[Entrance]) -> bool: @@ -311,7 +345,7 @@ def randomize_entrances( preserve_group_order: bool = False, er_targets: list[Entrance] | None = None, exits: list[Entrance] | None = None, - on_connect: Callable[[ERPlacementState, list[Entrance]], None] | None = None + on_connect: Callable[[ERPlacementState, list[Entrance], list[Entrance]], bool | None] | None = None ) -> ERPlacementState: """ Randomizes Entrances for a single world in the multiworld. @@ -328,14 +362,18 @@ def randomize_entrances( :param exits: The list of exits (Entrance objects with no target region) to use for randomization. Remember to be deterministic! If not provided, automatically discovers all valid exits in your world. :param on_connect: A callback function which allows specifying side effects after a placement is completed - successfully and the underlying collection state has been updated. + successfully and the underlying collection state has been updated. The arguments are + 1. The ER state + 2. The exits placed in this placement pass + 3. The entrances they were connected to. + If you use on_connect to make additional placements, you are expected to return True to inform + GER that an additional sweep is needed. """ if not world.explicit_indirect_conditions: raise EntranceRandomizationError("Entrance randomization requires explicit indirect conditions in order " + "to correctly analyze whether dead end regions can be required in logic.") start_time = time.perf_counter() - er_state = ERPlacementState(world, coupled) # similar to fill, skip validity checks on entrances if the game is beatable on minimal accessibility perform_validity_check = True @@ -351,23 +389,25 @@ def randomize_entrances( # used when membership checks are needed on the exit list, e.g. speculative sweep exits_set = set(exits) - entrance_lookup = EntranceLookup(world.random, coupled, exits_set) - for entrance in er_targets: - entrance_lookup.add(entrance) + er_state = ERPlacementState( + world, + EntranceLookup(world.random, coupled, exits_set, er_targets), + coupled + ) # place the menu region and connected start region(s) er_state.collection_state.update_reachable_regions(world.player) def do_placement(source_exit: Entrance, target_entrance: Entrance) -> None: - placed_exits, removed_entrances = er_state.connect(source_exit, target_entrance) - # remove the placed targets from consideration - for entrance in removed_entrances: - entrance_lookup.remove(entrance) + placed_exits, paired_entrances = er_state.connect(source_exit, target_entrance) # propagate new connections er_state.collection_state.update_reachable_regions(world.player) er_state.collection_state.sweep_for_advancements() if on_connect: - on_connect(er_state, placed_exits) + change = on_connect(er_state, placed_exits, paired_entrances) + if change: + er_state.collection_state.update_reachable_regions(world.player) + er_state.collection_state.sweep_for_advancements() def needs_speculative_sweep(dead_end: bool, require_new_exits: bool, placeable_exits: list[Entrance]) -> bool: # speculative sweep is expensive. We currently only do it as a last resort, if we might cap off the graph @@ -388,12 +428,12 @@ def needs_speculative_sweep(dead_end: bool, require_new_exits: bool, placeable_e # check to see if we are proposing the last placement if not coupled: # in uncoupled, this check is easy as there will only be one target. - is_last_placement = len(entrance_lookup) == 1 + is_last_placement = len(er_state.entrance_lookup) == 1 else: # a bit harder, there may be 1 or 2 targets depending on if the exit to place is one way or two way. # if it is two way, we can safely assume that one of the targets is the logical pair of the exit. desired_target_count = 2 if placeable_exits[0].randomization_type == EntranceType.TWO_WAY else 1 - is_last_placement = len(entrance_lookup) == desired_target_count + is_last_placement = len(er_state.entrance_lookup) == desired_target_count # if it's not the last placement, we need a sweep return not is_last_placement @@ -402,7 +442,7 @@ def find_pairing(dead_end: bool, require_new_exits: bool) -> bool: placeable_exits = er_state.find_placeable_exits(perform_validity_check, exits) for source_exit in placeable_exits: target_groups = target_group_lookup[source_exit.randomization_group] - for target_entrance in entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order): + for target_entrance in er_state.entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order): # when requiring new exits, ideally we would like to make it so that every placement increases # (or keeps the same number of) reachable exits. The goal is to continue to expand the search space # so that we do not crash. In the interest of performance and bias reduction, generally, just checking @@ -420,7 +460,7 @@ def find_pairing(dead_end: bool, require_new_exits: bool) -> bool: else: # no source exits had any valid target so this stage is deadlocked. retries may be implemented if early # deadlocking is a frequent issue. - lookup = entrance_lookup.dead_ends if dead_end else entrance_lookup.others + lookup = er_state.entrance_lookup.dead_ends if dead_end else er_state.entrance_lookup.others # if we're in a stage where we're trying to get to new regions, we could also enter this # branch in a success state (when all regions of the preferred type have been placed, but there are still @@ -466,21 +506,21 @@ def find_pairing(dead_end: bool, require_new_exits: bool) -> bool: f"All unplaced exits: {unplaced_exits}") # stage 1 - try to place all the non-dead-end entrances - while entrance_lookup.others: + while er_state.entrance_lookup.others: if not find_pairing(dead_end=False, require_new_exits=True): break # stage 2 - try to place all the dead-end entrances - while entrance_lookup.dead_ends: + while er_state.entrance_lookup.dead_ends: if not find_pairing(dead_end=True, require_new_exits=True): break # stage 3 - all the regions should be placed at this point. We now need to connect dangling edges # stage 3a - get the rest of the dead ends (e.g. second entrances into already-visited regions) # doing this before the non-dead-ends is important to ensure there are enough connections to # go around - while entrance_lookup.dead_ends: + while er_state.entrance_lookup.dead_ends: find_pairing(dead_end=True, require_new_exits=False) # stage 3b - tie all the other loose ends connecting visited regions to each other - while entrance_lookup.others: + while er_state.entrance_lookup.others: find_pairing(dead_end=False, require_new_exits=False) running_time = time.perf_counter() - start_time diff --git a/test/general/test_entrance_rando.py b/test/general/test_entrance_rando.py index 56a059ecf2dd..4db4c49f4072 100644 --- a/test/general/test_entrance_rando.py +++ b/test/general/test_entrance_rando.py @@ -68,11 +68,9 @@ def test_shuffled_targets(self): exits_set = set([ex for region in multiworld.get_regions(1) for ex in region.exits if not ex.connected_region]) - lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set) er_targets = [entrance for region in multiworld.get_regions(1) for entrance in region.entrances if not entrance.parent_region] - for entrance in er_targets: - lookup.add(entrance) + lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets) retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM], False, False) @@ -91,11 +89,9 @@ def test_ordered_targets(self): exits_set = set([ex for region in multiworld.get_regions(1) for ex in region.exits if not ex.connected_region]) - lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set) er_targets = [entrance for region in multiworld.get_regions(1) for entrance in region.entrances if not entrance.parent_region] - for entrance in er_targets: - lookup.add(entrance) + lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets) retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM], False, True) @@ -111,12 +107,10 @@ def test_selective_dead_ends(self): for ex in region.exits if not ex.connected_region and ex.name != "region20_right" and ex.name != "region21_left"]) - lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set) er_targets = [entrance for region in multiworld.get_regions(1) for entrance in region.entrances if not entrance.parent_region and entrance.name != "region20_right" and entrance.name != "region21_left"] - for entrance in er_targets: - lookup.add(entrance) + lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets) # region 20 is the bottom left corner of the grid, and therefore only has a right entrance from region 21 # and a top entrance from region 15; since we've told lookup to ignore the right entrance from region 21, # the top entrance from region 15 should be considered a dead-end @@ -128,6 +122,56 @@ def test_selective_dead_ends(self): self.assertTrue(dead_end in lookup.dead_ends) self.assertEqual(len(lookup.dead_ends), 1) + def test_find_target_by_name(self): + """Tests that find_target can find the correct target by name only""" + multiworld = generate_test_multiworld() + generate_disconnected_region_grid(multiworld, 5) + exits_set = set([ex for region in multiworld.get_regions(1) + for ex in region.exits if not ex.connected_region]) + + er_targets = [entrance for region in multiworld.get_regions(1) + for entrance in region.entrances if not entrance.parent_region] + lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets) + + target = lookup.find_target("region0_right") + self.assertEqual(target.name, "region0_right") + self.assertEqual(target.randomization_group, ERTestGroups.RIGHT) + self.assertIsNone(lookup.find_target("nonexistant")) + + def test_find_target_by_name_and_group(self): + """Tests that find_target can find the correct target by name and group""" + multiworld = generate_test_multiworld() + generate_disconnected_region_grid(multiworld, 5) + exits_set = set([ex for region in multiworld.get_regions(1) + for ex in region.exits if not ex.connected_region]) + + er_targets = [entrance for region in multiworld.get_regions(1) + for entrance in region.entrances if not entrance.parent_region] + lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets) + + target = lookup.find_target("region0_right", ERTestGroups.RIGHT) + self.assertEqual(target.name, "region0_right") + self.assertEqual(target.randomization_group, ERTestGroups.RIGHT) + # wrong group + self.assertIsNone(lookup.find_target("region0_right", ERTestGroups.LEFT)) + + def test_find_target_by_name_and_group_and_category(self): + """Tests that find_target can find the correct target by name, group, and dead-endedness""" + multiworld = generate_test_multiworld() + generate_disconnected_region_grid(multiworld, 5) + exits_set = set([ex for region in multiworld.get_regions(1) + for ex in region.exits if not ex.connected_region]) + + er_targets = [entrance for region in multiworld.get_regions(1) + for entrance in region.entrances if not entrance.parent_region] + lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets) + + target = lookup.find_target("region0_right", ERTestGroups.RIGHT, False) + self.assertEqual(target.name, "region0_right") + self.assertEqual(target.randomization_group, ERTestGroups.RIGHT) + # wrong deadendedness + self.assertIsNone(lookup.find_target("region0_right", ERTestGroups.RIGHT, True)) + class TestBakeTargetGroupLookup(unittest.TestCase): def test_lookup_generation(self): multiworld = generate_test_multiworld() @@ -264,12 +308,12 @@ def test_coupled(self): generate_disconnected_region_grid(multiworld, 5) seen_placement_count = 0 - def verify_coupled(_: ERPlacementState, placed_entrances: list[Entrance]): + def verify_coupled(_: ERPlacementState, placed_exits: list[Entrance], placed_targets: list[Entrance]): nonlocal seen_placement_count - seen_placement_count += len(placed_entrances) - self.assertEqual(2, len(placed_entrances)) - self.assertEqual(placed_entrances[0].parent_region, placed_entrances[1].connected_region) - self.assertEqual(placed_entrances[1].parent_region, placed_entrances[0].connected_region) + seen_placement_count += len(placed_exits) + self.assertEqual(2, len(placed_exits)) + self.assertEqual(placed_exits[0].parent_region, placed_exits[1].connected_region) + self.assertEqual(placed_exits[1].parent_region, placed_exits[0].connected_region) result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup, on_connect=verify_coupled) @@ -312,10 +356,10 @@ def test_uncoupled(self): generate_disconnected_region_grid(multiworld, 5) seen_placement_count = 0 - def verify_uncoupled(state: ERPlacementState, placed_entrances: list[Entrance]): + def verify_uncoupled(state: ERPlacementState, placed_exits: list[Entrance], placed_targets: list[Entrance]): nonlocal seen_placement_count - seen_placement_count += len(placed_entrances) - self.assertEqual(1, len(placed_entrances)) + seen_placement_count += len(placed_exits) + self.assertEqual(1, len(placed_exits)) result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup, on_connect=verify_uncoupled)