/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3.Model;

import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.yacy.ai.llama3.Llama;
import net.yacy.ai.llama3.Model.Arch;
import net.yacy.ai.llama3.Model.GGMLTensorEntry;
import net.yacy.ai.llama3.Model.GGUF;
import net.yacy.ai.llama3.Model.Pair;
import net.yacy.ai.llama3.Model.Tokenizer;
import net.yacy.ai.llama3.Model.Vocabulary;
import net.yacy.ai.llama3.Tensor.FloatTensor;

public final class ModelLoader {
    private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2";
    private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
    private static final String QWEN2_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

    private static Vocabulary loadVocabulary(Map<String, Object> metadata) {
        String model = (String)metadata.get("tokenizer.ggml.model");
        if (!TOKENIZER_LLAMA_3_MODEL.equals(model)) {
            throw new IllegalArgumentException("expected gpt2 but found " + model);
        }
        String[] tokens = (String[])metadata.get("tokenizer.ggml.tokens");
        float[] scores = (float[])metadata.get("tokenizer.ggml.scores");
        return new Vocabulary(tokens, scores);
    }

    public static Llama loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException {
        String name = ggufPath.getFileName().toString();
        if (name.contains("Llama")) {
            return ModelLoader.loadModelLlama3(ggufPath, contextLength, loadWeights);
        }
        if (name.contains("Qwen3")) {
            return ModelLoader.loadModelQwen3(ggufPath, contextLength);
        }
        throw new IOException("model type unknown");
    }

