/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors.prompt;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import com.github.tjake.jlama.safetensors.prompt.Tool;
import com.github.tjake.jlama.safetensors.prompt.ToolCall;
import com.github.tjake.jlama.safetensors.prompt.ToolResult;
import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel;
import com.github.tjake.jlama.util.JsonSupport;
import com.hubspot.jinjava.Jinjava;
import com.hubspot.jinjava.JinjavaConfig;
import com.hubspot.jinjava.LegacyOverrides;
import com.hubspot.jinjava.interpret.RenderResult;
import com.hubspot.jinjava.lib.fn.ELFunctionDefinition;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PromptSupport {
    private static final Logger logger = LoggerFactory.getLogger(PromptSupport.class);
    private static final Jinjava jinjava = new Jinjava(JinjavaConfig.newBuilder().withTrimBlocks(true).withLstripBlocks(true).withLegacyOverrides(LegacyOverrides.newBuilder().withParseWhitespaceControlStrictly(true).withUseTrimmingForNotesAndExpressions(true).withUseSnakeCasePropertyNaming(true).withKeepNullableLoopValues(true).build()).withObjectMapper(new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT).setDefaultPrettyPrinter(JsonSupport.JlamaPrettyPrinter.INSTANCE)).build());
    private final TokenizerModel m;

    public PromptSupport(TokenizerModel model) {
        this.m = model;
    }

    public Builder builder() {
        return new Builder(this.m);
    }

    public static void raiseException(String message) {
        logger.debug("Prompt template error: " + message);
    }

    static {
        jinjava.getGlobalContext().registerFunction(new ELFunctionDefinition("", "raise_exception", PromptSupport.class, "raiseException", String.class));
    }

    public static class Builder {
        private final TokenizerModel m;
        private PromptType type = PromptType.DEFAULT;
        private boolean addGenerationPrompt = true;
        private List<Message> messages = new ArrayList<Message>(2);
        private boolean stripPreamble = false;

        private Builder(TokenizerModel m) {
            this.m = m;
        }

        public Builder usePromptType(PromptType type) {
            this.type = type;
            return this;
        }

        public Builder addGenerationPrompt(boolean addGenerationPrompt) {
            this.addGenerationPrompt = addGenerationPrompt;
            return this;
        }

        public Builder addUserMessage(String content) {
            this.messages.add(new Message(content, PromptRole.USER));
            return this;
        }

        public Builder addToolResult(ToolResult result) {
            this.messages.add(new Message(result));
            return this;
        }

        public Builder addToolCall(ToolCall call) {
            this.messages.add(new Message(call));
            return this;
        }

        public Builder addSystemMessage(String content) {
            this.messages.add(new Message(content, PromptRole.SYSTEM));
            return this;
        }

        public Builder addAssistantMessage(String content) {
            this.messages.add(new Message(content, PromptRole.ASSISTANT));
            return this;
        }

        public Builder stripPreamble() {
            this.stripPreamble = true;
            return this;
        }

        public PromptContext build() {
            return this.build(Optional.empty());
        }

        public PromptContext build(List<Tool> tools) {
            return this.build(Optional.of(tools));
        }

        public PromptContext build(Tool ... tools) {
            return this.build(Optional.of(List.of(tools)));
        }

        private PromptContext build(Optional<List<Tool>> optionalTools) {
            RenderResult r;
            HashMap<String, String> args;
            if (this.messages.isEmpty()) {
                throw new IllegalArgumentException("No messages to generate prompt");
            }
            if (this.m.promptTemplates().isEmpty()) {
                throw new UnsupportedOperationException("Prompt templates are not available for this model");
            }
            String template = this.m.promptTemplates().map(t -> (String)t.get(this.type.name().toLowerCase())).orElseThrow(() -> new UnsupportedOperationException("Prompt template not available for type: " + String.valueOf((Object)this.type)));
            if (optionalTools.isPresent() && !optionalTools.get().isEmpty() && !this.m.hasToolSupport()) {
                logger.warn("This model does not support tools, but tools are specified");
            }
            String preamble = "";
            if (this.stripPreamble) {
                args = new HashMap<String, String>();
                args.putAll(Map.of("messages", Map.of(), "add_generation_prompt", false, "eos_token", this.m.eosToken(), "bos_token", ""));
                optionalTools.ifPresent(tools -> args.put("tools", (String)tools));
                r = jinjava.renderForResult(template, args);
                preamble = r.getOutput();
            }
            args = new HashMap();
            args.putAll(Map.of("messages", this.messages.stream().map(Message::toMap).toList(), "add_generation_prompt", this.addGenerationPrompt, "eos_token", this.m.eosToken(), "bos_token", ""));
            optionalTools.ifPresent(tools -> args.put("tools", (String)tools));
            r = jinjava.renderForResult(template, args);
            if (r.hasErrors()) {
                logger.debug("Prompt template errors: " + String.valueOf(r.getErrors()));
            }
            String output = r.getOutput();
            return new PromptContext(output.substring(preamble.length()), optionalTools);
        }
    }

    static class InnerToolCall {
        private final ToolCall call;

        private InnerToolCall(ToolCall call) {
            this.call = call;
        }

        public Map<String, Object> arguments() {
            return this.call.getParameters();
        }

        public String name() {
            return this.call.getName();
        }
    }

    static class ToolCallFunction {
        private final ToolCall call;

        private ToolCallFunction(ToolCall call) {
            this.call = call;
        }

        public InnerToolCall function() {
            return new InnerToolCall(this.call);
        }

        public Map toMap() {
            LinkedHashMap<String, Object> args = new LinkedHashMap<String, Object>();
            args.put("name", this.call.getName());
            args.put("arguments", this.call.getParameters());
            return Map.of("function", args, "id", this.call.getId());
        }
    }

    static class Message {
        private final Object content;
        private final PromptRole role;
        private final ToolCallFunction toolCalls;
        private final String toolCallId;

        private Message(Object content, PromptRole role) {
            this.content = content;
            this.role = role;
            this.toolCalls = null;
            this.toolCallId = null;
        }

        private Message(ToolCall toolCall) {
            this.content = null;
            this.role = PromptRole.TOOL_CALL;
            this.toolCalls = new ToolCallFunction(toolCall);
            this.toolCallId = toolCall.getId();
        }

        private Message(ToolResult toolResult) {
            this.content = toolResult.toJson();
            this.toolCalls = null;
            this.role = PromptRole.TOOL;
            this.toolCallId = toolResult.getToolCallId();
        }

        public Object getContent() {
            return this.content;
        }

        public Map toMap() {
            HashMap<String, Object> map = new HashMap<String, Object>();
            map.put("role", this.role.name().toLowerCase());
            map.put("content", this.content == null ? "" : this.content);
            if (this.toolCalls != null) {
                map.put("tool_calls", List.of(this.toolCalls.toMap()));
            }
            if (this.toolCallId != null) {
                map.put("tool_call_id", this.toolCallId);
            }
            return map;
        }

        public String getRole() {
            return this.role.name().toLowerCase();
        }

        public List<ToolCallFunction> toolCalls() {
            if (this.toolCalls == null) {
                return null;
            }
            return List.of(this.toolCalls);
        }
    }

    private static enum PromptRole {
        USER,
        SYSTEM,
        ASSISTANT,
        TOOL,
        TOOL_CALL;

    }

    private static enum PromptType {
        DEFAULT,
        TOOL,
        RAG;

    }
}

