"""
Device detection for PyTorch backends.
Supports CUDA, MPS (Apple Silicon), CPU fallback.
"""

import os


def get_best_device() -> str:
    import torch
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
        return "mps"
    return "cpu"


def get_device_info() -> dict:
    try:
        import torch
    except ImportError:
        return {"device": "cpu", "device_name": "CPU (torch not installed)",
                "cuda_available": False, "mps_available": False,
                "supports_quantization": False, "recommended_dtype": "float32"}
    device = get_best_device()
    info = {
        "device": device,
        "cuda_available": torch.cuda.is_available(),
        "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
        "device_name": "", "supports_quantization": False, "recommended_dtype": "float32",
    }
    if device == "cuda":
        info["device_name"] = torch.cuda.get_device_name(0)
        info["vram_gb"] = round(torch.cuda.get_device_properties(0).total_mem / (1024**3), 1)
        info["supports_quantization"] = True
        info["recommended_dtype"] = "bfloat16"
    elif device == "mps":
        info["device_name"] = "Apple Silicon (MPS)"
        info["recommended_dtype"] = "float16"
    else:
        info["device_name"] = "CPU"
    return info


def get_torch_dtype(dtype_str: str):
    import torch
    return {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}.get(dtype_str, torch.float32)
