/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.safetensors;

import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.SafeTensorIndex;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.TensorInfo;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.Weights;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.SegmentedTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.util.HttpSupport;
import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HTTPSafeTensorLoader
implements WeightLoader {
    private static final Logger logger = LoggerFactory.getLogger(HTTPSafeTensorLoader.class);
    private final Path modelRoot;
    private final String indexFile;
    private final String modelName;
    private final Optional<String> branch;
    private final Optional<String> authToken;
    private final SafeTensorIndex index;
    private final Map<String, Pair<RandomAccessFile, AbstractTensor>> layerFiles;
    private final Map<String, TensorInfo> dynamicTensorInfoMap;
    private final Map<String, Integer> tensorFileOffsets;
    private final DType modelDType;

    public HTTPSafeTensorLoader(Path modelRoot, String owner, String modelName, DType modelDType, Optional<String> branch, Optional<String> authToken) {
        this.modelRoot = modelRoot;
        this.modelName = owner + "/" + modelName;
        this.branch = branch;
        this.indexFile = String.format("%s/%s", modelRoot, "model.safetensors.index.json");
        this.authToken = authToken;
        if (!new File(this.indexFile).exists()) {
            this.index = new SafeTensorIndex(Collections.emptyMap(), Map.of("model-file", "model.safetensors"));
        } else {
            try {
                this.index = JsonSupport.om.readValue(new File(this.indexFile), SafeTensorIndex.class);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        this.layerFiles = new HashMap<String, Pair<RandomAccessFile, AbstractTensor>>();
        this.dynamicTensorInfoMap = new HashMap<String, TensorInfo>();
        this.tensorFileOffsets = new HashMap<String, Integer>();
        this.modelDType = modelDType;
    }

    @Override
    public Map<String, String> metadata() {
        return this.index.metadata();
    }

    @Override
    public Map<String, TensorInfo> tensorInfoMap() {
        return this.dynamicTensorInfoMap;
    }

    @Override
    public AbstractTensor load(String name, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) {
        Preconditions.checkArgument(!sparseColumns || !sparseRows, "Cannot have both sparse rows and columns");
        Preconditions.checkArgument(this.index.weightFileMap.containsKey(name) || this.index.weightFileMap.size() == 1, "Unknown weight: " + name);
        if (this.layerFiles.containsKey(name)) {
            return this.layerFiles.get(name).right();
        }
        try {
            TensorInfo info = this.maybeLoadTensorInfo(name);
            Pair<TensorShape, Pair<Long, Long>> offsets = Weights.getLoadOffsets(info, dctx, sparseRows);
            Integer headerOffset = this.tensorFileOffsets.get(name);
            assert (headerOffset != null && headerOffset > 0) : "Failed to find header offset for: " + name;
            String weightFile = this.index.weightFileMap.getOrDefault(name, "model.safetensors");
            TensorShape shape = (TensorShape)offsets.left;
            long positionOffset = (Long)((Pair)offsets.right).left + (long)headerOffset.intValue();
            long positionLimit = (Long)((Pair)offsets.right).right + (long)headerOffset.intValue();
            long length = positionLimit - positionOffset;
            if (length > Integer.MAX_VALUE) {
                assert (info.shape.length == 2) : "Only 2D tensors supported";
                ArrayList<AbstractTensor> tensors = new ArrayList<AbstractTensor>();
                int bytesPerColumn = info.dType.size() * info.shape[1];
                long offset = positionOffset;
                long chunkSize = Integer.MAX_VALUE - Integer.MAX_VALUE % bytesPerColumn;
                int chunkNum = 0;
                while (offset < positionLimit) {
                    long chunkEnd = Math.min(offset + chunkSize, positionLimit);
                    int numRowsInChunk = Ints.checkedCast((chunkEnd - offset) / (long)bytesPerColumn);
                    TensorShape chunkShape = TensorShape.of(numRowsInChunk, info.shape[1]);
                    tensors.add(this.downloadAndLoadTensor(name + ".part." + chunkNum++, weightFile, info, chunkShape, offset, chunkEnd, dctx, sparseRows, sparseColumns));
                    offset = chunkEnd;
                }
                SegmentedTensor wrapped = SegmentedTensor.wrap(tensors);
                this.layerFiles.put(name, Pair.of(null, wrapped));
                return wrapped;
            }
            return this.downloadAndLoadTensor(name, weightFile, info, shape, positionOffset, positionLimit, dctx, sparseRows, sparseColumns);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private AbstractTensor downloadAndLoadTensor(String name, String weightFile, TensorInfo info, TensorShape shape, long positionOffset, long positionLimit, DistributedContext dctx, boolean sparseRows, boolean sparseColumns) throws IOException {
        Path weightPath = this.modelRoot.resolve(weightFile + ".part." + positionOffset + "_" + positionLimit);
        if (!weightPath.toFile().exists()) {
            logger.info("Downloading file: {} for {} {}MB", weightPath, name, (positionLimit - positionOffset) / 1024L / 1024L);
            HttpSupport.downloadFile(this.modelName, weightFile, this.branch, this.authToken, Optional.of(Pair.of(positionOffset, positionLimit)), weightPath, Optional.empty());
        }
        int length = Ints.checkedCast(positionLimit - positionOffset);
        RandomAccessFile raf = new RandomAccessFile(weightPath.toFile(), "r");
        ByteBuffer buf = raf.getChannel().map(FileChannel.MapMode.READ_ONLY, 0L, raf.length()).duplicate().order(ByteOrder.LITTLE_ENDIAN).position(0).limit(length);
        if (raf.length() < (long)length) {
            throw new RuntimeException("Failed to download the correct number of bytes: " + raf.length() + " != " + length + " for " + String.valueOf(weightPath));
        }
        logger.debug("Loading tensor: {} from {} with offsets: {} {}", name, weightPath, positionOffset, positionLimit);
        AbstractTensor tensor = Weights.loadTensorFromBuffer(name, info.dType, this.modelDType, shape, buf, sparseRows, sparseColumns, dctx, this);
        this.layerFiles.put(name, Pair.of(raf, tensor));
        return tensor;
    }

    private TensorInfo maybeLoadTensorInfo(String name) throws IOException {
        if (this.dynamicTensorInfoMap.containsKey(name)) {
            return this.dynamicTensorInfoMap.get(name);
        }
        String weightFile = this.index.weightFileMap.getOrDefault(name, "model.safetensors");
        Path headerFile = this.modelRoot.resolve(weightFile + ".header");
        if (!Files.exists(headerFile, new LinkOption[0])) {
            HttpSupport.downloadFile(this.modelName, weightFile, this.branch, this.authToken, Optional.of(Pair.of(0L, 0x100000L)), headerFile, Optional.empty());
        }
        try (RandomAccessFile raf = new RandomAccessFile(headerFile.toFile(), "r");){
            MappedByteBuffer header = raf.getChannel().map(FileChannel.MapMode.READ_ONLY, 0L, Math.min(0x100000L, raf.length()));
            Map<String, TensorInfo> info = SafeTensorSupport.readTensorInfoMap(header, Optional.empty());
            int endOfHeaderPosition = header.position();
            for (Map.Entry<String, TensorInfo> e : info.entrySet()) {
                this.dynamicTensorInfoMap.put(e.getKey(), e.getValue());
                this.tensorFileOffsets.put(e.getKey(), endOfHeaderPosition);
            }
        }
        assert (this.dynamicTensorInfoMap.containsKey(name)) : "Failed to load tensor info for: " + name;
        return this.dynamicTensorInfoMap.get(name);
    }

    @Override
    public DType getModelDType() {
        return this.modelDType;
    }

    @Override
    public void close() {
        for (Pair<RandomAccessFile, AbstractTensor> pair : this.layerFiles.values()) {
            try {
                if (pair.left() == null) continue;
                pair.left().close();
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        this.layerFiles.clear();
        this.dynamicTensorInfoMap.clear();
    }
}

