/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3.Tensor;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Arrays;
import net.yacy.ai.llama3.Model.GGMLType;
import net.yacy.ai.llama3.Tensor.AbstractFloatTensor;
import net.yacy.ai.llama3.Tensor.FloatTensor;
import net.yacy.ai.llama3.Tensor.Q4_0FloatTensor;

public final class ArrayFloatTensor
extends AbstractFloatTensor
implements FloatTensor {
    public final float[] values;
    private static final VarHandle FLOAT_ARRAY_HANDLE;

    public ArrayFloatTensor(float[] values) {
        this.values = values;
    }

    public static FloatTensor allocate(int ... dims) {
        int numberOfElements = AbstractFloatTensor.numberOfElements(dims);
        return new ArrayFloatTensor(new float[numberOfElements]);
    }

    @Override
    public final int size() {
        return this.values.length;
    }

    @Override
    public final float getFloat(int index2) {
        return FLOAT_ARRAY_HANDLE.get(this.values, index2);
    }

    @Override
    public final void setFloat(int index2, float value) {
        FLOAT_ARRAY_HANDLE.set(this.values, index2, value);
    }

    @Override
    public GGMLType type() {
        return GGMLType.F32;
    }

    @Override
    public final AbstractFloatTensor fillInPlace(int thisOffset, int size, float value) {
        Arrays.fill(this.values, thisOffset, thisOffset + size, value);
        return this;
    }

    @Override
    public final AbstractFloatTensor mapInPlace(int thisOffset, int size, FloatTensor.MapFunction mapFunction) {
        int endIndex = thisOffset + size;
        for (int i = thisOffset; i < endIndex; ++i) {
            this.values[i] = mapFunction.apply(this.values[i]);
        }
        return this;
    }

    @Override
    public final void copyTo(int thisOffset, FloatTensor that, int thatOffset, int size) {
        int delta = thisOffset - thatOffset;
        if (that instanceof ArrayFloatTensor) {
            ArrayFloatTensor aft = (ArrayFloatTensor)that;
            int endOffset = thatOffset + size;
            for (int i = thatOffset; i < endOffset; ++i) {
                FLOAT_ARRAY_HANDLE.set(aft.values, i, FLOAT_ARRAY_HANDLE.get(this.values, i + delta));
            }
        } else {
            int endOffset = thatOffset + size;
            for (int i = thatOffset; i < endOffset; ++i) {
                that.setFloat(i, FLOAT_ARRAY_HANDLE.get(this.values, i + delta));
            }
        }
    }

    @Override
    public final float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
        float result = 0.0f;
        if (that instanceof ArrayFloatTensor) {
            ArrayFloatTensor aft = (ArrayFloatTensor)that;
            float[] a = this.values;
            float[] b = aft.values;
            for (int i = 0; i < size; ++i) {
                float valA = FLOAT_ARRAY_HANDLE.get(a, thisOffset + i);
                float valB = FLOAT_ARRAY_HANDLE.get(b, thatOffset + i);
                result += valA * valB;
            }
        } else {
            float[] a = this.values;
            for (int i = 0; i < size; ++i) {
                float valA = FLOAT_ARRAY_HANDLE.get(a, thisOffset + i);
                result += valA * that.getFloat(thatOffset + i);
            }
        }
        return result;
    }

    @Override
    public final void matmul(FloatTensor that, FloatTensor out, int dim0, int dim1) {
        if (that instanceof ArrayFloatTensor) {
            AbstractFloatTensor.parallelFor(0, dim0, i -> {
                ((ArrayFloatTensor)out).values[i] = this.dot(i * dim1, that, 0, dim1);
            });
        } else {
            AbstractFloatTensor.parallelFor(0, dim0, i -> out.setFloat(i, this.dot(i * dim1, that, 0, dim1)));
        }
    }

    @Override
    public final void matmul(int context, FloatTensor[] that, FloatTensor[] out, int dim0, int dim1) {
        if (that.length != out.length) {
            throw new IllegalArgumentException(String.format("that.len=%d, out.len=%d", that.length, out.length));
        }
        AbstractFloatTensor.parallelForLong(0L, dim0 * context, ti -> {
            int idxArr = (int)(ti / (long)dim0);
            int i = (int)(ti % (long)dim0);
            out[idxArr].setFloat(i, this.dot(i * dim1, that[idxArr], 0, dim1));
        });
    }

    @Override
    public final FloatTensor saxpyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size, float a) {
        if (that instanceof Q4_0FloatTensor) {
            int chunkSize;
            Q4_0FloatTensor qft = (Q4_0FloatTensor)that;
            float[] decodedBlock = Q4_0FloatTensor.scratchBuffer.get();
            int i = 0;
            for (int remaining = size; remaining > 0; remaining -= chunkSize) {
                chunkSize = Math.min(remaining, GGMLType.Q4_0.blockSize);
                qft.getFloatArray(thatOffset + i, decodedBlock, 0, chunkSize);
                for (int j = 0; j < chunkSize; ++j) {
                    int dstIdx = thisOffset + i + j;
                    this.setFloat(dstIdx, a * decodedBlock[j] + this.getFloat(dstIdx));
                }
                i += chunkSize;
            }
        } else {
            for (int i = 0; i < size; ++i) {
                int idx2 = thisOffset + i;
                this.setFloat(idx2, a * that.getFloat(thatOffset + i) + this.getFloat(idx2));
            }
        }
        return this;
    }

    static {
        try {
            FLOAT_ARRAY_HANDLE = MethodHandles.arrayElementVarHandle(float[].class);
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to get VarHandle for float[]", e);
        }
    }
}

