/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.training.loss.Loss;

public final class YOLOv3Loss
extends Loss {
    private static final float[] PRESETANCHORS = new float[]{116.0f, 90.0f, 156.0f, 198.0f, 373.0f, 326.0f, 30.0f, 61.0f, 62.0f, 45.0f, 59.0f, 119.0f, 10.0f, 13.0f, 16.0f, 30.0f, 33.0f, 23.0f};
    private float[] anchors;
    private int numClasses;
    private int boxAttr;
    private Shape inputShape;
    private float ignoreThreshold;
    private NDManager manager;
    private static final float EPSILON = 1.0E-7f;

    private YOLOv3Loss(Builder builder) {
        super(builder.name);
        this.anchors = builder.anchorsArray;
        this.numClasses = builder.numClasses;
        this.boxAttr = builder.numClasses + 5;
        this.inputShape = builder.inputShape;
        this.ignoreThreshold = builder.ignoreThreshold;
    }

    public static float[] getPresetAnchors() {
        return (float[])PRESETANCHORS.clone();
    }

    public NDArray clipByTensor(NDArray tList, float tMin, float tMax) {
        NDArray result = tList.gte(Float.valueOf(tMin)).mul(tList).add(tList.lt(Float.valueOf(tMin)).mul(Float.valueOf(tMin)));
        result = result.lte(Float.valueOf(tMax)).mul(result).add(result.gt(Float.valueOf(tMax)).mul(Float.valueOf(tMax)));
        return result;
    }

    public NDArray mseLoss(NDArray prediction, NDArray target) {
        return prediction.sub(target).pow(2);
    }

    public NDArray bceLoss(NDArray prediction, NDArray target) {
        prediction = this.clipByTensor(prediction, 1.0E-7f, 0.9999999f);
        return prediction.log().mul(target).add(prediction.mul(-1).add(1).log().mul(target.mul(-1).add(1))).mul(-1);
    }

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        this.manager = predictions.getManager();
        NDArray[] lossComponents = new NDArray[3];
        for (int i = 0; i < 3; ++i) {
            lossComponents[i] = this.evaluateOneOutput(i, (NDArray)predictions.get(i), labels.singletonOrThrow());
        }
        return NDArrays.add(lossComponents);
    }

    public NDArray evaluateOneOutput(int componentIndex, NDArray input, NDArray labels) {
        int batchSize = (int)input.getShape().get(0);
        int inW = (int)input.getShape().get(2);
        int inH = (int)input.getShape().get(3);
        NDArray prediction = input.reshape(batchSize, 3L, this.boxAttr, inW, inH).transpose(1, 0, 3, 4, 2);
        NDArray x = Activation.sigmoid(prediction.get("...,0", new Object[0]));
        NDArray y = Activation.sigmoid(prediction.get("...,1", new Object[0]));
        NDArray w = prediction.get("...,2", new Object[0]);
        NDArray h = prediction.get("...,3", new Object[0]);
        NDArray conf = Activation.sigmoid(prediction.get("...,4", new Object[0])).transpose(1, 0, 2, 3);
        NDArray predClass = Activation.sigmoid(prediction.get("...,5:", new Object[0])).transpose(1, 0, 2, 3, 4);
        NDList truthList = this.getTarget(labels, inH, inW);
        NDArray boxLossScale = ((NDArray)truthList.get(0)).transpose(1, 0, 2, 3);
        NDArray groundTruth = (NDArray)truthList.get(1);
        NDArray iou = this.calculateIOU(x, y, groundTruth.get("...,0:4", new Object[0]), componentIndex).transpose(1, 0, 2, 3);
        NDArray noObjMask = NDArrays.where(iou.lte(Float.valueOf(this.ignoreThreshold)), this.manager.ones(iou.getShape()), this.manager.create(0.0f));
        NDArray objMask = iou.argMax(1).oneHot(3).transpose(0, 3, 1, 2);
        objMask = NDArrays.where(iou.gte(Float.valueOf(this.ignoreThreshold / 2.0f)), objMask, this.manager.zeros(objMask.getShape()));
        noObjMask = NDArrays.where(objMask.eq(Float.valueOf(1.0f)), this.manager.zeros(noObjMask.getShape()), noObjMask);
        NDArray xTrue = groundTruth.get("...,0", new Object[0]);
        NDArray yTrue = groundTruth.get("...,1", new Object[0]);
        NDArray wTrue = groundTruth.get("...,2", new Object[0]);
        NDArray hTrue = groundTruth.get("...,3", new Object[0]);
        NDArray classTrue = groundTruth.get("...,4:", new Object[0]).transpose(1, 0, 2, 3, 4);
        NDArray widths = this.manager.create(new float[]{this.anchors[componentIndex * 6], this.anchors[componentIndex * 6 + 2], this.anchors[componentIndex * 6 + 4]}).div(this.inputShape.get(0));
        NDArray heights = this.manager.create(new float[]{this.anchors[componentIndex * 6 + 1], this.anchors[componentIndex * 6 + 3], this.anchors[componentIndex * 6 + 5]}).div(this.inputShape.get(1));
        NDArray boxLoss = objMask.mul(boxLossScale).mul(NDArrays.add(xTrue.sub(x).pow(2), yTrue.sub(y).pow(2), wTrue.sub(w.exp().mul(widths.broadcast(inH, inW, batchSize, 3L).transpose(3, 2, 1, 0))).pow(2), hTrue.sub(h.exp().mul(heights.broadcast(inH, inW, batchSize, 3L).transpose(3, 2, 1, 0))).pow(2)).transpose(1, 0, 2, 3)).sum();
        NDArray confLoss = objMask.mul(conf.add(Float.valueOf(1.0E-7f)).log().mul(-1).add(this.bceLoss(predClass, classTrue).sum(new int[]{4}))).sum();
        NDArray noObjLoss = noObjMask.mul(conf.mul(-1).add(Float.valueOf(1.0000001f)).log().mul(-1)).sum();
        return boxLoss.add(confLoss).add(noObjLoss).div(batchSize);
    }

    public NDList getTarget(NDArray labels, int inH, int inW) {
        int batchSize = (int)labels.size(0);
        NDList boxLossComponents = new NDList();
        NDList groundTruthComponents = new NDList();
        for (int batch = 0; batch < batchSize; ++batch) {
            if (labels.get(batch).size(0) == 0L) continue;
            NDArray boxLoss = this.manager.zeros(new Shape(inW, inH), DataType.FLOAT32);
            NDArray groundTruth = this.manager.zeros(new Shape(inW, inH, this.boxAttr - 1), DataType.FLOAT32);
            NDArray picture = labels.get(batch);
            NDArray xgt = picture.get("...,1", new Object[0]).add(picture.get("...,3", new Object[0]).div(2)).mul(inW);
            NDArray ygt = picture.get("...,2", new Object[0]).add(picture.get("...,4", new Object[0]).div(2)).mul(inH);
            NDArray wgt = picture.get("...,3", new Object[0]);
            NDArray hgt = picture.get("...,4", new Object[0]);
            NDArray objectClass = picture.get("...,0", new Object[0]);
            objectClass = objectClass.oneHot(this.numClasses);
            NDArray curLabel = labels.get(batch);
            int objectNum = (int)curLabel.size(0);
            for (int i = 0; i < objectNum; ++i) {
                int tx = (int)xgt.get(i).getFloat(new long[0]);
                int ty = (int)ygt.get(i).getFloat(new long[0]);
                float bx = xgt.get(i).getFloat(new long[0]) - (float)tx;
                float by = ygt.get(i).getFloat(new long[0]) - (float)ty;
                String index = tx + "," + ty;
                groundTruth.set(new NDIndex(index + ",0", new Object[0]), (Number)Float.valueOf(bx));
                groundTruth.set(new NDIndex(index + ",1", new Object[0]), (Number)Float.valueOf(by));
                groundTruth.set(new NDIndex(index + ",2", new Object[0]), (Number)Float.valueOf(wgt.getFloat(i)));
                groundTruth.set(new NDIndex(index + ",3", new Object[0]), (Number)Float.valueOf(hgt.getFloat(i)));
                groundTruth.set(new NDIndex(index + ",4:", new Object[0]), objectClass.get(i));
                boxLoss.set(new NDIndex(index, new Object[0]), (Number)Float.valueOf(2.0f - wgt.getFloat(i) * hgt.getFloat(i)));
            }
            boxLossComponents.add(boxLoss);
            groundTruthComponents.add(groundTruth);
        }
        NDArray boxLossScale = NDArrays.stack(boxLossComponents).broadcast(3L, batchSize, inW, inH);
        NDArray groundTruth = NDArrays.stack(groundTruthComponents).broadcast(3L, batchSize, inW, inH, this.boxAttr - 1);
        return new NDList(boxLossScale, groundTruth);
    }

    public NDArray calculateIOU(NDArray predx, NDArray predy, NDArray groundTruth, int componentIndex) {
        int inW = (int)predx.getShape().get(2);
        int inH = (int)predx.getShape().get(3);
        int strideW = (int)this.inputShape.get(0) / inW;
        int strideH = (int)this.inputShape.get(1) / inH;
        NDList iouComponent = new NDList();
        for (int i = 0; i < 3; ++i) {
            NDArray curPredx = predx.get(i);
            NDArray curPredy = predy.get(i);
            float width = this.anchors[componentIndex * 6 + 2 * i] / (float)strideW;
            float height = this.anchors[componentIndex * 6 + 2 * i + 1] / (float)strideH;
            NDArray predLeft = curPredx.sub(Float.valueOf(width / 2.0f));
            NDArray predRight = curPredx.add(Float.valueOf(width / 2.0f));
            NDArray predTop = curPredy.sub(Float.valueOf(height / 2.0f));
            NDArray predBottom = curPredy.add(Float.valueOf(height / 2.0f));
            NDArray truth = groundTruth.get(i);
            NDArray trueLeft = truth.get("...,0", new Object[0]).sub(truth.get("...,2", new Object[0]).mul(inW).div(2));
            NDArray trueRight = truth.get("...,0", new Object[0]).add(truth.get("...,2", new Object[0]).mul(inW).div(2));
            NDArray trueTop = truth.get("...,1", new Object[0]).sub(truth.get("...,3", new Object[0]).mul(inH).div(2));
            NDArray trueBottom = truth.get("...,1", new Object[0]).add(truth.get("...,3", new Object[0]).mul(inH).div(2));
            NDArray left = NDArrays.maximum(predLeft, trueLeft);
            NDArray right = NDArrays.minimum(predRight, trueRight);
            NDArray top = NDArrays.maximum(predTop, trueTop);
            NDArray bottom = NDArrays.minimum(predBottom, trueBottom);
            NDArray inter = right.sub(left).mul(bottom.sub(top));
            NDArray union = truth.get("...,2", new Object[0]).mul(inW).mul(truth.get("...,3", new Object[0]).mul(inH)).add(Float.valueOf(width * height)).sub(inter).add(Float.valueOf(1.0E-7f));
            iouComponent.add(inter.div(union));
        }
        return NDArrays.stack(iouComponent);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String name = "YOLOv3Loss";
        private float[] anchorsArray = YOLOv3Loss.access$500();
        private int numClasses = 20;
        private Shape inputShape = new Shape(419L, 419L);
        private float ignoreThreshold = 0.5f;

        public Builder setName(String name) {
            this.name = name;
            return this;
        }

        public Builder setAnchorsArray(float[] anchorsArray) {
            if (anchorsArray.length != PRESETANCHORS.length) {
                throw new IllegalArgumentException(String.format("setAnchorsArray requires anchors of length %d, but was given filters of length %d instead", PRESETANCHORS.length, anchorsArray.length));
            }
            this.anchorsArray = anchorsArray;
            return this;
        }

        public Builder setNumClasses(int numClasses) {
            this.numClasses = numClasses;
            return this;
        }

        public Builder setInputShape(Shape inputShape) {
            this.inputShape = inputShape;
            return this;
        }

        public Builder optIgnoreThreshold(float ignoreThreshold) {
            this.ignoreThreshold = ignoreThreshold;
            return this;
        }

        public YOLOv3Loss build() {
            return new YOLOv3Loss(this);
        }
    }
}

