Skip to content

Core: Move GER EntranceLookup onto ERPlacementState. Improve usefulness of on_connect #4904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 61 additions & 21 deletions entrance_rando.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ def remove(self, entrance: Entrance) -> None:
_coupled: bool
_usable_exits: set[Entrance]

def __init__(self, rng: random.Random, coupled: bool, usable_exits: set[Entrance]):
def __init__(self, rng: random.Random, coupled: bool, usable_exits: set[Entrance], targets: Iterable[Entrance]):
self.dead_ends = EntranceLookup.GroupLookup()
self.others = EntranceLookup.GroupLookup()
self._random = rng
self._expands_graph_cache = {}
self._coupled = coupled
self._usable_exits = usable_exits
for target in targets:
self.add(target)

def _can_expand_graph(self, entrance: Entrance) -> bool:
"""
Expand Down Expand Up @@ -121,7 +123,14 @@ def get_targets(
dead_end: bool,
preserve_group_order: bool
) -> Iterable[Entrance]:
"""
Gets available targets for the requested groups

:param groups: The groups to find targets for
:param dead_end: Whether to find dead ends. If false, finds non-dead-ends
:param preserve_group_order: Whether to preserve the group order in the returned iterable. If true, a sequence
like AAABBB is guaranteed. If false, groups can be interleaved, e.g. BAABAB.
"""
lookup = self.dead_ends if dead_end else self.others
if preserve_group_order:
for group in groups:
Expand All @@ -132,6 +141,27 @@ def get_targets(
self._random.shuffle(ret)
return ret

def find_target(self, name: str, group: int | None = None, dead_end: bool | None = None) -> Entrance | None:
"""
Finds a specific target in the lookup, if it is present.

