/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3;

import java.util.Comparator;
import java.util.Random;
import net.yacy.ai.llama3.Tensor.FloatTensor;

@FunctionalInterface
interface Sampler {
    public static final Sampler ARGMAX = FloatTensor::argmax;

    public int sampleToken(FloatTensor var1);

    public static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) {
        Sampler sampler;
        if (temperature == 0.0f) {
            sampler = ARGMAX;
        } else {
            Random rng = new Random(rngSeed);
            Sampler innerSampler = topp <= 0.0f || topp >= 1.0f ? new CategoricalSampler(rng) : new ToppSampler(vocabularySize, topp, rng);
            sampler = logits -> {
                logits.divideInPlace(0, logits.size(), temperature);
                logits.softmaxInPlace(0, logits.size());
                return innerSampler.sampleToken(logits);
            };
        }
        return sampler;
    }

    public static class CategoricalSampler
    implements Sampler {
        final Random rng;

        public CategoricalSampler(Random rng) {
            this.rng = rng;
        }

        @Override
        public int sampleToken(FloatTensor logits) {
            float random0to1 = this.rng.nextFloat();
            float cdf = 0.0f;
            for (int i = 0; i < logits.size(); ++i) {
                if (!(random0to1 < (cdf += logits.getFloat(i)))) continue;
                return i;
            }
            return logits.size() - 1;
        }
    }

    public static class ToppSampler
    implements Sampler {
        final int[] indices;
        final float topp;
        final Random rng;

        public ToppSampler(int maxNumberOfElements, float topp, Random rng) {
            this.indices = new int[maxNumberOfElements];
            this.topp = topp;
            this.rng = rng;
        }

        static void swap(int[] array, int from, int to) {
            int tmp = array[from];
            array[from] = array[to];
            array[to] = tmp;
        }

        static void siftDown(int[] array, int from, int n, Comparator<Integer> comparator) {
            int next;
            int prev = from;
            while ((next = 2 * prev + 1) < n) {
                int r = 2 * prev + 2;
                if (r < n && comparator.compare(array[r], array[next]) < 0) {
                    next = r;
                }
                if (comparator.compare(array[next], array[prev]) >= 0) break;
                ToppSampler.swap(array, prev, next);
                prev = next;
            }
        }

        @Override
        public int sampleToken(FloatTensor logits) {
            Comparator<Integer> comparator = Comparator.comparingDouble(logits::getFloat).reversed();
            int n = logits.size();
            int head = 0;
            int tail = n - 1;
            float cutoff = (1.0f - this.topp) / (float)(n - 1);
            for (int i = 0; i < this.indices.length; ++i) {
                if (logits.getFloat(i) >= cutoff) {
                    this.indices[head++] = i;
                    continue;
                }
                this.indices[tail--] = i;
            }
            int n0 = head;
            for (int i = n0 / 2 - 1; i >= 0; --i) {
                ToppSampler.siftDown(this.indices, i, n0, comparator);
            }
            float cumulativeProb = 0.0f;
            int lastIndex = 0;
            for (int i = n0 - 1; i >= 0; --i) {
                ToppSampler.swap(this.indices, 0, i);
                cumulativeProb += logits.getFloat(this.indices[i]);
                if (cumulativeProb > this.topp) {
                    lastIndex = i;
                    break;
                }
                ToppSampler.siftDown(this.indices, 0, i - 1, comparator);
            }
            float r = this.rng.nextFloat() * cumulativeProb;
            float cdf = 0.0f;
            for (int i = n0 - 1; i >= lastIndex; --i) {
                if (!(r < (cdf += logits.getFloat(this.indices[i])))) continue;
                return this.indices[i];
            }
            return this.indices[lastIndex];
        }
    }
}

