/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.CausalSelfAttention;
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.model.functions.FeedForward;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TransformerBlock {
    private static final Logger logger = LoggerFactory.getLogger(TransformerBlock.class);
    private final AbstractModel model;
    final int layerIndex;
    final Optional<LayerNorm> preAttentionNorm;
    final CausalSelfAttention attention;
    final Optional<LayerNorm> postAttentionNorm;
    final Optional<LayerNorm> preFFNorm;
    final FeedForward ffBlock;
    final Optional<LayerNorm> postFFNorm;
    final Optional<LayerNorm> preResponseNorm;

    public TransformerBlock(AbstractModel model, int layerIndex, LayerNorm preAttentionNorm, CausalSelfAttention attention, LayerNorm postAttentionNorm, FeedForward ffBlock) {
        this(model, layerIndex, Optional.of(preAttentionNorm), attention, Optional.empty(), Optional.of(postAttentionNorm), ffBlock, Optional.empty(), Optional.empty());
    }

    public TransformerBlock(AbstractModel model, int layerIndex, CausalSelfAttention attention, LayerNorm postAttentionNorm, FeedForward ffBlock, LayerNorm postFFNorm) {
        this(model, layerIndex, Optional.empty(), attention, Optional.empty(), Optional.of(postAttentionNorm), ffBlock, Optional.empty(), Optional.of(postFFNorm));
    }

    public TransformerBlock(AbstractModel model, int layerIndex, LayerNorm preAttentionNorm, CausalSelfAttention attention, LayerNorm postAttentionNorm, FeedForward ffBlock, LayerNorm postFFNorm) {
        this(model, layerIndex, Optional.of(preAttentionNorm), attention, Optional.empty(), Optional.of(postAttentionNorm), ffBlock, Optional.empty(), Optional.of(postFFNorm));
    }

    public TransformerBlock(AbstractModel model, int layerIndex, LayerNorm preAttentionNorm, CausalSelfAttention attention, LayerNorm postAttentionNorm, LayerNorm preFFNorm, FeedForward ffBlock, LayerNorm postFFNorm) {
        this(model, layerIndex, Optional.of(preAttentionNorm), attention, Optional.of(postAttentionNorm), Optional.of(preFFNorm), ffBlock, Optional.of(postFFNorm), Optional.empty());
    }

    public TransformerBlock(AbstractModel model, int layerIndex, Optional<LayerNorm> preAttentionNorm, CausalSelfAttention attention, Optional<LayerNorm> postAttentionNorm, Optional<LayerNorm> preFFNorm, FeedForward ffBlock, Optional<LayerNorm> postFFNorm, Optional<LayerNorm> preResponseNorm) {
        this.model = model;
        this.layerIndex = layerIndex;
        this.preAttentionNorm = preAttentionNorm;
        this.attention = attention;
        this.postAttentionNorm = postAttentionNorm;
        this.preFFNorm = preFFNorm;
        this.ffBlock = ffBlock;
        this.postFFNorm = postFFNorm;
        this.preResponseNorm = preResponseNorm;
    }

    public AbstractTensor forward(AbstractTensor embedding, int position, KvBufferCache.KvBuffer kvBuffer) {
        return this.forward(embedding, position, kvBuffer, Optional.empty());
    }

    public AbstractTensor forward(AbstractTensor embedding, int position, KvBufferCache.KvBuffer kvBuffer, Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
        AbstractTensor postFF;
        AbstractTensor postAttention;
        DebugSupport.debug("input_emb", embedding, this.layerIndex);
        AbstractTensor lnemb = this.preAttentionNorm.map(ln -> ln.forward(embedding)).orElse(embedding);
        DebugSupport.debug("ln_emb", lnemb, this.layerIndex);
        try (AbstractTensor qlnemb = this.model.maybeQuantize(lnemb);){
            postAttention = this.attention.forward(qlnemb, position, kvBuffer, tensorReducer);
        }
        DebugSupport.debug("post_attn", postAttention, this.layerIndex);
        AbstractTensor lnattn = this.maybeApplyNorm(postAttention, this.postAttentionNorm);
        DebugSupport.debug("post_attn_norm", lnattn, this.layerIndex);
        if (this.model.c.residualMultiplier != null) {
            TensorOperationsProvider.get().scale(this.model.c.residualMultiplier.floatValue(), lnattn, 0, this.model.c.embeddingLength);
        }
        TensorOperationsProvider.get().accumulate(lnattn, embedding, 0, this.model.c.embeddingLength);
        AbstractTensor lnpreFF = this.preFFNorm.map(ln -> ln.forward(lnattn)).orElse(lnattn);
        DebugSupport.debug("pre_ff_norm", lnpreFF, this.layerIndex);
        try (AbstractTensor qlnemb2 = this.model.maybeQuantize(lnpreFF);){
            postFF = this.ffBlock.forward(qlnemb2, tensorReducer);
            DebugSupport.debug("post_ff", postFF, this.layerIndex);
        }
        AbstractTensor lnpostFF = this.maybeApplyNorm(postFF, this.postFFNorm);
        if (this.model.c.residualMultiplier != null) {
            TensorOperationsProvider.get().scale(this.model.c.residualMultiplier.floatValue(), lnpostFF, 0, this.model.c.embeddingLength);
        }
        TensorOperationsProvider.get().accumulate(lnpostFF, lnattn, 0, this.model.c.embeddingLength);
        DebugSupport.debug("post_ff_res", lnpostFF, this.layerIndex);
        if (lnemb != embedding) {
            lnemb.close();
        }
        if (lnattn != postAttention) {
            lnattn.close();
        } else {
            postAttention.close();
        }
        if (lnpreFF != lnattn) {
            lnpreFF.close();
        } else {
            lnattn.close();
        }
        return this.maybeApplyNorm(lnpostFF, this.preResponseNorm);
    }

    private AbstractTensor maybeApplyNorm(AbstractTensor tensor, Optional<LayerNorm> norm) {
        return norm.map(ln -> {
            AbstractTensor o = ln.forward(tensor);
            tensor.close();
            return o;
        }).orElse(tensor);
    }
}