:param name: The name of the target
:param group: The target's group. Providing this will make the lookup faster, but can be omitted if it is not
known ahead of time for some reason.
:param dead_end: Whether the target is a dead end. Providing this will make the lookup faster, but can be
omitted if this is not known ahead of time (much more likely)
"""
if dead_end is None:
return (found
if (found := self.find_target(name, group, True))
else self.find_target(name, group, False))
lookup = self.dead_ends if dead_end else self.others
targets_to_check = lookup if group is None else lookup[group]
for target in targets_to_check:
if target.name == name:
return target
return None

def __len__(self):
return len(self.dead_ends) + len(self.others)

Expand All @@ -146,15 +176,18 @@ class ERPlacementState:
"""The world which is having its entrances randomized"""
collection_state: CollectionState
"""The CollectionState backing the entrance randomization logic"""
entrance_lookup: EntranceLookup
"""A lookup table of all unconnected ER targets"""
coupled: bool
"""Whether entrance randomization is operating in coupled mode"""

def __init__(self, world: World, coupled: bool):
def __init__(self, world: World, entrance_lookup: EntranceLookup, coupled: bool):
self.placements = []
self.pairings = []
self.world = world
self.coupled = coupled
self.collection_state = world.multiworld.get_all_state(False, True)
self.entrance_lookup = entrance_lookup

@property
def placed_regions(self) -> set[Region]:
Expand Down Expand Up @@ -182,6 +215,7 @@ def _connect_one_way(self, source_exit: Entrance, target_entrance: Entrance) ->
self.collection_state.stale[self.world.player] = True
self.placements.append(source_exit)
self.pairings.append((source_exit.name, target_entrance.name))
self.entrance_lookup.remove(target_entrance)

def test_speculative_connection(self, source_exit: Entrance, target_entrance: Entrance,
usable_exits: set[Entrance]) -> bool:
Expand Down Expand Up @@ -311,7 +345,7 @@ def randomize_entrances(
preserve_group_order: bool = False,
er_targets: list[Entrance] | None = None,
exits: list[Entrance] | None = None,
on_connect: Callable[[ERPlacementState, list[Entrance]], None] | None = None
on_connect: Callable[[ERPlacementState, list[Entrance], list[Entrance]], bool | None] | None = None
) -> ERPlacementState:
"""
Randomizes Entrances for a single world in the multiworld.
Expand All @@ -328,14 +362,18 @@ def randomize_entrances(
:param exits: The list of exits (Entrance objects with no target region) to use for randomization.
Remember to be deterministic! If not provided, automatically discovers all valid exits in your world.
:param on_connect: A callback function which allows specifying side effects after a placement is completed
successfully and the underlying collection state has been updated.
successfully and the underlying collection state has been updated. The arguments are
1. The ER state
2. The exits placed in this placement pass
3. The entrances they were connected to.
If you use on_connect to make additional placements, you are expected to return True to inform
GER that an additional sweep is needed.
"""
if not world.explicit_indirect_conditions:
raise EntranceRandomizationError("Entrance randomization requires explicit indirect conditions in order "
+ "to correctly analyze whether dead end regions can be required in logic.")

start_time = time.perf_counter()
er_state = ERPlacementState(world, coupled)
# similar to fill, skip validity checks on entrances if the game is beatable on minimal accessibility
perform_validity_check = True

Expand All @@ -351,23 +389,25 @@ def randomize_entrances(

# used when membership checks are needed on the exit list, e.g. speculative sweep
exits_set = set(exits)
entrance_lookup = EntranceLookup(world.random, coupled, exits_set)
for entrance in er_targets:
entrance_lookup.add(entrance)

er_state = ERPlacementState(
world,
EntranceLookup(world.random, coupled, exits_set, er_targets),
coupled
)
# place the menu region and connected start region(s)
er_state.collection_state.update_reachable_regions(world.player)

def do_placement(source_exit: Entrance, target_entrance: Entrance) -> None:
placed_exits, removed_entrances = er_state.connect(source_exit, target_entrance)
# remove the placed targets from consideration
for entrance in removed_entrances:
entrance_lookup.remove(entrance)
placed_exits, paired_entrances = er_state.connect(source_exit, target_entrance)
# propagate new connections
er_state.collection_state.update_reachable_regions(world.player)
er_state.collection_state.sweep_for_advancements()
if on_connect:
on_connect(er_state, placed_exits)
change = on_connect(er_state, placed_exits, paired_entrances)
if change:
er_state.collection_state.update_reachable_regions(world.player)
er_state.collection_state.sweep_for_advancements()

def needs_speculative_sweep(dead_end: bool, require_new_exits: bool, placeable_exits: list[Entrance]) -> bool:
# speculative sweep is expensive. We currently only do it as a last resort, if we might cap off the graph
Expand All @@ -388,12 +428,12 @@ def needs_speculative_sweep(dead_end: bool, require_new_exits: bool, placeable_e
# check to see if we are proposing the last placement
if not coupled:
# in uncoupled, this check is easy as there will only be one target.
is_last_placement = len(entrance_lookup) == 1
is_last_placement = len(er_state.entrance_lookup) == 1
else:
# a bit harder, there may be 1 or 2 targets depending on if the exit to place is one way or two way.
# if it is two way, we can safely assume that one of the targets is the logical pair of the exit.
desired_target_count = 2 if placeable_exits[0].randomization_type == EntranceType.TWO_WAY else 1
is_last_placement = len(entrance_lookup) == desired_target_count
is_last_placement = len(er_state.entrance_lookup) == desired_target_count
# if it's not the last placement, we need a sweep
return not is_last_placement

Expand All @@ -402,7 +442,7 @@ def find_pairing(dead_end: bool, require_new_exits: bool) -> bool:
placeable_exits = er_state.find_placeable_exits(perform_validity_check, exits)
for source_exit in placeable_exits:
target_groups = target_group_lookup[source_exit.randomization_group]
for target_entrance in entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order):
for target_entrance in er_state.entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order):
# when requiring new exits, ideally we would like to make it so that every placement increases
# (or keeps the same number of) reachable exits. The goal is to continue to expand the search space
# so that we do not crash. In the interest of performance and bias reduction, generally, just checking
Expand All @@ -420,7 +460,7 @@ def find_pairing(dead_end: bool, require_new_exits: bool) -> bool:
else:
# no source exits had any valid target so this stage is deadlocked. retries may be implemented if early
# deadlocking is a frequent issue.
lookup = entrance_lookup.dead_ends if dead_end else entrance_lookup.others
lookup = er_state.entrance_lookup.dead_ends if dead_end else er_state.entrance_lookup.others

# if we're in a stage where we're trying to get to new regions, we could also enter this
# branch in a success state (when all regions of the preferred type have been placed, but there are still
Expand Down Expand Up @@ -466,21 +506,21 @@ def find_pairing(dead_end: bool, require_new_exits: bool) -> bool:
f"All unplaced exits: {unplaced_exits}")

# stage 1 - try to place all the non-dead-end entrances
while entrance_lookup.others:
while er_state.entrance_lookup.others:
if not find_pairing(dead_end=False, require_new_exits=True):
break
# stage 2 - try to place all the dead-end entrances
while entrance_lookup.dead_ends:
while er_state.entrance_lookup.dead_ends:
if not find_pairing(dead_end=True, require_new_exits=True):
break
# stage 3 - all the regions should be placed at this point. We now need to connect dangling edges
# stage 3a - get the rest of the dead ends (e.g. second entrances into already-visited regions)
# doing this before the non-dead-ends is important to ensure there are enough connections to
# go around
while entrance_lookup.dead_ends:
while er_state.entrance_lookup.dead_ends:
find_pairing(dead_end=True, require_new_exits=False)
# stage 3b - tie all the other loose ends connecting visited regions to each other
while entrance_lookup.others:
while er_state.entrance_lookup.others:
find_pairing(dead_end=False, require_new_exits=False)

running_time = time.perf_counter() - start_time
Expand Down
78 changes: 61 additions & 17 deletions test/general/test_entrance_rando.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,9 @@ def test_shuffled_targets(self):
exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region])

lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set)
er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region]
for entrance in er_targets:
lookup.add(entrance)
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)

retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM],
False, False)
Expand All @@ -91,11 +89,9 @@ def test_ordered_targets(self):
exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region])

lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set)
er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region]
for entrance in er_targets:
lookup.add(entrance)
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)

retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM],
False, True)
Expand All @@ -111,12 +107,10 @@ def test_selective_dead_ends(self):
for ex in region.exits if not ex.connected_region
and ex.name != "region20_right" and ex.name != "region21_left"])

lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set)
er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region and
entrance.name != "region20_right" and entrance.name != "region21_left"]
for entrance in er_targets:
lookup.add(entrance)
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)
# region 20 is the bottom left corner of the grid, and therefore only has a right entrance from region 21
# and a top entrance from region 15; since we've told lookup to ignore the right entrance from region 21,
# the top entrance from region 15 should be considered a dead-end
Expand All @@ -128,6 +122,56 @@ def test_selective_dead_ends(self):
self.assertTrue(dead_end in lookup.dead_ends)
self.assertEqual(len(lookup.dead_ends), 1)

def test_find_target_by_name(self):
"""Tests that find_target can find the correct target by name only"""
multiworld = generate_test_multiworld()
generate_disconnected_region_grid(multiworld, 5)
exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region])

er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region]
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)

target = lookup.find_target("region0_right")
self.assertEqual(target.name, "region0_right")
self.assertEqual(target.randomization_group, ERTestGroups.RIGHT)
self.assertIsNone(lookup.find_target("nonexistant"))

def test_find_target_by_name_and_group(self):
"""Tests that find_target can find the correct target by name and group"""
multiworld = generate_test_multiworld()
generate_disconnected_region_grid(multiworld, 5)
exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region])

er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region]
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)

target = lookup.find_target("region0_right", ERTestGroups.RIGHT)
self.assertEqual(target.name, "region0_right")
self.assertEqual(target.randomization_group, ERTestGroups.RIGHT)
# wrong group
self.assertIsNone(lookup.find_target("region0_right", ERTestGroups.LEFT))

def test_find_target_by_name_and_group_and_category(self):
"""Tests that find_target can find the correct target by name, group, and dead-endedness"""
multiworld = generate_test_multiworld()
generate_disconnected_region_grid(multiworld, 5)
exits_set = set([ex for region in multiworld.get_regions(1)
for ex in region.exits if not ex.connected_region])

er_targets = [entrance for region in multiworld.get_regions(1)
for entrance in region.entrances if not entrance.parent_region]
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True, usable_exits=exits_set, targets=er_targets)

target = lookup.find_target("region0_right", ERTestGroups.RIGHT, False)
self.assertEqual(target.name, "region0_right")
self.assertEqual(target.randomization_group, ERTestGroups.RIGHT)
# wrong deadendedness
self.assertIsNone(lookup.find_target("region0_right", ERTestGroups.RIGHT, True))

class TestBakeTargetGroupLookup(unittest.TestCase):
def test_lookup_generation(self):
multiworld = generate_test_multiworld()
Expand Down Expand Up @@ -264,12 +308,12 @@ def test_coupled(self):
generate_disconnected_region_grid(multiworld, 5)
seen_placement_count = 0

def verify_coupled(_: ERPlacementState, placed_entrances: list[Entrance]):
def verify_coupled(_: ERPlacementState, placed_exits: list[Entrance], placed_targets: list[Entrance]):
nonlocal seen_placement_count
seen_placement_count += len(placed_entrances)
self.assertEqual(2, len(placed_entrances))
self.assertEqual(placed_entrances[0].parent_region, placed_entrances[1].connected_region)
self.assertEqual(placed_entrances[1].parent_region, placed_entrances[0].connected_region)
seen_placement_count += len(placed_exits)
self.assertEqual(2, len(placed_exits))
self.assertEqual(placed_exits[0].parent_region, placed_exits[1].connected_region)
self.assertEqual(placed_exits[1].parent_region, placed_exits[0].connected_region)

result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup,
on_connect=verify_coupled)
Expand Down Expand Up @@ -312,10 +356,10 @@ def test_uncoupled(self):
generate_disconnected_region_grid(multiworld, 5)
seen_placement_count = 0

def verify_uncoupled(state: ERPlacementState, placed_entrances: list[Entrance]):
def verify_uncoupled(state: ERPlacementState, placed_exits: list[Entrance], placed_targets: list[Entrance]):
nonlocal seen_placement_count
seen_placement_count += len(placed_entrances)
self.assertEqual(1, len(placed_entrances))
seen_placement_count += len(placed_exits)
self.assertEqual(1, len(placed_exits))

result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup,
on_connect=verify_uncoupled)
Expand Down
Loading