package jlama;

import dev.langchain4j.internal.Json;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class KeywordResponseParser {
    private static final int MIN_RESULT_COUNT = 1;

    public static List<Keyword> parseKeywords(String response, ArticleLoader articleLoader) {
        return parseKeywordsRaw(stripFormattingLines(response)).stream()
                .map(KeywordResponseParser::stripKeywordFormatting)
                .map(String::trim)
                .filter(text -> !text.isEmpty())
                .distinct()
                .map(s -> new Keyword(s, articleLoader.getKeywordResultCount(s)))
                .filter(k -> k.resultCount() == -1 || k.resultCount() > MIN_RESULT_COUNT)
                .limit(3)
                .toList();
    }

    private static String stripFormattingLines(String response) {
        return Arrays.stream(response.split("\\R"))
                .filter(s -> !s.startsWith("```"))
                .collect(Collectors.joining("\n"));
    }

    private static List<String> parseKeywordsRaw(String response) {
        if (response == null) {
            return new ArrayList<>();
        }
        List<String> jsonKeywords = parseJsonArrayKeywords(response);
        if (!jsonKeywords.isEmpty()) {
            return jsonKeywords;
        }
        List<String> unquotedKeywords = parseUnquotedListKeywords(response);
        if (!unquotedKeywords.isEmpty()) {
            return unquotedKeywords;
        }
        List<String> jsonObjectKeywords = parseJsonObjectKeywords(response);
        if (!jsonObjectKeywords.isEmpty()) {
            return jsonObjectKeywords;
        }
        List<String> numberedKeywords = parseNumberedKeywords(response);
        if (!numberedKeywords.isEmpty()) {
            return numberedKeywords;
        }
        List<String> bulletKeywords = parseBulletKeywords(response);
        if (!bulletKeywords.isEmpty()) {
            return bulletKeywords;
        }
        System.err.println("No keywords found in response " + response);
        return new ArrayList<>();
    }

    private static List<String> parseJsonArrayKeywords(String response) {
        try {
            List keywordList = Json.fromJson(response, List.class);
            if (keywordList == null) {
                return new ArrayList<>();
            }
            List<String> results = new ArrayList<>();
            for (Object keyword : keywordList) {
                if (keyword instanceof String text) {
                    results.add(text);
                }
            }
            return results;
        } catch (RuntimeException e) {
            return new ArrayList<>();
        }
    }

    private static List<String> parseJsonObjectKeywords(String response) {
        try {
            Map jsonObject = Json.fromJson(response, Map.class);
            if (jsonObject == null || jsonObject.isEmpty()) {
                return new ArrayList<>();
            }
            List<String> results = new ArrayList<>();
            for (Object value : jsonObject.values()) {
                if (value instanceof String text) {
                    results.add(text);
                }
            }
            return results;
        } catch (RuntimeException e) {
            return new ArrayList<>();
        }
    }

    private static List<String> parseNumberedKeywords(String response) {
        List<String> results = new ArrayList<>();
        String[] lines = response.split("\\R");
        int startIndex = -1;
        for (int i = 0; i < lines.length; i++) {
            String line = lines[i].trim();
            if (line.startsWith("1.")) {
                startIndex = i;
                break;
            }
        }
        if (startIndex == -1) {
            return results;
        }
        for (int i = startIndex; i < lines.length; i++) {
            String line = lines[i].trim();
            if (line.isEmpty()) {
                continue;
            }
            int dot = line.indexOf('.');
            if (dot <= 0) {
                break;
            }
            String keyword = line.substring(dot + 1).trim();
            if (keyword.endsWith(",")) {
                keyword = keyword.substring(0, keyword.length() - 1).trim();
            }
            results.add(keyword);
        }
        return results;
    }

    private static List<String> parseUnquotedListKeywords(String response) {
        int start = response.indexOf('[');
        if (start < 0) {
            return new ArrayList<>();
        }
        int end = response.indexOf(']', start);
        if (end < 0) {
            return new ArrayList<>();
        }
        String body = response.substring(start + 1, end);
        List<String> results = new ArrayList<>();
        for (String part : body.split(",")) {
            String keyword = part.trim();
            if (!keyword.isEmpty()) {
                results.add(keyword);
            }
        }
        return results;
    }

    private static List<String> parseBulletKeywords(String response) {
        List<String> results = new ArrayList<>();
        String[] lines = response.split("\\R");
        for (String line : lines) {
            String trimmedLine = line.trim();
            if (trimmedLine.startsWith("- ") || trimmedLine.startsWith("* ")) {
                String keyword = trimmedLine.substring(2).trim();
                if (keyword.endsWith(",")) {
                    keyword = keyword.substring(0, keyword.length() - 1).trim();
                }
                if (!keyword.isEmpty()) {
                    results.add(keyword);
                }
            }
        }
        return results;
    }

    private static String stripKeywordFormatting(String value) {
        if (value.startsWith("**") && value.endsWith("**") && value.length() > 4) {
            return value.substring(2, value.length() - 2).trim();
        } else if (value.startsWith("*") && value.endsWith("*") && value.length() > 2) {
            return value.substring(1, value.length() - 1).trim();
        } else {
            return value;
        }
    }
}
