/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors;

import com.fasterxml.jackson.core.TreeNode;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.type.MapType;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.SafeTensorIndex;
import com.github.tjake.jlama.safetensors.TensorInfo;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.Weights;
import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q5ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.util.HttpSupport;
import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.ProgressReporter;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.RandomAccessFile;
import java.lang.invoke.LambdaMetafactory;
import java.lang.invoke.StringConcatFactory;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SafeTensorSupport {
    private static final Logger logger = LoggerFactory.getLogger(SafeTensorSupport.class);
    private static final MapType metadataTypeReference = JsonSupport.om.getTypeFactory().constructMapType(Map.class, String.class, String.class);
    static String FINISHED_MARKER = ".finished";

    public static Map<String, TensorInfo> readTensorInfoMap(ByteBuffer buf, Optional<Map<String, String>> saveMetadata) {
        long MAX_HEADER_LENGTH = 0x40000000L;
        long headerLength = (buf = buf.order(ByteOrder.LITTLE_ENDIAN)).getLong();
        if (headerLength < 0L) {
            throw new IllegalArgumentException("Header length cannot be negative: " + headerLength);
        }
        if (headerLength > 0x40000000L) {
            throw new IllegalArgumentException(String.format("Header length %d exceeds the maximum allowed length %d.", headerLength, 0x40000000L));
        }
        byte[] header = new byte[Ints.checkedCast(headerLength)];
        buf.get(header);
        try {
            JsonNode rootNode = JsonSupport.om.readTree(header);
            Iterator<Map.Entry<String, JsonNode>> fields = rootNode.fields();
            HashMap<String, TensorInfo> tensorInfoMap = new HashMap<String, TensorInfo>();
            Map metadata = Collections.emptyMap();
            while (fields.hasNext()) {
                Map.Entry<String, JsonNode> field = fields.next();
                if (field.getKey().equalsIgnoreCase("__metadata__")) {
                    metadata = (Map)JsonSupport.om.treeToValue((TreeNode)field.getValue(), metadataTypeReference);
                    continue;
                }
                TensorInfo tensorInfo = JsonSupport.om.treeToValue((TreeNode)field.getValue(), TensorInfo.class);
                tensorInfoMap.put(field.getKey(), tensorInfo);
            }
            Map sortedMap = tensorInfoMap.entrySet().stream().sorted(Map.Entry.comparingByValue()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
            Map finalMetadata = metadata;
            saveMetadata.ifPresent(m -> m.putAll(finalMetadata));
            return sortedMap;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static Weights readWeights(ByteBuffer safeBuf) {
        safeBuf = safeBuf.duplicate();
        HashMap<String, String> metadata = new HashMap<String, String>();
        Map<String, TensorInfo> tensorInfoMap = SafeTensorSupport.readTensorInfoMap(safeBuf, Optional.of(metadata));
        return new Weights(metadata, tensorInfoMap, safeBuf.slice(), Optional.empty());
    }

    public static ModelSupport.ModelType detectModel(File configFile) throws IOException {
        JsonNode rootNode = JsonSupport.om.readTree(configFile);
        if (!rootNode.has("model_type")) {
            throw new IllegalArgumentException("Config missing model_type field.");
        }
        return ModelSupport.ModelType.valueOf(rootNode.get("model_type").textValue().toUpperCase());
    }

    public static WeightLoader loadWeights(File baseDir) {
        if (Files.exists(Paths.get(baseDir.getAbsolutePath(), "model.safetensors.index.json"), new LinkOption[0])) {
            return SafeTensorIndex.loadWithWeights(baseDir.toPath());
        }
        if (Files.exists(Paths.get(baseDir.getAbsolutePath(), "model.safetensors"), new LinkOption[0])) {
            return SafeTensorIndex.loadSingleFile(baseDir.toPath(), "model.safetensors");
        }
        throw new IllegalArgumentException("No safetensor model found in: " + String.valueOf(baseDir));
    }

    public static boolean isModelLocal(Path modelRoot) {
        if (Files.exists(modelRoot.resolve("model.safetensors"), new LinkOption[0])) {
            return true;
        }
        try {
            if (Files.exists(modelRoot.resolve("model.safetensors.index.json"), new LinkOption[0])) {
                SafeTensorIndex index = JsonSupport.om.readValue(modelRoot.resolve("model.safetensors.index.json").toFile(), SafeTensorIndex.class);
                for (String file : index.weightFileMap.values()) {
                    if (Files.exists(modelRoot.resolve(file), new LinkOption[0])) continue;
                    return false;
                }
                return true;
            }
        }
        catch (IOException e) {
            logger.error("Error reading model index", e);
            return false;
        }
        return false;
    }

    public static TokenizerModel loadTokenizer(Path modelRoot) throws IOException {
        File tokenizerConfigJson;
        File tokenizerJson = modelRoot.resolve("tokenizer.json").toFile();
        Preconditions.checkArgument(tokenizerJson.exists(), "No tokenizer.json found in " + String.valueOf(modelRoot));
        JsonNode rootNode = JsonSupport.om.readTree(tokenizerJson);
        if (!rootNode.has("model")) {
            throw new IllegalArgumentException("Json missing 'model' key");
        }
        TokenizerModel model = JsonSupport.om.treeToValue((TreeNode)rootNode.get("model"), TokenizerModel.class);
        if (rootNode.has("added_tokens") && rootNode.get("added_tokens") != null) {
            List addedTokens = JsonSupport.om.convertValue(rootNode.get("added_tokens"), List.class);
            model.setAddedTokens(addedTokens);
        }
        if (rootNode.has("pre_tokenizer") && rootNode.get("pre_tokenizer") != null) {
            model.setPreTokenizer(JsonSupport.om.treeToValue((TreeNode)rootNode.get("pre_tokenizer"), TokenizerModel.PreTokenizer.class));
        }
        if (rootNode.has("normalizer") && rootNode.get("normalizer") != null) {
            model.setNormalizer(JsonSupport.om.treeToValue((TreeNode)rootNode.get("normalizer"), TokenizerModel.Normalizer.class));
        }
        if ((tokenizerConfigJson = modelRoot.resolve("tokenizer_config.json").toFile()).exists()) {
            JsonNode configNode = JsonSupport.om.readTree(tokenizerConfigJson);
            if (configNode.has("legacy")) {
                model.setLegacy(configNode.get("legacy").asBoolean());
            }
            if (configNode.has("chat_template")) {
                JsonNode chatTemplateNode = configNode.get("chat_template");
                HashMap<String, String> promptTemplates = new HashMap<String, String>();
                if (chatTemplateNode.isTextual()) {
                    promptTemplates.put("default", chatTemplateNode.asText());
                } else if (chatTemplateNode.isArray()) {
                    List chatTemplates = JsonSupport.om.convertValue(chatTemplateNode, List.class);
                    for (Map chatTemplate : chatTemplates) {
                        if (chatTemplate.containsKey("name") && chatTemplate.containsKey("template")) {
                            promptTemplates.put((String)chatTemplate.get("name"), (String)chatTemplate.get("template"));
                            continue;
                        }
                        throw new IllegalArgumentException("Invalid chat_template format");
                    }
                } else {
                    throw new IllegalArgumentException("Invalid chat_template format");
                }
                model.setPromptTemplates(promptTemplates);
            }
            if (configNode.has("eos_token")) {
                model.setEosToken(configNode.get("eos_token").asText());
            }
            if (configNode.has("bos_token")) {
                model.setBosToken(configNode.get("bos_token").asText());
            }
        }
        return model;
    }

    /*
     * Unable to fully structure code
     */
    public static Path quantizeModel(Path modelRoot, DType modelQuantization, String[] skipLayerPrefixes, String[] dropLayerPrefixes, Optional<Path> outputRoot) throws IOException {
        tmp = File.createTempFile("safe", "tensor");
        tmp.deleteOnExit();
        wl = SafeTensorSupport.loadWeights(modelRoot.toFile());
        writtenInfo = new HashMap<Object, TensorInfo>();
        raf = new RandomAccessFile(tmp, "rw");
        try {
            tensors = wl.tensorInfoMap();
lbl8:
            // 7 sources

            for (Map.Entry<String, TensorInfo> e : tensors.entrySet()) {
                drop = false;
                if (dropLayerPrefixes != null) {
                    for (String dropLayerPrefix : dropLayerPrefixes) {
                        if (!e.getKey().startsWith(dropLayerPrefix)) continue;
                        SafeTensorSupport.logger.info("Dropping layer: " + e.getKey());
                        drop = true;
                    }
                }
                if (drop) continue;
                tr = wl.load(e.getKey());
                try {
                    skipQ = false;
                    if (skipLayerPrefixes != null) {
                        for (String skipLayerPrefix : skipLayerPrefixes) {
                            if (!e.getKey().contains(skipLayerPrefix)) continue;
                            SafeTensorSupport.logger.info("Skipping quantization of layer: " + e.getKey());
                            skipQ = true;
                            break;
                        }
                    }
                    t = skipQ != false ? tr : tr.quantize(modelQuantization);
                    switch (2.$SwitchMap$com$github$tjake$jlama$safetensors$DType[t.dType().ordinal()]) {
                        case 1: 
                        case 2: 
                        case 3: {
                            writtenInfo.put(e.getKey(), t.save(raf.getChannel()));
                            ** break;
                        }
                        case 4: {
                            writtenInfo.put(e.getKey(), t.save(raf.getChannel()));
                            writtenInfo.put(e.getKey() + ".qb", ((Q4ByteBufferTensor)t).getBlockF().save(raf.getChannel()));
                            ** break;
                        }
                        case 5: {
                            writtenInfo.put(e.getKey(), t.save(raf.getChannel()));
                            writtenInfo.put(e.getKey() + ".qb", ((Q5ByteBufferTensor)t).getBlockF().save(raf.getChannel()));
                            throw new UnsupportedOperationException("TODO");
                        }
                        case 6: {
                            writtenInfo.put(e.getKey(), t.save(raf.getChannel()));
                            writtenInfo.put(e.getKey() + ".qb", ((Q8ByteBufferTensor)t).getBlockF().save(raf.getChannel()));
                            ** break;
                        }
                        default: {
                            throw new UnsupportedOperationException(String.valueOf((Object)t.dType()) + " not implemented");
                        }
                    }
                }
                finally {
                    if (tr == null) continue;
                    tr.close();
                }
            }
        }
        finally {
            raf.close();
        }
        baseDirName = modelRoot.getName(modelRoot.getNameCount() - 1).toString();
        parentPath = modelRoot.getParent();
        qPath = outputRoot.orElseGet((Supplier<Path>)LambdaMetafactory.metafactory(null, null, null, ()Ljava/lang/Object;, lambda$quantizeModel$2(java.nio.file.Path java.lang.String com.github.tjake.jlama.safetensors.DType ), ()Ljava/nio/file/Path;)((Path)parentPath, (String)baseDirName, (DType)modelQuantization));
        qDir = qPath.toFile();
        qDir.mkdirs();
        Files.copy(modelRoot.resolve("config.json"), qPath.resolve("config.json"), new CopyOption[0]);
        Files.copy(modelRoot.resolve("tokenizer.json"), qPath.resolve("tokenizer.json"), new CopyOption[0]);
        Files.copy(modelRoot.resolve("README.md"), qPath.resolve("README.md"), new CopyOption[0]);
        SafeTensorSupport.addJlamaHeader(baseDirName, qPath.resolve("README.md"));
        if (Files.exists(modelRoot.resolve("tokenizer_config.json"), new LinkOption[0])) {
            Files.copy(modelRoot.resolve("tokenizer_config.json"), qPath.resolve("tokenizer_config.json"), new CopyOption[0]);
        }
        raf = new RandomAccessFile(qPath.resolve("model.safetensors").toFile(), "rw");
        try {
            header = JsonSupport.om.writeValueAsBytes(writtenInfo);
            hsize = new byte[8];
            ByteBuffer.wrap(hsize).order(ByteOrder.LITTLE_ENDIAN).putLong(header.length);
            raf.write(hsize);
            raf.write(header);
            Files.copy(tmp.toPath(), new OutputStream(){

                @Override
                public void write(int b) throws IOException {
                    raf.write(b);
                }

                @Override
                public void write(byte[] b) throws IOException {
                    raf.write(b);
                }

                @Override
                public void write(byte[] b, int off, int len) throws IOException {
                    raf.write(b, off, len);
                }
            });
        }
        finally {
            raf.close();
        }
        return qPath;
    }

    /*
     * WARNING - void declaration
     */
    private static void addJlamaHeader(String modelName, Path readmePath) throws IOException {
        String cleanName = modelName.replaceAll("_", "/");
        String header = String.format("# Quantized Version of %s \n\nThis model is a quantized variant of the %s model, optimized for use with Jlama, a Java-based inference engine. The quantization process reduces the model's size and improves inference speed, while maintaining high accuracy for efficient deployment in production environments.\n\nFor more information on Jlama, visit the [Jlama GitHub repository](https://github.com/tjake/jlama).\n\n---\n\n", cleanName, cleanName);
        String readme = new String(Files.readAllBytes(readmePath));
        boolean startMeta = false;
        boolean finishedMeta = false;
        int linenum = 0;
        StringBuilder finalReadme = new StringBuilder();
        for (String string : readme.split("\n")) {
            void var12_12;
            if (linenum++ == 0) {
                if (string.startsWith("---")) {
                    startMeta = true;
                } else {
                    finalReadme.append(header);
                }
            } else if (startMeta && !finishedMeta && string.startsWith("---")) {
                finishedMeta = true;
                String string2 = string + "\n\n" + header;
            }
            finalReadme.append((String)var12_12).append("\n");
        }
        Files.write(readmePath, finalReadme.toString().getBytes(), new OpenOption[0]);
    }

    public static File maybeDownloadModel(String modelDir, String fullModelName, ProgressReporter progressReporter) throws IOException {
        String name;
        String owner;
        String[] parts = fullModelName.split("/");
        if (parts.length == 0 || parts.length > 2) {
            throw new IllegalArgumentException("Model must be in the form owner/name");
        }
        if (parts.length == 1) {
            owner = null;
            name = fullModelName;
        } else {
            owner = parts[0];
            name = parts[1];
        }
        return SafeTensorSupport.maybeDownloadModel(modelDir, Optional.ofNullable(owner), name, true, Optional.empty(), Optional.empty(), Optional.ofNullable(progressReporter));
    }

    public static File maybeDownloadModel(String modelDir, String fullModelName) throws IOException {
        return SafeTensorSupport.maybeDownloadModel(modelDir, fullModelName, null);
    }

    public static Path constructLocalModelPath(String modelDir, String owner, String modelName) {
        return Paths.get(modelDir, owner + "_" + modelName);
    }

    public static File maybeDownloadModel(String modelDir, Optional<String> modelOwner, String modelName, boolean downloadWeights, Optional<String> optionalBranch, Optional<String> optionalAuthHeader, Optional<ProgressReporter> optionalProgressReporter) throws IOException {
        Path localModelDir = SafeTensorSupport.constructLocalModelPath(modelDir, modelOwner.orElse("na"), modelName);
        if (Files.exists(localModelDir.resolve(FINISHED_MARKER), new LinkOption[0])) {
            return localModelDir.toFile();
        }
        String hfModel = modelOwner.map(mo -> mo + "/" + modelName).orElse(modelName);
        InputStream modelInfoStream = (InputStream)HttpSupport.getResponse((String)((Object)StringConcatFactory.makeConcatWithConstants("makeConcatWithConstants", new Object[]{"https://huggingface.co/api/models/\u0001/tree/\u0001"}, (String)hfModel, (String)optionalBranch.orElse((String)"main"))), optionalAuthHeader, Optional.empty()).left;
        String modelInfo = HttpSupport.readInputStream(modelInfoStream);
        if (modelInfo == null) {
            throw new IOException("No valid model found or trying to access a restricted model (please include correct access token)");
        }
        List<String> allFiles = SafeTensorSupport.parseFileList(modelInfo);
        if (allFiles.isEmpty()) {
            throw new IOException("No valid model found");
        }
        ArrayList<String> tensorFiles = new ArrayList<String>();
        boolean hasSafetensor = false;
        for (String currFile : allFiles) {
            String f = currFile.toLowerCase();
            if ((!f.contains("safetensor") || f.contains("consolidated")) && !f.contains("readme") && !f.equals("config.json") && !f.contains("tokenizer")) continue;
            if (f.contains("safetensor")) {
                hasSafetensor = true;
            }
            if (!downloadWeights && f.contains("safetensor")) continue;
            tensorFiles.add(currFile);
        }
        if (!hasSafetensor) {
            throw new IOException("Model is not available in safetensor format");
        }
        Files.createDirectories(localModelDir, new FileAttribute[0]);
        for (String currFile : tensorFiles) {
            HttpSupport.downloadFile(hfModel, currFile, optionalBranch, optionalAuthHeader, Optional.empty(), localModelDir.resolve(currFile), optionalProgressReporter);
        }
        Files.createFile(localModelDir.resolve(FINISHED_MARKER), new FileAttribute[0]);
        return localModelDir.toFile();
    }

    private static List<String> parseFileList(String modelInfo) throws IOException {
        ArrayList<String> fileList = new ArrayList<String>();
        ObjectMapper objectMapper = new ObjectMapper();
        JsonNode siblingsNode = objectMapper.readTree(modelInfo);
        if (siblingsNode.isArray()) {
            for (JsonNode siblingNode : siblingsNode) {
                String rFilename = siblingNode.path("path").asText();
                fileList.add(rFilename);
            }
        }
        return fileList;
    }

    private static /* synthetic */ Path lambda$quantizeModel$2(Path parentPath, String baseDirName, DType modelQuantization) {
        return Paths.get(parentPath.toString(), baseDirName + "-J" + modelQuantization.name());
    }
}

