package org.apache.flink.streaming.api.runners.python.beam.state;

import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.python.util.ProtoUtils;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.streaming.api.utils.ByteArrayWrapper;
import org.apache.flink.streaming.api.utils.ByteArrayWrapperSerializer;
import org.apache.flink.streaming.api.utils.PythonOperatorUtils;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/streaming/api/runners/python/beam/state/BeamKeyedStateStore.class */
public class BeamKeyedStateStore implements BeamStateStore {
    private final KeyedStateBackend<?> keyedStateBackend;
    private final TypeSerializer<?> keySerializer;

    @Nullable
    private final TypeSerializer<?> namespaceSerializer;
    private final TypeSerializer<byte[]> valueSerializer = PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.createSerializer(new ExecutionConfig());
    private final ByteArrayInputStreamWithPos bais = new ByteArrayInputStreamWithPos();
    private final DataInputViewStreamWrapper baisWrapper = new DataInputViewStreamWrapper(this.bais);
    private final Map<String, StateDescriptor<?, ?>> stateDescriptorCache = new HashMap();

    public BeamKeyedStateStore(KeyedStateBackend<?> keyedStateBackend, TypeSerializer<?> typeSerializer, @Nullable TypeSerializer<?> typeSerializer2) {
        this.keyedStateBackend = keyedStateBackend;
        this.keySerializer = typeSerializer;
        this.namespaceSerializer = typeSerializer2;
    }

    @Override // org.apache.flink.streaming.api.runners.python.beam.state.BeamStateStore
    public ListState<byte[]> getListState(BeamFnApi.StateRequest stateRequest) throws Exception {
        ListStateDescriptor listStateDescriptor;
        if (!stateRequest.getStateKey().hasBagUserState()) {
            throw new RuntimeException("Unsupported keyed bag state request: " + stateRequest);
        }
        BeamFnApi.StateKey.BagUserState bagUserState = stateRequest.getStateKey().getBagUserState();
        byte[] byteArray = bagUserState.getKey().toByteArray();
        this.bais.setBuffer(byteArray, 0, byteArray.length);
        setCurrentKey(this.keySerializer.deserialize(this.baisWrapper));
        FlinkFnApi.StateDescriptor parseFrom = FlinkFnApi.StateDescriptor.parseFrom(Base64.getDecoder().decode(bagUserState.getUserStateId()));
        String str = BeamStateStore.PYTHON_STATE_PREFIX + parseFrom.getStateName();
        ListStateDescriptor listStateDescriptor2 = (StateDescriptor) this.stateDescriptorCache.get(str);
        if (listStateDescriptor2 instanceof ListStateDescriptor) {
            listStateDescriptor = listStateDescriptor2;
        } else {
            if (listStateDescriptor2 != null) {
                throw new RuntimeException(String.format("State name corrupt detected: '%s' is used both as LIST state and '%s' state at the same time.", str, listStateDescriptor2.getType()));
            }
            listStateDescriptor = new ListStateDescriptor(str, this.valueSerializer);
            if (parseFrom.hasStateTtlConfig()) {
                listStateDescriptor.enableTimeToLive(ProtoUtils.parseStateTtlConfigFromProto(parseFrom.getStateTtlConfig()));
            }
            this.stateDescriptorCache.put(str, listStateDescriptor);
        }
        byte[] byteArray2 = bagUserState.getWindow().toByteArray();
        if (byteArray2.length == 0) {
            return this.keyedStateBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, listStateDescriptor);
        }
        Preconditions.checkNotNull(this.namespaceSerializer);
        this.bais.setBuffer(byteArray2, 0, byteArray2.length);
        return this.keyedStateBackend.getPartitionedState(this.namespaceSerializer.deserialize(this.baisWrapper), this.namespaceSerializer, listStateDescriptor);
    }

    @Override // org.apache.flink.streaming.api.runners.python.beam.state.BeamStateStore
    public MapState<ByteArrayWrapper, byte[]> getMapState(BeamFnApi.StateRequest stateRequest) throws Exception {
        MapStateDescriptor mapStateDescriptor;
        if (!stateRequest.getStateKey().hasMultimapSideInput()) {
            throw new RuntimeException("Unsupported keyed map state request: " + stateRequest);
        }
        BeamFnApi.StateKey.MultimapSideInput multimapSideInput = stateRequest.getStateKey().getMultimapSideInput();
        byte[] byteArray = multimapSideInput.getKey().toByteArray();
        this.bais.setBuffer(byteArray, 0, byteArray.length);
        setCurrentKey(this.keySerializer.deserialize(this.baisWrapper));
        FlinkFnApi.StateDescriptor parseFrom = FlinkFnApi.StateDescriptor.parseFrom(Base64.getDecoder().decode(multimapSideInput.getSideInputId()));
        String str = BeamStateStore.PYTHON_STATE_PREFIX + parseFrom.getStateName();
        MapStateDescriptor mapStateDescriptor2 = (StateDescriptor) this.stateDescriptorCache.get(str);
        if (mapStateDescriptor2 instanceof MapStateDescriptor) {
            mapStateDescriptor = mapStateDescriptor2;
        } else {
            if (mapStateDescriptor2 != null) {
                throw new RuntimeException(String.format("State name corrupt detected: '%s' is used both as MAP state and '%s' state at the same time.", str, mapStateDescriptor2.getType()));
            }
            mapStateDescriptor = new MapStateDescriptor(str, ByteArrayWrapperSerializer.INSTANCE, this.valueSerializer);
            if (parseFrom.hasStateTtlConfig()) {
                mapStateDescriptor.enableTimeToLive(ProtoUtils.parseStateTtlConfigFromProto(parseFrom.getStateTtlConfig()));
            }
            this.stateDescriptorCache.put(str, mapStateDescriptor);
        }
        byte[] byteArray2 = multimapSideInput.getWindow().toByteArray();
        if (byteArray2.length == 0) {
            return this.keyedStateBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mapStateDescriptor);
        }
        Preconditions.checkNotNull(this.namespaceSerializer);
        this.bais.setBuffer(byteArray2, 0, byteArray2.length);
        return this.keyedStateBackend.getPartitionedState(this.namespaceSerializer.deserialize(this.baisWrapper), this.namespaceSerializer, mapStateDescriptor);
    }

    private void setCurrentKey(Object obj) {
        if (this.keyedStateBackend.getKeySerializer() instanceof RowDataSerializer) {
            PythonOperatorUtils.setCurrentKeyForStreaming(this.keyedStateBackend, this.keyedStateBackend.getKeySerializer().toBinaryRow((RowData) obj));
        } else {
            PythonOperatorUtils.setCurrentKeyForStreaming(this.keyedStateBackend, obj);
        }
    }
}
