/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.loss.Loss;

public class L2WeightDecay
extends Loss {
    private float lambda;
    private NDList parameters;

    public L2WeightDecay(NDList parameters) {
        this("L2WeightDecay", parameters);
    }

    public L2WeightDecay(String name, NDList parameters) {
        this(name, parameters, 1.0f);
    }

    public L2WeightDecay(String name, NDList parameters, float lambda) {
        super(name);
        this.lambda = lambda;
        this.parameters = parameters;
    }

    private NDArray l2(NDArray w) {
        return w.square().sum();
    }

    @Override
    public NDArray evaluate(NDList label, NDList prediction) {
        NDManager manager = this.parameters.getManager();
        NDArray sum = manager.create(0.0f);
        for (NDArray wi : this.parameters) {
            sum.addi(this.l2(wi));
        }
        return sum.muli(Float.valueOf(this.lambda));
    }
}

