package jlama;

import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.jlama.JlamaChatModel;
import dev.langchain4j.service.AiServices;

import java.io.IOException;
import java.util.List;

public class LLMWorker extends Thread {
    public static final int ARTICLE_BATCH_SIZE = 5;
    private static final int MAX_CONSOLIDATION_CALL_COUNT = 3;
    private static final int MAX_TOKENS = 2048;
    public static final float TEMPERATURE = 0.3f;
    private static final long MAX_CONSOLIDATED_KEYWORD_COUNT = 200;

    private final ArticleLoader articleLoader = new ArticleLoader();
    private final ChatModel model = JlamaChatModel.builder()
            .modelCachePath(ModelInstaller.MODEL_CACHE_PATH)
            .modelName(ModelInstaller.MODEL_NAME)
            .temperature(TEMPERATURE)
            .maxTokens(MAX_TOKENS)
            .build();
    private final KeywordService keywordService = createService(KeywordService.class, null);
    private final KeywordConsolidationService keywordConsolidationService = createService(KeywordConsolidationService.class, new KeywordConsolidationTools());
    private final KeywordContainer keywordContainer;
    private final StatusMessageSink statusMessageSink;

    private int consolidationCallCount;

    public LLMWorker(KeywordContainer keywordContainer,StatusMessageSink statusMessageSink) {
        this.keywordContainer = keywordContainer;
        this.statusMessageSink = statusMessageSink;
    }

    private <T> T createService(Class<T> serviceClass, Object tools) {
        AiServices<T> aiServices = AiServices.builder(serviceClass)
                .chatModel(model)
                .systemMessage("You are a keyword service for academic papers.");
        if (tools != null) {
            aiServices.tools(tools);
        }
        return aiServices.build();
    }

    @Override
    public void run() {
        int start = 0;
        //noinspection InfiniteLoopStatement
        while (true) {
            try {
                statusMessageSink.updateStatusMessage("Loading " + ARTICLE_BATCH_SIZE + " article summaries from arxiv.org ...");
                List<Article> articles = articleLoader.loadArticles(start, ARTICLE_BATCH_SIZE);
                for (Article article : articles) {
                    List<Keyword> keywords = generateKeywords(article.title(), article.summary());
                    for (Keyword keyword : keywords) {
                        keywordContainer.addArticle(keyword, article);
                    }
                }
                consolidateKeywords();
                start += ARTICLE_BATCH_SIZE;
            } catch (IOException e) {
                e.printStackTrace();
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException ignored) {
                }
            }
        }
    }

    private List<Keyword> generateKeywords(String title, String summary) throws IOException {
        statusMessageSink.updateStatusMessage("LLM is analyzing \"" + title + "\"");
        String text = keywordService.getKeywords(title, summary);
        return KeywordResponseParser.parseKeywords(text, articleLoader);
    }

    private void consolidateKeywords() {
        statusMessageSink.updateStatusMessage("LLM is consolidating keywords ...");
        consolidationCallCount = 0;
        try {
            List<Keyword> keywords = keywordContainer.getAllKeywords().stream().limit(MAX_CONSOLIDATED_KEYWORD_COUNT).toList();
            if (!keywords.isEmpty()) {
                keywordConsolidationService.consolidateKeywords(keywords.stream().map(Keyword::text).toList());
            }
        } catch (Exception e) {
            System.err.println("Tool calls failed: " + e.getMessage());
        }
    }

    private final class KeywordConsolidationTools {
        @SuppressWarnings("unused")
        @Tool("Consolidate two similar keywords")
        void consolidateKeywords(@P("First keyword") String firstKeyword, @P("Second keyword") String secondKeyword) throws IOException {
            if (consolidationCallCount++ < MAX_CONSOLIDATION_CALL_COUNT) {
                keywordContainer.merge(firstKeyword, secondKeyword);
            }
        }
    }

    public interface StatusMessageSink {
        void updateStatusMessage(String message);
    }
}
