/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3.Tensor;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.nio.ByteBuffer;
import net.yacy.ai.llama3.Model.GGMLType;
import net.yacy.ai.llama3.Tensor.AbstractFloatTensor;
import net.yacy.ai.llama3.Tensor.FloatTensor;

public final class Q4_0FloatTensor
extends AbstractFloatTensor
implements FloatTensor {
    private final int size;
    private final ByteBuffer buffer;
    private static final VarHandle FLOAT_ARRAY_HANDLE;
    private static final int LOG2_QUANT_BLOCK_SIZE;
    private static final int QUANT_HALF_BLOCK;
    private static final int QUANT_FLOAT16_BYTES = 2;
    public static final ThreadLocal<float[]> scratchBuffer;

    public Q4_0FloatTensor(int size, ByteBuffer buffer) {
        this.size = size;
        this.buffer = buffer;
    }

    @Override
    public final int size() {
        return this.size;
    }

    @Override
    public final void setFloat(int index2, float value) {
        throw new UnsupportedOperationException("setFloat");
    }

    @Override
    public final GGMLType type() {
        return GGMLType.Q4_0;
    }

    @Override
    public final float getFloat(int index2) {
        assert (0 <= index2 && index2 < this.size);
        int blockIndex = index2 >>> LOG2_QUANT_BLOCK_SIZE;
        int blockOffset = blockIndex * GGMLType.Q4_0.typeSize;
        long offset = blockOffset;
        float scale = AbstractFloatTensor.float16ToFloat(this.buffer.getShort((int)offset));
        int modIndex = index2 & GGMLType.Q4_0.blockSize - 1;
        boolean isLow = modIndex < QUANT_HALF_BLOCK;
        int adjustedIndex = modIndex - (isLow ? 0 : QUANT_HALF_BLOCK);
        int dataIndex = blockOffset + 2 + adjustedIndex;
        int packed = this.buffer.get((int)((long)dataIndex)) & 0xFF;
        int nibble = isLow ? packed & 0xF : packed >>> 4 & 0xF;
        return (float)(nibble - 8) * scale;
    }

    public final void getFloatArray(int index2, float[] out, int outOffset, int length) {
        int copied;
        int outPos = outOffset;
        int end = index2 + length;
        for (int inPos = index2; inPos < end; inPos += copied) {
            int blockIndex = inPos >>> LOG2_QUANT_BLOCK_SIZE;
            int blockOffset = blockIndex * GGMLType.Q4_0.typeSize;
            float scale = AbstractFloatTensor.float16ToFloat(this.buffer.getShort((int)((long)blockOffset)));
            int blockStart = blockIndex * GGMLType.Q4_0.blockSize;
            int blockEnd = Math.min(blockStart + GGMLType.Q4_0.blockSize, end);
            int i = inPos;
            while (i + 1 < blockEnd) {
                int blockMod0 = i & GGMLType.Q4_0.blockSize - 1;
                int blockMod1 = blockMod0 + 1;
                int byteOffset0 = blockOffset + 2 + (blockMod0 < QUANT_HALF_BLOCK ? blockMod0 : blockMod0 - QUANT_HALF_BLOCK);
                int byteOffset1 = blockOffset + 2 + (blockMod1 < QUANT_HALF_BLOCK ? blockMod1 : blockMod1 - QUANT_HALF_BLOCK);
                long offset = byteOffset0;
                int packed0 = this.buffer.get((int)offset) & 0xFF;
                long offset1 = byteOffset1;
                int packed1 = byteOffset1 == byteOffset0 ? packed0 : this.buffer.get((int)offset1) & 0xFF;
                int nibble0 = blockMod0 < QUANT_HALF_BLOCK ? packed0 & 0xF : packed0 >>> 4 & 0xF;
                out[outPos++] = (float)(nibble0 - 8) * scale;
                int nibble1 = blockMod1 < QUANT_HALF_BLOCK ? packed1 & 0xF : packed1 >>> 4 & 0xF;
                out[outPos++] = (float)(nibble1 - 8) * scale;
                i += 2;
            }
            if (i < blockEnd) {
                int blockMod = i & GGMLType.Q4_0.blockSize - 1;
                int byteOffset = blockOffset + 2 + (blockMod < QUANT_HALF_BLOCK ? blockMod : blockMod - QUANT_HALF_BLOCK);
                long offset = byteOffset;
                int packed = this.buffer.get((int)offset) & 0xFF;
                int nibble = blockMod < QUANT_HALF_BLOCK ? packed & 0xF : packed >>> 4 & 0xF;
                out[outPos++] = (float)(nibble - 8) * scale;
                ++i;
            }
            copied = blockEnd - inPos;
        }
    }

    @Override
    public void copyTo(int thisOffset, FloatTensor that, int thatOffset, int size) {
        int i;
        int remaining;
        float[] decoded = scratchBuffer.get();
        int srcIndex = thisOffset;
        int dstIndex = thatOffset;
        for (remaining = size; remaining >= GGMLType.Q4_0.blockSize; remaining -= GGMLType.Q4_0.blockSize) {
            this.getFloatArray(srcIndex, decoded, 0, GGMLType.Q4_0.blockSize);
            for (i = 0; i < GGMLType.Q4_0.blockSize; ++i) {
                that.setFloat(dstIndex + i, FLOAT_ARRAY_HANDLE.get(decoded, i));
            }
            srcIndex += GGMLType.Q4_0.blockSize;
            dstIndex += GGMLType.Q4_0.blockSize;
        }
        if (remaining > 0) {
            this.getFloatArray(srcIndex, decoded, 0, remaining);
            for (i = 0; i < remaining; ++i) {
                that.setFloat(dstIndex + i, FLOAT_ARRAY_HANDLE.get(decoded, i));
            }
        }
    }

    @Override
    public final float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
        int index2;
        float result = 0.0f;
        int blockLimit = size - size % GGMLType.Q4_0.blockSize;
        for (index2 = 0; index2 < blockLimit; index2 += GGMLType.Q4_0.blockSize) {
            int thisBlockIndex = thisOffset + index2 >>> LOG2_QUANT_BLOCK_SIZE;
            int thisBlockOffset = thisBlockIndex * GGMLType.Q4_0.typeSize;
            float thisScale = AbstractFloatTensor.float16ToFloat(this.buffer.getShort((int)((long)thisBlockOffset)));
            int quantOffset = thisBlockOffset + 2;
            float blockResult = 0.0f;
            int thatIndex = thatOffset + index2;
            for (int i = 0; i < QUANT_HALF_BLOCK; ++i) {
                byte packed = this.buffer.get(quantOffset + i);
                blockResult += (float)((packed & 0xF) - 8) * that.getFloat(thatIndex + i) + (float)((packed >>> 4 & 0xF) - 8) * that.getFloat(thatIndex + i + QUANT_HALF_BLOCK);
            }
            result += blockResult * thisScale;
        }
        while (index2 < size) {
            result += this.getFloat(thisOffset + index2) * that.getFloat(thatOffset + index2);
            ++index2;
        }
        return result;
    }

    static {
        try {
            FLOAT_ARRAY_HANDLE = MethodHandles.arrayElementVarHandle(float[].class);
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to get VarHandle for float[]", e);
        }
        LOG2_QUANT_BLOCK_SIZE = Integer.numberOfTrailingZeros(GGMLType.Q4_0.blockSize);
        QUANT_HALF_BLOCK = GGMLType.Q4_0.blockSize / 2;
        scratchBuffer = ThreadLocal.withInitial(() -> new float[GGMLType.Q4_0.blockSize]);
    }
}

