import collections
from typing import Dict, List, Set, Tuple
[docs]
class UnionFind:
"""
A Union-Find (Disjoint Set Union) data structure with path compression and union by rank.
This implementation supports efficient union and find operations for managing
connected components, commonly used in graph algorithms like Kruskal's MST.
Examples:
>>> uf = UnionFind(10)
>>> uf.union(1, 2)
True
>>> uf.union(2, 3)
True
>>> uf.find(1) == uf.find(3)
True
>>> uf.union(1, 3)
False
"""
[docs]
def __init__(self, size: int):
"""
Initialize the Union-Find structure.
:param size: The number of elements in the set.
:type size: int
"""
self.parent: List[int] = list(range(size))
self.rank: List[int] = [0] * size
[docs]
def find(self, idx: int) -> int:
"""
Find the representative (root) of the set containing idx.
Uses path compression to flatten the tree structure for faster future lookups.
:param idx: The index to find.
:type idx: int
:return: The representative (root) of the set.
:rtype: int
"""
if self.parent[idx] != idx:
self.parent[idx] = self.find(self.parent[idx])
return self.parent[idx]
[docs]
def union(self, idx1: int, idx2: int) -> bool:
"""
Unite the sets containing idx1 and idx2.
Uses union by rank to keep the tree balanced.
:param idx1: First index.
:type idx1: int
:param idx2: Second index.
:type idx2: int
:return: True if a union was performed, False if already in same set.
:rtype: bool
"""
pp = self.find(idx1)
pq = self.find(idx2)
if pp == pq:
return False
if self.rank[pp] < self.rank[pq]:
self.parent[pp] = pq
elif self.rank[pp] > self.rank[pq]:
self.parent[pq] = pp
else:
self.parent[pq] = pp
self.rank[pp] += 1
return True
[docs]
def steiner_forest_grid(
height: int, width: int, pairs: List[Tuple[Tuple[int, int], Tuple[int, int]]]
) -> Tuple[List[Tuple[int, int, float]], float, Set[int], Set[int], Set[int]]:
"""
Computes an approximate Steiner forest on a grid graph.
The algorithm is based on the primal-dual approach for the Steiner
network problem. It iteratively pays for edges until all terminal
pairs are connected.
The following diagram illustrates the concept of a Steiner tree,
where 'o' represents terminals and '*' represents Steiner points.
.. svgbob::
+--.----------o
| `. |
| `. |
| `. |
o---------`---+
In our case, we are looking for a Steiner forest, which is a collection
of Steiner trees connecting specified pairs of terminals.
.. svgbob::
Grid Steiner Forest
.---.---.---. .---.---.---.
| S | | T | | S o---*---o T |
'---'---'---' '---'---|---'---'
| | | | | | | |
'---'---'---' '---'---'---'---'
| S | | T | | S o---*---o T |
'---'---'---' '---'---'---'---'
The algorithm works by "growing" paths from terminals. Each active
component (a connected component containing at least one terminal
that needs to be connected to a terminal in another component)
contributes to the cost of edges.
.. svgbob::
+--<---o
|
*-------->---o
|
o--->--+
When the cost paid for an edge equals its weight, the edge is
added to the forest.
.. svgbob::
+-o
|
v o
| |
o---<-------*--->---+
The following diagrams illustrate the concept in 3D as well.
.. svgbob::
+z
^
|
|
+-----> +x
/
v
+y
.. svgbob::
+
/| b
* +----<------o
/|
+ +---------->--------o c
a |
o-----<-----+
After the growing phase, a reverse-delete step is performed to
remove redundant edges from the forest.
>>> h = 2
>>> w = 2
>>> pairs = [((0, 0), (1, 1))]
>>> F_pruned, total_cost, sources, terminals, steiner_nodes = steiner_forest_grid(h, w, pairs)
>>> sorted(F_pruned)
[(0, 1, 1.0), (1, 3, 1.0)]
>>> total_cost
2.0
>>> sources
{0}
>>> terminals
{3}
>>> steiner_nodes
{1}
"""
n: int = height * width
uf: UnionFind = UnionFind(n)
sources: Set[int] = set()
terminals: Set[int] = set()
pair_dict: Dict[int, List[int]] = collections.defaultdict(list)
for (sx, sy), (tx, ty) in pairs:
source_idx = sx * width + sy
target_idx = tx * width + ty
sources.add(source_idx)
terminals.add(target_idx)
pair_dict[source_idx].append(target_idx)
pair_dict[target_idx].append(source_idx)
all_term: Set[int] = sources | terminals
# Generate all possible grid edges: horizontal, vertical, diagonal
edges: List[Tuple[int, int, float]] = []
# diag_cost = 1.0 # Unit cost for demonstration; alternatively use math.sqrt(2) for Euclidean distance
# diag_cost = 1.4142
for row_idx in range(height):
for col_idx in range(width):
node = row_idx * width + col_idx
# Horizontal
if col_idx + 1 < width:
edges.append((node, node + 1, 1.0))
# Vertical
if row_idx + 1 < height:
edges.append((node, node + width, 1.0))
# # Diagonal \
# if row_idx + 1 < h and col_idx + 1 < w:
# edges.append((node, node + w + 1, diag_cost))
# # Diagonal /
# if row_idx + 1 < h and col_idx - 1 >= 0:
# edges.append((node, node + w - 1, diag_cost))
paid: Dict[Tuple[int, int], float] = collections.defaultdict(float)
F: List[Tuple[int, int, float]] = [] # list of (u, v, c) added in order
while True:
# Compute term_root
term_root: Dict[int, int] = {term: uf.find(term) for term in all_term}
# Check if feasible
feasible = True
for source in pair_dict:
root_source = term_root[source]
for target in pair_dict[source]:
if term_root[target] != root_source:
feasible = False
break
if not feasible:
break
if feasible:
break
# Compute comp_terms
comp_terms: Dict[int, Set[int]] = collections.defaultdict(set)
for terminal in all_term:
comp_terms[term_root[terminal]].add(terminal)
# Compute active_comps
active_comps: Set[int] = set()
for root, terms in comp_terms.items():
is_active = False
for terminal in terms:
for partner in pair_dict[terminal]:
if term_root[partner] != root:
is_active = True
break
if is_active:
break
if is_active:
active_comps.add(root)
# Find min_delta and chosen edge(s)
min_delta = float("inf")
candidate_es: List[Tuple[int, int, float, Tuple[int, int]]] = []
for node_u, node_v, cost in edges:
if uf.find(node_u) == uf.find(node_v):
continue
root_u = uf.find(node_u)
root_v = uf.find(node_v)
num = 0
if root_u in active_comps:
num += 1
if root_v in active_comps:
num += 1
if num == 0:
continue
key = (min(node_u, node_v), max(node_u, node_v))
paid_val = paid[key]
if paid_val > cost:
continue
delta_e = (cost - paid_val) / num if num > 0 else float("inf")
if delta_e < min_delta:
min_delta = delta_e
candidate_es = [(node_u, node_v, cost, key)]
elif delta_e == min_delta:
candidate_es.append((node_u, node_v, cost, key))
if min_delta == float("inf"):
raise ValueError("Graph is not connected or cannot connect pairs")
# Pick first candidate
chosen_u, chosen_v, chosen_c, chosen_key = candidate_es[0]
# Update paid for all eligible edges
for u2, v2, c2 in edges:
if uf.find(u2) == uf.find(v2):
continue
ru2 = uf.find(u2)
rv2 = uf.find(v2)
num2 = 0
if ru2 in active_comps:
num2 += 1
if rv2 in active_comps:
num2 += 1
if num2 == 0:
continue
key2 = (min(u2, v2), max(u2, v2))
paid[key2] += min_delta * num2
if paid[key2] > c2 + 1e-6: # tolerance
paid[key2] = c2
# Add chosen edge if not overpaid
if paid[chosen_key] >= chosen_c - 1e-6:
F.append((chosen_u, chosen_v, chosen_c))
uf.union(chosen_u, chosen_v)
# Reverse delete
F_pruned: List[Tuple[int, int, float]] = F[:]
for idx in range(len(F) - 1, -1, -1):
temp_uf = UnionFind(n)
for jdx in range(len(F)):
if jdx != idx:
node_u, node_v, _ = F[jdx]
temp_uf.union(node_u, node_v)
connected = True
for source in sources:
for target in pair_dict[source]:
if temp_uf.find(source) != temp_uf.find(target):
connected = False
break
if not connected:
break
if connected:
del F_pruned[idx]
# Compute cost
total_cost: float = sum(cost for _, _, cost in F_pruned)
# Identify Steiner nodes
used_nodes: Set[int] = set()
for node_u, node_v, _ in F_pruned:
used_nodes.add(node_u)
used_nodes.add(node_v)
steiner_nodes: Set[int] = used_nodes - all_term
return F_pruned, total_cost, sources, terminals, steiner_nodes
[docs]
def generate_svg(
height: int,
width: int,
F_pruned: List[Tuple[int, int, float]],
sources: Set[int],
terminals: Set[int],
steiner_nodes: Set[int],
cell_size: int,
margin: int,
filename: str,
) -> None:
"""
Generate an SVG visualization of a Steiner forest on a grid.
This function creates an SVG image showing the grid with nodes (sources in red,
terminals in green, Steiner nodes in blue) and selected edges highlighted.
:param height: Grid height.
:type height: int
:param width: Grid width.
:type width: int
:param F_pruned: List of edges in the Steiner forest as (u, v, cost) tuples.
:type F_pruned: List[Tuple[int, int, float]]
:param sources: Set of source node indices.
:type sources: Set[int]
:param terminals: Set of terminal node indices.
:type terminals: Set[int]
:param steiner_nodes: Set of Steiner node indices.
:type steiner_nodes: Set[int]
:param cell_size: Size of each grid cell in pixels.
:type cell_size: int
:param margin: Margin around the grid in pixels.
:type margin: int
:param filename: Output SVG filename.
:type filename: str
"""
svg_width = width * cell_size + 2 * margin
svg_height = height * cell_size + 2 * margin
svg = f'<svg width="{svg_width}" height="{svg_height}" xmlns="http://www.w3.org/2000/svg">'
svg = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">'
# Grid lines horizontal
for row_idx in range(height + 1):
y_pos = margin + row_idx * cell_size
svg += f'<line x1="{margin}" y1="{y_pos}" x2="{svg_width - margin}" y2="{y_pos}" stroke="gray" stroke-width="1"/>'
# Vertical
for col_idx in range(width + 1):
x_pos = margin + col_idx * cell_size
svg += f'<line x1="{x_pos}" y1="{margin}" x2="{x_pos}" y2="{svg_height - margin}" stroke="gray" stroke-width="1"/>'
# Nodes
sources | terminals
for row_idx in range(height):
for col_idx in range(width):
cx = margin + col_idx * cell_size + cell_size / 2
cy = margin + row_idx * cell_size + cell_size / 2
node = row_idx * width + col_idx
if node in sources:
radius = 10
fill = "red"
elif node in terminals:
radius = 10
fill = "green"
elif node in steiner_nodes:
radius = 7
fill = "blue"
else:
radius = 5
fill = "black"
svg += f'<circle cx="{cx}" cy="{cy}" r="{radius}" fill="{fill}"/>'
svg += f'<text x="{cx}" y="{cy + 4}" font-size="10" text-anchor="middle">{node}</text>'
# Selected edges
for node_u, node_v, _cost in F_pruned:
ui, uj = divmod(node_u, width)
vi, vj = divmod(node_v, width)
ux = margin + uj * cell_size + cell_size / 2
uy = margin + ui * cell_size + cell_size / 2
vx = margin + vj * cell_size + cell_size / 2
vy = margin + vi * cell_size + cell_size / 2
svg += f'<line x1="{ux}" y1="{uy}" x2="{vx}" y2="{vy}" stroke="orange" stroke-width="5" opacity="0.5"/>'
svg += "</svg>"
# Write to SVG file
with open(filename, "w") as f:
f.write(svg)
print(f"SVG file '{filename}' generated successfully.")
[docs]
def main() -> None:
"""Main function to run the example."""
# Example parameters (modify as needed)
height: int = 8 # Height
width: int = 8 # Width
pairs: List[Tuple[Tuple[int, int], Tuple[int, int]]] = [
((0, 0), (3, 2)),
((0, 0), (0, 5)),
((4, 4), (7, 5)),
((4, 4), (5, 7)),
((0, 1), (4, 1)),
] # Terminal pairs
F_pruned, total_cost, sources, terminals, steiner_nodes = steiner_forest_grid(
height, width, pairs
)
print(f"Total cost: {total_cost}")
print(f"Edges: {F_pruned}")
generate_svg(
height,
width,
F_pruned,
sources,
terminals,
steiner_nodes,
cell_size=50,
margin=20,
filename="steiner_forest.svg",
)
if __name__ == "__main__":
main()