/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.math;

import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.BiIntConsumer;
import com.github.tjake.jlama.util.PhysicalCoreExecutor;
import com.google.common.base.Preconditions;
import java.util.function.IntConsumer;
import java.util.stream.IntStream;
import net.jafama.FastMath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VectorMath {
    private static final Logger logger = LoggerFactory.getLogger(VectorMath.class);

    public static void pfor(int start, int end, IntConsumer action) {
        PhysicalCoreExecutor.instance.get().execute(() -> IntStream.range(start, end).parallel().forEach(action));
    }

    public static void pchunk(int offset, int length, BiIntConsumer action) {
        int splits = Math.min(length, TensorOperationsProvider.get().parallelSplitSize());
        int chunkSize = length / splits;
        int remainder = 0;
        if (splits == 1) {
            splits = length;
            chunkSize = 1;
        } else if (length % chunkSize != 0) {
            remainder = length % chunkSize;
        }
        int fsplits = splits;
        int fchunkSize = chunkSize;
        int fremainder = remainder;
        PhysicalCoreExecutor.instance.get().execute(() -> IntStream.range(0, fsplits).parallel().forEach(i -> action.accept(offset + i * fchunkSize, fremainder > 0 && i == fsplits - 1 ? fchunkSize + fremainder : fchunkSize)));
    }

    public static void softMax(AbstractTensor x, int offset, int length) {
        Preconditions.checkArgument(x.shape().first() == 1);
        long size = offset + length;
        float max_val = x.get(0, offset);
        int i = offset + 1;
        while ((long)i < size) {
            int[] nArray = new int[]{0, i};
            if (x.get(nArray) > max_val) {
                max_val = x.get(0, i);
            }
            ++i;
        }
        float sum = 0.0f;
        int i2 = offset;
        while ((long)i2 < size) {
            x.set((float)FastMath.exp(x.get(0, i2) - max_val), 0, i2);
            sum += x.get(0, i2++);
        }
        i2 = 0;
        while ((long)i2 < size) {
            x.set(x.get(0, i2) / sum, 0, i2++);
        }
    }

    public static void l1normalize(float[] x) {
        int i;
        float sum = 0.0f;
        for (i = 0; i < x.length; ++i) {
            sum += FastMath.abs(x[i]);
        }
        i = 0;
        while (i < x.length) {
            int n = i++;
            x[n] = x[n] / sum;
        }
    }

    public static void l2normalize(AbstractTensor x) {
        float sum = 0.0f;
        int i = 0;
        while (i < x.shape().last()) {
            float v = x.get(0, i++);
            sum += v * v;
        }
        double magnitude = FastMath.sqrt(sum);
        int i2 = 0;
        while (i2 < x.shape().last()) {
            x.set((float)((double)x.get(0, i2) / magnitude), 0, i2++);
        }
    }

    public static void l2normalize(float[] x) {
        float sum = 0.0f;
        for (int i = 0; i < x.length; ++i) {
            sum += x[i] * x[i];
        }
        double magnitude = FastMath.sqrt(sum);
        int i = 0;
        while (i < x.length) {
            int n = i++;
            x[n] = (float)((double)x[n] / magnitude);
        }
    }

    public static float cosineSimilarity(float[] a, float[] b) {
        float dotProduct = 0.0f;
        float aMagnitude = 0.0f;
        float bMagnitude = 0.0f;
        for (int i = 0; i < a.length; ++i) {
            dotProduct += a[i] * b[i];
            aMagnitude += a[i] * a[i];
            bMagnitude += b[i] * b[i];
        }
        return (float)((double)dotProduct / (FastMath.sqrt(aMagnitude) * FastMath.sqrt(bMagnitude)));
    }

    public static float[] outerProduct(float[] xs, float[] ys) {
        int n = xs.length;
        int m = ys.length;
        float[] result = new float[n * m];
        int idx = 0;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < m; ++j) {
                result[idx++] = xs[i] * ys[j];
            }
        }
        return result;
    }

    public static float[][] precomputeFreqsCis(int dim, int end, double theta, double scaling_factor) {
        float[] freqs = new float[dim / 2];
        float step = 0.0f;
        for (int i = 0; i < freqs.length; ++i) {
            freqs[i] = (float)(1.0 / FastMath.pow(theta, step / (float)dim) / scaling_factor);
            step = (float)((double)step + 2.0);
        }
        float[] t = new float[end];
        for (int i = 0; i < end; ++i) {
            t[i] = i;
        }
        float[] freqs_cis = VectorMath.outerProduct(t, freqs);
        float[][] r = new float[freqs_cis.length][];
        for (int i = 0; i < freqs_cis.length; ++i) {
            r[i] = new float[]{(float)FastMath.cos(freqs_cis[i]), (float)FastMath.sin(freqs_cis[i])};
        }
        return r;
    }
}

