package org.apache.flink.ml.classification.logisticregression;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationConfig;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Collector;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/LogisticRegression.class */
public class LogisticRegression implements Estimator<LogisticRegression, LogisticRegressionModel>, LogisticRegressionParams<LogisticRegression> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/LogisticRegression$CacheDataAndDoTrain.class */
    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>, IterationListener<double[]> {
        private final int globalBatchSize;
        private int localBatchSize;
        private final double learningRate;
        private final LogisticGradient logisticGradient;
        private DenseVector gradient;
        private DenseVector coefficient;
        private int coefficientDim;
        private ListState<DenseVector> coefficientState;
        private List<LabeledPointWithWeight> trainData;
        private ListState<LabeledPointWithWeight> trainDataState;
        private final Random random = new Random(2021);
        private List<LabeledPointWithWeight> miniBatchData;
        private double[] feedbackBuffer;
        private ListState<double[]> feedbackBufferState;
        private final OutputTag<LogisticRegressionModelData> modelDataOutputTag;

        public CacheDataAndDoTrain(LogisticGradient logisticGradient, int i, double d, OutputTag<LogisticRegressionModelData> outputTag) {
            this.logisticGradient = logisticGradient;
            this.globalBatchSize = i;
            this.learningRate = d;
            this.modelDataOutputTag = outputTag;
        }

        public void open() {
            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            this.localBatchSize = this.globalBatchSize / numberOfParallelSubtasks;
            if (this.globalBatchSize % numberOfParallelSubtasks > indexOfThisSubtask) {
                this.localBatchSize++;
            }
            this.miniBatchData = new ArrayList(this.localBatchSize);
        }

        private List<LabeledPointWithWeight> getMiniBatchData(List<LabeledPointWithWeight> list, int i) {
            this.miniBatchData.clear();
            for (int i2 = 0; i2 < i; i2++) {
                this.miniBatchData.add(list.get(this.random.nextInt(list.size())));
            }
            return this.miniBatchData;
        }

        private void updateModel() {
            System.arraycopy(this.feedbackBuffer, 0, this.gradient.values, 0, this.gradient.size());
            BLAS.axpy((-this.learningRate) / this.feedbackBuffer[this.coefficientDim], this.gradient, this.coefficient);
        }

        public void onEpochWatermarkIncremented(int i, IterationListener.Context context, Collector<double[]> collector) throws Exception {
            if (i == 0) {
                this.coefficient = new DenseVector(this.feedbackBuffer);
                this.coefficientDim = this.coefficient.size();
                this.feedbackBuffer = new double[this.coefficientDim + 2];
                this.gradient = new DenseVector(this.coefficientDim);
            } else {
                updateModel();
            }
            Arrays.fill(this.gradient.values, 0.0d);
            if (this.trainData == null) {
                this.trainData = IteratorUtils.toList(((Iterable) this.trainDataState.get()).iterator());
            }
            this.miniBatchData = getMiniBatchData(this.trainData, this.localBatchSize);
            Tuple2<Double, Double> computeLoss = this.logisticGradient.computeLoss(this.miniBatchData, this.coefficient);
            this.logisticGradient.computeGradient(this.miniBatchData, this.coefficient, this.gradient);
            System.arraycopy(this.gradient.values, 0, this.feedbackBuffer, 0, this.gradient.size());
            this.feedbackBuffer[this.coefficientDim] = ((Double) computeLoss.f0).doubleValue();
            this.feedbackBuffer[this.coefficientDim + 1] = ((Double) computeLoss.f1).doubleValue();
            collector.collect(this.feedbackBuffer);
        }

        public void onIterationTerminated(IterationListener.Context context, Collector<double[]> collector) {
            this.trainDataState.clear();
            this.coefficientState.clear();
            this.feedbackBufferState.clear();
            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                updateModel();
                context.output(this.modelDataOutputTag, new LogisticRegressionModelData(this.coefficient));
            }
        }

        public void processElement1(StreamRecord<LabeledPointWithWeight> streamRecord) throws Exception {
            this.trainDataState.add(streamRecord.getValue());
        }

        public void processElement2(StreamRecord<double[]> streamRecord) {
            this.feedbackBuffer = (double[]) streamRecord.getValue();
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.trainDataState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("trainDataState", TypeInformation.of(LabeledPointWithWeight.class)));
            this.coefficientState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("coefficientState", TypeInformation.of(DenseVector.class)));
            OperatorStateUtils.getUniqueElement(this.coefficientState, "coefficientState").ifPresent(denseVector -> {
                this.coefficient = denseVector;
            });
            this.feedbackBufferState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("feedbackBufferState", PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO));
            OperatorStateUtils.getUniqueElement(this.feedbackBufferState, "feedbackBufferState").ifPresent(dArr -> {
                this.feedbackBuffer = dArr;
            });
            if (this.coefficient != null) {
                this.coefficientDim = this.coefficient.size();
                this.gradient = new DenseVector(new double[this.coefficientDim]);
            }
        }

        public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception {
            this.coefficientState.clear();
            if (this.coefficient != null) {
                this.coefficientState.add(this.coefficient);
            }
            this.feedbackBufferState.clear();
            if (this.feedbackBuffer != null) {
                this.feedbackBufferState.add(this.feedbackBuffer);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/LogisticRegression$GenInitModelData.class */
    public static class GenInitModelData extends AbstractStreamOperator<double[]> implements OneInputStreamOperator<LabeledPointWithWeight, double[]>, BoundedOneInput {
        private int dim;
        private ListState<Integer> dimState;

        private GenInitModelData() {
            this.dim = 0;
        }

        public void endInput() {
            this.output.collect(new StreamRecord(new double[this.dim]));
        }

        public void processElement(StreamRecord<LabeledPointWithWeight> streamRecord) {
            if (this.dim == 0) {
                this.dim = ((LabeledPointWithWeight) streamRecord.getValue()).getFeatures().size();
            } else if (this.dim != ((LabeledPointWithWeight) streamRecord.getValue()).getFeatures().size()) {
                throw new RuntimeException("The training data should all have same dimensions.");
            }
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.dimState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("dimState", BasicTypeInfo.INT_TYPE_INFO));
            this.dim = ((Integer) OperatorStateUtils.getUniqueElement(this.dimState, "dimState").orElse(0)).intValue();
        }

        public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception {
            this.dimState.clear();
            this.dimState.add(Integer.valueOf(this.dim));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/classification/logisticregression/LogisticRegression$TrainIterationBody.class */
    public static class TrainIterationBody implements IterationBody {
        private final LogisticGradient logisticGradient;
        private final int globalBatchSize;
        private final double learningRate;
        private final int maxIter;
        private final double tol;

        public TrainIterationBody(LogisticGradient logisticGradient, int i, double d, int i2, double d2) {
            this.logisticGradient = logisticGradient;
            this.globalBatchSize = i;
            this.learningRate = d;
            this.maxIter = i2;
            this.tol = d2;
        }

        public IterationBodyResult process(DataStreamList dataStreamList, DataStreamList dataStreamList2) {
            DataStream dataStream = dataStreamList.get(0);
            DataStream dataStream2 = dataStreamList2.get(0);
            OutputTag<LogisticRegressionModelData> outputTag = new OutputTag<LogisticRegressionModelData>("MODEL_OUTPUT") { // from class: org.apache.flink.ml.classification.logisticregression.LogisticRegression.TrainIterationBody.1
            };
            DataStream transform = dataStream2.connect(dataStream).transform("CacheDataAndDoTrain", PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, new CacheDataAndDoTrain(this.logisticGradient, this.globalBatchSize, this.learningRate, outputTag));
            DataStreamList forEachRound = IterationBody.forEachRound(DataStreamList.of(new DataStream[]{transform}), dataStreamList3 -> {
                return DataStreamList.of(new DataStream[]{DataStreamUtils.allReduceSum(dataStreamList3.get(0))});
            });
            return new IterationBodyResult(DataStreamList.of(new DataStream[]{forEachRound.get(0)}), DataStreamList.of(new DataStream[]{transform.getSideOutput(outputTag)}), forEachRound.get(0).map(obj -> {
                double[] dArr = (double[]) obj;
                return Double.valueOf(dArr[dArr.length - 1] / dArr[dArr.length - 2]);
            }).flatMap(new TerminateOnMaxIterOrTol(Integer.valueOf(this.maxIter), Double.valueOf(this.tol))));
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case -1953380414:
                    if (implMethodName.equals("lambda$process$97e3fc93$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/logisticregression/LogisticRegression$TrainIterationBody") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Double;")) {
                        return obj -> {
                            double[] dArr = (double[]) obj;
                            return Double.valueOf(dArr[dArr.length - 1] / dArr[dArr.length - 2]);
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    public LogisticRegression() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static LogisticRegression load(StreamExecutionEnvironment streamExecutionEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public LogisticRegressionModel m4fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        String multiClass = getMultiClass();
        Preconditions.checkArgument("auto".equals(multiClass) || "binomial".equals(multiClass), "Multinomial classification is not supported yet. Supported options: [auto, binomial].");
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        SingleOutputStreamOperator map = tableEnvironment.toDataStream(tableArr[0]).map(row -> {
            Double valueOf = Double.valueOf(getWeightCol() == null ? 1.0d : ((Double) row.getField(getWeightCol())).doubleValue());
            Double d = (Double) row.getField(getLabelCol());
            if (Double.compare(0.0d, d.doubleValue()) == 0 || Double.compare(1.0d, d.doubleValue()) == 0) {
                return new LabeledPointWithWeight((DenseVector) row.getField(getFeaturesCol()), d.doubleValue(), valueOf.doubleValue());
            }
            throw new RuntimeException("Multinomial classification is not supported yet. Supported options: [auto, binomial].");
        });
        LogisticRegressionModel m5setModelData = new LogisticRegressionModel().m5setModelData(tableEnvironment.fromDataStream(train(map, map.transform("genInitModelData", PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, new GenInitModelData()))));
        ReadWriteUtils.updateExistingParams(m5setModelData, this.paramMap);
        return m5setModelData;
    }

    private DataStream<LogisticRegressionModelData> train(DataStream<LabeledPointWithWeight> dataStream, DataStream<double[]> dataStream2) {
        return Iterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(new DataStream[]{dataStream2}), ReplayableDataStreamList.notReplay(new DataStream[]{dataStream}), IterationConfig.newBuilder().build(), new TrainIterationBody(new LogisticGradient(getReg()), getGlobalBatchSize(), getLearningRate(), getMaxIter(), getTol())).get(0);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 712690700:
                if (implMethodName.equals("lambda$fit$af50211a$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/classification/logisticregression/LogisticRegression") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Lorg/apache/flink/ml/common/feature/LabeledPointWithWeight;")) {
                    LogisticRegression logisticRegression = (LogisticRegression) serializedLambda.getCapturedArg(0);
                    return row -> {
                        Double valueOf = Double.valueOf(getWeightCol() == null ? 1.0d : ((Double) row.getField(getWeightCol())).doubleValue());
                        Double d = (Double) row.getField(getLabelCol());
                        if (Double.compare(0.0d, d.doubleValue()) == 0 || Double.compare(1.0d, d.doubleValue()) == 0) {
                            return new LabeledPointWithWeight((DenseVector) row.getField(getFeaturesCol()), d.doubleValue(), valueOf.doubleValue());
                        }
                        throw new RuntimeException("Multinomial classification is not supported yet. Supported options: [auto, binomial].");
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
