package org.apache.flink.ml.classification.logisticregression;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;

/* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/LogisticGradient.class */
public class LogisticGradient implements Serializable {
    private final double l2;

    public LogisticGradient(double d) {
        this.l2 = d;
    }

    public Tuple2<Double, Double> computeLoss(List<LabeledPointWithWeight> list, DenseVector denseVector) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (LabeledPointWithWeight labeledPointWithWeight : list) {
            d2 += labeledPointWithWeight.getWeight() * computeLoss(labeledPointWithWeight, denseVector);
            d += labeledPointWithWeight.getWeight();
        }
        if (Double.compare(0.0d, this.l2) != 0) {
            d2 += this.l2 * Math.pow(BLAS.norm2(denseVector), 2.0d);
        }
        return Tuple2.of(Double.valueOf(d), Double.valueOf(d2));
    }

    public void computeGradient(List<LabeledPointWithWeight> list, DenseVector denseVector, DenseVector denseVector2) {
        Iterator<LabeledPointWithWeight> it = list.iterator();
        while (it.hasNext()) {
            computeGradient(it.next(), denseVector, denseVector2);
        }
        if (Double.compare(0.0d, this.l2) != 0) {
            BLAS.axpy(this.l2 * 2.0d, denseVector, denseVector2);
        }
    }

    private double computeLoss(LabeledPointWithWeight labeledPointWithWeight, DenseVector denseVector) {
        double dot = BLAS.dot(labeledPointWithWeight.getFeatures(), denseVector);
        return Math.log(1.0d + Math.exp((-dot) * ((2.0d * labeledPointWithWeight.getLabel()) - 1.0d)));
    }

    private void computeGradient(LabeledPointWithWeight labeledPointWithWeight, DenseVector denseVector, DenseVector denseVector2) {
        double dot = BLAS.dot(labeledPointWithWeight.getFeatures(), denseVector);
        double label = (2.0d * labeledPointWithWeight.getLabel()) - 1.0d;
        BLAS.axpy(labeledPointWithWeight.getWeight() * ((-label) / (Math.exp(dot * label) + 1.0d)), labeledPointWithWeight.getFeatures(), denseVector2);
    }
}
