"""
LDA — Collapsed Gibbs Sampling.
Efficient implementation for Digital Humanities corpora.
"""

import numpy as np
import re
import os
from collections import Counter

STOP_WORDS = frozenset("""
a about above after again against all am an and any are aren't as at be because
been before being below between both but by can't cannot could couldn't did
didn't do does doesn't doing don't down during each few for from further get got
had hadn't has hasn't have haven't having he he'd he'll he's her here here's
hers herself him himself his how how's i i'd i'll i'm i've if in into is isn't
it it's its itself let's me more most mustn't my myself no nor not of off on
once only or other ought our ours ourselves out over own same shan't she she'd
she'll she's should shouldn't so some such than that that's the their theirs
them themselves then there there's these they they'd they'll they're they've
this those through to too under until up upon very was wasn't we we'd we'll
we're we've were weren't what what's when when's where where's which while who
who's whom why why's with won't would wouldn't you you'd you'll you're you've
your yours yourself yourselves also among another back became become besides
came come could de did done each even every first found get give go got great
however just know last let like long look made make many may might much must
never new now old one ones only onto part per put quite rather really said say
says second see seem seemed seeming seems several shall since so still such
take tell th than that the their them then there therefore these they thing
this thorough those though three through thus together too toward two under
upon us use used using various very via want was way well went were what whatever
when where whether which while will with within without work would yet
""".split())


def tokenize(text: str, min_len: int = 3) -> list[str]:
    tokens = re.findall(r'[a-z]+', text.lower())
    return [t for t in tokens if len(t) >= min_len and t not in STOP_WORDS]


class LDA:
    def __init__(self, n_topics=10, alpha=0.1, beta=0.01, n_iter=300, seed=42):
        self.K = n_topics
        self.alpha = alpha
        self.beta = beta
        self.n_iter = n_iter
        self.rng = np.random.RandomState(seed)

    def fit(self, documents, callback=None):
        self._build_vocab(documents)
        self._init_counts(documents)
        for it in range(self.n_iter):
            self._gibbs_sweep()
            if callback and it % 10 == 0:
                callback(it, self.n_iter)
        if callback:
            callback(self.n_iter, self.n_iter)
        return self

    def top_words(self, n=15):
        phi = self._topic_word_dist()
        result = []
        for k in range(self.K):
            idx = np.argsort(phi[k])[::-1][:n]
            result.append([(self.id2word[i], float(phi[k, i])) for i in idx])
        return result

    def doc_topic_dist(self):
        theta = (self.n_dk.astype(np.float64) + self.alpha)
        theta /= theta.sum(axis=1, keepdims=True)
        return theta

    def _build_vocab(self, documents):
        doc_freq = Counter()
        for doc in documents:
            doc_freq.update(set(doc))
        vocab = sorted(w for w, c in doc_freq.items() if c >= 2)
        self.word2id = {w: i for i, w in enumerate(vocab)}
        self.id2word = {i: w for w, i in self.word2id.items()}
        self.V = len(vocab)

    def _init_counts(self, documents):
        K, V = self.K, self.V
        D = len(documents)
        self.D = D
        self.docs = []
        for doc in documents:
            ids = [self.word2id[w] for w in doc if w in self.word2id]
            self.docs.append(np.array(ids, dtype=np.int32))
        self.z = []
        self.n_dk = np.zeros((D, K), dtype=np.int32)
        self.n_kw = np.zeros((K, V), dtype=np.int32)
        self.n_k = np.zeros(K, dtype=np.int32)
        for d, doc_ids in enumerate(self.docs):
            N = len(doc_ids)
            topics = self.rng.randint(0, K, size=N).astype(np.int32)
            self.z.append(topics)
            for i in range(N):
                k, w = topics[i], doc_ids[i]
                self.n_dk[d, k] += 1
                self.n_kw[k, w] += 1
                self.n_k[k] += 1

    def _gibbs_sweep(self):
        K, V = self.K, self.V
        alpha, beta = self.alpha, self.beta
        beta_V = beta * V
        for d in range(self.D):
            doc_ids = self.docs[d]
            topics = self.z[d]
            for i in range(len(doc_ids)):
                w = doc_ids[i]
                k_old = topics[i]
                self.n_dk[d, k_old] -= 1
                self.n_kw[k_old, w] -= 1
                self.n_k[k_old] -= 1
                p = (self.n_dk[d] + alpha) * ((self.n_kw[:, w] + beta) / (self.n_k + beta_V))
                p /= p.sum()
                k_new = self.rng.choice(K, p=p)
                topics[i] = k_new
                self.n_dk[d, k_new] += 1
                self.n_kw[k_new, w] += 1
                self.n_k[k_new] += 1

    def _topic_word_dist(self):
        phi = (self.n_kw.astype(np.float64) + self.beta)
        phi /= phi.sum(axis=1, keepdims=True)
        return phi


def load_corpus(folder_path):
    """Load .txt files, return (filenames, tokenised_docs, raw_texts)."""
    filenames, docs, raw_texts = [], [], []
    for fn in sorted(os.listdir(folder_path)):
        if fn.lower().endswith('.txt'):
            path = os.path.join(folder_path, fn)
            with open(path, 'r', encoding='utf-8', errors='replace') as f:
                text = f.read()
            tokens = tokenize(text)
            if tokens:
                filenames.append(fn)
                docs.append(tokens)
                raw_texts.append(text)
    return filenames, docs, raw_texts
