/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.BFloat16BufferTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.TensorCache;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import java.io.Closeable;
import java.io.IOError;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.ShortBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.util.Iterator;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KvBufferCache
implements Closeable {
    private static final Logger logger = LoggerFactory.getLogger(KvBufferCache.class);
    private final ConcurrentMap<UUID, KvBuffer> kvBufferCache = new ConcurrentHashMap<UUID, KvBuffer>();
    private final AbstractModel model;

    public KvBufferCache(AbstractModel model) {
        this.model = model;
    }

    public KvBuffer getKvBuffer(UUID session) {
        return this.kvBufferCache.computeIfAbsent(session, s -> new KvBuffer((UUID)s, 0x800000, false));
    }

    public KvBuffer getEphemeralKvBuffer() {
        return new KvBuffer(UUID.randomUUID(), 0x100000, true);
    }

    @Override
    public void close() {
        Iterator it = this.kvBufferCache.entrySet().iterator();
        while (it.hasNext()) {
            ((KvBuffer)it.next().getValue()).close();
            it.remove();
        }
    }

    public class KvBuffer
    implements AutoCloseable {
        private UUID session;
        private final AtomicInteger currentContextPosition = new AtomicInteger(0);
        private final KvBufferPage[][] pages;
        private final KvPageContext pageContext;
        private final boolean ephemeral;

        KvBuffer(UUID session, int maxPageSizeInBytes, boolean ephemeral) {
            this.session = session;
            this.pageContext = this.computePageSize(maxPageSizeInBytes);
            this.pages = new KvBufferPage[this.pageContext.numberOfLayerPages][this.pageContext.numberOfContextPages];
            this.ephemeral = ephemeral;
        }

        public int getCurrentContextPosition() {
            return this.currentContextPosition.get();
        }

        public void setCurrentContextPosition(int position) {
            this.currentContextPosition.set(position);
        }

        public void incrementContextPosition() {
            this.currentContextPosition.incrementAndGet();
        }

        public KvPageContext computePageSize(long maxPageSizeInBytes) {
            Config c = KvBufferCache.this.model.getConfig();
            DType workingDType = KvBufferCache.this.model.getWorkingDType();
            long s = 2L * (long)workingDType.size() * (long)c.dctx().kvSegmentLength;
            Preconditions.checkArgument(maxPageSizeInBytes > s, "maxPageSizeInBytes must be greater than the size of a single layer");
            int N = c.dctx().numberOfLayers;
            int C = c.contextLength;
            int optimalLayersPerPage = 1;
            int optimalContextLengthPerPage = 1;
            long maxProduct = 0L;
            for (int x = N; x >= 1; --x) {
                long y = maxPageSizeInBytes / ((long)x * s);
                if (y < 1L || y > (long)C) continue;
                long product = (long)x * y;
                if (product > maxProduct) {
                    optimalLayersPerPage = x;
                    optimalContextLengthPerPage = (int)y;
                    maxProduct = product;
                }
                if (product < maxProduct) break;
            }
            int numberOfLayerPages = (int)Math.ceil((double)N / (double)optimalLayersPerPage);
            int numberOfContextPages = (int)Math.ceil((double)C / (double)optimalContextLengthPerPage);
            long pageSize = (long)(optimalLayersPerPage * optimalContextLengthPerPage) * s;
            if (pageSize > maxPageSizeInBytes) {
                throw new IllegalArgumentException("Calculation error: pageSize > maxPageSizeInBytes: " + pageSize + " > " + maxPageSizeInBytes);
            }
            logger.debug("Optimal page size: {} layers, {} context length, {} bytes, {} layer pages, {} length pages", optimalLayersPerPage, optimalContextLengthPerPage, pageSize, numberOfLayerPages, numberOfContextPages);
            return new KvPageContext(KvBufferCache.this, this.session, numberOfLayerPages, numberOfContextPages, optimalLayersPerPage, optimalContextLengthPerPage);
        }

        @Override
        public void close() {
            for (KvBufferPage[] layerPages : this.pages) {
                if (layerPages == null) continue;
                for (KvBufferPage page : layerPages) {
                    if (page == null) continue;
                    try {
                        page.close();
                    }
                    catch (IOException e) {
                        logger.debug("Error closing page", e);
                    }
                }
            }
        }

        public AbstractTensor getKeyTensorForPosition(int layerIndex, int position) {
            return this.getTensorForPosition(layerIndex, position, 0);
        }

        public AbstractTensor getValTensorForPosition(int layerIndex, int position) {
            return this.getTensorForPosition(layerIndex, position, 1);
        }

        private AbstractTensor getTensorForPosition(int layerIndex, int position, int index) {
            int layerPageIndex = layerIndex / this.pageContext.layersPerPage;
            int contextPageIndex = position / this.pageContext.contextLengthPerPage;
            int relativeLayerIndex = layerIndex % this.pageContext.layersPerPage;
            int relativeContextIndex = position % this.pageContext.contextLengthPerPage;
            KvBufferPage page = this.pages[layerPageIndex][contextPageIndex];
            if (page == null || page.isClosed()) {
                this.pages[layerPageIndex][contextPageIndex] = page = new KvBufferPage(KvBufferCache.this, this.pageContext, "L" + layerPageIndex + "C" + contextPageIndex, this.ephemeral);
            }
            return page.getTensor().slice(true, relativeLayerIndex, index, relativeContextIndex);
        }

        public AbstractTensor[] getKeyTensorsUptoPosition(int layerIndex, int upperBound) {
            return this.getTensorsUptoPosition(layerIndex, 0, upperBound);
        }

        public AbstractTensor[] getValTensorsUptoPosition(int layerIndex, int upperBound) {
            return this.getTensorsUptoPosition(layerIndex, 1, upperBound);
        }

        private AbstractTensor[] getTensorsUptoPosition(int layerIndex, int index, int upperBound) {
            int layerPageIndex = layerIndex / this.pageContext.layersPerPage;
            int contextPageIndex = upperBound / this.pageContext.contextLengthPerPage;
            int relativeLayerIndex = layerIndex % this.pageContext.layersPerPage;
            KvBufferPage[] layerPages = this.pages[layerPageIndex];
            AbstractTensor[] tensors = new AbstractTensor[contextPageIndex + 1];
            for (int i = 0; i <= contextPageIndex; ++i) {
                KvBufferPage page = layerPages[i];
                if (page == null || page.isClosed()) {
                    layerPages[i] = page = new KvBufferPage(KvBufferCache.this, this.pageContext, "L" + layerPageIndex + "C" + contextPageIndex, this.ephemeral);
                }
                tensors[i] = page.getTensor().slice(true, relativeLayerIndex, index);
            }
            return tensors;
        }
    }

    class KvBufferPage
    implements AutoCloseable {
        private final AbstractTensor tensor;
        private final KvPageContext pageCtx;
        private final String pageId;
        private final AtomicBoolean closed = new AtomicBoolean(false);
        private final RandomAccessFile raf;

        KvBufferPage(KvBufferCache this$0, KvPageContext pageCtx, String pageId, boolean ephemeral) {
            this.pageCtx = pageCtx;
            this.pageId = pageId;
            if (this$0.model.getConfig().workingDirectory().isEmpty() || ephemeral) {
                this.raf = null;
                this.tensor = TensorCache.instance.get(this$0.model.getWorkingDType(), pageCtx.pageShape);
            } else {
                try {
                    AbstractTensor t;
                    this.raf = new RandomAccessFile(Paths.get(this$0.model.getConfig().workingDirectory().get().toString(), pageCtx.session.toString() + "-" + pageId + ".page").toFile(), "rw");
                    long bytes = pageCtx.pageShape.size() * (long)this$0.model.getWorkingDType().size();
                    logger.debug("Allocating page {} with {} bytes {}", pageId, bytes, this.raf.length());
                    if (this.raf.length() != bytes) {
                        this.raf.setLength(bytes);
                    }
                    if (this$0.model.getWorkingDType() == DType.F32) {
                        FloatBuffer fb = this.raf.getChannel().map(FileChannel.MapMode.READ_WRITE, 0L, bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
                        t = new FloatBufferTensor(fb, pageCtx.pageShape, true);
                    } else if (this$0.model.getWorkingDType() == DType.BF16) {
                        ShortBuffer sb = this.raf.getChannel().map(FileChannel.MapMode.READ_WRITE, 0L, bytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer();
                        t = new BFloat16BufferTensor("kvmem", sb, pageCtx.pageShape, true);
                    } else {
                        throw new UnsupportedOperationException("Only F32/BF16 is supported for now");
                    }
                    this.tensor = t;
                }
                catch (IOException e) {
                    throw new IOError(e);
                }
            }
        }

        public AbstractTensor getTensor() {
            assert (!this.closed.get()) : "Page is closed";
            return this.tensor;
        }

        public boolean isClosed() {
            return this.closed.get();
        }

        @Override
        public void close() throws IOException {
            if (this.closed.compareAndSet(false, true)) {
                if (this.raf != null) {
                    this.raf.close();
                }
                this.tensor.close();
            }
        }
    }

    class KvPageContext {
        public final int numberOfLayerPages;
        public final int numberOfContextPages;
        private final int layersPerPage;
        private final int contextLengthPerPage;
        private final UUID session;
        public final TensorShape pageShape;

        public KvPageContext(KvBufferCache this$0, UUID session, int numberOfLayerPages, int numberOfContextPages, int layersPerPage, int contextLengthPerPage) {
            TensorShape s;
            this.session = session;
            this.numberOfLayerPages = numberOfLayerPages;
            this.numberOfContextPages = numberOfContextPages;
            this.layersPerPage = layersPerPage;
            this.contextLengthPerPage = contextLengthPerPage;
            if (numberOfLayerPages < 1) {
                throw new IllegalArgumentException("totalPageCount must be >= 1");
            }
            if (numberOfContextPages < 1) {
                throw new IllegalArgumentException("numberOfContextPages must be >= 1");
            }
            if (layersPerPage < 1) {
                throw new IllegalArgumentException("layersPerPage must be >= 1");
            }
            if (contextLengthPerPage < 1) {
                throw new IllegalArgumentException("contextLengthPerPage must be >= 1");
            }
            Config c = this$0.model.getConfig();
            DistributedContext dctx = c.dctx();
            int[] rawShape = new int[]{layersPerPage, 2, contextLengthPerPage, c.kvLength};
            if (c.kvLength != dctx.kvSegmentLength) {
                Pair<Integer, Integer> kvOffset = Pair.of(dctx.kvSegmentStart, dctx.kvSegmentEnd);
                s = TensorShape.sparseColumn(rawShape, kvOffset);
            } else {
                s = TensorShape.of(rawShape);
            }
            this.pageShape = s;
        }
    }
}

