/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.model.gpt2;

import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import net.fellbaum.jemoji.EmojiManager;

public class GPT2Tokenizer
extends BPETokenizer {
    private static BiMap<Integer, String> codePointsToByteStrings;
    private static BiMap<Integer, Integer> alteredBytes;

    public GPT2Tokenizer(Path modelPath) {
        super(modelPath);
    }

    @Override
    protected String preProcess(String sentence) {
        return sentence.codePoints().map(c -> alteredBytes.getOrDefault(c, c)).mapToObj(Character::toString).collect(Collectors.joining());
    }

    @Override
    protected long encodeCharacterAsToken(byte c) {
        int i = Byte.toUnsignedInt(c);
        Integer token = alteredBytes.getOrDefault(i, i);
        String s = Character.toString(token);
        Long b = (Long)this.model.vocabLookup.get(s);
        return b == null ? (long)token.intValue() : b;
    }

    @Override
    protected Optional<Character> maybeDecodeTokenAsCharacter(long id) {
        return Optional.empty();
    }

    @Override
    public String decode(long id) {
        String s = (String)this.model.vocabLookup.inverse().get(id);
        return s.codePoints().map(c -> alteredBytes.inverse().getOrDefault(c, c)).mapToObj(Character::toString).collect(Collectors.joining());
    }

    @Override
    protected String postProcess(String s) {
        for (Map.Entry e : codePointsToByteStrings.entrySet()) {
            if (!s.contains((CharSequence)e.getValue())) continue;
            s = s.replace((CharSequence)e.getValue(), Character.toString((Integer)e.getKey()));
        }
        return s;
    }

    static {
        alteredBytes = HashBiMap.create();
        int i = 0;
        for (int c = 0; c < 256; ++c) {
            if (c >= 33 && c <= 126 || c >= 161 && c <= 172 || c >= 174 && c <= 255) continue;
            int codepoint = i++ + 256;
            alteredBytes.put(c, codepoint);
        }
        codePointsToByteStrings = HashBiMap.create();
        for (int j = 9000; j <= 128512; ++j) {
            if (!EmojiManager.isEmoji((String)Character.toString(j))) continue;
            byte[] b = Character.toString(j).getBytes(StandardCharsets.UTF_8);
            StringBuilder sb = new StringBuilder();
            for (int k = 0; k < b.length; ++k) {
                String piece = Character.toString(Byte.toUnsignedInt(b[k]));
                sb.append(piece);
            }
            codePointsToByteStrings.put(j, sb.toString());
        }
    }
}

