/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.model.bert.BertConfig;
import com.github.tjake.jlama.model.bert.BertModel;
import com.github.tjake.jlama.model.bert.BertTokenizer;
import com.github.tjake.jlama.model.gemma.GemmaConfig;
import com.github.tjake.jlama.model.gemma.GemmaModel;
import com.github.tjake.jlama.model.gemma.GemmaTokenizer;
import com.github.tjake.jlama.model.gemma2.Gemma2Config;
import com.github.tjake.jlama.model.gemma2.Gemma2Model;
import com.github.tjake.jlama.model.gpt2.GPT2Config;
import com.github.tjake.jlama.model.gpt2.GPT2Model;
import com.github.tjake.jlama.model.gpt2.GPT2Tokenizer;
import com.github.tjake.jlama.model.granite.GraniteConfig;
import com.github.tjake.jlama.model.granite.GraniteModel;
import com.github.tjake.jlama.model.llama.LlamaConfig;
import com.github.tjake.jlama.model.llama.LlamaModel;
import com.github.tjake.jlama.model.llama.LlamaTokenizer;
import com.github.tjake.jlama.model.mistral.MistralConfig;
import com.github.tjake.jlama.model.mistral.MistralModel;
import com.github.tjake.jlama.model.mixtral.MixtralConfig;
import com.github.tjake.jlama.model.mixtral.MixtralModel;
import com.github.tjake.jlama.model.qwen2.Qwen2Config;
import com.github.tjake.jlama.model.qwen2.Qwen2Model;
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.PhysicalCoreExecutor;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.nio.file.Path;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelSupport {
    private static final Logger logger = LoggerFactory.getLogger(ModelSupport.class);

    public static AbstractModel loadModel(File model, DType workingMemoryType, DType workingQuantizationType) {
        return ModelSupport.loadModel(model, null, workingMemoryType, workingQuantizationType, Optional.empty(), Optional.empty());
    }

    public static AbstractModel loadEmbeddingModel(File model, DType workingMemoryType, DType workingQuantizationType) {
        return ModelSupport.loadModel(AbstractModel.InferenceType.FULL_EMBEDDING, model, null, workingMemoryType, workingQuantizationType, Optional.empty(), Optional.empty(), Optional.empty(), SafeTensorSupport::loadWeights);
    }

    public static AbstractModel loadClassifierModel(File model, DType workingMemoryType, DType workingQuantizationType) {
        return ModelSupport.loadModel(AbstractModel.InferenceType.FULL_CLASSIFICATION, model, null, workingMemoryType, workingQuantizationType, Optional.empty(), Optional.empty(), Optional.empty(), SafeTensorSupport::loadWeights);
    }

    public static AbstractModel loadModel(File model, File workingDirectory, DType workingMemoryType, DType workingQuantizationType, Optional<DType> modelQuantization, Optional<Integer> threadCount) {
        return ModelSupport.loadModel(AbstractModel.InferenceType.FULL_GENERATION, model, workingDirectory, workingMemoryType, workingQuantizationType, modelQuantization, threadCount, Optional.empty(), SafeTensorSupport::loadWeights);
    }

    public static AbstractModel loadModel(AbstractModel.InferenceType inferenceType, File model, File workingDirectory, DType workingMemoryType, DType workingQuantizationType, Optional<DType> modelQuantization, Optional<Integer> threadCount, Optional<Function<Config, DistributedContext>> distributedContextLoader, Function<File, WeightLoader> weightLoaderSupplier) {
        File baseDir;
        if (!model.exists()) {
            throw new IllegalArgumentException("Model location does not exist: " + String.valueOf(model));
        }
        File file = baseDir = model.isFile() ? model.getParentFile() : model;
        if (!baseDir.isDirectory()) {
            throw new IllegalArgumentException("Model directory does not exist: " + String.valueOf(baseDir));
        }
        File configFile = null;
        for (File f : Objects.requireNonNull(baseDir.listFiles())) {
            if (!f.getName().equals("config.json")) continue;
            configFile = f;
            break;
        }
        if (configFile == null) {
            throw new IllegalArgumentException("config.json in model directory does not exist: " + String.valueOf(baseDir));
        }
        try {
            threadCount.ifPresent(PhysicalCoreExecutor::overrideThreadCount);
            ModelType modelType = SafeTensorSupport.detectModel(configFile);
            Config c = JsonSupport.om.readValue(configFile, modelType.configClass);
            distributedContextLoader.ifPresent(loader -> c.setDistributedContext((DistributedContext)loader.apply(c)));
            c.setWorkingDirectory(workingDirectory);
            Tokenizer t = modelType.tokenizerClass.getConstructor(Path.class).newInstance(baseDir.toPath());
            WeightLoader wl = weightLoaderSupplier.apply(baseDir);
            return modelType.modelClass.getConstructor(AbstractModel.InferenceType.class, Config.class, WeightLoader.class, Tokenizer.class, DType.class, DType.class, Optional.class).newInstance(new Object[]{inferenceType, c, wl, t, workingMemoryType, workingQuantizationType, modelQuantization});
        }
        catch (IOException | IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }

    public static enum ModelType {
        GEMMA(GemmaModel.class, GemmaConfig.class, GemmaTokenizer.class),
        GEMMA2(Gemma2Model.class, Gemma2Config.class, GemmaTokenizer.class),
        MISTRAL(MistralModel.class, MistralConfig.class, LlamaTokenizer.class),
        GRANITE(GraniteModel.class, GraniteConfig.class, GPT2Tokenizer.class),
        MIXTRAL(MixtralModel.class, MixtralConfig.class, LlamaTokenizer.class),
        LLAMA(LlamaModel.class, LlamaConfig.class, LlamaTokenizer.class),
        GPT2(GPT2Model.class, GPT2Config.class, GPT2Tokenizer.class),
        BERT(BertModel.class, BertConfig.class, BertTokenizer.class),
        QWEN2(Qwen2Model.class, Qwen2Config.class, LlamaTokenizer.class);

        public final Class<? extends AbstractModel> modelClass;
        public final Class<? extends Config> configClass;
        public final Class<? extends Tokenizer> tokenizerClass;

        private ModelType(Class<? extends AbstractModel> modelClass, Class<? extends Config> configClass, Class<? extends Tokenizer> tokenizerClass) {
            this.modelClass = modelClass;
            this.configClass = configClass;
            this.tokenizerClass = tokenizerClass;
        }
    }
}

