/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.TransformerBlock;
import com.github.tjake.jlama.model.functions.ClassifyOutput;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.model.functions.PoolingLayer;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import com.github.tjake.jlama.safetensors.prompt.Tool;
import com.github.tjake.jlama.safetensors.prompt.ToolCall;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
import com.github.tjake.jlama.tensor.TensorShape;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import com.github.tjake.jlama.util.JsonSupport;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import jdk.incubator.vector.FloatVector;
import net.jafama.FastMath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractModel
implements Generator {
    private static final Logger logger = LoggerFactory.getLogger(AbstractModel.class);
    private static final Integer MAX_BATCH_SIZE = Integer.getInteger("jlama.max_batch_size", 256);
    protected final InferenceType inferenceType;
    protected final Config c;
    protected final WeightLoader weights;
    protected final Tokenizer tokenizer;
    protected final DType modelDType;
    protected final DType workingDType;
    protected final DType workingQType;
    protected final Optional<DType> modelQType;
    protected EmbedInput embedInput;
    protected SampleOutput sampleOutput;
    protected ClassifyOutput classifyOutput;
    protected Optional<PoolingLayer> poolingLayer;
    protected TransformerBlock[] transformerBlocks;
    protected KvBufferCache kvBufferCache;

    protected AbstractModel(InferenceType inferenceType, Config c, WeightLoader w, Tokenizer t, DType workingMemoryDType, DType workingMemoryQType, Optional<DType> modelQType) {
        this.inferenceType = inferenceType;
        this.c = c;
        this.weights = w;
        this.tokenizer = t;
        this.modelDType = w.getModelDType();
        this.workingDType = workingMemoryDType;
        this.modelQType = modelQType;
        this.kvBufferCache = new KvBufferCache(this);
        if (this.modelDType == DType.F32 && workingMemoryQType != DType.F32 && modelQType.isEmpty()) {
            workingMemoryQType = DType.F32;
        }
        if (this.modelDType == DType.BF16 && workingMemoryQType != DType.BF16 && modelQType.isEmpty()) {
            workingMemoryQType = DType.BF16;
        }
        if (this.modelDType == DType.Q4 && workingMemoryQType == DType.I8 && (c.embeddingLength / 32 % (FloatVector.SPECIES_PREFERRED.vectorBitSize() / 32) != 0 || c.hiddenLength / 32 % (FloatVector.SPECIES_PREFERRED.vectorBitSize() / 32) != 0)) {
            workingMemoryQType = DType.F32;
        }
        if (this.modelDType == DType.Q4 && workingMemoryQType == DType.I8 && c.embeddingLength / 32 % (FloatVector.SPECIES_PREFERRED.vectorBitSize() / 32) != 0) {
            workingMemoryQType = DType.F32;
        }
        if (workingMemoryQType != workingMemoryDType) {
            AbstractTensor tmp = this.makeDenseTensor(32);
            try (AbstractTensor tmp2 = TensorOperationsProvider.get().quantize(tmp, workingMemoryQType, 0, 32);){
                boolean supportsQType;
                boolean bl = supportsQType = tmp2.dType() == workingMemoryQType;
                if (!supportsQType) {
                    logger.warn("Quantized memory type {} not supported, falling back to {}", (Object)workingMemoryQType, (Object)workingMemoryDType);
                    this.workingQType = this.workingDType;
                }
                this.workingQType = workingMemoryQType;
            }
        } else {
            this.workingQType = workingMemoryQType;
        }
        logger.info("Model type = {}, Working memory type = {}, Quantized memory type = {}", new Object[]{this.modelDType, this.workingDType, this.workingQType});
        this.embedInput = inferenceType.isInput ? this.loadInputWeights() : null;
        this.transformerBlocks = inferenceType.isFwdPass ? this.loadTransformerBlockWeights() : null;
        this.sampleOutput = inferenceType.isOutput ? this.loadOutputWeights() : null;
        this.classifyOutput = inferenceType.isClassify ? this.loadClassifierWeights() : null;
        this.poolingLayer = inferenceType.isPooling ? Optional.ofNullable(this.loadPoolingWeights()) : Optional.empty();
    }

    @Override
    public void close() {
        this.kvBufferCache.close();
    }

    protected abstract EmbedInput loadInputWeights();

    protected abstract TransformerBlock[] loadTransformerBlockWeights();

    protected abstract SampleOutput loadOutputWeights();

    protected ClassifyOutput loadClassifierWeights() {
        throw new UnsupportedOperationException("Classification not supported by this model");
    }

    protected PoolingLayer loadPoolingWeights() {
        return null;
    }

    public abstract ModelSupport.ModelType getModelType();

    public InferenceType getInferenceType() {
        return this.inferenceType;
    }

    public DType getWorkingDType() {
        return this.workingDType;
    }

    @Override
    public Config getConfig() {
        return this.c;
    }

    @Override
    public Tokenizer getTokenizer() {
        return this.tokenizer;
    }

    public WeightLoader getWeights() {
        return this.weights;
    }

    @Override
    public Optional<PromptSupport> promptSupport() {
        return this.tokenizer.promptSupport();
    }

    public AbstractTensor makeTensor(int ... shape) {
        TensorShape s = TensorShape.of(shape);
        return this.c.tensorCache.get(this.workingDType, s);
    }

    public AbstractTensor makeDenseTensor(int ... shape) {
        return this.c.tensorCache.get(this.workingDType, TensorShape.of(shape));
    }

    public AbstractTensor makeDenseTensor(TensorShape s) {
        return this.c.tensorCache.get(this.workingDType, s);
    }

    protected AbstractTensor maybeQuantize(AbstractTensor t) {
        AbstractTensor t2 = this.c.tensorCache.get(t.dType(), t.shape());
        t2.copyFrom(t, 0, 0, Ints.checkedCast(t.size()));
        return t2;
    }

    protected AbstractTensor forward(int token_id, int pos, KvBufferCache.KvBuffer kvbuf) {
        return this.forward(token_id, pos, kvbuf, Optional.empty());
    }

    public AbstractTensor forward(int token_id, int pos, KvBufferCache.KvBuffer kvbuf, Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
        AbstractTensor embedding = this.embedInput.inputTokenToEmbedding(token_id, pos);
        DebugSupport.debug("EMBEDDING TOKEN", token_id);
        DebugSupport.debug("TOKEN POSITION", pos);
        return this.forward(embedding, pos, kvbuf, tensorReducer);
    }

    protected AbstractTensor batchForwardSlow(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf) {
        AbstractTensor last = null;
        for (int i = 0; i < token_ids.length; ++i) {
            if (last != null) {
                last.close();
            }
            last = this.forward(token_ids[i], startPos + i, kvbuf);
        }
        return last;
    }

    public AbstractTensor batchForward(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf) {
        return this.batchForward(token_ids, startPos, kvbuf, Optional.empty());
    }

    public AbstractTensor batchForward(int[] token_ids, int startPos, KvBufferCache.KvBuffer kvbuf, Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
        AbstractTensor embedding = null;
        for (int i = 0; i < token_ids.length; i += MAX_BATCH_SIZE.intValue()) {
            int[] batch = Arrays.copyOfRange(token_ids, i, Math.min(token_ids.length, i + MAX_BATCH_SIZE));
            embedding = this.embedInput.batchInputsToEmbeddings(batch, startPos + i);
            embedding = this.forward(embedding, startPos + i, kvbuf, tensorReducer);
            logger.debug("Batched forward pass for tokens {} to {}", (Object)i, (Object)(i + batch.length));
        }
        return embedding;
    }

    public AbstractTensor forward(AbstractTensor embedding, int startPos, KvBufferCache.KvBuffer kvbuf, Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
        for (int i = this.c.dctx().layerStart; i < this.c.dctx().layerEnd; ++i) {
            int relativeLayer = i - this.c.dctx().layerStart;
            AbstractTensor ref = embedding;
            embedding = this.transformerBlocks[relativeLayer].forward(embedding, startPos, kvbuf, tensorReducer);
            ref.close();
        }
        return embedding;
    }

    @Override
    public float[] embed(String input, Generator.PoolingType poolingType) {
        int[] encoded = Arrays.stream(this.tokenizer.encode(input)).mapToInt(Ints::checkedCast).toArray();
        Preconditions.checkArgument(encoded.length < this.c.contextLength);
        float[] outputEmbedding = new float[this.c.embeddingLength];
        try (KvBufferCache.KvBuffer kvmem = this.kvBufferCache.getEphemeralKvBuffer();){
            int promptLength = encoded.length;
            float avgp = 1.0f / (float)promptLength;
            try (AbstractTensor r = this.batchForward(encoded, 0, kvmem);){
                if (poolingType == Generator.PoolingType.MODEL) {
                    if (this.poolingLayer.isPresent()) {
                        AbstractTensor output = r.slice(promptLength - 1);
                        AbstractTensor pooled = this.makeDenseTensor(1, this.c.embeddingLength);
                        TensorOperationsProvider.get().batchDotProduct(pooled, output, this.poolingLayer.get().getPoolingWeights(), 0, 0, this.c.embeddingLength);
                        this.poolingLayer.get().getPoolingBias().ifPresent(bias -> TensorOperationsProvider.get().accumulate(pooled, (AbstractTensor)bias, 0, this.c.embeddingLength));
                        VectorMath.pfor(0, this.c.embeddingLength, i -> {
                            outputEmbedding[i] = ActivationFunction.eval(ActivationFunction.Type.TANH, pooled.get(0, i));
                        });
                        float[] fArray = outputEmbedding;
                        return fArray;
                    }
                    throw new UnsupportedOperationException("Pooling layer not found");
                }
                for (int i2 = 0; i2 < promptLength; ++i2) {
                    AbstractTensor output = r.slice(i2);
                    block18: for (int ii = 0; ii < this.c.embeddingLength; ++ii) {
                        switch (poolingType) {
                            case AVG: {
                                int n = ii;
                                outputEmbedding[n] = outputEmbedding[n] + output.get(0, ii) * avgp;
                                continue block18;
                            }
                            case MAX: {
                                outputEmbedding[ii] = Math.max(outputEmbedding[ii], output.get(0, ii));
                                continue block18;
                            }
                            case SUM: {
                                int n = ii;
                                outputEmbedding[n] = outputEmbedding[n] + output.get(0, ii);
                            }
                        }
                    }
                }
            }
            VectorMath.l2normalize(outputEmbedding);
            float[] fArray = outputEmbedding;
            return fArray;
        }
    }

    @Override
    public Map<String, Float> classify(String input, Generator.PoolingType poolingType) {
        if (!this.c.isClassifier() || this.classifyOutput == null) {
            throw new UnsupportedOperationException("Classification not supported by this model");
        }
        float[] embedding = this.embed(input, poolingType);
        FloatBufferTensor b = new FloatBufferTensor(FloatBuffer.wrap(embedding), TensorShape.of(embedding.length), false);
        int classes = this.classifyOutput.getClassificationWeights().shape().first();
        AbstractTensor scores = this.makeDenseTensor(classes);
        TensorOperationsProvider.get().batchDotProduct(scores, b, this.classifyOutput.getClassificationWeights(), 0, 0, this.c.embeddingLength);
        this.classifyOutput.getClassificationBias().ifPresent(bias -> TensorOperationsProvider.get().accumulate(scores, (AbstractTensor)bias, 0, classes));
        VectorMath.softMax(scores, 0, classes);
        HashMap<String, Float> result = new HashMap<String, Float>();
        int i = 0;
        while (i < classes) {
            String label = (String)this.c.classifcationLabels.get().inverse().get(i);
            Float score = Float.valueOf(scores.get(0, i++));
            result.put(label, score);
        }
        return result;
    }

    public float[] getLogits(AbstractTensor output) {
        try (AbstractTensor embedding = this.sampleOutput.getOutputLayerNorm().forward(output);){
            AbstractTensor logits = this.makeDenseTensor(1, this.c.vocabularySize);
            try {
                VectorMath.pchunk(0, this.c.vocabularySize, (chunkStart, chunkSize) -> TensorOperationsProvider.get().dotProductChunk(logits, embedding, this.sampleOutput.getOutputLogitsWeights(), 0, this.c.embeddingLength, chunkStart, chunkSize));
                VectorMath.softMax(logits, 0, this.c.vocabularySize);
                float[] r = new float[this.c.vocabularySize];
                logits.getMemorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(r);
                float[] fArray = r;
                if (logits != null) {
                    logits.close();
                }
                return fArray;
            }
            catch (Throwable throwable) {
                if (logits != null) {
                    try {
                        logits.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
        }
    }

    public int sample(AbstractTensor output, float temperature, float uniformSample, AbstractTensor logits) {
        try (AbstractTensor embedding = this.sampleOutput.getOutputLayerNorm().forward(output);){
            int i;
            VectorMath.pchunk(0, this.c.vocabularySize, (chunkStart, chunkSize) -> TensorOperationsProvider.get().dotProductChunk(logits, embedding, this.sampleOutput.getOutputLogitsWeights(), 0, this.c.embeddingLength, chunkStart, chunkSize));
            if (this.c.logitMultiplier != null) {
                TensorOperationsProvider.get().scale(1.0f / this.c.logitMultiplier.floatValue(), logits, 0, this.c.vocabularySize);
            }
            int maxi = Integer.MIN_VALUE;
            double maxv = Double.NEGATIVE_INFINITY;
            for (i = 0; i < this.c.vocabularySize; ++i) {
                float v = logits.get(0, i);
                if (this.c.finalLogitSoftCapping != null) {
                    v /= this.c.finalLogitSoftCapping.floatValue();
                    v = (float)FastMath.tanh(v);
                    logits.set(v *= this.c.finalLogitSoftCapping.floatValue(), 0, i);
                }
                if (!((double)v > maxv)) continue;
                maxi = i;
                maxv = v;
            }
            if ((double)temperature == 0.0) {
                i = maxi;
                return i;
            }
            float sum = 0.0f;
            int i2 = 0;
            while (i2 < this.c.vocabularySize) {
                float v = (float)FastMath.exp(((double)logits.get(0, i2) - maxv) / (double)temperature);
                sum += v;
                logits.set(v, 0, i2++);
            }
            float acc = 0.0f;
            int i3 = 0;
            while (i3 < this.c.vocabularySize) {
                int[] nArray = new int[]{0, i3++};
                float v = logits.get(nArray) / sum;
                if (!((acc += v) >= uniformSample)) continue;
                int n = i3;
                return n;
            }
            int n = this.c.vocabularySize - 1;
            return n;
        }
    }

    protected boolean addBosToken() {
        return true;
    }

    public int[] encodePrompt(PromptContext promptContext) {
        long[] encoded = this.tokenizer.encode(promptContext.getPrompt());
        if (!this.addBosToken()) {
            return Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray();
        }
        if (encoded.length > 0 && encoded[0] == (long)this.c.bosToken) {
            encoded = Arrays.copyOfRange(encoded, 1, encoded.length);
        }
        int[] promptTokens = new int[1 + encoded.length];
        promptTokens[0] = this.c.bosToken;
        for (int i = 1; i <= encoded.length; ++i) {
            promptTokens[i] = Ints.checkedCast(encoded[i - 1]);
        }
        return promptTokens;
    }

    @Override
    public Generator.Response generate(UUID sessionId, PromptContext promptContext, float temperature, int ntokens, BiConsumer<String, Float> onTokenWithTimings) {
        long[] encoded = this.tokenizer.encode(promptContext.getPrompt());
        if (encoded.length > 0 && encoded[0] == (long)this.c.bosToken) {
            encoded = Arrays.copyOfRange(encoded, 1, encoded.length);
        }
        Preconditions.checkArgument(encoded.length < this.c.contextLength && encoded.length < ntokens, "Prompt exceeds max tokens");
        try (KvBufferCache.KvBuffer kvmem = this.kvBufferCache.getKvBuffer(sessionId);){
            int startPos = kvmem.getCurrentContextPosition();
            logger.debug("Starting at token {} for session {} with prompt {}", startPos, sessionId, promptContext.getPrompt());
            if (ntokens > this.c.contextLength) {
                ntokens = this.c.contextLength;
            }
            Generator.FinishReason reason = Generator.FinishReason.MAX_TOKENS;
            StringBuilder responseText = new StringBuilder();
            StringBuilder responseTextWithSpecialTokens = new StringBuilder();
            AbstractTensor logits = this.makeDenseTensor(this.c.vocabularySize);
            try {
                long start;
                int promptLength;
                int[] promptTokens;
                if (this.addBosToken()) {
                    promptTokens = new int[1 + encoded.length];
                    promptTokens[0] = this.c.bosToken;
                    for (int i = 1; i <= encoded.length; ++i) {
                        promptTokens[i] = Ints.checkedCast(encoded[i - 1]);
                    }
                    promptLength = encoded.length;
                } else {
                    promptTokens = Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray();
                    promptLength = encoded.length;
                }
                long promptStart = start = System.currentTimeMillis();
                AbstractTensor last = DebugSupport.isDebug() ? this.batchForwardSlow(promptTokens, startPos, kvmem) : this.batchForward(promptTokens, startPos, kvmem);
                long promptBatchTime = System.currentTimeMillis() - start;
                float batchMsPerToken = Math.round((double)promptBatchTime / (double)promptLength);
                logger.debug("{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, Float.valueOf(batchMsPerToken));
                float genMsPerToken = 0.0f;
                int tokensGenerated = 0;
                int next = this.sample(last.slice(last.shape().first() - 1), temperature, ThreadLocalRandom.current().nextFloat(), logits);
                last.close();
                try {
                    String c = this.tokenizer.decode(next);
                    if (this.tokenizer.getModel().isSpecialToken(next)) {
                        responseTextWithSpecialTokens.append(c);
                    } else {
                        onTokenWithTimings.accept(c, Float.valueOf(batchMsPerToken));
                        responseText.append(c);
                        responseTextWithSpecialTokens.append(c);
                    }
                }
                catch (Exception e) {
                    logger.error("Failed to decode token {}", (Object)next, (Object)e);
                }
                start = System.currentTimeMillis();
                for (int i = startPos + promptTokens.length; i < ntokens; ++i) {
                    AbstractTensor output = this.forward(next, i, kvmem);
                    ++tokensGenerated;
                    next = this.sample(output, temperature, ThreadLocalRandom.current().nextFloat(), logits);
                    if (logger.isTraceEnabled()) {
                        logger.trace("Sampled token {} with temperature {}", (Object)next, (Object)Float.valueOf(temperature));
                    }
                    output.close();
                    kvmem.incrementContextPosition();
                    if (this.c.eosTokens.contains(next)) {
                        reason = Generator.FinishReason.STOP_TOKEN;
                        break;
                    }
                    try {
                        String c = this.tokenizer.decode(next);
                        if (this.tokenizer.getModel().isSpecialToken(next)) {
                            responseTextWithSpecialTokens.append(c);
                            continue;
                        }
                        genMsPerToken = (float)(System.currentTimeMillis() - start) / (float)tokensGenerated;
                        onTokenWithTimings.accept(c, Float.valueOf(genMsPerToken));
                        responseTextWithSpecialTokens.append(c);
                        responseText.append(c);
                        continue;
                    }
                    catch (Exception e) {
                        logger.error("Failed to decode token {}", (Object)next, (Object)e);
                    }
                }
                long end = System.currentTimeMillis();
                Generator.Response response = new Generator.Response(responseText.toString(), responseTextWithSpecialTokens.toString(), reason, promptLength, tokensGenerated, promptBatchTime, end - start);
                logger.debug(String.format("\n\nelapsed: %ds, prompt %.1fms per token, gen %.1fms per token\n", TimeUnit.MILLISECONDS.toSeconds(end - promptStart), Float.valueOf(batchMsPerToken), Float.valueOf(genMsPerToken)));
                Generator.Response response2 = this.postProcessResponse(promptContext, response);
                if (logits != null) {
                    logits.close();
                }
                return response2;
            }
            catch (Throwable throwable) {
                if (logits != null) {
                    try {
                        logits.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
        }
    }

    protected Generator.Response postProcessResponse(PromptContext promptContext, Generator.Response response) {
        if (!this.tokenizer.getModel().hasToolSupport() || !promptContext.hasTools() || response.finishReason != Generator.FinishReason.STOP_TOKEN) {
            return response;
        }
        List<Tool> tools = promptContext.getTools().get();
        boolean foundTool = false;
        for (Tool tool : tools) {
            if (!response.responseTextWithSpecialTokens.contains(tool.getFunction().getName())) continue;
            foundTool = true;
            break;
        }
        if (!foundTool) {
            return response;
        }
        try {
            List<String> jsonCalls = JsonSupport.extractJsonFromString(response.responseText);
            if (jsonCalls.isEmpty()) {
                logger.warn("Tool call detected but no tool call found in response: {}", (Object)response.responseText);
                return response;
            }
            logger.debug("Found tool calls: {}", (Object)jsonCalls);
            List<Object> toolCalls = new ArrayList(jsonCalls.size());
            for (String jsonCall : jsonCalls) {
                if (jsonCall.startsWith("[")) {
                    List<ToolCall> toolCallList = JsonSupport.om.readValue(jsonCall, new TypeReference<List<ToolCall>>(this){});
                    toolCalls.addAll(toolCallList);
                    continue;
                }
                ToolCall toolCall = JsonSupport.om.readValue(jsonCall, ToolCall.class);
                toolCalls.add(toolCall);
            }
            toolCalls = toolCalls.stream().sorted(Comparator.comparing(ToolCall::getName)).distinct().collect(Collectors.toList());
            for (int i = 0; i < toolCalls.size(); ++i) {
                ((ToolCall)toolCalls.get(i)).setId(String.format("%09d", i));
            }
            return response.copyWithToolCalls(toolCalls);
        }
        catch (JsonProcessingException e) {
            logger.error("Failed to parse tool call from response: {}", (Object)response.responseText, (Object)e);
            return response;
        }
    }

    public static enum InferenceType {
        INPUT_TO_EMBEDDING(true, false, false, false, false),
        OUTPUT_TO_TOKEN(false, false, true, false, false),
        FORWARD_PASS(true, true, false, false, false),
        FULL_GENERATION(true, true, true, false, false),
        FULL_CLASSIFICATION(true, true, false, true, true),
        FULL_EMBEDDING(true, true, false, false, true);

        final boolean isInput;
        final boolean isOutput;
        final boolean isClassify;
        final boolean isFwdPass;
        final boolean isPooling;

        private InferenceType(boolean isInput, boolean isFwdPass, boolean isOutput, boolean isClassify, boolean isPooling) {
            this.isInput = isInput;
            this.isOutput = isOutput;
            this.isFwdPass = isFwdPass;
            this.isClassify = isClassify;
            this.isPooling = isPooling;
        }
    }
}

