package org.apache.flink.ml.clustering.kmeans;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationConfig;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound;
import org.apache.flink.ml.common.iteration.TerminateOnMaxIter;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans.class */
public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMeans> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans$CentroidAccumulator.class */
    private static class CentroidAccumulator implements ReduceFunction<Tuple3<Integer, DenseVector, Long>> {
        private CentroidAccumulator() {
        }

        public Tuple3<Integer, DenseVector, Long> reduce(Tuple3<Integer, DenseVector, Long> tuple3, Tuple3<Integer, DenseVector, Long> tuple32) throws Exception {
            for (int i = 0; i < ((DenseVector) tuple3.f1).size(); i++) {
                double[] dArr = ((DenseVector) tuple3.f1).values;
                int i2 = i;
                dArr[i2] = dArr[i2] + ((DenseVector) tuple32.f1).values[i];
            }
            return new Tuple3<>(tuple3.f0, tuple3.f1, Long.valueOf(((Long) tuple3.f2).longValue() + ((Long) tuple32.f2).longValue()));
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans$CentroidAverager.class */
    private static class CentroidAverager implements MapFunction<Tuple3<Integer, DenseVector, Long>, DenseVector> {
        private CentroidAverager() {
        }

        public DenseVector map(Tuple3<Integer, DenseVector, Long> tuple3) {
            for (int i = 0; i < ((DenseVector) tuple3.f1).size(); i++) {
                double[] dArr = ((DenseVector) tuple3.f1).values;
                int i2 = i;
                dArr[i2] = dArr[i2] / ((Long) tuple3.f2).longValue();
            }
            return (DenseVector) tuple3.f1;
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans$CountAppender.class */
    private static class CountAppender implements MapFunction<Tuple2<Integer, DenseVector>, Tuple3<Integer, DenseVector, Long>> {
        private CountAppender() {
        }

        public Tuple3<Integer, DenseVector, Long> map(Tuple2<Integer, DenseVector> tuple2) {
            return Tuple3.of(tuple2.f0, tuple2.f1, 1L);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans$KMeansIterationBody.class */
    public static class KMeansIterationBody implements IterationBody {
        private final int maxIterationNum;
        private final DistanceMeasure distanceMeasure;

        public KMeansIterationBody(int i, DistanceMeasure distanceMeasure) {
            this.maxIterationNum = i;
            this.distanceMeasure = distanceMeasure;
        }

        public IterationBodyResult process(DataStreamList dataStreamList, DataStreamList dataStreamList2) {
            DataStream dataStream = dataStreamList.get(0);
            DataStream dataStream2 = dataStreamList2.get(0);
            SingleOutputStreamOperator flatMap = dataStream.flatMap(new TerminateOnMaxIter(this.maxIterationNum));
            DataStream transform = dataStream2.connect(dataStream.broadcast()).transform("SelectNearestCentroid", new TupleTypeInfo(new TypeInformation[]{BasicTypeInfo.INT_TYPE_INFO, DenseVectorTypeInfo.INSTANCE}), new SelectNearestCentroidOperator(this.distanceMeasure));
            final AllWindowFunction<DenseVector, DenseVector[], TimeWindow> allWindowFunction = new AllWindowFunction<DenseVector, DenseVector[], TimeWindow>() { // from class: org.apache.flink.ml.clustering.kmeans.KMeans.KMeansIterationBody.1
                public void apply(TimeWindow timeWindow, Iterable<DenseVector> iterable, Collector<DenseVector[]> collector) {
                    collector.collect(IteratorUtils.toList(iterable.iterator()).toArray(new DenseVector[0]));
                }

                public /* bridge */ /* synthetic */ void apply(Window window, Iterable iterable, Collector collector) throws Exception {
                    apply((TimeWindow) window, (Iterable<DenseVector>) iterable, (Collector<DenseVector[]>) collector);
                }
            };
            DataStream dataStream3 = IterationBody.forEachRound(DataStreamList.of(new DataStream[]{transform}), new IterationBody.PerRoundSubBody() { // from class: org.apache.flink.ml.clustering.kmeans.KMeans.KMeansIterationBody.2
                public DataStreamList process(DataStreamList dataStreamList3) {
                    return DataStreamList.of(new DataStream[]{dataStreamList3.get(0).map(new CountAppender()).keyBy(tuple3 -> {
                        return (Integer) tuple3.f0;
                    }).window(EndOfStreamWindows.get()).reduce(new CentroidAccumulator()).map(new CentroidAverager()).windowAll(EndOfStreamWindows.get()).apply(allWindowFunction)});
                }

                private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
                    String implMethodName = serializedLambda.getImplMethodName();
                    boolean z = -1;
                    switch (implMethodName.hashCode()) {
                        case 1021313028:
                            if (implMethodName.equals("lambda$process$796fd106$1")) {
                                z = false;
                                break;
                            }
                            break;
                    }
                    switch (z) {
                        case false:
                            if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/clustering/kmeans/KMeans$KMeansIterationBody$2") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple3;)Ljava/lang/Integer;")) {
                                return tuple3 -> {
                                    return (Integer) tuple3.f0;
                                };
                            }
                            break;
                    }
                    throw new IllegalArgumentException("Invalid lambda deserialization");
                }
            }).get(0);
            return new IterationBodyResult(DataStreamList.of(new DataStream[]{dataStream3}), DataStreamList.of(new DataStream[]{dataStream3.flatMap(new ForwardInputsOfLastRound())}), flatMap);
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans$SelectNearestCentroidOperator.class */
    private static class SelectNearestCentroidOperator extends AbstractStreamOperator<Tuple2<Integer, DenseVector>> implements TwoInputStreamOperator<DenseVector, DenseVector[], Tuple2<Integer, DenseVector>>, IterationListener<Tuple2<Integer, DenseVector>> {
        private final DistanceMeasure distanceMeasure;
        private ListState<DenseVector> points;
        private ListState<DenseVector[]> centroids;

        public SelectNearestCentroidOperator(DistanceMeasure distanceMeasure) {
            this.distanceMeasure = distanceMeasure;
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.points = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("points", DenseVector.class));
            this.centroids = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("centroids", ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE)));
        }

        public void processElement1(StreamRecord<DenseVector> streamRecord) throws Exception {
            this.points.add(streamRecord.getValue());
        }

        public void processElement2(StreamRecord<DenseVector[]> streamRecord) throws Exception {
            this.centroids.add(streamRecord.getValue());
        }

        public void onEpochWatermarkIncremented(int i, IterationListener.Context context, Collector<Tuple2<Integer, DenseVector>> collector) throws Exception {
            List list = IteratorUtils.toList(((Iterable) this.centroids.get()).iterator());
            if (list.size() != 1) {
                throw new RuntimeException("The operator received " + list.size() + " list of centroids in this round");
            }
            Vector[] vectorArr = (DenseVector[]) list.get(0);
            for (DenseVector denseVector : (Iterable) this.points.get()) {
                double d = Double.MAX_VALUE;
                int i2 = -1;
                for (int i3 = 0; i3 < vectorArr.length; i3++) {
                    double distance = this.distanceMeasure.distance(vectorArr[i3], denseVector);
                    if (distance < d) {
                        d = distance;
                        i2 = i3;
                    }
                }
                this.output.collect(new StreamRecord(Tuple2.of(Integer.valueOf(i2), denseVector)));
            }
            this.centroids.clear();
        }

        public void onIterationTerminated(IterationListener.Context context, Collector<Tuple2<Integer, DenseVector>> collector) {
            this.points.clear();
        }
    }

    public KMeans() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public KMeansModel m12fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        DataStream map = tableEnvironment.toDataStream(tableArr[0]).map(row -> {
            return (DenseVector) row.getField(getFeaturesCol());
        });
        KMeansModel m13setModelData = new KMeansModel().m13setModelData(tableEnvironment.fromDataStream(Iterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(new DataStream[]{selectRandomCentroids(map, getK(), getSeed())}), ReplayableDataStreamList.notReplay(new DataStream[]{map}), IterationConfig.newBuilder().setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND).build(), new KMeansIterationBody(getMaxIter(), DistanceMeasure.getInstance(getDistanceMeasure()))).get(0).map(KMeansModelData::new)));
        ReadWriteUtils.updateExistingParams(m13setModelData, this.paramMap);
        return m13setModelData;
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static KMeans load(StreamExecutionEnvironment streamExecutionEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public static DataStream<DenseVector[]> selectRandomCentroids(DataStream<DenseVector> dataStream, final int i, final long j) {
        DataStream<DenseVector[]> mapPartition = DataStreamUtils.mapPartition(dataStream, new MapPartitionFunction<DenseVector, DenseVector[]>() { // from class: org.apache.flink.ml.clustering.kmeans.KMeans.1
            public void mapPartition(Iterable<DenseVector> iterable, Collector<DenseVector[]> collector) {
                ArrayList arrayList = new ArrayList();
                Iterator<DenseVector> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                Collections.shuffle(arrayList, new Random(j));
                collector.collect(arrayList.subList(0, i).toArray(new DenseVector[0]));
            }
        });
        mapPartition.getTransformation().setParallelism(1);
        return mapPartition;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -550846040:
                if (implMethodName.equals("lambda$fit$33d572ed$1")) {
                    z = false;
                    break;
                }
                break;
            case 1818100338:
                if (implMethodName.equals("<init>")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/clustering/kmeans/KMeans") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Lorg/apache/flink/ml/linalg/DenseVector;")) {
                    KMeans kMeans = (KMeans) serializedLambda.getCapturedArg(0);
                    return row -> {
                        return (DenseVector) row.getField(getFeaturesCol());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/clustering/kmeans/KMeansModelData") && serializedLambda.getImplMethodSignature().equals("([Lorg/apache/flink/ml/linalg/DenseVector;)V")) {
                    return KMeansModelData::new;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
