/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3.Model;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.yacy.ai.llama3.Model.Pair;
import net.yacy.ai.llama3.Model.Vocabulary;

public class Tokenizer {
    private final Pattern compiledPattern;
    private final Vocabulary vocabulary;
    private final Map<Pair<Integer, Integer>, Integer> merges;
    private final Map<String, Integer> specialTokens;
    private final int[] tokenTypes;
    static final Map<Integer, Integer> BYTE_ENCODER = Tokenizer.bytesToUnicode();
    static final Map<Integer, Integer> BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));

    public String regexPattern() {
        if (this.compiledPattern == null) {
            return null;
        }
        return this.compiledPattern.pattern();
    }

    public Map<String, Integer> getSpecialTokens() {
        return this.specialTokens;
    }

    public boolean isSpecialToken(int tokenIndex) {
        return this.specialTokens.containsValue(tokenIndex);
    }

    public int getTokenType(int tokenIndex) {
        return this.tokenTypes[tokenIndex];
    }

    public Tokenizer(Vocabulary vocabulary, List<Pair<Integer, Integer>> merges, String regexPattern, Map<String, Integer> specialTokens, int[] tokenTypes) {
        this.vocabulary = vocabulary;
        this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null;
        this.specialTokens = new HashMap<String, Integer>(specialTokens);
        this.merges = new HashMap<Pair<Integer, Integer>, Integer>();
        this.tokenTypes = tokenTypes;
        for (Pair<Integer, Integer> pair : merges) {
            int firstIndex = pair.first();
            int secondIndex = pair.second();
            int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow();
            this.merges.put(pair, mergeIndex);
        }
    }

    private int[] encodeImpl(String text) {
        return this.encode(text, Set.of()).stream().mapToInt(i -> i).toArray();
    }

    List<Integer> encode(String text, Set<String> allowedSpecial) {
        Set<String> special = allowedSpecial;
        assert (this.getSpecialTokens().keySet().containsAll(special));
        if (special.isEmpty()) {
            return this.encodeOrdinary(text);
        }
        String specialPattern = special.stream().map(Pattern::quote).collect(Collectors.joining("|", "(", ")"));
        String[] specialChunks = text.split(specialPattern);
        ArrayList<Integer> ids = new ArrayList<Integer>();
        for (String part : specialChunks) {
            if (special.contains(part)) {
                ids.add(this.getSpecialTokens().get(part));
                continue;
            }
            ids.addAll(this.encodeOrdinary(part));
        }
        return ids;
    }

    private static List<String> findAll(Pattern pattern, String text) {
        ArrayList<String> allMatches = new ArrayList<String>();
        Matcher matcher = pattern.matcher(text);
        while (matcher.find()) {
            allMatches.add(matcher.group());
        }
        return allMatches;
    }

    public List<Integer> encodeOrdinary(String text) {
        List<String> textChunks = Tokenizer.findAll(this.compiledPattern, text);
        ArrayList<Integer> ids = new ArrayList<Integer>();
        for (String chunk : textChunks) {
            List<Integer> chunkIds = this.encodeChunk(chunk);
            ids.addAll(chunkIds);
        }
        return ids;
    }

    private Map<Pair<Integer, Integer>, Integer> getStats(List<Integer> ids) {
        HashMap<Pair<Integer, Integer>, Integer> map = new HashMap<Pair<Integer, Integer>, Integer>();
        int i = 0;
        while (i + 1 < ids.size()) {
            Pair<Integer, Integer> key = new Pair<Integer, Integer>(ids.get(i), ids.get(i + 1));
            map.put(key, map.getOrDefault(key, 0) + 1);
            ++i;
        }
        return map;
    }

    private List<Integer> encodeChunk(String chunk) {
        Map<Pair<Integer, Integer>, Integer> stats;
        Pair pair;
        List<Integer> ids = new ArrayList<Integer>();
        for (char b : chunk.toCharArray()) {
            int tokenIndex = this.vocabulary.getIndex(String.valueOf(b)).orElseThrow();
            ids.add(tokenIndex);
        }
        while (ids.size() >= 2 && this.merges.containsKey(pair = (stats = this.getStats(ids)).keySet().stream().min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow())) {
            int idx2 = this.merges.get(pair);
            ids = Tokenizer.merge(ids, pair, idx2);
        }
        return ids;
    }

    private static List<Integer> merge(List<Integer> ids, Pair<Integer, Integer> pair, int idx2) {
        ArrayList<Integer> newids = new ArrayList<Integer>();
        int i = 0;
        while (i < ids.size()) {
            if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) {
                newids.add(idx2);
                i += 2;
                continue;
            }
            newids.add(ids.get(i));
            ++i;
        }
        return newids;
    }

    public String decodeImpl(List<Integer> tokens) {
        StringBuilder sb = new StringBuilder();
        for (int token : tokens) {
            String tokenString = this.vocabulary.get(token);
            sb.append(tokenString);
        }
        return sb.toString();
    }

    private static Map<Integer, Integer> bytesToUnicode() {
        ArrayList<Integer> bs = new ArrayList<Integer>();
        IntStream.rangeClosed(33, 126).forEach(bs::add);
        IntStream.rangeClosed(161, 172).forEach(bs::add);
        IntStream.rangeClosed(174, 255).forEach(bs::add);
        ArrayList<Integer> cs = new ArrayList<Integer>(bs);
        int n = 0;
        for (int b = 0; b < 256; ++b) {
            if (bs.contains(b)) continue;
            bs.add(b);
            cs.add(256 + n);
            ++n;
        }
        return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get));
    }

    public int[] encode(String text) {
        byte[] bytes;
        StringBuilder sb = new StringBuilder();
        for (byte b : bytes = text.getBytes(StandardCharsets.UTF_8)) {
            sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b)));
        }
        return this.encodeImpl(sb.toString());
    }

    public List<Integer> encodeAsList(String text) {
        return Arrays.stream(this.encode(text)).boxed().collect(Collectors.toList());
    }

    public String decode(List<Integer> tokens) {
        String decoded = this.decodeImpl(tokens);
        int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray();
        byte[] rawBytes = new byte[decodedBytesAsInts.length];
        for (int i = 0; i < decoded.length(); ++i) {
            rawBytes[i] = (byte)decodedBytesAsInts[i];
        }
        return new String(rawBytes, StandardCharsets.UTF_8);
    }
}

