/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.http.servlets;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.stream.Collectors;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import net.yacy.ai.LLM;
import net.yacy.cora.federate.solr.SolrType;
import net.yacy.cora.federate.solr.connector.EmbeddedSolrConnector;
import net.yacy.cora.protocol.Domains;
import net.yacy.search.Switchboard;
import net.yacy.search.schema.CollectionSchema;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrDocumentList;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.servlet.cache.Method;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import org.json.JSONTokener;

public class RAGProxyServlet
extends HttpServlet {
    private static final long serialVersionUID = 3411544789759643137L;
    private static String LLM_SYSTEM_PREFIX = "\n\nYou may receive additional expert knowledge in the user prompt after a 'Additional Information' headline to enhance your knowledge. Use it only if applicable.";
    private static String LLM_USER_PREFIX = "\n\nAdditional Information:\n\nbelow you find a collection of texts that might be useful to generate a response. Do not discuss these documents, just use them to answer the question above.\n\n";

    public void service(ServletRequest request, ServletResponse response) throws IOException, ServletException {
        String line;
        Method reqMethod;
        response.setContentType("application/json;charset=utf-8");
        HttpServletResponse hresponse = (HttpServletResponse)response;
        HttpServletRequest hrequest = (HttpServletRequest)request;
        hresponse.setHeader("Access-Control-Allow-Origin", "*");
        hresponse.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE");
        hresponse.setHeader("Access-Control-Allow-Headers", "Content-Type, Authorization");
        String clientIP = hrequest.getRemoteAddr();
        boolean localhostAccess = Domains.isLocalhost(clientIP);
        if (!localhostAccess) {
            hresponse.sendError(403);
        }
        if ((reqMethod = Method.getMethod((String)hrequest.getMethod())) == Method.OTHER) {
            hresponse.setStatus(200);
            return;
        }
        if (reqMethod != Method.POST) {
            hresponse.sendError(405);
            return;
        }
        ServletOutputStream out = response.getOutputStream();
        BufferedReader reader = request.getReader();
        StringBuilder bodyBuilder = new StringBuilder();
        while ((line = reader.readLine()) != null) {
            bodyBuilder.append(line);
        }
        String body = bodyBuilder.toString();
        try {
            JSONObject bodyObject = new JSONObject(body);
            String model = bodyObject.optString("model", LLM.LLMUsage.chat.name());
            LLM.LLMUsage usage = LLM.LLMUsage.chat;
            try {
                usage = LLM.LLMUsage.valueOf(model);
            }
            catch (IllegalArgumentException illegalArgumentException) {
                // empty catch block
            }
            LLM.LLMModel llm4Chat = LLM.llmFromUsage(usage);
            LLM.LLMModel llm4tldr = LLM.llmFromUsage(LLM.LLMUsage.tldr);
            bodyObject.put("model", llm4Chat.model);
            JSONArray messages = bodyObject.optJSONArray("messages");
            for (int i = 0; i < messages.length(); ++i) {
                JSONObject message2 = messages.getJSONObject(i);
                if (!message2.optString("role", "").equals("user")) continue;
                UserObject userObject = new UserObject(message2);
                userObject.attachAttachment(LLM_USER_PREFIX);
            }
            UserObject userObject = new UserObject(messages.getJSONObject(messages.length() - 1));
            Object user = userObject.getContentText();
            boolean rag = userObject.getSearch();
            String searchResultQuery = "";
            String searchResultMarkdown = "";
            if (rag) {
                searchResultQuery = this.searchWordsForPrompt(llm4tldr.llm, llm4tldr.model, (String)user);
                searchResultMarkdown = RAGProxyServlet.searchResultsAsMarkdown(searchResultQuery, 10);
                user = (String)user + LLM_USER_PREFIX;
                user = (String)user + searchResultMarkdown;
                userObject.setContentText((String)user);
            }
            body = bodyObject.toString();
            URL url = new URI(llm4Chat.llm.hoststub + "/v1/chat/completions").toURL();
            HttpURLConnection conn = (HttpURLConnection)url.openConnection();
            conn.setRequestMethod("POST");
            conn.setRequestProperty("Content-Type", "application/json");
            if (!llm4Chat.llm.api_key.isEmpty()) {
                conn.setRequestProperty("Authorization", "Bearer " + llm4Chat.llm.api_key);
            }
            conn.setDoOutput(true);
            try (OutputStream os = conn.getOutputStream();){
                os.write(body.getBytes());
                os.flush();
            }
            int status = conn.getResponseCode();
            hresponse.setStatus(status);
            if (status == 200) {
                LinkedBlockingQueue inputQueue = new LinkedBlockingQueue();
                String POISON = "POISON";
                Thread readerThread = new Thread(() -> {
                    try {
                        String inputLine;
                        BufferedReader in = new BufferedReader(new InputStreamReader(conn.getInputStream()));
                        while ((inputLine = in.readLine()) != null) {
                            inputQueue.put(inputLine);
                        }
                        in.close();
                        inputQueue.put("POISON");
                    }
                    catch (IOException | InterruptedException exception) {
                    }
                    finally {
                        try {
                            inputQueue.put("POISON");
                        }
                        catch (InterruptedException interruptedException) {}
                    }
                });
                readerThread.start();
                try {
                    Object inputLine;
                    int count = 0;
                    while (!((String)(inputLine = (String)inputQueue.take())).equals("POISON")) {
                        int p;
                        if (count == 0 && searchResultMarkdown.length() > 0 && (p = ((String)inputLine).indexOf(123)) > 0) {
                            JSONObject j = new JSONObject(new JSONTokener(((String)inputLine).substring(p)));
                            j.put("search-filename", "search_result_" + searchResultQuery.replace(' ', '_') + ".md");
                            j.put("search-text-base64", new String(Base64.getEncoder().encode(searchResultMarkdown.getBytes(StandardCharsets.UTF_8)), StandardCharsets.UTF_8));
                            inputLine = ((String)inputLine).substring(0, p) + j.toString();
                        }
                        out.println((String)inputLine);
                        out.flush();
                        ++count;
                    }
                }
                catch (InterruptedException interruptedException) {
                    // empty catch block
                }
            }
            out.close();
        }
        catch (URISyntaxException | JSONException e) {
            throw new IOException(e.getMessage());
        }
    }

    public static JSONArray searchResults(String query2, int count, boolean includeSnippet) {
        JSONArray results = new JSONArray();
        if (query2 == null || query2.length() == 0 || count == 0) {
            return results;
        }
        Switchboard sb = Switchboard.getSwitchboard();
        EmbeddedSolrConnector connector = sb.index.fulltext().getDefaultEmbeddedConnector();
        SolrQuery params = new SolrQuery();
        params.setQuery(CollectionSchema.text_t.getSolrFieldName() + ":" + query2);
        params.setRows(Integer.valueOf(count));
        params.setStart(Integer.valueOf(0));
        params.setFacet(false);
        params.clearSorts();
        params.setFields(new String[]{CollectionSchema.sku.getSolrFieldName(), CollectionSchema.text_t.getSolrFieldName()});
        params.setIncludeScore(true);
        params.set("df", new String[]{CollectionSchema.text_t.getSolrFieldName()});
        try {
            SolrDocumentList sdl = connector.getDocumentListByParams((ModifiableSolrParams)params);
            Iterator i = sdl.iterator();
            while (i.hasNext()) {
                try {
                    SolrDocument doc = (SolrDocument)i.next();
                    JSONObject result = new JSONObject(true);
                    String url = (String)doc.getFieldValue(CollectionSchema.sku.getSolrFieldName());
                    result.put("url", url == null ? "" : url.trim());
                    String title = RAGProxyServlet.getOneString(doc, CollectionSchema.title);
                    result.put("title", title == null ? "" : title.trim());
                    if (includeSnippet) {
                        String text = (String)doc.getFieldValue(CollectionSchema.text_t.getSolrFieldName());
                        result.put("text", text == null ? "" : text.trim());
                    }
                    results.put(result);
                }
                catch (JSONException jSONException) {}
            }
            return results;
        }
        catch (IOException | SolrException e) {
            return results;
        }
    }

    public static String searchResultsAsMarkdown(String query2, int count) {
        int i;
        JSONArray searchResults = RAGProxyServlet.searchResults(query2, count, true);
        StringBuilder sb = new StringBuilder();
        ArrayList<Snippet> results = new ArrayList<Snippet>();
        for (i = 0; i < searchResults.length(); ++i) {
            try {
                Snippet snippet;
                JSONObject r = searchResults.getJSONObject(i);
                String title = r.optString("title", "");
                String url = r.optString("url", "");
                String text = r.optString("text", "");
                if (title.length() <= 0 || text.length() <= 0 || (snippet = new Snippet(query2, text, url, title, 256)).getText().length() <= 0) continue;
                results.add(snippet);
                continue;
            }
            catch (JSONException r) {
                // empty catch block
            }
        }
        results.sort(Comparator.comparingDouble(Snippet::getScore));
        for (i = 0; i < results.size() / 2; ++i) {
            Snippet snippet = (Snippet)results.get(i);
            sb.append("## ").append(snippet.getTitle()).append("\n");
            sb.append(snippet.text).append("\n");
            if (snippet.getURL().length() > 0) {
                sb.append("Source: ").append(snippet.getURL()).append("\n");
            }
            sb.append("\n\n");
        }
        return sb.toString();
    }

    public static List<String> slicer(String text, int len) {
        ArrayList<String> result = new ArrayList<String>();
        if (text == null || len <= 0) {
            return result;
        }
        int start = 0;
        while (start < text.length()) {
            char ch;
            int end;
            for (end = Math.min(start + len, text.length()); end < text.length() && ((ch = text.charAt(end - 1)) != '.' && ch != '?' && ch != '!' || !Character.isWhitespace(text.charAt(end))); ++end) {
            }
            result.add(text.substring(start, end));
            start = end;
        }
        return result;
    }

    private static String getOneString(SolrDocument doc, CollectionSchema field) {
        assert (field.isMultiValued());
        assert (field.getType() == SolrType.string || field.getType() == SolrType.text_general);
        Object r = doc.getFieldValue(field.getSolrFieldName());
        if (r == null) {
            return "";
        }
        if (r instanceof ArrayList) {
            return (String)((ArrayList)r).get(0);
        }
        return r.toString();
    }

    private String searchWordsForPrompt(LLM llm, String model, String prompt) {
        String question = "Make a list of search words with low document frequency for the following prompt; use a JSON Array: " + prompt;
        try {
            String[] a;
            LLM.Context context = new LLM.Context(LLM_SYSTEM_PREFIX);
            context.addPrompt(question);
            LinkedHashSet<String> singlewords = new LinkedHashSet<String>();
            for (String s : a = LLM.stringsFromChat(llm.chat(model, context, LLM.listSchema, 200))) {
                for (String t : s.split(" ")) {
                    singlewords.add(t.toLowerCase());
                }
            }
            StringBuilder query2 = new StringBuilder();
            for (String s : singlewords) {
                query2.append(s).append(' ');
            }
            return query2.toString().trim();
        }
        catch (IOException | JSONException e) {
            e.printStackTrace();
            return "";
        }
    }

    private static Set<String> querySet(String query2) {
        Set<String> queryWordSet = Arrays.stream(query2.trim().toLowerCase().split("\\s+")).map(String::toLowerCase).filter(word -> !word.isEmpty()).collect(Collectors.toSet());
        return queryWordSet;
    }

    private static JSONObject responseLine(String payload) {
        JSONObject j = new JSONObject(true);
        try {
            j.put("id", "log");
            j.put("object", "chat.completion.chunk");
            j.put("created", System.currentTimeMillis() / 1000L);
            j.put("model", "log");
            j.put("system_fingerprint", "YaCy");
            JSONArray choices = new JSONArray();
            JSONObject choice = new JSONObject(true);
            choice.put("index", 0);
            JSONObject delta = new JSONObject(true);
            delta.put("role", "assistant");
            delta.put("content", payload);
            choice.put("delta", delta);
            choices.put(choice);
            j.put("choices", choices);
        }
        catch (JSONException jSONException) {
            // empty catch block
        }
        return j;
    }

    public static final class UserObject {
        private JSONObject userObject;

        public UserObject(JSONObject userObject) {
            this.userObject = userObject;
        }

        public void attachAttachment(String prefix) {
            List<DataURL> data_urls = this.getContentAttachments();
            for (DataURL data_url : data_urls) {
                if (!data_url.getMimetype().startsWith("text/")) continue;
                Object user = this.getContentText();
                String attachment = new String(data_url.getData(), StandardCharsets.UTF_8);
                user = (String)user + prefix;
                user = (String)user + attachment;
                this.setContentText((String)user);
                this.removeContentAttachment(data_url);
            }
            this.normalize();
        }

        public boolean getSearch() {
            boolean search2 = this.userObject.optBoolean("search", false);
            return search2;
        }

        public String getContentText() {
            Object content = this.userObject.opt("content");
            assert (content != null);
            if (content instanceof JSONArray) {
                JSONArray array = (JSONArray)content;
                for (int i = 0; i < array.length(); ++i) {
                    JSONObject j = array.optJSONObject(i);
                    String ctype = j.optString("type");
                    if (ctype == null || !ctype.equals("text")) continue;
                    String text = j.optString("text", "");
                    return text;
                }
                return "";
            }
            assert (content instanceof String);
            return (String)content;
        }

        public List<DataURL> getContentAttachments() {
            ArrayList<DataURL> list2 = new ArrayList<DataURL>();
            Object content = this.userObject.opt("content");
            assert (content != null);
            if (content instanceof JSONArray) {
                JSONArray array = (JSONArray)content;
                for (int i = 0; i < array.length(); ++i) {
                    String data_url;
                    JSONObject image_url;
                    JSONObject j = array.optJSONObject(i);
                    String ctype = j.optString("type");
                    if (ctype == null || !ctype.equals("image_url") || (image_url = j.optJSONObject("image_url")) == null || (data_url = image_url.optString("url", "")).length() <= 0) continue;
                    DataURL dataurl = new DataURL(data_url);
                    list2.add(dataurl);
                }
            }
            return list2;
        }

        public void removeContentAttachment(DataURL delete_data_url) {
            Object content = this.userObject.opt("content");
            assert (content != null);
            if (content instanceof JSONArray) {
                JSONArray array = (JSONArray)content;
                for (int i = 0; i < array.length(); ++i) {
                    DataURL dataurl;
                    String data_url;
                    JSONObject image_url;
                    JSONObject j = array.optJSONObject(i);
                    String ctype = j.optString("type");
                    if (ctype == null || !ctype.equals("image_url") || (image_url = j.optJSONObject("image_url")) == null || (data_url = image_url.optString("url", "")).length() <= 0 || (dataurl = new DataURL(data_url)).getSiganture() != delete_data_url.getSiganture()) continue;
                    array.remove(i);
                    break;
                }
                this.normalize();
            }
        }

        public void normalize() {
            Object content = this.userObject.opt("content");
            assert (content != null);
            if (content instanceof String) {
                return;
            }
            assert (content instanceof JSONArray);
            JSONArray array = (JSONArray)content;
            assert (array.length() > 0);
            if (array.length() != 1) {
                return;
            }
            JSONObject j = array.optJSONObject(0);
            String ctype = j.optString("type");
            assert (ctype != null);
            assert (ctype.equals("text"));
            if (!ctype.equals("text")) {
                return;
            }
            String text = j.optString("text", "");
            try {
                this.userObject.putOpt("content", text);
            }
            catch (JSONException jSONException) {
                // empty catch block
            }
        }

        public void setContentText(String text) {
            Object content = this.userObject.opt("content");
            assert (content != null);
            if (content instanceof String) {
                try {
                    this.userObject.put("content", text);
                }
                catch (JSONException jSONException) {
                    // empty catch block
                }
                return;
            }
            assert (content instanceof JSONArray);
            JSONArray array = (JSONArray)content;
            for (int i = 0; i < array.length(); ++i) {
                JSONObject j = array.optJSONObject(i);
                String ctype = j.optString("type");
                if (ctype == null || !ctype.equals("text")) continue;
                try {
                    j.putOpt("text", text);
                }
                catch (JSONException jSONException) {
                    // empty catch block
                }
                return;
            }
        }
    }

    public static class Snippet {
        private String text;
        private String url;
        private String title;
        private double score;

        public Snippet(String query2, String text, String url, String title, int maxChunkLength) {
            this.url = url;
            this.title = title;
            this.score = 0.0;
            if (text == null || text.isEmpty() || maxChunkLength <= 0 || query2 == null) {
                this.text = "";
                return;
            }
            List<String> chunks = RAGProxyServlet.slicer(text, maxChunkLength);
            if (chunks.isEmpty()) {
                this.text = "";
                return;
            }
            ArrayList<String> chunksLowerCase = new ArrayList<String>(chunks.size());
            for (String chunk : chunks) {
                chunksLowerCase.add(chunk.toLowerCase());
            }
            Set<String> queryWordSet = RAGProxyServlet.querySet(query2);
            if (queryWordSet.isEmpty()) {
                this.text = "";
                return;
            }
            int totalChunks = chunksLowerCase.size();
            HashMap<String, Double> idf = new HashMap<String, Double>();
            for (String word : queryWordSet) {
                int docFreq = 0;
                for (String chunk : chunksLowerCase) {
                    if (!chunk.contains(word)) continue;
                    ++docFreq;
                }
                idf.put(word, Math.log((double)totalChunks / (double)(docFreq + 1)) + 1.0);
            }
            HashMap<Integer, Double> chunkScores = new HashMap<Integer, Double>();
            for (int i = 0; i < chunksLowerCase.size(); ++i) {
                String chunk = (String)chunksLowerCase.get(i);
                double score = 0.0;
                HashMap<String, Integer> tf = new HashMap<String, Integer>();
                String[] wordsInChunk = chunk.split("\\s+");
                for (String w : wordsInChunk) {
                    String cleanWord = w.replaceAll("[.,!?;:]", "");
                    if (cleanWord.length() <= 0 || !queryWordSet.contains(cleanWord)) continue;
                    tf.put(cleanWord, tf.getOrDefault(cleanWord, 0) + 1);
                }
                for (String word : queryWordSet) {
                    int tfValue = tf.getOrDefault(word, 0);
                    double tfIdf = (double)tfValue * idf.getOrDefault(word, 1.0);
                    score += tfIdf;
                }
                chunkScores.put(i, score);
            }
            int topChunkIndex = -1;
            for (Map.Entry entry2 : chunkScores.entrySet()) {
                if (!((Double)entry2.getValue() > this.score)) continue;
                this.score = (Double)entry2.getValue();
                topChunkIndex = (Integer)entry2.getKey();
            }
            if (topChunkIndex < 0) {
                this.text = "";
                this.score = 0.0;
                return;
            }
            ArrayList<String> snippetChunks = new ArrayList<String>();
            if (topChunkIndex > 0) {
                snippetChunks.add(chunks.get(topChunkIndex - 1));
            }
            snippetChunks.add(chunks.get(topChunkIndex));
            if (topChunkIndex < chunks.size() - 1) {
                snippetChunks.add(chunks.get(topChunkIndex + 1));
            }
            this.text = String.join((CharSequence)" ", snippetChunks);
        }

        public double getScore() {
            return this.score;
        }

        public String getText() {
            return this.text;
        }

        public String getURL() {
            return this.url;
        }

        public String getTitle() {
            return this.title;
        }
    }

    public static final class DataURL {
        private String mimetype;
        private byte[] data;
        private int signature;

        public DataURL(String data_url) {
            if (data_url == null || !data_url.startsWith("data:")) {
                throw new IllegalArgumentException("data url not valid: it must start with 'data:'");
            }
            int commaIndex = data_url.indexOf(44);
            if (commaIndex == -1) {
                throw new IllegalArgumentException("data url not valid: it must contain a comma");
            }
            String header = data_url.substring(5, commaIndex);
            String base64Data = data_url.substring(commaIndex + 1);
            String[] headerParts = header.split(";");
            this.mimetype = headerParts[0];
            this.data = Base64.getDecoder().decode(base64Data);
            this.signature = base64Data.hashCode();
        }

        public String getMimetype() {
            return this.mimetype;
        }

        public byte[] getData() {
            return this.data;
        }

        public int getSiganture() {
            return this.signature;
        }
    }
}