    private static Llama loadModelLlama3(Path ggufPath, int contextLength, boolean loadWeights) throws IOException {
        GGUF gguf = GGUF.loadModel(ggufPath);
        FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ);
        Map<String, Object> metadata = gguf.getMetadata();
        Vocabulary vocabulary = ModelLoader.loadVocabulary(metadata);
        Tokenizer tokenizer = ModelLoader.createLlama3Tokenizer(metadata, vocabulary);
        Llama.Configuration config = new Llama.Configuration(Arch.LLM_ARCH_LLAMA, (Integer)metadata.get("llama.embedding_length"), (Integer)metadata.get("llama.feed_forward_length"), (Integer)metadata.get("llama.block_count"), (Integer)metadata.get("llama.attention.head_count"), metadata.containsKey("llama.attention.head_count_kv") ? ((Integer)metadata.get("llama.attention.head_count_kv")).intValue() : ((Integer)metadata.get("llama.attention.head_count")).intValue(), vocabulary.size(), (Integer)metadata.get("llama.context_length"), false, ((Float)metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", Float.valueOf(1.0E-5f))).floatValue(), ((Float)metadata.getOrDefault("llama.rope.freq_base", Float.valueOf(10000.0f))).floatValue()).withContextLength(contextLength);
        Llama.Weights weights = null;
        if (loadWeights) {
            Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
            weights = ModelLoader.loadWeightsLlama3(tensorEntries, config);
        }
        return new Llama(config, tokenizer, weights);
    }

    private static Llama.Weights loadWeightsLlama3(Map<String, GGMLTensorEntry> tensorEntries, Llama.Configuration config) {
        boolean ropeScaling = tensorEntries.containsKey("rope_freqs");
        float scaleFactor = 8.0f;
        float loFreqFactor = 1.0f;
        float hiFreqFactor = 3.0f;
        int oldContextLength = 8192;
        Pair<float[], float[]> ropeFreqs = ModelLoader.precomputeFreqsCis4Llama3(config.contextLength, config.headSize, config.ropeTheta, ropeScaling, scaleFactor, loFreqFactor, hiFreqFactor, oldContextLength);
        float[] ropeFreqsReal = ropeFreqs.first();
        float[] ropeFreqsImag = ropeFreqs.second();
        GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
        Llama.Weights qw = new Llama.Weights(tokenEmbeddings.loadQuantized(), ModelLoader.loadArrayOfFloatBuffer(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_norm.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_q.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_k.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_v.weight")), null, null, null, ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_output.weight")), ModelLoader.loadArrayOfFloatBuffer(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".ffn_norm.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".ffn_gate.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".ffn_down.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".ffn_up.weight")), tensorEntries.get("output_norm.weight").toFloatBuffer(), FloatBuffer.wrap(ropeFreqsReal), FloatBuffer.wrap(ropeFreqsImag), tensorEntries.getOrDefault("output.weight", tokenEmbeddings).loadQuantized());
        return qw;
    }

    private static Llama loadModelQwen3(Path ggufPath, int contextLength) throws IOException {
        GGUF gguf = GGUF.loadModel(ggufPath);
        Map<String, Object> metadata = gguf.getMetadata();
        Vocabulary vocabulary = ModelLoader.loadVocabulary(metadata);
        Tokenizer tokenizer = ModelLoader.createQwen2Tokenizer(metadata, vocabulary);
        int modelContextLength = (Integer)metadata.get("qwen2.context_length");
        if (contextLength < 0 || modelContextLength < contextLength) {
            contextLength = modelContextLength;
        }
        Llama.Configuration config = new Llama.Configuration(Arch.LLM_ARCH_QWEN3, (Integer)metadata.get("qwen2.embedding_length"), (Integer)metadata.get("qwen2.feed_forward_length"), (Integer)metadata.get("qwen2.block_count"), (Integer)metadata.get("qwen2.attention.head_count"), metadata.containsKey("qwen2.attention.head_count_kv") ? ((Integer)metadata.get("qwen2.attention.head_count_kv")).intValue() : ((Integer)metadata.get("qwen2.attention.head_count")).intValue(), vocabulary.size(), contextLength, false, ((Float)metadata.get("qwen2.attention.layer_norm_rms_epsilon")).floatValue(), ((Float)metadata.get("qwen2.rope.freq_base")).floatValue());
        Map<String, GGMLTensorEntry> tensorEntries = gguf.getTensorEntries();
        Pair<float[], float[]> ropeFreqs = ModelLoader.precomputeFreqsCis4Qwen2(config.contextLength, config.headSize, config.ropeTheta);
        float[] ropeFreqsReal = ropeFreqs.first();
        float[] ropeFreqsImag = ropeFreqs.second();
        FloatTensor tokenEmbeddingTable = tensorEntries.get("token_embd.weight").loadQuantized();
        Llama.Weights qw = new Llama.Weights(tokenEmbeddingTable, ModelLoader.loadArrayOfFloatBuffer(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_norm.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_q.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_k.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_v.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_q.bias")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_k.bias")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_v.bias")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".attn_output.weight")), ModelLoader.loadArrayOfFloatBuffer(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".ffn_norm.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".ffn_gate.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".ffn_down.weight")), ModelLoader.loadArrayOfQuantized(config.numberOfLayers, i -> (GGMLTensorEntry)tensorEntries.get("blk." + i + ".ffn_up.weight")), tensorEntries.get("output_norm.weight").toFloatBuffer(), FloatBuffer.wrap(ropeFreqsReal), FloatBuffer.wrap(ropeFreqsImag), tensorEntries.containsKey("output.weight") ? tensorEntries.get("output.weight").loadQuantized() : tokenEmbeddingTable);
        return new Llama(config, tokenizer, qw);
    }

    private static Tokenizer createLlama3Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
        String[] mergeLines = (String[])metadata.get("tokenizer.ggml.merges");
        List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" ")).map(parts -> new Pair<Integer, Integer>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).collect(Collectors.toList());
        int allTokens = vocabulary.size();
        int baseTokens = 128000;
        List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).collect(Collectors.toList());
        assert (specialTokensList.stream().allMatch(token -> vocabulary.getIndex((String)token).isPresent()));
        Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> (String)specialTokensList.get((int)i), i -> baseTokens + i));
        return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens, null);
    }

    private static Tokenizer createQwen2Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
        int[] tokenTypes = (int[])metadata.get("tokenizer.ggml.token_type");
        String[] mergeLines = (String[])metadata.get("tokenizer.ggml.merges");
        List<Pair<Integer, Integer>> merges = Arrays.stream(mergeLines).map(line -> line.split(" ")).map(parts -> new Pair<Integer, Integer>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).collect(Collectors.toList());
        int allTokens = vocabulary.size();
        int baseTokens = vocabulary.getIndex("<|endoftext|>").orElseThrow();
        List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).collect(Collectors.toList());
        assert (specialTokensList.stream().allMatch(token -> vocabulary.getIndex((String)token).isPresent()));
        Map<String, Integer> specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> (String)specialTokensList.get((int)i), i -> baseTokens + i));
        return new Tokenizer(vocabulary, merges, QWEN2_PATTERN, specialTokens, tokenTypes);
    }

    private static FloatTensor[] loadArrayOfQuantized(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
        FloatTensor[] array = new FloatTensor[size];
        for (int i = 0; i < size; ++i) {
            array[i] = getTensorEntry.apply(i).loadQuantized();
        }
        return array;
    }

    private static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
        FloatBuffer[] array = new FloatBuffer[size];
        for (int i = 0; i < size; ++i) {
            array[i] = getTensorEntry.apply(i).toFloatBuffer();
        }
        return array;
    }

    private static Pair<float[], float[]> precomputeFreqsCis4Llama3(int contextLength, int headSize, double theta, boolean ropeScaling, float scaleFactor, float loFreqFactor, float hiFreqFactor, float oldContextLength) {
        assert (headSize % 2 == 0);
        float[] cr = new float[contextLength * (headSize / 2)];
        float[] ci = new float[contextLength * (headSize / 2)];
        int n = 0;
        for (int pos = 0; pos < contextLength; ++pos) {
            for (int i = 0; i < headSize; i += 2) {
                float freq = (float)(1.0 / Math.pow(theta, (double)i / (double)headSize));
                if (ropeScaling) {
                    float loFreqWavelen = oldContextLength / loFreqFactor;
                    float wavelen = (float)(Math.PI * 2 / (double)freq);
                    float hiFreqWavelen = oldContextLength / hiFreqFactor;
                    if (!(wavelen < hiFreqWavelen)) {
                        if (wavelen > loFreqWavelen) {
                            freq /= scaleFactor;
                        } else {
                            float smooth = (oldContextLength / wavelen - loFreqFactor) / (hiFreqFactor - loFreqFactor);
                            freq = (1.0f - smooth) * freq / scaleFactor + smooth * freq;
                        }
                    }
                }
                float val = (float)pos * freq;
                cr[n] = (float)Math.cos(val);
                ci[n] = (float)Math.sin(val);
                ++n;
            }
        }
        assert (contextLength * (headSize / 2) == n);
        return new Pair<float[], float[]>(cr, ci);
    }

    private static Pair<float[], float[]> precomputeFreqsCis4Qwen2(int contextLength, int headSize, double theta) {
        assert (headSize % 2 == 0);
        float[] cr = new float[contextLength * (headSize / 2)];
        float[] ci = new float[contextLength * (headSize / 2)];
        int n = 0;
        for (int pos = 0; pos < contextLength; ++pos) {
            for (int i = 0; i < headSize; i += 2) {
                float freq = (float)(1.0 / Math.pow(theta, (double)i / (double)headSize));
                float val = (float)pos * freq;
                cr[n] = (float)Math.cos(val);
                ci[n] = (float)Math.sin(val);
                ++n;
            }
        }
        assert (contextLength * (headSize / 2) == n);
        return new Pair<float[], float[]>(cr, ci);
    }
}

