/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;

public class TensorShape {
    public static TensorShape one = TensorShape.of(1, 1);
    private final int[] tshape;
    private final long capacity;
    private final Optional<Pair<Integer, Integer>> sparseColumnRange;
    private final Optional<Pair<Integer, Integer>> sparseRowRange;
    private final boolean isSparse;
    private final int sparseColumnOffset;
    private final int sparseColumnLength;
    private final int sparseRowOffset;
    private final int sparseRowLength;

    public static TensorShape of(int ... shape) {
        if (shape.length == 1) {
            shape = new int[]{1, shape[0]};
        }
        return new TensorShape(shape, Optional.empty(), Optional.empty());
    }

    public static TensorShape sparseColumn(int[] shape, Pair<Integer, Integer> sparseOffset) {
        return new TensorShape(shape, Optional.empty(), Optional.of(sparseOffset));
    }

    public static TensorShape sparseRow(int[] shape, Pair<Integer, Integer> sparseOffset) {
        return new TensorShape(shape, Optional.of(sparseOffset), Optional.empty());
    }

    private TensorShape(int[] shape, Optional<Pair<Integer, Integer>> sparseRowRange, Optional<Pair<Integer, Integer>> sparseColumnRange) {
        Preconditions.checkArgument(shape.length > 1, "Shape must have at least two dimensions, even if first is 1 (to represent a vector)");
        this.tshape = shape;
        this.sparseColumnRange = sparseColumnRange;
        this.sparseRowRange = sparseRowRange;
        this.isSparse = this.sparseColumnRange.isPresent() || this.sparseRowRange.isPresent();
        this.sparseColumnOffset = this.sparseColumnRange.map(Pair::left).orElse(0);
        this.sparseColumnLength = this.sparseColumnRange.map(Pair::right).orElse(shape[shape.length - 1]);
        this.sparseRowOffset = this.sparseRowRange.map(Pair::left).orElse(0);
        this.sparseRowLength = this.sparseRowRange.map(Pair::right).orElse(shape[shape.length - 2]);
        long c = 1L;
        for (int i = 0; i < shape.length - 2; ++i) {
            c *= (long)shape[i];
        }
        c *= (long)this.sparseRowLength;
        this.capacity = c *= (long)this.sparseColumnLength;
    }

    public final boolean isSparse() {
        return this.isSparse;
    }

    public int dims() {
        return this.tshape.length;
    }

    public int dim(int i) {
        Preconditions.checkArgument(i < this.tshape.length);
        return this.tshape[i];
    }

    public final int getOffset(int ... pdims) {
        switch (pdims.length) {
            case 1: {
                return this.sparseColumnLength * (pdims[0] - this.sparseRowOffset) - this.sparseColumnOffset;
            }
            case 2: {
                return this.sparseColumnLength * (pdims[0] - this.sparseRowOffset) + pdims[1] - this.sparseColumnOffset;
            }
            case 3: {
                return this.sparseColumnLength * this.tshape[1] * (pdims[0] - this.sparseRowOffset) + this.sparseColumnLength * pdims[1] + pdims[2] - this.sparseColumnOffset;
            }
        }
        int totalOffset = 0;
        for (int d = 0; d < pdims.length - 1; ++d) {
            int offset = this.sparseColumnLength;
            for (int i = this.tshape.length - 2; i > d; --i) {
                offset *= this.tshape[i];
            }
            totalOffset += pdims[d] * offset;
        }
        return totalOffset + pdims[pdims.length - 1] - this.sparseColumnOffset;
    }

    public int sparseColumnLength() {
        return this.sparseColumnLength;
    }

    public int sparseColumnOffset() {
        return this.sparseColumnOffset;
    }

    public int sparseRowLength() {
        return this.sparseRowLength;
    }

    public int sparseRowOffset() {
        return this.sparseRowOffset;
    }

    public TensorShape scaleLastDim(float scale) {
        int[] copy = Arrays.copyOf(this.tshape, this.tshape.length);
        int n = copy.length - 1;
        copy[n] = (int)((float)copy[n] * scale);
        return this.sparseColumnRange.isPresent() ? TensorShape.sparseColumn(copy, Pair.of((int)((float)this.sparseColumnOffset * scale), (int)((float)this.sparseColumnLength * scale))) : TensorShape.of(copy);
    }

    public TensorShape setDimValue(int dim, int value) {
        Preconditions.checkArgument(dim < this.tshape.length);
        int[] copy = Arrays.copyOf(this.tshape, this.tshape.length);
        copy[dim] = value;
        int newSparseLength = copy[copy.length - 1];
        return this.sparseColumnRange.isPresent() ? TensorShape.sparseColumn(copy, Pair.of(this.sparseColumnOffset, newSparseLength)) : TensorShape.of(copy);
    }

    public int first() {
        return this.tshape[0];
    }

    public int last() {
        return this.tshape[this.tshape.length - 1];
    }

    public long size() {
        return this.capacity;
    }

    public TensorShape sparsifyColumns(int offset, int length) {
        Preconditions.checkArgument(!this.isSparse, "Cannot sparsify a sparse tensor");
        return new TensorShape(this.tshape, Optional.empty(), Optional.of(Pair.of(offset, length)));
    }

    public TensorShape slice(int numDims) {
        Preconditions.checkArgument(numDims < this.tshape.length, "Too many dimensions specified for tensor");
        int newLength = this.tshape.length - numDims;
        if (newLength == 1) {
            return new TensorShape(new int[]{1, this.tshape[this.tshape.length - 1]}, this.sparseRowRange, this.sparseColumnRange);
        }
        return new TensorShape(Arrays.copyOfRange(this.tshape, numDims, this.tshape.length), this.sparseRowRange, this.sparseColumnRange);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        TensorShape that = (TensorShape)o;
        return Arrays.equals(this.tshape, that.tshape) && Objects.equals(this.sparseColumnRange, that.sparseColumnRange);
    }

    public int hashCode() {
        int result = Objects.hash(this.sparseColumnRange);
        result = 31 * result + Arrays.hashCode(this.tshape);
        return result;
    }

    public String toString() {
        return "TensorShape{tshape=" + Arrays.toString(this.tshape) + ", capacity=" + this.capacity + ", sparseRange=" + String.valueOf(this.sparseColumnRange) + "}";
    }
}

