/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.math.FloatConversions;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.util.UnsafeDirectByteBuffer;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import java.nio.ShortBuffer;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorSpecies;

public class BFloat16BufferTensor
extends AbstractTensor<ShortVector, Short> {
    private final ShortBuffer b;
    private final String name;
    private final MemorySegment segment;

    public BFloat16BufferTensor(AbstractTensor ft) {
        this(ft.shape);
        Preconditions.checkArgument(ft.dType != DType.BF16, "This should never happen, likely a bug");
        int[] cursor = new int[ft.shape.dims()];
        do {
            this.set(ft.get(cursor), cursor);
        } while (ft.iterate(cursor));
    }

    public BFloat16BufferTensor(int ... shape) {
        this(TensorShape.of(shape));
    }

    public BFloat16BufferTensor(TensorShape shape) {
        super(DType.BF16, shape, true);
        this.name = "tmp";
        this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(Ints.checkedCast(this.size() * (long)this.dType().size()), 64L).asShortBuffer();
        this.segment = MemorySegment.ofBuffer(this.b);
    }

    public BFloat16BufferTensor(String name, ShortBuffer b, TensorShape shape, boolean cacheSlices) {
        super(DType.BF16, shape, cacheSlices);
        this.name = name;
        this.b = b;
        this.segment = MemorySegment.ofBuffer(b);
    }

    @Override
    protected AbstractTensor make(TensorShape shape) {
        return new BFloat16BufferTensor(shape);
    }

    @Override
    protected AbstractTensor make(int offset, int length, TensorShape shape, boolean cacheSlices) {
        return new BFloat16BufferTensor(this.name, this.b.slice(offset, length), shape, cacheSlices);
    }

    @Override
    public float get(int ... dims) {
        Preconditions.checkArgument(dims.length <= this.shape.dims(), "Too many dimensions specified");
        Preconditions.checkArgument(dims.length == this.shape.dims(), "Must specify all dimensions");
        return FloatConversions.bFloat16ToFloat32(this.b.get(this.getOffset(dims)));
    }

    @Override
    public void set(float v, int ... dims) {
        Preconditions.checkArgument(dims.length <= this.shape.dims(), "Too many dimensions specified for tensor");
        Preconditions.checkArgument(dims.length == this.shape.dims(), "Must specify all dimensions");
        Preconditions.checkArgument(!this.b.isReadOnly(), "Can't modify a read only buffer");
        this.b.put(this.getOffset(dims), FloatConversions.float32ToBFloat16(v));
    }

    @Override
    public ShortVector getVector(VectorSpecies<Short> species, int ... voffset) {
        int offset = this.getOffset(voffset);
        return ShortVector.fromMemorySegment(species, (MemorySegment)this.segment, (long)this.getMemorySegmentOffset(offset), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
    }

    @Override
    public void intoTensor(ShortVector vector, int ... aoffset) {
        Preconditions.checkArgument(!this.b.isReadOnly());
        int offset = this.getOffset(aoffset);
        vector.intoMemorySegment(this.segment, (long)this.getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
    }

    @Override
    public MemorySegment getMemorySegment() {
        return this.segment;
    }

    @Override
    public int getMemorySegmentOffset(int offset) {
        return offset * this.dType.size();
    }

    @Override
    public void copyFrom(AbstractTensor src, int srcOffset, int destOffset, int length) {
        Preconditions.checkArgument(this.dType == src.dType, "different types");
        Preconditions.checkArgument(!this.b.isReadOnly(), "Read-only");
        this.segment.asSlice((long)this.getMemorySegmentOffset(destOffset), length).copyFrom(src.getMemorySegment().asSlice((long)src.getMemorySegmentOffset(srcOffset), length));
    }

    @Override
    public void clear() {
        Preconditions.checkArgument(!this.b.isReadOnly(), "Can't clear a read-only buffer");
        this.segment.fill((byte)0);
    }

    public String toString() {
        int i;
        float[] sample = new float[Math.min(10, this.b.remaining())];
        for (int i2 = 0; i2 < sample.length; ++i2) {
            sample[i2] = FloatConversions.bFloat16ToFloat32(this.b.get(i2));
        }
        StringBuffer sb = new StringBuffer();
        for (i = 0; i < sample.length; ++i) {
            sb.append(String.format("%8.4f", Float.valueOf(sample[i])));
            if (i >= sample.length - 1) continue;
            sb.append(", ");
        }
        for (i = 0; i < sample.length; ++i) {
            sample[i] = FloatConversions.bFloat16ToFloat32(this.b.get(i + this.shape.first() / 2));
        }
        StringBuffer sb2 = new StringBuffer();
        for (int i3 = 0; i3 < sample.length; ++i3) {
            sb2.append(String.format("%8.4f", Float.valueOf(sample[i3])));
            if (i3 >= sample.length - 1) continue;
            sb2.append(", ");
        }
        return "BFloat16BufferTensor{name='" + this.name + "', shape=" + String.valueOf(this.shape) + ",\n b=" + String.valueOf(sb) + "..." + String.valueOf(sb2) + "}";
    }
}

