/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.google.common.base.Preconditions;
import net.jafama.FastMath;

public class LayerNorm {
    protected final AbstractModel m;
    private final AbstractTensor bias;
    protected final AbstractTensor weights;

    public LayerNorm(AbstractModel m, AbstractTensor bias, AbstractTensor weights) {
        this.m = m;
        this.bias = bias;
        this.weights = weights;
    }

    public AbstractTensor forward(AbstractTensor input) {
        Preconditions.checkArgument(input.shape().dims() == 2);
        int size = input.shape().last();
        Preconditions.checkArgument(size == this.m.c.embeddingLength);
        return this.forward(input, 0, this.m.c.embeddingLength);
    }

    public AbstractTensor forward(AbstractTensor input, int offset, int length) {
        int batchSize = input.shape().first();
        AbstractTensor output = input.copyShape();
        for (int b = 0; b < batchSize; ++b) {
            float sum = 0.0f;
            float sumSq = 0.0f;
            int limit = offset + length;
            int i = offset;
            while (i < limit) {
                float v = input.get(b, i++);
                sum += v;
                sumSq += v * v;
            }
            float mean = sum / (float)this.m.c.embeddingLength;
            float variance = sumSq / (float)this.m.c.embeddingLength - mean * mean;
            float invStddev = 1.0f / (float)FastMath.sqrt(variance + this.m.c.layerNormEps);
            int i2 = offset;
            while (i2 < limit) {
                float v = (input.get(b, i2) - mean) * invStddev * this.weights.get(0, i2) + this.bias.get(0, i2);
                output.set(v, b, i2++);
            }
        }
        return output;
    }
}

