package org.apache.flink.ml.classification.knn;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseMatrix;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/classification/knn/Knn.class */
public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public Knn() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public KnnModel m0fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        KnnModel m1setModelData = new KnnModel().m1setModelData(tableEnvironment.fromDataStream(genModelData(computeNormSquare(tableEnvironment.toDataStream(tableArr[0])))));
        ReadWriteUtils.updateExistingParams(m1setModelData, getParamMap());
        return m1setModelData;
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static Knn load(StreamExecutionEnvironment streamExecutionEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    private static DataStream<KnnModelData> genModelData(DataStream<Tuple3<DenseVector, Double, Double>> dataStream) {
        DataStream<KnnModelData> mapPartition = DataStreamUtils.mapPartition(dataStream, new RichMapPartitionFunction<Tuple3<DenseVector, Double, Double>, KnnModelData>() { // from class: org.apache.flink.ml.classification.knn.Knn.1
            public void mapPartition(Iterable<Tuple3<DenseVector, Double, Double>> iterable, Collector<KnnModelData> collector) {
                ArrayList<Tuple3> arrayList = new ArrayList();
                Iterator<Tuple3<DenseVector, Double, Double>> it = iterable.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next());
                }
                int size = ((DenseVector) ((Tuple3) arrayList.get(0)).f0).size();
                DenseMatrix denseMatrix = new DenseMatrix(size, arrayList.size());
                DenseVector denseVector = new DenseVector(arrayList.size());
                DenseVector denseVector2 = new DenseVector(arrayList.size());
                int i = 0;
                for (Tuple3 tuple3 : arrayList) {
                    System.arraycopy(((DenseVector) tuple3.f0).values, 0, denseMatrix.values, i * size, size);
                    denseVector2.values[i] = ((Double) tuple3.f1).doubleValue();
                    int i2 = i;
                    i++;
                    denseVector.values[i2] = ((Double) tuple3.f2).doubleValue();
                }
                collector.collect(new KnnModelData(denseMatrix, denseVector, denseVector2));
            }
        });
        mapPartition.getTransformation().setParallelism(1);
        return mapPartition;
    }

    private DataStream<Tuple3<DenseVector, Double, Double>> computeNormSquare(DataStream<Row> dataStream) {
        return dataStream.map(new MapFunction<Row, Tuple3<DenseVector, Double, Double>>() { // from class: org.apache.flink.ml.classification.knn.Knn.2
            public Tuple3<DenseVector, Double, Double> map(Row row) {
                Double d = (Double) row.getField(Knn.this.getLabelCol());
                DenseVector denseVector = (DenseVector) row.getField(Knn.this.getFeaturesCol());
                return Tuple3.of(denseVector, d, Double.valueOf(Math.pow(BLAS.norm2(denseVector), 2.0d)));
            }
        });
    }
}
