from collections import defaultdict def compute_distances(n, adj_list): # initialize distances with -1 for all nodes dist = [-1] * n # initialize distance of root node to 0 dist[0] = 0 # DFS to compute distances stack = [0] while stack: node = stack.pop() for neighbor in adj_list[node]: if dist[neighbor] == -1: dist[neighbor] = dist[node] + 1 stack.append(neighbor) return dist def solve_case(w, e, c, w_adj_list, e_adj_list, options): # compute distances in western and eastern trees w_dist = compute_distances(w, w_adj_list) e_dist = compute_distances(e, e_adj_list) # compute distances in complete map with each option dists = [] for a, b in options: # compute distance between a and b ab_dist = w_dist[a-1] + e_dist[b-1] + 1 # compute distance between all pairs of stations using w_dist, e_dist, and ab_dist total_dist = 0 for x in range(1, w+1): for y in range(1, e+1): if x <= w and y <= e: # both on western side or both on eastern side if x != a and