"""
ForceAtlas2 — Force-directed graph layout with network statistics.
Based on Jacomy et al. (2014).
"""

import numpy as np
from collections import defaultdict


class ForceAtlas2:
    def __init__(self, gravity=1.0, scaling_ratio=2.0, strong_gravity=False,
                 edge_weight_influence=1.0, jitter_tolerance=1.0,
                 barnes_hut=True, barnes_hut_theta=1.2, n_iter=500, seed=42):
        self.gravity = gravity
        self.scaling_ratio = scaling_ratio
        self.strong_gravity = strong_gravity
        self.edge_weight_influence = edge_weight_influence
        self.jitter_tolerance = jitter_tolerance
        self.n_iter = n_iter
        self.rng = np.random.RandomState(seed)

    def layout(self, nodes, edges, weights=None, callback=None):
        N = len(nodes)
        if N == 0:
            return {}
        node_idx = {n: i for i, n in enumerate(nodes)}
        pos = self.rng.randn(N, 2).astype(np.float64) * 10.0

        # Build edge arrays for vectorized operations
        src_idx = []
        tgt_idx = []
        edge_w = []
        for i, (s, t) in enumerate(edges):
            si, ti = node_idx.get(s), node_idx.get(t)
            if si is not None and ti is not None:
                src_idx.append(si)
                tgt_idx.append(ti)
                edge_w.append(float(weights[i]) if weights else 1.0)
        src_idx = np.array(src_idx, dtype=np.int32)
        tgt_idx = np.array(tgt_idx, dtype=np.int32)
        edge_w = np.array(edge_w, dtype=np.float64)

        if self.edge_weight_influence != 1.0:
            edge_w = edge_w ** self.edge_weight_influence

        # Mass = 1 + degree
        mass = np.ones(N, dtype=np.float64)
        for si in src_idx:
            mass[si] += 1
        for ti in tgt_idx:
            mass[ti] += 1

        kr = self.scaling_ratio
        speed = 1.0
        speed_efficiency = 1.0

        for it in range(self.n_iter):
            forces = np.zeros((N, 2), dtype=np.float64)

            # ── Vectorized repulsion (all-pairs via broadcasting) ────
            # For N < 5000 this is fast with NumPy; avoids Python loops entirely
            diff = pos[:, np.newaxis, :] - pos[np.newaxis, :, :]  # N x N x 2
            dist = np.sqrt((diff ** 2).sum(axis=2))  # N x N
            np.fill_diagonal(dist, 1.0)  # avoid div by zero
            mass_prod = mass[:, np.newaxis] * mass[np.newaxis, :]  # N x N
            rep_strength = kr * mass_prod / dist  # N x N
            np.fill_diagonal(rep_strength, 0.0)
            # Force direction: diff / dist, scaled by rep_strength
            rep_force = diff * (rep_strength / dist)[:, :, np.newaxis]  # N x N x 2
            forces += rep_force.sum(axis=1)  # N x 2

            # ── Vectorized attraction (edges) ────────────────────────
            if len(src_idx) > 0:
                edge_diff = pos[tgt_idx] - pos[src_idx]  # E x 2
                edge_force = edge_diff * edge_w[:, np.newaxis]  # E x 2
                np.add.at(forces, src_idx, edge_force)
                np.add.at(forces, tgt_idx, -edge_force)

            # ── Gravity ──────────────────────────────────────────────
            dist_to_center = np.sqrt((pos ** 2).sum(axis=1)) + 1e-9  # N
            if self.strong_gravity:
                grav_scale = -self.gravity * mass
            else:
                grav_scale = -self.gravity * mass / dist_to_center
            forces += pos * (grav_scale / dist_to_center)[:, np.newaxis]

            # ── Adaptive speed ───────────────────────────────────────
            force_mag = np.sqrt((forces ** 2).sum(axis=1))  # N
            global_swing = (force_mag * mass).sum()
            global_traction = (force_mag * mass).sum()

            if global_swing > 0:
                est = self.jitter_tolerance * global_traction / global_swing
                jt = min(np.sqrt(est), est)
                if global_swing / global_traction > 0.5:
                    speed_efficiency *= 0.7 if speed_efficiency > 0.05 else 1.0
                speed = jt * speed_efficiency

            # ── Apply forces ─────────────────────────────────────────
            node_speed = speed / (1.0 + speed * np.sqrt(force_mag + 1e-9))  # N
            pos += forces * node_speed[:, np.newaxis]

            if callback and it % 20 == 0:
                callback(it, self.n_iter)

        if callback:
            callback(self.n_iter, self.n_iter)

        return {nodes[i]: (float(pos[i, 0]), float(pos[i, 1])) for i in range(N)}


def compute_network_stats(nodes, edges, weights=None):
    """Compute degree statistics for all nodes."""
    degree = defaultdict(int)
    weighted_degree = defaultdict(float)

    for i, (s, t) in enumerate(edges):
        w = weights[i] if weights else 1.0
        degree[s] += 1
        degree[t] += 1
        weighted_degree[s] += w
        weighted_degree[t] += w

    stats = {}
    for n in nodes:
        stats[n] = {
            'degree': degree.get(n, 0),
            'weighted_degree': round(weighted_degree.get(n, 0), 3),
        }
    return stats


def parse_edge_csv(filepath: str):
    """Parse CSV with flexible column detection."""
    import csv
    with open(filepath, 'r', encoding='utf-8-sig', errors='replace') as f:
        reader = csv.DictReader(f)
        headers = [h.strip() for h in reader.fieldnames] if reader.fieldnames else []
        rows = list(reader)

    # Auto-detect source/target columns
    src_col = None
    tgt_col = None
    wt_col = None

    for h in headers:
        hl = h.lower().strip()
        if hl in ('source', 'src', 'from', 'node1', 'node_1'):
            src_col = h
        elif hl in ('target', 'tgt', 'to', 'dest', 'destination', 'node2', 'node_2'):
            tgt_col = h
        elif hl in ('weight', 'value', 'strength', 'w'):
            wt_col = h

    return {
        'headers': headers,
        'rows': rows,
        'detected_source': src_col,
        'detected_target': tgt_col,
        'detected_weight': wt_col,
    }


def build_graph(rows, src_col, tgt_col, wt_col=None, headers=None):
    """Build graph from parsed rows with given column mappings."""
    nodes_set = set()
    edges = []
    weights = []
    node_metadata = defaultdict(dict)

    extra_cols = [h for h in (headers or []) if h not in (src_col, tgt_col, wt_col)]

    for row in rows:
        s = row.get(src_col, '').strip()
        t = row.get(tgt_col, '').strip()
        if not s or not t:
            continue
        w = 1.0
        if wt_col and row.get(wt_col, '').strip():
            try:
                w = float(row[wt_col].strip())
            except ValueError:
                w = 1.0
        nodes_set.add(s)
        nodes_set.add(t)
        edges.append((s, t))
        weights.append(w)

        # Collect metadata
        for col in extra_cols:
            val = row.get(col, '').strip()
            if val:
                if col not in node_metadata[s]:
                    node_metadata[s][col] = val
                if col not in node_metadata[t]:
                    node_metadata[t][col] = val

    nodes = sorted(nodes_set)
    return nodes, edges, weights, dict(node_metadata)
