/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3.Tensor;

import java.util.Arrays;
import java.util.function.IntConsumer;
import java.util.function.LongConsumer;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import net.yacy.ai.llama3.Model.GGMLType;
import net.yacy.ai.llama3.Tensor.FloatTensor;

public abstract class AbstractFloatTensor
implements FloatTensor {
    public static final float float16ToFloat(short h) {
        int fBits;
        int hBits = h & 0xFFFF;
        int sign = hBits >>> 15 & 1;
        int exp = hBits >>> 10 & 0x1F;
        int mant = hBits & 0x3FF;
        if (exp == 0) {
            if (mant == 0) {
                fBits = sign << 31;
            } else {
                while ((mant & 0x400) == 0) {
                    mant <<= 1;
                    --exp;
                }
                fBits = sign << 31 | ++exp + 127 - 15 << 23 | (mant &= 0xFFFFFBFF) << 13;
            }
        } else {
            fBits = exp == 31 ? sign << 31 | 0x7F800000 | mant << 13 : sign << 31 | exp + 127 - 15 << 23 | mant << 13;
        }
        return Float.intBitsToFloat(fBits);
    }

    public static final short floatToFloat16(float f) {
        int fBits = Float.floatToIntBits(f);
        int sign = fBits >>> 31 & 1;
        int exp = fBits >>> 23 & 0xFF;
        int mant = fBits & 0x7FFFFF;
        short hBits = exp == 255 ? (short)(sign << 15 | 0x7C00 | mant >>> 13) : (exp < 112 ? (short)(sign << 15) : (exp > 143 ? (short)(sign << 15 | 0x7C00) : (short)(sign << 15 | exp - 112 << 10 | mant >>> 13)));
        return hBits;
    }

    @Override
    public abstract int size();

    @Override
    public abstract float getFloat(int var1);

    @Override
    public abstract void setFloat(int var1, float var2);

    abstract GGMLType type();

    public static int numberOfElements(int ... dimensions) {
        assert (Arrays.stream(dimensions).allMatch(i -> i > 0));
        return Arrays.stream(dimensions).reduce(Math::multiplyExact).orElseThrow();
    }

    public static void parallelFor(int startInclusive, int endExclusive, IntConsumer action) {
        if (startInclusive == 0 && endExclusive == 1) {
            action.accept(0);
            return;
        }
        IntStream.range(startInclusive, endExclusive).parallel().forEach(action);
    }

    public static void parallelForLong(long startInclusive, long endExclusive, LongConsumer action) {
        if (startInclusive == 0L && endExclusive == 1L) {
            action.accept(0L);
            return;
        }
        LongStream.range(startInclusive, endExclusive).parallel().forEach(action);
    }

    @Override
    public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
        float sum0 = 0.0f;
        float sum1 = 0.0f;
        float sum2 = 0.0f;
        float sum3 = 0.0f;
        int i = thisOffset;
        int k = thatOffset;
        int limit = size & 0xFFFFFFFC;
        for (int j = 0; j < limit; j += 4) {
            sum0 += this.getFloat(i) * that.getFloat(k);
            sum1 += this.getFloat(i + 1) * that.getFloat(k + 1);
            sum2 += this.getFloat(i + 2) * that.getFloat(k + 2);
            sum3 += this.getFloat(i + 3) * that.getFloat(k + 3);
            i += 4;
            k += 4;
        }
        float result = sum0 + sum1 + sum2 + sum3;
        for (int j = limit; j < size; ++j) {
            result += this.getFloat(j) * that.getFloat(j);
        }
        return result;
    }

    @Override
    public void matmul(FloatTensor that, FloatTensor out, int dim0, int dim1) {
        AbstractFloatTensor.parallelFor(0, dim0, i -> out.setFloat(i, this.dot(i * dim1, that, 0, dim1)));
    }

    @Override
    public void matmul(int context, FloatTensor[] that, FloatTensor[] out, int dim0, int dim1) {
        if (that.length != out.length) {
            throw new IllegalArgumentException(String.format("that.len=%d, out.len=%d", that.length, out.length));
        }
        AbstractFloatTensor.parallelForLong(0L, dim0 * context, ti -> {
            int idxArr = (int)(ti / (long)dim0);
            int i = (int)(ti % (long)dim0);
            out[idxArr].setFloat(i, this.dot(i * dim1, that[idxArr], 0, dim1));
        });
    }

    @Override
    public float reduce(int thisOffset, int size, float seed, AggregateFunction reduce) {
        float result = seed;
        for (int i = 0; i < size; ++i) {
            result = reduce.apply(result, this.getFloat(thisOffset + i));
        }
        return result;
    }

    private float sum(int thisOffset, int size) {
        return this.reduce(thisOffset, size, 0.0f, Float::sum);
    }

    private float max(int thisOffset, int size) {
        return this.reduce(thisOffset, size, Float.NEGATIVE_INFINITY, Float::max);
    }

    @Override
    public void copyTo(int thisOffset, FloatTensor that, int thatOffset, int size) {
        int endOffset = thatOffset + size;
        for (int i = thatOffset; i < endOffset; ++i) {
            that.setFloat(i, this.getFloat(i - thatOffset + thisOffset));
        }
    }

    @Override
    public int argmax() {
        int size = this.size();
        assert (size > 0);
        int maxIndex = 0;
        float maxValue = this.getFloat(maxIndex);
        int endIndex = size;
        for (int i = 0; i < endIndex; ++i) {
            float f = this.getFloat(i);
            if (!(f > maxValue)) continue;
            maxValue = f;
            maxIndex = i;
        }
        return maxIndex;
    }

    @Override
    public FloatTensor mapInPlace(int thisOffset, int size, FloatTensor.MapFunction mapFunction) {
        int endIndex = thisOffset + size;
        for (int i = thisOffset; i < endIndex; ++i) {
            this.setFloat(i, mapFunction.apply(this.getFloat(i)));
        }
        return this;
    }

    @Override
    public final FloatTensor mapInPlace(FloatTensor.MapFunction mapFunction) {
        return this.mapInPlace(0, this.size(), mapFunction);
    }

    public FloatTensor mapWithIndexInPlace(int thisOffset, int size, FloatTensor.MapWithIndexFunction mapWithIndexFunction) {
        int endOffset = thisOffset + size;
        for (int i = thisOffset; i < endOffset; ++i) {
            this.setFloat(i, mapWithIndexFunction.apply(this.getFloat(i), i));
        }
        return this;
    }

    private final FloatTensor addInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) {
        return this.mapWithIndexInPlace(thisOffset, size, (value, index2) -> value + that.getFloat(index2 - thisOffset + thatOffset));
    }

    @Override
    public final FloatTensor addInPlace(FloatTensor that) {
        return this.addInPlace(0, that, 0, this.size());
    }

    private final FloatTensor multiplyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) {
        return this.mapWithIndexInPlace(thisOffset, size, (value, index2) -> value * that.getFloat(index2 - thisOffset + thatOffset));
    }

    @Override
    public final FloatTensor multiplyInPlace(FloatTensor that) {
        return this.multiplyInPlace(0, that, 0, this.size());
    }

    @Override
    public final FloatTensor divideInPlace(int thisOffset, int size, float value) {
        return this.mapInPlace(thisOffset, size, f -> f / value);
    }

    @Override
    public FloatTensor fillInPlace(int thisOffset, int size, float value) {
        return this.mapInPlace(thisOffset, size, unused -> value);
    }

    @Override
    public final FloatTensor softmaxInPlace(int thisOffset, int size) {
        float maxVal = this.max(thisOffset, size);
        this.mapInPlace(thisOffset, size, f -> (float)Math.exp(f - maxVal));
        float sum = this.sum(thisOffset, size);
        return this.divideInPlace(thisOffset, size, sum);
    }

    @Override
    public FloatTensor saxpyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size, float a) {
        for (int i = 0; i < size; ++i) {
            this.setFloat(thisOffset + i, a * that.getFloat(thatOffset + i) + this.getFloat(thisOffset + i));
        }
        return this;
    }

    @FunctionalInterface
    public static interface AggregateFunction {
        public float apply(float var1, float var2);
    }
}

