/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.store.embedding.redis;

import dev.langchain4j.community.store.embedding.redis.RedisJsonUtils;
import dev.langchain4j.community.store.embedding.redis.RedisMetadataFilterMapper;
import dev.langchain4j.community.store.embedding.redis.RedisRequestFailedException;
import dev.langchain4j.community.store.embedding.redis.RedisSchema;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.params.ScanParams;
import redis.clients.jedis.resps.ScanResult;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.FTCreateParams;
import redis.clients.jedis.search.IndexDataType;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.RediSearchUtil;
import redis.clients.jedis.search.SearchResult;
import redis.clients.jedis.search.schemafields.SchemaField;
import redis.clients.jedis.search.schemafields.TextField;

public class RedisEmbeddingStore
implements EmbeddingStore<TextSegment>,
AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(RedisEmbeddingStore.class);
    private static final String QUERY_TEMPLATE = "%s=>[ KNN %d @%s $BLOB AS %s ]";
    private final JedisPooled client;
    private final RedisSchema schema;
    private final RedisMetadataFilterMapper filterMapper;

    public RedisEmbeddingStore(String host, Integer port, String user, String password, String indexName, String prefix, Integer dimension, Map<String, SchemaField> metadataConfig) {
        ValidationUtils.ensureNotBlank((String)host, (String)"host");
        ValidationUtils.ensureNotNull((Object)port, (String)"port");
        this.client = user == null ? new JedisPooled(host, port.intValue()) : new JedisPooled(host, port.intValue(), user, password);
        this.schema = RedisSchema.builder().indexName((String)Utils.getOrDefault((Object)indexName, (Object)"embedding-index")).prefix((String)Utils.getOrDefault((Object)prefix, (Object)"embedding:")).dimension(dimension).metadataConfig(Utils.copyIfNotNull(metadataConfig)).build();
        this.filterMapper = new RedisMetadataFilterMapper(metadataConfig);
        if (!this.isIndexExist(this.schema.indexName())) {
            ValidationUtils.ensureNotNull((Object)dimension, (String)"dimension");
            this.createIndex(this.schema.indexName());
        }
    }

    public RedisEmbeddingStore(String uri, String indexName, String prefix, Integer dimension, Map<String, SchemaField> metadataConfig) {
        ValidationUtils.ensureNotBlank((String)uri, (String)"uri");
        this.client = new JedisPooled(uri);
        this.schema = RedisSchema.builder().indexName((String)Utils.getOrDefault((Object)indexName, (Object)"embedding-index")).prefix((String)Utils.getOrDefault((Object)prefix, (Object)"embedding:")).dimension(dimension).metadataConfig(Utils.copyIfNotNull(metadataConfig)).build();
        this.filterMapper = new RedisMetadataFilterMapper(metadataConfig);
        if (!this.isIndexExist(this.schema.indexName())) {
            ValidationUtils.ensureNotNull((Object)dimension, (String)"dimension");
            this.createIndex(this.schema.indexName());
        }
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAll(ids, embeddings, null);
        return ids;
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
        Query query = new Query(String.format(QUERY_TEMPLATE, this.filterMapper.mapToFilter(request.filter()), request.maxResults(), this.schema.vectorFieldName(), "vector_score")).addParam("BLOB", (Object)RediSearchUtil.toByteArray((float[])request.queryEmbedding().vector())).setSortBy("vector_score", true).limit(Integer.valueOf(0), Integer.valueOf(request.maxResults())).dialect(2);
        SearchResult result = this.client.ftSearch(this.schema.indexName(), query);
        List documents = result.getDocuments();
        return new EmbeddingSearchResult(this.toEmbeddingMatch(documents, request.minScore()));
    }

    public void removeAll(Collection<String> ids) {
        ValidationUtils.ensureNotEmpty(ids, (String)"ids");
        String[] redisKeys = (String[])ids.stream().map(id -> this.schema.prefix() + id).toArray(String[]::new);
        this.client.del(redisKeys);
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull((Object)filter, (String)"filter");
        SearchResult results = this.client.ftSearch(this.schema.indexName(), this.filterMapper.mapToFilter(filter));
        String[] keys = (String[])results.getDocuments().stream().map(Document::getId).toArray(String[]::new);
        this.client.del(keys);
    }

    public void removeAll() {
        HashSet matchingKeys = new HashSet();
        ScanParams params = new ScanParams();
        params.match(this.schema.prefix() + "*");
        String nextCursor = "0";
        do {
            ScanResult scanResult = this.client.scan(nextCursor, params);
            List keys = scanResult.getResult();
            nextCursor = scanResult.getCursor();
            matchingKeys.addAll(keys);
        } while (!nextCursor.equals("0"));
        if (matchingKeys.isEmpty()) {
            return;
        }
        this.client.del(matchingKeys.toArray(new String[0]));
    }

    private void createIndex(String indexName) {
        String res = this.client.ftCreate(indexName, FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.schema.prefix()), this.schema.toSchemaFields());
        if (!"OK".equals(res)) {
            if (log.isErrorEnabled()) {
                log.error("create index error, msg={}", (Object)res);
            }
            throw new RedisRequestFailedException("create index error, msg=" + res);
        }
    }

    private boolean isIndexExist(String indexName) {
        Set indexes = this.client.ftList();
        return indexes.contains(indexName);
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        List responses;
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("do not add empty embeddings to redis");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        try (Pipeline pipeline = this.client.pipelined();){
            int size = ids.size();
            for (int i = 0; i < size; ++i) {
                String id = ids.get(i);
                Embedding embedding = embeddings.get(i);
                TextSegment textSegment = embedded == null ? null : embedded.get(i);
                HashMap<String, Object> fields = new HashMap<String, Object>();
                fields.put(this.schema.vectorFieldName(), embedding.vector());
                if (textSegment != null) {
                    fields.put(this.schema.scalarFieldName(), textSegment.text());
                    fields.putAll(textSegment.metadata().toMap());
                }
                String key = this.schema.prefix() + id;
                pipeline.jsonSetWithEscape(key, RedisSchema.JSON_SET_PATH, fields);
            }
            responses = pipeline.syncAndReturnAll();
        }
        Optional<Object> errResponse = responses.stream().filter(response -> !"OK".equals(response)).findAny();
        if (errResponse.isPresent()) {
            if (log.isErrorEnabled()) {
                log.error("add embedding failed, msg={}", errResponse.get());
            }
            throw new RedisRequestFailedException("add embedding failed, msg=" + String.valueOf(errResponse.get()));
        }
    }

    private List<EmbeddingMatch<TextSegment>> toEmbeddingMatch(List<Document> documents, double minScore) {
        if (documents == null || documents.isEmpty()) {
            return new ArrayList<EmbeddingMatch<TextSegment>>();
        }
        return documents.stream().map(document -> {
            double score = (2.0 - Double.parseDouble(document.getString("vector_score"))) / 2.0;
            String id = document.getId().substring(this.schema.prefix().length());
            Map<String, Object> properties = RedisJsonUtils.toProperties(document.getString("$"));
            List vectors = (List)properties.get(this.schema.vectorFieldName());
            Embedding embedding = Embedding.from(vectors.stream().map(Double::floatValue).collect(Collectors.toList()));
            String text = properties.containsKey(this.schema.scalarFieldName()) ? (String)properties.get(this.schema.scalarFieldName()) : null;
            TextSegment textSegment = null;
            if (text != null) {
                Map<String, Object> metadata = this.schema.schemaFieldMap().keySet().stream().filter(properties::containsKey).collect(Collectors.toMap(metadataKey -> metadataKey, properties::get));
                textSegment = TextSegment.from((String)text, (Metadata)Metadata.from(metadata));
            }
            return new EmbeddingMatch(Double.valueOf(score), id, embedding, textSegment);
        }).filter(embeddingMatch -> embeddingMatch.score() >= minScore).collect(Collectors.toList());
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public void close() {
        this.client.close();
    }

    public static class Builder {
        private String uri;
        private String host;
        private Integer port;
        private String user;
        private String password;
        private String indexName;
        private String prefix;
        private Integer dimension;
        private Map<String, SchemaField> metadataConfig = new HashMap<String, SchemaField>();

        public Builder uri(String uri) {
            this.uri = uri;
            return this;
        }

        public Builder host(String host) {
            this.host = host;
            return this;
        }

        public Builder port(Integer port) {
            this.port = port;
            return this;
        }

        public Builder user(String user) {
            this.user = user;
            return this;
        }

        public Builder password(String password) {
            this.password = password;
            return this;
        }

        public Builder indexName(String indexName) {
            this.indexName = indexName;
            return this;
        }

        public Builder prefix(String prefix) {
            this.prefix = prefix;
            return this;
        }

        public Builder dimension(Integer dimension) {
            this.dimension = dimension;
            return this;
        }

        @Deprecated
        public Builder metadataFieldsName(Collection<String> metadataFieldsName) {
            return this.metadataKeys(metadataFieldsName);
        }

        public Builder metadataKeys(Collection<String> metadataKeys) {
            if (!Utils.isNullOrEmpty(metadataKeys)) {
                metadataKeys.forEach(metadataKey -> this.metadataConfig.put((String)metadataKey, (SchemaField)TextField.of((String)("$." + metadataKey)).as(metadataKey).weight(1.0)));
            }
            return this;
        }

        public Builder metadataConfig(Map<String, SchemaField> metadataConfig) {
            this.metadataConfig = metadataConfig;
            return this;
        }

        public RedisEmbeddingStore build() {
            if (this.uri != null) {
                return new RedisEmbeddingStore(this.uri, this.indexName, this.prefix, this.dimension, this.metadataConfig);
            }
            return new RedisEmbeddingStore(this.host, this.port, this.user, this.password, this.indexName, this.prefix, this.dimension, this.metadataConfig);
        }
    }
}

