from sentence_transformers import CrossEncoder  # type: ignore
from sentence_transformers import SentenceTransformer  # type: ignore
from transformers import AutoTokenizer  # type: ignore
from transformers import TFDistilBertForSequenceClassification  # type: ignore

from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import INTENT_MODEL_VERSION
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
from danswer.configs.model_configs import SKIP_RERANKING


_TOKENIZER: None | AutoTokenizer = None
_EMBED_MODEL: None | SentenceTransformer = None
_RERANK_MODELS: None | list[CrossEncoder] = None
_INTENT_TOKENIZER: None | AutoTokenizer = None
_INTENT_MODEL: None | TFDistilBertForSequenceClassification = None


def get_default_tokenizer() -> AutoTokenizer:
    global _TOKENIZER
    if _TOKENIZER is None:
        _TOKENIZER = AutoTokenizer.from_pretrained(DOCUMENT_ENCODER_MODEL)
    return _TOKENIZER


def get_default_embedding_model() -> SentenceTransformer:
    global _EMBED_MODEL
    if _EMBED_MODEL is None:
        _EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
        _EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
    return _EMBED_MODEL


def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
    global _RERANK_MODELS
    if _RERANK_MODELS is None:
        _RERANK_MODELS = [
            CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
        ]
        for model in _RERANK_MODELS:
            model.max_length = CROSS_EMBED_CONTEXT_SIZE
    return _RERANK_MODELS


def get_default_intent_model_tokenizer() -> AutoTokenizer:
    global _INTENT_TOKENIZER
    if _INTENT_TOKENIZER is None:
        _INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_VERSION)
    return _INTENT_TOKENIZER


def get_default_intent_model() -> TFDistilBertForSequenceClassification:
    global _INTENT_MODEL
    if _INTENT_MODEL is None:
        _INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
            INTENT_MODEL_VERSION
        )
        _INTENT_MODEL.max_seq_length = QUERY_MAX_CONTEXT_SIZE
    return _INTENT_MODEL


def warm_up_models(
    indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
) -> None:
    warm_up_str = "Danswer is amazing"
    get_default_tokenizer()(warm_up_str)
    get_default_embedding_model().encode(warm_up_str)

    if indexer_only:
        return

    if not skip_cross_encoders:
        cross_encoders = get_default_reranking_model_ensemble()
        [
            cross_encoder.predict((warm_up_str, warm_up_str))
            for cross_encoder in cross_encoders
        ]

    intent_tokenizer = get_default_intent_model_tokenizer()
    inputs = intent_tokenizer(
        warm_up_str, return_tensors="tf", truncation=True, padding=True
    )
    get_default_intent_model()(inputs)