/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.internal.hppc.IntArrayList;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.index.codec.vectors.CentroidAssignments;
import org.elasticsearch.index.codec.vectors.DocIdsWriter;
import org.elasticsearch.index.codec.vectors.IVFVectorsWriter;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;

public class DefaultIVFVectorsWriter
extends IVFVectorsWriter {
    private static final Logger logger = LogManager.getLogger(DefaultIVFVectorsWriter.class);
    private final int vectorPerCluster;

    public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException {
        super(state, rawVectorDelegate);
        this.vectorPerCluster = vectorPerCluster;
    }

    @Override
    long[] buildAndWritePostingsLists(FieldInfo fieldInfo, IVFVectorsWriter.CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, IntArrayList[] assignmentsByCluster) throws IOException {
        long[] offsets = new long[centroidSupplier.size()];
        OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
        BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
        DocIdsWriter docIdsWriter = new DocIdsWriter();
        for (int c = 0; c < centroidSupplier.size(); ++c) {
            float[] centroid = centroidSupplier.centroid(c);
            binarizedByteVectorValues.centroid = centroid;
            IntArrayList cluster = assignmentsByCluster[c];
            offsets[c] = postingsOutput.getFilePointer();
            int size = cluster.size();
            postingsOutput.writeVInt(size);
            postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct((float[])centroid, (float[])centroid)));
            docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, (DataOutput)postingsOutput);
            this.writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
        }
        if (logger.isDebugEnabled()) {
            DefaultIVFVectorsWriter.printClusterQualityStatistics(assignmentsByCluster);
        }
        return offsets;
    }

    private static void printClusterQualityStatistics(IntArrayList[] clusters) {
        float min = Float.MAX_VALUE;
        float max = Float.MIN_VALUE;
        float mean = 0.0f;
        float m2 = 0.0f;
        int count = 0;
        for (IntArrayList cluster : clusters) {
            ++count;
            if (cluster == null) continue;
            float delta = (float)cluster.size() - mean;
            m2 += delta * ((float)cluster.size() - (mean += delta / (float)count));
            min = Math.min(min, (float)cluster.size());
            max = Math.max(max, (float)cluster.size());
        }
        float variance = m2 / (float)(clusters.length - 1);
        logger.debug("Centroid count: {} min: {} max: {} mean: {} stdDev: {} variance: {}", new Object[]{clusters.length, Float.valueOf(min), Float.valueOf(max), Float.valueOf(mean), Math.sqrt(variance), Float.valueOf(variance)});
    }

    private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues) throws IOException {
        int cidx;
        int limit = cluster.size() - 16 + 1;
        OptimizedScalarQuantizer.QuantizationResult[] corrections = new OptimizedScalarQuantizer.QuantizationResult[16];
        for (cidx = 0; cidx < limit; cidx += 16) {
            int j;
            for (j = 0; j < 16; ++j) {
                int ord = cluster.get(cidx + j);
                byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
                postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
                corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord);
            }
            for (j = 0; j < 16; ++j) {
                postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval()));
            }
            for (j = 0; j < 16; ++j) {
                postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval()));
            }
            for (j = 0; j < 16; ++j) {
                int targetComponentSum = corrections[j].quantizedComponentSum();
                assert (targetComponentSum >= 0 && targetComponentSum <= 65535);
                postingsOutput.writeShort((short)targetComponentSum);
            }
            for (j = 0; j < 16; ++j) {
                postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection()));
            }
        }
        while (cidx < cluster.size()) {
            int ord = cluster.get(cidx);
            byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
            OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
            DefaultIVFVectorsWriter.writeQuantizedValue(postingsOutput, binaryValue, correction);
            binarizedByteVectorValues.getCorrectiveTerms(ord);
            postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
            postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval()));
            postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval()));
            postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
            assert (correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 65535);
            postingsOutput.writeShort((short)correction.quantizedComponentSum());
            ++cidx;
        }
    }

    @Override
    IVFVectorsWriter.CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
        return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo);
    }

    static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput) throws IOException {
        OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
        byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
        float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
        for (float[] centroid : centroids) {
            System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
            OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(centroidScratch, quantizedScratch, (byte)4, globalCentroid);
            DefaultIVFVectorsWriter.writeQuantizedValue(centroidOutput, quantizedScratch, result);
        }
        ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * 4).order(ByteOrder.LITTLE_ENDIAN);
        for (float[] centroid : centroids) {
            buffer.asFloatBuffer().put(centroid);
            centroidOutput.writeBytes(buffer.array(), buffer.array().length);
        }
    }

    @Override
    CentroidAssignments calculateAndWriteCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IndexOutput centroidOutput, MergeState mergeState, float[] globalCentroid) throws IOException {
        return this.calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid, false);
    }

    @Override
    CentroidAssignments calculateAndWriteCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IndexOutput centroidOutput, float[] globalCentroid) throws IOException {
        return this.calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid, true);
    }

    CentroidAssignments calculateAndWriteCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, IndexOutput centroidOutput, float[] globalCentroid, boolean cacheCentroids) throws IOException {
        long nanoTime = System.nanoTime();
        KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, this.vectorPerCluster);
        float[][] centroids = kMeansResult.centroids();
        int[] assignments = kMeansResult.assignments();
        int[] soarAssignments = kMeansResult.soarAssignments();
        for (float[] centroid : centroids) {
            for (int j = 0; j < centroid.length; ++j) {
                int n = j;
                globalCentroid[n] = globalCentroid[n] + centroid[j];
            }
        }
        int j = 0;
        while (j < globalCentroid.length) {
            int n = j++;
            globalCentroid[n] = globalCentroid[n] / (float)centroids.length;
        }
        DefaultIVFVectorsWriter.writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);
        if (logger.isDebugEnabled()) {
            logger.debug("calculate centroids and assign vectors time ms: {}", new Object[]{(double)(System.nanoTime() - nanoTime) / 1000000.0});
            logger.debug("final centroid count: {}", new Object[]{centroids.length});
        }
        IntArrayList[] assignmentsByCluster = new IntArrayList[centroids.length];
        for (int c = 0; c < centroids.length; ++c) {
            int j2;
            IntArrayList cluster = new IntArrayList(this.vectorPerCluster);
            for (j2 = 0; j2 < assignments.length; ++j2) {
                if (assignments[j2] != c) continue;
                cluster.add(j2);
            }
            for (j2 = 0; j2 < soarAssignments.length; ++j2) {
                if (soarAssignments[j2] != c) continue;
                cluster.add(j2);
            }
            cluster.trimToSize();
            assignmentsByCluster[c] = cluster;
        }
        if (cacheCentroids) {
            return new CentroidAssignments(centroids, assignmentsByCluster);
        }
        return new CentroidAssignments(centroids.length, assignmentsByCluster);
    }

    static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) throws IOException {
        indexOutput.writeBytes(binaryValue, binaryValue.length);
        indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
        indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval()));
        indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
        assert (corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 65535);
        indexOutput.writeShort((short)corrections.quantizedComponentSum());
    }

    static class BinarizedFloatVectorValues {
        private OptimizedScalarQuantizer.QuantizationResult corrections;
        private final byte[] binarized;
        private final byte[] initQuantized;
        private float[] centroid;
        private final FloatVectorValues values;
        private final OptimizedScalarQuantizer quantizer;
        private int lastOrd = -1;

        BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) {
            this.values = delegate;
            this.quantizer = quantizer;
            this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8];
            this.initQuantized = new byte[delegate.dimension()];
        }

        public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
            if (ord != this.lastOrd) {
                throw new IllegalStateException("attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + this.lastOrd);
            }
            return this.corrections;
        }

        public byte[] vectorValue(int ord) throws IOException {
            if (ord != this.lastOrd) {
                this.binarize(ord);
                this.lastOrd = ord;
            }
            return this.binarized;
        }

        private void binarize(int ord) throws IOException {
            this.corrections = this.quantizer.scalarQuantize(this.values.vectorValue(ord), this.initQuantized, (byte)1, this.centroid);
            BQVectorUtils.packAsBinary(this.initQuantized, this.binarized);
        }
    }

    static class OffHeapCentroidSupplier
    implements IVFVectorsWriter.CentroidSupplier {
        private final IndexInput centroidsInput;
        private final int numCentroids;
        private final int dimension;
        private final float[] scratch;
        private final long rawCentroidOffset;
        private int currOrd = -1;

        OffHeapCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
            this.centroidsInput = centroidsInput;
            this.numCentroids = numCentroids;
            this.dimension = info.getVectorDimension();
            this.scratch = new float[this.dimension];
            this.rawCentroidOffset = (this.dimension + 12 + 2) * numCentroids;
        }

        @Override
        public int size() {
            return this.numCentroids;
        }

        @Override
        public float[] centroid(int centroidOrdinal) throws IOException {
            if (centroidOrdinal == this.currOrd) {
                return this.scratch;
            }
            this.centroidsInput.seek(this.rawCentroidOffset + (long)centroidOrdinal * (long)this.dimension * 4L);
            this.centroidsInput.readFloats(this.scratch, 0, this.dimension);
            this.currOrd = centroidOrdinal;
            return this.scratch;
        }
    }
}

