/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3;

import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
import java.util.Set;
import java.util.function.IntConsumer;
import net.yacy.ai.llama3.ChatFormat;
import net.yacy.ai.llama3.Context;
import net.yacy.ai.llama3.Llama;
import net.yacy.ai.llama3.Model.ModelLoader;
import net.yacy.ai.llama3.Sampler;

public class Llama3 {
    private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);

    static void runInteractive(Llama model, Sampler sampler, Context options) {
        Integer stopToken;
        Llama.State state = null;
        ArrayList<Integer> conversationTokens = new ArrayList<Integer>();
        ChatFormat chatFormat = new ChatFormat(model.tokenizer());
        conversationTokens.add(chatFormat.beginOfText);
        if (options.systemPrompt != null) {
            conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt)));
        }
        int startPosition = 0;
        Scanner in = new Scanner(System.in);
        do {
            System.out.print("> ");
            System.out.flush();
            String userText = in.nextLine();
            if (state == null) {
                state = model.createNewState(BATCH_SIZE);
            }
            conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
            conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
            Set<Integer> stopTokens = chatFormat.getStopTokens();
            List<Integer> responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens, sampler, token -> {
                if (!model.tokenizer().isSpecialToken(token)) {
                    System.out.print(model.tokenizer().decode(List.of(Integer.valueOf(token))));
                }
            });
            conversationTokens.addAll(responseTokens);
            startPosition = conversationTokens.size();
            stopToken = null;
            if (responseTokens.isEmpty() || !stopTokens.contains(responseTokens.get(responseTokens.size() - 1))) continue;
            stopToken = responseTokens.get(responseTokens.size() - 1);
            responseTokens.remove(responseTokens.size() - 1);
        } while (stopToken != null);
        System.out.println("Ran out of context length...");
    }

    public static List<Integer> runInstructOnce(Llama model, Sampler sampler, Context options, IntConsumer onTokenGenerated) {
        Llama.State state = model.createNewState(BATCH_SIZE);
        ChatFormat chatFormat = new ChatFormat(model.tokenizer());
        ArrayList<Integer> promptTokens = new ArrayList<Integer>();
        promptTokens.add(chatFormat.beginOfText);
        if (options.systemPrompt != null) {
            promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt)));
        }
        promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt)));
        promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
        Set<Integer> stopTokens = chatFormat.getStopTokens();
        List<Integer> responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens, sampler, onTokenGenerated);
        if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.get(responseTokens.size() - 1))) {
            responseTokens.remove(responseTokens.size() - 1);
        }
        return responseTokens;
    }

    public static String toString(Llama model, List<Integer> tokens) {
        return model.tokenizer().decode(tokens);
    }

    public static void main(String[] args) throws IOException {
        Path modelPath = Path.of("/Users/admin/git/yacy_search_server", "DATA", "LLMS", "Llama-3.2-1B-Instruct-Q4_0.gguf");
        Context options = new Context("Write a Java program which computes the first 42 prime numbers.", "Be a very good programmer.", 0.0f, 0.95f, 0L, 1024);
        Llama model = ModelLoader.loadModel(modelPath, 1024, true);
        long startTime = System.currentTimeMillis();
        Sampler sampler = Sampler.selectSampler(model.configuration().vocabularySize, options.temp, options.topp, options.seed);
        List<Integer> resultToken = Llama3.runInstructOnce(model, sampler, options, token -> {
            if (!model.tokenizer().isSpecialToken(token)) {
                System.out.print(model.tokenizer().decode(List.of(Integer.valueOf(token))));
            }
        });
        long endTime = System.currentTimeMillis();
        System.out.println("\nToken: " + resultToken.size() + ", " + (double)resultToken.size() * 1000.0 / (double)(endTime - startTime) + " Tokens per second");
    }
}

